mistralrs_core/utils/
mod.rs

1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16    ($thing:expr) => {
17        loop {
18            if let Ok(inner) = $thing.try_lock() {
19                break inner;
20            }
21            // Yield to allow other threads to make progress and release the lock.
22            // This prevents deadlock when a spawned async task busy-loops while
23            // another task holds the lock across an await point.
24            std::thread::yield_now();
25        }
26    };
27}
28
29#[doc(hidden)]
30#[macro_export]
31macro_rules! handle_seq_error {
32    ($fallible:expr, $response:expr) => {
33        match $fallible {
34            Ok(v) => v,
35            Err(e) => {
36                use $crate::response::Response;
37                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
38                    tracing::warn!("Receiver disconnected");
39                }
40                return;
41            }
42        }
43    };
44}
45
46#[doc(hidden)]
47#[macro_export]
48macro_rules! handle_seq_error_ok {
49    ($fallible:expr, $response:expr) => {
50        match $fallible {
51            Ok(v) => v,
52            Err(e) => {
53                use $crate::response::Response;
54                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
55                    tracing::warn!("Receiver disconnected");
56                }
57                return Ok(());
58            }
59        }
60    };
61}
62
63#[doc(hidden)]
64#[macro_export]
65macro_rules! handle_seq_error_stateaware_ok {
66    ($fallible:expr, $seq:expr) => {
67        match $fallible {
68            Ok(v) => v,
69            Err(e) => {
70                use $crate::response::Response;
71                use $crate::sequence::SequenceState;
72                if let Err(_) = $seq
73                    .responder()
74                    .send(Response::InternalError(e.into()))
75                    .await
76                {
77                    tracing::warn!("Receiver disconnected");
78                }
79                $seq.set_state(SequenceState::Error);
80                return Ok(());
81            }
82        }
83    };
84}
85
86#[doc(hidden)]
87#[macro_export]
88macro_rules! handle_pipeline_forward_error {
89    ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
90        match $fallible {
91            Ok(v) => v,
92            Err(e) => {
93                let (tokenizer, pipeline_name) = {
94                    let pipeline = get_mut_arcmutex!($pipeline);
95                    let pipeline_name = pipeline.name();
96                    let tokenizer = pipeline.tokenizer();
97                    (tokenizer, pipeline_name)
98                };
99                use $crate::response::Response;
100                use $crate::sequence::SequenceState;
101                use $crate::response::SYSTEM_FINGERPRINT;
102                use tracing::error;
103                error!("{} - Model failed with error: {:?}", $stage, &e);
104                for seq in $seq_slice.iter_mut() {
105                    // Step 1: Add all choices to groups
106                    let start = seq.prompt_tokens().min(seq.get_toks().len());
107                    let res = match &tokenizer {
108                        Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
109                            Ok(t) => t,
110                            Err(_) => "".to_string(),
111                        },
112                        None => "".to_string(),
113                    };
114
115                    if seq.get_mut_group().is_chat {
116                        let choice = Choice {
117                            finish_reason: "error".to_string(),
118                            index: seq.get_response_index(),
119                            message: ResponseMessage {
120                                content: Some(res),
121                                role: "assistant".to_string(),
122                                tool_calls: None,
123                                reasoning_content: None,
124                            },
125                            logprobs: None,
126                        };
127                        seq.add_choice_to_group(choice);
128                    } else {
129                        let choice = CompletionChoice {
130                            finish_reason: "error".to_string(),
131                            index: seq.get_response_index(),
132                            text: res,
133                            logprobs: None,
134                        };
135                        seq.add_completion_choice_to_group(choice);
136                    }
137                }
138                for seq in $seq_slice.iter_mut() {
139                    // Step 2: Respond with all groups
140                    let group = seq.get_mut_group();
141
142                    if group.is_chat {
143                        let partial_completion_response = ChatCompletionResponse {
144                            id: seq.id().to_string(),
145                            choices: group.get_choices().to_vec(),
146                            created: seq.creation_time(),
147                            model: pipeline_name.clone(),
148                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
149                            object: "chat.completion".to_string(),
150                            usage: group.get_usage(),
151                        };
152
153                        seq.responder()
154                            .send(Response::ModelError(
155                                e.to_string(),
156                                partial_completion_response
157                            ))
158                            .await
159                            .unwrap();
160                    } else {
161                        let partial_completion_response = CompletionResponse {
162                            id: seq.id().to_string(),
163                            choices: group.get_completion_choices().to_vec(),
164                            created: seq.creation_time(),
165                            model: pipeline_name.clone(),
166                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
167                            object: "text_completion".to_string(),
168                            usage: group.get_usage(),
169                        };
170
171                        seq.responder()
172                            .send(Response::CompletionModelError(
173                                e.to_string(),
174                                partial_completion_response
175                            ))
176                            .await
177                            .unwrap();
178                    }
179                }
180                for seq in $seq_slice.iter_mut() {
181                    // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell
182                    seq.set_state(SequenceState::Error);
183                }
184
185                let p = get_mut_arcmutex!($pipeline);
186                // Also reset non granular state because:
187                // - The sequence is gone
188                // - We should reset the state then, including draft.
189                p.set_none_cache($seq_slice, true, true, false);
190                get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
191
192                continue $label;
193            }
194        }
195    };
196}
197
198#[doc(hidden)]
199#[macro_export]
200macro_rules! get_mut_group {
201    ($this:expr) => {
202        loop {
203            if let Ok(inner) = $this.group.try_lock() {
204                break inner;
205            }
206            // Yield to allow other threads to make progress and release the lock.
207            std::thread::yield_now();
208        }
209    };
210}
211
212#[doc(hidden)]
213#[macro_export]
214macro_rules! serde_default_fn {
215    ($t:ty, $name:ident, $v:expr) => {
216        fn $name() -> $t {
217            $v
218        }
219    };
220}
221
222/// `true` if built with CUDA (requires Unix) /Metal
223#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
224pub const fn paged_attn_supported() -> bool {
225    true
226}
227
228/// `true` if built with CUDA (requires Unix) /Metal
229#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
230pub const fn paged_attn_supported() -> bool {
231    false
232}
233
234/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
235#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
236pub const fn using_flash_attn() -> bool {
237    false
238}
239
240/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
241#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
242pub const fn using_flash_attn() -> bool {
243    true
244}