mistralrs_macros/
lib.rs

1//! Proc macros for ergonomic tool definition in mistral.rs
2//!
3//! This crate provides the `#[tool]` attribute macro for defining tools
4//! that can be used with the mistral.rs agentic loop.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use mistralrs_macros::tool;
10//! use schemars::JsonSchema;
11//! use serde::{Deserialize, Serialize};
12//!
13//! #[derive(Debug, Serialize, Deserialize, JsonSchema)]
14//! struct WeatherInfo {
15//!     temperature: f32,
16//!     conditions: String,
17//! }
18//!
19//! #[tool(description = "Get the current weather for a location")]
20//! fn get_weather(
21//!     #[description = "The city name"]
22//!     city: String,
23//! ) -> anyhow::Result<WeatherInfo> {
24//!     Ok(WeatherInfo {
25//!         temperature: 22.5,
26//!         conditions: "Sunny".to_string(),
27//!     })
28//! }
29//!
30//! // This generates:
31//! // - get_weather_tool() -> Tool
32//! // - get_weather_callback() -> Arc<ToolCallback>
33//! // - get_weather_tool_with_callback() -> (Tool, Arc<ToolCallback>)
34//! ```
35
36use darling::{ast::NestedMeta, FromMeta};
37use proc_macro::TokenStream;
38use proc_macro2::TokenStream as TokenStream2;
39use quote::{format_ident, quote};
40use syn::{parse_macro_input, Expr, FnArg, ItemFn, Lit, Meta, Pat, PatType, Type};
41
42/// Arguments for the `#[tool]` attribute
43#[derive(Debug, FromMeta)]
44struct ToolArgs {
45    /// Description of what the tool does
46    description: String,
47    /// Optional override for the tool name (defaults to function name)
48    #[darling(default)]
49    name: Option<String>,
50}
51
52/// Arguments for parameter-level attributes
53#[derive(Debug, Default)]
54struct ParamArgs {
55    description: Option<String>,
56    default: Option<Expr>,
57}
58
59impl ParamArgs {
60    fn from_attrs(attrs: &[syn::Attribute]) -> Self {
61        let mut args = ParamArgs::default();
62
63        for attr in attrs {
64            if attr.path().is_ident("description") {
65                if let Meta::NameValue(nv) = &attr.meta {
66                    if let Expr::Lit(expr_lit) = &nv.value {
67                        if let Lit::Str(lit_str) = &expr_lit.lit {
68                            args.description = Some(lit_str.value());
69                        }
70                    }
71                }
72            } else if attr.path().is_ident("default") {
73                if let Meta::NameValue(nv) = &attr.meta {
74                    args.default = Some(nv.value.clone());
75                }
76            }
77        }
78
79        args
80    }
81}
82
83/// The `#[tool]` attribute macro for defining tools.
84///
85/// This macro transforms a regular Rust function into a tool that can be
86/// used with the mistral.rs agentic loop. It generates:
87///
88/// - `{fn_name}_tool()` - Returns the `Tool` definition
89/// - `{fn_name}_callback()` - Returns an `Arc<ToolCallback>` that wraps the function
90/// - `{fn_name}_tool_with_callback()` - Returns both as a tuple
91///
92/// # Attributes
93///
94/// - `description` (required): A description of what the tool does
95/// - `name` (optional): Override the tool name (defaults to function name)
96///
97/// # Parameter Attributes
98///
99/// - `#[description = "..."]`: Description of the parameter
100/// - `#[default = value]`: Default value if parameter is optional
101///
102/// # Requirements
103///
104/// - All parameter types must implement `serde::Deserialize` and `schemars::JsonSchema`
105/// - The return type must be `Result<T>` or `anyhow::Result<T>` where `T: Serialize`
106/// - For async functions, ensure a tokio runtime is available
107#[proc_macro_attribute]
108pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
109    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
110        Ok(v) => v,
111        Err(e) => return TokenStream::from(e.into_compile_error()),
112    };
113
114    let tool_args = match ToolArgs::from_list(&attr_args) {
115        Ok(v) => v,
116        Err(e) => return TokenStream::from(e.write_errors()),
117    };
118
119    let input_fn = parse_macro_input!(item as ItemFn);
120
121    match generate_tool_impl(tool_args, input_fn) {
122        Ok(tokens) => tokens.into(),
123        Err(e) => e.into_compile_error().into(),
124    }
125}
126
127fn generate_tool_impl(args: ToolArgs, input_fn: ItemFn) -> syn::Result<TokenStream2> {
128    let fn_name = &input_fn.sig.ident;
129    let fn_vis = &input_fn.vis;
130    let is_async = input_fn.sig.asyncness.is_some();
131
132    let tool_name = args.name.unwrap_or_else(|| fn_name.to_string());
133    let description = &args.description;
134
135    // Generate helper function names
136    let tool_fn_name = format_ident!("{}_tool", fn_name);
137    let callback_fn_name = format_ident!("{}_callback", fn_name);
138    let combined_fn_name = format_ident!("{}_tool_with_callback", fn_name);
139    let args_struct_name = format_ident!("__{}Args", fn_name);
140
141    // Create a stripped version of the function without our custom attributes
142    let mut stripped_fn = input_fn.clone();
143    for arg in &mut stripped_fn.sig.inputs {
144        if let FnArg::Typed(pat_type) = arg {
145            // Remove #[description] and #[default] attributes
146            pat_type.attrs.retain(|attr| {
147                !attr.path().is_ident("description") && !attr.path().is_ident("default")
148            });
149        }
150    }
151
152    // Collect parameter information
153    let mut param_names = Vec::new();
154    let mut param_types = Vec::new();
155    let mut param_descriptions = Vec::new();
156    let mut param_defaults = Vec::new();
157    let mut required_params = Vec::new();
158
159    for arg in &input_fn.sig.inputs {
160        if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = arg {
161            if let Pat::Ident(pat_ident) = pat.as_ref() {
162                let param_name = &pat_ident.ident;
163                let param_args = ParamArgs::from_attrs(attrs);
164
165                param_names.push(param_name.clone());
166                param_types.push(ty.as_ref().clone());
167                param_descriptions.push(param_args.description);
168                param_defaults.push(param_args.default);
169
170                // Check if the type is Option<T>
171                let is_optional = is_option_type(ty);
172                if !is_optional && param_defaults.last().unwrap().is_none() {
173                    required_params.push(param_name.to_string());
174                }
175            }
176        }
177    }
178
179    // Generate the Args struct fields with serde attributes
180    let args_struct_fields: Vec<TokenStream2> = param_names
181        .iter()
182        .zip(param_types.iter())
183        .zip(param_defaults.iter())
184        .map(|((name, ty), default)| {
185            if default.is_some() {
186                let default_fn_name_str = format!("__default_{}", name);
187                quote! {
188                    #[serde(default = #default_fn_name_str)]
189                    pub #name: #ty
190                }
191            } else {
192                quote! {
193                    pub #name: #ty
194                }
195            }
196        })
197        .collect();
198
199    // Generate default functions for parameters with defaults
200    let default_fns: Vec<TokenStream2> = param_names
201        .iter()
202        .zip(param_types.iter())
203        .zip(param_defaults.iter())
204        .filter_map(|((name, ty), default)| {
205            default.as_ref().map(|default_expr| {
206                let default_fn_name = format_ident!("__default_{}", name);
207                // If the type is Option<T>, wrap the default value in Some()
208                let value_expr = if is_option_type(ty) {
209                    quote! { Some(#default_expr.into()) }
210                } else {
211                    quote! { #default_expr }
212                };
213                quote! {
214                    fn #default_fn_name() -> #ty {
215                        #value_expr
216                    }
217                }
218            })
219        })
220        .collect();
221
222    // Generate property schema for each parameter
223    let property_schemas: Vec<TokenStream2> = param_names
224        .iter()
225        .zip(param_types.iter())
226        .zip(param_descriptions.iter())
227        .map(|((name, ty), desc)| {
228            let name_str = name.to_string();
229            // Extract inner type if Option<T>
230            let schema_type = extract_option_inner_type(ty).unwrap_or(ty);
231            let desc_insert = if let Some(d) = desc {
232                quote! {
233                    if let Some(obj) = prop_schema.as_object_mut() {
234                        obj.insert("description".to_string(), serde_json::json!(#d));
235                    }
236                }
237            } else {
238                quote! {}
239            };
240            quote! {
241                {
242                    let schema = schemars::schema_for!(#schema_type);
243                    let mut prop_schema = serde_json::to_value(&schema.schema).unwrap_or(serde_json::json!({}));
244                    #desc_insert
245                    properties.insert(#name_str.to_string(), prop_schema);
246                }
247            }
248        })
249        .collect();
250
251    // Generate required array
252    let required_array: Vec<TokenStream2> = required_params
253        .iter()
254        .map(|name| quote! { #name.to_string() })
255        .collect();
256
257    // Generate the function call
258    let call_args: Vec<TokenStream2> = param_names
259        .iter()
260        .map(|name| quote! { args.#name })
261        .collect();
262
263    // Build the output based on whether function is async or sync
264    let output = if is_async {
265        // Async function: generate AsyncToolCallback
266        quote! {
267            // Original function preserved (with custom attributes stripped)
268            #stripped_fn
269
270            // Default value functions (if any)
271            #(#default_fns)*
272
273            // Arguments struct for deserialization
274            #[derive(serde::Deserialize)]
275            #[allow(non_camel_case_types)]
276            struct #args_struct_name {
277                #(#args_struct_fields),*
278            }
279
280            /// Returns the Tool definition for this function
281            #fn_vis fn #tool_fn_name() -> mistralrs::Tool {
282                let mut properties = std::collections::HashMap::<String, serde_json::Value>::new();
283
284                #(#property_schemas)*
285
286                let required: Vec<String> = vec![#(#required_array),*];
287
288                let parameters: std::collections::HashMap<String, serde_json::Value> = serde_json::from_value(
289                    serde_json::json!({
290                        "type": "object",
291                        "properties": properties,
292                        "required": required,
293                    })
294                ).expect("Failed to create tool parameters");
295
296                mistralrs::Tool {
297                    tp: mistralrs::ToolType::Function,
298                    function: mistralrs::Function {
299                        description: Some(#description.to_string()),
300                        name: #tool_name.to_string(),
301                        parameters: Some(parameters),
302                    },
303                }
304            }
305
306            /// Returns an async callback that wraps this function for tool execution
307            #fn_vis fn #callback_fn_name() -> std::sync::Arc<mistralrs::AsyncToolCallback> {
308                std::sync::Arc::new(|called: mistralrs::CalledFunction| {
309                    Box::pin(async move {
310                        let args: #args_struct_name = serde_json::from_str(&called.arguments)
311                            .map_err(|e| anyhow::anyhow!("Failed to parse tool arguments: {}", e))?;
312
313                        let result = #fn_name(#(#call_args),*).await?;
314
315                        serde_json::to_string(&result)
316                            .map_err(|e| anyhow::anyhow!("Failed to serialize tool result: {}", e))
317                    })
318                })
319            }
320
321            /// Returns both the Tool definition and callback as a tuple
322            #fn_vis fn #combined_fn_name() -> (mistralrs::Tool, mistralrs::ToolCallbackType) {
323                (#tool_fn_name(), mistralrs::ToolCallbackType::Async(#callback_fn_name()))
324            }
325        }
326    } else {
327        // Sync function: generate ToolCallback
328        quote! {
329            // Original function preserved (with custom attributes stripped)
330            #stripped_fn
331
332            // Default value functions (if any)
333            #(#default_fns)*
334
335            // Arguments struct for deserialization
336            #[derive(serde::Deserialize)]
337            #[allow(non_camel_case_types)]
338            struct #args_struct_name {
339                #(#args_struct_fields),*
340            }
341
342            /// Returns the Tool definition for this function
343            #fn_vis fn #tool_fn_name() -> mistralrs::Tool {
344                let mut properties = std::collections::HashMap::<String, serde_json::Value>::new();
345
346                #(#property_schemas)*
347
348                let required: Vec<String> = vec![#(#required_array),*];
349
350                let parameters: std::collections::HashMap<String, serde_json::Value> = serde_json::from_value(
351                    serde_json::json!({
352                        "type": "object",
353                        "properties": properties,
354                        "required": required,
355                    })
356                ).expect("Failed to create tool parameters");
357
358                mistralrs::Tool {
359                    tp: mistralrs::ToolType::Function,
360                    function: mistralrs::Function {
361                        description: Some(#description.to_string()),
362                        name: #tool_name.to_string(),
363                        parameters: Some(parameters),
364                    },
365                }
366            }
367
368            /// Returns a sync callback that wraps this function for tool execution
369            #fn_vis fn #callback_fn_name() -> std::sync::Arc<mistralrs::ToolCallback> {
370                std::sync::Arc::new(|called: &mistralrs::CalledFunction| {
371                    let args: #args_struct_name = serde_json::from_str(&called.arguments)
372                        .map_err(|e| anyhow::anyhow!("Failed to parse tool arguments: {}", e))?;
373
374                    let result = #fn_name(#(#call_args),*)?;
375
376                    serde_json::to_string(&result)
377                        .map_err(|e| anyhow::anyhow!("Failed to serialize tool result: {}", e))
378                })
379            }
380
381            /// Returns both the Tool definition and callback as a tuple
382            #fn_vis fn #combined_fn_name() -> (mistralrs::Tool, mistralrs::ToolCallbackType) {
383                (#tool_fn_name(), mistralrs::ToolCallbackType::Sync(#callback_fn_name()))
384            }
385        }
386    };
387
388    Ok(output)
389}
390
391/// Check if a type is Option<T>
392fn is_option_type(ty: &Type) -> bool {
393    if let Type::Path(type_path) = ty {
394        if let Some(segment) = type_path.path.segments.last() {
395            return segment.ident == "Option";
396        }
397    }
398    false
399}
400
401/// Extract the inner type from Option<T>, returning None if not an Option
402fn extract_option_inner_type(ty: &Type) -> Option<&Type> {
403    if let Type::Path(type_path) = ty {
404        if let Some(segment) = type_path.path.segments.last() {
405            if segment.ident == "Option" {
406                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
407                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
408                        return Some(inner);
409                    }
410                }
411            }
412        }
413    }
414    None
415}