mistralrs_core/tools/
mod.rs

1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde::de::{self, Deserializer, MapAccess, Visitor};
9use serde_json::{Map, Value};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::{Arc, OnceLock};
13use uuid::Uuid;
14
15use crate::Pipeline;
16
17/// Callback used for custom tool functions. Receives the called function
18/// (name and JSON arguments) and returns the tool output as a string.
19pub type ToolCallback = dyn Fn(&CalledFunction) -> anyhow::Result<String> + Send + Sync;
20
21/// Collection of callbacks keyed by tool name.
22pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
23
24fn contains_tool_call_prefix(prefix: &str) -> bool {
25    prefix.contains("<tool_call>")
26        || prefix.contains("<|tool▁call▁begin|>")
27        || prefix.contains("<|python_tag|>")
28        || prefix.contains("[TOOL_CALLS]")
29}
30
31fn process_model_specific_message(message: &str) -> Result<String> {
32    static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
33    static QWEN_REGEX: OnceLock<Regex> = OnceLock::new();
34
35    // These are reasoning models so we need a regex.
36    let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
37        r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
38    ).unwrap());
39    let qwen_regex = QWEN_REGEX
40        .get_or_init(|| Regex::new(r"(?s)<tool_call>(?P<inner>.*?)</tool_call>").unwrap());
41
42    if let Some(message) = message.strip_prefix("<|python_tag|>") {
43        // Llama case
44        Ok(message.to_string())
45    } else if qwen_regex.is_match(message) {
46        if let Some(caps) = qwen_regex.captures(message) {
47            let inner = caps.name("inner").unwrap().as_str();
48            return Ok(inner.trim().to_string());
49        }
50        Ok(message.to_string())
51    } else if let Some(message) = message
52        .strip_prefix("[TOOL_CALLS][")
53        .and_then(|s| s.strip_suffix("]"))
54    {
55        // Mistral Nemo case
56        Ok(message.to_string())
57    } else if deepseek_regex.find(message).is_some() {
58        #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
59        struct ToolCall {
60            name: String,
61            arguments: Value,
62        }
63        let mut calls = Vec::new();
64        for caps in deepseek_regex.captures_iter(message) {
65            let name = caps
66                .name("name")
67                .ok_or("Could not capture function name")
68                .map_err(candle_core::Error::msg)?
69                .as_str()
70                .trim()
71                .to_string();
72            let json_str = caps
73                .name("json")
74                .ok_or("Could not capture JSON arguments")
75                .map_err(candle_core::Error::msg)?
76                .as_str()
77                .trim();
78            let arguments: Value =
79                serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
80            calls.push(ToolCall { name, arguments });
81        }
82        Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
83    } else {
84        Ok(message.to_string())
85    }
86}
87
88pub struct ToolCallingMatcher {
89    tool_choice: ToolChoice,
90}
91
92// Same as CalledFunction, but has different cases for variations on the names
93#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
94pub struct CalledFunctionParameters {
95    #[serde(alias = "function")]
96    pub name: String,
97    #[serde(alias = "arguments", deserialize_with = "flexible_args")]
98    pub parameters: Value,
99}
100
101// Accept either `{...}` **or** a `"stringified { ... }"`
102fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
103where
104    D: Deserializer<'de>,
105{
106    struct ArgVisitor;
107
108    impl<'de> Visitor<'de> for ArgVisitor {
109        type Value = Value;
110
111        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
112            f.write_str("an object or a JSON-encoded string containing an object")
113        }
114
115        // Case 1 – the good case: already a JSON object
116        fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
117        where
118            M: MapAccess<'de>,
119        {
120            let mut map = Map::new();
121            while let Some((k, v)) = m.next_entry()? {
122                map.insert(k, v);
123            }
124            Ok(Value::Object(map))
125        }
126
127        // Case 2 – got a *string*; try parsing it as JSON
128        fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
129        where
130            E: de::Error,
131        {
132            serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
133        }
134    }
135
136    d.deserialize_any(ArgVisitor)
137}
138
139/// Fixup potentially broken JSON
140/// 1) allow/handle arguments as maps in quotations
141fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
142    // 1) Delete the opening quote that shouldn’t be there
143    let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
144    // 2) Delete the closing quote that matches it
145    let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
146
147    Ok(fixed)
148}
149
150impl ToolCallingMatcher {
151    pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
152        Ok(Self { tool_choice })
153    }
154
155    // Checks if the `message_prefix` could be a tool call. If false, either
156    // [`ToolChoice::None`] was selected, or the prefix could not match.
157    //
158    // If the start of a message could be a tool call, then it looks like an incomplete JSON of a given structure, e.g. `{"name": "foo", "param`.
159    //
160    // Returns a tuple of `(could_be_tool, is_complete_tool)`.
161    pub fn prefix_could_be_tool(
162        &self,
163        _pipeline: &dyn Pipeline,
164        message_prefix: &str,
165    ) -> Result<(bool, bool)> {
166        if matches!(self.tool_choice, ToolChoice::None) {
167            return Ok((false, false));
168        }
169        let message_prefix = process_model_specific_message(message_prefix)?;
170        let message_prefix = fix_broken_json(&message_prefix).unwrap();
171
172        // Check if the prefix could be a JSON serialization of any of the following types.
173        Ok([
174            could_be_json::<CalledFunctionParameters>,
175            could_be_json::<Vec<CalledFunctionParameters>>,
176        ]
177        .iter()
178        .find_map(|check| {
179            let (could_be_tool, is_complete_tool) = check(&message_prefix);
180            if could_be_tool || is_complete_tool {
181                Some((could_be_tool, is_complete_tool))
182            } else {
183                None
184            }
185        })
186        .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
187    }
188
189    pub fn get_call(
190        &self,
191        _pipeline: &dyn Pipeline,
192        message: &str,
193    ) -> anyhow::Result<Vec<ToolCallResponse>> {
194        if matches!(self.tool_choice, ToolChoice::None) {
195            return Ok(Vec::new());
196        }
197        let message = process_model_specific_message(message)?;
198        let message = fix_broken_json(&message).unwrap();
199
200        if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
201            let id = format!("call-{}", Uuid::new_v4());
202            Ok(vec![ToolCallResponse {
203                id,
204                tp: ToolCallType::Function,
205                function: CalledFunction {
206                    name: deser.name,
207                    arguments: serde_json::to_string(&deser.parameters)?,
208                },
209            }])
210        } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
211            Ok(deser
212                .into_iter()
213                .map(|deser| {
214                    let id = format!("call-{}", Uuid::new_v4());
215                    Ok(ToolCallResponse {
216                        id,
217                        tp: ToolCallType::Function,
218                        function: CalledFunction {
219                            name: deser.name,
220                            arguments: serde_json::to_string(&deser.parameters)?,
221                        },
222                    })
223                })
224                .collect::<anyhow::Result<Vec<_>>>()?)
225        } else {
226            if matches!(self.tool_choice, ToolChoice::Tool(_)) {
227                anyhow::bail!("Tool choice was required but no tools were called.")
228            }
229            Ok(Vec::new())
230        }
231    }
232}
233
234/// Checks if the given prefix could be the start of, or the entire JSON serialization of a given type, `T`.
235///
236/// Returns a tuple of `(could_be_tool, is_entire_tool)`.
237fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
238where
239    T: serde::de::DeserializeOwned,
240{
241    if text_prefix.is_empty() {
242        return (false, false);
243    }
244    match serde_json::from_str::<T>(text_prefix) {
245        Ok(_) => (false, true),
246        // EOF show that JSON parsing was successful up to the end of the entire string.
247        Err(e) if e.is_eof() => (true, false),
248        _ => (false, false),
249    }
250}
251
252/// Takes raw UTf8 text and parses any possible tool calls from it.
253pub fn parse_text_tools<'a>(
254    pipeline: &dyn Pipeline,
255    raw_text: &'a str,
256    matcher: Option<Arc<ToolCallingMatcher>>,
257) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
258    let mut tool_calls = Vec::new();
259    let mut text_new = Some(raw_text);
260
261    if let Some(ref matcher) = matcher {
262        let calls = matcher
263            .get_call(pipeline, raw_text)
264            .map_err(candle_core::Error::msg)?;
265        if !calls.is_empty() {
266            text_new = None;
267            tool_calls = calls;
268        }
269    };
270    Ok((text_new, tool_calls))
271}