mistralrs_core/utils/
mod.rs

1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod log;
4pub(crate) mod memory_usage;
5pub(crate) mod model_config;
6pub(crate) mod normal;
7pub(crate) mod progress;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16    ($thing:expr) => {
17        loop {
18            if let Ok(inner) = $thing.try_lock() {
19                break inner;
20            }
21        }
22    };
23}
24
25#[doc(hidden)]
26#[macro_export]
27macro_rules! handle_seq_error {
28    ($fallible:expr, $response:expr) => {
29        match $fallible {
30            Ok(v) => v,
31            Err(e) => {
32                use $crate::response::Response;
33                $response
34                    .send(Response::InternalError(e.into()))
35                    .await
36                    .expect("Expected receiver.");
37                return;
38            }
39        }
40    };
41}
42
43#[doc(hidden)]
44#[macro_export]
45macro_rules! handle_seq_error_ok {
46    ($fallible:expr, $response:expr) => {
47        match $fallible {
48            Ok(v) => v,
49            Err(e) => {
50                use $crate::response::Response;
51                $response
52                    .send(Response::InternalError(e.into()))
53                    .await
54                    .expect("Expected receiver.");
55                return Ok(());
56            }
57        }
58    };
59}
60
61#[doc(hidden)]
62#[macro_export]
63macro_rules! handle_seq_error_stateaware_ok {
64    ($fallible:expr, $seq:expr) => {
65        match $fallible {
66            Ok(v) => v,
67            Err(e) => {
68                use $crate::response::Response;
69                use $crate::sequence::SequenceState;
70                $seq.responder()
71                    .send(Response::InternalError(e.into()))
72                    .await
73                    .expect("Expected receiver.");
74                $seq.set_state(SequenceState::Error);
75                return Ok(());
76            }
77        }
78    };
79}
80
81#[doc(hidden)]
82#[macro_export]
83macro_rules! handle_pipeline_forward_error {
84    ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
85        match $fallible {
86            Ok(v) => v,
87            Err(e) => {
88                let (tokenizer, pipeline_name) = {
89                    let pipeline = get_mut_arcmutex!($pipeline);
90                    let pipeline_name = pipeline.name();
91                    let tokenizer = pipeline.tokenizer();
92                    (tokenizer, pipeline_name)
93                };
94                use $crate::response::Response;
95                use $crate::sequence::SequenceState;
96                use $crate::response::SYSTEM_FINGERPRINT;
97                use tracing::error;
98                error!("{} - Model failed with error: {:?}", $stage, &e);
99                for seq in $seq_slice.iter_mut() {
100                    // Step 1: Add all choices to groups
101                    let res = match &tokenizer
102                    {
103                        Some(tok) => match tok.decode(&seq.get_toks()[seq.prompt_tokens()..], false) {
104                            Ok(t) => t,
105                            Err(_) => "".to_string()
106                        },
107                        None => "".to_string()
108                    };
109
110                    if seq.get_mut_group().is_chat {
111                        let choice = Choice {
112                            finish_reason: "error".to_string(),
113                            index: seq.get_response_index(),
114                            message: ResponseMessage {
115                                content: Some(res),
116                                role: "assistant".to_string(),
117                                tool_calls: None,
118                            },
119                            logprobs: None,
120                        };
121                        seq.add_choice_to_group(choice);
122                    } else {
123                        let choice = CompletionChoice {
124                            finish_reason: "error".to_string(),
125                            index: seq.get_response_index(),
126                            text: res,
127                            logprobs: None,
128                        };
129                        seq.add_completion_choice_to_group(choice);
130                    }
131                }
132                for seq in $seq_slice.iter_mut() {
133                    // Step 2: Respond with all groups
134                    let group = seq.get_mut_group();
135
136                    if group.is_chat {
137                        let partial_completion_response = ChatCompletionResponse {
138                            id: seq.id().to_string(),
139                            choices: group.get_choices().to_vec(),
140                            created: seq.creation_time(),
141                            model: pipeline_name.clone(),
142                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
143                            object: "chat.completion".to_string(),
144                            usage: group.get_usage(),
145                        };
146
147                        seq.responder()
148                            .send(Response::ModelError(
149                                e.to_string(),
150                                partial_completion_response
151                            ))
152                            .await
153                            .unwrap();
154                    } else {
155                        let partial_completion_response = CompletionResponse {
156                            id: seq.id().to_string(),
157                            choices: group.get_completion_choices().to_vec(),
158                            created: seq.creation_time(),
159                            model: pipeline_name.clone(),
160                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
161                            object: "text_completion".to_string(),
162                            usage: group.get_usage(),
163                        };
164
165                        seq.responder()
166                            .send(Response::CompletionModelError(
167                                e.to_string(),
168                                partial_completion_response
169                            ))
170                            .await
171                            .unwrap();
172                    }
173                }
174                for seq in $seq_slice.iter_mut() {
175                    // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell
176                    seq.set_state(SequenceState::Error);
177                }
178
179                let p = get_mut_arcmutex!($pipeline);
180                // Also reset non granular state because:
181                // - The sequence is gone
182                // - We should reset the state then, including draft.
183                p.set_none_cache($seq_slice, true, true, false);
184                get_mut_arcmutex!($prefix_cacher).evict_all_to_cpu().unwrap();
185
186                continue $label;
187            }
188        }
189    };
190}
191
192#[doc(hidden)]
193#[macro_export]
194macro_rules! get_mut_group {
195    ($this:expr) => {
196        loop {
197            if let Ok(inner) = $this.group.try_lock() {
198                break inner;
199            }
200        }
201    };
202}
203
204#[doc(hidden)]
205#[macro_export]
206macro_rules! serde_default_fn {
207    ($t:ty, $name:ident, $v:expr) => {
208        fn $name() -> $t {
209            $v
210        }
211    };
212}
213
214/// `true` if built with CUDA (requires Unix) /Metal
215#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
216pub const fn paged_attn_supported() -> bool {
217    true
218}
219
220/// `true` if built with CUDA (requires Unix) /Metal
221#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
222pub const fn paged_attn_supported() -> bool {
223    false
224}
225
226/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
227#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
228pub const fn using_flash_attn() -> bool {
229    false
230}
231
232/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
233#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
234pub const fn using_flash_attn() -> bool {
235    true
236}