mistralrs_core/pipeline/
normal.rs

1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7    TokenSource,
8};
9use super::{
10    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
15    MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
16    Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::get_chat_template;
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{ChatTemplate, LocalModelPaths};
31use crate::prefix_cacher::PrefixCacheManagerV2;
32use crate::sequence::Sequence;
33use crate::utils::tokenizer::get_tokenizer;
34use crate::utils::varbuilder_utils::DeviceForLoadTensor;
35use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
36use crate::xlora_models::NonGranularState;
37use crate::{
38    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
39    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
40    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
41};
42use anyhow::Result;
43use candle_core::{Device, Tensor, Var};
44use hf_hub::Cache;
45use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
46use indicatif::MultiProgress;
47use mistralrs_quant::log::once_log_info;
48use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
49use rand_isaac::Isaac64Rng;
50use regex_automata::meta::Regex;
51use std::any::Any;
52use std::borrow::Cow;
53use std::num::{NonZero, NonZeroUsize};
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use std::time::Instant;
58use std::{env, fs};
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63pub struct NormalPipeline {
64    model: Box<dyn NormalModel + Send + Sync>,
65    tokenizer: Arc<Tokenizer>,
66    no_kv_cache: bool,
67    chat_template: Arc<ChatTemplate>,
68    non_granular_state: Option<NonGranularState>,
69    model_id: String,
70    metadata: Arc<GeneralMetadata>,
71    topology: Option<Topology>,
72    silent: bool,
73    organization: IsqOrganization,
74    // For full UQFF serialization
75    template_filename: Option<PathBuf>,
76    generation_config: Option<PathBuf>,
77    config: String,
78    imatrix: Option<PathBuf>,
79    mapper: Box<dyn DeviceMapper + Send + Sync>,
80}
81
82/// A loader for a "normal" (non-quantized) model.
83pub struct NormalLoader {
84    inner: Box<dyn NormalModelLoader>,
85    model_id: String,
86    config: NormalSpecificConfig,
87    xlora_model_id: Option<String>,
88    lora_adapter_ids: Option<Vec<String>>,
89    kind: ModelKind,
90    xlora_order: Option<Ordering>,
91    no_kv_cache: bool,
92    chat_template: Option<String>,
93    tokenizer_json: Option<String>,
94    tgt_non_granular_index: Option<usize>,
95    token_source: RwLock<Option<TokenSource>>,
96    revision: RwLock<Option<String>>,
97    from_uqff: RwLock<Option<Vec<PathBuf>>>,
98    jinja_explicit: Option<String>,
99    hf_cache_path: Option<PathBuf>,
100}
101
102#[derive(Default)]
103/// A builder for a loader for a "normal" (non-quantized) model.
104pub struct NormalLoaderBuilder {
105    model_id: Option<String>,
106    config: NormalSpecificConfig,
107    xlora_model_id: Option<String>,
108    lora_adapter_ids: Option<Vec<String>>,
109    kind: ModelKind,
110    xlora_order: Option<Ordering>,
111    no_kv_cache: bool,
112    chat_template: Option<String>,
113    tokenizer_json: Option<String>,
114    tgt_non_granular_index: Option<usize>,
115    jinja_explicit: Option<String>,
116    hf_cache_path: Option<PathBuf>,
117}
118
119#[derive(Clone, Default)]
120/// Config specific to loading a normal model.
121pub struct NormalSpecificConfig {
122    pub prompt_chunksize: Option<NonZeroUsize>,
123    pub topology: Option<Topology>,
124    pub organization: IsqOrganization,
125    pub write_uqff: Option<PathBuf>,
126    pub from_uqff: Option<Vec<PathBuf>>,
127    pub imatrix: Option<PathBuf>,
128    pub calibration_file: Option<PathBuf>,
129    pub hf_cache_path: Option<PathBuf>,
130}
131
132impl NormalLoaderBuilder {
133    pub fn new(
134        config: NormalSpecificConfig,
135        chat_template: Option<String>,
136        tokenizer_json: Option<String>,
137        model_id: Option<String>,
138        no_kv_cache: bool,
139        jinja_explicit: Option<String>,
140    ) -> Self {
141        Self {
142            config,
143            chat_template,
144            tokenizer_json,
145            model_id,
146            kind: ModelKind::Normal,
147            jinja_explicit,
148            no_kv_cache,
149            ..Default::default()
150        }
151    }
152
153    fn with_adapter(
154        mut self,
155        xlora_model_id: String,
156        xlora_order: Ordering,
157        no_kv_cache: bool,
158        tgt_non_granular_index: Option<usize>,
159    ) -> Self {
160        self.xlora_model_id = Some(xlora_model_id);
161        self.xlora_order = Some(xlora_order);
162        self.no_kv_cache = no_kv_cache;
163        self.tgt_non_granular_index = tgt_non_granular_index;
164        self.model_id = if let Some(id) = self.model_id {
165            Some(id)
166        } else {
167            info!(
168                "Using adapter base model ID: `{}`",
169                self.xlora_order.as_ref().unwrap().base_model_id
170            );
171            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
172        };
173        self
174    }
175
176    pub fn with_xlora(
177        mut self,
178        xlora_model_id: String,
179        xlora_order: Ordering,
180        no_kv_cache: bool,
181        tgt_non_granular_index: Option<usize>,
182    ) -> Self {
183        self.kind = ModelKind::Adapter {
184            adapter: AdapterKind::XLora,
185        };
186        self.with_adapter(
187            xlora_model_id,
188            xlora_order,
189            no_kv_cache,
190            tgt_non_granular_index,
191        )
192    }
193
194    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
195        self.kind = ModelKind::Adapter {
196            adapter: AdapterKind::Lora,
197        };
198        self.lora_adapter_ids = Some(lora_adapter_ids);
199        self
200    }
201
202    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
203        self.hf_cache_path = Some(hf_cache_path);
204        self
205    }
206
207    /// If the loader type is not specified, loader type is automatically determined from the
208    /// `architectures` array in the config.
209    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
210        let loader: Box<dyn NormalModelLoader> = match loader_tp {
211            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
212            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
213            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
214            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
215            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
216            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
217            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
218            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
219            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
220            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
221            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
222            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
223            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
224            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
225            None => Box::new(AutoNormalLoader),
226        };
227        Ok(Box::new(NormalLoader {
228            inner: loader,
229            model_id: self.model_id.unwrap(),
230            config: self.config,
231            xlora_model_id: self.xlora_model_id,
232            lora_adapter_ids: self.lora_adapter_ids,
233            kind: self.kind,
234            xlora_order: self.xlora_order,
235            no_kv_cache: self.no_kv_cache,
236            chat_template: self.chat_template,
237            tokenizer_json: self.tokenizer_json,
238            tgt_non_granular_index: self.tgt_non_granular_index,
239            jinja_explicit: self.jinja_explicit,
240            token_source: RwLock::new(None),
241            revision: RwLock::new(None),
242            from_uqff: RwLock::new(None),
243            hf_cache_path: self.hf_cache_path,
244        }))
245    }
246}
247
248impl Loader for NormalLoader {
249    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
250    fn load_model_from_hf(
251        &self,
252        revision: Option<String>,
253        token_source: TokenSource,
254        dtype: &dyn TryIntoDType,
255        device: &Device,
256        silent: bool,
257        mapper: DeviceMapSetting,
258        in_situ_quant: Option<IsqType>,
259        paged_attn_config: Option<PagedAttentionConfig>,
260    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
261        let cache = self
262            .hf_cache_path
263            .clone()
264            .map(Cache::new)
265            .unwrap_or_default();
266        GLOBAL_HF_CACHE.get_or_init(|| cache);
267
268        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
269            LocalModelPaths,
270            &token_source,
271            revision.clone(),
272            self,
273            None,
274            None,
275            silent,
276            self.config.from_uqff.is_some()
277        );
278        if let Some(from_uqff) = self.config.from_uqff.clone() {
279            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
280        }
281        *self
282            .token_source
283            .write()
284            .expect("Failed to write to token source") = Some(token_source);
285        *self.revision.write().expect("Failed to write to revision") = revision;
286        self.load_model_from_path(
287            &paths?,
288            dtype,
289            device,
290            silent,
291            mapper,
292            in_situ_quant,
293            paged_attn_config,
294        )
295    }
296
297    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
298    fn load_model_from_path(
299        &self,
300        paths: &Box<dyn ModelPaths>,
301        dtype: &dyn TryIntoDType,
302        device: &Device,
303        silent: bool,
304        mut mapper: DeviceMapSetting,
305        mut in_situ_quant: Option<IsqType>,
306        mut paged_attn_config: Option<PagedAttentionConfig>,
307    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
308        let config = std::fs::read_to_string(paths.get_config_filename())?;
309
310        if !self.inner.supports_paged_attention(&config)? {
311            paged_attn_config = None;
312        }
313
314        // Apply default prompt size here
315        let prompt_chunksize = self
316            .config
317            .prompt_chunksize
318            .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
319            .get();
320
321        info!("Prompt chunk size is {prompt_chunksize}.",);
322
323        let use_nccl = mistralrs_quant::distributed::use_nccl();
324
325        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
326            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
327            let WorkerTransferData::Init { id: _, worker_rank } = payload;
328            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
329        } else if use_nccl {
330            vec![candle_core::Device::new_cuda(0)?]
331        } else {
332            device_map::get_all_similar_devices(device)?
333        };
334        let device = if use_nccl {
335            available_devices[0].clone()
336        } else {
337            device.clone()
338        };
339
340        // If auto, convert to Map if not using nccl
341        if use_nccl {
342            mapper = DeviceMapSetting::DummyNccl {
343                nm_device: available_devices[0].clone(),
344            };
345        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
346            // Initial dtype
347            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
348
349            // Disable ISQ if we are loading a prequantized model.
350            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
351                in_situ_quant = None;
352            }
353
354            // ISQ or UQFF: quantized path
355            // Match logic below where UQFF has priority
356            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
357                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
358                    let weight_pack_factor = {
359                        let ser_artifacts = unsafe {
360                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
361                        };
362                        let mut total_pack_factors = 0;
363                        let total_tensors = ser_artifacts.tensors().len();
364                        for (_, artifact) in ser_artifacts.tensors() {
365                            let artifact = artifact.data();
366                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
367                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
368                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
369                            {
370                                QuantizedSerdeType::Hqq => {
371                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
372                                        .pack_factor(dtype)
373                                }
374                                QuantizedSerdeType::Gguf => {
375                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
376                                        .pack_factor(dtype)
377                                }
378                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
379                                QuantizedSerdeType::Unquant => 1,
380                                QuantizedSerdeType::Afq => {
381                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
382                                        .pack_factor(dtype)
383                                }
384                            };
385                            total_pack_factors += pack_factor;
386                        }
387
388                        total_pack_factors / total_tensors
389                    };
390
391                    let layer_sizes_in_bytes =
392                        self.inner
393                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
394                    let non_mapped_size_in_bytes =
395                        self.inner
396                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
397                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
398                    (
399                        layer_sizes_in_bytes,
400                        non_mapped_size_in_bytes,
401                        layer_sizes_sum + non_mapped_size_in_bytes,
402                    )
403                } else if let Some(isq) = in_situ_quant {
404                    let weight_pack_factor = isq.pack_factor(dtype);
405                    let layer_sizes_in_bytes =
406                        self.inner
407                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
408                    let non_mapped_size_in_bytes =
409                        self.inner
410                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
411                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
412                    (
413                        layer_sizes_in_bytes,
414                        non_mapped_size_in_bytes,
415                        layer_sizes_sum + non_mapped_size_in_bytes,
416                    )
417                } else {
418                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
419                    let weight_pack_factor =
420                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
421                    let layer_sizes_in_bytes =
422                        self.inner
423                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
424                    let non_mapped_size_in_bytes =
425                        self.inner
426                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
427                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
428                    (
429                        layer_sizes_in_bytes,
430                        non_mapped_size_in_bytes,
431                        layer_sizes_sum + non_mapped_size_in_bytes,
432                    )
433                };
434
435            let new = self.inner.get_device_layers(
436                &config,
437                self.inner.num_layers(&config)?,
438                layer_sizes_in_bytes,
439                non_mapped_size_in_bytes,
440                total_model_size_in_bytes,
441                &available_devices,
442                dtype,
443                &params,
444                prompt_chunksize,
445                paged_attn_config.as_ref(),
446            )?;
447            mapper = DeviceMapSetting::Map(new);
448        }
449
450        let pipeline_mapper = mapper.into_mapper(
451            self.inner.num_layers(&config)?,
452            &device,
453            self.config.topology.as_ref(),
454        )?;
455        let mapper = mapper.into_mapper(
456            self.inner.num_layers(&config)?,
457            &device,
458            self.config.topology.as_ref(),
459        )?;
460        let mut layer_devices = Vec::new();
461        for layer in 0..self.inner.num_layers(&config)? {
462            let device = mapper.device_for(layer, false).cloned();
463            layer_devices.push(device);
464        }
465        let dtype = mapper.get_min_dtype(dtype)?;
466
467        // TODO: PagedAttention is not supported with CPU for now.
468        // This check is not really necessary because `get_device_layers` should prevent it.
469        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
470        if mapping_uses_cpu {
471            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
472            paged_attn_config = None;
473        }
474
475        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
476        if crate::using_flash_attn() {
477            once_log_info("FlashAttention is enabled.");
478        }
479
480        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
481        let mut loading_isq = if self.config.imatrix.is_none()
482            && self.config.calibration_file.is_none()
483            && !device.is_cuda()
484            && self.config.write_uqff.is_none()
485            && in_situ_quant.is_some()
486        {
487            let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
488            {
489                self.inner.immediate_isq_predicates_moqe(&config)?
490            } else {
491                self.inner.immediate_isq_predicates(&config)?
492            };
493            info!("Applying ISQ to {in_situ_quant:?}");
494            if predicates.is_empty() {
495                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
496            }
497            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
498            false
499        } else {
500            in_situ_quant.is_some()
501        };
502
503        if let Some(ref topology) = self.config.topology {
504            loading_isq |= topology
505                .0
506                .iter()
507                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
508        }
509
510        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
511            anyhow::bail!(
512                "`imatrix` and `calibration_file` were both specified, this is not allowed."
513            );
514        }
515
516        // Load onto the regular device if not using isq or if the calibration file is specified
517        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
518            loading_isq = false;
519            device.clone()
520        } else {
521            Device::Cpu
522        };
523
524        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
525
526        let attention_mechanism = if paged_attn_config.is_some() {
527            AttentionImplementation::PagedAttention
528        } else {
529            AttentionImplementation::Eager
530        };
531
532        let multi_progress = Arc::new(MultiProgress::new());
533
534        let mut model = if use_nccl {
535            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
536                dtype,
537                &device,
538                &available_devices,
539                silent,
540                &config,
541                loading_isq,
542                self.config.from_uqff.is_some(),
543                self.config.organization,
544                &*self.inner,
545                paths.as_ref(),
546            )?;
547
548            // Special case for where things can be more optimially loaded.
549            match self.kind {
550                ModelKind::Normal => normal_model_loader_sharded!(
551                    sharded_vb,
552                    config,
553                    self.inner,
554                    mapper,
555                    loading_isq,
556                    device.clone(),
557                    attention_mechanism,
558                    multi_progress.clone(),
559                ),
560                ModelKind::Adapter {
561                    adapter: AdapterKind::XLora,
562                } => xlora_model_loader!(
563                    paths,
564                    Some(dtype),
565                    &load_device,
566                    layer_devices.clone(),
567                    config,
568                    self.inner,
569                    silent,
570                    mapper,
571                    loading_isq,
572                    device.clone(),
573                    multi_progress.clone(),
574                ),
575                ModelKind::Adapter {
576                    adapter: AdapterKind::Lora,
577                } => lora_model_loader!(
578                    paths,
579                    Some(dtype),
580                    &load_device,
581                    layer_devices.clone(),
582                    config,
583                    self.inner,
584                    silent,
585                    mapper,
586                    loading_isq,
587                    self.config.from_uqff.is_some(),
588                    device.clone(),
589                    attention_mechanism,
590                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
591                    multi_progress.clone(),
592                ),
593                _ => unreachable!(),
594            }
595        } else {
596            match self.kind {
597                ModelKind::Normal => normal_model_loader!(
598                    paths,
599                    Some(dtype),
600                    &load_device,
601                    layer_devices.clone(),
602                    config,
603                    self.inner,
604                    silent,
605                    mapper,
606                    loading_isq,
607                    self.config.from_uqff.is_some(),
608                    device.clone(),
609                    attention_mechanism,
610                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
611                    multi_progress.clone(),
612                ),
613                ModelKind::Adapter {
614                    adapter: AdapterKind::XLora,
615                } => xlora_model_loader!(
616                    paths,
617                    Some(dtype),
618                    &load_device,
619                    layer_devices.clone(),
620                    config,
621                    self.inner,
622                    silent,
623                    mapper,
624                    loading_isq,
625                    device.clone(),
626                    multi_progress.clone(),
627                ),
628                ModelKind::Adapter {
629                    adapter: AdapterKind::Lora,
630                } => lora_model_loader!(
631                    paths,
632                    Some(dtype),
633                    &load_device,
634                    layer_devices.clone(),
635                    config,
636                    self.inner,
637                    silent,
638                    mapper,
639                    loading_isq,
640                    self.config.from_uqff.is_some(),
641                    device.clone(),
642                    attention_mechanism,
643                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
644                    multi_progress.clone(),
645                ),
646                _ => unreachable!(),
647            }
648        };
649
650        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
651        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
652            serde_json::from_str(&fs::read_to_string(f).unwrap())
653                .expect("bos_token_id/eos_token_id missing in generation_config.json")
654        });
655
656        let chat_template_explicit = paths
657            .get_chat_template_explicit()
658            .as_ref()
659            .map(|x| x.to_string_lossy().to_string());
660        let chat_template = get_chat_template(
661            paths,
662            self.jinja_explicit.as_ref(),
663            chat_template_explicit.as_ref(),
664            self.chat_template.as_ref(),
665            None,
666        );
667
668        if let Some(calibration_file) = &self.config.calibration_file {
669            let calibration_data = std::fs::read_to_string(calibration_file)?;
670            // Tokenize, don't add bos yet
671            let tokens = tokenizer
672                .encode_fast(calibration_data, false)
673                .map_err(anyhow::Error::msg)?
674                .get_ids()
675                .to_vec();
676            info!(
677                "Collecting imatrix from calibration file `{}` of {} tokens.",
678                calibration_file.display(),
679                tokens.len()
680            );
681            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
682            let bos_tok_id = tokenizer
683                .token_to_id(&bos_toks[0])
684                .expect("Somehow the bos token is not present.");
685
686            match self.config.organization {
687                IsqOrganization::Default => model.begin_track_stats()?,
688                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
689            }
690
691            const CHUNK_SIZE: usize = 1024;
692            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
693            let start = Instant::now();
694            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
695                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
696                let chunk_len = chunk.len();
697
698                let start = Instant::now();
699                let inputs = make_prompt_chunk(
700                    0,
701                    vec![&chunk],
702                    &[0],
703                    &load_device,
704                    None,
705                    false,
706                    None,
707                    Some(pipeline_mapper.as_ref()),
708                )?;
709
710                model.forward(
711                    &inputs.input.to_device(model.device())?,
712                    &inputs.positions,
713                    inputs.context_lens.clone(),
714                    inputs.position_ids.clone(),
715                    None,
716                    &inputs.flash_meta.clone(),
717                )?;
718
719                match model.cache_mut() {
720                    EitherCache::Full(full) => {
721                        for layer in &mut *full.lock() {
722                            *layer = None
723                        }
724                    }
725                    EitherCache::Normal(normal) => {
726                        for layer in &mut *normal.lock().unwrap().0 {
727                            layer.reset();
728                        }
729                    }
730                }
731
732                let end = Instant::now();
733                info!(
734                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
735                    i + 1,
736                    end.duration_since(start).as_secs_f32()
737                );
738            }
739            load_device.synchronize()?;
740            let end = Instant::now();
741            info!(
742                "Finished collecting imatrix in {:.2}s",
743                end.duration_since(start).as_secs_f32()
744            );
745        }
746
747        // Only if loading from UQFF
748        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
749            let imatrix_source = match (
750                self.config.imatrix.as_ref(),
751                self.config.calibration_file.is_some(),
752            ) {
753                (None, false) => None,
754                (Some(file), false) => Some(ImatrixDataSource::File(file)),
755                (None, true) => Some(ImatrixDataSource::Collected),
756                (Some(_), true) => unreachable!(),
757            };
758
759            info!("Applying ISQ to all ranks.");
760
761            let multi_progress = Arc::new(MultiProgress::new());
762
763            model.quantize(
764                in_situ_quant,
765                model.device().clone(),
766                self.config.topology.as_ref(),
767                silent,
768                imatrix_source,
769                self.config.organization,
770                self.config.write_uqff.as_ref(),
771                UqffFullSer {
772                    tokenizer: &tokenizer,
773                    template_filename: paths.get_template_filename(),
774                    generation_config: paths.get_gen_conf_filename(),
775                    config: config.clone(),
776                    processor_filename: &None,
777                    preprocessor_filename: &None,
778                },
779                multi_progress.clone(),
780            )?;
781        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
782            model.load_from_artifacts(
783                device.clone(),
784                self.config.topology.as_ref(),
785                silent,
786                from_uqff,
787            )?;
788        }
789
790        let paged_attn_config = if matches!(
791            self.kind,
792            ModelKind::Adapter {
793                adapter: AdapterKind::XLora
794            }
795        ) {
796            warn!(
797                "Adapter parallel_models do not currently support PagedAttention, running without"
798            );
799            None
800        } else {
801            paged_attn_config
802        };
803
804        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
805            let cache_config = calculate_cache_config(
806                paged_attn_config.mem_gpu,
807                paged_attn_config.mem_cpu,
808                paged_attn_config.block_size,
809                dtype,
810                model.config(),
811                &device,
812                &pipeline_mapper
813                    .get_unique_devices()
814                    .into_iter()
815                    .map(Some)
816                    .collect::<Vec<_>>(),
817                silent,
818            )?;
819
820            let mut layer_devices = Vec::new();
821            for layer in 0..self.inner.num_layers(&config)? {
822                let device = model.get_layers().1.device_for(layer, false).cloned();
823                layer_devices.push(device);
824            }
825            let cache_engine = CacheEngine::new(
826                model.config(),
827                &cache_config,
828                dtype,
829                model.device(),
830                layer_devices.clone(),
831            )?;
832
833            (Some(cache_config), Some(cache_engine))
834        } else {
835            (None, None)
836        };
837
838        let max_seq_len = model.max_seq_len();
839        let llg_factory = build_llg_factory(tokenizer.clone())?;
840        let num_hidden_layers = match model.cache() {
841            EitherCache::Full(full) => full.lock().len(),
842            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
843        };
844        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
845        let sliding_window = model.config().sliding_window;
846        let model_metadata = Arc::new(model.config().clone());
847
848        Ok(Arc::new(Mutex::new(NormalPipeline {
849            model,
850            tokenizer: tokenizer.into(),
851            no_kv_cache: self.no_kv_cache,
852            chat_template: Arc::new(chat_template),
853            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
854                NonGranularState {
855                    non_granular_index: Arc::new(Mutex::new(0)),
856                    tgt_non_granular_index,
857                }
858            }),
859            model_id: self.model_id.clone(),
860            metadata: Arc::new(GeneralMetadata {
861                max_seq_len,
862                llg_factory: Some(llg_factory),
863                no_kv_cache: self.no_kv_cache,
864                no_prefix_cache: is_xlora,
865                num_hidden_layers,
866                eos_tok: eos,
867                kind: self.kind.clone(),
868                is_xlora,
869                activation_dtype: dtype,
870                sliding_window,
871                cache_config,
872                cache_engine,
873                prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
874                model_metadata: Some(model_metadata),
875            }),
876            topology: self.config.topology.clone(),
877            silent,
878            organization: self.config.organization,
879            template_filename: paths.get_template_filename().clone(),
880            generation_config: paths.get_gen_conf_filename().cloned(),
881            config,
882            imatrix: self.config.imatrix.clone(),
883            mapper: pipeline_mapper,
884        })))
885    }
886
887    fn get_id(&self) -> String {
888        self.model_id.clone()
889    }
890
891    fn get_kind(&self) -> ModelKind {
892        self.kind.clone()
893    }
894}
895
896impl PreProcessingMixin for NormalPipeline {
897    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
898        Some(self.chat_template.clone())
899    }
900    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
901        None
902    }
903}
904
905impl IsqPipelineMixin for NormalPipeline {
906    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
907        let device = self.device().clone();
908        let multi_progress = Arc::new(MultiProgress::new());
909        self.model.quantize(
910            Some(dtype),
911            device.clone(),
912            self.topology.as_ref(),
913            self.silent,
914            self.imatrix.as_ref().map(ImatrixDataSource::File),
915            self.organization,
916            None,
917            UqffFullSer {
918                tokenizer: &self.tokenizer,
919                template_filename: &self.template_filename,
920                generation_config: self.generation_config.as_ref(),
921                config: self.config.clone(),
922                processor_filename: &None,
923                preprocessor_filename: &None,
924            },
925            multi_progress.clone(),
926        )?;
927        Ok(())
928    }
929}
930
931impl CacheManagerMixin for NormalPipeline {
932    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
933        if matches!(self.model.cache(), EitherCache::Full(_)) {
934            FullCacheManager.clone_in_cache(self, seqs, false)
935        } else {
936            NormalCacheManager.clone_in_cache(self, seqs, false)
937        }
938    }
939    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
940        if matches!(self.model.cache(), EitherCache::Full(_)) {
941            FullCacheManager.clone_out_cache(self, seqs, false)
942        } else {
943            NormalCacheManager.clone_out_cache(self, seqs, false)
944        }
945    }
946    fn set_none_cache(
947        &self,
948        seqs: &mut [&mut Sequence],
949        reset_non_granular: bool,
950        modify_draft_cache: bool,
951        load_preallocated_cache: bool,
952    ) {
953        if matches!(self.model.cache(), EitherCache::Full(_)) {
954            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
955        } else {
956            NormalCacheManager.set_none_cache(
957                self,
958                seqs,
959                modify_draft_cache,
960                load_preallocated_cache,
961            );
962        }
963        if reset_non_granular {
964            self.reset_non_granular_state()
965        }
966    }
967    fn cache(&self) -> &EitherCache {
968        self.model.cache()
969    }
970}
971
972impl MetadataMixin for NormalPipeline {
973    fn device(&self) -> Device {
974        self.model.device().clone()
975    }
976    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
977        Some(self.tokenizer.clone())
978    }
979    fn name(&self) -> String {
980        self.model_id.clone()
981    }
982    fn reset_non_granular_state(&self) {
983        if let Some(s) = self.non_granular_state.as_ref() {
984            *self.cache().full().get_scalings_cache() = None;
985            *get_mut_arcmutex!(s.non_granular_index) = 0;
986        }
987    }
988    fn get_metadata(&self) -> Arc<GeneralMetadata> {
989        self.metadata.clone()
990    }
991    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
992        Some(&*self.mapper)
993    }
994}
995
996#[async_trait::async_trait]
997impl Pipeline for NormalPipeline {
998    fn forward_inputs(
999        &mut self,
1000        inputs: Box<dyn Any>,
1001        return_raw_logits: bool,
1002    ) -> Result<ForwardInputsResult, candle_core::Error> {
1003        let ModelInputs {
1004            input_ids,
1005            input_ids_full,
1006            seqlen_offsets,
1007            seqlen_offsets_full,
1008            context_lens,
1009            position_ids,
1010            paged_attn_meta,
1011            flash_meta,
1012            flash_meta_full,
1013        } = *inputs.downcast().expect("Downcast failed.");
1014        let metadata = self.get_metadata();
1015        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1016            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1017            (Some(_), None) => {
1018                // This can happen if Rust-side user code is wrong
1019                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1020            }
1021            (None, Some(_)) => {
1022                // This should never happen but we handle it anyway
1023                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1024            }
1025            (None, None) => None,
1026        };
1027        let logits = match self.model.is_xlora() {
1028            false => {
1029                let paged_attn_meta = paged_attn_meta
1030                    .as_ref()
1031                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1032
1033                self.model.forward(
1034                    &input_ids,
1035                    &seqlen_offsets,
1036                    context_lens,
1037                    position_ids,
1038                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1039                    &flash_meta,
1040                )?
1041            }
1042            true => self.model.xlora_forward(
1043                &input_ids,
1044                input_ids_full.as_ref().unwrap_or(&input_ids),
1045                &seqlen_offsets,
1046                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1047                self.no_kv_cache,
1048                &self.non_granular_state,
1049                context_lens,
1050                position_ids,
1051                &flash_meta,
1052                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1053            )?,
1054        };
1055        if return_raw_logits {
1056            Ok(ForwardInputsResult::RawLogits { logits })
1057        } else {
1058            Ok(ForwardInputsResult::CausalGeneration { logits })
1059        }
1060    }
1061    async fn sample_causal_gen(
1062        &self,
1063        seqs: &mut [&mut Sequence],
1064        logits: Vec<Tensor>,
1065        prefix_cacher: &mut PrefixCacheManagerV2,
1066        disable_eos_stop: bool,
1067        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1068    ) -> Result<(), candle_core::Error> {
1069        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1070    }
1071    fn category(&self) -> ModelCategory {
1072        ModelCategory::Text
1073    }
1074}
1075
1076impl AnyMoePipelineMixin for NormalPipeline {
1077    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1078        self.model.finish_training(gate_model_id)
1079    }
1080    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1081        self.model.get_vars()
1082    }
1083    fn amoe_base_model_trainable_params(&self) -> usize {
1084        self.model.trainable_params()
1085    }
1086    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1087        self.model.take_cached_gating_outputs()
1088    }
1089    fn amoe_create_layers(
1090        &mut self,
1091        model_ids: Vec<String>,
1092        token: &TokenSource,
1093        revision: Option<String>,
1094        match_regex: &str,
1095        config: crate::amoe::AnyMoeConfig,
1096        dtype: candle_core::DType,
1097        dev: &Device,
1098        (prefix, mlp): (String, String),
1099        layers: Vec<usize>,
1100        expert_type: AnyMoeExpertType,
1101        silent: bool,
1102        gate_model_id: Option<String>,
1103    ) -> candle_core::Result<()> {
1104        let mut vbs = Vec::new();
1105        // Precompile regex here
1106        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1107        for model_id in model_ids {
1108            let model_id_str = &model_id;
1109            let model_id = Path::new(&model_id);
1110
1111            let api = {
1112                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1113                let mut api = ApiBuilder::from_cache(cache)
1114                    .with_progress(!silent)
1115                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1116                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1117                    api = api.with_cache_dir(x.into());
1118                }
1119                api.build().map_err(candle_core::Error::msg)?
1120            };
1121            let revision = revision.clone().unwrap_or("main".to_string());
1122            let api = api.repo(Repo::with_revision(
1123                model_id_str.clone(),
1124                RepoType::Model,
1125                revision.clone(),
1126            ));
1127
1128            let mut filenames = vec![];
1129            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1130                filenames.push(api_get_file!(api, &rfilename, model_id));
1131            }
1132
1133            let regex = regex.clone();
1134            let match_regex_clone = match_regex.to_string();
1135            let layers_clone = layers.clone();
1136            let vb = from_mmaped_safetensors(
1137                filenames,
1138                vec![],
1139                Some(dtype),
1140                dev,
1141                vec![None],
1142                silent,
1143                None,
1144                move |key| {
1145                    if regex.is_match(&key) {
1146                        // Idx of the last char of the layer id, +1
1147                        // Assumes N.MLP
1148                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1149                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1150                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1151                            .parse::<usize>()
1152                            .unwrap();
1153                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1154                    } else {
1155                        false
1156                    }
1157                },
1158                Arc::new(|_| DeviceForLoadTensor::Base),
1159            )?;
1160            vbs.push(vb);
1161        }
1162
1163        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1164            let model_id_str = &gate_model_id;
1165            let model_id = Path::new(&gate_model_id);
1166
1167            let api = {
1168                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1169                let mut api = ApiBuilder::from_cache(cache)
1170                    .with_progress(!silent)
1171                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1172                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1173                    api = api.with_cache_dir(x.into());
1174                }
1175                api.build().map_err(candle_core::Error::msg)?
1176            };
1177            let revision = revision.clone().unwrap_or("main".to_string());
1178            let api = api.repo(Repo::with_revision(
1179                model_id_str.clone(),
1180                RepoType::Model,
1181                revision.clone(),
1182            ));
1183
1184            let mut gate_filenames = vec![];
1185            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1186                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1187            }
1188            assert_eq!(
1189                gate_filenames.len(),
1190                1,
1191                "Gate model ID must contain only one .safetensors file"
1192            );
1193
1194            let vb = from_mmaped_safetensors(
1195                gate_filenames.clone(),
1196                vec![],
1197                Some(dtype),
1198                dev,
1199                vec![None],
1200                silent,
1201                None,
1202                |_| true,
1203                Arc::new(|_| DeviceForLoadTensor::Base),
1204            )?;
1205            info!(
1206                "Loaded gating layers from `{}`",
1207                gate_filenames[0].display()
1208            );
1209            Some(vb)
1210        } else {
1211            None
1212        };
1213
1214        self.model.create_anymoe_layers(
1215            vbs.clone(),
1216            config.clone(),
1217            (prefix.clone(), mlp.clone()),
1218            layers.clone(),
1219            expert_type.clone(),
1220            gate_vb.clone(),
1221        )?;
1222
1223        Ok(())
1224    }
1225    fn amoe_supported(&self) -> bool {
1226        self.model.amoe_supported()
1227    }
1228}