mistralrs_mcp/
tools.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6/// Callback used for custom tool functions. Receives the called function
7/// (name and JSON arguments) and returns the tool output as a string.
8pub type ToolCallback = dyn Fn(&CalledFunction) -> anyhow::Result<String> + Send + Sync;
9
10/// A tool callback with its associated Tool definition.
11#[derive(Clone)]
12pub struct ToolCallbackWithTool {
13    pub callback: Arc<ToolCallback>,
14    pub tool: Tool,
15}
16
17/// Collection of callbacks keyed by tool name.
18pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
19
20/// Collection of callbacks with their tool definitions keyed by tool name.
21pub type ToolCallbacksWithTools = HashMap<String, ToolCallbackWithTool>;
22
23/// Type of tool
24#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
25#[derive(Clone, Debug, Deserialize, Serialize)]
26pub enum ToolType {
27    #[serde(rename = "function")]
28    Function,
29}
30
31/// Function definition for a tool
32#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
33#[derive(Clone, Debug, Deserialize, Serialize)]
34pub struct Function {
35    pub description: Option<String>,
36    pub name: String,
37    #[serde(alias = "arguments")]
38    pub parameters: Option<HashMap<String, Value>>,
39}
40
41/// Tool definition
42#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
43#[derive(Clone, Debug, Deserialize, Serialize)]
44pub struct Tool {
45    #[serde(rename = "type")]
46    pub tp: ToolType,
47    pub function: Function,
48}
49
50/// Called function with name and arguments
51#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
52#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct CalledFunction {
55    pub name: String,
56    pub arguments: String,
57}