mistralrs_core/utils/
mod.rs

1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16    ($thing:expr) => {
17        loop {
18            if let Ok(inner) = $thing.try_lock() {
19                break inner;
20            }
21        }
22    };
23}
24
25#[doc(hidden)]
26#[macro_export]
27macro_rules! handle_seq_error {
28    ($fallible:expr, $response:expr) => {
29        match $fallible {
30            Ok(v) => v,
31            Err(e) => {
32                use $crate::response::Response;
33                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
34                    tracing::warn!("Receiver disconnected");
35                }
36                return;
37            }
38        }
39    };
40}
41
42#[doc(hidden)]
43#[macro_export]
44macro_rules! handle_seq_error_ok {
45    ($fallible:expr, $response:expr) => {
46        match $fallible {
47            Ok(v) => v,
48            Err(e) => {
49                use $crate::response::Response;
50                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
51                    tracing::warn!("Receiver disconnected");
52                }
53                return Ok(());
54            }
55        }
56    };
57}
58
59#[doc(hidden)]
60#[macro_export]
61macro_rules! handle_seq_error_stateaware_ok {
62    ($fallible:expr, $seq:expr) => {
63        match $fallible {
64            Ok(v) => v,
65            Err(e) => {
66                use $crate::response::Response;
67                use $crate::sequence::SequenceState;
68                if let Err(_) = $seq
69                    .responder()
70                    .send(Response::InternalError(e.into()))
71                    .await
72                {
73                    tracing::warn!("Receiver disconnected");
74                }
75                $seq.set_state(SequenceState::Error);
76                return Ok(());
77            }
78        }
79    };
80}
81
82#[doc(hidden)]
83#[macro_export]
84macro_rules! handle_pipeline_forward_error {
85    ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
86        match $fallible {
87            Ok(v) => v,
88            Err(e) => {
89                let (tokenizer, pipeline_name) = {
90                    let pipeline = get_mut_arcmutex!($pipeline);
91                    let pipeline_name = pipeline.name();
92                    let tokenizer = pipeline.tokenizer();
93                    (tokenizer, pipeline_name)
94                };
95                use $crate::response::Response;
96                use $crate::sequence::SequenceState;
97                use $crate::response::SYSTEM_FINGERPRINT;
98                use tracing::error;
99                error!("{} - Model failed with error: {:?}", $stage, &e);
100                for seq in $seq_slice.iter_mut() {
101                    // Step 1: Add all choices to groups
102                    let start = seq.prompt_tokens().min(seq.get_toks().len());
103                    let res = match &tokenizer {
104                        Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
105                            Ok(t) => t,
106                            Err(_) => "".to_string(),
107                        },
108                        None => "".to_string(),
109                    };
110
111                    if seq.get_mut_group().is_chat {
112                        let choice = Choice {
113                            finish_reason: "error".to_string(),
114                            index: seq.get_response_index(),
115                            message: ResponseMessage {
116                                content: Some(res),
117                                role: "assistant".to_string(),
118                                tool_calls: None,
119                            },
120                            logprobs: None,
121                        };
122                        seq.add_choice_to_group(choice);
123                    } else {
124                        let choice = CompletionChoice {
125                            finish_reason: "error".to_string(),
126                            index: seq.get_response_index(),
127                            text: res,
128                            logprobs: None,
129                        };
130                        seq.add_completion_choice_to_group(choice);
131                    }
132                }
133                for seq in $seq_slice.iter_mut() {
134                    // Step 2: Respond with all groups
135                    let group = seq.get_mut_group();
136
137                    if group.is_chat {
138                        let partial_completion_response = ChatCompletionResponse {
139                            id: seq.id().to_string(),
140                            choices: group.get_choices().to_vec(),
141                            created: seq.creation_time(),
142                            model: pipeline_name.clone(),
143                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
144                            object: "chat.completion".to_string(),
145                            usage: group.get_usage(),
146                        };
147
148                        seq.responder()
149                            .send(Response::ModelError(
150                                e.to_string(),
151                                partial_completion_response
152                            ))
153                            .await
154                            .unwrap();
155                    } else {
156                        let partial_completion_response = CompletionResponse {
157                            id: seq.id().to_string(),
158                            choices: group.get_completion_choices().to_vec(),
159                            created: seq.creation_time(),
160                            model: pipeline_name.clone(),
161                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
162                            object: "text_completion".to_string(),
163                            usage: group.get_usage(),
164                        };
165
166                        seq.responder()
167                            .send(Response::CompletionModelError(
168                                e.to_string(),
169                                partial_completion_response
170                            ))
171                            .await
172                            .unwrap();
173                    }
174                }
175                for seq in $seq_slice.iter_mut() {
176                    // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell
177                    seq.set_state(SequenceState::Error);
178                }
179
180                let p = get_mut_arcmutex!($pipeline);
181                // Also reset non granular state because:
182                // - The sequence is gone
183                // - We should reset the state then, including draft.
184                p.set_none_cache($seq_slice, true, true, false);
185                get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
186
187                continue $label;
188            }
189        }
190    };
191}
192
193#[doc(hidden)]
194#[macro_export]
195macro_rules! get_mut_group {
196    ($this:expr) => {
197        loop {
198            if let Ok(inner) = $this.group.try_lock() {
199                break inner;
200            }
201        }
202    };
203}
204
205#[doc(hidden)]
206#[macro_export]
207macro_rules! serde_default_fn {
208    ($t:ty, $name:ident, $v:expr) => {
209        fn $name() -> $t {
210            $v
211        }
212    };
213}
214
215/// `true` if built with CUDA (requires Unix) /Metal
216#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
217pub const fn paged_attn_supported() -> bool {
218    true
219}
220
221/// `true` if built with CUDA (requires Unix) /Metal
222#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
223pub const fn paged_attn_supported() -> bool {
224    false
225}
226
227/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
228#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
229pub const fn using_flash_attn() -> bool {
230    false
231}
232
233/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
234#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
235pub const fn using_flash_attn() -> bool {
236    true
237}