mistralrs_core/utils/
mod.rs1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16 ($thing:expr) => {
17 loop {
18 if let Ok(inner) = $thing.try_lock() {
19 break inner;
20 }
21 std::thread::yield_now();
25 }
26 };
27}
28
29#[doc(hidden)]
30#[macro_export]
31macro_rules! handle_seq_error {
32 ($fallible:expr, $response:expr) => {
33 match $fallible {
34 Ok(v) => v,
35 Err(e) => {
36 use $crate::response::Response;
37 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
38 tracing::warn!("Receiver disconnected");
39 }
40 return;
41 }
42 }
43 };
44}
45
46#[doc(hidden)]
47#[macro_export]
48macro_rules! handle_seq_error_ok {
49 ($fallible:expr, $response:expr) => {
50 match $fallible {
51 Ok(v) => v,
52 Err(e) => {
53 use $crate::response::Response;
54 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
55 tracing::warn!("Receiver disconnected");
56 }
57 return Ok(());
58 }
59 }
60 };
61}
62
63#[doc(hidden)]
64#[macro_export]
65macro_rules! handle_seq_error_stateaware_ok {
66 ($fallible:expr, $seq:expr) => {
67 match $fallible {
68 Ok(v) => v,
69 Err(e) => {
70 use $crate::response::Response;
71 use $crate::sequence::SequenceState;
72 if let Err(_) = $seq
73 .responder()
74 .send(Response::InternalError(e.into()))
75 .await
76 {
77 tracing::warn!("Receiver disconnected");
78 }
79 $seq.set_state(SequenceState::Error);
80 return Ok(());
81 }
82 }
83 };
84}
85
86#[doc(hidden)]
87#[macro_export]
88macro_rules! handle_pipeline_forward_error {
89 ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
90 match $fallible {
91 Ok(v) => v,
92 Err(e) => {
93 let (tokenizer, pipeline_name) = {
94 let pipeline = get_mut_arcmutex!($pipeline);
95 let pipeline_name = pipeline.name();
96 let tokenizer = pipeline.tokenizer();
97 (tokenizer, pipeline_name)
98 };
99 use $crate::response::Response;
100 use $crate::sequence::SequenceState;
101 use $crate::response::SYSTEM_FINGERPRINT;
102 use tracing::error;
103 error!("{} - Model failed with error: {:?}", $stage, &e);
104 for seq in $seq_slice.iter_mut() {
105 let start = seq.prompt_tokens().min(seq.get_toks().len());
107 let res = match &tokenizer {
108 Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
109 Ok(t) => t,
110 Err(_) => "".to_string(),
111 },
112 None => "".to_string(),
113 };
114
115 if seq.get_mut_group().is_chat {
116 let choice = Choice {
117 finish_reason: "error".to_string(),
118 index: seq.get_response_index(),
119 message: ResponseMessage {
120 content: Some(res),
121 role: "assistant".to_string(),
122 tool_calls: None,
123 reasoning_content: None,
124 },
125 logprobs: None,
126 };
127 seq.add_choice_to_group(choice);
128 } else {
129 let choice = CompletionChoice {
130 finish_reason: "error".to_string(),
131 index: seq.get_response_index(),
132 text: res,
133 logprobs: None,
134 };
135 seq.add_completion_choice_to_group(choice);
136 }
137 }
138 for seq in $seq_slice.iter_mut() {
139 let group = seq.get_mut_group();
141
142 if group.is_chat {
143 let partial_completion_response = ChatCompletionResponse {
144 id: seq.id().to_string(),
145 choices: group.get_choices().to_vec(),
146 created: seq.creation_time(),
147 model: pipeline_name.clone(),
148 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
149 object: "chat.completion".to_string(),
150 usage: group.get_usage(),
151 };
152
153 seq.responder()
154 .send(Response::ModelError(
155 e.to_string(),
156 partial_completion_response
157 ))
158 .await
159 .unwrap();
160 } else {
161 let partial_completion_response = CompletionResponse {
162 id: seq.id().to_string(),
163 choices: group.get_completion_choices().to_vec(),
164 created: seq.creation_time(),
165 model: pipeline_name.clone(),
166 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
167 object: "text_completion".to_string(),
168 usage: group.get_usage(),
169 };
170
171 seq.responder()
172 .send(Response::CompletionModelError(
173 e.to_string(),
174 partial_completion_response
175 ))
176 .await
177 .unwrap();
178 }
179 }
180 for seq in $seq_slice.iter_mut() {
181 seq.set_state(SequenceState::Error);
183 }
184
185 let p = get_mut_arcmutex!($pipeline);
186 p.set_none_cache($seq_slice, true, true, false);
190 get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
191
192 continue $label;
193 }
194 }
195 };
196}
197
198#[doc(hidden)]
199#[macro_export]
200macro_rules! get_mut_group {
201 ($this:expr) => {
202 loop {
203 if let Ok(inner) = $this.group.try_lock() {
204 break inner;
205 }
206 std::thread::yield_now();
208 }
209 };
210}
211
212#[doc(hidden)]
213#[macro_export]
214macro_rules! serde_default_fn {
215 ($t:ty, $name:ident, $v:expr) => {
216 fn $name() -> $t {
217 $v
218 }
219 };
220}
221
222#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
224pub const fn paged_attn_supported() -> bool {
225 true
226}
227
228#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
230pub const fn paged_attn_supported() -> bool {
231 false
232}
233
234#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
236pub const fn using_flash_attn() -> bool {
237 false
238}
239
240#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
242pub const fn using_flash_attn() -> bool {
243 true
244}