mistralrs_server_core/
lib.rs

1//! > **mistral.rs server core**
2//!
3//! ## About
4//!
5//! This crate powers mistral.rs server. It exposes the underlying functionality
6//! allowing others to implement and extend the server implementation.
7//!
8//! ### Features
9//! 1. Incorporate mistral.rs server into another axum.rs project.
10//! 2. Hook into the mistral.rs server lifecycle.
11//!
12//! ### Example
13//! ```no_run
14//! use std::sync::Arc;
15//!
16//! use axum::{
17//!     extract::State,
18//!     routing::{get, post},
19//!     Json, Router,
20//! };
21//! use utoipa::OpenApi;
22//! use utoipa_swagger_ui::SwaggerUi;
23//!
24//! use mistralrs_core::{
25//!     initialize_logging, AutoDeviceMapParams, ChatCompletionChunkResponse, ModelDType, ModelSelected,
26//! };
27//! use mistralrs_server_core::{
28//!     chat_completion::{
29//!         create_streamer, handle_error, parse_request, process_non_streaming_response,
30//!         ChatCompletionOnChunkCallback, ChatCompletionOnDoneCallback, ChatCompletionResponder,
31//!     },
32//!     handler_core::{create_response_channel, send_request},
33//!     mistralrs_for_server_builder::MistralRsForServerBuilder,
34//!     mistralrs_server_router_builder::MistralRsServerRouterBuilder,
35//!     openai::ChatCompletionRequest,
36//!     openapi_doc::get_openapi_doc,
37//!     types::SharedMistralRsState,
38//! };
39//!
40//! #[derive(OpenApi)]
41//! #[openapi(
42//!     paths(root, custom_chat),
43//!     tags(
44//!         (name = "hello", description = "Hello world endpoints")
45//!     ),
46//!     info(
47//!         title = "Hello World API",
48//!         version = "1.0.0",
49//!         description = "A simple API that responds with a greeting"
50//!     )
51//! )]
52//! struct ApiDoc;
53//!
54//! #[derive(Clone)]
55//! pub struct AppState {
56//!     pub mistralrs_state: SharedMistralRsState,
57//!     pub db_create: fn(),
58//! }
59//!
60//! #[tokio::main]
61//! async fn main() {
62//!     initialize_logging();
63//!
64//!     let plain_model_id = String::from("meta-llama/Llama-3.2-1B-Instruct");
65//!     let tokenizer_json = None;
66//!     let arch = None;
67//!     let organization = None;
68//!     let write_uqff = None;
69//!     let from_uqff = None;
70//!     let imatrix = None;
71//!     let calibration_file = None;
72//!     let hf_cache_path = None;
73//!
74//!     let dtype = ModelDType::Auto;
75//!     let topology = None;
76//!     let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
77//!     let max_batch_size = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE;
78//!     let matformer_config_path = None;
79//!     let matformer_slice_name = None;
80//!
81//!     let model = ModelSelected::Plain {
82//!         model_id: plain_model_id,
83//!         tokenizer_json,
84//!         arch,
85//!         dtype,
86//!         topology,
87//!         organization,
88//!         write_uqff,
89//!         from_uqff,
90//!         imatrix,
91//!         calibration_file,
92//!         max_seq_len,
93//!         max_batch_size,
94//!         hf_cache_path,
95//!         matformer_config_path,
96//!         matformer_slice_name,
97//!     };
98//!
99//!     let shared_mistralrs = MistralRsForServerBuilder::new()
100//!         .with_model(model)
101//!         .with_in_situ_quant("8".to_string())
102//!         .set_paged_attn(Some(true))
103//!         .build()
104//!         .await
105//!         .unwrap();
106//!
107//!     let mistralrs_base_path = "/api/mistral";
108//!
109//!     let mistralrs_routes = MistralRsServerRouterBuilder::new()
110//!         .with_mistralrs(shared_mistralrs.clone())
111//!         .with_include_swagger_routes(false)
112//!         .with_base_path(mistralrs_base_path)
113//!         .build()
114//!         .await
115//!         .unwrap();
116//!
117//!     let mistralrs_doc = get_openapi_doc(Some(mistralrs_base_path));
118//!     let mut api_docs = ApiDoc::openapi();
119//!     api_docs.merge(mistralrs_doc);
120//!
121//!     let app_state = Arc::new(AppState {
122//!         mistralrs_state: shared_mistralrs,
123//!         db_create: mock_db_call,
124//!     });
125//!
126//!     let app = Router::new()
127//!         .route("/", get(root))
128//!         .route("/chat", post(custom_chat))
129//!         .with_state(app_state.clone())
130//!         .nest(mistralrs_base_path, mistralrs_routes)
131//!         .merge(SwaggerUi::new("/api-docs").url("/api-docs/openapi.json", api_docs));
132//!
133//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
134//!     axum::serve(listener, app).await.unwrap();
135//!
136//!     println!("Listening on 0.0.0.0:3000");
137//! }
138//!
139//! #[utoipa::path(
140//!     get,
141//!     path = "/",
142//!     tag = "hello",
143//!     responses(
144//!         (status = 200, description = "Successful response with greeting message", body = String)
145//!     )
146//! )]
147//! async fn root() -> &'static str {
148//!     "Hello, World!"
149//! }
150//!
151//! #[utoipa::path(
152//!     post,
153//!     tag = "Custom",
154//!     path = "/chat",
155//!     request_body = ChatCompletionRequest,
156//!     responses((status = 200, description = "Chat completions"))
157//! )]
158//! pub async fn custom_chat(
159//!     State(state): State<Arc<AppState>>,
160//!     Json(oai_request): Json<ChatCompletionRequest>,
161//! ) -> ChatCompletionResponder {
162//!     let mistralrs_state = state.mistralrs_state.clone();
163//!     let (tx, mut rx) = create_response_channel(None);
164//!
165//!     let (request, is_streaming) =
166//!         match parse_request(oai_request, mistralrs_state.clone(), tx).await {
167//!             Ok(x) => x,
168//!             Err(e) => return handle_error(mistralrs_state, e.into()),
169//!         };
170//!
171//!     dbg!(request.clone());
172//!
173//!     if let Err(e) = send_request(&mistralrs_state, request).await {
174//!         return handle_error(mistralrs_state, e.into());
175//!     }
176//!
177//!     if is_streaming {
178//!         let db_fn = state.db_create;
179//!
180//!         let on_chunk: ChatCompletionOnChunkCallback =
181//!             Box::new(move |mut chunk: ChatCompletionChunkResponse| {
182//!                 dbg!(&chunk);
183//!
184//!                 if let Some(original_content) = &chunk.choices[0].delta.content {
185//!                     chunk.choices[0].delta.content = Some(format!("CHANGED! {}", original_content));
186//!                 }
187//!
188//!                 chunk.clone()
189//!             });
190//!
191//!         let on_done: ChatCompletionOnDoneCallback =
192//!             Box::new(move |chunks: &[ChatCompletionChunkResponse]| {
193//!                 dbg!(chunks);
194//!                 (db_fn)();
195//!             });
196//!
197//!         let streamer = create_streamer(rx, mistralrs_state.clone(), Some(on_chunk), Some(on_done));
198//!
199//!         ChatCompletionResponder::Sse(streamer)
200//!     } else {
201//!         let response = process_non_streaming_response(&mut rx, mistralrs_state.clone()).await;
202//!
203//!         match &response {
204//!             ChatCompletionResponder::Json(json_response) => {
205//!                 dbg!(json_response);
206//!                 (state.db_create)();
207//!             }
208//!             _ => {
209//!                 //
210//!             }
211//!         }
212//!
213//!         response
214//!     }
215//! }
216//!
217//! pub fn mock_db_call() {
218//!     println!("Saving to DB");
219//! }
220//! ```
221
222pub mod cached_responses;
223pub mod chat_completion;
224mod completion_core;
225pub mod completions;
226pub mod handler_core;
227mod handlers;
228pub mod image_generation;
229pub mod mistralrs_for_server_builder;
230pub mod mistralrs_server_router_builder;
231pub mod openai;
232pub mod openapi_doc;
233pub mod responses;
234pub mod speech_generation;
235pub mod streaming;
236pub mod types;
237pub mod util;