mistralrs_core/utils/
mod.rs1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tokenizer;
8pub(crate) mod tokens;
9pub(crate) mod unvarbuilder;
10pub(crate) mod varbuilder_utils;
11
12#[doc(hidden)]
13#[macro_export]
14macro_rules! get_mut_arcmutex {
15 ($thing:expr) => {
16 loop {
17 if let Ok(inner) = $thing.try_lock() {
18 break inner;
19 }
20 }
21 };
22}
23
24#[doc(hidden)]
25#[macro_export]
26macro_rules! handle_seq_error {
27 ($fallible:expr, $response:expr) => {
28 match $fallible {
29 Ok(v) => v,
30 Err(e) => {
31 use $crate::response::Response;
32 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
33 tracing::warn!("Receiver disconnected");
34 }
35 return;
36 }
37 }
38 };
39}
40
41#[doc(hidden)]
42#[macro_export]
43macro_rules! handle_seq_error_ok {
44 ($fallible:expr, $response:expr) => {
45 match $fallible {
46 Ok(v) => v,
47 Err(e) => {
48 use $crate::response::Response;
49 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
50 tracing::warn!("Receiver disconnected");
51 }
52 return Ok(());
53 }
54 }
55 };
56}
57
58#[doc(hidden)]
59#[macro_export]
60macro_rules! handle_seq_error_stateaware_ok {
61 ($fallible:expr, $seq:expr) => {
62 match $fallible {
63 Ok(v) => v,
64 Err(e) => {
65 use $crate::response::Response;
66 use $crate::sequence::SequenceState;
67 if let Err(_) = $seq
68 .responder()
69 .send(Response::InternalError(e.into()))
70 .await
71 {
72 tracing::warn!("Receiver disconnected");
73 }
74 $seq.set_state(SequenceState::Error);
75 return Ok(());
76 }
77 }
78 };
79}
80
81#[doc(hidden)]
82#[macro_export]
83macro_rules! handle_pipeline_forward_error {
84 ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
85 match $fallible {
86 Ok(v) => v,
87 Err(e) => {
88 let (tokenizer, pipeline_name) = {
89 let pipeline = get_mut_arcmutex!($pipeline);
90 let pipeline_name = pipeline.name();
91 let tokenizer = pipeline.tokenizer();
92 (tokenizer, pipeline_name)
93 };
94 use $crate::response::Response;
95 use $crate::sequence::SequenceState;
96 use $crate::response::SYSTEM_FINGERPRINT;
97 use tracing::error;
98 error!("{} - Model failed with error: {:?}", $stage, &e);
99 for seq in $seq_slice.iter_mut() {
100 let start = seq.prompt_tokens().min(seq.get_toks().len());
102 let res = match &tokenizer {
103 Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
104 Ok(t) => t,
105 Err(_) => "".to_string(),
106 },
107 None => "".to_string(),
108 };
109
110 if seq.get_mut_group().is_chat {
111 let choice = Choice {
112 finish_reason: "error".to_string(),
113 index: seq.get_response_index(),
114 message: ResponseMessage {
115 content: Some(res),
116 role: "assistant".to_string(),
117 tool_calls: None,
118 },
119 logprobs: None,
120 };
121 seq.add_choice_to_group(choice);
122 } else {
123 let choice = CompletionChoice {
124 finish_reason: "error".to_string(),
125 index: seq.get_response_index(),
126 text: res,
127 logprobs: None,
128 };
129 seq.add_completion_choice_to_group(choice);
130 }
131 }
132 for seq in $seq_slice.iter_mut() {
133 let group = seq.get_mut_group();
135
136 if group.is_chat {
137 let partial_completion_response = ChatCompletionResponse {
138 id: seq.id().to_string(),
139 choices: group.get_choices().to_vec(),
140 created: seq.creation_time(),
141 model: pipeline_name.clone(),
142 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
143 object: "chat.completion".to_string(),
144 usage: group.get_usage(),
145 };
146
147 seq.responder()
148 .send(Response::ModelError(
149 e.to_string(),
150 partial_completion_response
151 ))
152 .await
153 .unwrap();
154 } else {
155 let partial_completion_response = CompletionResponse {
156 id: seq.id().to_string(),
157 choices: group.get_completion_choices().to_vec(),
158 created: seq.creation_time(),
159 model: pipeline_name.clone(),
160 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
161 object: "text_completion".to_string(),
162 usage: group.get_usage(),
163 };
164
165 seq.responder()
166 .send(Response::CompletionModelError(
167 e.to_string(),
168 partial_completion_response
169 ))
170 .await
171 .unwrap();
172 }
173 }
174 for seq in $seq_slice.iter_mut() {
175 seq.set_state(SequenceState::Error);
177 }
178
179 let p = get_mut_arcmutex!($pipeline);
180 p.set_none_cache($seq_slice, true, true, false);
184 get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
185
186 continue $label;
187 }
188 }
189 };
190}
191
192#[doc(hidden)]
193#[macro_export]
194macro_rules! get_mut_group {
195 ($this:expr) => {
196 loop {
197 if let Ok(inner) = $this.group.try_lock() {
198 break inner;
199 }
200 }
201 };
202}
203
204#[doc(hidden)]
205#[macro_export]
206macro_rules! serde_default_fn {
207 ($t:ty, $name:ident, $v:expr) => {
208 fn $name() -> $t {
209 $v
210 }
211 };
212}
213
214#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
216pub const fn paged_attn_supported() -> bool {
217 true
218}
219
220#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
222pub const fn paged_attn_supported() -> bool {
223 false
224}
225
226#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
228pub const fn using_flash_attn() -> bool {
229 false
230}
231
232#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
234pub const fn using_flash_attn() -> bool {
235 true
236}