1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
3use super::isq::ImatrixDataSource;
4use super::llg::build_tok_env;
5use super::{
6 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
7 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
8 TokenSource,
9};
10use super::{
11 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
12 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
13};
14use super::{
15 AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
16 MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
17 Qwen2Loader, Starcoder2Loader,
18};
19use crate::amoe::AnyMoeExpertType;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::get_chat_template;
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::sampling::sample_and_add_toks;
28use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
29use crate::pipeline::{ChatTemplate, LocalModelPaths};
30use crate::prefix_cacher::PrefixCacheManagerV2;
31use crate::sequence::Sequence;
32use crate::utils::tokenizer::get_tokenizer;
33use crate::utils::varbuilder_utils::DeviceForLoadTensor;
34use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
35use crate::xlora_models::NonGranularState;
36use crate::{
37 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
38 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
39 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
40};
41use anyhow::Result;
42use candle_core::{Device, Tensor, Var};
43use hf_hub::Cache;
44use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
45use indicatif::MultiProgress;
46use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
47use rand_isaac::Isaac64Rng;
48use regex_automata::meta::Regex;
49use std::any::Any;
50use std::borrow::Cow;
51use std::num::{NonZero, NonZeroUsize};
52use std::path::{Path, PathBuf};
53use std::str::FromStr;
54use std::sync::{Arc, RwLock};
55use std::time::Instant;
56use std::{env, fs};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{info, warn};
60
61pub struct NormalPipeline {
62 model: Box<dyn NormalModel + Send + Sync>,
63 tokenizer: Arc<Tokenizer>,
64 no_kv_cache: bool,
65 chat_template: Arc<ChatTemplate>,
66 non_granular_state: Option<NonGranularState>,
67 model_id: String,
68 metadata: Arc<GeneralMetadata>,
69 topology: Option<Topology>,
70 silent: bool,
71 organization: IsqOrganization,
72 template_filename: Option<PathBuf>,
74 generation_config: Option<PathBuf>,
75 config: String,
76 imatrix: Option<PathBuf>,
77 mapper: Box<dyn DeviceMapper + Send + Sync>,
78}
79
80pub struct NormalLoader {
82 inner: Box<dyn NormalModelLoader>,
83 model_id: String,
84 config: NormalSpecificConfig,
85 xlora_model_id: Option<String>,
86 lora_adapter_ids: Option<Vec<String>>,
87 kind: ModelKind,
88 xlora_order: Option<Ordering>,
89 no_kv_cache: bool,
90 chat_template: Option<String>,
91 tokenizer_json: Option<String>,
92 tgt_non_granular_index: Option<usize>,
93 token_source: RwLock<Option<TokenSource>>,
94 revision: RwLock<Option<String>>,
95 from_uqff: RwLock<Option<PathBuf>>,
96 jinja_explicit: Option<String>,
97 hf_cache_path: Option<PathBuf>,
98}
99
100#[derive(Default)]
101pub struct NormalLoaderBuilder {
103 model_id: Option<String>,
104 config: NormalSpecificConfig,
105 xlora_model_id: Option<String>,
106 lora_adapter_ids: Option<Vec<String>>,
107 kind: ModelKind,
108 xlora_order: Option<Ordering>,
109 no_kv_cache: bool,
110 chat_template: Option<String>,
111 tokenizer_json: Option<String>,
112 tgt_non_granular_index: Option<usize>,
113 jinja_explicit: Option<String>,
114 hf_cache_path: Option<PathBuf>,
115}
116
117#[derive(Clone, Default)]
118pub struct NormalSpecificConfig {
120 pub use_flash_attn: bool,
121 pub prompt_chunksize: Option<NonZeroUsize>,
122 pub topology: Option<Topology>,
123 pub organization: IsqOrganization,
124 pub write_uqff: Option<PathBuf>,
125 pub from_uqff: Option<PathBuf>,
126 pub imatrix: Option<PathBuf>,
127 pub calibration_file: Option<PathBuf>,
128 pub hf_cache_path: Option<PathBuf>,
129}
130
131impl NormalLoaderBuilder {
132 pub fn new(
133 config: NormalSpecificConfig,
134 chat_template: Option<String>,
135 tokenizer_json: Option<String>,
136 model_id: Option<String>,
137 no_kv_cache: bool,
138 jinja_explicit: Option<String>,
139 ) -> Self {
140 Self {
141 config,
142 chat_template,
143 tokenizer_json,
144 model_id,
145 kind: ModelKind::Normal,
146 jinja_explicit,
147 no_kv_cache,
148 ..Default::default()
149 }
150 }
151
152 fn with_adapter(
153 mut self,
154 xlora_model_id: String,
155 xlora_order: Ordering,
156 no_kv_cache: bool,
157 tgt_non_granular_index: Option<usize>,
158 ) -> Self {
159 self.xlora_model_id = Some(xlora_model_id);
160 self.xlora_order = Some(xlora_order);
161 self.no_kv_cache = no_kv_cache;
162 self.tgt_non_granular_index = tgt_non_granular_index;
163 self.model_id = if let Some(id) = self.model_id {
164 Some(id)
165 } else {
166 info!(
167 "Using adapter base model ID: `{}`",
168 self.xlora_order.as_ref().unwrap().base_model_id
169 );
170 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
171 };
172 self
173 }
174
175 pub fn with_xlora(
176 mut self,
177 xlora_model_id: String,
178 xlora_order: Ordering,
179 no_kv_cache: bool,
180 tgt_non_granular_index: Option<usize>,
181 ) -> Self {
182 self.kind = ModelKind::Adapter {
183 adapter: AdapterKind::XLora,
184 };
185 self.with_adapter(
186 xlora_model_id,
187 xlora_order,
188 no_kv_cache,
189 tgt_non_granular_index,
190 )
191 }
192
193 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
194 self.kind = ModelKind::Adapter {
195 adapter: AdapterKind::Lora,
196 };
197 self.lora_adapter_ids = Some(lora_adapter_ids);
198 self
199 }
200
201 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
202 self.hf_cache_path = Some(hf_cache_path);
203 self
204 }
205
206 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
209 let loader: Box<dyn NormalModelLoader> = match loader_tp {
210 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
211 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
212 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
213 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
214 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
215 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
216 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
217 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
218 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
219 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
220 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
221 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
222 None => Box::new(AutoLoader),
223 };
224 Ok(Box::new(NormalLoader {
225 inner: loader,
226 model_id: self.model_id.unwrap(),
227 config: self.config,
228 xlora_model_id: self.xlora_model_id,
229 lora_adapter_ids: self.lora_adapter_ids,
230 kind: self.kind,
231 xlora_order: self.xlora_order,
232 no_kv_cache: self.no_kv_cache,
233 chat_template: self.chat_template,
234 tokenizer_json: self.tokenizer_json,
235 tgt_non_granular_index: self.tgt_non_granular_index,
236 jinja_explicit: self.jinja_explicit,
237 token_source: RwLock::new(None),
238 revision: RwLock::new(None),
239 from_uqff: RwLock::new(None),
240 hf_cache_path: self.hf_cache_path,
241 }))
242 }
243}
244
245impl Loader for NormalLoader {
246 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
247 fn load_model_from_hf(
248 &self,
249 revision: Option<String>,
250 token_source: TokenSource,
251 dtype: &dyn TryIntoDType,
252 device: &Device,
253 silent: bool,
254 mapper: DeviceMapSetting,
255 in_situ_quant: Option<IsqType>,
256 paged_attn_config: Option<PagedAttentionConfig>,
257 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
258 let cache = self
259 .hf_cache_path
260 .clone()
261 .map(Cache::new)
262 .unwrap_or_default();
263 GLOBAL_HF_CACHE.get_or_init(|| cache);
264
265 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
266 LocalModelPaths,
267 &token_source,
268 revision.clone(),
269 self,
270 None,
271 None,
272 silent,
273 self.config.from_uqff.is_some()
274 );
275 if let Some(from_uqff) = self.config.from_uqff.clone() {
276 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
277 }
278 *self
279 .token_source
280 .write()
281 .expect("Failed to write to token source") = Some(token_source);
282 *self.revision.write().expect("Failed to write to revision") = revision;
283 self.load_model_from_path(
284 &paths?,
285 dtype,
286 device,
287 silent,
288 mapper,
289 in_situ_quant,
290 paged_attn_config,
291 )
292 }
293
294 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
295 fn load_model_from_path(
296 &self,
297 paths: &Box<dyn ModelPaths>,
298 dtype: &dyn TryIntoDType,
299 device: &Device,
300 silent: bool,
301 mut mapper: DeviceMapSetting,
302 in_situ_quant: Option<IsqType>,
303 mut paged_attn_config: Option<PagedAttentionConfig>,
304 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
305 let config = std::fs::read_to_string(paths.get_config_filename())?;
306
307 let prompt_chunksize = self
309 .config
310 .prompt_chunksize
311 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
312 .get();
313
314 info!("Prompt chunk size is {prompt_chunksize}.",);
315
316 let use_nccl = mistralrs_quant::distributed::use_nccl();
317
318 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
319 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
320 let WorkerTransferData::Init { id: _, worker_rank } = payload;
321 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
322 } else if use_nccl {
323 vec![candle_core::Device::new_cuda(0)?]
324 } else {
325 device_map::get_all_similar_devices(device)?
326 };
327 let device = if use_nccl {
328 available_devices[0].clone()
329 } else {
330 device.clone()
331 };
332
333 if use_nccl {
335 mapper = DeviceMapSetting::DummyNccl {
336 nm_device: available_devices[0].clone(),
337 };
338 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
339 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
341
342 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
345 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
346 let weight_pack_factor = {
347 let ser_artifacts = unsafe {
348 candle_core::safetensors::MmapedSafetensors::new(serialized)?
349 };
350 let mut total_pack_factors = 0;
351 let total_tensors = ser_artifacts.tensors().len();
352 for (_, artifact) in ser_artifacts.tensors() {
353 let artifact = artifact.data();
354 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
356 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
357 {
358 QuantizedSerdeType::Hqq => {
359 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
360 .pack_factor(dtype)
361 }
362 QuantizedSerdeType::Gguf => {
363 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
364 .pack_factor(dtype)
365 }
366 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
367 QuantizedSerdeType::Unquant => 1,
368 QuantizedSerdeType::Afq => {
369 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
370 .pack_factor(dtype)
371 }
372 };
373 total_pack_factors += pack_factor;
374 }
375
376 total_pack_factors / total_tensors
377 };
378
379 let layer_sizes_in_bytes =
380 self.inner
381 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
382 let non_mapped_size_in_bytes =
383 self.inner
384 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
385 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
386 (
387 layer_sizes_in_bytes,
388 non_mapped_size_in_bytes,
389 layer_sizes_sum + non_mapped_size_in_bytes,
390 )
391 } else if let Some(isq) = in_situ_quant {
392 let weight_pack_factor = isq.pack_factor(dtype);
393 let layer_sizes_in_bytes =
394 self.inner
395 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
396 let non_mapped_size_in_bytes =
397 self.inner
398 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
399 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
400 (
401 layer_sizes_in_bytes,
402 non_mapped_size_in_bytes,
403 layer_sizes_sum + non_mapped_size_in_bytes,
404 )
405 } else {
406 let layer_sizes_in_bytes =
407 self.inner.layer_sizes_in_bytes(&config, dtype, 1)?;
408 let non_mapped_size_in_bytes =
409 self.inner.non_mapped_size_in_bytes(&config, dtype, 1)?;
410 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
411 (
412 layer_sizes_in_bytes,
413 non_mapped_size_in_bytes,
414 layer_sizes_sum + non_mapped_size_in_bytes,
415 )
416 };
417
418 let new = self.inner.get_device_layers(
419 &config,
420 self.inner.num_layers(&config)?,
421 layer_sizes_in_bytes,
422 non_mapped_size_in_bytes,
423 total_model_size_in_bytes,
424 &available_devices,
425 dtype,
426 ¶ms,
427 prompt_chunksize,
428 paged_attn_config.as_ref(),
429 )?;
430 mapper = DeviceMapSetting::Map(new);
431 }
432
433 let pipeline_mapper = mapper.into_mapper(
434 self.inner.num_layers(&config)?,
435 &device,
436 self.config.topology.as_ref(),
437 )?;
438 let mapper = mapper.into_mapper(
439 self.inner.num_layers(&config)?,
440 &device,
441 self.config.topology.as_ref(),
442 )?;
443 let mut layer_devices = Vec::new();
444 for layer in 0..self.inner.num_layers(&config)? {
445 let device = mapper.device_for(layer, false).cloned();
446 layer_devices.push(device);
447 }
448 let dtype = mapper.get_min_dtype(dtype)?;
449
450 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
453 if mapping_uses_cpu {
454 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
455 paged_attn_config = None;
456 }
457
458 info!(
459 "Model config: {:?}",
460 self.inner
461 .get_config_repr(&config, self.config.use_flash_attn)?
462 );
463
464 let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
465 if let Some(ref topology) = self.config.topology {
466 loading_isq |= topology
467 .0
468 .iter()
469 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
470 }
471
472 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
473 anyhow::bail!(
474 "`imatrix` and `calibration_file` were both specified, this is not allowed."
475 );
476 }
477
478 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
480 loading_isq = false;
481 device.clone()
482 } else {
483 Device::Cpu
484 };
485
486 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
487
488 let attention_mechanism = if paged_attn_config.is_some() {
489 AttentionImplementation::PagedAttention
490 } else {
491 AttentionImplementation::Eager
492 };
493
494 let multi_progress = Arc::new(MultiProgress::new());
495
496 let mut model = if use_nccl {
497 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
498 dtype,
499 &device,
500 &load_device,
501 &available_devices,
502 &config,
503 loading_isq,
504 self.config.from_uqff.is_some(),
505 self.config.organization,
506 &*self.inner,
507 paths.as_ref(),
508 )?;
509
510 match self.kind {
512 ModelKind::Normal => normal_model_loader_sharded!(
513 sharded_vb,
514 config,
515 self.inner,
516 self.config.use_flash_attn,
517 mapper,
518 loading_isq,
519 device.clone(),
520 attention_mechanism,
521 multi_progress.clone(),
522 ),
523 ModelKind::Adapter {
524 adapter: AdapterKind::XLora,
525 } => xlora_model_loader!(
526 paths,
527 Some(dtype),
528 &load_device,
529 layer_devices.clone(),
530 config,
531 self.inner,
532 self.config.use_flash_attn,
533 silent,
534 mapper,
535 loading_isq,
536 device.clone(),
537 multi_progress.clone(),
538 ),
539 ModelKind::Adapter {
540 adapter: AdapterKind::Lora,
541 } => lora_model_loader!(
542 paths,
543 Some(dtype),
544 &load_device,
545 layer_devices.clone(),
546 config,
547 self.inner,
548 self.config.use_flash_attn,
549 silent,
550 mapper,
551 loading_isq,
552 self.config.from_uqff.is_some(),
553 device.clone(),
554 attention_mechanism,
555 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
556 multi_progress.clone(),
557 ),
558 _ => unreachable!(),
559 }
560 } else {
561 match self.kind {
562 ModelKind::Normal => normal_model_loader!(
563 paths,
564 Some(dtype),
565 &load_device,
566 layer_devices.clone(),
567 config,
568 self.inner,
569 self.config.use_flash_attn,
570 silent,
571 mapper,
572 loading_isq,
573 self.config.from_uqff.is_some(),
574 device.clone(),
575 attention_mechanism,
576 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
577 multi_progress.clone(),
578 ),
579 ModelKind::Adapter {
580 adapter: AdapterKind::XLora,
581 } => xlora_model_loader!(
582 paths,
583 Some(dtype),
584 &load_device,
585 layer_devices.clone(),
586 config,
587 self.inner,
588 self.config.use_flash_attn,
589 silent,
590 mapper,
591 loading_isq,
592 device.clone(),
593 multi_progress.clone(),
594 ),
595 ModelKind::Adapter {
596 adapter: AdapterKind::Lora,
597 } => lora_model_loader!(
598 paths,
599 Some(dtype),
600 &load_device,
601 layer_devices.clone(),
602 config,
603 self.inner,
604 self.config.use_flash_attn,
605 silent,
606 mapper,
607 loading_isq,
608 self.config.from_uqff.is_some(),
609 device.clone(),
610 attention_mechanism,
611 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
612 multi_progress.clone(),
613 ),
614 _ => unreachable!(),
615 }
616 };
617
618 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
619 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
620 serde_json::from_str(&fs::read_to_string(f).unwrap())
621 .expect("bos_token_id/eos_token_id missing in generation_config.json")
622 });
623
624 let chat_template = get_chat_template(
625 paths,
626 &self.jinja_explicit,
627 &paths
628 .get_chat_template_explicit()
629 .as_ref()
630 .map(|x| x.to_string_lossy().to_string())
631 .clone(),
632 &self.chat_template,
633 None,
634 );
635
636 if let Some(calibration_file) = &self.config.calibration_file {
637 let calibration_data = std::fs::read_to_string(calibration_file)?;
638 let tokens = tokenizer
640 .encode_fast(calibration_data, false)
641 .map_err(anyhow::Error::msg)?
642 .get_ids()
643 .to_vec();
644 info!(
645 "Collecting imatrix from calibration file `{}` of {} tokens.",
646 calibration_file.display(),
647 tokens.len()
648 );
649 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
650 let bos_tok_id = tokenizer
651 .token_to_id(&bos_toks[0])
652 .expect("Somehow the bos token is not present.");
653
654 match self.config.organization {
655 IsqOrganization::Default => model.begin_track_stats()?,
656 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
657 }
658
659 const CHUNK_SIZE: usize = 1024;
660 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
661 let start = Instant::now();
662 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
663 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
664 let chunk_len = chunk.len();
665
666 let start = Instant::now();
667 let inputs = make_prompt_chunk(
668 0,
669 vec![chunk],
670 &[0],
671 &load_device,
672 None,
673 false,
674 None,
675 Some(pipeline_mapper.as_ref()),
676 )?;
677
678 model.forward(
679 &inputs.input.to_device(model.device())?,
680 &inputs.positions,
681 inputs.context_lens.clone(),
682 inputs.position_ids.clone(),
683 None,
684 &inputs.flash_meta.clone(),
685 )?;
686
687 match model.cache_mut() {
688 EitherCache::Full(full) => {
689 for layer in &mut *full.lock() {
690 *layer = None
691 }
692 }
693 EitherCache::Normal(normal) => {
694 for layer in &mut *normal.lock().unwrap().0 {
695 layer.reset();
696 }
697 }
698 }
699
700 let end = Instant::now();
701 info!(
702 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
703 i + 1,
704 end.duration_since(start).as_secs_f32()
705 );
706 }
707 load_device.synchronize()?;
708 let end = Instant::now();
709 info!(
710 "Finished collecting imatrix in {:.2}s",
711 end.duration_since(start).as_secs_f32()
712 );
713 }
714
715 if (in_situ_quant.is_some() || self.config.topology.is_some())
716 && self.config.from_uqff.is_none()
717 {
718 let imatrix_source = match (
719 self.config.imatrix.as_ref(),
720 self.config.calibration_file.is_some(),
721 ) {
722 (None, false) => None,
723 (Some(file), false) => Some(ImatrixDataSource::File(file)),
724 (None, true) => Some(ImatrixDataSource::Collected),
725 (Some(_), true) => unreachable!(),
726 };
727
728 info!("Applying ISQ to all ranks.");
729
730 let multi_progress = Arc::new(MultiProgress::new());
731
732 model.quantize(
733 in_situ_quant,
734 model.device().clone(),
735 self.config.topology.as_ref(),
736 silent,
737 imatrix_source,
738 self.config.organization,
739 self.config.write_uqff.as_ref(),
740 UqffFullSer {
741 tokenizer: &tokenizer,
742 template_filename: paths.get_template_filename(),
743 generation_config: paths.get_gen_conf_filename(),
744 config: config.clone(),
745 processor_filename: &None,
746 preprocessor_filename: &None,
747 },
748 multi_progress.clone(),
749 )?;
750 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
751 model.load_from_artifacts(
752 device.clone(),
753 self.config.topology.as_ref(),
754 silent,
755 from_uqff,
756 )?;
757 }
758
759 let paged_attn_config = if matches!(
760 self.kind,
761 ModelKind::Adapter {
762 adapter: AdapterKind::XLora
763 }
764 ) {
765 warn!(
766 "Adapter parallel_models do not currently support PagedAttention, running without"
767 );
768 None
769 } else {
770 paged_attn_config
771 };
772
773 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
774 let cache_config = calculate_cache_config(
775 paged_attn_config.mem_gpu,
776 paged_attn_config.mem_cpu,
777 paged_attn_config.block_size,
778 dtype,
779 model.config(),
780 &device,
781 &pipeline_mapper
782 .get_unique_devices()
783 .into_iter()
784 .map(Some)
785 .collect::<Vec<_>>(),
786 silent,
787 )?;
788
789 let mut layer_devices = Vec::new();
790 for layer in 0..self.inner.num_layers(&config)? {
791 let device = model.get_layers().1.device_for(layer, false).cloned();
792 layer_devices.push(device);
793 }
794 let cache_engine = CacheEngine::new(
795 model.config(),
796 &cache_config,
797 dtype,
798 model.device(),
799 layer_devices.clone(),
800 )?;
801
802 (Some(cache_config), Some(cache_engine))
803 } else {
804 (None, None)
805 };
806
807 let max_seq_len = model.max_seq_len();
808 let tok_env = build_tok_env(tokenizer.clone());
809 let num_hidden_layers = match model.cache() {
810 EitherCache::Full(full) => full.lock().len(),
811 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
812 };
813 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
814 let sliding_window = model.config().sliding_window;
815 let model_metadata = Arc::new(model.config().clone());
816
817 Ok(Arc::new(Mutex::new(NormalPipeline {
818 model,
819 tokenizer: tokenizer.into(),
820 no_kv_cache: self.no_kv_cache,
821 chat_template: Arc::new(chat_template),
822 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
823 NonGranularState {
824 non_granular_index: Arc::new(Mutex::new(0)),
825 tgt_non_granular_index,
826 }
827 }),
828 model_id: self.model_id.clone(),
829 metadata: Arc::new(GeneralMetadata {
830 max_seq_len,
831 tok_env: Some(tok_env),
832 no_kv_cache: self.no_kv_cache,
833 no_prefix_cache: is_xlora,
834 num_hidden_layers,
835 eos_tok: eos,
836 kind: self.kind.clone(),
837 is_xlora,
838 activation_dtype: dtype,
839 sliding_window,
840 cache_config,
841 cache_engine,
842 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
843 model_metadata: Some(model_metadata),
844 }),
845 topology: self.config.topology.clone(),
846 silent,
847 organization: self.config.organization,
848 template_filename: paths.get_template_filename().clone(),
849 generation_config: paths.get_gen_conf_filename().cloned(),
850 config,
851 imatrix: self.config.imatrix.clone(),
852 mapper: pipeline_mapper,
853 })))
854 }
855
856 fn get_id(&self) -> String {
857 self.model_id.clone()
858 }
859
860 fn get_kind(&self) -> ModelKind {
861 self.kind.clone()
862 }
863}
864
865impl PreProcessingMixin for NormalPipeline {
866 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
867 Some(self.chat_template.clone())
868 }
869 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
870 None
871 }
872}
873
874impl IsqPipelineMixin for NormalPipeline {
875 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
876 let device = self.device().clone();
877 let multi_progress = Arc::new(MultiProgress::new());
878 self.model.quantize(
879 Some(dtype),
880 device.clone(),
881 self.topology.as_ref(),
882 self.silent,
883 self.imatrix.as_ref().map(ImatrixDataSource::File),
884 self.organization,
885 None,
886 UqffFullSer {
887 tokenizer: &self.tokenizer,
888 template_filename: &self.template_filename,
889 generation_config: self.generation_config.as_ref(),
890 config: self.config.clone(),
891 processor_filename: &None,
892 preprocessor_filename: &None,
893 },
894 multi_progress.clone(),
895 )?;
896 Ok(())
897 }
898}
899
900impl CacheManagerMixin for NormalPipeline {
901 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
902 if matches!(self.model.cache(), EitherCache::Full(_)) {
903 FullCacheManager.clone_in_cache(self, seqs, false)
904 } else {
905 NormalCacheManager.clone_in_cache(self, seqs, false)
906 }
907 }
908 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
909 if matches!(self.model.cache(), EitherCache::Full(_)) {
910 FullCacheManager.clone_out_cache(self, seqs, false)
911 } else {
912 NormalCacheManager.clone_out_cache(self, seqs, false)
913 }
914 }
915 fn set_none_cache(
916 &self,
917 seqs: &mut [&mut Sequence],
918 reset_non_granular: bool,
919 modify_draft_cache: bool,
920 load_preallocated_cache: bool,
921 ) {
922 if matches!(self.model.cache(), EitherCache::Full(_)) {
923 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
924 } else {
925 NormalCacheManager.set_none_cache(
926 self,
927 seqs,
928 modify_draft_cache,
929 load_preallocated_cache,
930 );
931 }
932 if reset_non_granular {
933 self.reset_non_granular_state()
934 }
935 }
936 fn cache(&self) -> &EitherCache {
937 self.model.cache()
938 }
939}
940
941impl MetadataMixin for NormalPipeline {
942 fn device(&self) -> Device {
943 self.model.device().clone()
944 }
945 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
946 Some(self.tokenizer.clone())
947 }
948 fn name(&self) -> String {
949 self.model_id.clone()
950 }
951 fn reset_non_granular_state(&self) {
952 if let Some(s) = self.non_granular_state.as_ref() {
953 *self.cache().full().get_scalings_cache() = None;
954 *get_mut_arcmutex!(s.non_granular_index) = 0;
955 }
956 }
957 fn get_metadata(&self) -> Arc<GeneralMetadata> {
958 self.metadata.clone()
959 }
960 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
961 Some(&*self.mapper)
962 }
963}
964
965#[async_trait::async_trait]
966impl Pipeline for NormalPipeline {
967 fn forward_inputs(
968 &mut self,
969 inputs: Box<dyn Any>,
970 return_raw_logits: bool,
971 ) -> Result<ForwardInputsResult, candle_core::Error> {
972 let ModelInputs {
973 input_ids,
974 input_ids_full,
975 seqlen_offsets,
976 seqlen_offsets_full,
977 context_lens,
978 position_ids,
979 paged_attn_meta,
980 flash_meta,
981 flash_meta_full,
982 } = *inputs.downcast().expect("Downcast failed.");
983 let metadata = self.get_metadata();
984 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
985 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
986 (Some(_), None) => {
987 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
989 }
990 (None, Some(_)) => {
991 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
993 }
994 (None, None) => None,
995 };
996 #[cfg(feature = "metal")]
997 let logits = objc::rc::autoreleasepool(|| -> candle_core::Result<Tensor> {
998 match self.model.is_xlora() {
999 false => {
1000 let paged_attn_meta = paged_attn_meta
1001 .as_ref()
1002 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1003
1004 self.model.forward(
1005 &input_ids,
1006 &seqlen_offsets,
1007 context_lens,
1008 position_ids,
1009 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1010 &flash_meta,
1011 )
1012 }
1013 true => self.model.xlora_forward(
1014 &input_ids,
1015 input_ids_full.as_ref().unwrap_or(&input_ids),
1016 &seqlen_offsets,
1017 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1018 self.no_kv_cache,
1019 &self.non_granular_state,
1020 context_lens,
1021 position_ids,
1022 &flash_meta,
1023 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1024 ),
1025 }
1026 })?;
1027 #[cfg(not(feature = "metal"))]
1028 let logits = match self.model.is_xlora() {
1029 false => {
1030 let paged_attn_meta = paged_attn_meta
1031 .as_ref()
1032 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1033
1034 self.model.forward(
1035 &input_ids,
1036 &seqlen_offsets,
1037 context_lens,
1038 position_ids,
1039 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1040 &flash_meta,
1041 )?
1042 }
1043 true => self.model.xlora_forward(
1044 &input_ids,
1045 input_ids_full.as_ref().unwrap_or(&input_ids),
1046 &seqlen_offsets,
1047 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1048 self.no_kv_cache,
1049 &self.non_granular_state,
1050 context_lens,
1051 position_ids,
1052 &flash_meta,
1053 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1054 )?,
1055 };
1056 if return_raw_logits {
1057 Ok(ForwardInputsResult::RawLogits { logits })
1058 } else {
1059 Ok(ForwardInputsResult::CausalGeneration { logits })
1060 }
1061 }
1062 async fn sample_causal_gen(
1063 &self,
1064 seqs: &mut [&mut Sequence],
1065 logits: Vec<Tensor>,
1066 prefix_cacher: &mut PrefixCacheManagerV2,
1067 disable_eos_stop: bool,
1068 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1069 ) -> Result<(), candle_core::Error> {
1070 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1071 }
1072 fn category(&self) -> ModelCategory {
1073 ModelCategory::Text
1074 }
1075}
1076
1077impl AnyMoePipelineMixin for NormalPipeline {
1078 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1079 self.model.finish_training(gate_model_id)
1080 }
1081 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1082 self.model.get_vars()
1083 }
1084 fn amoe_base_model_trainable_params(&self) -> usize {
1085 self.model.trainable_params()
1086 }
1087 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1088 self.model.take_cached_gating_outputs()
1089 }
1090 fn amoe_create_layers(
1091 &mut self,
1092 model_ids: Vec<String>,
1093 token: &TokenSource,
1094 revision: Option<String>,
1095 match_regex: &str,
1096 config: crate::amoe::AnyMoeConfig,
1097 dtype: candle_core::DType,
1098 dev: &Device,
1099 (prefix, mlp): (String, String),
1100 layers: Vec<usize>,
1101 expert_type: AnyMoeExpertType,
1102 silent: bool,
1103 gate_model_id: Option<String>,
1104 ) -> candle_core::Result<()> {
1105 let mut vbs = Vec::new();
1106 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1108 for model_id in model_ids {
1109 let model_id_str = &model_id;
1110 let model_id = Path::new(&model_id);
1111
1112 let api = {
1113 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1114 let mut api = ApiBuilder::from_cache(cache)
1115 .with_progress(!silent)
1116 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1117 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1118 api = api.with_cache_dir(x.into());
1119 }
1120 api.build().map_err(candle_core::Error::msg)?
1121 };
1122 let revision = revision.clone().unwrap_or("main".to_string());
1123 let api = api.repo(Repo::with_revision(
1124 model_id_str.clone(),
1125 RepoType::Model,
1126 revision.clone(),
1127 ));
1128
1129 let mut filenames = vec![];
1130 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1131 filenames.push(api_get_file!(api, &rfilename, model_id));
1132 }
1133
1134 let regex = regex.clone();
1135 let match_regex_clone = match_regex.to_string();
1136 let layers_clone = layers.clone();
1137 let vb = from_mmaped_safetensors(
1138 filenames,
1139 vec![],
1140 Some(dtype),
1141 dev,
1142 vec![None],
1143 silent,
1144 None,
1145 move |key| {
1146 if regex.is_match(&key) {
1147 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1150 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1151 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1152 .parse::<usize>()
1153 .unwrap();
1154 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1155 } else {
1156 false
1157 }
1158 },
1159 Arc::new(|_| DeviceForLoadTensor::Base),
1160 )?;
1161 vbs.push(vb);
1162 }
1163
1164 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1165 let model_id_str = &gate_model_id;
1166 let model_id = Path::new(&gate_model_id);
1167
1168 let api = {
1169 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1170 let mut api = ApiBuilder::from_cache(cache)
1171 .with_progress(!silent)
1172 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1173 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1174 api = api.with_cache_dir(x.into());
1175 }
1176 api.build().map_err(candle_core::Error::msg)?
1177 };
1178 let revision = revision.clone().unwrap_or("main".to_string());
1179 let api = api.repo(Repo::with_revision(
1180 model_id_str.clone(),
1181 RepoType::Model,
1182 revision.clone(),
1183 ));
1184
1185 let mut gate_filenames = vec![];
1186 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1187 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1188 }
1189 assert_eq!(
1190 gate_filenames.len(),
1191 1,
1192 "Gate model ID must contain only one .safetensors file"
1193 );
1194
1195 let vb = from_mmaped_safetensors(
1196 gate_filenames.clone(),
1197 vec![],
1198 Some(dtype),
1199 dev,
1200 vec![None],
1201 silent,
1202 None,
1203 |_| true,
1204 Arc::new(|_| DeviceForLoadTensor::Base),
1205 )?;
1206 info!(
1207 "Loaded gating layers from `{}`",
1208 gate_filenames[0].display()
1209 );
1210 Some(vb)
1211 } else {
1212 None
1213 };
1214
1215 self.model.create_anymoe_layers(
1216 vbs.clone(),
1217 config.clone(),
1218 (prefix.clone(), mlp.clone()),
1219 layers.clone(),
1220 expert_type.clone(),
1221 gate_vb.clone(),
1222 )?;
1223
1224 Ok(())
1225 }
1226 fn amoe_supported(&self) -> bool {
1227 self.model.amoe_supported()
1228 }
1229}