1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::isq::ImatrixDataSource;
3use super::isq::UqffFullSer;
4use super::{
5 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManager,
6 CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader, GeneralMetadata,
7 IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory, ModelKind, ModelPaths,
8 Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, TokenSource, VLlamaLoader,
9 VisionModel, VisionModelLoader, VisionPromptPrefixer,
10};
11use super::{
12 Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
13 Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, WorkerTransferData};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_tok_env;
20use crate::pipeline::sampling::sample_and_add_toks;
21use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
22use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
23use crate::prefix_cacher::PrefixCacheManagerV2;
24use crate::sequence::Sequence;
25use crate::utils::tokenizer::get_tokenizer;
26use crate::utils::varbuilder_utils::DeviceForLoadTensor;
27use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
28use crate::vision_models::preprocessor_config::PreProcessorConfig;
29use crate::vision_models::processor_config::ProcessorConfig;
30use crate::vision_models::ModelInputs;
31use crate::{
32 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
33 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
34 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
35};
36use anyhow::Result;
37use candle_core::{Device, Tensor, Var};
38use hf_hub::Cache;
39use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
40use indicatif::MultiProgress;
41use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
42use rand_isaac::Isaac64Rng;
43use regex_automata::meta::Regex;
44use std::any::Any;
45use std::borrow::Cow;
46use std::num::NonZeroUsize;
47use std::path::{Path, PathBuf};
48use std::str::FromStr;
49use std::sync::{Arc, RwLock};
50use std::time::Instant;
51use std::{env, fs};
52use tokenizers::Tokenizer;
53use tokio::sync::Mutex;
54use tracing::{info, warn};
55
56pub struct VisionPipeline {
57 model: Box<dyn VisionModel + Send + Sync>,
58 tokenizer: Arc<Tokenizer>,
59 chat_template: Arc<ChatTemplate>,
60 model_id: String,
61 metadata: Arc<GeneralMetadata>,
62 processor: Arc<dyn Processor + Send + Sync>,
63 preprocessor_config: Arc<PreProcessorConfig>,
64 topology: Option<Topology>,
65 silent: bool,
66 prefixer: Arc<dyn VisionPromptPrefixer>,
67 mapper: Box<dyn DeviceMapper + Send + Sync>,
68
69 template_filename: Option<PathBuf>,
71 generation_config: Option<PathBuf>,
72 config: String,
73 processor_filename: Option<PathBuf>,
74 preprocessor_filename: Option<PathBuf>,
75 imatrix: Option<PathBuf>,
76}
77
78pub struct VisionLoader {
80 inner: Box<dyn VisionModelLoader>,
81 model_id: String,
82 config: VisionSpecificConfig,
83 kind: ModelKind,
84 chat_template: Option<String>,
85 tokenizer_json: Option<String>,
86 xlora_model_id: Option<String>,
87 xlora_order: Option<Ordering>,
88 token_source: RwLock<Option<TokenSource>>,
89 revision: RwLock<Option<String>>,
90 from_uqff: RwLock<Option<PathBuf>>,
91 jinja_explicit: Option<String>,
92 hf_cache_path: Option<PathBuf>,
93 lora_adapter_ids: Option<Vec<String>>,
94}
95
96#[derive(Default)]
97pub struct VisionLoaderBuilder {
99 model_id: Option<String>,
100 config: VisionSpecificConfig,
101 kind: ModelKind,
102 chat_template: Option<String>,
103 tokenizer_json: Option<String>,
104 jinja_explicit: Option<String>,
105 hf_cache_path: Option<PathBuf>,
106 lora_adapter_ids: Option<Vec<String>>,
107}
108
109#[derive(Clone, Default)]
110pub struct VisionSpecificConfig {
112 pub use_flash_attn: bool,
113 pub prompt_chunksize: Option<NonZeroUsize>,
114 pub topology: Option<Topology>,
115 pub write_uqff: Option<PathBuf>,
116 pub from_uqff: Option<PathBuf>,
117 pub max_edge: Option<u32>,
118 pub imatrix: Option<PathBuf>,
119 pub calibration_file: Option<PathBuf>,
120 pub hf_cache_path: Option<PathBuf>,
121}
122
123impl VisionLoaderBuilder {
124 pub fn new(
125 config: VisionSpecificConfig,
126 chat_template: Option<String>,
127 tokenizer_json: Option<String>,
128 model_id: Option<String>,
129 jinja_explicit: Option<String>,
130 ) -> Self {
131 Self {
132 config,
133 chat_template,
134 tokenizer_json,
135 model_id,
136 jinja_explicit,
137 kind: ModelKind::Normal,
138 hf_cache_path: None,
139 ..Default::default()
140 }
141 }
142
143 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
144 self.hf_cache_path = Some(hf_cache_path);
145 self
146 }
147
148 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
149 self.kind = ModelKind::Adapter {
150 adapter: AdapterKind::Lora,
151 };
152 self.lora_adapter_ids = Some(lora_adapter_ids);
153 self
154 }
155
156 pub fn build(self, loader: VisionLoaderType) -> Box<dyn Loader> {
157 let loader: Box<dyn VisionModelLoader> = match loader {
158 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
159 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
160 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
161 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
162 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
163 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
164 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
165 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
166 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
167 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
168 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
169 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
170 };
171 Box::new(VisionLoader {
172 inner: loader,
173 model_id: self.model_id.unwrap(),
174 config: self.config,
175 kind: self.kind,
176 chat_template: self.chat_template,
177 tokenizer_json: self.tokenizer_json,
178 xlora_model_id: None,
179 xlora_order: None,
180 jinja_explicit: self.jinja_explicit,
181 token_source: RwLock::new(None),
182 revision: RwLock::new(None),
183 from_uqff: RwLock::new(None),
184 hf_cache_path: self.hf_cache_path,
185 lora_adapter_ids: self.lora_adapter_ids,
186 })
187 }
188}
189
190impl Loader for VisionLoader {
191 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
192 fn load_model_from_hf(
193 &self,
194 revision: Option<String>,
195 token_source: TokenSource,
196 dtype: &dyn TryIntoDType,
197 device: &Device,
198 silent: bool,
199 mapper: DeviceMapSetting,
200 in_situ_quant: Option<IsqType>,
201 paged_attn_config: Option<PagedAttentionConfig>,
202 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
203 let cache = self
204 .hf_cache_path
205 .clone()
206 .map(Cache::new)
207 .unwrap_or_default();
208 GLOBAL_HF_CACHE.get_or_init(|| cache);
209
210 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
211 LocalModelPaths,
212 &token_source,
213 revision.clone(),
214 self,
215 None,
216 None,
217 silent,
218 self.config.from_uqff.is_some()
219 );
220 if let Some(from_uqff) = self.config.from_uqff.clone() {
221 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
222 }
223 *self
224 .token_source
225 .write()
226 .expect("Failed to write to token source") = Some(token_source);
227 *self.revision.write().expect("Failed to write to revision") = revision;
228 self.load_model_from_path(
229 &paths?,
230 dtype,
231 device,
232 silent,
233 mapper,
234 in_situ_quant,
235 paged_attn_config,
236 )
237 }
238
239 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
240 fn load_model_from_path(
241 &self,
242 paths: &Box<dyn ModelPaths>,
243 dtype: &dyn TryIntoDType,
244 device: &Device,
245 silent: bool,
246 mut mapper: DeviceMapSetting,
247 in_situ_quant: Option<IsqType>,
248 mut paged_attn_config: Option<PagedAttentionConfig>,
249 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
250 let config = std::fs::read_to_string(paths.get_config_filename())?;
251
252 if !self.inner.supports_paged_attention() {
253 paged_attn_config = None;
254 }
255
256 let use_nccl = mistralrs_quant::distributed::use_nccl();
257
258 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
259 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
260 let WorkerTransferData::Init { id: _, worker_rank } = payload;
261 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
262 } else if use_nccl {
263 vec![candle_core::Device::new_cuda_with_stream(0)?]
264 } else {
265 device_map::get_all_similar_devices(device)?
266 };
267 let device = if use_nccl {
268 available_devices[0].clone()
269 } else {
270 device.clone()
271 };
272
273 if use_nccl {
275 mapper = DeviceMapSetting::DummyNccl {
276 nm_device: available_devices[0].clone(),
277 };
278 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
279 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
281
282 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
285 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
286 let weight_pack_factor = {
287 let ser_artifacts = unsafe {
288 candle_core::safetensors::MmapedSafetensors::new(serialized)?
289 };
290 let mut total_pack_factors = 0;
291 let total_tensors = ser_artifacts.tensors().len();
292 for (_, artifact) in ser_artifacts.tensors() {
293 let artifact = artifact.data();
294 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
296 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
297 {
298 QuantizedSerdeType::Hqq => {
299 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
300 .pack_factor(dtype)
301 }
302 QuantizedSerdeType::Gguf => {
303 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
304 .pack_factor(dtype)
305 }
306 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
307 QuantizedSerdeType::Unquant => 1,
308 QuantizedSerdeType::Afq => {
309 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
310 .pack_factor(dtype)
311 }
312 };
313 total_pack_factors += pack_factor;
314 }
315
316 total_pack_factors / total_tensors
317 };
318
319 let layer_sizes_in_bytes =
320 self.inner
321 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
322 let non_mapped_size_in_bytes =
323 self.inner
324 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
325 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
326 (
327 layer_sizes_in_bytes,
328 non_mapped_size_in_bytes,
329 layer_sizes_sum + non_mapped_size_in_bytes,
330 )
331 } else if let Some(isq) = in_situ_quant {
332 let weight_pack_factor = isq.pack_factor(dtype);
333 let layer_sizes_in_bytes =
334 self.inner
335 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
336 let non_mapped_size_in_bytes =
337 self.inner
338 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
339 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
340 (
341 layer_sizes_in_bytes,
342 non_mapped_size_in_bytes,
343 layer_sizes_sum + non_mapped_size_in_bytes,
344 )
345 } else {
346 let layer_sizes_in_bytes =
347 self.inner.layer_sizes_in_bytes(&config, dtype, 1)?;
348 let non_mapped_size_in_bytes =
349 self.inner.non_mapped_size_in_bytes(&config, dtype, 1)?;
350 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
351 (
352 layer_sizes_in_bytes,
353 non_mapped_size_in_bytes,
354 layer_sizes_sum + non_mapped_size_in_bytes,
355 )
356 };
357
358 let new = self.inner.get_device_layers(
359 &config,
360 self.inner.num_layers(&config)?,
361 layer_sizes_in_bytes,
362 non_mapped_size_in_bytes,
363 total_model_size_in_bytes,
364 &available_devices,
365 dtype,
366 ¶ms,
367 params.max_seq_len(),
368 paged_attn_config.as_ref(),
369 )?;
370 mapper = DeviceMapSetting::Map(new);
371 }
372
373 let pipeline_mapper = mapper.into_mapper(
374 self.inner.num_layers(&config)?,
375 &device,
376 self.config.topology.as_ref(),
377 )?;
378 let mapper = mapper.into_mapper(
379 self.inner.num_layers(&config)?,
380 &device,
381 self.config.topology.as_ref(),
382 )?;
383 let mut layer_devices = Vec::new();
384 for layer in 0..self.inner.num_layers(&config)? {
385 let device = mapper.device_for(layer, false).cloned();
386 layer_devices.push(device);
387 }
388 let dtype = mapper.get_min_dtype(dtype)?;
389
390 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
393 if mapping_uses_cpu {
394 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
395 paged_attn_config = None;
396 }
397
398 info!(
399 "Model config: {:?}",
400 self.inner
401 .get_config_repr(&config, self.config.use_flash_attn)?
402 );
403
404 let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
405 if let Some(ref topology) = self.config.topology {
406 loading_isq |= topology
407 .0
408 .iter()
409 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
410 }
411
412 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
413 anyhow::bail!(
414 "`imatrix` and `calibration_file` were both specified, this is not allowed."
415 );
416 }
417
418 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
420 loading_isq = false;
421 device.clone()
422 } else {
423 Device::Cpu
424 };
425
426 let attention_mechanism = if paged_attn_config.is_some() {
427 AttentionImplementation::PagedAttention
428 } else {
429 AttentionImplementation::Eager
430 };
431
432 let multi_progress = Arc::new(MultiProgress::new());
433
434 let mut model = if use_nccl {
435 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
436 dtype,
437 &device,
438 &load_device,
439 &available_devices,
440 &config,
441 loading_isq,
442 self.config.from_uqff.is_some(),
443 IsqOrganization::Default,
444 &*self.inner,
445 paths.as_ref(),
446 )?;
447
448 match self.kind {
450 ModelKind::Normal => vision_normal_model_loader_sharded!(
451 sharded_vb,
452 config,
453 self.inner,
454 self.config.use_flash_attn,
455 mapper,
456 loading_isq,
457 device.clone(),
458 attention_mechanism,
459 multi_progress.clone(),
460 ),
461 _ => unreachable!(),
462 }
463 } else {
464 match self.kind {
465 ModelKind::Normal => vision_normal_model_loader!(
466 paths,
467 Some(dtype),
468 &load_device,
469 layer_devices.clone(),
470 config,
471 self.inner,
472 self.config.use_flash_attn,
473 silent,
474 mapper,
475 loading_isq,
476 self.config.from_uqff.is_some(),
477 device.clone(),
478 attention_mechanism,
479 multi_progress,
480 ),
481 _ => unreachable!(),
482 }
483 };
484
485 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
487 {
488 Some(preprocessor_config) => {
489 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
490 }
491 None => PreProcessorConfig::default(),
492 };
493 let processor_config: Option<ProcessorConfig> = paths
494 .get_processor_config()
495 .as_ref()
496 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
497
498 let processor = self.inner.get_processor(
499 &config,
500 processor_config,
501 preprocessor_config.clone(),
502 self.config.max_edge,
503 ); let tokenizer = get_tokenizer(
506 paths.get_tokenizer_filename(),
507 Some(processor.get_special_tokens()),
508 )?;
509
510 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
511 serde_json::from_str(&fs::read_to_string(f).unwrap())
512 .expect("bos_token_id/eos_token_id missing in generation_config.json")
513 });
514 let chat_template = get_chat_template(
515 paths,
516 &self.jinja_explicit,
517 &paths
518 .get_chat_template_explicit()
519 .as_ref()
520 .map(|x| x.to_string_lossy().to_string())
521 .clone(),
522 &self.chat_template,
523 None,
524 );
525
526 if let Some(calibration_file) = &self.config.calibration_file {
527 let calibration_data = std::fs::read_to_string(calibration_file)?;
528 let tokens = tokenizer
530 .encode_fast(calibration_data, false)
531 .map_err(anyhow::Error::msg)?
532 .get_ids()
533 .to_vec();
534 info!(
535 "Collecting imatrix from calibration file `{}` of {} tokens.",
536 calibration_file.display(),
537 tokens.len()
538 );
539 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
540 let bos_tok_id = tokenizer
541 .token_to_id(&bos_toks[0])
542 .expect("Somehow the bos token is not present.");
543
544 model.begin_track_stats()?;
547
548 const CHUNK_SIZE: usize = 1024;
549 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
550 let start = Instant::now();
551 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
552 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
553 let chunk_len = chunk.len();
554
555 let start = Instant::now();
556 let inputs =
557 make_prompt_chunk(0, vec![chunk], &[0], &load_device, None, false, None, None)?;
558 let _ = model.forward(
559 &inputs.input,
560 None, &inputs.positions,
562 inputs.context_lens,
563 inputs.position_ids,
564 model.default_model_specific_args(&inputs.input),
565 None,
566 &inputs.flash_meta,
567 )?;
568 match model.cache_mut() {
569 EitherCache::Full(full) => {
570 for layer in &mut *full.lock() {
571 *layer = None
572 }
573 }
574 EitherCache::Normal(normal) => {
575 for layer in &mut *normal.lock().unwrap().0 {
576 layer.reset();
577 }
578 }
579 }
580 let end = Instant::now();
581 info!(
582 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
583 i + 1,
584 end.duration_since(start).as_secs_f32()
585 );
586 }
587 load_device.synchronize()?;
588 let end = Instant::now();
589 info!(
590 "Finished collecting imatrix in {:.2}s",
591 end.duration_since(start).as_secs_f32()
592 );
593 }
594
595 if (in_situ_quant.is_some() || self.config.topology.is_some())
596 && self.config.from_uqff.is_none()
597 {
598 let imatrix_source = match (
599 self.config.imatrix.as_ref(),
600 self.config.calibration_file.is_some(),
601 ) {
602 (None, false) => None,
603 (Some(file), false) => Some(ImatrixDataSource::File(file)),
604 (None, true) => Some(ImatrixDataSource::Collected),
605 (Some(_), true) => unreachable!(),
606 };
607 model.quantize(
608 in_situ_quant,
609 device.clone(),
610 self.config.topology.as_ref(),
611 silent,
612 imatrix_source,
613 IsqOrganization::Default,
614 self.config.write_uqff.as_ref(),
615 UqffFullSer {
616 tokenizer: &tokenizer,
617 template_filename: paths.get_template_filename(),
618 generation_config: paths.get_gen_conf_filename(),
619 config: config.clone(),
620 processor_filename: paths.get_processor_config(),
621 preprocessor_filename: paths.get_preprocessor_config(),
622 },
623 Arc::new(MultiProgress::new()),
624 )?;
625 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
626 model.load_from_artifacts(
627 device.clone(),
628 self.config.topology.as_ref(),
629 silent,
630 from_uqff,
631 )?;
632 }
633
634 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
635 anyhow::ensure!(
636 !matches!(self.kind, ModelKind::Adapter { .. }),
637 "PagedAttention does not support adapter models."
638 );
639 let cache_config = calculate_cache_config(
640 paged_attn_config.mem_gpu,
641 paged_attn_config.mem_cpu,
642 paged_attn_config.block_size,
643 dtype,
644 model.config(),
645 &device,
646 &layer_devices,
647 silent,
648 )?;
649 let cache_engine =
650 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
651 (Some(cache_config), Some(cache_engine))
652 } else {
653 (None, None)
654 };
655
656 let max_seq_len = model.max_seq_len();
657 let tok_env = build_tok_env(tokenizer.clone());
658 let num_hidden_layers = match model.cache() {
659 EitherCache::Full(full) => full.lock().len(),
660 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
661 };
662 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
663 let sliding_window = model.config().sliding_window;
664 let model_metadata = Arc::new(model.config().clone());
665 Ok(Arc::new(Mutex::new(VisionPipeline {
666 model,
667 tokenizer: tokenizer.into(),
668 chat_template: Arc::new(chat_template),
669 model_id: self.model_id.clone(),
670 metadata: Arc::new(GeneralMetadata {
671 max_seq_len,
672 tok_env: Some(tok_env),
673 is_xlora: false,
674 num_hidden_layers,
675 eos_tok: eos,
676 kind: self.kind.clone(),
677 no_kv_cache: false,
678 no_prefix_cache: false,
679 activation_dtype: dtype,
680 sliding_window,
681 cache_config,
682 cache_engine,
683 prompt_chunksize: self.config.prompt_chunksize,
684 model_metadata: Some(model_metadata),
685 }),
686 processor,
687 prefixer: self.inner.prefixer(),
688 preprocessor_config: Arc::new(preprocessor_config),
689 topology: self.config.topology.clone(),
690 silent,
691 template_filename: paths.get_template_filename().clone(),
692 generation_config: paths.get_gen_conf_filename().cloned(),
693 config,
694 processor_filename: paths.get_processor_config().clone(),
695 preprocessor_filename: paths.get_preprocessor_config().clone(),
696 mapper: pipeline_mapper,
697 imatrix: self.config.imatrix.clone(),
698 })))
699 }
700
701 fn get_id(&self) -> String {
702 self.model_id.to_string()
703 }
704
705 fn get_kind(&self) -> ModelKind {
706 self.kind.clone()
707 }
708}
709
710impl PreProcessingMixin for VisionPipeline {
711 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
712 Some(self.chat_template.clone())
713 }
714 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
715 Some(self.preprocessor_config.clone())
716 }
717 fn get_processor(&self) -> Arc<dyn super::Processor> {
718 self.processor.clone()
719 }
720}
721
722impl IsqPipelineMixin for VisionPipeline {
723 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
724 let device = self.device().clone();
725 self.model
726 .quantize(
727 Some(dtype),
728 device,
729 self.topology.as_ref(),
730 self.silent,
731 self.imatrix.as_ref().map(ImatrixDataSource::File),
732 IsqOrganization::Default,
733 None,
734 UqffFullSer {
735 tokenizer: &self.tokenizer,
736 template_filename: &self.template_filename,
737 generation_config: self.generation_config.as_ref(),
738 config: self.config.clone(),
739 processor_filename: &self.processor_filename,
740 preprocessor_filename: &self.preprocessor_filename,
741 },
742 Arc::new(MultiProgress::new()),
743 )
744 .map_err(anyhow::Error::msg)
745 }
746}
747
748impl CacheManagerMixin for VisionPipeline {
749 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
750 if matches!(self.model.cache(), EitherCache::Full(_)) {
751 FullCacheManager.clone_in_cache(self, seqs, false)
752 } else {
753 NormalCacheManager.clone_in_cache(self, seqs, false)
754 }
755 }
756 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
757 if matches!(self.model.cache(), EitherCache::Full(_)) {
758 FullCacheManager.clone_out_cache(self, seqs, false)
759 } else {
760 NormalCacheManager.clone_out_cache(self, seqs, false)
761 }
762 }
763 fn set_none_cache(
764 &self,
765 seqs: &mut [&mut Sequence],
766 reset_non_granular: bool,
767 modify_draft_cache: bool,
768
769 load_preallocated_cache: bool,
770 ) {
771 if matches!(self.model.cache(), EitherCache::Full(_)) {
772 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
773 } else {
774 NormalCacheManager.set_none_cache(
775 self,
776 seqs,
777 modify_draft_cache,
778 load_preallocated_cache,
779 );
780 }
781 if reset_non_granular {
782 self.reset_non_granular_state()
783 }
784 }
785 fn cache(&self) -> &EitherCache {
786 self.model.cache()
787 }
788}
789
790impl MetadataMixin for VisionPipeline {
791 fn device(&self) -> Device {
792 self.model.device().clone()
793 }
794 fn get_metadata(&self) -> Arc<GeneralMetadata> {
795 self.metadata.clone()
796 }
797 fn name(&self) -> String {
798 self.model_id.clone()
799 }
800 fn reset_non_granular_state(&self) {}
801 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
802 Some(self.tokenizer.clone())
803 }
804 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
805 Some(&*self.mapper)
806 }
807}
808
809#[async_trait::async_trait]
810impl Pipeline for VisionPipeline {
811 fn forward_inputs(
812 &mut self,
813 inputs: Box<dyn Any>,
814 return_raw_logits: bool,
815 ) -> candle_core::Result<ForwardInputsResult> {
816 let ModelInputs {
817 input_ids,
818 seqlen_offsets,
819 context_lens,
820 position_ids,
821 pixel_values,
822 model_specific_args,
823 paged_attn_meta,
824 flash_meta,
825 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
826 let metadata = self.get_metadata();
827 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
828 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
829 (Some(_), None) => {
830 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
832 }
833 (None, Some(_)) => {
834 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
836 }
837 (None, None) => None,
838 };
839 #[cfg(feature = "metal")]
840 let logits = objc::rc::autoreleasepool(|| {
841 self.model.forward(
842 &input_ids,
843 pixel_values,
844 &seqlen_offsets,
845 context_lens,
846 position_ids,
847 model_specific_args,
848 paged_attn_meta,
849 &flash_meta,
850 )
851 })?;
852 #[cfg(not(feature = "metal"))]
853 let logits = self.model.forward(
854 &input_ids,
855 pixel_values,
856 &seqlen_offsets,
857 context_lens,
858 position_ids,
859 model_specific_args,
860 paged_attn_meta,
861 &flash_meta,
862 )?;
863 if return_raw_logits {
864 Ok(ForwardInputsResult::RawLogits { logits })
865 } else {
866 Ok(ForwardInputsResult::CausalGeneration { logits })
867 }
868 }
869 async fn sample_causal_gen(
870 &self,
871 seqs: &mut [&mut Sequence],
872 logits: Vec<Tensor>,
873 prefix_cacher: &mut PrefixCacheManagerV2,
874 disable_eos_stop: bool,
875 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
876 ) -> Result<(), candle_core::Error> {
877 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
878 }
879 fn category(&self) -> ModelCategory {
880 let has_conv2d = self.model.has_conv2d();
881 ModelCategory::Vision {
882 has_conv2d,
883 prefixer: self.prefixer.clone(),
884 }
885 }
886}
887
888impl AnyMoePipelineMixin for VisionPipeline {
889 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
890 self.model.finish_training(gate_model_id)
891 }
892 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
893 self.model.get_vars()
894 }
895 fn amoe_base_model_trainable_params(&self) -> usize {
896 self.model.trainable_params()
897 }
898 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
899 self.model.take_cached_gating_outputs()
900 }
901 fn amoe_create_layers(
902 &mut self,
903 model_ids: Vec<String>,
904 token: &TokenSource,
905 revision: Option<String>,
906 match_regex: &str,
907 config: crate::amoe::AnyMoeConfig,
908 dtype: candle_core::DType,
909 dev: &Device,
910 (prefix, mlp): (String, String),
911 layers: Vec<usize>,
912 expert_type: AnyMoeExpertType,
913 silent: bool,
914 gate_model_id: Option<String>,
915 ) -> candle_core::Result<()> {
916 let mut vbs = Vec::new();
917 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
919 for model_id in model_ids {
920 let model_id_str = &model_id;
921 let model_id = Path::new(&model_id);
922
923 let api = {
924 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
925 let mut api = ApiBuilder::from_cache(cache)
926 .with_progress(!silent)
927 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
928 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
929 api = api.with_cache_dir(x.into());
930 }
931 api.build().map_err(candle_core::Error::msg)?
932 };
933 let revision = revision.clone().unwrap_or("main".to_string());
934 let api = api.repo(Repo::with_revision(
935 model_id_str.clone(),
936 RepoType::Model,
937 revision.clone(),
938 ));
939
940 let mut filenames = vec![];
941 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
942 filenames.push(api_get_file!(api, &rfilename, model_id));
943 }
944
945 let regex = regex.clone();
946 let match_regex_clone = match_regex.to_string();
947 let layers_clone = layers.clone();
948 let vb = from_mmaped_safetensors(
949 filenames,
950 vec![],
951 Some(dtype),
952 dev,
953 vec![None],
954 silent,
955 None,
956 move |key| {
957 if regex.is_match(&key) {
958 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
961 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
962 let layer_n = key[first_layer_idx + 1..last_layer_idx]
963 .parse::<usize>()
964 .unwrap();
965 layers_clone.contains(&layer_n) || layers_clone.is_empty()
966 } else {
967 false
968 }
969 },
970 Arc::new(|_| DeviceForLoadTensor::Base),
971 )?;
972 vbs.push(vb);
973 }
974
975 let gate_vb = if let Some(gate_model_id) = gate_model_id {
976 let model_id_str = &gate_model_id;
977 let model_id = Path::new(&gate_model_id);
978
979 let api = {
980 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
981 let mut api = ApiBuilder::from_cache(cache)
982 .with_progress(!silent)
983 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
984 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
985 api = api.with_cache_dir(x.into());
986 }
987 api.build().map_err(candle_core::Error::msg)?
988 };
989 let revision = revision.clone().unwrap_or("main".to_string());
990 let api = api.repo(Repo::with_revision(
991 model_id_str.clone(),
992 RepoType::Model,
993 revision.clone(),
994 ));
995
996 let mut gate_filenames = vec![];
997 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
998 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
999 }
1000 assert_eq!(
1001 gate_filenames.len(),
1002 1,
1003 "Gate model ID must contain only one .safetensors file"
1004 );
1005
1006 let vb = from_mmaped_safetensors(
1007 gate_filenames.clone(),
1008 vec![],
1009 Some(dtype),
1010 dev,
1011 vec![None],
1012 silent,
1013 None,
1014 |_| true,
1015 Arc::new(|_| DeviceForLoadTensor::Base),
1016 )?;
1017 info!(
1018 "Loaded gating layers from `{}`",
1019 gate_filenames[0].display()
1020 );
1021 Some(vb)
1022 } else {
1023 None
1024 };
1025
1026 self.model
1027 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1028 }
1029 fn amoe_supported(&self) -> bool {
1030 self.model.amoe_supported()
1031 }
1032}