mistralrs_core/utils/
mod.rs1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tokenizer;
8pub(crate) mod tokens;
9pub(crate) mod unvarbuilder;
10pub(crate) mod varbuilder_utils;
11
12#[doc(hidden)]
13#[macro_export]
14macro_rules! get_mut_arcmutex {
15 ($thing:expr) => {
16 loop {
17 if let Ok(inner) = $thing.try_lock() {
18 break inner;
19 }
20 }
21 };
22}
23
24#[doc(hidden)]
25#[macro_export]
26macro_rules! handle_seq_error {
27 ($fallible:expr, $response:expr) => {
28 match $fallible {
29 Ok(v) => v,
30 Err(e) => {
31 use $crate::response::Response;
32 $response
33 .send(Response::InternalError(e.into()))
34 .await
35 .expect("Expected receiver.");
36 return;
37 }
38 }
39 };
40}
41
42#[doc(hidden)]
43#[macro_export]
44macro_rules! handle_seq_error_ok {
45 ($fallible:expr, $response:expr) => {
46 match $fallible {
47 Ok(v) => v,
48 Err(e) => {
49 use $crate::response::Response;
50 $response
51 .send(Response::InternalError(e.into()))
52 .await
53 .expect("Expected receiver.");
54 return Ok(());
55 }
56 }
57 };
58}
59
60#[doc(hidden)]
61#[macro_export]
62macro_rules! handle_seq_error_stateaware_ok {
63 ($fallible:expr, $seq:expr) => {
64 match $fallible {
65 Ok(v) => v,
66 Err(e) => {
67 use $crate::response::Response;
68 use $crate::sequence::SequenceState;
69 $seq.responder()
70 .send(Response::InternalError(e.into()))
71 .await
72 .expect("Expected receiver.");
73 $seq.set_state(SequenceState::Error);
74 return Ok(());
75 }
76 }
77 };
78}
79
80#[doc(hidden)]
81#[macro_export]
82macro_rules! handle_pipeline_forward_error {
83 ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
84 match $fallible {
85 Ok(v) => v,
86 Err(e) => {
87 let (tokenizer, pipeline_name) = {
88 let pipeline = get_mut_arcmutex!($pipeline);
89 let pipeline_name = pipeline.name();
90 let tokenizer = pipeline.tokenizer();
91 (tokenizer, pipeline_name)
92 };
93 use $crate::response::Response;
94 use $crate::sequence::SequenceState;
95 use $crate::response::SYSTEM_FINGERPRINT;
96 use tracing::error;
97 error!("{} - Model failed with error: {:?}", $stage, &e);
98 for seq in $seq_slice.iter_mut() {
99 let res = match &tokenizer
101 {
102 Some(tok) => match tok.decode(&seq.get_toks()[seq.prompt_tokens()..], false) {
103 Ok(t) => t,
104 Err(_) => "".to_string()
105 },
106 None => "".to_string()
107 };
108
109 if seq.get_mut_group().is_chat {
110 let choice = Choice {
111 finish_reason: "error".to_string(),
112 index: seq.get_response_index(),
113 message: ResponseMessage {
114 content: Some(res),
115 role: "assistant".to_string(),
116 tool_calls: None,
117 },
118 logprobs: None,
119 };
120 seq.add_choice_to_group(choice);
121 } else {
122 let choice = CompletionChoice {
123 finish_reason: "error".to_string(),
124 index: seq.get_response_index(),
125 text: res,
126 logprobs: None,
127 };
128 seq.add_completion_choice_to_group(choice);
129 }
130 }
131 for seq in $seq_slice.iter_mut() {
132 let group = seq.get_mut_group();
134
135 if group.is_chat {
136 let partial_completion_response = ChatCompletionResponse {
137 id: seq.id().to_string(),
138 choices: group.get_choices().to_vec(),
139 created: seq.creation_time(),
140 model: pipeline_name.clone(),
141 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
142 object: "chat.completion".to_string(),
143 usage: group.get_usage(),
144 };
145
146 seq.responder()
147 .send(Response::ModelError(
148 e.to_string(),
149 partial_completion_response
150 ))
151 .await
152 .unwrap();
153 } else {
154 let partial_completion_response = CompletionResponse {
155 id: seq.id().to_string(),
156 choices: group.get_completion_choices().to_vec(),
157 created: seq.creation_time(),
158 model: pipeline_name.clone(),
159 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
160 object: "text_completion".to_string(),
161 usage: group.get_usage(),
162 };
163
164 seq.responder()
165 .send(Response::CompletionModelError(
166 e.to_string(),
167 partial_completion_response
168 ))
169 .await
170 .unwrap();
171 }
172 }
173 for seq in $seq_slice.iter_mut() {
174 seq.set_state(SequenceState::Error);
176 }
177
178 let p = get_mut_arcmutex!($pipeline);
179 p.set_none_cache($seq_slice, true, true, false);
183 get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
184
185 continue $label;
186 }
187 }
188 };
189}
190
191#[doc(hidden)]
192#[macro_export]
193macro_rules! get_mut_group {
194 ($this:expr) => {
195 loop {
196 if let Ok(inner) = $this.group.try_lock() {
197 break inner;
198 }
199 }
200 };
201}
202
203#[doc(hidden)]
204#[macro_export]
205macro_rules! serde_default_fn {
206 ($t:ty, $name:ident, $v:expr) => {
207 fn $name() -> $t {
208 $v
209 }
210 };
211}
212
213#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
215pub const fn paged_attn_supported() -> bool {
216 true
217}
218
219#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
221pub const fn paged_attn_supported() -> bool {
222 false
223}
224
225#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
227pub const fn using_flash_attn() -> bool {
228 false
229}
230
231#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
233pub const fn using_flash_attn() -> bool {
234 true
235}