mistralrs_core/utils/
mod.rs

1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16    ($thing:expr) => {
17        loop {
18            if let Ok(inner) = $thing.try_lock() {
19                break inner;
20            }
21            // Yield to allow other threads to make progress and release the lock.
22            // This prevents deadlock when a spawned async task busy-loops while
23            // another task holds the lock across an await point.
24            std::thread::yield_now();
25        }
26    };
27}
28
29#[doc(hidden)]
30#[macro_export]
31macro_rules! handle_seq_error {
32    ($fallible:expr, $response:expr) => {
33        match $fallible {
34            Ok(v) => v,
35            Err(e) => {
36                use $crate::response::Response;
37                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
38                    tracing::warn!("Receiver disconnected");
39                }
40                return;
41            }
42        }
43    };
44}
45
46#[doc(hidden)]
47#[macro_export]
48macro_rules! handle_seq_error_ok {
49    ($fallible:expr, $response:expr) => {
50        match $fallible {
51            Ok(v) => v,
52            Err(e) => {
53                use $crate::response::Response;
54                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
55                    tracing::warn!("Receiver disconnected");
56                }
57                return Ok(());
58            }
59        }
60    };
61}
62
63#[doc(hidden)]
64#[macro_export]
65macro_rules! handle_seq_error_stateaware_ok {
66    ($fallible:expr, $seq:expr) => {
67        match $fallible {
68            Ok(v) => v,
69            Err(e) => {
70                use $crate::response::Response;
71                use $crate::sequence::SequenceState;
72                if let Err(_) = $seq
73                    .responder()
74                    .send(Response::InternalError(e.into()))
75                    .await
76                {
77                    tracing::warn!("Receiver disconnected");
78                }
79                $seq.set_state(SequenceState::Error);
80                return Ok(());
81            }
82        }
83    };
84}
85
86#[doc(hidden)]
87#[macro_export]
88macro_rules! handle_pipeline_forward_error {
89    ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
90        match $fallible {
91            Ok(v) => v,
92            Err(e) => {
93                let (tokenizer, pipeline_name) = {
94                    let pipeline = get_mut_arcmutex!($pipeline);
95                    let pipeline_name = pipeline.name();
96                    let tokenizer = pipeline.tokenizer();
97                    (tokenizer, pipeline_name)
98                };
99                use $crate::response::Response;
100                use $crate::sequence::SequenceState;
101                use $crate::response::SYSTEM_FINGERPRINT;
102                use tracing::error;
103                error!("{} - Model failed with error: {:?}", $stage, &e);
104                for seq in $seq_slice.iter_mut() {
105                    // Step 1: Add all choices to groups
106                    let start = seq.prompt_tokens().min(seq.get_toks().len());
107                    let res = match &tokenizer {
108                        Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
109                            Ok(t) => t,
110                            Err(_) => "".to_string(),
111                        },
112                        None => "".to_string(),
113                    };
114
115                    if seq.get_mut_group().is_chat {
116                        let choice = Choice {
117                            finish_reason: "error".to_string(),
118                            index: seq.get_response_index(),
119                            message: ResponseMessage {
120                                content: Some(res),
121                                role: "assistant".to_string(),
122                                tool_calls: None,
123                            },
124                            logprobs: None,
125                        };
126                        seq.add_choice_to_group(choice);
127                    } else {
128                        let choice = CompletionChoice {
129                            finish_reason: "error".to_string(),
130                            index: seq.get_response_index(),
131                            text: res,
132                            logprobs: None,
133                        };
134                        seq.add_completion_choice_to_group(choice);
135                    }
136                }
137                for seq in $seq_slice.iter_mut() {
138                    // Step 2: Respond with all groups
139                    let group = seq.get_mut_group();
140
141                    if group.is_chat {
142                        let partial_completion_response = ChatCompletionResponse {
143                            id: seq.id().to_string(),
144                            choices: group.get_choices().to_vec(),
145                            created: seq.creation_time(),
146                            model: pipeline_name.clone(),
147                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
148                            object: "chat.completion".to_string(),
149                            usage: group.get_usage(),
150                        };
151
152                        seq.responder()
153                            .send(Response::ModelError(
154                                e.to_string(),
155                                partial_completion_response
156                            ))
157                            .await
158                            .unwrap();
159                    } else {
160                        let partial_completion_response = CompletionResponse {
161                            id: seq.id().to_string(),
162                            choices: group.get_completion_choices().to_vec(),
163                            created: seq.creation_time(),
164                            model: pipeline_name.clone(),
165                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
166                            object: "text_completion".to_string(),
167                            usage: group.get_usage(),
168                        };
169
170                        seq.responder()
171                            .send(Response::CompletionModelError(
172                                e.to_string(),
173                                partial_completion_response
174                            ))
175                            .await
176                            .unwrap();
177                    }
178                }
179                for seq in $seq_slice.iter_mut() {
180                    // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell
181                    seq.set_state(SequenceState::Error);
182                }
183
184                let p = get_mut_arcmutex!($pipeline);
185                // Also reset non granular state because:
186                // - The sequence is gone
187                // - We should reset the state then, including draft.
188                p.set_none_cache($seq_slice, true, true, false);
189                get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
190
191                continue $label;
192            }
193        }
194    };
195}
196
197#[doc(hidden)]
198#[macro_export]
199macro_rules! get_mut_group {
200    ($this:expr) => {
201        loop {
202            if let Ok(inner) = $this.group.try_lock() {
203                break inner;
204            }
205            // Yield to allow other threads to make progress and release the lock.
206            std::thread::yield_now();
207        }
208    };
209}
210
211#[doc(hidden)]
212#[macro_export]
213macro_rules! serde_default_fn {
214    ($t:ty, $name:ident, $v:expr) => {
215        fn $name() -> $t {
216            $v
217        }
218    };
219}
220
221/// `true` if built with CUDA (requires Unix) /Metal
222#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
223pub const fn paged_attn_supported() -> bool {
224    true
225}
226
227/// `true` if built with CUDA (requires Unix) /Metal
228#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
229pub const fn paged_attn_supported() -> bool {
230    false
231}
232
233/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
234#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
235pub const fn using_flash_attn() -> bool {
236    false
237}
238
239/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
240#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
241pub const fn using_flash_attn() -> bool {
242    true
243}