mistralrs_core/utils/
mod.rs1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16 ($thing:expr) => {
17 loop {
18 if let Ok(inner) = $thing.try_lock() {
19 break inner;
20 }
21 std::thread::yield_now();
25 }
26 };
27}
28
29#[doc(hidden)]
30#[macro_export]
31macro_rules! handle_seq_error {
32 ($fallible:expr, $response:expr) => {
33 match $fallible {
34 Ok(v) => v,
35 Err(e) => {
36 use $crate::response::Response;
37 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
38 tracing::warn!("Receiver disconnected");
39 }
40 return;
41 }
42 }
43 };
44}
45
46#[doc(hidden)]
47#[macro_export]
48macro_rules! handle_seq_error_ok {
49 ($fallible:expr, $response:expr) => {
50 match $fallible {
51 Ok(v) => v,
52 Err(e) => {
53 use $crate::response::Response;
54 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
55 tracing::warn!("Receiver disconnected");
56 }
57 return Ok(());
58 }
59 }
60 };
61}
62
63#[doc(hidden)]
64#[macro_export]
65macro_rules! handle_seq_error_stateaware_ok {
66 ($fallible:expr, $seq:expr) => {
67 match $fallible {
68 Ok(v) => v,
69 Err(e) => {
70 use $crate::response::Response;
71 use $crate::sequence::SequenceState;
72 if let Err(_) = $seq
73 .responder()
74 .send(Response::InternalError(e.into()))
75 .await
76 {
77 tracing::warn!("Receiver disconnected");
78 }
79 $seq.set_state(SequenceState::Error);
80 return Ok(());
81 }
82 }
83 };
84}
85
86#[doc(hidden)]
87#[macro_export]
88macro_rules! handle_pipeline_forward_error {
89 ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
90 match $fallible {
91 Ok(v) => v,
92 Err(e) => {
93 let (tokenizer, pipeline_name) = {
94 let pipeline = get_mut_arcmutex!($pipeline);
95 let pipeline_name = pipeline.name();
96 let tokenizer = pipeline.tokenizer();
97 (tokenizer, pipeline_name)
98 };
99 use $crate::response::Response;
100 use $crate::sequence::SequenceState;
101 use $crate::response::SYSTEM_FINGERPRINT;
102 use tracing::error;
103 error!("{} - Model failed with error: {:?}", $stage, &e);
104 for seq in $seq_slice.iter_mut() {
105 let start = seq.prompt_tokens().min(seq.get_toks().len());
107 let res = match &tokenizer {
108 Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
109 Ok(t) => t,
110 Err(_) => "".to_string(),
111 },
112 None => "".to_string(),
113 };
114
115 if seq.get_mut_group().is_chat {
116 let choice = Choice {
117 finish_reason: "error".to_string(),
118 index: seq.get_response_index(),
119 message: ResponseMessage {
120 content: Some(res),
121 role: "assistant".to_string(),
122 tool_calls: None,
123 },
124 logprobs: None,
125 };
126 seq.add_choice_to_group(choice);
127 } else {
128 let choice = CompletionChoice {
129 finish_reason: "error".to_string(),
130 index: seq.get_response_index(),
131 text: res,
132 logprobs: None,
133 };
134 seq.add_completion_choice_to_group(choice);
135 }
136 }
137 for seq in $seq_slice.iter_mut() {
138 let group = seq.get_mut_group();
140
141 if group.is_chat {
142 let partial_completion_response = ChatCompletionResponse {
143 id: seq.id().to_string(),
144 choices: group.get_choices().to_vec(),
145 created: seq.creation_time(),
146 model: pipeline_name.clone(),
147 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
148 object: "chat.completion".to_string(),
149 usage: group.get_usage(),
150 };
151
152 seq.responder()
153 .send(Response::ModelError(
154 e.to_string(),
155 partial_completion_response
156 ))
157 .await
158 .unwrap();
159 } else {
160 let partial_completion_response = CompletionResponse {
161 id: seq.id().to_string(),
162 choices: group.get_completion_choices().to_vec(),
163 created: seq.creation_time(),
164 model: pipeline_name.clone(),
165 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
166 object: "text_completion".to_string(),
167 usage: group.get_usage(),
168 };
169
170 seq.responder()
171 .send(Response::CompletionModelError(
172 e.to_string(),
173 partial_completion_response
174 ))
175 .await
176 .unwrap();
177 }
178 }
179 for seq in $seq_slice.iter_mut() {
180 seq.set_state(SequenceState::Error);
182 }
183
184 let p = get_mut_arcmutex!($pipeline);
185 p.set_none_cache($seq_slice, true, true, false);
189 get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
190
191 continue $label;
192 }
193 }
194 };
195}
196
197#[doc(hidden)]
198#[macro_export]
199macro_rules! get_mut_group {
200 ($this:expr) => {
201 loop {
202 if let Ok(inner) = $this.group.try_lock() {
203 break inner;
204 }
205 std::thread::yield_now();
207 }
208 };
209}
210
211#[doc(hidden)]
212#[macro_export]
213macro_rules! serde_default_fn {
214 ($t:ty, $name:ident, $v:expr) => {
215 fn $name() -> $t {
216 $v
217 }
218 };
219}
220
221#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
223pub const fn paged_attn_supported() -> bool {
224 true
225}
226
227#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
229pub const fn paged_attn_supported() -> bool {
230 false
231}
232
233#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
235pub const fn using_flash_attn() -> bool {
236 false
237}
238
239#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
241pub const fn using_flash_attn() -> bool {
242 true
243}