mistralrs_core/tools/
mod.rs1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde::de::{self, Deserializer, MapAccess, Visitor};
9use serde_json::{Map, Value};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::{Arc, OnceLock};
13use uuid::Uuid;
14
15use crate::Pipeline;
16use mistralrs_mcp::CalledFunction;
17
18pub use mistralrs_mcp::{ToolCallback, ToolCallbackWithTool};
20
21pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
23
24pub type ToolCallbacksWithTools = HashMap<String, ToolCallbackWithTool>;
26
27fn contains_tool_call_prefix(prefix: &str) -> bool {
28 prefix.contains("<tool_call>")
29 || prefix.contains("<|tool▁call▁begin|>")
30 || prefix.contains("<|python_tag|>")
31 || prefix.contains("[TOOL_CALLS]")
32}
33
34fn process_model_specific_message(message: &str) -> Result<String> {
35 static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
36 static QWEN_REGEX: OnceLock<Regex> = OnceLock::new();
37
38 let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
40 r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
41 ).unwrap());
42 let qwen_regex = QWEN_REGEX
43 .get_or_init(|| Regex::new(r"(?s)<tool_call>(?P<inner>.*?)</tool_call>").unwrap());
44
45 if let Some(message) = message.strip_prefix("<|python_tag|>") {
46 Ok(message.to_string())
48 } else if qwen_regex.is_match(message) {
49 if let Some(caps) = qwen_regex.captures(message) {
50 let inner = caps.name("inner").unwrap().as_str();
51 return Ok(inner.trim().to_string());
52 }
53 Ok(message.to_string())
54 } else if let Some(message) = message
55 .strip_prefix("[TOOL_CALLS][")
56 .and_then(|s| s.strip_suffix("]"))
57 {
58 Ok(message.to_string())
60 } else if deepseek_regex.find(message).is_some() {
61 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
62 struct ToolCall {
63 name: String,
64 arguments: Value,
65 }
66 let mut calls = Vec::new();
67 for caps in deepseek_regex.captures_iter(message) {
68 let name = caps
69 .name("name")
70 .ok_or("Could not capture function name")
71 .map_err(candle_core::Error::msg)?
72 .as_str()
73 .trim()
74 .to_string();
75 let json_str = caps
76 .name("json")
77 .ok_or("Could not capture JSON arguments")
78 .map_err(candle_core::Error::msg)?
79 .as_str()
80 .trim();
81 let arguments: Value =
82 serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
83 calls.push(ToolCall { name, arguments });
84 }
85 Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
86 } else {
87 Ok(message.to_string())
88 }
89}
90
91pub struct ToolCallingMatcher {
92 tool_choice: ToolChoice,
93}
94
95#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
97pub struct CalledFunctionParameters {
98 #[serde(alias = "function")]
99 pub name: String,
100 #[serde(alias = "arguments", deserialize_with = "flexible_args")]
101 pub parameters: Value,
102}
103
104fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
106where
107 D: Deserializer<'de>,
108{
109 struct ArgVisitor;
110
111 impl<'de> Visitor<'de> for ArgVisitor {
112 type Value = Value;
113
114 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
115 f.write_str("an object or a JSON-encoded string containing an object")
116 }
117
118 fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
120 where
121 M: MapAccess<'de>,
122 {
123 let mut map = Map::new();
124 while let Some((k, v)) = m.next_entry()? {
125 map.insert(k, v);
126 }
127 Ok(Value::Object(map))
128 }
129
130 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
132 where
133 E: de::Error,
134 {
135 serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
136 }
137 }
138
139 d.deserialize_any(ArgVisitor)
140}
141
142fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
145 let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
147 let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
149
150 Ok(fixed)
151}
152
153impl ToolCallingMatcher {
154 pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
155 Ok(Self { tool_choice })
156 }
157
158 pub fn prefix_could_be_tool(
165 &self,
166 _pipeline: &dyn Pipeline,
167 message_prefix: &str,
168 ) -> Result<(bool, bool)> {
169 if matches!(self.tool_choice, ToolChoice::None) {
170 return Ok((false, false));
171 }
172 let message_prefix = process_model_specific_message(message_prefix)?;
173 let message_prefix = fix_broken_json(&message_prefix).unwrap();
174
175 Ok([
177 could_be_json::<CalledFunctionParameters>,
178 could_be_json::<Vec<CalledFunctionParameters>>,
179 ]
180 .iter()
181 .find_map(|check| {
182 let (could_be_tool, is_complete_tool) = check(&message_prefix);
183 if could_be_tool || is_complete_tool {
184 Some((could_be_tool, is_complete_tool))
185 } else {
186 None
187 }
188 })
189 .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
190 }
191
192 pub fn get_call(
193 &self,
194 _pipeline: &dyn Pipeline,
195 message: &str,
196 ) -> anyhow::Result<Vec<ToolCallResponse>> {
197 if matches!(self.tool_choice, ToolChoice::None) {
198 return Ok(Vec::new());
199 }
200 let message = process_model_specific_message(message)?;
201 let message = fix_broken_json(&message).unwrap();
202
203 if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
204 let id = format!("call-{}", Uuid::new_v4());
205 Ok(vec![ToolCallResponse {
206 index: 0,
207 id,
208 tp: ToolCallType::Function,
209 function: CalledFunction {
210 name: deser.name,
211 arguments: serde_json::to_string(&deser.parameters)?,
212 },
213 }])
214 } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
215 Ok(deser
216 .into_iter()
217 .enumerate()
218 .map(|(idx, deser)| {
219 let id = format!("call-{}", Uuid::new_v4());
220 Ok(ToolCallResponse {
221 index: idx,
222 id,
223 tp: ToolCallType::Function,
224 function: CalledFunction {
225 name: deser.name,
226 arguments: serde_json::to_string(&deser.parameters)?,
227 },
228 })
229 })
230 .collect::<anyhow::Result<Vec<_>>>()?)
231 } else {
232 if matches!(self.tool_choice, ToolChoice::Tool(_)) {
233 anyhow::bail!("Tool choice was required but no tools were called.")
234 }
235 Ok(Vec::new())
236 }
237 }
238}
239
240fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
244where
245 T: serde::de::DeserializeOwned,
246{
247 if text_prefix.is_empty() {
248 return (false, false);
249 }
250 match serde_json::from_str::<T>(text_prefix) {
251 Ok(_) => (false, true),
252 Err(e) if e.is_eof() => (true, false),
254 _ => (false, false),
255 }
256}
257
258pub fn parse_text_tools<'a>(
260 pipeline: &dyn Pipeline,
261 raw_text: &'a str,
262 matcher: Option<Arc<ToolCallingMatcher>>,
263) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
264 let mut tool_calls = Vec::new();
265 let mut text_new = Some(raw_text);
266
267 if let Some(ref matcher) = matcher {
268 let calls = matcher
269 .get_call(pipeline, raw_text)
270 .map_err(candle_core::Error::msg)?;
271 if !calls.is_empty() {
272 text_new = None;
273 tool_calls = calls;
274 }
275 };
276 Ok((text_new, tool_calls))
277}