1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7 TokenSource,
8};
9use super::{
10 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
15 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
16 Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::{Path, PathBuf};
56use std::str::FromStr;
57use std::sync::{Arc, RwLock};
58use std::time::Instant;
59use std::{env, fs};
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64pub struct NormalPipeline {
65 model: Box<dyn NormalModel + Send + Sync>,
66 tokenizer: Arc<Tokenizer>,
67 no_kv_cache: bool,
68 chat_template: Arc<ChatTemplate>,
69 non_granular_state: Option<NonGranularState>,
70 model_id: String,
71 metadata: Arc<GeneralMetadata>,
72 topology: Option<Topology>,
73 silent: bool,
74 organization: IsqOrganization,
75 template_filename: Option<PathBuf>,
77 generation_config: Option<PathBuf>,
78 config: String,
79 imatrix: Option<PathBuf>,
80 mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83pub struct NormalLoader {
85 inner: Box<dyn NormalModelLoader>,
86 model_id: String,
87 config: NormalSpecificConfig,
88 xlora_model_id: Option<String>,
89 lora_adapter_ids: Option<Vec<String>>,
90 kind: ModelKind,
91 xlora_order: Option<Ordering>,
92 no_kv_cache: bool,
93 chat_template: Option<String>,
94 tokenizer_json: Option<String>,
95 tgt_non_granular_index: Option<usize>,
96 token_source: RwLock<Option<TokenSource>>,
97 revision: RwLock<Option<String>>,
98 from_uqff: RwLock<Option<Vec<PathBuf>>>,
99 jinja_explicit: Option<String>,
100 hf_cache_path: Option<PathBuf>,
101}
102
103#[derive(Default)]
104pub struct NormalLoaderBuilder {
106 model_id: Option<String>,
107 config: NormalSpecificConfig,
108 xlora_model_id: Option<String>,
109 lora_adapter_ids: Option<Vec<String>>,
110 kind: ModelKind,
111 xlora_order: Option<Ordering>,
112 no_kv_cache: bool,
113 chat_template: Option<String>,
114 tokenizer_json: Option<String>,
115 tgt_non_granular_index: Option<usize>,
116 jinja_explicit: Option<String>,
117 hf_cache_path: Option<PathBuf>,
118}
119
120#[derive(Clone, Default)]
121pub struct NormalSpecificConfig {
123 pub prompt_chunksize: Option<NonZeroUsize>,
124 pub topology: Option<Topology>,
125 pub organization: IsqOrganization,
126 pub write_uqff: Option<PathBuf>,
127 pub from_uqff: Option<Vec<PathBuf>>,
128 pub imatrix: Option<PathBuf>,
129 pub calibration_file: Option<PathBuf>,
130 pub hf_cache_path: Option<PathBuf>,
131}
132
133impl NormalLoaderBuilder {
134 pub fn new(
135 config: NormalSpecificConfig,
136 chat_template: Option<String>,
137 tokenizer_json: Option<String>,
138 model_id: Option<String>,
139 no_kv_cache: bool,
140 jinja_explicit: Option<String>,
141 ) -> Self {
142 Self {
143 config,
144 chat_template,
145 tokenizer_json,
146 model_id,
147 kind: ModelKind::Normal,
148 jinja_explicit,
149 no_kv_cache,
150 ..Default::default()
151 }
152 }
153
154 fn with_adapter(
155 mut self,
156 xlora_model_id: String,
157 xlora_order: Ordering,
158 no_kv_cache: bool,
159 tgt_non_granular_index: Option<usize>,
160 ) -> Self {
161 self.xlora_model_id = Some(xlora_model_id);
162 self.xlora_order = Some(xlora_order);
163 self.no_kv_cache = no_kv_cache;
164 self.tgt_non_granular_index = tgt_non_granular_index;
165 self.model_id = if let Some(id) = self.model_id {
166 Some(id)
167 } else {
168 info!(
169 "Using adapter base model ID: `{}`",
170 self.xlora_order.as_ref().unwrap().base_model_id
171 );
172 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173 };
174 self
175 }
176
177 pub fn with_xlora(
178 mut self,
179 xlora_model_id: String,
180 xlora_order: Ordering,
181 no_kv_cache: bool,
182 tgt_non_granular_index: Option<usize>,
183 ) -> Self {
184 self.kind = ModelKind::Adapter {
185 adapter: AdapterKind::XLora,
186 };
187 self.with_adapter(
188 xlora_model_id,
189 xlora_order,
190 no_kv_cache,
191 tgt_non_granular_index,
192 )
193 }
194
195 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
196 self.kind = ModelKind::Adapter {
197 adapter: AdapterKind::Lora,
198 };
199 self.lora_adapter_ids = Some(lora_adapter_ids);
200 self
201 }
202
203 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
204 self.hf_cache_path = Some(hf_cache_path);
205 self
206 }
207
208 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
211 let loader: Box<dyn NormalModelLoader> = match loader_tp {
212 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
213 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
214 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
215 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
216 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
217 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
218 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
219 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
220 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
221 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
222 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
223 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
224 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
226 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
227 None => Box::new(AutoNormalLoader),
228 };
229 Ok(Box::new(NormalLoader {
230 inner: loader,
231 model_id: self.model_id.unwrap(),
232 config: self.config,
233 xlora_model_id: self.xlora_model_id,
234 lora_adapter_ids: self.lora_adapter_ids,
235 kind: self.kind,
236 xlora_order: self.xlora_order,
237 no_kv_cache: self.no_kv_cache,
238 chat_template: self.chat_template,
239 tokenizer_json: self.tokenizer_json,
240 tgt_non_granular_index: self.tgt_non_granular_index,
241 jinja_explicit: self.jinja_explicit,
242 token_source: RwLock::new(None),
243 revision: RwLock::new(None),
244 from_uqff: RwLock::new(None),
245 hf_cache_path: self.hf_cache_path,
246 }))
247 }
248}
249
250impl Loader for NormalLoader {
251 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
252 fn load_model_from_hf(
253 &self,
254 revision: Option<String>,
255 token_source: TokenSource,
256 dtype: &dyn TryIntoDType,
257 device: &Device,
258 silent: bool,
259 mapper: DeviceMapSetting,
260 in_situ_quant: Option<IsqType>,
261 paged_attn_config: Option<PagedAttentionConfig>,
262 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
263 let cache = self
264 .hf_cache_path
265 .clone()
266 .map(Cache::new)
267 .unwrap_or_default();
268 GLOBAL_HF_CACHE.get_or_init(|| cache);
269
270 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
271 LocalModelPaths,
272 &token_source,
273 revision.clone(),
274 self,
275 None,
276 None,
277 silent,
278 self.config.from_uqff.is_some()
279 );
280 if let Some(from_uqff) = self.config.from_uqff.clone() {
281 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
282 }
283 *self
284 .token_source
285 .write()
286 .expect("Failed to write to token source") = Some(token_source);
287 *self.revision.write().expect("Failed to write to revision") = revision;
288 self.load_model_from_path(
289 &paths?,
290 dtype,
291 device,
292 silent,
293 mapper,
294 in_situ_quant,
295 paged_attn_config,
296 )
297 }
298
299 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
300 fn load_model_from_path(
301 &self,
302 paths: &Box<dyn ModelPaths>,
303 dtype: &dyn TryIntoDType,
304 device: &Device,
305 silent: bool,
306 mut mapper: DeviceMapSetting,
307 mut in_situ_quant: Option<IsqType>,
308 mut paged_attn_config: Option<PagedAttentionConfig>,
309 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
310 let config = std::fs::read_to_string(paths.get_config_filename())?;
311
312 if !self.inner.supports_paged_attention(&config)? {
313 paged_attn_config = None;
314 }
315
316 let prompt_chunksize = self
318 .config
319 .prompt_chunksize
320 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
321 .get();
322
323 info!("Prompt chunk size is {prompt_chunksize}.",);
324
325 let use_nccl = mistralrs_quant::distributed::use_nccl();
326
327 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
328 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
329 let WorkerTransferData::Init { id: _, worker_rank } = payload;
330 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
331 } else if use_nccl {
332 vec![candle_core::Device::new_cuda(0)?]
333 } else {
334 device_map::get_all_similar_devices(device)?
335 };
336 let device = if use_nccl || cfg!(feature = "ring") {
337 available_devices[0].clone()
338 } else {
339 device.clone()
340 };
341
342 if use_nccl || cfg!(feature = "ring") {
344 mapper = DeviceMapSetting::DummyNccl {
345 nm_device: available_devices[0].clone(),
346 };
347 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
348 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
350
351 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
353 in_situ_quant = None;
354 }
355
356 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
359 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
360 let weight_pack_factor = {
361 let ser_artifacts = unsafe {
362 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
363 };
364 let mut total_pack_factors = 0;
365 let total_tensors = ser_artifacts.tensors().len();
366 for (_, artifact) in ser_artifacts.tensors() {
367 let artifact = artifact.data();
368 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
370 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
371 {
372 QuantizedSerdeType::Hqq => {
373 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
374 .pack_factor(dtype)
375 }
376 QuantizedSerdeType::Gguf => {
377 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
378 .pack_factor(dtype)
379 }
380 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
381 QuantizedSerdeType::Unquant => 1,
382 QuantizedSerdeType::Afq => {
383 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
384 .pack_factor(dtype)
385 }
386 };
387 total_pack_factors += pack_factor;
388 }
389
390 total_pack_factors / total_tensors
391 };
392
393 let layer_sizes_in_bytes =
394 self.inner
395 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
396 let non_mapped_size_in_bytes =
397 self.inner
398 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
399 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
400 (
401 layer_sizes_in_bytes,
402 non_mapped_size_in_bytes,
403 layer_sizes_sum + non_mapped_size_in_bytes,
404 )
405 } else if let Some(isq) = in_situ_quant {
406 let weight_pack_factor = isq.pack_factor(dtype);
407 let layer_sizes_in_bytes =
408 self.inner
409 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
410 let non_mapped_size_in_bytes =
411 self.inner
412 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
413 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
414 (
415 layer_sizes_in_bytes,
416 non_mapped_size_in_bytes,
417 layer_sizes_sum + non_mapped_size_in_bytes,
418 )
419 } else {
420 let weight_pack_factor =
422 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
423 let layer_sizes_in_bytes =
424 self.inner
425 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
426 let non_mapped_size_in_bytes =
427 self.inner
428 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
429 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
430 (
431 layer_sizes_in_bytes,
432 non_mapped_size_in_bytes,
433 layer_sizes_sum + non_mapped_size_in_bytes,
434 )
435 };
436
437 let new = auto_device_map::get_device_layers(
438 &*self.inner,
439 &config,
440 self.inner.num_layers(&config)?,
441 layer_sizes_in_bytes,
442 non_mapped_size_in_bytes,
443 total_model_size_in_bytes,
444 &available_devices,
445 dtype,
446 ¶ms,
447 prompt_chunksize,
448 paged_attn_config.as_ref(),
449 )?;
450 mapper = DeviceMapSetting::Map(new);
451 }
452
453 let pipeline_mapper = mapper.into_mapper(
454 self.inner.num_layers(&config)?,
455 &device,
456 self.config.topology.as_ref(),
457 )?;
458 let mapper = mapper.into_mapper(
459 self.inner.num_layers(&config)?,
460 &device,
461 self.config.topology.as_ref(),
462 )?;
463 let mut layer_devices = Vec::new();
464 for layer in 0..self.inner.num_layers(&config)? {
465 let device = mapper.device_for(layer, false).cloned();
466 layer_devices.push(device);
467 }
468 let dtype = mapper.get_min_dtype(dtype)?;
469
470 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
473 if mapping_uses_cpu {
474 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
475 paged_attn_config = None;
476 }
477
478 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
479 if crate::using_flash_attn() {
480 once_log_info("FlashAttention is enabled.");
481 }
482
483 let mut loading_isq = if self.config.imatrix.is_none()
485 && self.config.calibration_file.is_none()
486 && !device.is_cuda()
487 && self.config.write_uqff.is_none()
488 && in_situ_quant.is_some()
489 {
490 let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
491 {
492 self.inner.immediate_isq_predicates_moqe(&config)?
493 } else {
494 self.inner.immediate_isq_predicates(&config)?
495 };
496 info!("Applying ISQ to {in_situ_quant:?}");
497 if predicates.is_empty() {
498 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
499 }
500 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
501 false
502 } else {
503 in_situ_quant.is_some()
504 };
505
506 if let Some(ref topology) = self.config.topology {
507 loading_isq |= topology
508 .0
509 .iter()
510 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
511 }
512
513 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
514 anyhow::bail!(
515 "`imatrix` and `calibration_file` were both specified, this is not allowed."
516 );
517 }
518
519 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
521 loading_isq = false;
522 device.clone()
523 } else {
524 Device::Cpu
525 };
526
527 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
528
529 let attention_mechanism = if paged_attn_config.is_some() {
530 AttentionImplementation::PagedAttention
531 } else {
532 AttentionImplementation::Eager
533 };
534
535 let multi_progress = Arc::new(MultiProgress::new());
536
537 let mut model = if use_nccl || cfg!(feature = "ring") {
538 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
539 dtype,
540 &device,
541 &available_devices,
542 silent,
543 &config,
544 loading_isq,
545 self.config.from_uqff.is_some(),
546 self.config.organization,
547 &*self.inner,
548 paths.as_ref(),
549 )?;
550
551 match self.kind {
553 ModelKind::Normal => normal_model_loader_sharded!(
554 sharded_vb,
555 config,
556 self.inner,
557 mapper,
558 loading_isq,
559 device.clone(),
560 attention_mechanism,
561 multi_progress.clone(),
562 ),
563 ModelKind::Adapter {
564 adapter: AdapterKind::XLora,
565 } => xlora_model_loader!(
566 paths,
567 Some(dtype),
568 &load_device,
569 layer_devices.clone(),
570 config,
571 self.inner,
572 silent,
573 mapper,
574 loading_isq,
575 device.clone(),
576 multi_progress.clone(),
577 ),
578 ModelKind::Adapter {
579 adapter: AdapterKind::Lora,
580 } => lora_model_loader!(
581 paths,
582 Some(dtype),
583 &load_device,
584 layer_devices.clone(),
585 config,
586 self.inner,
587 silent,
588 mapper,
589 loading_isq,
590 self.config.from_uqff.is_some(),
591 device.clone(),
592 attention_mechanism,
593 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
594 multi_progress.clone(),
595 ),
596 _ => unreachable!(),
597 }
598 } else {
599 match self.kind {
600 ModelKind::Normal => normal_model_loader!(
601 paths,
602 Some(dtype),
603 &load_device,
604 layer_devices.clone(),
605 config,
606 self.inner,
607 silent,
608 mapper,
609 loading_isq,
610 self.config.from_uqff.is_some(),
611 device.clone(),
612 attention_mechanism,
613 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
614 multi_progress.clone(),
615 ),
616 ModelKind::Adapter {
617 adapter: AdapterKind::XLora,
618 } => xlora_model_loader!(
619 paths,
620 Some(dtype),
621 &load_device,
622 layer_devices.clone(),
623 config,
624 self.inner,
625 silent,
626 mapper,
627 loading_isq,
628 device.clone(),
629 multi_progress.clone(),
630 ),
631 ModelKind::Adapter {
632 adapter: AdapterKind::Lora,
633 } => lora_model_loader!(
634 paths,
635 Some(dtype),
636 &load_device,
637 layer_devices.clone(),
638 config,
639 self.inner,
640 silent,
641 mapper,
642 loading_isq,
643 self.config.from_uqff.is_some(),
644 device.clone(),
645 attention_mechanism,
646 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
647 multi_progress.clone(),
648 ),
649 _ => unreachable!(),
650 }
651 };
652
653 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
654 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
655 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
656 Ok(conf) => Some(conf),
657 Err(e) => {
658 warn!("Failed to parse generation_config.json: {}", e);
659 None
660 }
661 }
662 });
663
664 let chat_template_explicit = paths
665 .get_chat_template_explicit()
666 .as_ref()
667 .map(|x| x.to_string_lossy().to_string());
668 let chat_template = get_chat_template(
669 paths,
670 self.jinja_explicit.as_ref(),
671 chat_template_explicit.as_ref(),
672 self.chat_template.as_ref(),
673 None,
674 );
675
676 if let Some(calibration_file) = &self.config.calibration_file {
677 let calibration_data = std::fs::read_to_string(calibration_file)?;
678 let tokens = tokenizer
680 .encode_fast(calibration_data, false)
681 .map_err(anyhow::Error::msg)?
682 .get_ids()
683 .to_vec();
684 info!(
685 "Collecting imatrix from calibration file `{}` of {} tokens.",
686 calibration_file.display(),
687 tokens.len()
688 );
689 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
690 let bos_tok_id = tokenizer
691 .token_to_id(&bos_toks[0])
692 .expect("Somehow the bos token is not present.");
693
694 match self.config.organization {
695 IsqOrganization::Default => model.begin_track_stats()?,
696 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
697 }
698
699 const CHUNK_SIZE: usize = 1024;
700 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
701 let start = Instant::now();
702 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
703 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
704 let chunk_len = chunk.len();
705
706 let start = Instant::now();
707 let inputs = make_prompt_chunk(
708 0,
709 vec![&chunk],
710 &[0],
711 &load_device,
712 None,
713 false,
714 None,
715 Some(pipeline_mapper.as_ref()),
716 )?;
717
718 model.forward(
719 &inputs.input.to_device(model.device())?,
720 &inputs.positions,
721 inputs.context_lens.clone(),
722 inputs.position_ids.clone(),
723 None,
724 &inputs.flash_meta.clone(),
725 )?;
726
727 match model.cache_mut() {
728 EitherCache::Full(full) => {
729 for layer in &mut *full.lock() {
730 *layer = None
731 }
732 }
733 EitherCache::Normal(normal) => {
734 for layer in &mut *normal.lock().unwrap().0 {
735 layer.reset();
736 }
737 }
738 }
739
740 let end = Instant::now();
741 info!(
742 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
743 i + 1,
744 end.duration_since(start).as_secs_f32()
745 );
746 }
747 load_device.synchronize()?;
748 let end = Instant::now();
749 info!(
750 "Finished collecting imatrix in {:.2}s",
751 end.duration_since(start).as_secs_f32()
752 );
753 }
754
755 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
757 let imatrix_source = match (
758 self.config.imatrix.as_ref(),
759 self.config.calibration_file.is_some(),
760 ) {
761 (None, false) => None,
762 (Some(file), false) => Some(ImatrixDataSource::File(file)),
763 (None, true) => Some(ImatrixDataSource::Collected),
764 (Some(_), true) => unreachable!(),
765 };
766
767 info!("Applying ISQ to all ranks.");
768
769 let multi_progress = Arc::new(MultiProgress::new());
770
771 model.quantize(
772 in_situ_quant,
773 model.device().clone(),
774 self.config.topology.as_ref(),
775 silent,
776 imatrix_source,
777 self.config.organization,
778 self.config.write_uqff.as_ref(),
779 UqffFullSer {
780 tokenizer: &tokenizer,
781 template_filename: paths.get_template_filename(),
782 generation_config: paths.get_gen_conf_filename(),
783 config: config.clone(),
784 processor_filename: &None,
785 preprocessor_filename: &None,
786 },
787 multi_progress.clone(),
788 )?;
789 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
790 model.load_from_artifacts(
791 device.clone(),
792 self.config.topology.as_ref(),
793 silent,
794 from_uqff,
795 )?;
796 }
797
798 let paged_attn_config = if matches!(
799 self.kind,
800 ModelKind::Adapter {
801 adapter: AdapterKind::XLora
802 }
803 ) {
804 warn!(
805 "Adapter parallel_models do not currently support PagedAttention, running without"
806 );
807 None
808 } else {
809 paged_attn_config
810 };
811
812 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
813 let cache_config = calculate_cache_config(
814 paged_attn_config.mem_gpu,
815 paged_attn_config.mem_cpu,
816 paged_attn_config.block_size,
817 dtype,
818 paged_attn_config.cache_type,
819 model.config(),
820 &device,
821 &pipeline_mapper
822 .get_unique_devices()
823 .into_iter()
824 .map(Some)
825 .collect::<Vec<_>>(),
826 silent,
827 )?;
828
829 let mut layer_devices = Vec::new();
830 for layer in 0..self.inner.num_layers(&config)? {
831 let device = model.get_layers().1.device_for(layer, false).cloned();
832 layer_devices.push(device);
833 }
834 let cache_engine = CacheEngine::new(
835 model.config(),
836 &cache_config,
837 dtype,
838 model.device(),
839 layer_devices.clone(),
840 )?;
841
842 (Some(cache_config), Some(cache_engine))
843 } else {
844 (None, None)
845 };
846
847 let max_seq_len = model.max_seq_len();
848 let llg_factory = build_llg_factory(tokenizer.clone())?;
849 let num_hidden_layers = match model.cache() {
850 EitherCache::Full(full) => full.lock().len(),
851 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
852 };
853 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
854 let sliding_window = model.config().sliding_window;
855 let model_metadata = Arc::new(model.config().clone());
856
857 Ok(Arc::new(Mutex::new(NormalPipeline {
858 model,
859 tokenizer: tokenizer.into(),
860 no_kv_cache: self.no_kv_cache,
861 chat_template: Arc::new(chat_template),
862 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
863 NonGranularState {
864 non_granular_index: Arc::new(Mutex::new(0)),
865 tgt_non_granular_index,
866 }
867 }),
868 model_id: self.model_id.clone(),
869 metadata: Arc::new(GeneralMetadata {
870 max_seq_len,
871 llg_factory: Some(llg_factory),
872 no_kv_cache: self.no_kv_cache,
873 no_prefix_cache: is_xlora,
874 num_hidden_layers,
875 eos_tok: eos,
876 kind: self.kind.clone(),
877 is_xlora,
878 activation_dtype: dtype,
879 sliding_window,
880 cache_config,
881 cache_engine,
882 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
883 model_metadata: Some(model_metadata),
884 modalities: Modalities {
885 input: vec![SupportedModality::Text],
886 output: vec![SupportedModality::Text],
887 },
888 }),
889 topology: self.config.topology.clone(),
890 silent,
891 organization: self.config.organization,
892 template_filename: paths.get_template_filename().clone(),
893 generation_config: paths.get_gen_conf_filename().cloned(),
894 config,
895 imatrix: self.config.imatrix.clone(),
896 mapper: pipeline_mapper,
897 })))
898 }
899
900 fn get_id(&self) -> String {
901 self.model_id.clone()
902 }
903
904 fn get_kind(&self) -> ModelKind {
905 self.kind.clone()
906 }
907}
908
909impl PreProcessingMixin for NormalPipeline {
910 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
911 Some(self.chat_template.clone())
912 }
913 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
914 None
915 }
916}
917
918impl IsqPipelineMixin for NormalPipeline {
919 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
920 let device = self.device().clone();
921 let multi_progress = Arc::new(MultiProgress::new());
922 self.model.quantize(
923 Some(dtype),
924 device.clone(),
925 self.topology.as_ref(),
926 self.silent,
927 self.imatrix.as_ref().map(ImatrixDataSource::File),
928 self.organization,
929 None,
930 UqffFullSer {
931 tokenizer: &self.tokenizer,
932 template_filename: &self.template_filename,
933 generation_config: self.generation_config.as_ref(),
934 config: self.config.clone(),
935 processor_filename: &None,
936 preprocessor_filename: &None,
937 },
938 multi_progress.clone(),
939 )?;
940 Ok(())
941 }
942}
943
944impl CacheManagerMixin for NormalPipeline {
945 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
946 if matches!(self.model.cache(), EitherCache::Full(_)) {
947 FullCacheManager.clone_in_cache(self, seqs, false)
948 } else {
949 NormalCacheManager.clone_in_cache(self, seqs, false)
950 }
951 }
952 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
953 if matches!(self.model.cache(), EitherCache::Full(_)) {
954 FullCacheManager.clone_out_cache(self, seqs, false)
955 } else {
956 NormalCacheManager.clone_out_cache(self, seqs, false)
957 }
958 }
959 fn set_none_cache(
960 &self,
961 seqs: &mut [&mut Sequence],
962 reset_non_granular: bool,
963 modify_draft_cache: bool,
964 load_preallocated_cache: bool,
965 ) {
966 if matches!(self.model.cache(), EitherCache::Full(_)) {
967 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
968 } else {
969 NormalCacheManager.set_none_cache(
970 self,
971 seqs,
972 modify_draft_cache,
973 load_preallocated_cache,
974 );
975 }
976 if reset_non_granular {
977 self.reset_non_granular_state()
978 }
979 }
980 fn cache(&self) -> &EitherCache {
981 self.model.cache()
982 }
983}
984
985impl MetadataMixin for NormalPipeline {
986 fn device(&self) -> Device {
987 self.model.device().clone()
988 }
989 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
990 Some(self.tokenizer.clone())
991 }
992 fn name(&self) -> String {
993 self.model_id.clone()
994 }
995 fn reset_non_granular_state(&self) {
996 if let Some(s) = self.non_granular_state.as_ref() {
997 *self.cache().full().get_scalings_cache() = None;
998 *get_mut_arcmutex!(s.non_granular_index) = 0;
999 }
1000 }
1001 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1002 self.metadata.clone()
1003 }
1004 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1005 Some(&*self.mapper)
1006 }
1007}
1008
1009#[async_trait::async_trait]
1010impl Pipeline for NormalPipeline {
1011 fn forward_inputs(
1012 &mut self,
1013 inputs: Box<dyn Any>,
1014 return_raw_logits: bool,
1015 ) -> Result<ForwardInputsResult, candle_core::Error> {
1016 let ModelInputs {
1017 input_ids,
1018 input_ids_full,
1019 seqlen_offsets,
1020 seqlen_offsets_full,
1021 context_lens,
1022 position_ids,
1023 paged_attn_meta,
1024 flash_meta,
1025 flash_meta_full,
1026 } = *inputs.downcast().expect("Downcast failed.");
1027 let metadata = self.get_metadata();
1028 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1029 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1030 (Some(_), None) => {
1031 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1033 }
1034 (None, Some(_)) => {
1035 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1037 }
1038 (None, None) => None,
1039 };
1040 let logits = match self.model.is_xlora() {
1041 false => {
1042 let paged_attn_meta = paged_attn_meta
1043 .as_ref()
1044 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1045
1046 self.model.forward(
1047 &input_ids,
1048 &seqlen_offsets,
1049 context_lens,
1050 position_ids,
1051 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1052 &flash_meta,
1053 )?
1054 }
1055 true => self.model.xlora_forward(
1056 &input_ids,
1057 input_ids_full.as_ref().unwrap_or(&input_ids),
1058 &seqlen_offsets,
1059 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1060 self.no_kv_cache,
1061 &self.non_granular_state,
1062 context_lens,
1063 position_ids,
1064 &flash_meta,
1065 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1066 )?,
1067 };
1068 if return_raw_logits {
1069 Ok(ForwardInputsResult::RawLogits { logits })
1070 } else {
1071 Ok(ForwardInputsResult::CausalGeneration { logits })
1072 }
1073 }
1074 async fn sample_causal_gen(
1075 &self,
1076 seqs: &mut [&mut Sequence],
1077 logits: Vec<Tensor>,
1078 prefix_cacher: &mut PrefixCacheManagerV2,
1079 disable_eos_stop: bool,
1080 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1081 ) -> Result<(), candle_core::Error> {
1082 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1083 }
1084 fn category(&self) -> ModelCategory {
1085 ModelCategory::Text
1086 }
1087}
1088
1089impl AnyMoePipelineMixin for NormalPipeline {
1090 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1091 self.model.finish_training(gate_model_id)
1092 }
1093 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1094 self.model.get_vars()
1095 }
1096 fn amoe_base_model_trainable_params(&self) -> usize {
1097 self.model.trainable_params()
1098 }
1099 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1100 self.model.take_cached_gating_outputs()
1101 }
1102 fn amoe_create_layers(
1103 &mut self,
1104 model_ids: Vec<String>,
1105 token: &TokenSource,
1106 revision: Option<String>,
1107 match_regex: &str,
1108 config: crate::amoe::AnyMoeConfig,
1109 dtype: candle_core::DType,
1110 dev: &Device,
1111 (prefix, mlp): (String, String),
1112 layers: Vec<usize>,
1113 expert_type: AnyMoeExpertType,
1114 silent: bool,
1115 gate_model_id: Option<String>,
1116 ) -> candle_core::Result<()> {
1117 let mut vbs = Vec::new();
1118 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1120 for model_id in model_ids {
1121 let model_id_str = &model_id;
1122 let model_id = Path::new(&model_id);
1123
1124 let api = {
1125 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1126 let mut api = ApiBuilder::from_cache(cache)
1127 .with_progress(!silent)
1128 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1129 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1130 api = api.with_cache_dir(x.into());
1131 }
1132 api.build().map_err(candle_core::Error::msg)?
1133 };
1134 let revision = revision.clone().unwrap_or("main".to_string());
1135 let api = api.repo(Repo::with_revision(
1136 model_id_str.clone(),
1137 RepoType::Model,
1138 revision.clone(),
1139 ));
1140
1141 let mut filenames = vec![];
1142 for rfilename in
1143 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1144 {
1145 filenames.push(api_get_file!(api, &rfilename, model_id));
1146 }
1147
1148 let regex = regex.clone();
1149 let match_regex_clone = match_regex.to_string();
1150 let layers_clone = layers.clone();
1151 let vb = from_mmaped_safetensors(
1152 filenames,
1153 vec![],
1154 Some(dtype),
1155 dev,
1156 vec![None],
1157 silent,
1158 None,
1159 move |key| {
1160 if regex.is_match(&key) {
1161 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1164 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1165 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1166 .parse::<usize>()
1167 .unwrap();
1168 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1169 } else {
1170 false
1171 }
1172 },
1173 Arc::new(|_| DeviceForLoadTensor::Base),
1174 )?;
1175 vbs.push(vb);
1176 }
1177
1178 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1179 let model_id_str = &gate_model_id;
1180 let model_id = Path::new(&gate_model_id);
1181
1182 let api = {
1183 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1184 let mut api = ApiBuilder::from_cache(cache)
1185 .with_progress(!silent)
1186 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1187 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1188 api = api.with_cache_dir(x.into());
1189 }
1190 api.build().map_err(candle_core::Error::msg)?
1191 };
1192 let revision = revision.clone().unwrap_or("main".to_string());
1193 let api = api.repo(Repo::with_revision(
1194 model_id_str.clone(),
1195 RepoType::Model,
1196 revision.clone(),
1197 ));
1198
1199 let mut gate_filenames = vec![];
1200 for rfilename in
1201 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1202 {
1203 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1204 }
1205 assert_eq!(
1206 gate_filenames.len(),
1207 1,
1208 "Gate model ID must contain only one .safetensors file"
1209 );
1210
1211 let vb = from_mmaped_safetensors(
1212 gate_filenames.clone(),
1213 vec![],
1214 Some(dtype),
1215 dev,
1216 vec![None],
1217 silent,
1218 None,
1219 |_| true,
1220 Arc::new(|_| DeviceForLoadTensor::Base),
1221 )?;
1222 info!(
1223 "Loaded gating layers from `{}`",
1224 gate_filenames[0].display()
1225 );
1226 Some(vb)
1227 } else {
1228 None
1229 };
1230
1231 self.model.create_anymoe_layers(
1232 vbs.clone(),
1233 config.clone(),
1234 (prefix.clone(), mlp.clone()),
1235 layers.clone(),
1236 expert_type.clone(),
1237 gate_vb.clone(),
1238 )?;
1239
1240 Ok(())
1241 }
1242 fn amoe_supported(&self) -> bool {
1243 self.model.amoe_supported()
1244 }
1245}