mistralrs_core/tools/
mod.rs

1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde::de::{self, Deserializer, MapAccess, Visitor};
9use serde_json::{Map, Value};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::{Arc, OnceLock};
13use uuid::Uuid;
14
15use crate::Pipeline;
16use mistralrs_mcp::CalledFunction;
17
18// Re-export the types so they're accessible as tools::Type
19pub use mistralrs_mcp::{ToolCallback, ToolCallbackWithTool};
20
21/// Collection of callbacks keyed by tool name.
22pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
23
24/// Collection of callbacks with their tool definitions keyed by tool name.
25pub type ToolCallbacksWithTools = HashMap<String, ToolCallbackWithTool>;
26
27fn contains_tool_call_prefix(prefix: &str) -> bool {
28    prefix.contains("<tool_call>")
29        || prefix.contains("<|tool▁call▁begin|>")
30        || prefix.contains("<|python_tag|>")
31        || prefix.contains("[TOOL_CALLS]")
32}
33
34fn process_model_specific_message(message: &str) -> Result<String> {
35    static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
36    static QWEN_REGEX: OnceLock<Regex> = OnceLock::new();
37
38    // These are reasoning models so we need a regex.
39    let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
40        r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
41    ).unwrap());
42    let qwen_regex = QWEN_REGEX
43        .get_or_init(|| Regex::new(r"(?s)<tool_call>(?P<inner>.*?)</tool_call>").unwrap());
44
45    if let Some(message) = message.strip_prefix("<|python_tag|>") {
46        // Llama case
47        Ok(message.to_string())
48    } else if qwen_regex.is_match(message) {
49        if let Some(caps) = qwen_regex.captures(message) {
50            let inner = caps.name("inner").unwrap().as_str();
51            return Ok(inner.trim().to_string());
52        }
53        Ok(message.to_string())
54    } else if let Some(message) = message
55        .strip_prefix("[TOOL_CALLS][")
56        .and_then(|s| s.strip_suffix("]"))
57    {
58        // Mistral Nemo case
59        Ok(message.to_string())
60    } else if deepseek_regex.find(message).is_some() {
61        #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
62        struct ToolCall {
63            name: String,
64            arguments: Value,
65        }
66        let mut calls = Vec::new();
67        for caps in deepseek_regex.captures_iter(message) {
68            let name = caps
69                .name("name")
70                .ok_or("Could not capture function name")
71                .map_err(candle_core::Error::msg)?
72                .as_str()
73                .trim()
74                .to_string();
75            let json_str = caps
76                .name("json")
77                .ok_or("Could not capture JSON arguments")
78                .map_err(candle_core::Error::msg)?
79                .as_str()
80                .trim();
81            let arguments: Value =
82                serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
83            calls.push(ToolCall { name, arguments });
84        }
85        Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
86    } else {
87        Ok(message.to_string())
88    }
89}
90
91pub struct ToolCallingMatcher {
92    tool_choice: ToolChoice,
93}
94
95// Same as CalledFunction, but has different cases for variations on the names
96#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
97pub struct CalledFunctionParameters {
98    #[serde(alias = "function")]
99    pub name: String,
100    #[serde(alias = "arguments", deserialize_with = "flexible_args")]
101    pub parameters: Value,
102}
103
104// Accept either `{...}` **or** a `"stringified { ... }"`
105fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
106where
107    D: Deserializer<'de>,
108{
109    struct ArgVisitor;
110
111    impl<'de> Visitor<'de> for ArgVisitor {
112        type Value = Value;
113
114        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
115            f.write_str("an object or a JSON-encoded string containing an object")
116        }
117
118        // Case 1 – the good case: already a JSON object
119        fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
120        where
121            M: MapAccess<'de>,
122        {
123            let mut map = Map::new();
124            while let Some((k, v)) = m.next_entry()? {
125                map.insert(k, v);
126            }
127            Ok(Value::Object(map))
128        }
129
130        // Case 2 – got a *string*; try parsing it as JSON
131        fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
132        where
133            E: de::Error,
134        {
135            serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
136        }
137    }
138
139    d.deserialize_any(ArgVisitor)
140}
141
142/// Fixup potentially broken JSON
143/// 1) allow/handle arguments as maps in quotations
144fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
145    // Only apply the fix if the first pattern matches - otherwise we might corrupt valid JSON
146    // where arguments is a properly escaped string containing `}`
147    if raw.contains(r#""arguments":"{"#) {
148        // 1) Delete the opening quote that shouldn't be there
149        let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
150        // 2) Delete the closing quote that matches it
151        let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
152        Ok(fixed)
153    } else {
154        Ok(raw.to_string())
155    }
156}
157
158impl ToolCallingMatcher {
159    pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
160        Ok(Self { tool_choice })
161    }
162
163    // Checks if the `message_prefix` could be a tool call. If false, either
164    // [`ToolChoice::None`] was selected, or the prefix could not match.
165    //
166    // If the start of a message could be a tool call, then it looks like an incomplete JSON of a given structure, e.g. `{"name": "foo", "param`.
167    //
168    // Returns a tuple of `(could_be_tool, is_complete_tool)`.
169    pub fn prefix_could_be_tool(
170        &self,
171        _pipeline: &dyn Pipeline,
172        message_prefix: &str,
173    ) -> Result<(bool, bool)> {
174        if matches!(self.tool_choice, ToolChoice::None) {
175            return Ok((false, false));
176        }
177        let message_prefix = process_model_specific_message(message_prefix)?;
178        let message_prefix = fix_broken_json(&message_prefix).unwrap();
179
180        // Check if the prefix could be a JSON serialization of any of the following types.
181        Ok([
182            could_be_json::<CalledFunctionParameters>,
183            could_be_json::<Vec<CalledFunctionParameters>>,
184        ]
185        .iter()
186        .find_map(|check| {
187            let (could_be_tool, is_complete_tool) = check(&message_prefix);
188            if could_be_tool || is_complete_tool {
189                Some((could_be_tool, is_complete_tool))
190            } else {
191                None
192            }
193        })
194        .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
195    }
196
197    pub fn get_call(
198        &self,
199        _pipeline: &dyn Pipeline,
200        message: &str,
201    ) -> anyhow::Result<Vec<ToolCallResponse>> {
202        if matches!(self.tool_choice, ToolChoice::None) {
203            return Ok(Vec::new());
204        }
205        let message = process_model_specific_message(message)?;
206        let message = fix_broken_json(&message).unwrap();
207
208        if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
209            let id = format!("call-{}", Uuid::new_v4());
210            Ok(vec![ToolCallResponse {
211                index: 0,
212                id,
213                tp: ToolCallType::Function,
214                function: CalledFunction {
215                    name: deser.name,
216                    arguments: serde_json::to_string(&deser.parameters)?,
217                },
218            }])
219        } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
220            Ok(deser
221                .into_iter()
222                .enumerate()
223                .map(|(idx, deser)| {
224                    let id = format!("call-{}", Uuid::new_v4());
225                    Ok(ToolCallResponse {
226                        index: idx,
227                        id,
228                        tp: ToolCallType::Function,
229                        function: CalledFunction {
230                            name: deser.name,
231                            arguments: serde_json::to_string(&deser.parameters)?,
232                        },
233                    })
234                })
235                .collect::<anyhow::Result<Vec<_>>>()?)
236        } else {
237            if matches!(self.tool_choice, ToolChoice::Tool(_)) {
238                anyhow::bail!("Tool choice was required but no tools were called.")
239            }
240            Ok(Vec::new())
241        }
242    }
243}
244
245/// Checks if the given prefix could be the start of, or the entire JSON serialization of a given type, `T`.
246///
247/// Returns a tuple of `(could_be_tool, is_entire_tool)`.
248fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
249where
250    T: serde::de::DeserializeOwned,
251{
252    if text_prefix.trim().is_empty() {
253        return (false, false);
254    }
255    match serde_json::from_str::<T>(text_prefix) {
256        Ok(_) => (false, true),
257        // EOF show that JSON parsing was successful up to the end of the entire string.
258        Err(e) if e.is_eof() => (true, false),
259        _ => (false, false),
260    }
261}
262
263/// Takes raw UTf8 text and parses any possible tool calls from it.
264pub fn parse_text_tools<'a>(
265    pipeline: &dyn Pipeline,
266    raw_text: &'a str,
267    matcher: Option<Arc<ToolCallingMatcher>>,
268) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
269    let mut tool_calls = Vec::new();
270    let mut text_new = Some(raw_text);
271
272    if let Some(ref matcher) = matcher {
273        let calls = matcher
274            .get_call(pipeline, raw_text)
275            .map_err(candle_core::Error::msg)?;
276        if !calls.is_empty() {
277            text_new = None;
278            tool_calls = calls;
279        }
280    };
281    Ok((text_new, tool_calls))
282}