mistralrs_server_core/lib.rs
1//! > **mistral.rs server core**
2//!
3//! ## About
4//!
5//! This crate powers mistral.rs server. It exposes the underlying functionality
6//! allowing others to implement and extend the server implementation.
7//!
8//! ### Features
9//! 1. Incorporate mistral.rs server into another axum.rs project.
10//! 2. Hook into the mistral.rs server lifecycle.
11//!
12//! ### Example
13//! ```no_run
14//! use std::sync::Arc;
15//!
16//! use axum::{
17//! extract::State,
18//! routing::{get, post},
19//! Json, Router,
20//! };
21//! use utoipa::OpenApi;
22//! use utoipa_swagger_ui::SwaggerUi;
23//!
24//! use mistralrs_core::{
25//! initialize_logging, AutoDeviceMapParams, ChatCompletionChunkResponse, ModelDType, ModelSelected,
26//! };
27//! use mistralrs_server_core::{
28//! chat_completion::{
29//! create_streamer, handle_error, parse_request, process_non_streaming_response,
30//! ChatCompletionOnChunkCallback, ChatCompletionOnDoneCallback, ChatCompletionResponder,
31//! },
32//! handler_core::{create_response_channel, send_request},
33//! mistralrs_for_server_builder::MistralRsForServerBuilder,
34//! mistralrs_server_router_builder::MistralRsServerRouterBuilder,
35//! openai::ChatCompletionRequest,
36//! openapi_doc::get_openapi_doc,
37//! types::SharedMistralRsState,
38//! };
39//!
40//! #[derive(OpenApi)]
41//! #[openapi(
42//! paths(root, custom_chat),
43//! tags(
44//! (name = "hello", description = "Hello world endpoints")
45//! ),
46//! info(
47//! title = "Hello World API",
48//! version = "1.0.0",
49//! description = "A simple API that responds with a greeting"
50//! )
51//! )]
52//! struct ApiDoc;
53//!
54//! #[derive(Clone)]
55//! pub struct AppState {
56//! pub mistralrs_state: SharedMistralRsState,
57//! pub db_create: fn(),
58//! }
59//!
60//! #[tokio::main]
61//! async fn main() {
62//! initialize_logging();
63//!
64//! let plain_model_id = String::from("meta-llama/Llama-3.2-1B-Instruct");
65//! let tokenizer_json = None;
66//! let arch = None;
67//! let organization = None;
68//! let write_uqff = None;
69//! let from_uqff = None;
70//! let imatrix = None;
71//! let calibration_file = None;
72//! let hf_cache_path = None;
73//!
74//! let dtype = ModelDType::Auto;
75//! let topology = None;
76//! let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
77//! let max_batch_size = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE;
78//! let matformer_config_path = None;
79//! let matformer_slice_name = None;
80//!
81//! let model = ModelSelected::Plain {
82//! model_id: plain_model_id,
83//! tokenizer_json,
84//! arch,
85//! dtype,
86//! topology,
87//! organization,
88//! write_uqff,
89//! from_uqff,
90//! imatrix,
91//! calibration_file,
92//! max_seq_len,
93//! max_batch_size,
94//! hf_cache_path,
95//! matformer_config_path,
96//! matformer_slice_name,
97//! };
98//!
99//! let shared_mistralrs = MistralRsForServerBuilder::new()
100//! .with_model(model)
101//! .with_in_situ_quant("8".to_string())
102//! .set_paged_attn(Some(true))
103//! .build()
104//! .await
105//! .unwrap();
106//!
107//! let mistralrs_base_path = "/api/mistral";
108//!
109//! let mistralrs_routes = MistralRsServerRouterBuilder::new()
110//! .with_mistralrs(shared_mistralrs.clone())
111//! .with_include_swagger_routes(false)
112//! .with_base_path(mistralrs_base_path)
113//! .build()
114//! .await
115//! .unwrap();
116//!
117//! let mistralrs_doc = get_openapi_doc(Some(mistralrs_base_path));
118//! let mut api_docs = ApiDoc::openapi();
119//! api_docs.merge(mistralrs_doc);
120//!
121//! let app_state = Arc::new(AppState {
122//! mistralrs_state: shared_mistralrs,
123//! db_create: mock_db_call,
124//! });
125//!
126//! let app = Router::new()
127//! .route("/", get(root))
128//! .route("/chat", post(custom_chat))
129//! .with_state(app_state.clone())
130//! .nest(mistralrs_base_path, mistralrs_routes)
131//! .merge(SwaggerUi::new("/api-docs").url("/api-docs/openapi.json", api_docs));
132//!
133//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
134//! axum::serve(listener, app).await.unwrap();
135//!
136//! println!("Listening on 0.0.0.0:3000");
137//! }
138//!
139//! #[utoipa::path(
140//! get,
141//! path = "/",
142//! tag = "hello",
143//! responses(
144//! (status = 200, description = "Successful response with greeting message", body = String)
145//! )
146//! )]
147//! async fn root() -> &'static str {
148//! "Hello, World!"
149//! }
150//!
151//! #[utoipa::path(
152//! post,
153//! tag = "Custom",
154//! path = "/chat",
155//! request_body = ChatCompletionRequest,
156//! responses((status = 200, description = "Chat completions"))
157//! )]
158//! pub async fn custom_chat(
159//! State(state): State<Arc<AppState>>,
160//! Json(oai_request): Json<ChatCompletionRequest>,
161//! ) -> ChatCompletionResponder {
162//! let mistralrs_state = state.mistralrs_state.clone();
163//! let (tx, mut rx) = create_response_channel(None);
164//!
165//! let (request, is_streaming) =
166//! match parse_request(oai_request, mistralrs_state.clone(), tx).await {
167//! Ok(x) => x,
168//! Err(e) => return handle_error(mistralrs_state, e.into()),
169//! };
170//!
171//! dbg!(request.clone());
172//!
173//! if let Err(e) = send_request(&mistralrs_state, request).await {
174//! return handle_error(mistralrs_state, e.into());
175//! }
176//!
177//! if is_streaming {
178//! let db_fn = state.db_create;
179//!
180//! let on_chunk: ChatCompletionOnChunkCallback =
181//! Box::new(move |mut chunk: ChatCompletionChunkResponse| {
182//! dbg!(&chunk);
183//!
184//! if let Some(original_content) = &chunk.choices[0].delta.content {
185//! chunk.choices[0].delta.content = Some(format!("CHANGED! {}", original_content));
186//! }
187//!
188//! chunk.clone()
189//! });
190//!
191//! let on_done: ChatCompletionOnDoneCallback =
192//! Box::new(move |chunks: &[ChatCompletionChunkResponse]| {
193//! dbg!(chunks);
194//! (db_fn)();
195//! });
196//!
197//! let streamer = create_streamer(rx, mistralrs_state.clone(), Some(on_chunk), Some(on_done));
198//!
199//! ChatCompletionResponder::Sse(streamer)
200//! } else {
201//! let response = process_non_streaming_response(&mut rx, mistralrs_state.clone()).await;
202//!
203//! match &response {
204//! ChatCompletionResponder::Json(json_response) => {
205//! dbg!(json_response);
206//! (state.db_create)();
207//! }
208//! _ => {
209//! //
210//! }
211//! }
212//!
213//! response
214//! }
215//! }
216//!
217//! pub fn mock_db_call() {
218//! println!("Saving to DB");
219//! }
220//! ```
221
222pub mod cached_responses;
223pub mod chat_completion;
224mod completion_core;
225pub mod completions;
226pub mod handler_core;
227mod handlers;
228pub mod image_generation;
229pub mod mistralrs_for_server_builder;
230pub mod mistralrs_server_router_builder;
231pub mod openai;
232pub mod openapi_doc;
233pub mod responses;
234pub mod speech_generation;
235pub mod streaming;
236pub mod types;
237pub mod util;