mistralrs_core/utils/
mod.rs1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod log;
4pub(crate) mod memory_usage;
5pub(crate) mod model_config;
6pub(crate) mod normal;
7pub(crate) mod progress;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16 ($thing:expr) => {
17 loop {
18 if let Ok(inner) = $thing.try_lock() {
19 break inner;
20 }
21 }
22 };
23}
24
25#[doc(hidden)]
26#[macro_export]
27macro_rules! handle_seq_error {
28 ($fallible:expr, $response:expr) => {
29 match $fallible {
30 Ok(v) => v,
31 Err(e) => {
32 use $crate::response::Response;
33 $response
34 .send(Response::InternalError(e.into()))
35 .await
36 .expect("Expected receiver.");
37 return;
38 }
39 }
40 };
41}
42
43#[doc(hidden)]
44#[macro_export]
45macro_rules! handle_seq_error_ok {
46 ($fallible:expr, $response:expr) => {
47 match $fallible {
48 Ok(v) => v,
49 Err(e) => {
50 use $crate::response::Response;
51 $response
52 .send(Response::InternalError(e.into()))
53 .await
54 .expect("Expected receiver.");
55 return Ok(());
56 }
57 }
58 };
59}
60
61#[doc(hidden)]
62#[macro_export]
63macro_rules! handle_seq_error_stateaware_ok {
64 ($fallible:expr, $seq:expr) => {
65 match $fallible {
66 Ok(v) => v,
67 Err(e) => {
68 use $crate::response::Response;
69 use $crate::sequence::SequenceState;
70 $seq.responder()
71 .send(Response::InternalError(e.into()))
72 .await
73 .expect("Expected receiver.");
74 $seq.set_state(SequenceState::Error);
75 return Ok(());
76 }
77 }
78 };
79}
80
81#[doc(hidden)]
82#[macro_export]
83macro_rules! handle_pipeline_forward_error {
84 ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
85 match $fallible {
86 Ok(v) => v,
87 Err(e) => {
88 let (tokenizer, pipeline_name) = {
89 let pipeline = get_mut_arcmutex!($pipeline);
90 let pipeline_name = pipeline.name();
91 let tokenizer = pipeline.tokenizer();
92 (tokenizer, pipeline_name)
93 };
94 use $crate::response::Response;
95 use $crate::sequence::SequenceState;
96 use $crate::response::SYSTEM_FINGERPRINT;
97 use tracing::error;
98 error!("{} - Model failed with error: {:?}", $stage, &e);
99 for seq in $seq_slice.iter_mut() {
100 let res = match &tokenizer
102 {
103 Some(tok) => match tok.decode(&seq.get_toks()[seq.prompt_tokens()..], false) {
104 Ok(t) => t,
105 Err(_) => "".to_string()
106 },
107 None => "".to_string()
108 };
109
110 if seq.get_mut_group().is_chat {
111 let choice = Choice {
112 finish_reason: "error".to_string(),
113 index: seq.get_response_index(),
114 message: ResponseMessage {
115 content: Some(res),
116 role: "assistant".to_string(),
117 tool_calls: None,
118 },
119 logprobs: None,
120 };
121 seq.add_choice_to_group(choice);
122 } else {
123 let choice = CompletionChoice {
124 finish_reason: "error".to_string(),
125 index: seq.get_response_index(),
126 text: res,
127 logprobs: None,
128 };
129 seq.add_completion_choice_to_group(choice);
130 }
131 }
132 for seq in $seq_slice.iter_mut() {
133 let group = seq.get_mut_group();
135
136 if group.is_chat {
137 let partial_completion_response = ChatCompletionResponse {
138 id: seq.id().to_string(),
139 choices: group.get_choices().to_vec(),
140 created: seq.creation_time(),
141 model: pipeline_name.clone(),
142 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
143 object: "chat.completion".to_string(),
144 usage: group.get_usage(),
145 };
146
147 seq.responder()
148 .send(Response::ModelError(
149 e.to_string(),
150 partial_completion_response
151 ))
152 .await
153 .unwrap();
154 } else {
155 let partial_completion_response = CompletionResponse {
156 id: seq.id().to_string(),
157 choices: group.get_completion_choices().to_vec(),
158 created: seq.creation_time(),
159 model: pipeline_name.clone(),
160 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
161 object: "text_completion".to_string(),
162 usage: group.get_usage(),
163 };
164
165 seq.responder()
166 .send(Response::CompletionModelError(
167 e.to_string(),
168 partial_completion_response
169 ))
170 .await
171 .unwrap();
172 }
173 }
174 for seq in $seq_slice.iter_mut() {
175 seq.set_state(SequenceState::Error);
177 }
178
179 let p = get_mut_arcmutex!($pipeline);
180 p.set_none_cache($seq_slice, true, true, false);
184 get_mut_arcmutex!($prefix_cacher).evict_all_to_cpu().unwrap();
185
186 continue $label;
187 }
188 }
189 };
190}
191
192#[doc(hidden)]
193#[macro_export]
194macro_rules! get_mut_group {
195 ($this:expr) => {
196 loop {
197 if let Ok(inner) = $this.group.try_lock() {
198 break inner;
199 }
200 }
201 };
202}
203
204#[doc(hidden)]
205#[macro_export]
206macro_rules! serde_default_fn {
207 ($t:ty, $name:ident, $v:expr) => {
208 fn $name() -> $t {
209 $v
210 }
211 };
212}
213
214#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
216pub const fn paged_attn_supported() -> bool {
217 true
218}
219
220#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
222pub const fn paged_attn_supported() -> bool {
223 false
224}
225
226#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
228pub const fn using_flash_attn() -> bool {
229 false
230}
231
232#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
234pub const fn using_flash_attn() -> bool {
235 true
236}