mistralrs_core/utils/
mod.rs1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16 ($thing:expr) => {
17 loop {
18 if let Ok(inner) = $thing.try_lock() {
19 break inner;
20 }
21 }
22 };
23}
24
25#[doc(hidden)]
26#[macro_export]
27macro_rules! handle_seq_error {
28 ($fallible:expr, $response:expr) => {
29 match $fallible {
30 Ok(v) => v,
31 Err(e) => {
32 use $crate::response::Response;
33 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
34 tracing::warn!("Receiver disconnected");
35 }
36 return;
37 }
38 }
39 };
40}
41
42#[doc(hidden)]
43#[macro_export]
44macro_rules! handle_seq_error_ok {
45 ($fallible:expr, $response:expr) => {
46 match $fallible {
47 Ok(v) => v,
48 Err(e) => {
49 use $crate::response::Response;
50 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
51 tracing::warn!("Receiver disconnected");
52 }
53 return Ok(());
54 }
55 }
56 };
57}
58
59#[doc(hidden)]
60#[macro_export]
61macro_rules! handle_seq_error_stateaware_ok {
62 ($fallible:expr, $seq:expr) => {
63 match $fallible {
64 Ok(v) => v,
65 Err(e) => {
66 use $crate::response::Response;
67 use $crate::sequence::SequenceState;
68 if let Err(_) = $seq
69 .responder()
70 .send(Response::InternalError(e.into()))
71 .await
72 {
73 tracing::warn!("Receiver disconnected");
74 }
75 $seq.set_state(SequenceState::Error);
76 return Ok(());
77 }
78 }
79 };
80}
81
82#[doc(hidden)]
83#[macro_export]
84macro_rules! handle_pipeline_forward_error {
85 ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
86 match $fallible {
87 Ok(v) => v,
88 Err(e) => {
89 let (tokenizer, pipeline_name) = {
90 let pipeline = get_mut_arcmutex!($pipeline);
91 let pipeline_name = pipeline.name();
92 let tokenizer = pipeline.tokenizer();
93 (tokenizer, pipeline_name)
94 };
95 use $crate::response::Response;
96 use $crate::sequence::SequenceState;
97 use $crate::response::SYSTEM_FINGERPRINT;
98 use tracing::error;
99 error!("{} - Model failed with error: {:?}", $stage, &e);
100 for seq in $seq_slice.iter_mut() {
101 let start = seq.prompt_tokens().min(seq.get_toks().len());
103 let res = match &tokenizer {
104 Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
105 Ok(t) => t,
106 Err(_) => "".to_string(),
107 },
108 None => "".to_string(),
109 };
110
111 if seq.get_mut_group().is_chat {
112 let choice = Choice {
113 finish_reason: "error".to_string(),
114 index: seq.get_response_index(),
115 message: ResponseMessage {
116 content: Some(res),
117 role: "assistant".to_string(),
118 tool_calls: None,
119 },
120 logprobs: None,
121 };
122 seq.add_choice_to_group(choice);
123 } else {
124 let choice = CompletionChoice {
125 finish_reason: "error".to_string(),
126 index: seq.get_response_index(),
127 text: res,
128 logprobs: None,
129 };
130 seq.add_completion_choice_to_group(choice);
131 }
132 }
133 for seq in $seq_slice.iter_mut() {
134 let group = seq.get_mut_group();
136
137 if group.is_chat {
138 let partial_completion_response = ChatCompletionResponse {
139 id: seq.id().to_string(),
140 choices: group.get_choices().to_vec(),
141 created: seq.creation_time(),
142 model: pipeline_name.clone(),
143 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
144 object: "chat.completion".to_string(),
145 usage: group.get_usage(),
146 };
147
148 seq.responder()
149 .send(Response::ModelError(
150 e.to_string(),
151 partial_completion_response
152 ))
153 .await
154 .unwrap();
155 } else {
156 let partial_completion_response = CompletionResponse {
157 id: seq.id().to_string(),
158 choices: group.get_completion_choices().to_vec(),
159 created: seq.creation_time(),
160 model: pipeline_name.clone(),
161 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
162 object: "text_completion".to_string(),
163 usage: group.get_usage(),
164 };
165
166 seq.responder()
167 .send(Response::CompletionModelError(
168 e.to_string(),
169 partial_completion_response
170 ))
171 .await
172 .unwrap();
173 }
174 }
175 for seq in $seq_slice.iter_mut() {
176 seq.set_state(SequenceState::Error);
178 }
179
180 let p = get_mut_arcmutex!($pipeline);
181 p.set_none_cache($seq_slice, true, true, false);
185 get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
186
187 continue $label;
188 }
189 }
190 };
191}
192
193#[doc(hidden)]
194#[macro_export]
195macro_rules! get_mut_group {
196 ($this:expr) => {
197 loop {
198 if let Ok(inner) = $this.group.try_lock() {
199 break inner;
200 }
201 }
202 };
203}
204
205#[doc(hidden)]
206#[macro_export]
207macro_rules! serde_default_fn {
208 ($t:ty, $name:ident, $v:expr) => {
209 fn $name() -> $t {
210 $v
211 }
212 };
213}
214
215#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
217pub const fn paged_attn_supported() -> bool {
218 true
219}
220
221#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
223pub const fn paged_attn_supported() -> bool {
224 false
225}
226
227#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
229pub const fn using_flash_attn() -> bool {
230 false
231}
232
233#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
235pub const fn using_flash_attn() -> bool {
236 true
237}