mistralrs_server_core/
lib.rs

1//! > **mistral.rs server core**
2//!
3//! ## About
4//!
5//! This crate powers mistral.rs server. It exposes the underlying functionality
6//! allowing others to implement and extend the server implementation.
7//!
8//! ### Features
9//! 1. Incorporate mistral.rs server into another axum.rs project.
10//! 2. Hook into the mistral.rs server lifecycle.
11//!
12//! ### Example
13//! ```no_run
14//! use std::sync::Arc;
15//!
16//! use axum::{
17//!     extract::State,
18//!     routing::{get, post},
19//!     Json, Router,
20//! };
21//! use utoipa::OpenApi;
22//! use utoipa_swagger_ui::SwaggerUi;
23//!
24//! use mistralrs_core::{
25//!     initialize_logging, AutoDeviceMapParams, ChatCompletionChunkResponse, ModelDType, ModelSelected,
26//! };
27//! use mistralrs_server_core::{
28//!     chat_completion::{
29//!         create_streamer, handle_error, parse_request, process_non_streaming_response,
30//!         ChatCompletionOnChunkCallback, ChatCompletionOnDoneCallback, ChatCompletionResponder,
31//!     },
32//!     handler_core::{create_response_channel, send_request},
33//!     mistralrs_for_server_builder::MistralRsForServerBuilder,
34//!     mistralrs_server_router_builder::MistralRsServerRouterBuilder,
35//!     openai::ChatCompletionRequest,
36//!     openapi_doc::get_openapi_doc,
37//!     types::SharedMistralRsState,
38//! };
39//!
40//! #[derive(OpenApi)]
41//! #[openapi(
42//!     paths(root, custom_chat),
43//!     tags(
44//!         (name = "hello", description = "Hello world endpoints")
45//!     ),
46//!     info(
47//!         title = "Hello World API",
48//!         version = "1.0.0",
49//!         description = "A simple API that responds with a greeting"
50//!     )
51//! )]
52//! struct ApiDoc;
53//!
54//! #[derive(Clone)]
55//! pub struct AppState {
56//!     pub mistralrs_state: SharedMistralRsState,
57//!     pub db_create: fn(),
58//! }
59//!
60//! #[tokio::main]
61//! async fn main() {
62//!     initialize_logging();
63//!
64//!     let plain_model_id = String::from("meta-llama/Llama-3.2-1B-Instruct");
65//!     let tokenizer_json = None;
66//!     let arch = None;
67//!     let organization = None;
68//!     let write_uqff = None;
69//!     let from_uqff = None;
70//!     let imatrix = None;
71//!     let calibration_file = None;
72//!     let hf_cache_path = None;
73//!
74//!     let dtype = ModelDType::Auto;
75//!     let topology = None;
76//!     let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
77//!     let max_batch_size = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE;
78//!
79//!     let model = ModelSelected::Plain {
80//!         model_id: plain_model_id,
81//!         tokenizer_json,
82//!         arch,
83//!         dtype,
84//!         topology,
85//!         organization,
86//!         write_uqff,
87//!         from_uqff,
88//!         imatrix,
89//!         calibration_file,
90//!         max_seq_len,
91//!         max_batch_size,
92//!         hf_cache_path,
93//!     };
94//!
95//!     let shared_mistralrs = MistralRsForServerBuilder::new()
96//!         .with_model(model)
97//!         .with_in_situ_quant("8".to_string())
98//!         .set_paged_attn(Some(true))
99//!         .build()
100//!         .await
101//!         .unwrap();
102//!
103//!     let mistralrs_base_path = "/api/mistral";
104//!
105//!     let mistralrs_routes = MistralRsServerRouterBuilder::new()
106//!         .with_mistralrs(shared_mistralrs.clone())
107//!         .with_include_swagger_routes(false)
108//!         .with_base_path(mistralrs_base_path)
109//!         .build()
110//!         .await
111//!         .unwrap();
112//!
113//!     let mistralrs_doc = get_openapi_doc(Some(mistralrs_base_path));
114//!     let mut api_docs = ApiDoc::openapi();
115//!     api_docs.merge(mistralrs_doc);
116//!
117//!     let app_state = Arc::new(AppState {
118//!         mistralrs_state: shared_mistralrs,
119//!         db_create: mock_db_call,
120//!     });
121//!
122//!     let app = Router::new()
123//!         .route("/", get(root))
124//!         .route("/chat", post(custom_chat))
125//!         .with_state(app_state.clone())
126//!         .nest(mistralrs_base_path, mistralrs_routes)
127//!         .merge(SwaggerUi::new("/api-docs").url("/api-docs/openapi.json", api_docs));
128//!
129//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
130//!     axum::serve(listener, app).await.unwrap();
131//!
132//!     println!("Listening on 0.0.0.0:3000");
133//! }
134//!
135//! #[utoipa::path(
136//!     get,
137//!     path = "/",
138//!     tag = "hello",
139//!     responses(
140//!         (status = 200, description = "Successful response with greeting message", body = String)
141//!     )
142//! )]
143//! async fn root() -> &'static str {
144//!     "Hello, World!"
145//! }
146//!
147//! #[utoipa::path(
148//!     post,
149//!     tag = "Custom",
150//!     path = "/chat",
151//!     request_body = ChatCompletionRequest,
152//!     responses((status = 200, description = "Chat completions"))
153//! )]
154//! pub async fn custom_chat(
155//!     State(state): State<Arc<AppState>>,
156//!     Json(oai_request): Json<ChatCompletionRequest>,
157//! ) -> ChatCompletionResponder {
158//!     let mistralrs_state = state.mistralrs_state.clone();
159//!     let (tx, mut rx) = create_response_channel(None);
160//!
161//!     let (request, is_streaming) =
162//!         match parse_request(oai_request, mistralrs_state.clone(), tx).await {
163//!             Ok(x) => x,
164//!             Err(e) => return handle_error(mistralrs_state, e.into()),
165//!         };
166//!
167//!     dbg!(request.clone());
168//!
169//!     if let Err(e) = send_request(&mistralrs_state, request).await {
170//!         return handle_error(mistralrs_state, e.into());
171//!     }
172//!
173//!     if is_streaming {
174//!         let db_fn = state.db_create;
175//!
176//!         let on_chunk: ChatCompletionOnChunkCallback =
177//!             Box::new(move |mut chunk: ChatCompletionChunkResponse| {
178//!                 dbg!(&chunk);
179//!
180//!                 if let Some(original_content) = &chunk.choices[0].delta.content {
181//!                     chunk.choices[0].delta.content = Some(format!("CHANGED! {}", original_content));
182//!                 }
183//!
184//!                 chunk.clone()
185//!             });
186//!
187//!         let on_done: ChatCompletionOnDoneCallback =
188//!             Box::new(move |chunks: &[ChatCompletionChunkResponse]| {
189//!                 dbg!(chunks);
190//!                 (db_fn)();
191//!             });
192//!
193//!         let streamer = create_streamer(rx, mistralrs_state.clone(), Some(on_chunk), Some(on_done));
194//!
195//!         ChatCompletionResponder::Sse(streamer)
196//!     } else {
197//!         let response = process_non_streaming_response(&mut rx, mistralrs_state.clone()).await;
198//!
199//!         match &response {
200//!             ChatCompletionResponder::Json(json_response) => {
201//!                 dbg!(json_response);
202//!                 (state.db_create)();
203//!             }
204//!             _ => {
205//!                 //
206//!             }
207//!         }
208//!
209//!         response
210//!     }
211//! }
212//!
213//! pub fn mock_db_call() {
214//!     println!("Saving to DB");
215//! }
216//! ```
217
218pub mod chat_completion;
219mod completion_core;
220pub mod completions;
221pub mod handler_core;
222mod handlers;
223pub mod image_generation;
224pub mod mistralrs_for_server_builder;
225pub mod mistralrs_server_router_builder;
226pub mod openai;
227pub mod openapi_doc;
228pub mod speech_generation;
229pub mod streaming;
230pub mod types;
231pub mod util;