1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod normal_loaders;
4mod vision_loaders;
5pub use auto_device_map::AutoDeviceMapParams;
6use auto_device_map::NonMappedSubModel;
7
8use std::{
9 fmt::{self, Debug},
10 path::PathBuf,
11 str::FromStr,
12 sync::Arc,
13};
14
15use anyhow::Result;
16use as_any::AsAny;
17use candle_core::{DType, Device};
18use mistralrs_quant::{IsqType, QuantizedConfig};
19use serde::Deserialize;
20use tokio::sync::Mutex;
21
22pub use normal_loaders::{
23 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
24 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata,
25 NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader,
26 Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
27};
28
29pub use vision_loaders::{
30 AutoVisionLoader, Gemma3Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
31 MiniCpmOLoader, Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader, Qwen2_5VLLoader,
32 VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
33};
34
35pub use diffusion_loaders::{
36 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
37 DiffusionModelPathsInner, FluxLoader,
38};
39
40use crate::{
41 paged_attention::ModelConfigLike, DeviceMapMetadata, DeviceMapSetting, PagedAttentionConfig,
42 TryIntoDType,
43};
44
45use super::{paths::AdapterPaths, Pipeline};
46
47pub trait ModelPaths: AsAny + Debug + Send + Sync {
50 fn get_weight_filenames(&self) -> &[PathBuf];
52
53 fn get_config_filename(&self) -> &PathBuf;
57
58 fn get_tokenizer_filename(&self) -> &PathBuf;
62
63 fn get_template_filename(&self) -> &Option<PathBuf>;
67
68 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
70
71 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
73
74 fn get_processor_config(&self) -> &Option<PathBuf>;
76
77 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
79
80 fn get_adapter_paths(&self) -> &AdapterPaths;
82}
83
84#[derive(Clone, Debug)]
85pub struct LocalModelPaths<P: Debug> {
87 pub tokenizer_filename: P,
88 pub config_filename: P,
89 pub template_filename: Option<P>,
90 pub filenames: Vec<P>,
91 pub adapter_paths: AdapterPaths,
92 pub gen_conf: Option<P>,
93 pub preprocessor_config: Option<P>,
94 pub processor_config: Option<P>,
95 pub chat_template_json_filename: Option<P>,
96}
97
98impl<P: Debug> LocalModelPaths<P> {
99 #[allow(clippy::too_many_arguments)]
100 pub fn new(
101 tokenizer_filename: P,
102 config_filename: P,
103 template_filename: P,
104 filenames: Vec<P>,
105 adapter_paths: AdapterPaths,
106 gen_conf: Option<P>,
107 preprocessor_config: Option<P>,
108 processor_config: Option<P>,
109 chat_template_json_filename: Option<P>,
110 ) -> Self {
111 Self {
112 tokenizer_filename,
113 config_filename,
114 template_filename: Some(template_filename),
115 filenames,
116 adapter_paths,
117 gen_conf,
118 preprocessor_config,
119 processor_config,
120 chat_template_json_filename,
121 }
122 }
123}
124
125impl ModelPaths for LocalModelPaths<PathBuf> {
126 fn get_config_filename(&self) -> &PathBuf {
127 &self.config_filename
128 }
129 fn get_tokenizer_filename(&self) -> &PathBuf {
130 &self.tokenizer_filename
131 }
132 fn get_weight_filenames(&self) -> &[PathBuf] {
133 &self.filenames
134 }
135 fn get_template_filename(&self) -> &Option<PathBuf> {
136 &self.template_filename
137 }
138 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
139 self.gen_conf.as_ref()
140 }
141 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
142 &self.preprocessor_config
143 }
144 fn get_processor_config(&self) -> &Option<PathBuf> {
145 &self.processor_config
146 }
147 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
148 &self.chat_template_json_filename
149 }
150 fn get_adapter_paths(&self) -> &AdapterPaths {
151 &self.adapter_paths
152 }
153}
154
155#[derive(Debug, Clone)]
156pub enum TokenSource {
158 Literal(String),
159 EnvVar(String),
160 Path(String),
161 CacheToken,
162 None,
163}
164
165impl FromStr for TokenSource {
166 type Err = String;
167
168 fn from_str(s: &str) -> Result<Self, Self::Err> {
169 let parts: Vec<&str> = s.splitn(2, ':').collect();
170 match parts[0] {
171 "literal" => parts
172 .get(1)
173 .map(|&value| TokenSource::Literal(value.to_string()))
174 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
175 "env" => Ok(TokenSource::EnvVar(
176 parts
177 .get(1)
178 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
179 .to_string(),
180 )),
181 "path" => parts
182 .get(1)
183 .map(|&value| TokenSource::Path(value.to_string()))
184 .ok_or_else(|| "Expected a value for 'path'".to_string()),
185 "cache" => Ok(TokenSource::CacheToken),
186 "none" => Ok(TokenSource::None),
187 _ => Err("Invalid token source format".to_string()),
188 }
189 }
190}
191
192impl fmt::Display for TokenSource {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 match self {
195 TokenSource::Literal(value) => write!(f, "literal:{}", value),
196 TokenSource::EnvVar(value) => write!(f, "env:{}", value),
197 TokenSource::Path(value) => write!(f, "path:{}", value),
198 TokenSource::CacheToken => write!(f, "cache"),
199 TokenSource::None => write!(f, "none"),
200 }
201 }
202}
203
204#[derive(Clone, Default, derive_more::From, strum::Display)]
206pub enum ModelKind {
207 #[default]
208 #[strum(to_string = "normal (no adapters)")]
209 Normal,
210
211 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
212 GgufQuantized { quant: QuantizationKind },
213
214 #[strum(to_string = "{adapter}")]
215 Adapter { adapter: AdapterKind },
216
217 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
218 GgufAdapter {
219 adapter: AdapterKind,
220 quant: QuantizationKind,
221 },
222
223 #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
224 Speculative {
225 target: Box<ModelKind>,
226 draft: Box<ModelKind>,
227 },
228
229 #[strum(to_string = "anymoe: target: `{target}`")]
230 AnyMoe { target: Box<ModelKind> },
231}
232
233#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
234#[strum(serialize_all = "kebab-case")]
235pub enum QuantizationKind {
236 Ggml,
238 Gguf,
240 Gptq,
242}
243
244#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
245#[strum(serialize_all = "kebab-case")]
246pub enum AdapterKind {
247 Lora,
249 XLora,
251}
252
253pub trait PrettyName: strum::EnumMessage + ToString {
255 fn pretty_name(&self) -> String {
256 match self.get_documentation() {
257 Some(s) => s.to_string(),
258 None => self.to_string(),
261 }
262 }
263}
264
265impl PrettyName for AdapterKind {}
266impl PrettyName for QuantizationKind {}
267
268impl ModelKind {
269 pub fn is_quantized(&self) -> bool {
271 self.quantized_kind().iter().any(|q| q.is_some())
272 }
273
274 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
275 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
276 }
277
278 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
279 use ModelKind::*;
280
281 match self {
282 Normal | Adapter { .. } => vec![None],
283 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
284 Speculative { target, draft } => {
285 let t = *target.clone();
286 let d = *draft.clone();
287
288 [t.quantized_kind(), d.quantized_kind()].concat()
289 }
290 AnyMoe { target } => target.quantized_kind(),
291 }
292 }
293
294 pub fn is_adapted(&self) -> bool {
296 self.adapted_kind().iter().any(|a| a.is_some())
297 }
298
299 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
300 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
301 }
302
303 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
304 use ModelKind::*;
305
306 match self {
307 Normal | GgufQuantized { .. } => vec![None],
308 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
309 Speculative { target, draft } => {
310 let t = *target.clone();
311 let d = *draft.clone();
312
313 [t.adapted_kind(), d.adapted_kind()].concat()
314 }
315 AnyMoe { target } => target.adapted_kind(),
316 }
317 }
318}
319
320#[derive(Deserialize)]
321pub struct QuantizationConfigShim {
322 quantization_config: Option<QuantizedConfig>,
323}
324
325impl QuantizationConfigShim {
326 pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
327 let QuantizationConfigShim {
328 quantization_config,
329 } = serde_json::from_str(config)?;
330
331 if let Some(quantization_config) = quantization_config {
332 Ok(quantization_config.pack_factor(dtype))
333 } else {
334 Ok(1)
335 }
336 }
337}
338
339pub trait DeviceMappedModelLoader {
340 fn non_mapped_max_act_size_elems(
343 &self,
344 config: &str,
345 params: &AutoDeviceMapParams,
346 ) -> Result<usize>;
347 fn mapped_max_act_size_elems(
349 &self,
350 config: &str,
351 params: &AutoDeviceMapParams,
352 prompt_chunksize: usize,
353 ) -> Result<usize>;
354 fn non_mapped_size_in_bytes(
356 &self,
357 config: &str,
358 dtype: DType,
359 weight_pack_factor: usize,
360 ) -> Result<usize>;
361 fn layer_sizes_in_bytes(
363 &self,
364 config: &str,
365 dtype: DType,
366 weight_pack_factor: usize,
367 ) -> Result<Vec<usize>>;
368 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
369 None
370 }
371 fn num_layers(&self, config: &str) -> Result<usize>;
372 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
373
374 #[allow(clippy::too_many_arguments)]
375 fn get_device_layers(
376 &self,
377 config: &str,
378 num_layers: usize,
379 layer_sizes_in_bytes: Vec<usize>,
380 non_mapped_size_in_bytes: usize,
381 total_model_size_in_bytes: usize,
382 devices: &[Device],
383 dtype: DType,
384 params: &AutoDeviceMapParams,
385 prompt_chunksize: usize,
386 paged_attn_config: Option<&PagedAttentionConfig>,
387 ) -> Result<DeviceMapMetadata>
388 where
389 Self: Sized,
390 {
391 auto_device_map::get_device_layers(
392 self,
393 config,
394 num_layers,
395 layer_sizes_in_bytes,
396 non_mapped_size_in_bytes,
397 total_model_size_in_bytes,
398 devices,
399 dtype,
400 params,
401 prompt_chunksize,
402 paged_attn_config,
403 )
404 }
405}
406
407pub trait Loader: Send + Sync {
428 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
432 fn load_model_from_hf(
433 &self,
434 revision: Option<String>,
435 token_source: TokenSource,
436 dtype: &dyn TryIntoDType,
437 device: &Device,
438 silent: bool,
439 mapper: DeviceMapSetting,
440 in_situ_quant: Option<IsqType>,
441 paged_attn_config: Option<PagedAttentionConfig>,
442 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
443
444 #[allow(
447 clippy::type_complexity,
448 clippy::too_many_arguments,
449 clippy::borrowed_box
450 )]
451 fn load_model_from_path(
452 &self,
453 paths: &Box<dyn ModelPaths>,
454 dtype: &dyn TryIntoDType,
455 device: &Device,
456 silent: bool,
457 mapper: DeviceMapSetting,
458 in_situ_quant: Option<IsqType>,
459 paged_attn_config: Option<PagedAttentionConfig>,
460 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
461
462 fn get_id(&self) -> String;
463 fn get_kind(&self) -> ModelKind;
464}