1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::isq::ImatrixDataSource;
3use super::isq::UqffFullSer;
4use super::{
5 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManager,
6 CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader, GeneralMetadata,
7 IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory, ModelKind, ModelPaths,
8 Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, TokenSource, VLlama4Loader,
9 VLlamaLoader, VisionModel, VisionModelLoader, VisionPromptPrefixer,
10};
11use super::{
12 Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
13 Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, WorkerTransferData};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::sampling::sample_and_add_toks;
21use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
22use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
23use crate::prefix_cacher::PrefixCacheManagerV2;
24use crate::sequence::Sequence;
25use crate::utils::tokenizer::get_tokenizer;
26use crate::utils::varbuilder_utils::DeviceForLoadTensor;
27use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
28use crate::vision_models::preprocessor_config::PreProcessorConfig;
29use crate::vision_models::processor_config::ProcessorConfig;
30use crate::vision_models::ModelInputs;
31use crate::{
32 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
33 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
34 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
35};
36use anyhow::Result;
37use candle_core::{Device, Tensor, Var};
38use hf_hub::Cache;
39use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
40use indicatif::MultiProgress;
41use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
42use rand_isaac::Isaac64Rng;
43use regex_automata::meta::Regex;
44use std::any::Any;
45use std::borrow::Cow;
46use std::num::NonZeroUsize;
47use std::path::{Path, PathBuf};
48use std::str::FromStr;
49use std::sync::{Arc, RwLock};
50use std::time::Instant;
51use std::{env, fs};
52use tokenizers::Tokenizer;
53use tokio::sync::Mutex;
54use tracing::{info, warn};
55
56pub struct VisionPipeline {
57 model: Box<dyn VisionModel + Send + Sync>,
58 tokenizer: Arc<Tokenizer>,
59 chat_template: Arc<ChatTemplate>,
60 model_id: String,
61 metadata: Arc<GeneralMetadata>,
62 processor: Arc<dyn Processor + Send + Sync>,
63 preprocessor_config: Arc<PreProcessorConfig>,
64 topology: Option<Topology>,
65 silent: bool,
66 prefixer: Arc<dyn VisionPromptPrefixer>,
67 mapper: Box<dyn DeviceMapper + Send + Sync>,
68
69 template_filename: Option<PathBuf>,
71 generation_config: Option<PathBuf>,
72 config: String,
73 processor_filename: Option<PathBuf>,
74 preprocessor_filename: Option<PathBuf>,
75 imatrix: Option<PathBuf>,
76}
77
78pub struct VisionLoader {
80 inner: Box<dyn VisionModelLoader>,
81 model_id: String,
82 config: VisionSpecificConfig,
83 kind: ModelKind,
84 chat_template: Option<String>,
85 tokenizer_json: Option<String>,
86 xlora_model_id: Option<String>,
87 xlora_order: Option<Ordering>,
88 token_source: RwLock<Option<TokenSource>>,
89 revision: RwLock<Option<String>>,
90 from_uqff: RwLock<Option<Vec<PathBuf>>>,
91 jinja_explicit: Option<String>,
92 hf_cache_path: Option<PathBuf>,
93 lora_adapter_ids: Option<Vec<String>>,
94}
95
96#[derive(Default)]
97pub struct VisionLoaderBuilder {
99 model_id: Option<String>,
100 config: VisionSpecificConfig,
101 kind: ModelKind,
102 chat_template: Option<String>,
103 tokenizer_json: Option<String>,
104 jinja_explicit: Option<String>,
105 hf_cache_path: Option<PathBuf>,
106 lora_adapter_ids: Option<Vec<String>>,
107}
108
109#[derive(Clone, Default)]
110pub struct VisionSpecificConfig {
112 pub use_flash_attn: bool,
113 pub prompt_chunksize: Option<NonZeroUsize>,
114 pub topology: Option<Topology>,
115 pub write_uqff: Option<PathBuf>,
116 pub from_uqff: Option<Vec<PathBuf>>,
117 pub max_edge: Option<u32>,
118 pub imatrix: Option<PathBuf>,
119 pub calibration_file: Option<PathBuf>,
120 pub hf_cache_path: Option<PathBuf>,
121}
122
123impl VisionLoaderBuilder {
124 pub fn new(
125 config: VisionSpecificConfig,
126 chat_template: Option<String>,
127 tokenizer_json: Option<String>,
128 model_id: Option<String>,
129 jinja_explicit: Option<String>,
130 ) -> Self {
131 Self {
132 config,
133 chat_template,
134 tokenizer_json,
135 model_id,
136 jinja_explicit,
137 kind: ModelKind::Normal,
138 hf_cache_path: None,
139 ..Default::default()
140 }
141 }
142
143 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
144 self.hf_cache_path = Some(hf_cache_path);
145 self
146 }
147
148 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
149 self.kind = ModelKind::Adapter {
150 adapter: AdapterKind::Lora,
151 };
152 self.lora_adapter_ids = Some(lora_adapter_ids);
153 self
154 }
155
156 pub fn build(self, loader: VisionLoaderType) -> Box<dyn Loader> {
157 let loader: Box<dyn VisionModelLoader> = match loader {
158 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
159 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
160 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
161 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
162 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
163 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
164 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
165 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
166 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
167 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
168 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
169 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
170 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
171 };
172 Box::new(VisionLoader {
173 inner: loader,
174 model_id: self.model_id.unwrap(),
175 config: self.config,
176 kind: self.kind,
177 chat_template: self.chat_template,
178 tokenizer_json: self.tokenizer_json,
179 xlora_model_id: None,
180 xlora_order: None,
181 jinja_explicit: self.jinja_explicit,
182 token_source: RwLock::new(None),
183 revision: RwLock::new(None),
184 from_uqff: RwLock::new(None),
185 hf_cache_path: self.hf_cache_path,
186 lora_adapter_ids: self.lora_adapter_ids,
187 })
188 }
189}
190
191impl Loader for VisionLoader {
192 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
193 fn load_model_from_hf(
194 &self,
195 revision: Option<String>,
196 token_source: TokenSource,
197 dtype: &dyn TryIntoDType,
198 device: &Device,
199 silent: bool,
200 mapper: DeviceMapSetting,
201 in_situ_quant: Option<IsqType>,
202 paged_attn_config: Option<PagedAttentionConfig>,
203 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
204 let cache = self
205 .hf_cache_path
206 .clone()
207 .map(Cache::new)
208 .unwrap_or_default();
209 GLOBAL_HF_CACHE.get_or_init(|| cache);
210
211 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
212 LocalModelPaths,
213 &token_source,
214 revision.clone(),
215 self,
216 None,
217 None,
218 silent,
219 self.config.from_uqff.is_some()
220 );
221 if let Some(from_uqff) = self.config.from_uqff.clone() {
222 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
223 }
224 *self
225 .token_source
226 .write()
227 .expect("Failed to write to token source") = Some(token_source);
228 *self.revision.write().expect("Failed to write to revision") = revision;
229 self.load_model_from_path(
230 &paths?,
231 dtype,
232 device,
233 silent,
234 mapper,
235 in_situ_quant,
236 paged_attn_config,
237 )
238 }
239
240 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
241 fn load_model_from_path(
242 &self,
243 paths: &Box<dyn ModelPaths>,
244 dtype: &dyn TryIntoDType,
245 device: &Device,
246 silent: bool,
247 mut mapper: DeviceMapSetting,
248 in_situ_quant: Option<IsqType>,
249 mut paged_attn_config: Option<PagedAttentionConfig>,
250 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
251 let config = std::fs::read_to_string(paths.get_config_filename())?;
252
253 if !self.inner.supports_paged_attention() {
254 paged_attn_config = None;
255 }
256
257 let use_nccl = mistralrs_quant::distributed::use_nccl();
258
259 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
260 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
261 let WorkerTransferData::Init { id: _, worker_rank } = payload;
262 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
263 } else if use_nccl {
264 vec![candle_core::Device::new_cuda_with_stream(0)?]
265 } else {
266 device_map::get_all_similar_devices(device)?
267 };
268 let device = if use_nccl {
269 available_devices[0].clone()
270 } else {
271 device.clone()
272 };
273
274 if use_nccl {
276 mapper = DeviceMapSetting::DummyNccl {
277 nm_device: available_devices[0].clone(),
278 };
279 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
280 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
282
283 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
286 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
287 let weight_pack_factor = {
288 let ser_artifacts = unsafe {
289 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
290 };
291 let mut total_pack_factors = 0;
292 let total_tensors = ser_artifacts.tensors().len();
293 for (_, artifact) in ser_artifacts.tensors() {
294 let artifact = artifact.data();
295 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
297 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
298 {
299 QuantizedSerdeType::Hqq => {
300 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
301 .pack_factor(dtype)
302 }
303 QuantizedSerdeType::Gguf => {
304 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
305 .pack_factor(dtype)
306 }
307 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
308 QuantizedSerdeType::Unquant => 1,
309 QuantizedSerdeType::Afq => {
310 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
311 .pack_factor(dtype)
312 }
313 };
314 total_pack_factors += pack_factor;
315 }
316
317 total_pack_factors / total_tensors
318 };
319
320 let layer_sizes_in_bytes =
321 self.inner
322 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
323 let non_mapped_size_in_bytes =
324 self.inner
325 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
326 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
327 (
328 layer_sizes_in_bytes,
329 non_mapped_size_in_bytes,
330 layer_sizes_sum + non_mapped_size_in_bytes,
331 )
332 } else if let Some(isq) = in_situ_quant {
333 let weight_pack_factor = isq.pack_factor(dtype);
334 let layer_sizes_in_bytes =
335 self.inner
336 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
337 let non_mapped_size_in_bytes =
338 self.inner
339 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
340 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
341 (
342 layer_sizes_in_bytes,
343 non_mapped_size_in_bytes,
344 layer_sizes_sum + non_mapped_size_in_bytes,
345 )
346 } else {
347 let layer_sizes_in_bytes =
348 self.inner.layer_sizes_in_bytes(&config, dtype, 1)?;
349 let non_mapped_size_in_bytes =
350 self.inner.non_mapped_size_in_bytes(&config, dtype, 1)?;
351 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
352 (
353 layer_sizes_in_bytes,
354 non_mapped_size_in_bytes,
355 layer_sizes_sum + non_mapped_size_in_bytes,
356 )
357 };
358
359 let new = self.inner.get_device_layers(
360 &config,
361 self.inner.num_layers(&config)?,
362 layer_sizes_in_bytes,
363 non_mapped_size_in_bytes,
364 total_model_size_in_bytes,
365 &available_devices,
366 dtype,
367 ¶ms,
368 params.max_seq_len(),
369 paged_attn_config.as_ref(),
370 )?;
371 mapper = DeviceMapSetting::Map(new);
372 }
373
374 let pipeline_mapper = mapper.into_mapper(
375 self.inner.num_layers(&config)?,
376 &device,
377 self.config.topology.as_ref(),
378 )?;
379 let mapper = mapper.into_mapper(
380 self.inner.num_layers(&config)?,
381 &device,
382 self.config.topology.as_ref(),
383 )?;
384 let mut layer_devices = Vec::new();
385 for layer in 0..self.inner.num_layers(&config)? {
386 let device = mapper.device_for(layer, false).cloned();
387 layer_devices.push(device);
388 }
389 let dtype = mapper.get_min_dtype(dtype)?;
390
391 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
394 if mapping_uses_cpu {
395 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
396 paged_attn_config = None;
397 }
398
399 info!(
400 "Model config: {:?}",
401 self.inner
402 .get_config_repr(&config, self.config.use_flash_attn)?
403 );
404
405 let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
406 if let Some(ref topology) = self.config.topology {
407 loading_isq |= topology
408 .0
409 .iter()
410 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
411 }
412
413 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
414 anyhow::bail!(
415 "`imatrix` and `calibration_file` were both specified, this is not allowed."
416 );
417 }
418
419 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
421 loading_isq = false;
422 device.clone()
423 } else {
424 Device::Cpu
425 };
426
427 let attention_mechanism = if paged_attn_config.is_some() {
428 AttentionImplementation::PagedAttention
429 } else {
430 AttentionImplementation::Eager
431 };
432
433 let multi_progress = Arc::new(MultiProgress::new());
434
435 let mut model = if use_nccl {
436 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
437 dtype,
438 &device,
439 &available_devices,
440 silent,
441 &config,
442 loading_isq,
443 self.config.from_uqff.is_some(),
444 IsqOrganization::Default,
445 &*self.inner,
446 paths.as_ref(),
447 )?;
448
449 match self.kind {
451 ModelKind::Normal => vision_normal_model_loader_sharded!(
452 sharded_vb,
453 config,
454 self.inner,
455 self.config.use_flash_attn,
456 mapper,
457 loading_isq,
458 device.clone(),
459 attention_mechanism,
460 multi_progress.clone(),
461 ),
462 _ => unreachable!(),
463 }
464 } else {
465 match self.kind {
466 ModelKind::Normal => vision_normal_model_loader!(
467 paths,
468 Some(dtype),
469 &load_device,
470 layer_devices.clone(),
471 config,
472 self.inner,
473 self.config.use_flash_attn,
474 silent,
475 mapper,
476 loading_isq,
477 self.config.from_uqff.is_some(),
478 device.clone(),
479 attention_mechanism,
480 multi_progress,
481 ),
482 _ => unreachable!(),
483 }
484 };
485
486 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
488 {
489 Some(preprocessor_config) => {
490 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
491 }
492 None => PreProcessorConfig::default(),
493 };
494 let processor_config: Option<ProcessorConfig> = paths
495 .get_processor_config()
496 .as_ref()
497 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
498
499 let processor = self.inner.get_processor(
500 &config,
501 processor_config,
502 preprocessor_config.clone(),
503 self.config.max_edge,
504 ); let tokenizer = get_tokenizer(
507 paths.get_tokenizer_filename(),
508 Some(processor.get_special_tokens()),
509 )?;
510
511 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
512 serde_json::from_str(&fs::read_to_string(f).unwrap())
513 .expect("bos_token_id/eos_token_id missing in generation_config.json")
514 });
515 let chat_template = get_chat_template(
516 paths,
517 &self.jinja_explicit,
518 &paths
519 .get_chat_template_explicit()
520 .as_ref()
521 .map(|x| x.to_string_lossy().to_string())
522 .clone(),
523 &self.chat_template,
524 None,
525 );
526
527 if let Some(calibration_file) = &self.config.calibration_file {
528 let calibration_data = std::fs::read_to_string(calibration_file)?;
529 let tokens = tokenizer
531 .encode_fast(calibration_data, false)
532 .map_err(anyhow::Error::msg)?
533 .get_ids()
534 .to_vec();
535 info!(
536 "Collecting imatrix from calibration file `{}` of {} tokens.",
537 calibration_file.display(),
538 tokens.len()
539 );
540 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
541 let bos_tok_id = tokenizer
542 .token_to_id(&bos_toks[0])
543 .expect("Somehow the bos token is not present.");
544
545 model.begin_track_stats()?;
548
549 const CHUNK_SIZE: usize = 1024;
550 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
551 let start = Instant::now();
552 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
553 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
554 let chunk_len = chunk.len();
555
556 let start = Instant::now();
557 let inputs = make_prompt_chunk(
558 0,
559 vec![&chunk],
560 &[0],
561 &load_device,
562 None,
563 false,
564 None,
565 None,
566 )?;
567 let _ = model.forward(
568 &inputs.input,
569 None, &inputs.positions,
571 inputs.context_lens,
572 inputs.position_ids,
573 model.default_model_specific_args(&inputs.input),
574 None,
575 &inputs.flash_meta,
576 )?;
577 match model.cache_mut() {
578 EitherCache::Full(full) => {
579 for layer in &mut *full.lock() {
580 *layer = None
581 }
582 }
583 EitherCache::Normal(normal) => {
584 for layer in &mut *normal.lock().unwrap().0 {
585 layer.reset();
586 }
587 }
588 }
589 let end = Instant::now();
590 info!(
591 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
592 i + 1,
593 end.duration_since(start).as_secs_f32()
594 );
595 }
596 load_device.synchronize()?;
597 let end = Instant::now();
598 info!(
599 "Finished collecting imatrix in {:.2}s",
600 end.duration_since(start).as_secs_f32()
601 );
602 }
603
604 if (in_situ_quant.is_some() || self.config.topology.is_some())
605 && self.config.from_uqff.is_none()
606 {
607 let imatrix_source = match (
608 self.config.imatrix.as_ref(),
609 self.config.calibration_file.is_some(),
610 ) {
611 (None, false) => None,
612 (Some(file), false) => Some(ImatrixDataSource::File(file)),
613 (None, true) => Some(ImatrixDataSource::Collected),
614 (Some(_), true) => unreachable!(),
615 };
616 model.quantize(
617 in_situ_quant,
618 device.clone(),
619 self.config.topology.as_ref(),
620 silent,
621 imatrix_source,
622 IsqOrganization::Default,
623 self.config.write_uqff.as_ref(),
624 UqffFullSer {
625 tokenizer: &tokenizer,
626 template_filename: paths.get_template_filename(),
627 generation_config: paths.get_gen_conf_filename(),
628 config: config.clone(),
629 processor_filename: paths.get_processor_config(),
630 preprocessor_filename: paths.get_preprocessor_config(),
631 },
632 Arc::new(MultiProgress::new()),
633 )?;
634 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
635 model.load_from_artifacts(
636 device.clone(),
637 self.config.topology.as_ref(),
638 silent,
639 from_uqff,
640 )?;
641 }
642
643 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
644 anyhow::ensure!(
645 !matches!(self.kind, ModelKind::Adapter { .. }),
646 "PagedAttention does not support adapter models."
647 );
648 let cache_config = calculate_cache_config(
649 paged_attn_config.mem_gpu,
650 paged_attn_config.mem_cpu,
651 paged_attn_config.block_size,
652 dtype,
653 model.config(),
654 &device,
655 &layer_devices,
656 silent,
657 )?;
658 let cache_engine =
659 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
660 (Some(cache_config), Some(cache_engine))
661 } else {
662 (None, None)
663 };
664
665 let max_seq_len = model.max_seq_len();
666 let llg_factory = build_llg_factory(tokenizer.clone())?;
667 let num_hidden_layers = match model.cache() {
668 EitherCache::Full(full) => full.lock().len(),
669 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
670 };
671 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
672 let sliding_window = model.config().sliding_window;
673 let model_metadata = Arc::new(model.config().clone());
674 Ok(Arc::new(Mutex::new(VisionPipeline {
675 model,
676 tokenizer: tokenizer.into(),
677 chat_template: Arc::new(chat_template),
678 model_id: self.model_id.clone(),
679 metadata: Arc::new(GeneralMetadata {
680 max_seq_len,
681 llg_factory: Some(llg_factory),
682 is_xlora: false,
683 num_hidden_layers,
684 eos_tok: eos,
685 kind: self.kind.clone(),
686 no_kv_cache: false,
687 no_prefix_cache: false,
688 activation_dtype: dtype,
689 sliding_window,
690 cache_config,
691 cache_engine,
692 prompt_chunksize: self.config.prompt_chunksize,
693 model_metadata: Some(model_metadata),
694 }),
695 processor,
696 prefixer: self.inner.prefixer(),
697 preprocessor_config: Arc::new(preprocessor_config),
698 topology: self.config.topology.clone(),
699 silent,
700 template_filename: paths.get_template_filename().clone(),
701 generation_config: paths.get_gen_conf_filename().cloned(),
702 config,
703 processor_filename: paths.get_processor_config().clone(),
704 preprocessor_filename: paths.get_preprocessor_config().clone(),
705 mapper: pipeline_mapper,
706 imatrix: self.config.imatrix.clone(),
707 })))
708 }
709
710 fn get_id(&self) -> String {
711 self.model_id.to_string()
712 }
713
714 fn get_kind(&self) -> ModelKind {
715 self.kind.clone()
716 }
717}
718
719impl PreProcessingMixin for VisionPipeline {
720 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
721 Some(self.chat_template.clone())
722 }
723 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
724 Some(self.preprocessor_config.clone())
725 }
726 fn get_processor(&self) -> Arc<dyn super::Processor> {
727 self.processor.clone()
728 }
729}
730
731impl IsqPipelineMixin for VisionPipeline {
732 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
733 let device = self.device().clone();
734 self.model
735 .quantize(
736 Some(dtype),
737 device,
738 self.topology.as_ref(),
739 self.silent,
740 self.imatrix.as_ref().map(ImatrixDataSource::File),
741 IsqOrganization::Default,
742 None,
743 UqffFullSer {
744 tokenizer: &self.tokenizer,
745 template_filename: &self.template_filename,
746 generation_config: self.generation_config.as_ref(),
747 config: self.config.clone(),
748 processor_filename: &self.processor_filename,
749 preprocessor_filename: &self.preprocessor_filename,
750 },
751 Arc::new(MultiProgress::new()),
752 )
753 .map_err(anyhow::Error::msg)
754 }
755}
756
757impl CacheManagerMixin for VisionPipeline {
758 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
759 if matches!(self.model.cache(), EitherCache::Full(_)) {
760 FullCacheManager.clone_in_cache(self, seqs, false)
761 } else {
762 NormalCacheManager.clone_in_cache(self, seqs, false)
763 }
764 }
765 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
766 if matches!(self.model.cache(), EitherCache::Full(_)) {
767 FullCacheManager.clone_out_cache(self, seqs, false)
768 } else {
769 NormalCacheManager.clone_out_cache(self, seqs, false)
770 }
771 }
772 fn set_none_cache(
773 &self,
774 seqs: &mut [&mut Sequence],
775 reset_non_granular: bool,
776 modify_draft_cache: bool,
777
778 load_preallocated_cache: bool,
779 ) {
780 if matches!(self.model.cache(), EitherCache::Full(_)) {
781 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
782 } else {
783 NormalCacheManager.set_none_cache(
784 self,
785 seqs,
786 modify_draft_cache,
787 load_preallocated_cache,
788 );
789 }
790 if reset_non_granular {
791 self.reset_non_granular_state()
792 }
793 }
794 fn cache(&self) -> &EitherCache {
795 self.model.cache()
796 }
797}
798
799impl MetadataMixin for VisionPipeline {
800 fn device(&self) -> Device {
801 self.model.device().clone()
802 }
803 fn get_metadata(&self) -> Arc<GeneralMetadata> {
804 self.metadata.clone()
805 }
806 fn name(&self) -> String {
807 self.model_id.clone()
808 }
809 fn reset_non_granular_state(&self) {}
810 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
811 Some(self.tokenizer.clone())
812 }
813 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
814 Some(&*self.mapper)
815 }
816}
817
818#[async_trait::async_trait]
819impl Pipeline for VisionPipeline {
820 fn forward_inputs(
821 &mut self,
822 inputs: Box<dyn Any>,
823 return_raw_logits: bool,
824 ) -> candle_core::Result<ForwardInputsResult> {
825 let ModelInputs {
826 input_ids,
827 seqlen_offsets,
828 context_lens,
829 position_ids,
830 pixel_values,
831 model_specific_args,
832 paged_attn_meta,
833 flash_meta,
834 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
835 let metadata = self.get_metadata();
836 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
837 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
838 (Some(_), None) => {
839 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
841 }
842 (None, Some(_)) => {
843 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
845 }
846 (None, None) => None,
847 };
848 let logits = self.model.forward(
849 &input_ids,
850 pixel_values,
851 &seqlen_offsets,
852 context_lens,
853 position_ids,
854 model_specific_args,
855 paged_attn_meta,
856 &flash_meta,
857 )?;
858 if return_raw_logits {
859 Ok(ForwardInputsResult::RawLogits { logits })
860 } else {
861 Ok(ForwardInputsResult::CausalGeneration { logits })
862 }
863 }
864 async fn sample_causal_gen(
865 &self,
866 seqs: &mut [&mut Sequence],
867 logits: Vec<Tensor>,
868 prefix_cacher: &mut PrefixCacheManagerV2,
869 disable_eos_stop: bool,
870 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
871 ) -> Result<(), candle_core::Error> {
872 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
873 }
874 fn category(&self) -> ModelCategory {
875 let has_conv2d = self.model.has_conv2d();
876 ModelCategory::Vision {
877 has_conv2d,
878 prefixer: self.prefixer.clone(),
879 }
880 }
881}
882
883impl AnyMoePipelineMixin for VisionPipeline {
884 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
885 self.model.finish_training(gate_model_id)
886 }
887 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
888 self.model.get_vars()
889 }
890 fn amoe_base_model_trainable_params(&self) -> usize {
891 self.model.trainable_params()
892 }
893 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
894 self.model.take_cached_gating_outputs()
895 }
896 fn amoe_create_layers(
897 &mut self,
898 model_ids: Vec<String>,
899 token: &TokenSource,
900 revision: Option<String>,
901 match_regex: &str,
902 config: crate::amoe::AnyMoeConfig,
903 dtype: candle_core::DType,
904 dev: &Device,
905 (prefix, mlp): (String, String),
906 layers: Vec<usize>,
907 expert_type: AnyMoeExpertType,
908 silent: bool,
909 gate_model_id: Option<String>,
910 ) -> candle_core::Result<()> {
911 let mut vbs = Vec::new();
912 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
914 for model_id in model_ids {
915 let model_id_str = &model_id;
916 let model_id = Path::new(&model_id);
917
918 let api = {
919 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
920 let mut api = ApiBuilder::from_cache(cache)
921 .with_progress(!silent)
922 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
923 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
924 api = api.with_cache_dir(x.into());
925 }
926 api.build().map_err(candle_core::Error::msg)?
927 };
928 let revision = revision.clone().unwrap_or("main".to_string());
929 let api = api.repo(Repo::with_revision(
930 model_id_str.clone(),
931 RepoType::Model,
932 revision.clone(),
933 ));
934
935 let mut filenames = vec![];
936 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
937 filenames.push(api_get_file!(api, &rfilename, model_id));
938 }
939
940 let regex = regex.clone();
941 let match_regex_clone = match_regex.to_string();
942 let layers_clone = layers.clone();
943 let vb = from_mmaped_safetensors(
944 filenames,
945 vec![],
946 Some(dtype),
947 dev,
948 vec![None],
949 silent,
950 None,
951 move |key| {
952 if regex.is_match(&key) {
953 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
956 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
957 let layer_n = key[first_layer_idx + 1..last_layer_idx]
958 .parse::<usize>()
959 .unwrap();
960 layers_clone.contains(&layer_n) || layers_clone.is_empty()
961 } else {
962 false
963 }
964 },
965 Arc::new(|_| DeviceForLoadTensor::Base),
966 )?;
967 vbs.push(vb);
968 }
969
970 let gate_vb = if let Some(gate_model_id) = gate_model_id {
971 let model_id_str = &gate_model_id;
972 let model_id = Path::new(&gate_model_id);
973
974 let api = {
975 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
976 let mut api = ApiBuilder::from_cache(cache)
977 .with_progress(!silent)
978 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
979 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
980 api = api.with_cache_dir(x.into());
981 }
982 api.build().map_err(candle_core::Error::msg)?
983 };
984 let revision = revision.clone().unwrap_or("main".to_string());
985 let api = api.repo(Repo::with_revision(
986 model_id_str.clone(),
987 RepoType::Model,
988 revision.clone(),
989 ));
990
991 let mut gate_filenames = vec![];
992 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
993 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
994 }
995 assert_eq!(
996 gate_filenames.len(),
997 1,
998 "Gate model ID must contain only one .safetensors file"
999 );
1000
1001 let vb = from_mmaped_safetensors(
1002 gate_filenames.clone(),
1003 vec![],
1004 Some(dtype),
1005 dev,
1006 vec![None],
1007 silent,
1008 None,
1009 |_| true,
1010 Arc::new(|_| DeviceForLoadTensor::Base),
1011 )?;
1012 info!(
1013 "Loaded gating layers from `{}`",
1014 gate_filenames[0].display()
1015 );
1016 Some(vb)
1017 } else {
1018 None
1019 };
1020
1021 self.model
1022 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1023 }
1024 fn amoe_supported(&self) -> bool {
1025 self.model.amoe_supported()
1026 }
1027}