mistralrs_core/tools/
mod.rs1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde::de::{self, Deserializer, MapAccess, Visitor};
9use serde_json::{Map, Value};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::{Arc, OnceLock};
13use uuid::Uuid;
14
15use crate::Pipeline;
16
17pub type ToolCallback = dyn Fn(&CalledFunction) -> anyhow::Result<String> + Send + Sync;
20
21pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
23
24fn contains_tool_call_prefix(prefix: &str) -> bool {
25 prefix.contains("<tool_call>")
26 || prefix.contains("<|tool▁call▁begin|>")
27 || prefix.contains("<|python_tag|>")
28 || prefix.contains("[TOOL_CALLS]")
29}
30
31fn process_model_specific_message(message: &str) -> Result<String> {
32 static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
33 static QWEN_REGEX: OnceLock<Regex> = OnceLock::new();
34
35 let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
37 r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
38 ).unwrap());
39 let qwen_regex = QWEN_REGEX
40 .get_or_init(|| Regex::new(r"(?s)<tool_call>(?P<inner>.*?)</tool_call>").unwrap());
41
42 if let Some(message) = message.strip_prefix("<|python_tag|>") {
43 Ok(message.to_string())
45 } else if qwen_regex.is_match(message) {
46 if let Some(caps) = qwen_regex.captures(message) {
47 let inner = caps.name("inner").unwrap().as_str();
48 return Ok(inner.trim().to_string());
49 }
50 Ok(message.to_string())
51 } else if let Some(message) = message
52 .strip_prefix("[TOOL_CALLS][")
53 .and_then(|s| s.strip_suffix("]"))
54 {
55 Ok(message.to_string())
57 } else if deepseek_regex.find(message).is_some() {
58 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
59 struct ToolCall {
60 name: String,
61 arguments: Value,
62 }
63 let mut calls = Vec::new();
64 for caps in deepseek_regex.captures_iter(message) {
65 let name = caps
66 .name("name")
67 .ok_or("Could not capture function name")
68 .map_err(candle_core::Error::msg)?
69 .as_str()
70 .trim()
71 .to_string();
72 let json_str = caps
73 .name("json")
74 .ok_or("Could not capture JSON arguments")
75 .map_err(candle_core::Error::msg)?
76 .as_str()
77 .trim();
78 let arguments: Value =
79 serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
80 calls.push(ToolCall { name, arguments });
81 }
82 Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
83 } else {
84 Ok(message.to_string())
85 }
86}
87
88pub struct ToolCallingMatcher {
89 tool_choice: ToolChoice,
90}
91
92#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
94pub struct CalledFunctionParameters {
95 #[serde(alias = "function")]
96 pub name: String,
97 #[serde(alias = "arguments", deserialize_with = "flexible_args")]
98 pub parameters: Value,
99}
100
101fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
103where
104 D: Deserializer<'de>,
105{
106 struct ArgVisitor;
107
108 impl<'de> Visitor<'de> for ArgVisitor {
109 type Value = Value;
110
111 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 f.write_str("an object or a JSON-encoded string containing an object")
113 }
114
115 fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
117 where
118 M: MapAccess<'de>,
119 {
120 let mut map = Map::new();
121 while let Some((k, v)) = m.next_entry()? {
122 map.insert(k, v);
123 }
124 Ok(Value::Object(map))
125 }
126
127 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
129 where
130 E: de::Error,
131 {
132 serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
133 }
134 }
135
136 d.deserialize_any(ArgVisitor)
137}
138
139fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
142 let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
144 let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
146
147 Ok(fixed)
148}
149
150impl ToolCallingMatcher {
151 pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
152 Ok(Self { tool_choice })
153 }
154
155 pub fn prefix_could_be_tool(
162 &self,
163 _pipeline: &dyn Pipeline,
164 message_prefix: &str,
165 ) -> Result<(bool, bool)> {
166 if matches!(self.tool_choice, ToolChoice::None) {
167 return Ok((false, false));
168 }
169 let message_prefix = process_model_specific_message(message_prefix)?;
170 let message_prefix = fix_broken_json(&message_prefix).unwrap();
171
172 Ok([
174 could_be_json::<CalledFunctionParameters>,
175 could_be_json::<Vec<CalledFunctionParameters>>,
176 ]
177 .iter()
178 .find_map(|check| {
179 let (could_be_tool, is_complete_tool) = check(&message_prefix);
180 if could_be_tool || is_complete_tool {
181 Some((could_be_tool, is_complete_tool))
182 } else {
183 None
184 }
185 })
186 .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
187 }
188
189 pub fn get_call(
190 &self,
191 _pipeline: &dyn Pipeline,
192 message: &str,
193 ) -> anyhow::Result<Vec<ToolCallResponse>> {
194 if matches!(self.tool_choice, ToolChoice::None) {
195 return Ok(Vec::new());
196 }
197 let message = process_model_specific_message(message)?;
198 let message = fix_broken_json(&message).unwrap();
199
200 if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
201 let id = format!("call-{}", Uuid::new_v4());
202 Ok(vec![ToolCallResponse {
203 id,
204 tp: ToolCallType::Function,
205 function: CalledFunction {
206 name: deser.name,
207 arguments: serde_json::to_string(&deser.parameters)?,
208 },
209 }])
210 } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
211 Ok(deser
212 .into_iter()
213 .map(|deser| {
214 let id = format!("call-{}", Uuid::new_v4());
215 Ok(ToolCallResponse {
216 id,
217 tp: ToolCallType::Function,
218 function: CalledFunction {
219 name: deser.name,
220 arguments: serde_json::to_string(&deser.parameters)?,
221 },
222 })
223 })
224 .collect::<anyhow::Result<Vec<_>>>()?)
225 } else {
226 if matches!(self.tool_choice, ToolChoice::Tool(_)) {
227 anyhow::bail!("Tool choice was required but no tools were called.")
228 }
229 Ok(Vec::new())
230 }
231 }
232}
233
234fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
238where
239 T: serde::de::DeserializeOwned,
240{
241 if text_prefix.is_empty() {
242 return (false, false);
243 }
244 match serde_json::from_str::<T>(text_prefix) {
245 Ok(_) => (false, true),
246 Err(e) if e.is_eof() => (true, false),
248 _ => (false, false),
249 }
250}
251
252pub fn parse_text_tools<'a>(
254 pipeline: &dyn Pipeline,
255 raw_text: &'a str,
256 matcher: Option<Arc<ToolCallingMatcher>>,
257) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
258 let mut tool_calls = Vec::new();
259 let mut text_new = Some(raw_text);
260
261 if let Some(ref matcher) = matcher {
262 let calls = matcher
263 .get_call(pipeline, raw_text)
264 .map_err(candle_core::Error::msg)?;
265 if !calls.is_empty() {
266 text_new = None;
267 tool_calls = calls;
268 }
269 };
270 Ok((text_new, tool_calls))
271}