mistralrs_core/utils/
mod.rs

1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tokenizer;
8pub(crate) mod tokens;
9pub(crate) mod unvarbuilder;
10pub(crate) mod varbuilder_utils;
11
12#[doc(hidden)]
13#[macro_export]
14macro_rules! get_mut_arcmutex {
15    ($thing:expr) => {
16        loop {
17            if let Ok(inner) = $thing.try_lock() {
18                break inner;
19            }
20        }
21    };
22}
23
24#[doc(hidden)]
25#[macro_export]
26macro_rules! handle_seq_error {
27    ($fallible:expr, $response:expr) => {
28        match $fallible {
29            Ok(v) => v,
30            Err(e) => {
31                use $crate::response::Response;
32                $response
33                    .send(Response::InternalError(e.into()))
34                    .await
35                    .expect("Expected receiver.");
36                return;
37            }
38        }
39    };
40}
41
42#[doc(hidden)]
43#[macro_export]
44macro_rules! handle_seq_error_ok {
45    ($fallible:expr, $response:expr) => {
46        match $fallible {
47            Ok(v) => v,
48            Err(e) => {
49                use $crate::response::Response;
50                $response
51                    .send(Response::InternalError(e.into()))
52                    .await
53                    .expect("Expected receiver.");
54                return Ok(());
55            }
56        }
57    };
58}
59
60#[doc(hidden)]
61#[macro_export]
62macro_rules! handle_seq_error_stateaware_ok {
63    ($fallible:expr, $seq:expr) => {
64        match $fallible {
65            Ok(v) => v,
66            Err(e) => {
67                use $crate::response::Response;
68                use $crate::sequence::SequenceState;
69                $seq.responder()
70                    .send(Response::InternalError(e.into()))
71                    .await
72                    .expect("Expected receiver.");
73                $seq.set_state(SequenceState::Error);
74                return Ok(());
75            }
76        }
77    };
78}
79
80#[doc(hidden)]
81#[macro_export]
82macro_rules! handle_pipeline_forward_error {
83    ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
84        match $fallible {
85            Ok(v) => v,
86            Err(e) => {
87                let (tokenizer, pipeline_name) = {
88                    let pipeline = get_mut_arcmutex!($pipeline);
89                    let pipeline_name = pipeline.name();
90                    let tokenizer = pipeline.tokenizer();
91                    (tokenizer, pipeline_name)
92                };
93                use $crate::response::Response;
94                use $crate::sequence::SequenceState;
95                use $crate::response::SYSTEM_FINGERPRINT;
96                use tracing::error;
97                error!("{} - Model failed with error: {:?}", $stage, &e);
98                for seq in $seq_slice.iter_mut() {
99                    // Step 1: Add all choices to groups
100                    let res = match &tokenizer
101                    {
102                        Some(tok) => match tok.decode(&seq.get_toks()[seq.prompt_tokens()..], false) {
103                            Ok(t) => t,
104                            Err(_) => "".to_string()
105                        },
106                        None => "".to_string()
107                    };
108
109                    if seq.get_mut_group().is_chat {
110                        let choice = Choice {
111                            finish_reason: "error".to_string(),
112                            index: seq.get_response_index(),
113                            message: ResponseMessage {
114                                content: Some(res),
115                                role: "assistant".to_string(),
116                                tool_calls: None,
117                            },
118                            logprobs: None,
119                        };
120                        seq.add_choice_to_group(choice);
121                    } else {
122                        let choice = CompletionChoice {
123                            finish_reason: "error".to_string(),
124                            index: seq.get_response_index(),
125                            text: res,
126                            logprobs: None,
127                        };
128                        seq.add_completion_choice_to_group(choice);
129                    }
130                }
131                for seq in $seq_slice.iter_mut() {
132                    // Step 2: Respond with all groups
133                    let group = seq.get_mut_group();
134
135                    if group.is_chat {
136                        let partial_completion_response = ChatCompletionResponse {
137                            id: seq.id().to_string(),
138                            choices: group.get_choices().to_vec(),
139                            created: seq.creation_time(),
140                            model: pipeline_name.clone(),
141                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
142                            object: "chat.completion".to_string(),
143                            usage: group.get_usage(),
144                        };
145
146                        seq.responder()
147                            .send(Response::ModelError(
148                                e.to_string(),
149                                partial_completion_response
150                            ))
151                            .await
152                            .unwrap();
153                    } else {
154                        let partial_completion_response = CompletionResponse {
155                            id: seq.id().to_string(),
156                            choices: group.get_completion_choices().to_vec(),
157                            created: seq.creation_time(),
158                            model: pipeline_name.clone(),
159                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
160                            object: "text_completion".to_string(),
161                            usage: group.get_usage(),
162                        };
163
164                        seq.responder()
165                            .send(Response::CompletionModelError(
166                                e.to_string(),
167                                partial_completion_response
168                            ))
169                            .await
170                            .unwrap();
171                    }
172                }
173                for seq in $seq_slice.iter_mut() {
174                    // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell
175                    seq.set_state(SequenceState::Error);
176                }
177
178                let p = get_mut_arcmutex!($pipeline);
179                // Also reset non granular state because:
180                // - The sequence is gone
181                // - We should reset the state then, including draft.
182                p.set_none_cache($seq_slice, true, true, false);
183                get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
184
185                continue $label;
186            }
187        }
188    };
189}
190
191#[doc(hidden)]
192#[macro_export]
193macro_rules! get_mut_group {
194    ($this:expr) => {
195        loop {
196            if let Ok(inner) = $this.group.try_lock() {
197                break inner;
198            }
199        }
200    };
201}
202
203#[doc(hidden)]
204#[macro_export]
205macro_rules! serde_default_fn {
206    ($t:ty, $name:ident, $v:expr) => {
207        fn $name() -> $t {
208            $v
209        }
210    };
211}
212
213/// `true` if built with CUDA (requires Unix) /Metal
214#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
215pub const fn paged_attn_supported() -> bool {
216    true
217}
218
219/// `true` if built with CUDA (requires Unix) /Metal
220#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
221pub const fn paged_attn_supported() -> bool {
222    false
223}
224
225/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
226#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
227pub const fn using_flash_attn() -> bool {
228    false
229}
230
231/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
232#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
233pub const fn using_flash_attn() -> bool {
234    true
235}