mistralrs_core/tools/
mod.rs1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde::de::{self, Deserializer, MapAccess, Visitor};
9use serde_json::{Map, Value};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::{Arc, OnceLock};
13use uuid::Uuid;
14
15use crate::Pipeline;
16use mistralrs_mcp::CalledFunction;
17
18pub use mistralrs_mcp::{ToolCallback, ToolCallbackWithTool};
20
21pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
23
24pub type ToolCallbacksWithTools = HashMap<String, ToolCallbackWithTool>;
26
27fn contains_tool_call_prefix(prefix: &str) -> bool {
28 prefix.contains("<tool_call>")
29 || prefix.contains("<|tool▁call▁begin|>")
30 || prefix.contains("<|python_tag|>")
31 || prefix.contains("[TOOL_CALLS]")
32}
33
34fn process_model_specific_message(message: &str) -> Result<String> {
35 static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
36 static QWEN_REGEX: OnceLock<Regex> = OnceLock::new();
37
38 let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
40 r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
41 ).unwrap());
42 let qwen_regex = QWEN_REGEX
43 .get_or_init(|| Regex::new(r"(?s)<tool_call>(?P<inner>.*?)</tool_call>").unwrap());
44
45 if let Some(message) = message.strip_prefix("<|python_tag|>") {
46 Ok(message.to_string())
48 } else if qwen_regex.is_match(message) {
49 if let Some(caps) = qwen_regex.captures(message) {
50 let inner = caps.name("inner").unwrap().as_str();
51 return Ok(inner.trim().to_string());
52 }
53 Ok(message.to_string())
54 } else if let Some(message) = message
55 .strip_prefix("[TOOL_CALLS][")
56 .and_then(|s| s.strip_suffix("]"))
57 {
58 Ok(message.to_string())
60 } else if deepseek_regex.find(message).is_some() {
61 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
62 struct ToolCall {
63 name: String,
64 arguments: Value,
65 }
66 let mut calls = Vec::new();
67 for caps in deepseek_regex.captures_iter(message) {
68 let name = caps
69 .name("name")
70 .ok_or("Could not capture function name")
71 .map_err(candle_core::Error::msg)?
72 .as_str()
73 .trim()
74 .to_string();
75 let json_str = caps
76 .name("json")
77 .ok_or("Could not capture JSON arguments")
78 .map_err(candle_core::Error::msg)?
79 .as_str()
80 .trim();
81 let arguments: Value =
82 serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
83 calls.push(ToolCall { name, arguments });
84 }
85 Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
86 } else {
87 Ok(message.to_string())
88 }
89}
90
91pub struct ToolCallingMatcher {
92 tool_choice: ToolChoice,
93}
94
95#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
97pub struct CalledFunctionParameters {
98 #[serde(alias = "function")]
99 pub name: String,
100 #[serde(alias = "arguments", deserialize_with = "flexible_args")]
101 pub parameters: Value,
102}
103
104fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
106where
107 D: Deserializer<'de>,
108{
109 struct ArgVisitor;
110
111 impl<'de> Visitor<'de> for ArgVisitor {
112 type Value = Value;
113
114 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
115 f.write_str("an object or a JSON-encoded string containing an object")
116 }
117
118 fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
120 where
121 M: MapAccess<'de>,
122 {
123 let mut map = Map::new();
124 while let Some((k, v)) = m.next_entry()? {
125 map.insert(k, v);
126 }
127 Ok(Value::Object(map))
128 }
129
130 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
132 where
133 E: de::Error,
134 {
135 serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
136 }
137 }
138
139 d.deserialize_any(ArgVisitor)
140}
141
142fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
145 if raw.contains(r#""arguments":"{"#) {
148 let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
150 let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
152 Ok(fixed)
153 } else {
154 Ok(raw.to_string())
155 }
156}
157
158impl ToolCallingMatcher {
159 pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
160 Ok(Self { tool_choice })
161 }
162
163 pub fn prefix_could_be_tool(
170 &self,
171 _pipeline: &dyn Pipeline,
172 message_prefix: &str,
173 ) -> Result<(bool, bool)> {
174 if matches!(self.tool_choice, ToolChoice::None) {
175 return Ok((false, false));
176 }
177 let message_prefix = process_model_specific_message(message_prefix)?;
178 let message_prefix = fix_broken_json(&message_prefix).unwrap();
179
180 Ok([
182 could_be_json::<CalledFunctionParameters>,
183 could_be_json::<Vec<CalledFunctionParameters>>,
184 ]
185 .iter()
186 .find_map(|check| {
187 let (could_be_tool, is_complete_tool) = check(&message_prefix);
188 if could_be_tool || is_complete_tool {
189 Some((could_be_tool, is_complete_tool))
190 } else {
191 None
192 }
193 })
194 .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
195 }
196
197 pub fn get_call(
198 &self,
199 _pipeline: &dyn Pipeline,
200 message: &str,
201 ) -> anyhow::Result<Vec<ToolCallResponse>> {
202 if matches!(self.tool_choice, ToolChoice::None) {
203 return Ok(Vec::new());
204 }
205 let message = process_model_specific_message(message)?;
206 let message = fix_broken_json(&message).unwrap();
207
208 if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
209 let id = format!("call-{}", Uuid::new_v4());
210 Ok(vec![ToolCallResponse {
211 index: 0,
212 id,
213 tp: ToolCallType::Function,
214 function: CalledFunction {
215 name: deser.name,
216 arguments: serde_json::to_string(&deser.parameters)?,
217 },
218 }])
219 } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
220 Ok(deser
221 .into_iter()
222 .enumerate()
223 .map(|(idx, deser)| {
224 let id = format!("call-{}", Uuid::new_v4());
225 Ok(ToolCallResponse {
226 index: idx,
227 id,
228 tp: ToolCallType::Function,
229 function: CalledFunction {
230 name: deser.name,
231 arguments: serde_json::to_string(&deser.parameters)?,
232 },
233 })
234 })
235 .collect::<anyhow::Result<Vec<_>>>()?)
236 } else {
237 if matches!(self.tool_choice, ToolChoice::Tool(_)) {
238 anyhow::bail!("Tool choice was required but no tools were called.")
239 }
240 Ok(Vec::new())
241 }
242 }
243}
244
245fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
249where
250 T: serde::de::DeserializeOwned,
251{
252 if text_prefix.trim().is_empty() {
253 return (false, false);
254 }
255 match serde_json::from_str::<T>(text_prefix) {
256 Ok(_) => (false, true),
257 Err(e) if e.is_eof() => (true, false),
259 _ => (false, false),
260 }
261}
262
263pub fn parse_text_tools<'a>(
265 pipeline: &dyn Pipeline,
266 raw_text: &'a str,
267 matcher: Option<Arc<ToolCallingMatcher>>,
268) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
269 let mut tool_calls = Vec::new();
270 let mut text_new = Some(raw_text);
271
272 if let Some(ref matcher) = matcher {
273 let calls = matcher
274 .get_call(pipeline, raw_text)
275 .map_err(candle_core::Error::msg)?;
276 if !calls.is_empty() {
277 text_new = None;
278 tool_calls = calls;
279 }
280 };
281 Ok((text_new, tool_calls))
282}