mistralrs_core/tools/
mod.rs

1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde_json::Value;
9use std::{
10    collections::HashMap,
11    sync::{Arc, OnceLock},
12};
13use uuid::Uuid;
14
15use crate::Pipeline;
16
17fn process_model_specific_message(message: &str) -> Result<String> {
18    static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
19    let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
20        r"<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
21    ).unwrap());
22
23    if let Some(message) = message.strip_prefix("<|python_tag|>") {
24        // Llama case
25        Ok(message.to_string())
26    } else if let Some(message) = message
27        .strip_prefix("<tool_call>")
28        .and_then(|s| s.strip_suffix("</tool_call>"))
29    {
30        // Hermes case
31        Ok(message.to_string())
32    } else if let Some(message) = message
33        .strip_prefix("[TOOL_CALLS][")
34        .and_then(|s| s.strip_suffix("]"))
35    {
36        // Mistral Nemo case
37        Ok(message.to_string())
38    } else if deepseek_regex.find(message).is_some() {
39        #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
40        struct ToolCall {
41            name: String,
42            arguments: Value,
43        }
44        let mut calls = Vec::new();
45        for caps in deepseek_regex.captures_iter(message) {
46            let name = caps
47                .name("name")
48                .ok_or("Could not capture function name")
49                .map_err(candle_core::Error::msg)?
50                .as_str()
51                .trim()
52                .to_string();
53            let json_str = caps
54                .name("json")
55                .ok_or("Could not capture JSON arguments")
56                .map_err(candle_core::Error::msg)?
57                .as_str()
58                .trim();
59            let arguments: Value =
60                serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
61            calls.push(ToolCall { name, arguments });
62        }
63        Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
64    } else {
65        Ok(message.to_string())
66    }
67}
68
69pub struct ToolCallingMatcher {
70    tool_choice: ToolChoice,
71}
72
73// Same as CalledFunction, but has different cases for variations on the names
74#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
75pub struct CalledFunctionParameters {
76    #[serde(alias = "function")]
77    pub name: String,
78    #[serde(alias = "arguments")]
79    pub parameters: HashMap<String, Value>,
80}
81
82impl ToolCallingMatcher {
83    pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
84        Ok(Self { tool_choice })
85    }
86
87    // Checks if the the `message_prefix` could be a tool call. If false, either
88    // [`ToolChoice::None`] was selected, or the prefix could not match.
89    //
90    // If the start of a message could be a tool call, then it looks like an incomplete JSON of a given structure, e.g. `{"name": "foo", "param`.
91    //
92    // Returns a tuple of `(could_be_tool, is_complete_tool)`.
93    pub fn prefix_could_be_tool(
94        &self,
95        _pipeline: &dyn Pipeline,
96        message_prefix: &str,
97    ) -> Result<(bool, bool)> {
98        if matches!(self.tool_choice, ToolChoice::None) {
99            return Ok((false, false));
100        }
101        let message_prefix = process_model_specific_message(message_prefix)?;
102
103        // Check if the prefix could be a JSON serialization of any of the following types.
104        Ok([
105            could_be_json::<CalledFunctionParameters>,
106            could_be_json::<Vec<CalledFunctionParameters>>,
107        ]
108        .iter()
109        .find_map(|check| {
110            let (could_be_tool, is_complete_tool) = check(&message_prefix);
111            if could_be_tool || is_complete_tool {
112                Some((could_be_tool, is_complete_tool))
113            } else {
114                None
115            }
116        })
117        .unwrap_or_default())
118    }
119
120    pub fn get_call(
121        &self,
122        _pipeline: &dyn Pipeline,
123        message: &str,
124    ) -> anyhow::Result<Vec<ToolCallResponse>> {
125        if matches!(self.tool_choice, ToolChoice::None) {
126            return Ok(Vec::new());
127        }
128        let message = process_model_specific_message(message)?;
129
130        if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
131            let id = format!("call-{}", Uuid::new_v4());
132            Ok(vec![ToolCallResponse {
133                id,
134                tp: ToolCallType::Function,
135                function: CalledFunction {
136                    name: deser.name,
137                    arguments: serde_json::to_string(&deser.parameters)?,
138                },
139            }])
140        } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
141            Ok(deser
142                .into_iter()
143                .map(|deser| {
144                    let id = format!("call-{}", Uuid::new_v4());
145                    Ok(ToolCallResponse {
146                        id,
147                        tp: ToolCallType::Function,
148                        function: CalledFunction {
149                            name: deser.name,
150                            arguments: serde_json::to_string(&deser.parameters)?,
151                        },
152                    })
153                })
154                .collect::<anyhow::Result<Vec<_>>>()?)
155        } else {
156            if matches!(self.tool_choice, ToolChoice::Tool(_)) {
157                anyhow::bail!("Tool choice was required but no tools were called.")
158            }
159            Ok(Vec::new())
160        }
161    }
162}
163
164/// Checks if the given prefix could be the start of, or the entire JSON serialization of a given type, `T`.
165///
166/// Returns a tuple of `(could_be_tool, is_entire_tool)`.
167fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
168where
169    T: serde::de::DeserializeOwned,
170{
171    if text_prefix.is_empty() {
172        return (false, false);
173    }
174    match serde_json::from_str::<T>(text_prefix) {
175        Ok(_) => (false, true),
176        // EOF show that JSON parsing was successful up to the end of the entire string.
177        Err(e) if e.is_eof() => (true, false),
178        _ => (false, false),
179    }
180}
181
182/// Takes raw UTf8 text and parses any possible tool calls from it.
183pub fn parse_text_tools<'a>(
184    pipeline: &dyn Pipeline,
185    raw_text: &'a str,
186    matcher: Option<Arc<ToolCallingMatcher>>,
187) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
188    let mut tool_calls = Vec::new();
189    let mut text_new = Some(raw_text);
190
191    if let Some(ref matcher) = matcher {
192        let calls = matcher
193            .get_call(pipeline, raw_text)
194            .map_err(candle_core::Error::msg)?;
195        if !calls.is_empty() {
196            text_new = None;
197            tool_calls = calls;
198        }
199    };
200    Ok((text_new, tool_calls))
201}