mistralrs_core/tools/
mod.rs1mod request;
2mod response;
3
4use candle_core::Result;
5use regex::Regex;
6pub use request::*;
7pub use response::*;
8use serde_json::Value;
9use std::{
10 collections::HashMap,
11 sync::{Arc, OnceLock},
12};
13use uuid::Uuid;
14
15use crate::Pipeline;
16
17fn process_model_specific_message(message: &str) -> Result<String> {
18 static DEEPSEEK_REGEX: OnceLock<Regex> = OnceLock::new();
19 let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| Regex::new(
20 r"<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
21 ).unwrap());
22
23 if let Some(message) = message.strip_prefix("<|python_tag|>") {
24 Ok(message.to_string())
26 } else if let Some(message) = message
27 .strip_prefix("<tool_call>")
28 .and_then(|s| s.strip_suffix("</tool_call>"))
29 {
30 Ok(message.to_string())
32 } else if let Some(message) = message
33 .strip_prefix("[TOOL_CALLS][")
34 .and_then(|s| s.strip_suffix("]"))
35 {
36 Ok(message.to_string())
38 } else if deepseek_regex.find(message).is_some() {
39 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
40 struct ToolCall {
41 name: String,
42 arguments: Value,
43 }
44 let mut calls = Vec::new();
45 for caps in deepseek_regex.captures_iter(message) {
46 let name = caps
47 .name("name")
48 .ok_or("Could not capture function name")
49 .map_err(candle_core::Error::msg)?
50 .as_str()
51 .trim()
52 .to_string();
53 let json_str = caps
54 .name("json")
55 .ok_or("Could not capture JSON arguments")
56 .map_err(candle_core::Error::msg)?
57 .as_str()
58 .trim();
59 let arguments: Value =
60 serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
61 calls.push(ToolCall { name, arguments });
62 }
63 Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
64 } else {
65 Ok(message.to_string())
66 }
67}
68
69pub struct ToolCallingMatcher {
70 tool_choice: ToolChoice,
71}
72
73#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
75pub struct CalledFunctionParameters {
76 #[serde(alias = "function")]
77 pub name: String,
78 #[serde(alias = "arguments")]
79 pub parameters: HashMap<String, Value>,
80}
81
82impl ToolCallingMatcher {
83 pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
84 Ok(Self { tool_choice })
85 }
86
87 pub fn prefix_could_be_tool(
94 &self,
95 _pipeline: &dyn Pipeline,
96 message_prefix: &str,
97 ) -> Result<(bool, bool)> {
98 if matches!(self.tool_choice, ToolChoice::None) {
99 return Ok((false, false));
100 }
101 let message_prefix = process_model_specific_message(message_prefix)?;
102
103 Ok([
105 could_be_json::<CalledFunctionParameters>,
106 could_be_json::<Vec<CalledFunctionParameters>>,
107 ]
108 .iter()
109 .find_map(|check| {
110 let (could_be_tool, is_complete_tool) = check(&message_prefix);
111 if could_be_tool || is_complete_tool {
112 Some((could_be_tool, is_complete_tool))
113 } else {
114 None
115 }
116 })
117 .unwrap_or_default())
118 }
119
120 pub fn get_call(
121 &self,
122 _pipeline: &dyn Pipeline,
123 message: &str,
124 ) -> anyhow::Result<Vec<ToolCallResponse>> {
125 if matches!(self.tool_choice, ToolChoice::None) {
126 return Ok(Vec::new());
127 }
128 let message = process_model_specific_message(message)?;
129
130 if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
131 let id = format!("call-{}", Uuid::new_v4());
132 Ok(vec![ToolCallResponse {
133 id,
134 tp: ToolCallType::Function,
135 function: CalledFunction {
136 name: deser.name,
137 arguments: serde_json::to_string(&deser.parameters)?,
138 },
139 }])
140 } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
141 Ok(deser
142 .into_iter()
143 .map(|deser| {
144 let id = format!("call-{}", Uuid::new_v4());
145 Ok(ToolCallResponse {
146 id,
147 tp: ToolCallType::Function,
148 function: CalledFunction {
149 name: deser.name,
150 arguments: serde_json::to_string(&deser.parameters)?,
151 },
152 })
153 })
154 .collect::<anyhow::Result<Vec<_>>>()?)
155 } else {
156 if matches!(self.tool_choice, ToolChoice::Tool(_)) {
157 anyhow::bail!("Tool choice was required but no tools were called.")
158 }
159 Ok(Vec::new())
160 }
161 }
162}
163
164fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
168where
169 T: serde::de::DeserializeOwned,
170{
171 if text_prefix.is_empty() {
172 return (false, false);
173 }
174 match serde_json::from_str::<T>(text_prefix) {
175 Ok(_) => (false, true),
176 Err(e) if e.is_eof() => (true, false),
178 _ => (false, false),
179 }
180}
181
182pub fn parse_text_tools<'a>(
184 pipeline: &dyn Pipeline,
185 raw_text: &'a str,
186 matcher: Option<Arc<ToolCallingMatcher>>,
187) -> anyhow::Result<(Option<&'a str>, Vec<ToolCallResponse>)> {
188 let mut tool_calls = Vec::new();
189 let mut text_new = Some(raw_text);
190
191 if let Some(ref matcher) = matcher {
192 let calls = matcher
193 .get_call(pipeline, raw_text)
194 .map_err(candle_core::Error::msg)?;
195 if !calls.is_empty() {
196 text_new = None;
197 tool_calls = calls;
198 }
199 };
200 Ok((text_new, tool_calls))
201}