mistralrs_core/tools/
mod.rs

1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde::de::{self, Deserializer, MapAccess, Visitor};
9use serde_json::{Map, Value};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::{Arc, OnceLock};
13use uuid::Uuid;
14
15use crate::Pipeline;
16use mistralrs_mcp::CalledFunction;
17
18// Re-export the types so they're accessible as tools::Type
19pub use mistralrs_mcp::{ToolCallback, ToolCallbackWithTool};
20
21/// Collection of callbacks keyed by tool name.
22pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
23
24/// Collection of callbacks with their tool definitions keyed by tool name.
25pub type ToolCallbacksWithTools = HashMap<String, ToolCallbackWithTool>;
26
27fn contains_tool_call_prefix(prefix: &str) -> bool {
28    prefix.contains("<tool_call>")
29        || prefix.contains("<|tool▁call▁begin|>")
30        || prefix.contains("<|python_tag|>")
31        || prefix.contains("[TOOL_CALLS]")
32}
33
34fn process_model_specific_message(message: &str) -> Result<String> {
35    static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
36    static QWEN_REGEX: OnceLock<Regex> = OnceLock::new();
37
38    // These are reasoning models so we need a regex.
39    let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
40        r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
41    ).unwrap());
42    let qwen_regex = QWEN_REGEX
43        .get_or_init(|| Regex::new(r"(?s)<tool_call>(?P<inner>.*?)</tool_call>").unwrap());
44
45    if let Some(message) = message.strip_prefix("<|python_tag|>") {
46        // Llama case
47        Ok(message.to_string())
48    } else if qwen_regex.is_match(message) {
49        if let Some(caps) = qwen_regex.captures(message) {
50            let inner = caps.name("inner").unwrap().as_str();
51            return Ok(inner.trim().to_string());
52        }
53        Ok(message.to_string())
54    } else if let Some(message) = message
55        .strip_prefix("[TOOL_CALLS][")
56        .and_then(|s| s.strip_suffix("]"))
57    {
58        // Mistral Nemo case
59        Ok(message.to_string())
60    } else if deepseek_regex.find(message).is_some() {
61        #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
62        struct ToolCall {
63            name: String,
64            arguments: Value,
65        }
66        let mut calls = Vec::new();
67        for caps in deepseek_regex.captures_iter(message) {
68            let name = caps
69                .name("name")
70                .ok_or("Could not capture function name")
71                .map_err(candle_core::Error::msg)?
72                .as_str()
73                .trim()
74                .to_string();
75            let json_str = caps
76                .name("json")
77                .ok_or("Could not capture JSON arguments")
78                .map_err(candle_core::Error::msg)?
79                .as_str()
80                .trim();
81            let arguments: Value =
82                serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
83            calls.push(ToolCall { name, arguments });
84        }
85        Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
86    } else {
87        Ok(message.to_string())
88    }
89}
90
91pub struct ToolCallingMatcher {
92    tool_choice: ToolChoice,
93}
94
95// Same as CalledFunction, but has different cases for variations on the names
96#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
97pub struct CalledFunctionParameters {
98    #[serde(alias = "function")]
99    pub name: String,
100    #[serde(alias = "arguments", deserialize_with = "flexible_args")]
101    pub parameters: Value,
102}
103
104// Accept either `{...}` **or** a `"stringified { ... }"`
105fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
106where
107    D: Deserializer<'de>,
108{
109    struct ArgVisitor;
110
111    impl<'de> Visitor<'de> for ArgVisitor {
112        type Value = Value;
113
114        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
115            f.write_str("an object or a JSON-encoded string containing an object")
116        }
117
118        // Case 1 – the good case: already a JSON object
119        fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
120        where
121            M: MapAccess<'de>,
122        {
123            let mut map = Map::new();
124            while let Some((k, v)) = m.next_entry()? {
125                map.insert(k, v);
126            }
127            Ok(Value::Object(map))
128        }
129
130        // Case 2 – got a *string*; try parsing it as JSON
131        fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
132        where
133            E: de::Error,
134        {
135            serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
136        }
137    }
138
139    d.deserialize_any(ArgVisitor)
140}
141
142/// Fixup potentially broken JSON
143/// 1) allow/handle arguments as maps in quotations
144fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
145    // 1) Delete the opening quote that shouldn’t be there
146    let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
147    // 2) Delete the closing quote that matches it
148    let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
149
150    Ok(fixed)
151}
152
153impl ToolCallingMatcher {
154    pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
155        Ok(Self { tool_choice })
156    }
157
158    // Checks if the `message_prefix` could be a tool call. If false, either
159    // [`ToolChoice::None`] was selected, or the prefix could not match.
160    //
161    // If the start of a message could be a tool call, then it looks like an incomplete JSON of a given structure, e.g. `{"name": "foo", "param`.
162    //
163    // Returns a tuple of `(could_be_tool, is_complete_tool)`.
164    pub fn prefix_could_be_tool(
165        &self,
166        _pipeline: &dyn Pipeline,
167        message_prefix: &str,
168    ) -> Result<(bool, bool)> {
169        if matches!(self.tool_choice, ToolChoice::None) {
170            return Ok((false, false));
171        }
172        let message_prefix = process_model_specific_message(message_prefix)?;
173        let message_prefix = fix_broken_json(&message_prefix).unwrap();
174
175        // Check if the prefix could be a JSON serialization of any of the following types.
176        Ok([
177            could_be_json::<CalledFunctionParameters>,
178            could_be_json::<Vec<CalledFunctionParameters>>,
179        ]
180        .iter()
181        .find_map(|check| {
182            let (could_be_tool, is_complete_tool) = check(&message_prefix);
183            if could_be_tool || is_complete_tool {
184                Some((could_be_tool, is_complete_tool))
185            } else {
186                None
187            }
188        })
189        .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
190    }
191
192    pub fn get_call(
193        &self,
194        _pipeline: &dyn Pipeline,
195        message: &str,
196    ) -> anyhow::Result<Vec<ToolCallResponse>> {
197        if matches!(self.tool_choice, ToolChoice::None) {
198            return Ok(Vec::new());
199        }
200        let message = process_model_specific_message(message)?;
201        let message = fix_broken_json(&message).unwrap();
202
203        if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
204            let id = format!("call-{}", Uuid::new_v4());
205            Ok(vec![ToolCallResponse {
206                index: 0,
207                id,
208                tp: ToolCallType::Function,
209                function: CalledFunction {
210                    name: deser.name,
211                    arguments: serde_json::to_string(&deser.parameters)?,
212                },
213            }])
214        } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
215            Ok(deser
216                .into_iter()
217                .enumerate()
218                .map(|(idx, deser)| {
219                    let id = format!("call-{}", Uuid::new_v4());
220                    Ok(ToolCallResponse {
221                        index: idx,
222                        id,
223                        tp: ToolCallType::Function,
224                        function: CalledFunction {
225                            name: deser.name,
226                            arguments: serde_json::to_string(&deser.parameters)?,
227                        },
228                    })
229                })
230                .collect::<anyhow::Result<Vec<_>>>()?)
231        } else {
232            if matches!(self.tool_choice, ToolChoice::Tool(_)) {
233                anyhow::bail!("Tool choice was required but no tools were called.")
234            }
235            Ok(Vec::new())
236        }
237    }
238}
239
240/// Checks if the given prefix could be the start of, or the entire JSON serialization of a given type, `T`.
241///
242/// Returns a tuple of `(could_be_tool, is_entire_tool)`.
243fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
244where
245    T: serde::de::DeserializeOwned,
246{
247    if text_prefix.is_empty() {
248        return (false, false);
249    }
250    match serde_json::from_str::<T>(text_prefix) {
251        Ok(_) => (false, true),
252        // EOF show that JSON parsing was successful up to the end of the entire string.
253        Err(e) if e.is_eof() => (true, false),
254        _ => (false, false),
255    }
256}
257
258/// Takes raw UTf8 text and parses any possible tool calls from it.
259pub fn parse_text_tools<'a>(
260    pipeline: &dyn Pipeline,
261    raw_text: &'a str,
262    matcher: Option<Arc<ToolCallingMatcher>>,
263) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
264    let mut tool_calls = Vec::new();
265    let mut text_new = Some(raw_text);
266
267    if let Some(ref matcher) = matcher {
268        let calls = matcher
269            .get_call(pipeline, raw_text)
270            .map_err(candle_core::Error::msg)?;
271        if !calls.is_empty() {
272            text_new = None;
273            tool_calls = calls;
274        }
275    };
276    Ok((text_new, tool_calls))
277}