mistralrs_core/pipeline/
normal.rs

1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6    TokenSource,
7};
8use super::{
9    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14    LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
15    Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
16};
17use crate::amoe::AnyMoeExpertType;
18use crate::attention::ATTENTION_CHUNK_SIZE;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{
50    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use rand_isaac::Isaac64Rng;
53use regex_automata::meta::Regex;
54use std::any::Any;
55use std::borrow::Cow;
56use std::path::{Path, PathBuf};
57use std::str::FromStr;
58use std::sync::{Arc, RwLock};
59use std::time::Instant;
60use std::{env, fs};
61use tokenizers::Tokenizer;
62use tokio::sync::Mutex;
63use tracing::{info, warn};
64
65pub struct NormalPipeline {
66    model: Box<dyn NormalModel + Send + Sync>,
67    tokenizer: Arc<Tokenizer>,
68    no_kv_cache: bool,
69    chat_template: Arc<ChatTemplate>,
70    non_granular_state: Option<NonGranularState>,
71    model_id: String,
72    metadata: Arc<GeneralMetadata>,
73    topology: Option<Topology>,
74    silent: bool,
75    organization: IsqOrganization,
76    // For full UQFF serialization
77    template_filename: Option<PathBuf>,
78    generation_config: Option<PathBuf>,
79    config: String,
80    imatrix: Option<PathBuf>,
81    mapper: Box<dyn DeviceMapper + Send + Sync>,
82}
83
84/// A loader for a "normal" (non-quantized) model.
85pub struct NormalLoader {
86    inner: Box<dyn NormalModelLoader>,
87    model_id: String,
88    config: NormalSpecificConfig,
89    xlora_model_id: Option<String>,
90    lora_adapter_ids: Option<Vec<String>>,
91    kind: ModelKind,
92    xlora_order: Option<Ordering>,
93    no_kv_cache: bool,
94    chat_template: Option<String>,
95    tokenizer_json: Option<String>,
96    tgt_non_granular_index: Option<usize>,
97    token_source: RwLock<Option<TokenSource>>,
98    revision: RwLock<Option<String>>,
99    from_uqff: RwLock<Option<Vec<PathBuf>>>,
100    jinja_explicit: Option<String>,
101    hf_cache_path: Option<PathBuf>,
102}
103
104#[derive(Default)]
105/// A builder for a loader for a "normal" (non-quantized) model.
106pub struct NormalLoaderBuilder {
107    model_id: Option<String>,
108    config: NormalSpecificConfig,
109    xlora_model_id: Option<String>,
110    lora_adapter_ids: Option<Vec<String>>,
111    kind: ModelKind,
112    xlora_order: Option<Ordering>,
113    no_kv_cache: bool,
114    chat_template: Option<String>,
115    tokenizer_json: Option<String>,
116    tgt_non_granular_index: Option<usize>,
117    jinja_explicit: Option<String>,
118    hf_cache_path: Option<PathBuf>,
119}
120
121#[derive(Clone, Default)]
122/// Config specific to loading a normal model.
123pub struct NormalSpecificConfig {
124    pub topology: Option<Topology>,
125    pub organization: IsqOrganization,
126    pub write_uqff: Option<PathBuf>,
127    pub from_uqff: Option<Vec<PathBuf>>,
128    pub imatrix: Option<PathBuf>,
129    pub calibration_file: Option<PathBuf>,
130    pub hf_cache_path: Option<PathBuf>,
131    pub matformer_config_path: Option<PathBuf>,
132    pub matformer_slice_name: Option<String>,
133}
134
135impl NormalLoaderBuilder {
136    pub fn new(
137        config: NormalSpecificConfig,
138        chat_template: Option<String>,
139        tokenizer_json: Option<String>,
140        model_id: Option<String>,
141        no_kv_cache: bool,
142        jinja_explicit: Option<String>,
143    ) -> Self {
144        Self {
145            config,
146            chat_template,
147            tokenizer_json,
148            model_id,
149            kind: ModelKind::Normal,
150            jinja_explicit,
151            no_kv_cache,
152            ..Default::default()
153        }
154    }
155
156    fn with_adapter(
157        mut self,
158        xlora_model_id: String,
159        xlora_order: Ordering,
160        no_kv_cache: bool,
161        tgt_non_granular_index: Option<usize>,
162    ) -> Self {
163        self.xlora_model_id = Some(xlora_model_id);
164        self.xlora_order = Some(xlora_order);
165        self.no_kv_cache = no_kv_cache;
166        self.tgt_non_granular_index = tgt_non_granular_index;
167        self.model_id = if let Some(id) = self.model_id {
168            Some(id)
169        } else {
170            info!(
171                "Using adapter base model ID: `{}`",
172                self.xlora_order.as_ref().unwrap().base_model_id
173            );
174            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
175        };
176        self
177    }
178
179    pub fn with_xlora(
180        mut self,
181        xlora_model_id: String,
182        xlora_order: Ordering,
183        no_kv_cache: bool,
184        tgt_non_granular_index: Option<usize>,
185    ) -> Self {
186        self.kind = ModelKind::Adapter {
187            adapter: AdapterKind::XLora,
188        };
189        self.with_adapter(
190            xlora_model_id,
191            xlora_order,
192            no_kv_cache,
193            tgt_non_granular_index,
194        )
195    }
196
197    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
198        self.kind = ModelKind::Adapter {
199            adapter: AdapterKind::Lora,
200        };
201        self.lora_adapter_ids = Some(lora_adapter_ids);
202        self
203    }
204
205    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
206        self.hf_cache_path = Some(hf_cache_path);
207        self
208    }
209
210    /// If the loader type is not specified, loader type is automatically determined from the
211    /// `architectures` array in the config.
212    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
213        let loader: Box<dyn NormalModelLoader> = match loader_tp {
214            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
215            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
216            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
217            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
218            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
219            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
220            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
221            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
222            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
223            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
224            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
225            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
226            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
227            Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
228            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
229            Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
230            None => Box::new(AutoNormalLoader),
231        };
232        Ok(Box::new(NormalLoader {
233            inner: loader,
234            model_id: self.model_id.unwrap(),
235            config: self.config,
236            xlora_model_id: self.xlora_model_id,
237            lora_adapter_ids: self.lora_adapter_ids,
238            kind: self.kind,
239            xlora_order: self.xlora_order,
240            no_kv_cache: self.no_kv_cache,
241            chat_template: self.chat_template,
242            tokenizer_json: self.tokenizer_json,
243            tgt_non_granular_index: self.tgt_non_granular_index,
244            jinja_explicit: self.jinja_explicit,
245            token_source: RwLock::new(None),
246            revision: RwLock::new(None),
247            from_uqff: RwLock::new(None),
248            hf_cache_path: self.hf_cache_path,
249        }))
250    }
251}
252
253impl Loader for NormalLoader {
254    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
255    fn load_model_from_hf(
256        &self,
257        revision: Option<String>,
258        token_source: TokenSource,
259        dtype: &dyn TryIntoDType,
260        device: &Device,
261        silent: bool,
262        mapper: DeviceMapSetting,
263        in_situ_quant: Option<IsqType>,
264        paged_attn_config: Option<PagedAttentionConfig>,
265    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
266        let cache = self
267            .hf_cache_path
268            .clone()
269            .map(Cache::new)
270            .unwrap_or_default();
271        GLOBAL_HF_CACHE.get_or_init(|| cache);
272
273        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
274            LocalModelPaths,
275            &token_source,
276            revision.clone(),
277            self,
278            None,
279            None,
280            silent,
281            self.config.from_uqff.is_some()
282        );
283        if let Some(from_uqff) = self.config.from_uqff.clone() {
284            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
285        }
286        *self
287            .token_source
288            .write()
289            .expect("Failed to write to token source") = Some(token_source);
290        *self.revision.write().expect("Failed to write to revision") = revision;
291        self.load_model_from_path(
292            &paths?,
293            dtype,
294            device,
295            silent,
296            mapper,
297            in_situ_quant,
298            paged_attn_config,
299        )
300    }
301
302    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
303    fn load_model_from_path(
304        &self,
305        paths: &Box<dyn ModelPaths>,
306        dtype: &dyn TryIntoDType,
307        device: &Device,
308        silent: bool,
309        mut mapper: DeviceMapSetting,
310        mut in_situ_quant: Option<IsqType>,
311        mut paged_attn_config: Option<PagedAttentionConfig>,
312    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
313        let config = std::fs::read_to_string(paths.get_config_filename())?;
314
315        if !self.inner.supports_paged_attention(&config)? {
316            paged_attn_config = None;
317        }
318
319        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
320
321        let use_nccl = mistralrs_quant::distributed::use_nccl();
322
323        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
324            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
325            let WorkerTransferData::Init { id: _, worker_rank } = payload;
326            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
327        } else if use_nccl {
328            vec![candle_core::Device::new_cuda(0)?]
329        } else {
330            device_map::get_all_similar_devices(device)?
331        };
332        #[cfg(feature = "cuda")]
333        for device in &available_devices {
334            if let Device::Cuda(dev) = device {
335                unsafe { dev.disable_event_tracking() };
336            }
337        }
338        let device = if use_nccl || cfg!(feature = "ring") {
339            available_devices[0].clone()
340        } else {
341            device.clone()
342        };
343
344        // If auto, convert to Map if not using nccl
345        if use_nccl || cfg!(feature = "ring") {
346            mapper = DeviceMapSetting::DummyNccl {
347                nm_device: available_devices[0].clone(),
348            };
349        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
350            // Initial dtype
351            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
352
353            // Disable ISQ if we are loading a prequantized model.
354            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
355                in_situ_quant = None;
356            }
357
358            // ISQ or UQFF: quantized path
359            // Match logic below where UQFF has priority
360            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
361                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
362                    let weight_pack_factor = {
363                        let ser_artifacts = unsafe {
364                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
365                        };
366                        let mut total_pack_factors = 0;
367                        let total_tensors = ser_artifacts.tensors().len();
368                        for (_, artifact) in ser_artifacts.tensors() {
369                            let artifact = artifact.data();
370                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
371                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
372                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
373                            {
374                                QuantizedSerdeType::Hqq => {
375                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
376                                        .pack_factor(dtype)
377                                }
378                                QuantizedSerdeType::Gguf => {
379                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
380                                        .pack_factor(dtype)
381                                }
382                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
383                                QuantizedSerdeType::Unquant => 1,
384                                QuantizedSerdeType::Afq => {
385                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
386                                        .pack_factor(dtype)
387                                }
388                            };
389                            total_pack_factors += pack_factor;
390                        }
391
392                        total_pack_factors / total_tensors
393                    };
394
395                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
396                        &config,
397                        dtype,
398                        weight_pack_factor,
399                        None,
400                    )?;
401                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
402                        &config,
403                        dtype,
404                        weight_pack_factor,
405                        None,
406                    )?;
407                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
408                    (
409                        layer_sizes_in_bytes,
410                        non_mapped_size_in_bytes,
411                        layer_sizes_sum + non_mapped_size_in_bytes,
412                    )
413                } else if let Some(isq) = in_situ_quant {
414                    let weight_pack_factor = isq.pack_factor(dtype);
415                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
416                        &config,
417                        dtype,
418                        weight_pack_factor,
419                        None,
420                    )?;
421                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
422                        &config,
423                        dtype,
424                        weight_pack_factor,
425                        None,
426                    )?;
427                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
428                    (
429                        layer_sizes_in_bytes,
430                        non_mapped_size_in_bytes,
431                        layer_sizes_sum + non_mapped_size_in_bytes,
432                    )
433                } else {
434                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
435                    let weight_pack_factor =
436                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
437                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
438                        &config,
439                        dtype,
440                        weight_pack_factor,
441                        None,
442                    )?;
443                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
444                        &config,
445                        dtype,
446                        weight_pack_factor,
447                        None,
448                    )?;
449                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
450                    (
451                        layer_sizes_in_bytes,
452                        non_mapped_size_in_bytes,
453                        layer_sizes_sum + non_mapped_size_in_bytes,
454                    )
455                };
456
457            let new = auto_device_map::get_device_layers(
458                &*self.inner,
459                &config,
460                self.inner.num_layers(&config)?,
461                layer_sizes_in_bytes,
462                non_mapped_size_in_bytes,
463                total_model_size_in_bytes,
464                &available_devices,
465                dtype,
466                &params,
467                paged_attn_config.as_ref(),
468            )?;
469            mapper = DeviceMapSetting::Map(new);
470        }
471
472        let pipeline_mapper = mapper.into_mapper(
473            self.inner.num_layers(&config)?,
474            &device,
475            self.config.topology.as_ref(),
476        )?;
477        let mapper = mapper.into_mapper(
478            self.inner.num_layers(&config)?,
479            &device,
480            self.config.topology.as_ref(),
481        )?;
482        let mut layer_devices = Vec::new();
483        for layer in 0..self.inner.num_layers(&config)? {
484            let device = mapper.device_for(layer, false).cloned();
485            layer_devices.push(device);
486        }
487        let dtype = mapper.get_min_dtype(dtype)?;
488
489        // TODO: PagedAttention is not supported with CPU for now.
490        // This check is not really necessary because `get_device_layers` should prevent it.
491        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
492        if mapping_uses_cpu && paged_attn_config.is_some() {
493            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
494            paged_attn_config = None;
495        }
496
497        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
498        if crate::using_flash_attn() {
499            once_log_info("FlashAttention is enabled.");
500        }
501
502        let topology_overrides = self
503            .config
504            .topology
505            .as_ref()
506            .map(|topology| {
507                topology
508                    .pattern_overrides()
509                    .into_iter()
510                    .map(|(regex, layer)| ImmediateIsqOverride {
511                        predicate: regex,
512                        ty: layer.isq,
513                        device: layer.device.clone(),
514                    })
515                    .collect::<Vec<_>>()
516            })
517            .unwrap_or_default();
518        let has_override_isq = topology_overrides
519            .iter()
520            .any(|override_entry| override_entry.ty.is_some());
521        let topology_requires_post_quant = self
522            .config
523            .topology
524            .as_ref()
525            .is_some_and(|topology| topology.requires_post_quantization());
526
527        let allow_immediate_cli = self.config.imatrix.is_none()
528            && self.config.calibration_file.is_none()
529            && !device.is_cuda()
530            && in_situ_quant.is_some();
531
532        let mut immediate_ty = None;
533        let mut immediate_predicates = Vec::new();
534        if allow_immediate_cli {
535            immediate_ty = in_situ_quant;
536            immediate_predicates =
537                if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
538                    self.inner.immediate_isq_predicates_moqe(&config)?
539                } else {
540                    self.inner.immediate_isq_predicates(&config)?
541                };
542            info!("Applying ISQ to {in_situ_quant:?}");
543            if immediate_predicates.is_empty() {
544                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
545            }
546        }
547
548        let use_immediate = allow_immediate_cli || has_override_isq;
549        if use_immediate {
550            mistralrs_quant::set_immediate_isq_with_overrides(
551                immediate_ty,
552                immediate_predicates.clone(),
553                topology_overrides.clone(),
554            );
555        }
556
557        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
558        let mut loading_isq = if use_immediate {
559            false
560        } else {
561            in_situ_quant.is_some()
562        };
563        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
564            loading_isq = true;
565        }
566        loading_isq |= topology_requires_post_quant;
567
568        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
569            anyhow::bail!(
570                "`imatrix` and `calibration_file` were both specified, this is not allowed."
571            );
572        }
573
574        // Load onto the regular device if not using isq or if the calibration file is specified
575        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
576            loading_isq = false;
577            device.clone()
578        } else {
579            Device::Cpu
580        };
581
582        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
583
584        let attention_mechanism = if paged_attn_config.is_some() {
585            AttentionImplementation::PagedAttention
586        } else {
587            AttentionImplementation::Eager
588        };
589
590        let multi_progress = Arc::new(MultiProgress::new());
591
592        // Load matformer slicing config if provided
593        let matformer_slicing_config = if let Some(matformer_path) =
594            &self.config.matformer_config_path
595        {
596            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
597            info!("Loading Matformer config from {:?}", matformer_path);
598            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
599
600            if let Some(slice_name) = &self.config.matformer_slice_name {
601                info!("Using Matformer slice: {}", slice_name);
602                Some(MatformerSliceConfig::new(slice_name.clone(), config))
603            } else {
604                // If no slice name is provided but config exists, we'll need to handle this
605                // For now, return None and let the model handle the default slice selection
606                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
607                None
608            }
609        } else {
610            None
611        };
612
613        let mut model = if use_nccl || cfg!(feature = "ring") {
614            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
615                dtype,
616                &device,
617                &available_devices,
618                silent,
619                &config,
620                loading_isq,
621                self.config.from_uqff.is_some(),
622                self.config.organization,
623                &*self.inner,
624                paths.as_ref(),
625            )?;
626
627            // Special case for where things can be more optimially loaded.
628            match self.kind {
629                ModelKind::Normal => normal_model_loader_sharded!(
630                    sharded_vb,
631                    config,
632                    self.inner,
633                    mapper,
634                    loading_isq,
635                    device.clone(),
636                    attention_mechanism,
637                    multi_progress.clone(),
638                    matformer_slicing_config.clone(),
639                ),
640                ModelKind::Adapter {
641                    adapter: AdapterKind::XLora,
642                } => xlora_model_loader!(
643                    paths,
644                    Some(dtype),
645                    &load_device,
646                    layer_devices.clone(),
647                    config,
648                    self.inner,
649                    silent,
650                    mapper,
651                    loading_isq,
652                    device.clone(),
653                    multi_progress.clone(),
654                    matformer_slicing_config.clone(),
655                ),
656                ModelKind::Adapter {
657                    adapter: AdapterKind::Lora,
658                } => lora_model_loader!(
659                    paths,
660                    Some(dtype),
661                    &load_device,
662                    layer_devices.clone(),
663                    config,
664                    self.inner,
665                    silent,
666                    mapper,
667                    loading_isq,
668                    self.config.from_uqff.is_some(),
669                    device.clone(),
670                    attention_mechanism,
671                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
672                    multi_progress.clone(),
673                    matformer_slicing_config.clone(),
674                ),
675                _ => unreachable!(),
676            }
677        } else {
678            match self.kind {
679                ModelKind::Normal => normal_model_loader!(
680                    paths,
681                    Some(dtype),
682                    &load_device,
683                    layer_devices.clone(),
684                    config,
685                    self.inner,
686                    silent,
687                    mapper,
688                    loading_isq,
689                    self.config.from_uqff.is_some(),
690                    device.clone(),
691                    attention_mechanism,
692                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
693                    multi_progress.clone(),
694                    matformer_slicing_config.clone(),
695                ),
696                ModelKind::Adapter {
697                    adapter: AdapterKind::XLora,
698                } => xlora_model_loader!(
699                    paths,
700                    Some(dtype),
701                    &load_device,
702                    layer_devices.clone(),
703                    config,
704                    self.inner,
705                    silent,
706                    mapper,
707                    loading_isq,
708                    device.clone(),
709                    multi_progress.clone(),
710                    matformer_slicing_config.clone(),
711                ),
712                ModelKind::Adapter {
713                    adapter: AdapterKind::Lora,
714                } => lora_model_loader!(
715                    paths,
716                    Some(dtype),
717                    &load_device,
718                    layer_devices.clone(),
719                    config,
720                    self.inner,
721                    silent,
722                    mapper,
723                    loading_isq,
724                    self.config.from_uqff.is_some(),
725                    device.clone(),
726                    attention_mechanism,
727                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
728                    multi_progress.clone(),
729                    matformer_slicing_config.clone(),
730                ),
731                _ => unreachable!(),
732            }
733        };
734
735        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
736        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
737            match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
738                Ok(conf) => Some(conf),
739                Err(e) => {
740                    warn!("Failed to parse generation_config.json: {}", e);
741                    None
742                }
743            }
744        });
745
746        let chat_template_explicit = paths
747            .get_chat_template_explicit()
748            .as_ref()
749            .map(|x| x.to_string_lossy().to_string());
750        let chat_template = get_chat_template(
751            paths,
752            self.jinja_explicit.as_ref(),
753            chat_template_explicit.as_ref(),
754            self.chat_template.as_ref(),
755            None,
756        );
757
758        if let Some(calibration_file) = &self.config.calibration_file {
759            let calibration_data = std::fs::read_to_string(calibration_file)?;
760            // Tokenize, don't add bos yet
761            let tokens = tokenizer
762                .encode_fast(calibration_data, false)
763                .map_err(anyhow::Error::msg)?
764                .get_ids()
765                .to_vec();
766            info!(
767                "Collecting imatrix from calibration file `{}` of {} tokens.",
768                calibration_file.display(),
769                tokens.len()
770            );
771            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
772            let bos_tok_id = tokenizer
773                .token_to_id(&bos_toks[0])
774                .expect("Somehow the bos token is not present.");
775
776            match self.config.organization {
777                IsqOrganization::Default => model.begin_track_stats()?,
778                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
779            }
780
781            const CHUNK_SIZE: usize = 1024;
782            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
783            let start = Instant::now();
784            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
785                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
786                let chunk_len = chunk.len();
787
788                let start = Instant::now();
789                let inputs = make_prompt_chunk(
790                    0,
791                    vec![&chunk],
792                    &[0],
793                    &load_device,
794                    None,
795                    false,
796                    None,
797                    Some(pipeline_mapper.as_ref()),
798                )?;
799
800                model.forward(
801                    &inputs.input.to_device(model.device())?,
802                    &inputs.positions,
803                    inputs.context_lens.clone(),
804                    inputs.position_ids.clone(),
805                    None,
806                    &inputs.flash_meta.clone(),
807                )?;
808
809                match model.cache_mut() {
810                    EitherCache::Full(full) => {
811                        for layer in &mut *full.lock() {
812                            *layer = None
813                        }
814                    }
815                    EitherCache::Normal(normal) => {
816                        for layer in &mut *normal.lock().unwrap().0 {
817                            layer.reset();
818                        }
819                    }
820                }
821
822                let end = Instant::now();
823                info!(
824                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
825                    i + 1,
826                    end.duration_since(start).as_secs_f32()
827                );
828            }
829            load_device.synchronize()?;
830            let end = Instant::now();
831            info!(
832                "Finished collecting imatrix in {:.2}s",
833                end.duration_since(start).as_secs_f32()
834            );
835        }
836
837        // Only if loading from UQFF
838        let should_serialize = self.config.write_uqff.is_some();
839        let should_quantize_pass = loading_isq;
840
841        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
842            let imatrix_source = if should_quantize_pass {
843                match (
844                    self.config.imatrix.as_ref(),
845                    self.config.calibration_file.is_some(),
846                ) {
847                    (None, false) => None,
848                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
849                    (None, true) => Some(ImatrixDataSource::Collected),
850                    (Some(_), true) => unreachable!(),
851                }
852            } else {
853                None
854            };
855
856            if should_quantize_pass {
857                info!("Applying ISQ to all ranks.");
858            } else {
859                info!("Serializing existing ISQ tensors without additional quantization.");
860            }
861
862            let multi_progress = Arc::new(MultiProgress::new());
863
864            model.quantize(
865                in_situ_quant,
866                model.device().clone(),
867                self.config.topology.as_ref(),
868                silent,
869                imatrix_source,
870                self.config.organization,
871                should_quantize_pass,
872                self.config.write_uqff.as_ref(),
873                UqffFullSer {
874                    tokenizer: &tokenizer,
875                    template_filename: paths.get_template_filename(),
876                    generation_config: paths.get_gen_conf_filename(),
877                    config: config.clone(),
878                    processor_filename: &None,
879                    preprocessor_filename: &None,
880                    modules: None,
881                    module_paths: None,
882                },
883                multi_progress.clone(),
884            )?;
885        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
886            model.load_from_artifacts(
887                device.clone(),
888                self.config.topology.as_ref(),
889                silent,
890                from_uqff,
891            )?;
892        }
893
894        let paged_attn_config = if matches!(
895            self.kind,
896            ModelKind::Adapter {
897                adapter: AdapterKind::XLora
898            }
899        ) {
900            warn!(
901                "Adapter parallel_models do not currently support PagedAttention, running without"
902            );
903            None
904        } else {
905            paged_attn_config
906        };
907
908        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
909            let cache_config = calculate_cache_config(
910                paged_attn_config.mem_gpu,
911                paged_attn_config.block_size,
912                dtype,
913                paged_attn_config.cache_type,
914                model.config(),
915                &device,
916                &pipeline_mapper
917                    .get_unique_devices()
918                    .into_iter()
919                    .map(Some)
920                    .collect::<Vec<_>>(),
921                silent,
922            )?;
923
924            let mut layer_devices = Vec::new();
925            for layer in 0..self.inner.num_layers(&config)? {
926                let device = model.get_layers().1.device_for(layer, false).cloned();
927                layer_devices.push(device);
928            }
929            let cache_engine = CacheEngine::new(
930                model.config(),
931                &cache_config,
932                dtype,
933                model.device(),
934                layer_devices.clone(),
935            )?;
936
937            (Some(cache_config), Some(cache_engine))
938        } else {
939            (None, None)
940        };
941
942        let max_seq_len = model.max_seq_len();
943        let llg_factory = build_llg_factory(tokenizer.clone())?;
944        let num_hidden_layers = match model.cache() {
945            EitherCache::Full(full) => full.lock().len(),
946            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
947        };
948        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
949        let sliding_window = model.config().sliding_window;
950        let model_metadata = Arc::new(model.config().clone());
951
952        Ok(Arc::new(Mutex::new(NormalPipeline {
953            model,
954            tokenizer: tokenizer.into(),
955            no_kv_cache: self.no_kv_cache,
956            chat_template: Arc::new(chat_template),
957            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
958                NonGranularState {
959                    non_granular_index: Arc::new(Mutex::new(0)),
960                    tgt_non_granular_index,
961                }
962            }),
963            model_id: self.model_id.clone(),
964            metadata: Arc::new(GeneralMetadata {
965                max_seq_len,
966                llg_factory: Some(llg_factory),
967                no_kv_cache: self.no_kv_cache,
968                no_prefix_cache: is_xlora,
969                num_hidden_layers,
970                eos_tok: eos,
971                kind: self.kind.clone(),
972                is_xlora,
973                activation_dtype: dtype,
974                sliding_window,
975                cache_config,
976                cache_engine,
977                model_metadata: Some(model_metadata),
978                modalities: Modalities {
979                    input: vec![SupportedModality::Text],
980                    output: vec![SupportedModality::Text],
981                },
982            }),
983            topology: self.config.topology.clone(),
984            silent,
985            organization: self.config.organization,
986            template_filename: paths.get_template_filename().clone(),
987            generation_config: paths.get_gen_conf_filename().cloned(),
988            config,
989            imatrix: self.config.imatrix.clone(),
990            mapper: pipeline_mapper,
991        })))
992    }
993
994    fn get_id(&self) -> String {
995        self.model_id.clone()
996    }
997
998    fn get_kind(&self) -> ModelKind {
999        self.kind.clone()
1000    }
1001}
1002
1003impl PreProcessingMixin for NormalPipeline {
1004    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1005        Some(self.chat_template.clone())
1006    }
1007    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1008        None
1009    }
1010}
1011
1012impl IsqPipelineMixin for NormalPipeline {
1013    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1014        let device = self.device().clone();
1015        let multi_progress = Arc::new(MultiProgress::new());
1016        self.model.quantize(
1017            Some(dtype),
1018            device.clone(),
1019            self.topology.as_ref(),
1020            self.silent,
1021            self.imatrix.as_ref().map(ImatrixDataSource::File),
1022            self.organization,
1023            true,
1024            None,
1025            UqffFullSer {
1026                tokenizer: &self.tokenizer,
1027                template_filename: &self.template_filename,
1028                generation_config: self.generation_config.as_ref(),
1029                config: self.config.clone(),
1030                processor_filename: &None,
1031                preprocessor_filename: &None,
1032                modules: None,
1033                module_paths: None,
1034            },
1035            multi_progress.clone(),
1036        )?;
1037        Ok(())
1038    }
1039}
1040
1041impl CacheManagerMixin for NormalPipeline {
1042    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1043        if matches!(self.model.cache(), EitherCache::Full(_)) {
1044            FullCacheManager.clone_in_cache(self, seqs, false)
1045        } else {
1046            NormalCacheManager.clone_in_cache(self, seqs, false)
1047        }
1048    }
1049    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1050        if matches!(self.model.cache(), EitherCache::Full(_)) {
1051            FullCacheManager.clone_out_cache(self, seqs, false)
1052        } else {
1053            NormalCacheManager.clone_out_cache(self, seqs, false)
1054        }
1055    }
1056    fn set_none_cache(
1057        &self,
1058        seqs: &mut [&mut Sequence],
1059        reset_non_granular: bool,
1060        modify_draft_cache: bool,
1061        load_preallocated_cache: bool,
1062    ) {
1063        if matches!(self.model.cache(), EitherCache::Full(_)) {
1064            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
1065        } else {
1066            NormalCacheManager.set_none_cache(
1067                self,
1068                seqs,
1069                modify_draft_cache,
1070                load_preallocated_cache,
1071            );
1072        }
1073        if reset_non_granular {
1074            self.reset_non_granular_state()
1075        }
1076    }
1077    fn cache(&self) -> &EitherCache {
1078        self.model.cache()
1079    }
1080}
1081
1082impl MetadataMixin for NormalPipeline {
1083    fn device(&self) -> Device {
1084        self.model.device().clone()
1085    }
1086    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1087        Some(self.tokenizer.clone())
1088    }
1089    fn name(&self) -> String {
1090        self.model_id.clone()
1091    }
1092    fn reset_non_granular_state(&self) {
1093        if let Some(s) = self.non_granular_state.as_ref() {
1094            *self.cache().full().get_scalings_cache() = None;
1095            *get_mut_arcmutex!(s.non_granular_index) = 0;
1096        }
1097    }
1098    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1099        self.metadata.clone()
1100    }
1101    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1102        Some(&*self.mapper)
1103    }
1104}
1105
1106#[async_trait::async_trait]
1107impl Pipeline for NormalPipeline {
1108    fn forward_inputs(
1109        &mut self,
1110        inputs: Box<dyn Any>,
1111        return_raw_logits: bool,
1112    ) -> Result<ForwardInputsResult, candle_core::Error> {
1113        let ModelInputs {
1114            input_ids,
1115            input_ids_full,
1116            seqlen_offsets,
1117            seqlen_offsets_full,
1118            context_lens,
1119            position_ids,
1120            paged_attn_meta,
1121            flash_meta,
1122            flash_meta_full,
1123        } = *inputs.downcast().expect("Downcast failed.");
1124        let metadata = self.get_metadata();
1125        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1126            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1127            (Some(_), None) => {
1128                // This can happen if Rust-side user code is wrong
1129                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1130            }
1131            (None, Some(_)) => {
1132                // This should never happen but we handle it anyway
1133                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1134            }
1135            (None, None) => None,
1136        };
1137        let logits = match self.model.is_xlora() {
1138            false => {
1139                let paged_attn_meta = paged_attn_meta
1140                    .as_ref()
1141                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1142
1143                self.model.forward(
1144                    &input_ids,
1145                    &seqlen_offsets,
1146                    context_lens,
1147                    position_ids,
1148                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1149                    &flash_meta,
1150                )?
1151            }
1152            true => self.model.xlora_forward(
1153                &input_ids,
1154                input_ids_full.as_ref().unwrap_or(&input_ids),
1155                &seqlen_offsets,
1156                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1157                self.no_kv_cache,
1158                &self.non_granular_state,
1159                context_lens,
1160                position_ids,
1161                &flash_meta,
1162                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1163            )?,
1164        };
1165        if return_raw_logits {
1166            Ok(ForwardInputsResult::RawLogits { logits })
1167        } else {
1168            Ok(ForwardInputsResult::CausalGeneration { logits })
1169        }
1170    }
1171    async fn sample_causal_gen(
1172        &self,
1173        seqs: &mut [&mut Sequence],
1174        logits: Vec<Tensor>,
1175        prefix_cacher: &mut PrefixCacheManagerV2,
1176        disable_eos_stop: bool,
1177        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1178    ) -> Result<(), candle_core::Error> {
1179        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1180    }
1181    fn category(&self) -> ModelCategory {
1182        ModelCategory::Text
1183    }
1184}
1185
1186impl AnyMoePipelineMixin for NormalPipeline {
1187    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1188        self.model.finish_training(gate_model_id)
1189    }
1190    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1191        self.model.get_vars()
1192    }
1193    fn amoe_base_model_trainable_params(&self) -> usize {
1194        self.model.trainable_params()
1195    }
1196    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1197        self.model.take_cached_gating_outputs()
1198    }
1199    fn amoe_create_layers(
1200        &mut self,
1201        model_ids: Vec<String>,
1202        token: &TokenSource,
1203        revision: Option<String>,
1204        match_regex: &str,
1205        config: crate::amoe::AnyMoeConfig,
1206        dtype: candle_core::DType,
1207        dev: &Device,
1208        (prefix, mlp): (String, String),
1209        layers: Vec<usize>,
1210        expert_type: AnyMoeExpertType,
1211        silent: bool,
1212        gate_model_id: Option<String>,
1213    ) -> candle_core::Result<()> {
1214        let mut vbs = Vec::new();
1215        // Precompile regex here
1216        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1217        for model_id in model_ids {
1218            let model_id_str = &model_id;
1219            let model_id = Path::new(&model_id);
1220
1221            let api = {
1222                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1223                let mut api = ApiBuilder::from_cache(cache)
1224                    .with_progress(!silent)
1225                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1226                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1227                    api = api.with_cache_dir(x.into());
1228                }
1229                api.build().map_err(candle_core::Error::msg)?
1230            };
1231            let revision = revision.clone().unwrap_or("main".to_string());
1232            let api = api.repo(Repo::with_revision(
1233                model_id_str.clone(),
1234                RepoType::Model,
1235                revision.clone(),
1236            ));
1237
1238            let mut filenames = vec![];
1239            for rfilename in
1240                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1241            {
1242                filenames.push(api_get_file!(api, &rfilename, model_id));
1243            }
1244
1245            let regex = regex.clone();
1246            let match_regex_clone = match_regex.to_string();
1247            let layers_clone = layers.clone();
1248            let vb = from_mmaped_safetensors(
1249                filenames,
1250                vec![],
1251                Some(dtype),
1252                dev,
1253                vec![None],
1254                silent,
1255                None,
1256                move |key| {
1257                    if regex.is_match(&key) {
1258                        // Idx of the last char of the layer id, +1
1259                        // Assumes N.MLP
1260                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1261                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1262                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1263                            .parse::<usize>()
1264                            .unwrap();
1265                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1266                    } else {
1267                        false
1268                    }
1269                },
1270                Arc::new(|_| DeviceForLoadTensor::Base),
1271            )?;
1272            vbs.push(vb);
1273        }
1274
1275        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1276            let model_id_str = &gate_model_id;
1277            let model_id = Path::new(&gate_model_id);
1278
1279            let api = {
1280                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1281                let mut api = ApiBuilder::from_cache(cache)
1282                    .with_progress(!silent)
1283                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1284                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1285                    api = api.with_cache_dir(x.into());
1286                }
1287                api.build().map_err(candle_core::Error::msg)?
1288            };
1289            let revision = revision.clone().unwrap_or("main".to_string());
1290            let api = api.repo(Repo::with_revision(
1291                model_id_str.clone(),
1292                RepoType::Model,
1293                revision.clone(),
1294            ));
1295
1296            let mut gate_filenames = vec![];
1297            for rfilename in
1298                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1299            {
1300                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1301            }
1302            assert_eq!(
1303                gate_filenames.len(),
1304                1,
1305                "Gate model ID must contain only one .safetensors file"
1306            );
1307
1308            let vb = from_mmaped_safetensors(
1309                gate_filenames.clone(),
1310                vec![],
1311                Some(dtype),
1312                dev,
1313                vec![None],
1314                silent,
1315                None,
1316                |_| true,
1317                Arc::new(|_| DeviceForLoadTensor::Base),
1318            )?;
1319            info!(
1320                "Loaded gating layers from `{}`",
1321                gate_filenames[0].display()
1322            );
1323            Some(vb)
1324        } else {
1325            None
1326        };
1327
1328        self.model.create_anymoe_layers(
1329            vbs.clone(),
1330            config.clone(),
1331            (prefix.clone(), mlp.clone()),
1332            layers.clone(),
1333            expert_type.clone(),
1334            gate_vb.clone(),
1335        )?;
1336
1337        Ok(())
1338    }
1339    fn amoe_supported(&self) -> bool {
1340        self.model.amoe_supported()
1341    }
1342}