mistralrs_core/tools/
response.rs

1#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
2#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
3#[derive(Clone, Debug, serde::Serialize, PartialEq)]
4#[serde(rename_all = "snake_case")]
5pub enum ToolCallType {
6    Function,
7}
8
9impl std::fmt::Display for ToolCallType {
10    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11        match self {
12            ToolCallType::Function => write!(f, "function"),
13        }
14    }
15}
16
17use mistralrs_mcp::CalledFunction;
18
19#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
20#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
21#[derive(Clone, Debug, serde::Serialize)]
22pub struct ToolCallResponse {
23    pub index: usize,
24    pub id: String,
25    #[serde(rename = "type")]
26    pub tp: ToolCallType,
27    pub function: CalledFunction,
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33
34    #[test]
35    fn serializes_index_field() {
36        let resp = ToolCallResponse {
37            index: 0,
38            id: "call-1".to_string(),
39            tp: ToolCallType::Function,
40            function: CalledFunction {
41                name: "foo".to_string(),
42                arguments: "{}".to_string(),
43            },
44        };
45
46        let json = serde_json::to_value(&resp).unwrap();
47        assert_eq!(json.get("index").and_then(|v| v.as_u64()), Some(0));
48    }
49}