mistralrs_core/pipeline/
normal.rs

1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6    TokenSource,
7};
8use super::{
9    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14    LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
15    Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
16};
17use crate::amoe::AnyMoeExpertType;
18use crate::attention::ATTENTION_CHUNK_SIZE;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use std::time::Instant;
58use std::{env, fs};
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63pub struct NormalPipeline {
64    model: Box<dyn NormalModel + Send + Sync>,
65    tokenizer: Arc<Tokenizer>,
66    no_kv_cache: bool,
67    chat_template: Arc<ChatTemplate>,
68    non_granular_state: Option<NonGranularState>,
69    model_id: String,
70    metadata: Arc<GeneralMetadata>,
71    topology: Option<Topology>,
72    silent: bool,
73    organization: IsqOrganization,
74    // For full UQFF serialization
75    template_filename: Option<PathBuf>,
76    generation_config: Option<PathBuf>,
77    config: String,
78    imatrix: Option<PathBuf>,
79    mapper: Box<dyn DeviceMapper + Send + Sync>,
80}
81
82/// A loader for a "normal" (non-quantized) model.
83pub struct NormalLoader {
84    inner: Box<dyn NormalModelLoader>,
85    model_id: String,
86    config: NormalSpecificConfig,
87    xlora_model_id: Option<String>,
88    lora_adapter_ids: Option<Vec<String>>,
89    kind: ModelKind,
90    xlora_order: Option<Ordering>,
91    no_kv_cache: bool,
92    chat_template: Option<String>,
93    tokenizer_json: Option<String>,
94    tgt_non_granular_index: Option<usize>,
95    token_source: RwLock<Option<TokenSource>>,
96    revision: RwLock<Option<String>>,
97    from_uqff: RwLock<Option<Vec<PathBuf>>>,
98    jinja_explicit: Option<String>,
99    hf_cache_path: Option<PathBuf>,
100}
101
102#[derive(Default)]
103/// A builder for a loader for a "normal" (non-quantized) model.
104pub struct NormalLoaderBuilder {
105    model_id: Option<String>,
106    config: NormalSpecificConfig,
107    xlora_model_id: Option<String>,
108    lora_adapter_ids: Option<Vec<String>>,
109    kind: ModelKind,
110    xlora_order: Option<Ordering>,
111    no_kv_cache: bool,
112    chat_template: Option<String>,
113    tokenizer_json: Option<String>,
114    tgt_non_granular_index: Option<usize>,
115    jinja_explicit: Option<String>,
116    hf_cache_path: Option<PathBuf>,
117}
118
119#[derive(Clone, Default)]
120/// Config specific to loading a normal model.
121pub struct NormalSpecificConfig {
122    pub topology: Option<Topology>,
123    pub organization: IsqOrganization,
124    pub write_uqff: Option<PathBuf>,
125    pub from_uqff: Option<Vec<PathBuf>>,
126    pub imatrix: Option<PathBuf>,
127    pub calibration_file: Option<PathBuf>,
128    pub hf_cache_path: Option<PathBuf>,
129    pub matformer_config_path: Option<PathBuf>,
130    pub matformer_slice_name: Option<String>,
131}
132
133impl NormalLoaderBuilder {
134    pub fn new(
135        config: NormalSpecificConfig,
136        chat_template: Option<String>,
137        tokenizer_json: Option<String>,
138        model_id: Option<String>,
139        no_kv_cache: bool,
140        jinja_explicit: Option<String>,
141    ) -> Self {
142        Self {
143            config,
144            chat_template,
145            tokenizer_json,
146            model_id,
147            kind: ModelKind::Normal,
148            jinja_explicit,
149            no_kv_cache,
150            ..Default::default()
151        }
152    }
153
154    fn with_adapter(
155        mut self,
156        xlora_model_id: String,
157        xlora_order: Ordering,
158        no_kv_cache: bool,
159        tgt_non_granular_index: Option<usize>,
160    ) -> Self {
161        self.xlora_model_id = Some(xlora_model_id);
162        self.xlora_order = Some(xlora_order);
163        self.no_kv_cache = no_kv_cache;
164        self.tgt_non_granular_index = tgt_non_granular_index;
165        self.model_id = if let Some(id) = self.model_id {
166            Some(id)
167        } else {
168            info!(
169                "Using adapter base model ID: `{}`",
170                self.xlora_order.as_ref().unwrap().base_model_id
171            );
172            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173        };
174        self
175    }
176
177    pub fn with_xlora(
178        mut self,
179        xlora_model_id: String,
180        xlora_order: Ordering,
181        no_kv_cache: bool,
182        tgt_non_granular_index: Option<usize>,
183    ) -> Self {
184        self.kind = ModelKind::Adapter {
185            adapter: AdapterKind::XLora,
186        };
187        self.with_adapter(
188            xlora_model_id,
189            xlora_order,
190            no_kv_cache,
191            tgt_non_granular_index,
192        )
193    }
194
195    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
196        self.kind = ModelKind::Adapter {
197            adapter: AdapterKind::Lora,
198        };
199        self.lora_adapter_ids = Some(lora_adapter_ids);
200        self
201    }
202
203    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
204        self.hf_cache_path = Some(hf_cache_path);
205        self
206    }
207
208    /// If the loader type is not specified, loader type is automatically determined from the
209    /// `architectures` array in the config.
210    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
211        let loader: Box<dyn NormalModelLoader> = match loader_tp {
212            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
213            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
214            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
215            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
216            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
217            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
218            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
219            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
220            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
221            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
222            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
223            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
224            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225            Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
226            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
227            Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
228            None => Box::new(AutoNormalLoader),
229        };
230        Ok(Box::new(NormalLoader {
231            inner: loader,
232            model_id: self.model_id.unwrap(),
233            config: self.config,
234            xlora_model_id: self.xlora_model_id,
235            lora_adapter_ids: self.lora_adapter_ids,
236            kind: self.kind,
237            xlora_order: self.xlora_order,
238            no_kv_cache: self.no_kv_cache,
239            chat_template: self.chat_template,
240            tokenizer_json: self.tokenizer_json,
241            tgt_non_granular_index: self.tgt_non_granular_index,
242            jinja_explicit: self.jinja_explicit,
243            token_source: RwLock::new(None),
244            revision: RwLock::new(None),
245            from_uqff: RwLock::new(None),
246            hf_cache_path: self.hf_cache_path,
247        }))
248    }
249}
250
251impl Loader for NormalLoader {
252    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
253    fn load_model_from_hf(
254        &self,
255        revision: Option<String>,
256        token_source: TokenSource,
257        dtype: &dyn TryIntoDType,
258        device: &Device,
259        silent: bool,
260        mapper: DeviceMapSetting,
261        in_situ_quant: Option<IsqType>,
262        paged_attn_config: Option<PagedAttentionConfig>,
263    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
264        let cache = self
265            .hf_cache_path
266            .clone()
267            .map(Cache::new)
268            .unwrap_or_default();
269        GLOBAL_HF_CACHE.get_or_init(|| cache);
270
271        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
272            LocalModelPaths,
273            &token_source,
274            revision.clone(),
275            self,
276            None,
277            None,
278            silent,
279            self.config.from_uqff.is_some()
280        );
281        if let Some(from_uqff) = self.config.from_uqff.clone() {
282            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
283        }
284        *self
285            .token_source
286            .write()
287            .expect("Failed to write to token source") = Some(token_source);
288        *self.revision.write().expect("Failed to write to revision") = revision;
289        self.load_model_from_path(
290            &paths?,
291            dtype,
292            device,
293            silent,
294            mapper,
295            in_situ_quant,
296            paged_attn_config,
297        )
298    }
299
300    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
301    fn load_model_from_path(
302        &self,
303        paths: &Box<dyn ModelPaths>,
304        dtype: &dyn TryIntoDType,
305        device: &Device,
306        silent: bool,
307        mut mapper: DeviceMapSetting,
308        mut in_situ_quant: Option<IsqType>,
309        mut paged_attn_config: Option<PagedAttentionConfig>,
310    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
311        let config = std::fs::read_to_string(paths.get_config_filename())?;
312
313        if !self.inner.supports_paged_attention(&config)? {
314            paged_attn_config = None;
315        }
316
317        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
318
319        let use_nccl = mistralrs_quant::distributed::use_nccl();
320
321        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
322            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
323            let WorkerTransferData::Init { id: _, worker_rank } = payload;
324            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
325        } else if use_nccl {
326            vec![candle_core::Device::new_cuda(0)?]
327        } else {
328            device_map::get_all_similar_devices(device)?
329        };
330        #[cfg(feature = "cuda")]
331        for device in &available_devices {
332            if let Device::Cuda(dev) = device {
333                unsafe { dev.disable_event_tracking() };
334            }
335        }
336        let device = if use_nccl || cfg!(feature = "ring") {
337            available_devices[0].clone()
338        } else {
339            device.clone()
340        };
341
342        // If auto, convert to Map if not using nccl
343        if use_nccl || cfg!(feature = "ring") {
344            mapper = DeviceMapSetting::DummyNccl {
345                nm_device: available_devices[0].clone(),
346            };
347        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
348            // Initial dtype
349            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
350
351            // Disable ISQ if we are loading a prequantized model.
352            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
353                in_situ_quant = None;
354            }
355
356            // ISQ or UQFF: quantized path
357            // Match logic below where UQFF has priority
358            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
359                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
360                    let weight_pack_factor = {
361                        let ser_artifacts = unsafe {
362                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
363                        };
364                        let mut total_pack_factors = 0;
365                        let total_tensors = ser_artifacts.tensors().len();
366                        for (_, artifact) in ser_artifacts.tensors() {
367                            let artifact = artifact.data();
368                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
369                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
370                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
371                            {
372                                QuantizedSerdeType::Hqq => {
373                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
374                                        .pack_factor(dtype)
375                                }
376                                QuantizedSerdeType::Gguf => {
377                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
378                                        .pack_factor(dtype)
379                                }
380                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
381                                QuantizedSerdeType::Unquant => 1,
382                                QuantizedSerdeType::Afq => {
383                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
384                                        .pack_factor(dtype)
385                                }
386                            };
387                            total_pack_factors += pack_factor;
388                        }
389
390                        total_pack_factors / total_tensors
391                    };
392
393                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
394                        &config,
395                        dtype,
396                        weight_pack_factor,
397                        None,
398                    )?;
399                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
400                        &config,
401                        dtype,
402                        weight_pack_factor,
403                        None,
404                    )?;
405                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
406                    (
407                        layer_sizes_in_bytes,
408                        non_mapped_size_in_bytes,
409                        layer_sizes_sum + non_mapped_size_in_bytes,
410                    )
411                } else if let Some(isq) = in_situ_quant {
412                    let weight_pack_factor = isq.pack_factor(dtype);
413                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
414                        &config,
415                        dtype,
416                        weight_pack_factor,
417                        None,
418                    )?;
419                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
420                        &config,
421                        dtype,
422                        weight_pack_factor,
423                        None,
424                    )?;
425                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
426                    (
427                        layer_sizes_in_bytes,
428                        non_mapped_size_in_bytes,
429                        layer_sizes_sum + non_mapped_size_in_bytes,
430                    )
431                } else {
432                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
433                    let weight_pack_factor =
434                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
435                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
436                        &config,
437                        dtype,
438                        weight_pack_factor,
439                        None,
440                    )?;
441                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
442                        &config,
443                        dtype,
444                        weight_pack_factor,
445                        None,
446                    )?;
447                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
448                    (
449                        layer_sizes_in_bytes,
450                        non_mapped_size_in_bytes,
451                        layer_sizes_sum + non_mapped_size_in_bytes,
452                    )
453                };
454
455            let new = auto_device_map::get_device_layers(
456                &*self.inner,
457                &config,
458                self.inner.num_layers(&config)?,
459                layer_sizes_in_bytes,
460                non_mapped_size_in_bytes,
461                total_model_size_in_bytes,
462                &available_devices,
463                dtype,
464                &params,
465                paged_attn_config.as_ref(),
466            )?;
467            mapper = DeviceMapSetting::Map(new);
468        }
469
470        let pipeline_mapper = mapper.into_mapper(
471            self.inner.num_layers(&config)?,
472            &device,
473            self.config.topology.as_ref(),
474        )?;
475        let mapper = mapper.into_mapper(
476            self.inner.num_layers(&config)?,
477            &device,
478            self.config.topology.as_ref(),
479        )?;
480        let mut layer_devices = Vec::new();
481        for layer in 0..self.inner.num_layers(&config)? {
482            let device = mapper.device_for(layer, false).cloned();
483            layer_devices.push(device);
484        }
485        let dtype = mapper.get_min_dtype(dtype)?;
486
487        // TODO: PagedAttention is not supported with CPU for now.
488        // This check is not really necessary because `get_device_layers` should prevent it.
489        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
490        if mapping_uses_cpu && paged_attn_config.is_some() {
491            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
492            paged_attn_config = None;
493        }
494
495        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
496        if crate::using_flash_attn() {
497            once_log_info("FlashAttention is enabled.");
498        }
499
500        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
501        let mut loading_isq = if self.config.imatrix.is_none()
502            && self.config.calibration_file.is_none()
503            && !device.is_cuda()
504            && self.config.write_uqff.is_none()
505            && in_situ_quant.is_some()
506        {
507            let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
508            {
509                self.inner.immediate_isq_predicates_moqe(&config)?
510            } else {
511                self.inner.immediate_isq_predicates(&config)?
512            };
513            info!("Applying ISQ to {in_situ_quant:?}");
514            if predicates.is_empty() {
515                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
516            }
517            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
518            false
519        } else {
520            in_situ_quant.is_some()
521        };
522
523        if let Some(ref topology) = self.config.topology {
524            loading_isq |= topology
525                .0
526                .iter()
527                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
528        }
529
530        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
531            anyhow::bail!(
532                "`imatrix` and `calibration_file` were both specified, this is not allowed."
533            );
534        }
535
536        // Load onto the regular device if not using isq or if the calibration file is specified
537        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
538            loading_isq = false;
539            device.clone()
540        } else {
541            Device::Cpu
542        };
543
544        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
545
546        let attention_mechanism = if paged_attn_config.is_some() {
547            AttentionImplementation::PagedAttention
548        } else {
549            AttentionImplementation::Eager
550        };
551
552        let multi_progress = Arc::new(MultiProgress::new());
553
554        // Load matformer slicing config if provided
555        let matformer_slicing_config = if let Some(matformer_path) =
556            &self.config.matformer_config_path
557        {
558            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
559            info!("Loading Matformer config from {:?}", matformer_path);
560            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
561
562            if let Some(slice_name) = &self.config.matformer_slice_name {
563                info!("Using Matformer slice: {}", slice_name);
564                Some(MatformerSliceConfig::new(slice_name.clone(), config))
565            } else {
566                // If no slice name is provided but config exists, we'll need to handle this
567                // For now, return None and let the model handle the default slice selection
568                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
569                None
570            }
571        } else {
572            None
573        };
574
575        let mut model = if use_nccl || cfg!(feature = "ring") {
576            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
577                dtype,
578                &device,
579                &available_devices,
580                silent,
581                &config,
582                loading_isq,
583                self.config.from_uqff.is_some(),
584                self.config.organization,
585                &*self.inner,
586                paths.as_ref(),
587            )?;
588
589            // Special case for where things can be more optimially loaded.
590            match self.kind {
591                ModelKind::Normal => normal_model_loader_sharded!(
592                    sharded_vb,
593                    config,
594                    self.inner,
595                    mapper,
596                    loading_isq,
597                    device.clone(),
598                    attention_mechanism,
599                    multi_progress.clone(),
600                    matformer_slicing_config.clone(),
601                ),
602                ModelKind::Adapter {
603                    adapter: AdapterKind::XLora,
604                } => xlora_model_loader!(
605                    paths,
606                    Some(dtype),
607                    &load_device,
608                    layer_devices.clone(),
609                    config,
610                    self.inner,
611                    silent,
612                    mapper,
613                    loading_isq,
614                    device.clone(),
615                    multi_progress.clone(),
616                    matformer_slicing_config.clone(),
617                ),
618                ModelKind::Adapter {
619                    adapter: AdapterKind::Lora,
620                } => lora_model_loader!(
621                    paths,
622                    Some(dtype),
623                    &load_device,
624                    layer_devices.clone(),
625                    config,
626                    self.inner,
627                    silent,
628                    mapper,
629                    loading_isq,
630                    self.config.from_uqff.is_some(),
631                    device.clone(),
632                    attention_mechanism,
633                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
634                    multi_progress.clone(),
635                    matformer_slicing_config.clone(),
636                ),
637                _ => unreachable!(),
638            }
639        } else {
640            match self.kind {
641                ModelKind::Normal => normal_model_loader!(
642                    paths,
643                    Some(dtype),
644                    &load_device,
645                    layer_devices.clone(),
646                    config,
647                    self.inner,
648                    silent,
649                    mapper,
650                    loading_isq,
651                    self.config.from_uqff.is_some(),
652                    device.clone(),
653                    attention_mechanism,
654                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
655                    multi_progress.clone(),
656                    matformer_slicing_config.clone(),
657                ),
658                ModelKind::Adapter {
659                    adapter: AdapterKind::XLora,
660                } => xlora_model_loader!(
661                    paths,
662                    Some(dtype),
663                    &load_device,
664                    layer_devices.clone(),
665                    config,
666                    self.inner,
667                    silent,
668                    mapper,
669                    loading_isq,
670                    device.clone(),
671                    multi_progress.clone(),
672                    matformer_slicing_config.clone(),
673                ),
674                ModelKind::Adapter {
675                    adapter: AdapterKind::Lora,
676                } => lora_model_loader!(
677                    paths,
678                    Some(dtype),
679                    &load_device,
680                    layer_devices.clone(),
681                    config,
682                    self.inner,
683                    silent,
684                    mapper,
685                    loading_isq,
686                    self.config.from_uqff.is_some(),
687                    device.clone(),
688                    attention_mechanism,
689                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
690                    multi_progress.clone(),
691                    matformer_slicing_config.clone(),
692                ),
693                _ => unreachable!(),
694            }
695        };
696
697        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
698        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
699            match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
700                Ok(conf) => Some(conf),
701                Err(e) => {
702                    warn!("Failed to parse generation_config.json: {}", e);
703                    None
704                }
705            }
706        });
707
708        let chat_template_explicit = paths
709            .get_chat_template_explicit()
710            .as_ref()
711            .map(|x| x.to_string_lossy().to_string());
712        let chat_template = get_chat_template(
713            paths,
714            self.jinja_explicit.as_ref(),
715            chat_template_explicit.as_ref(),
716            self.chat_template.as_ref(),
717            None,
718        );
719
720        if let Some(calibration_file) = &self.config.calibration_file {
721            let calibration_data = std::fs::read_to_string(calibration_file)?;
722            // Tokenize, don't add bos yet
723            let tokens = tokenizer
724                .encode_fast(calibration_data, false)
725                .map_err(anyhow::Error::msg)?
726                .get_ids()
727                .to_vec();
728            info!(
729                "Collecting imatrix from calibration file `{}` of {} tokens.",
730                calibration_file.display(),
731                tokens.len()
732            );
733            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
734            let bos_tok_id = tokenizer
735                .token_to_id(&bos_toks[0])
736                .expect("Somehow the bos token is not present.");
737
738            match self.config.organization {
739                IsqOrganization::Default => model.begin_track_stats()?,
740                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
741            }
742
743            const CHUNK_SIZE: usize = 1024;
744            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
745            let start = Instant::now();
746            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
747                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
748                let chunk_len = chunk.len();
749
750                let start = Instant::now();
751                let inputs = make_prompt_chunk(
752                    0,
753                    vec![&chunk],
754                    &[0],
755                    &load_device,
756                    None,
757                    false,
758                    None,
759                    Some(pipeline_mapper.as_ref()),
760                )?;
761
762                model.forward(
763                    &inputs.input.to_device(model.device())?,
764                    &inputs.positions,
765                    inputs.context_lens.clone(),
766                    inputs.position_ids.clone(),
767                    None,
768                    &inputs.flash_meta.clone(),
769                )?;
770
771                match model.cache_mut() {
772                    EitherCache::Full(full) => {
773                        for layer in &mut *full.lock() {
774                            *layer = None
775                        }
776                    }
777                    EitherCache::Normal(normal) => {
778                        for layer in &mut *normal.lock().unwrap().0 {
779                            layer.reset();
780                        }
781                    }
782                }
783
784                let end = Instant::now();
785                info!(
786                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
787                    i + 1,
788                    end.duration_since(start).as_secs_f32()
789                );
790            }
791            load_device.synchronize()?;
792            let end = Instant::now();
793            info!(
794                "Finished collecting imatrix in {:.2}s",
795                end.duration_since(start).as_secs_f32()
796            );
797        }
798
799        // Only if loading from UQFF
800        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
801            let imatrix_source = match (
802                self.config.imatrix.as_ref(),
803                self.config.calibration_file.is_some(),
804            ) {
805                (None, false) => None,
806                (Some(file), false) => Some(ImatrixDataSource::File(file)),
807                (None, true) => Some(ImatrixDataSource::Collected),
808                (Some(_), true) => unreachable!(),
809            };
810
811            info!("Applying ISQ to all ranks.");
812
813            let multi_progress = Arc::new(MultiProgress::new());
814
815            model.quantize(
816                in_situ_quant,
817                model.device().clone(),
818                self.config.topology.as_ref(),
819                silent,
820                imatrix_source,
821                self.config.organization,
822                self.config.write_uqff.as_ref(),
823                UqffFullSer {
824                    tokenizer: &tokenizer,
825                    template_filename: paths.get_template_filename(),
826                    generation_config: paths.get_gen_conf_filename(),
827                    config: config.clone(),
828                    processor_filename: &None,
829                    preprocessor_filename: &None,
830                },
831                multi_progress.clone(),
832            )?;
833        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
834            model.load_from_artifacts(
835                device.clone(),
836                self.config.topology.as_ref(),
837                silent,
838                from_uqff,
839            )?;
840        }
841
842        let paged_attn_config = if matches!(
843            self.kind,
844            ModelKind::Adapter {
845                adapter: AdapterKind::XLora
846            }
847        ) {
848            warn!(
849                "Adapter parallel_models do not currently support PagedAttention, running without"
850            );
851            None
852        } else {
853            paged_attn_config
854        };
855
856        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
857            let cache_config = calculate_cache_config(
858                paged_attn_config.mem_gpu,
859                paged_attn_config.mem_cpu,
860                paged_attn_config.block_size,
861                dtype,
862                paged_attn_config.cache_type,
863                model.config(),
864                &device,
865                &pipeline_mapper
866                    .get_unique_devices()
867                    .into_iter()
868                    .map(Some)
869                    .collect::<Vec<_>>(),
870                silent,
871            )?;
872
873            let mut layer_devices = Vec::new();
874            for layer in 0..self.inner.num_layers(&config)? {
875                let device = model.get_layers().1.device_for(layer, false).cloned();
876                layer_devices.push(device);
877            }
878            let cache_engine = CacheEngine::new(
879                model.config(),
880                &cache_config,
881                dtype,
882                model.device(),
883                layer_devices.clone(),
884            )?;
885
886            (Some(cache_config), Some(cache_engine))
887        } else {
888            (None, None)
889        };
890
891        let max_seq_len = model.max_seq_len();
892        let llg_factory = build_llg_factory(tokenizer.clone())?;
893        let num_hidden_layers = match model.cache() {
894            EitherCache::Full(full) => full.lock().len(),
895            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
896        };
897        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
898        let sliding_window = model.config().sliding_window;
899        let model_metadata = Arc::new(model.config().clone());
900
901        Ok(Arc::new(Mutex::new(NormalPipeline {
902            model,
903            tokenizer: tokenizer.into(),
904            no_kv_cache: self.no_kv_cache,
905            chat_template: Arc::new(chat_template),
906            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
907                NonGranularState {
908                    non_granular_index: Arc::new(Mutex::new(0)),
909                    tgt_non_granular_index,
910                }
911            }),
912            model_id: self.model_id.clone(),
913            metadata: Arc::new(GeneralMetadata {
914                max_seq_len,
915                llg_factory: Some(llg_factory),
916                no_kv_cache: self.no_kv_cache,
917                no_prefix_cache: is_xlora,
918                num_hidden_layers,
919                eos_tok: eos,
920                kind: self.kind.clone(),
921                is_xlora,
922                activation_dtype: dtype,
923                sliding_window,
924                cache_config,
925                cache_engine,
926                model_metadata: Some(model_metadata),
927                modalities: Modalities {
928                    input: vec![SupportedModality::Text],
929                    output: vec![SupportedModality::Text],
930                },
931            }),
932            topology: self.config.topology.clone(),
933            silent,
934            organization: self.config.organization,
935            template_filename: paths.get_template_filename().clone(),
936            generation_config: paths.get_gen_conf_filename().cloned(),
937            config,
938            imatrix: self.config.imatrix.clone(),
939            mapper: pipeline_mapper,
940        })))
941    }
942
943    fn get_id(&self) -> String {
944        self.model_id.clone()
945    }
946
947    fn get_kind(&self) -> ModelKind {
948        self.kind.clone()
949    }
950}
951
952impl PreProcessingMixin for NormalPipeline {
953    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
954        Some(self.chat_template.clone())
955    }
956    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
957        None
958    }
959}
960
961impl IsqPipelineMixin for NormalPipeline {
962    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
963        let device = self.device().clone();
964        let multi_progress = Arc::new(MultiProgress::new());
965        self.model.quantize(
966            Some(dtype),
967            device.clone(),
968            self.topology.as_ref(),
969            self.silent,
970            self.imatrix.as_ref().map(ImatrixDataSource::File),
971            self.organization,
972            None,
973            UqffFullSer {
974                tokenizer: &self.tokenizer,
975                template_filename: &self.template_filename,
976                generation_config: self.generation_config.as_ref(),
977                config: self.config.clone(),
978                processor_filename: &None,
979                preprocessor_filename: &None,
980            },
981            multi_progress.clone(),
982        )?;
983        Ok(())
984    }
985}
986
987impl CacheManagerMixin for NormalPipeline {
988    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
989        if matches!(self.model.cache(), EitherCache::Full(_)) {
990            FullCacheManager.clone_in_cache(self, seqs, false)
991        } else {
992            NormalCacheManager.clone_in_cache(self, seqs, false)
993        }
994    }
995    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
996        if matches!(self.model.cache(), EitherCache::Full(_)) {
997            FullCacheManager.clone_out_cache(self, seqs, false)
998        } else {
999            NormalCacheManager.clone_out_cache(self, seqs, false)
1000        }
1001    }
1002    fn set_none_cache(
1003        &self,
1004        seqs: &mut [&mut Sequence],
1005        reset_non_granular: bool,
1006        modify_draft_cache: bool,
1007        load_preallocated_cache: bool,
1008    ) {
1009        if matches!(self.model.cache(), EitherCache::Full(_)) {
1010            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
1011        } else {
1012            NormalCacheManager.set_none_cache(
1013                self,
1014                seqs,
1015                modify_draft_cache,
1016                load_preallocated_cache,
1017            );
1018        }
1019        if reset_non_granular {
1020            self.reset_non_granular_state()
1021        }
1022    }
1023    fn cache(&self) -> &EitherCache {
1024        self.model.cache()
1025    }
1026}
1027
1028impl MetadataMixin for NormalPipeline {
1029    fn device(&self) -> Device {
1030        self.model.device().clone()
1031    }
1032    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1033        Some(self.tokenizer.clone())
1034    }
1035    fn name(&self) -> String {
1036        self.model_id.clone()
1037    }
1038    fn reset_non_granular_state(&self) {
1039        if let Some(s) = self.non_granular_state.as_ref() {
1040            *self.cache().full().get_scalings_cache() = None;
1041            *get_mut_arcmutex!(s.non_granular_index) = 0;
1042        }
1043    }
1044    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1045        self.metadata.clone()
1046    }
1047    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1048        Some(&*self.mapper)
1049    }
1050}
1051
1052#[async_trait::async_trait]
1053impl Pipeline for NormalPipeline {
1054    fn forward_inputs(
1055        &mut self,
1056        inputs: Box<dyn Any>,
1057        return_raw_logits: bool,
1058    ) -> Result<ForwardInputsResult, candle_core::Error> {
1059        let ModelInputs {
1060            input_ids,
1061            input_ids_full,
1062            seqlen_offsets,
1063            seqlen_offsets_full,
1064            context_lens,
1065            position_ids,
1066            paged_attn_meta,
1067            flash_meta,
1068            flash_meta_full,
1069        } = *inputs.downcast().expect("Downcast failed.");
1070        let metadata = self.get_metadata();
1071        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1072            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1073            (Some(_), None) => {
1074                // This can happen if Rust-side user code is wrong
1075                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1076            }
1077            (None, Some(_)) => {
1078                // This should never happen but we handle it anyway
1079                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1080            }
1081            (None, None) => None,
1082        };
1083        let logits = match self.model.is_xlora() {
1084            false => {
1085                let paged_attn_meta = paged_attn_meta
1086                    .as_ref()
1087                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1088
1089                self.model.forward(
1090                    &input_ids,
1091                    &seqlen_offsets,
1092                    context_lens,
1093                    position_ids,
1094                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1095                    &flash_meta,
1096                )?
1097            }
1098            true => self.model.xlora_forward(
1099                &input_ids,
1100                input_ids_full.as_ref().unwrap_or(&input_ids),
1101                &seqlen_offsets,
1102                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1103                self.no_kv_cache,
1104                &self.non_granular_state,
1105                context_lens,
1106                position_ids,
1107                &flash_meta,
1108                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1109            )?,
1110        };
1111        if return_raw_logits {
1112            Ok(ForwardInputsResult::RawLogits { logits })
1113        } else {
1114            Ok(ForwardInputsResult::CausalGeneration { logits })
1115        }
1116    }
1117    async fn sample_causal_gen(
1118        &self,
1119        seqs: &mut [&mut Sequence],
1120        logits: Vec<Tensor>,
1121        prefix_cacher: &mut PrefixCacheManagerV2,
1122        disable_eos_stop: bool,
1123        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1124    ) -> Result<(), candle_core::Error> {
1125        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1126    }
1127    fn category(&self) -> ModelCategory {
1128        ModelCategory::Text
1129    }
1130}
1131
1132impl AnyMoePipelineMixin for NormalPipeline {
1133    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1134        self.model.finish_training(gate_model_id)
1135    }
1136    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1137        self.model.get_vars()
1138    }
1139    fn amoe_base_model_trainable_params(&self) -> usize {
1140        self.model.trainable_params()
1141    }
1142    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1143        self.model.take_cached_gating_outputs()
1144    }
1145    fn amoe_create_layers(
1146        &mut self,
1147        model_ids: Vec<String>,
1148        token: &TokenSource,
1149        revision: Option<String>,
1150        match_regex: &str,
1151        config: crate::amoe::AnyMoeConfig,
1152        dtype: candle_core::DType,
1153        dev: &Device,
1154        (prefix, mlp): (String, String),
1155        layers: Vec<usize>,
1156        expert_type: AnyMoeExpertType,
1157        silent: bool,
1158        gate_model_id: Option<String>,
1159    ) -> candle_core::Result<()> {
1160        let mut vbs = Vec::new();
1161        // Precompile regex here
1162        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1163        for model_id in model_ids {
1164            let model_id_str = &model_id;
1165            let model_id = Path::new(&model_id);
1166
1167            let api = {
1168                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1169                let mut api = ApiBuilder::from_cache(cache)
1170                    .with_progress(!silent)
1171                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1172                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1173                    api = api.with_cache_dir(x.into());
1174                }
1175                api.build().map_err(candle_core::Error::msg)?
1176            };
1177            let revision = revision.clone().unwrap_or("main".to_string());
1178            let api = api.repo(Repo::with_revision(
1179                model_id_str.clone(),
1180                RepoType::Model,
1181                revision.clone(),
1182            ));
1183
1184            let mut filenames = vec![];
1185            for rfilename in
1186                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1187            {
1188                filenames.push(api_get_file!(api, &rfilename, model_id));
1189            }
1190
1191            let regex = regex.clone();
1192            let match_regex_clone = match_regex.to_string();
1193            let layers_clone = layers.clone();
1194            let vb = from_mmaped_safetensors(
1195                filenames,
1196                vec![],
1197                Some(dtype),
1198                dev,
1199                vec![None],
1200                silent,
1201                None,
1202                move |key| {
1203                    if regex.is_match(&key) {
1204                        // Idx of the last char of the layer id, +1
1205                        // Assumes N.MLP
1206                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1207                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1208                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1209                            .parse::<usize>()
1210                            .unwrap();
1211                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1212                    } else {
1213                        false
1214                    }
1215                },
1216                Arc::new(|_| DeviceForLoadTensor::Base),
1217            )?;
1218            vbs.push(vb);
1219        }
1220
1221        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1222            let model_id_str = &gate_model_id;
1223            let model_id = Path::new(&gate_model_id);
1224
1225            let api = {
1226                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1227                let mut api = ApiBuilder::from_cache(cache)
1228                    .with_progress(!silent)
1229                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1230                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1231                    api = api.with_cache_dir(x.into());
1232                }
1233                api.build().map_err(candle_core::Error::msg)?
1234            };
1235            let revision = revision.clone().unwrap_or("main".to_string());
1236            let api = api.repo(Repo::with_revision(
1237                model_id_str.clone(),
1238                RepoType::Model,
1239                revision.clone(),
1240            ));
1241
1242            let mut gate_filenames = vec![];
1243            for rfilename in
1244                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1245            {
1246                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1247            }
1248            assert_eq!(
1249                gate_filenames.len(),
1250                1,
1251                "Gate model ID must contain only one .safetensors file"
1252            );
1253
1254            let vb = from_mmaped_safetensors(
1255                gate_filenames.clone(),
1256                vec![],
1257                Some(dtype),
1258                dev,
1259                vec![None],
1260                silent,
1261                None,
1262                |_| true,
1263                Arc::new(|_| DeviceForLoadTensor::Base),
1264            )?;
1265            info!(
1266                "Loaded gating layers from `{}`",
1267                gate_filenames[0].display()
1268            );
1269            Some(vb)
1270        } else {
1271            None
1272        };
1273
1274        self.model.create_anymoe_layers(
1275            vbs.clone(),
1276            config.clone(),
1277            (prefix.clone(), mlp.clone()),
1278            layers.clone(),
1279            expert_type.clone(),
1280            gate_vb.clone(),
1281        )?;
1282
1283        Ok(())
1284    }
1285    fn amoe_supported(&self) -> bool {
1286        self.model.amoe_supported()
1287    }
1288}