mistralrs/agent.rs
1//! Agentic loop implementation for mistral.rs
2//!
3//! This module provides an `Agent` that runs an agentic loop with tool calling.
4//! The agent takes a model, registers tools, and automatically handles the
5//! tool calling loop until the model produces a final response.
6//!
7//! # Features
8//!
9//! - **Async tools**: Native support for async tool functions
10//! - **Parallel execution**: Execute multiple tool calls concurrently
11//! - **Streaming**: Stream assistant responses and tool execution events
12//!
13//! # Example
14//!
15//! ```ignore
16//! use mistralrs::{tool, AgentBuilder, AgentEvent};
17//!
18//! // Async tool - runs natively async
19//! #[tool(description = "Fetch a URL")]
20//! async fn fetch_url(url: String) -> Result<String> {
21//! reqwest::get(&url).await?.text().await.map_err(Into::into)
22//! }
23//!
24//! // Sync tool
25//! #[tool(description = "Get weather")]
26//! fn get_weather(city: String) -> Result<WeatherInfo> {
27//! Ok(WeatherInfo { temperature: 22.5 })
28//! }
29//!
30//! #[tokio::main]
31//! async fn main() -> anyhow::Result<()> {
32//! let model = TextModelBuilder::new("model-id").build().await?;
33//!
34//! let agent = AgentBuilder::new(model)
35//! .with_system_prompt("You are helpful.")
36//! .register_tool(fetch_url_tool_with_callback())
37//! .register_tool(get_weather_tool_with_callback())
38//! .build();
39//!
40//! // Streaming execution
41//! let mut stream = agent.run_stream("What's the weather?").await?;
42//! while let Some(event) = stream.next().await {
43//! match event {
44//! AgentEvent::TextDelta(text) => print!("{}", text),
45//! AgentEvent::Complete(response) => println!("\nDone!"),
46//! _ => {}
47//! }
48//! }
49//! Ok(())
50//! }
51//! ```
52
53use std::collections::HashMap;
54use std::future::Future;
55use std::pin::Pin;
56use std::sync::Arc;
57
58use crate::{
59 CalledFunction, ChatCompletionChunkResponse, ChatCompletionResponse, ChunkChoice, Delta, Model,
60 RequestBuilder, Response, TextMessageRole, Tool, ToolCallResponse, ToolCallback, ToolChoice,
61};
62
63/// Async tool callback type for native async tool support
64pub type AsyncToolCallback = dyn Fn(CalledFunction) -> Pin<Box<dyn Future<Output = anyhow::Result<String>> + Send>>
65 + Send
66 + Sync;
67
68/// Unified tool callback that can be sync or async
69#[derive(Clone)]
70pub enum ToolCallbackType {
71 /// Synchronous callback (runs in spawn_blocking for parallel execution)
72 Sync(Arc<ToolCallback>),
73 /// Asynchronous callback (runs natively async)
74 Async(Arc<AsyncToolCallback>),
75}
76
77/// Configuration for the agentic loop
78#[derive(Clone, Debug)]
79pub struct AgentConfig {
80 /// Maximum number of iterations before stopping (default: 10)
81 pub max_iterations: usize,
82 /// Tool choice strategy (default: Auto)
83 pub tool_choice: ToolChoice,
84 /// Optional system prompt for the agent
85 pub system_prompt: Option<String>,
86 /// Whether to execute multiple tool calls in parallel (default: true)
87 pub parallel_tool_execution: bool,
88}
89
90impl Default for AgentConfig {
91 fn default() -> Self {
92 Self {
93 max_iterations: 10,
94 tool_choice: ToolChoice::Auto,
95 system_prompt: None,
96 parallel_tool_execution: true,
97 }
98 }
99}
100
101/// Represents a single step in the agent execution
102#[derive(Debug, Clone)]
103pub struct AgentStep {
104 /// The model's response for this step
105 pub response: ChatCompletionResponse,
106 /// Tool calls made in this step (if any)
107 pub tool_calls: Vec<ToolCallResponse>,
108 /// Results from tool executions
109 pub tool_results: Vec<ToolResult>,
110}
111
112/// Result of a tool execution
113#[derive(Debug, Clone)]
114pub struct ToolResult {
115 /// The tool call ID this result corresponds to
116 pub tool_call_id: String,
117 /// Name of the tool that was called
118 pub tool_name: String,
119 /// The result: Ok(output) or Err(error_message)
120 pub result: Result<String, String>,
121}
122
123/// Final response from the agent
124#[derive(Debug, Clone)]
125pub struct AgentResponse {
126 /// All steps taken during execution
127 pub steps: Vec<AgentStep>,
128 /// Final text response (if any)
129 pub final_response: Option<String>,
130 /// Total number of iterations performed
131 pub iterations: usize,
132 /// Why the agent stopped
133 pub stop_reason: AgentStopReason,
134}
135
136/// Reason why the agent stopped executing
137#[derive(Debug, Clone, PartialEq)]
138pub enum AgentStopReason {
139 /// Model produced a text response with no tool calls
140 TextResponse,
141 /// Maximum iterations reached
142 MaxIterations,
143 /// No tool calls and no text response
144 NoAction,
145 /// Error during execution
146 Error(String),
147}
148
149/// Events yielded during agent streaming
150#[derive(Debug, Clone)]
151pub enum AgentEvent {
152 /// Text content delta from the model
153 TextDelta(String),
154 /// Model is calling tools (with the tool calls)
155 ToolCallsStart(Vec<ToolCallResponse>),
156 /// A single tool completed execution
157 ToolResult(ToolResult),
158 /// All tools completed, continuing to next iteration
159 ToolCallsComplete,
160 /// Agent finished with final response
161 Complete(AgentResponse),
162}
163
164/// Internal state for the agent stream
165enum AgentStreamState {
166 /// Currently streaming model response
167 Streaming {
168 messages: RequestBuilder,
169 iteration: usize,
170 accumulated_content: String,
171 accumulated_tool_calls: Vec<ToolCallResponse>,
172 steps: Vec<AgentStep>,
173 },
174 /// Executing tool calls
175 ExecutingTools {
176 messages: RequestBuilder,
177 iteration: usize,
178 response: ChatCompletionResponse,
179 tool_calls: Vec<ToolCallResponse>,
180 tool_results: Vec<ToolResult>,
181 pending_indices: Vec<usize>,
182 steps: Vec<AgentStep>,
183 },
184 /// Agent has completed
185 Done,
186}
187
188/// Stream of agent events during execution
189pub struct AgentStream<'a> {
190 agent: &'a Agent,
191 state: AgentStreamState,
192 model_stream: Option<crate::model::Stream<'a>>,
193}
194
195impl<'a> AgentStream<'a> {
196 /// Get the next event from the agent stream
197 pub async fn next(&mut self) -> Option<AgentEvent> {
198 loop {
199 match &mut self.state {
200 AgentStreamState::Done => return None,
201
202 AgentStreamState::Streaming {
203 messages,
204 iteration,
205 accumulated_content,
206 accumulated_tool_calls,
207 steps,
208 } => {
209 // Get next chunk from model stream
210 if let Some(ref mut stream) = self.model_stream {
211 if let Some(response) = stream.next().await {
212 match response {
213 Response::Chunk(ChatCompletionChunkResponse {
214 choices, ..
215 }) => {
216 if let Some(ChunkChoice {
217 delta:
218 Delta {
219 content,
220 tool_calls,
221 ..
222 },
223 finish_reason,
224 ..
225 }) = choices.first()
226 {
227 // Accumulate content
228 if let Some(text) = content {
229 accumulated_content.push_str(text);
230 return Some(AgentEvent::TextDelta(text.clone()));
231 }
232
233 // Accumulate tool calls
234 if let Some(calls) = tool_calls {
235 accumulated_tool_calls.extend(calls.clone());
236 }
237
238 // Check if done
239 if finish_reason.is_some() {
240 self.model_stream = None;
241
242 if accumulated_tool_calls.is_empty() {
243 // No tool calls - we're done
244 let final_response =
245 if accumulated_content.is_empty() {
246 None
247 } else {
248 Some(accumulated_content.clone())
249 };
250
251 let stop_reason = if final_response.is_some() {
252 AgentStopReason::TextResponse
253 } else {
254 AgentStopReason::NoAction
255 };
256
257 let response = AgentResponse {
258 steps: steps.clone(),
259 final_response,
260 iterations: *iteration + 1,
261 stop_reason,
262 };
263
264 self.state = AgentStreamState::Done;
265 return Some(AgentEvent::Complete(response));
266 } else {
267 // Transition to executing tools
268 let tool_calls = accumulated_tool_calls.clone();
269 let event =
270 AgentEvent::ToolCallsStart(tool_calls.clone());
271
272 // Create a placeholder response for the step
273 let placeholder_response = ChatCompletionResponse {
274 id: String::new(),
275 choices: vec![],
276 created: 0,
277 model: String::new(),
278 system_fingerprint: String::new(),
279 object: String::new(),
280 usage: crate::Usage {
281 completion_tokens: 0,
282 prompt_tokens: 0,
283 total_tokens: 0,
284 avg_tok_per_sec: 0.0,
285 avg_prompt_tok_per_sec: 0.0,
286 avg_compl_tok_per_sec: 0.0,
287 total_time_sec: 0.0,
288 total_prompt_time_sec: 0.0,
289 total_completion_time_sec: 0.0,
290 },
291 };
292
293 self.state = AgentStreamState::ExecutingTools {
294 messages: messages.clone(),
295 iteration: *iteration,
296 response: placeholder_response,
297 tool_calls: tool_calls.clone(),
298 tool_results: Vec::new(),
299 pending_indices: (0..tool_calls.len())
300 .collect(),
301 steps: steps.clone(),
302 };
303
304 return Some(event);
305 }
306 }
307 }
308 }
309 Response::Done(response) => {
310 self.model_stream = None;
311 let tool_calls = response
312 .choices
313 .first()
314 .and_then(|c| c.message.tool_calls.clone())
315 .unwrap_or_default();
316
317 if tool_calls.is_empty() {
318 let final_response = response
319 .choices
320 .first()
321 .and_then(|c| c.message.content.clone());
322 let stop_reason = if final_response.is_some() {
323 AgentStopReason::TextResponse
324 } else {
325 AgentStopReason::NoAction
326 };
327
328 let agent_response = AgentResponse {
329 steps: steps.clone(),
330 final_response,
331 iterations: *iteration + 1,
332 stop_reason,
333 };
334
335 self.state = AgentStreamState::Done;
336 return Some(AgentEvent::Complete(agent_response));
337 } else {
338 let event = AgentEvent::ToolCallsStart(tool_calls.clone());
339
340 self.state = AgentStreamState::ExecutingTools {
341 messages: messages.clone(),
342 iteration: *iteration,
343 response: response.clone(),
344 tool_calls: tool_calls.clone(),
345 tool_results: Vec::new(),
346 pending_indices: (0..tool_calls.len()).collect(),
347 steps: steps.clone(),
348 };
349
350 return Some(event);
351 }
352 }
353 _ => continue,
354 }
355 }
356 }
357
358 // Stream ended unexpectedly
359 self.state = AgentStreamState::Done;
360 return None;
361 }
362
363 AgentStreamState::ExecutingTools {
364 messages,
365 iteration,
366 response,
367 tool_calls,
368 tool_results,
369 pending_indices,
370 steps,
371 } => {
372 if pending_indices.is_empty() {
373 // All tools executed - prepare for next iteration
374 let mut new_messages = messages.clone();
375
376 // Add assistant message with tool calls
377 new_messages = new_messages.add_message_with_tool_call(
378 TextMessageRole::Assistant,
379 response
380 .choices
381 .first()
382 .and_then(|c| c.message.content.clone())
383 .unwrap_or_default(),
384 tool_calls.clone(),
385 );
386
387 // Add tool results
388 for result in tool_results.iter() {
389 let result_str = match &result.result {
390 Ok(s) => s.clone(),
391 Err(e) => format!("Error: {}", e),
392 };
393 new_messages =
394 new_messages.add_tool_message(&result_str, &result.tool_call_id);
395 }
396
397 // Record step
398 let step = AgentStep {
399 response: response.clone(),
400 tool_calls: tool_calls.clone(),
401 tool_results: tool_results.clone(),
402 };
403 let mut new_steps = steps.clone();
404 new_steps.push(step);
405
406 let new_iteration = *iteration + 1;
407
408 // Check max iterations
409 if new_iteration >= self.agent.config.max_iterations {
410 let agent_response = AgentResponse {
411 steps: new_steps,
412 final_response: None,
413 iterations: new_iteration,
414 stop_reason: AgentStopReason::MaxIterations,
415 };
416 self.state = AgentStreamState::Done;
417 return Some(AgentEvent::Complete(agent_response));
418 }
419
420 // Start new model request
421 let request = new_messages
422 .clone()
423 .set_tools(self.agent.tools.clone())
424 .set_tool_choice(self.agent.config.tool_choice.clone());
425
426 match self.agent.model.stream_chat_request(request).await {
427 Ok(stream) => {
428 self.model_stream = Some(stream);
429 self.state = AgentStreamState::Streaming {
430 messages: new_messages,
431 iteration: new_iteration,
432 accumulated_content: String::new(),
433 accumulated_tool_calls: Vec::new(),
434 steps: new_steps,
435 };
436 return Some(AgentEvent::ToolCallsComplete);
437 }
438 Err(e) => {
439 let agent_response = AgentResponse {
440 steps: new_steps,
441 final_response: None,
442 iterations: new_iteration,
443 stop_reason: AgentStopReason::Error(e.to_string()),
444 };
445 self.state = AgentStreamState::Done;
446 return Some(AgentEvent::Complete(agent_response));
447 }
448 }
449 }
450
451 // Execute next pending tool
452 let idx = pending_indices.remove(0);
453 let tool_call = &tool_calls[idx];
454 let result = self.agent.execute_tool_async(tool_call).await;
455 let event = AgentEvent::ToolResult(result.clone());
456 tool_results.push(result);
457 return Some(event);
458 }
459 }
460 }
461 }
462}
463
464/// An agent that runs an agentic loop with tool calling
465pub struct Agent {
466 model: Model,
467 tools: Vec<Tool>,
468 callbacks: HashMap<String, ToolCallbackType>,
469 config: AgentConfig,
470}
471
472impl Agent {
473 /// Create a new agent with the given model and configuration
474 pub fn new(model: Model, config: AgentConfig) -> Self {
475 Self {
476 model,
477 tools: Vec::new(),
478 callbacks: HashMap::new(),
479 config,
480 }
481 }
482
483 /// Add a tool with its callback
484 pub fn with_tool(mut self, tool: Tool, callback: ToolCallbackType) -> Self {
485 let name = tool.function.name.clone();
486 self.tools.push(tool);
487 self.callbacks.insert(name, callback);
488 self
489 }
490
491 /// Run the agentic loop with the given user message
492 ///
493 /// This method will:
494 /// 1. Send the user message to the model
495 /// 2. If the model returns tool calls, execute them and send results back
496 /// 3. Repeat until the model returns a text response or max iterations is reached
497 pub async fn run(&self, user_message: impl ToString) -> anyhow::Result<AgentResponse> {
498 let mut steps = Vec::new();
499 let mut messages = RequestBuilder::new();
500
501 // Add system prompt if configured
502 if let Some(ref system) = self.config.system_prompt {
503 messages = messages.add_message(TextMessageRole::System, system);
504 }
505
506 // Add initial user message
507 messages = messages.add_message(TextMessageRole::User, user_message.to_string());
508
509 for iteration in 0..self.config.max_iterations {
510 // Configure tools for this request
511 let request = messages
512 .clone()
513 .set_tools(self.tools.clone())
514 .set_tool_choice(self.config.tool_choice.clone());
515
516 // Send request to model
517 let response = self.model.send_chat_request(request).await?;
518
519 let choice = response
520 .choices
521 .first()
522 .ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
523
524 // Check for tool calls
525 let tool_calls = choice.message.tool_calls.clone().unwrap_or_default();
526
527 if tool_calls.is_empty() {
528 // No tool calls - we're done
529 let final_text = choice.message.content.clone();
530 steps.push(AgentStep {
531 response: response.clone(),
532 tool_calls: vec![],
533 tool_results: vec![],
534 });
535
536 let stop_reason = if final_text.is_some() {
537 AgentStopReason::TextResponse
538 } else {
539 AgentStopReason::NoAction
540 };
541
542 return Ok(AgentResponse {
543 steps,
544 final_response: final_text,
545 iterations: iteration + 1,
546 stop_reason,
547 });
548 }
549
550 // Execute tool calls (parallel or sequential based on config)
551 let tool_results = if self.config.parallel_tool_execution {
552 self.execute_tools_parallel(&tool_calls).await
553 } else {
554 let mut results = Vec::new();
555 for tool_call in &tool_calls {
556 results.push(self.execute_tool_async(tool_call).await);
557 }
558 results
559 };
560
561 // Add assistant message with tool calls
562 messages = messages.add_message_with_tool_call(
563 TextMessageRole::Assistant,
564 choice.message.content.clone().unwrap_or_default(),
565 tool_calls.clone(),
566 );
567
568 // Add tool results to messages
569 for result in &tool_results {
570 let result_str = match &result.result {
571 Ok(s) => s.clone(),
572 Err(e) => format!("Error: {}", e),
573 };
574 messages = messages.add_tool_message(&result_str, &result.tool_call_id);
575 }
576
577 steps.push(AgentStep {
578 response: response.clone(),
579 tool_calls: tool_calls.clone(),
580 tool_results,
581 });
582 }
583
584 // Max iterations reached
585 Ok(AgentResponse {
586 steps,
587 final_response: None,
588 iterations: self.config.max_iterations,
589 stop_reason: AgentStopReason::MaxIterations,
590 })
591 }
592
593 /// Run the agent with streaming output
594 ///
595 /// Returns a stream of `AgentEvent` that can be used to observe
596 /// the agent's progress in real-time.
597 pub async fn run_stream(&self, user_message: impl ToString) -> anyhow::Result<AgentStream<'_>> {
598 let mut messages = RequestBuilder::new();
599
600 // Add system prompt if configured
601 if let Some(ref system) = self.config.system_prompt {
602 messages = messages.add_message(TextMessageRole::System, system);
603 }
604
605 // Add initial user message
606 messages = messages.add_message(TextMessageRole::User, user_message.to_string());
607
608 // Configure tools for this request
609 let request = messages
610 .clone()
611 .set_tools(self.tools.clone())
612 .set_tool_choice(self.config.tool_choice.clone());
613
614 // Start streaming
615 let stream = self.model.stream_chat_request(request).await?;
616
617 Ok(AgentStream {
618 agent: self,
619 state: AgentStreamState::Streaming {
620 messages,
621 iteration: 0,
622 accumulated_content: String::new(),
623 accumulated_tool_calls: Vec::new(),
624 steps: Vec::new(),
625 },
626 model_stream: Some(stream),
627 })
628 }
629
630 /// Execute multiple tool calls in parallel
631 async fn execute_tools_parallel(&self, tool_calls: &[ToolCallResponse]) -> Vec<ToolResult> {
632 let futures: Vec<_> = tool_calls
633 .iter()
634 .map(|tc| self.execute_tool_async(tc))
635 .collect();
636
637 futures::future::join_all(futures).await
638 }
639
640 /// Execute a single tool call (async-compatible)
641 async fn execute_tool_async(&self, tool_call: &ToolCallResponse) -> ToolResult {
642 let tool_name = &tool_call.function.name;
643
644 let result = match self.callbacks.get(tool_name) {
645 Some(ToolCallbackType::Sync(callback)) => {
646 // Run sync callback in spawn_blocking to not block async runtime
647 let callback = Arc::clone(callback);
648 let function = tool_call.function.clone();
649 tokio::task::spawn_blocking(move || callback(&function))
650 .await
651 .map_err(|e| anyhow::anyhow!("Task join error: {}", e))
652 .and_then(|r| r)
653 .map_err(|e| e.to_string())
654 }
655 Some(ToolCallbackType::Async(callback)) => {
656 let function = tool_call.function.clone();
657 callback(function).await.map_err(|e| e.to_string())
658 }
659 None => Err(format!("Unknown tool: {}", tool_name)),
660 };
661
662 ToolResult {
663 tool_call_id: tool_call.id.clone(),
664 tool_name: tool_name.clone(),
665 result,
666 }
667 }
668
669 /// Get a reference to the underlying model
670 pub fn model(&self) -> &Model {
671 &self.model
672 }
673
674 /// Get a reference to the registered tools
675 pub fn tools(&self) -> &[Tool] {
676 &self.tools
677 }
678
679 /// Get a reference to the agent configuration
680 pub fn config(&self) -> &AgentConfig {
681 &self.config
682 }
683}
684
685/// Builder for creating agents with a fluent API
686pub struct AgentBuilder {
687 model: Model,
688 tools: Vec<Tool>,
689 callbacks: HashMap<String, ToolCallbackType>,
690 config: AgentConfig,
691}
692
693impl AgentBuilder {
694 /// Create a new agent builder with the given model
695 pub fn new(model: Model) -> Self {
696 Self {
697 model,
698 tools: Vec::new(),
699 callbacks: HashMap::new(),
700 config: AgentConfig::default(),
701 }
702 }
703
704 /// Set the maximum number of iterations
705 pub fn with_max_iterations(mut self, max: usize) -> Self {
706 self.config.max_iterations = max;
707 self
708 }
709
710 /// Set the system prompt for the agent
711 pub fn with_system_prompt(mut self, prompt: impl ToString) -> Self {
712 self.config.system_prompt = Some(prompt.to_string());
713 self
714 }
715
716 /// Set the tool choice strategy
717 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
718 self.config.tool_choice = choice;
719 self
720 }
721
722 /// Enable or disable parallel tool execution (default: true)
723 pub fn with_parallel_tool_execution(mut self, enabled: bool) -> Self {
724 self.config.parallel_tool_execution = enabled;
725 self
726 }
727
728 /// Add a sync tool with its callback
729 pub fn with_sync_tool(mut self, tool: Tool, callback: Arc<ToolCallback>) -> Self {
730 let name = tool.function.name.clone();
731 self.tools.push(tool);
732 self.callbacks
733 .insert(name, ToolCallbackType::Sync(callback));
734 self
735 }
736
737 /// Add an async tool with its callback
738 pub fn with_async_tool(mut self, tool: Tool, callback: Arc<AsyncToolCallback>) -> Self {
739 let name = tool.function.name.clone();
740 self.tools.push(tool);
741 self.callbacks
742 .insert(name, ToolCallbackType::Async(callback));
743 self
744 }
745
746 /// Register a tool using a tuple of (Tool, ToolCallbackType)
747 ///
748 /// This is designed to work with the `_tool_with_callback()` functions
749 /// generated by the `#[tool]` macro.
750 pub fn register_tool(mut self, (tool, callback): (Tool, ToolCallbackType)) -> Self {
751 let name = tool.function.name.clone();
752 self.tools.push(tool);
753 self.callbacks.insert(name, callback);
754 self
755 }
756
757 /// Build the agent
758 pub fn build(self) -> Agent {
759 Agent {
760 model: self.model,
761 tools: self.tools,
762 callbacks: self.callbacks,
763 config: self.config,
764 }
765 }
766}
767
768#[cfg(test)]
769mod tests {
770 use super::*;
771
772 #[test]
773 fn test_agent_config_default() {
774 let config = AgentConfig::default();
775 assert_eq!(config.max_iterations, 10);
776 assert!(config.system_prompt.is_none());
777 assert!(config.parallel_tool_execution);
778 }
779
780 #[test]
781 fn test_agent_stop_reason_equality() {
782 assert_eq!(AgentStopReason::TextResponse, AgentStopReason::TextResponse);
783 assert_eq!(
784 AgentStopReason::MaxIterations,
785 AgentStopReason::MaxIterations
786 );
787 assert_ne!(
788 AgentStopReason::TextResponse,
789 AgentStopReason::MaxIterations
790 );
791 }
792
793 #[test]
794 fn test_tool_callback_type_clone() {
795 // Ensure ToolCallbackType can be cloned
796 let sync_cb: Arc<ToolCallback> = Arc::new(|_| Ok("test".to_string()));
797 let cb_type = ToolCallbackType::Sync(sync_cb);
798 let _ = cb_type.clone();
799 }
800}