mistralrs_core/pipeline/
normal.rs

1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7    TokenSource,
8};
9use super::{
10    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
15    LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
16    Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::{Path, PathBuf};
56use std::str::FromStr;
57use std::sync::{Arc, RwLock};
58use std::time::Instant;
59use std::{env, fs};
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64pub struct NormalPipeline {
65    model: Box<dyn NormalModel + Send + Sync>,
66    tokenizer: Arc<Tokenizer>,
67    no_kv_cache: bool,
68    chat_template: Arc<ChatTemplate>,
69    non_granular_state: Option<NonGranularState>,
70    model_id: String,
71    metadata: Arc<GeneralMetadata>,
72    topology: Option<Topology>,
73    silent: bool,
74    organization: IsqOrganization,
75    // For full UQFF serialization
76    template_filename: Option<PathBuf>,
77    generation_config: Option<PathBuf>,
78    config: String,
79    imatrix: Option<PathBuf>,
80    mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83/// A loader for a "normal" (non-quantized) model.
84pub struct NormalLoader {
85    inner: Box<dyn NormalModelLoader>,
86    model_id: String,
87    config: NormalSpecificConfig,
88    xlora_model_id: Option<String>,
89    lora_adapter_ids: Option<Vec<String>>,
90    kind: ModelKind,
91    xlora_order: Option<Ordering>,
92    no_kv_cache: bool,
93    chat_template: Option<String>,
94    tokenizer_json: Option<String>,
95    tgt_non_granular_index: Option<usize>,
96    token_source: RwLock<Option<TokenSource>>,
97    revision: RwLock<Option<String>>,
98    from_uqff: RwLock<Option<Vec<PathBuf>>>,
99    jinja_explicit: Option<String>,
100    hf_cache_path: Option<PathBuf>,
101}
102
103#[derive(Default)]
104/// A builder for a loader for a "normal" (non-quantized) model.
105pub struct NormalLoaderBuilder {
106    model_id: Option<String>,
107    config: NormalSpecificConfig,
108    xlora_model_id: Option<String>,
109    lora_adapter_ids: Option<Vec<String>>,
110    kind: ModelKind,
111    xlora_order: Option<Ordering>,
112    no_kv_cache: bool,
113    chat_template: Option<String>,
114    tokenizer_json: Option<String>,
115    tgt_non_granular_index: Option<usize>,
116    jinja_explicit: Option<String>,
117    hf_cache_path: Option<PathBuf>,
118}
119
120#[derive(Clone, Default)]
121/// Config specific to loading a normal model.
122pub struct NormalSpecificConfig {
123    pub prompt_chunksize: Option<NonZeroUsize>,
124    pub topology: Option<Topology>,
125    pub organization: IsqOrganization,
126    pub write_uqff: Option<PathBuf>,
127    pub from_uqff: Option<Vec<PathBuf>>,
128    pub imatrix: Option<PathBuf>,
129    pub calibration_file: Option<PathBuf>,
130    pub hf_cache_path: Option<PathBuf>,
131    pub matformer_config_path: Option<PathBuf>,
132    pub matformer_slice_name: Option<String>,
133}
134
135impl NormalLoaderBuilder {
136    pub fn new(
137        config: NormalSpecificConfig,
138        chat_template: Option<String>,
139        tokenizer_json: Option<String>,
140        model_id: Option<String>,
141        no_kv_cache: bool,
142        jinja_explicit: Option<String>,
143    ) -> Self {
144        Self {
145            config,
146            chat_template,
147            tokenizer_json,
148            model_id,
149            kind: ModelKind::Normal,
150            jinja_explicit,
151            no_kv_cache,
152            ..Default::default()
153        }
154    }
155
156    fn with_adapter(
157        mut self,
158        xlora_model_id: String,
159        xlora_order: Ordering,
160        no_kv_cache: bool,
161        tgt_non_granular_index: Option<usize>,
162    ) -> Self {
163        self.xlora_model_id = Some(xlora_model_id);
164        self.xlora_order = Some(xlora_order);
165        self.no_kv_cache = no_kv_cache;
166        self.tgt_non_granular_index = tgt_non_granular_index;
167        self.model_id = if let Some(id) = self.model_id {
168            Some(id)
169        } else {
170            info!(
171                "Using adapter base model ID: `{}`",
172                self.xlora_order.as_ref().unwrap().base_model_id
173            );
174            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
175        };
176        self
177    }
178
179    pub fn with_xlora(
180        mut self,
181        xlora_model_id: String,
182        xlora_order: Ordering,
183        no_kv_cache: bool,
184        tgt_non_granular_index: Option<usize>,
185    ) -> Self {
186        self.kind = ModelKind::Adapter {
187            adapter: AdapterKind::XLora,
188        };
189        self.with_adapter(
190            xlora_model_id,
191            xlora_order,
192            no_kv_cache,
193            tgt_non_granular_index,
194        )
195    }
196
197    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
198        self.kind = ModelKind::Adapter {
199            adapter: AdapterKind::Lora,
200        };
201        self.lora_adapter_ids = Some(lora_adapter_ids);
202        self
203    }
204
205    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
206        self.hf_cache_path = Some(hf_cache_path);
207        self
208    }
209
210    /// If the loader type is not specified, loader type is automatically determined from the
211    /// `architectures` array in the config.
212    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
213        let loader: Box<dyn NormalModelLoader> = match loader_tp {
214            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
215            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
216            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
217            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
218            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
219            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
220            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
221            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
222            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
223            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
224            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
225            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
226            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
227            Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
228            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
229            Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
230            None => Box::new(AutoNormalLoader),
231        };
232        Ok(Box::new(NormalLoader {
233            inner: loader,
234            model_id: self.model_id.unwrap(),
235            config: self.config,
236            xlora_model_id: self.xlora_model_id,
237            lora_adapter_ids: self.lora_adapter_ids,
238            kind: self.kind,
239            xlora_order: self.xlora_order,
240            no_kv_cache: self.no_kv_cache,
241            chat_template: self.chat_template,
242            tokenizer_json: self.tokenizer_json,
243            tgt_non_granular_index: self.tgt_non_granular_index,
244            jinja_explicit: self.jinja_explicit,
245            token_source: RwLock::new(None),
246            revision: RwLock::new(None),
247            from_uqff: RwLock::new(None),
248            hf_cache_path: self.hf_cache_path,
249        }))
250    }
251}
252
253impl Loader for NormalLoader {
254    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
255    fn load_model_from_hf(
256        &self,
257        revision: Option<String>,
258        token_source: TokenSource,
259        dtype: &dyn TryIntoDType,
260        device: &Device,
261        silent: bool,
262        mapper: DeviceMapSetting,
263        in_situ_quant: Option<IsqType>,
264        paged_attn_config: Option<PagedAttentionConfig>,
265    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
266        let cache = self
267            .hf_cache_path
268            .clone()
269            .map(Cache::new)
270            .unwrap_or_default();
271        GLOBAL_HF_CACHE.get_or_init(|| cache);
272
273        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
274            LocalModelPaths,
275            &token_source,
276            revision.clone(),
277            self,
278            None,
279            None,
280            silent,
281            self.config.from_uqff.is_some()
282        );
283        if let Some(from_uqff) = self.config.from_uqff.clone() {
284            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
285        }
286        *self
287            .token_source
288            .write()
289            .expect("Failed to write to token source") = Some(token_source);
290        *self.revision.write().expect("Failed to write to revision") = revision;
291        self.load_model_from_path(
292            &paths?,
293            dtype,
294            device,
295            silent,
296            mapper,
297            in_situ_quant,
298            paged_attn_config,
299        )
300    }
301
302    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
303    fn load_model_from_path(
304        &self,
305        paths: &Box<dyn ModelPaths>,
306        dtype: &dyn TryIntoDType,
307        device: &Device,
308        silent: bool,
309        mut mapper: DeviceMapSetting,
310        mut in_situ_quant: Option<IsqType>,
311        mut paged_attn_config: Option<PagedAttentionConfig>,
312    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
313        let config = std::fs::read_to_string(paths.get_config_filename())?;
314
315        if !self.inner.supports_paged_attention(&config)? {
316            paged_attn_config = None;
317        }
318
319        // Apply default prompt size here
320        let prompt_chunksize = self
321            .config
322            .prompt_chunksize
323            .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
324            .get();
325
326        info!("Prompt chunk size is {prompt_chunksize}.",);
327
328        let use_nccl = mistralrs_quant::distributed::use_nccl();
329
330        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
331            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
332            let WorkerTransferData::Init { id: _, worker_rank } = payload;
333            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
334        } else if use_nccl {
335            vec![candle_core::Device::new_cuda(0)?]
336        } else {
337            device_map::get_all_similar_devices(device)?
338        };
339        let device = if use_nccl || cfg!(feature = "ring") {
340            available_devices[0].clone()
341        } else {
342            device.clone()
343        };
344
345        // If auto, convert to Map if not using nccl
346        if use_nccl || cfg!(feature = "ring") {
347            mapper = DeviceMapSetting::DummyNccl {
348                nm_device: available_devices[0].clone(),
349            };
350        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
351            // Initial dtype
352            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
353
354            // Disable ISQ if we are loading a prequantized model.
355            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
356                in_situ_quant = None;
357            }
358
359            // ISQ or UQFF: quantized path
360            // Match logic below where UQFF has priority
361            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
362                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
363                    let weight_pack_factor = {
364                        let ser_artifacts = unsafe {
365                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
366                        };
367                        let mut total_pack_factors = 0;
368                        let total_tensors = ser_artifacts.tensors().len();
369                        for (_, artifact) in ser_artifacts.tensors() {
370                            let artifact = artifact.data();
371                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
372                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
373                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
374                            {
375                                QuantizedSerdeType::Hqq => {
376                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
377                                        .pack_factor(dtype)
378                                }
379                                QuantizedSerdeType::Gguf => {
380                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
381                                        .pack_factor(dtype)
382                                }
383                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
384                                QuantizedSerdeType::Unquant => 1,
385                                QuantizedSerdeType::Afq => {
386                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
387                                        .pack_factor(dtype)
388                                }
389                            };
390                            total_pack_factors += pack_factor;
391                        }
392
393                        total_pack_factors / total_tensors
394                    };
395
396                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
397                        &config,
398                        dtype,
399                        weight_pack_factor,
400                        None,
401                    )?;
402                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
403                        &config,
404                        dtype,
405                        weight_pack_factor,
406                        None,
407                    )?;
408                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
409                    (
410                        layer_sizes_in_bytes,
411                        non_mapped_size_in_bytes,
412                        layer_sizes_sum + non_mapped_size_in_bytes,
413                    )
414                } else if let Some(isq) = in_situ_quant {
415                    let weight_pack_factor = isq.pack_factor(dtype);
416                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
417                        &config,
418                        dtype,
419                        weight_pack_factor,
420                        None,
421                    )?;
422                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
423                        &config,
424                        dtype,
425                        weight_pack_factor,
426                        None,
427                    )?;
428                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
429                    (
430                        layer_sizes_in_bytes,
431                        non_mapped_size_in_bytes,
432                        layer_sizes_sum + non_mapped_size_in_bytes,
433                    )
434                } else {
435                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
436                    let weight_pack_factor =
437                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
438                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
439                        &config,
440                        dtype,
441                        weight_pack_factor,
442                        None,
443                    )?;
444                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
445                        &config,
446                        dtype,
447                        weight_pack_factor,
448                        None,
449                    )?;
450                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
451                    (
452                        layer_sizes_in_bytes,
453                        non_mapped_size_in_bytes,
454                        layer_sizes_sum + non_mapped_size_in_bytes,
455                    )
456                };
457
458            let new = auto_device_map::get_device_layers(
459                &*self.inner,
460                &config,
461                self.inner.num_layers(&config)?,
462                layer_sizes_in_bytes,
463                non_mapped_size_in_bytes,
464                total_model_size_in_bytes,
465                &available_devices,
466                dtype,
467                &params,
468                prompt_chunksize,
469                paged_attn_config.as_ref(),
470            )?;
471            mapper = DeviceMapSetting::Map(new);
472        }
473
474        let pipeline_mapper = mapper.into_mapper(
475            self.inner.num_layers(&config)?,
476            &device,
477            self.config.topology.as_ref(),
478        )?;
479        let mapper = mapper.into_mapper(
480            self.inner.num_layers(&config)?,
481            &device,
482            self.config.topology.as_ref(),
483        )?;
484        let mut layer_devices = Vec::new();
485        for layer in 0..self.inner.num_layers(&config)? {
486            let device = mapper.device_for(layer, false).cloned();
487            layer_devices.push(device);
488        }
489        let dtype = mapper.get_min_dtype(dtype)?;
490
491        // TODO: PagedAttention is not supported with CPU for now.
492        // This check is not really necessary because `get_device_layers` should prevent it.
493        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
494        if mapping_uses_cpu && paged_attn_config.is_some() {
495            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
496            paged_attn_config = None;
497        }
498
499        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
500        if crate::using_flash_attn() {
501            once_log_info("FlashAttention is enabled.");
502        }
503
504        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
505        let mut loading_isq = if self.config.imatrix.is_none()
506            && self.config.calibration_file.is_none()
507            && !device.is_cuda()
508            && self.config.write_uqff.is_none()
509            && in_situ_quant.is_some()
510        {
511            let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
512            {
513                self.inner.immediate_isq_predicates_moqe(&config)?
514            } else {
515                self.inner.immediate_isq_predicates(&config)?
516            };
517            info!("Applying ISQ to {in_situ_quant:?}");
518            if predicates.is_empty() {
519                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
520            }
521            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
522            false
523        } else {
524            in_situ_quant.is_some()
525        };
526
527        if let Some(ref topology) = self.config.topology {
528            loading_isq |= topology
529                .0
530                .iter()
531                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
532        }
533
534        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
535            anyhow::bail!(
536                "`imatrix` and `calibration_file` were both specified, this is not allowed."
537            );
538        }
539
540        // Load onto the regular device if not using isq or if the calibration file is specified
541        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
542            loading_isq = false;
543            device.clone()
544        } else {
545            Device::Cpu
546        };
547
548        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
549
550        let attention_mechanism = if paged_attn_config.is_some() {
551            AttentionImplementation::PagedAttention
552        } else {
553            AttentionImplementation::Eager
554        };
555
556        let multi_progress = Arc::new(MultiProgress::new());
557
558        // Load matformer slicing config if provided
559        let matformer_slicing_config = if let Some(matformer_path) =
560            &self.config.matformer_config_path
561        {
562            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
563            info!("Loading Matformer config from {:?}", matformer_path);
564            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
565
566            if let Some(slice_name) = &self.config.matformer_slice_name {
567                info!("Using Matformer slice: {}", slice_name);
568                Some(MatformerSliceConfig::new(slice_name.clone(), config))
569            } else {
570                // If no slice name is provided but config exists, we'll need to handle this
571                // For now, return None and let the model handle the default slice selection
572                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
573                None
574            }
575        } else {
576            None
577        };
578
579        let mut model = if use_nccl || cfg!(feature = "ring") {
580            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
581                dtype,
582                &device,
583                &available_devices,
584                silent,
585                &config,
586                loading_isq,
587                self.config.from_uqff.is_some(),
588                self.config.organization,
589                &*self.inner,
590                paths.as_ref(),
591            )?;
592
593            // Special case for where things can be more optimially loaded.
594            match self.kind {
595                ModelKind::Normal => normal_model_loader_sharded!(
596                    sharded_vb,
597                    config,
598                    self.inner,
599                    mapper,
600                    loading_isq,
601                    device.clone(),
602                    attention_mechanism,
603                    multi_progress.clone(),
604                    matformer_slicing_config.clone(),
605                ),
606                ModelKind::Adapter {
607                    adapter: AdapterKind::XLora,
608                } => xlora_model_loader!(
609                    paths,
610                    Some(dtype),
611                    &load_device,
612                    layer_devices.clone(),
613                    config,
614                    self.inner,
615                    silent,
616                    mapper,
617                    loading_isq,
618                    device.clone(),
619                    multi_progress.clone(),
620                    matformer_slicing_config.clone(),
621                ),
622                ModelKind::Adapter {
623                    adapter: AdapterKind::Lora,
624                } => lora_model_loader!(
625                    paths,
626                    Some(dtype),
627                    &load_device,
628                    layer_devices.clone(),
629                    config,
630                    self.inner,
631                    silent,
632                    mapper,
633                    loading_isq,
634                    self.config.from_uqff.is_some(),
635                    device.clone(),
636                    attention_mechanism,
637                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
638                    multi_progress.clone(),
639                    matformer_slicing_config.clone(),
640                ),
641                _ => unreachable!(),
642            }
643        } else {
644            match self.kind {
645                ModelKind::Normal => normal_model_loader!(
646                    paths,
647                    Some(dtype),
648                    &load_device,
649                    layer_devices.clone(),
650                    config,
651                    self.inner,
652                    silent,
653                    mapper,
654                    loading_isq,
655                    self.config.from_uqff.is_some(),
656                    device.clone(),
657                    attention_mechanism,
658                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
659                    multi_progress.clone(),
660                    matformer_slicing_config.clone(),
661                ),
662                ModelKind::Adapter {
663                    adapter: AdapterKind::XLora,
664                } => xlora_model_loader!(
665                    paths,
666                    Some(dtype),
667                    &load_device,
668                    layer_devices.clone(),
669                    config,
670                    self.inner,
671                    silent,
672                    mapper,
673                    loading_isq,
674                    device.clone(),
675                    multi_progress.clone(),
676                    matformer_slicing_config.clone(),
677                ),
678                ModelKind::Adapter {
679                    adapter: AdapterKind::Lora,
680                } => lora_model_loader!(
681                    paths,
682                    Some(dtype),
683                    &load_device,
684                    layer_devices.clone(),
685                    config,
686                    self.inner,
687                    silent,
688                    mapper,
689                    loading_isq,
690                    self.config.from_uqff.is_some(),
691                    device.clone(),
692                    attention_mechanism,
693                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
694                    multi_progress.clone(),
695                    matformer_slicing_config.clone(),
696                ),
697                _ => unreachable!(),
698            }
699        };
700
701        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
702        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
703            match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
704                Ok(conf) => Some(conf),
705                Err(e) => {
706                    warn!("Failed to parse generation_config.json: {}", e);
707                    None
708                }
709            }
710        });
711
712        let chat_template_explicit = paths
713            .get_chat_template_explicit()
714            .as_ref()
715            .map(|x| x.to_string_lossy().to_string());
716        let chat_template = get_chat_template(
717            paths,
718            self.jinja_explicit.as_ref(),
719            chat_template_explicit.as_ref(),
720            self.chat_template.as_ref(),
721            None,
722        );
723
724        if let Some(calibration_file) = &self.config.calibration_file {
725            let calibration_data = std::fs::read_to_string(calibration_file)?;
726            // Tokenize, don't add bos yet
727            let tokens = tokenizer
728                .encode_fast(calibration_data, false)
729                .map_err(anyhow::Error::msg)?
730                .get_ids()
731                .to_vec();
732            info!(
733                "Collecting imatrix from calibration file `{}` of {} tokens.",
734                calibration_file.display(),
735                tokens.len()
736            );
737            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
738            let bos_tok_id = tokenizer
739                .token_to_id(&bos_toks[0])
740                .expect("Somehow the bos token is not present.");
741
742            match self.config.organization {
743                IsqOrganization::Default => model.begin_track_stats()?,
744                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
745            }
746
747            const CHUNK_SIZE: usize = 1024;
748            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
749            let start = Instant::now();
750            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
751                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
752                let chunk_len = chunk.len();
753
754                let start = Instant::now();
755                let inputs = make_prompt_chunk(
756                    0,
757                    vec![&chunk],
758                    &[0],
759                    &load_device,
760                    None,
761                    false,
762                    None,
763                    Some(pipeline_mapper.as_ref()),
764                )?;
765
766                model.forward(
767                    &inputs.input.to_device(model.device())?,
768                    &inputs.positions,
769                    inputs.context_lens.clone(),
770                    inputs.position_ids.clone(),
771                    None,
772                    &inputs.flash_meta.clone(),
773                )?;
774
775                match model.cache_mut() {
776                    EitherCache::Full(full) => {
777                        for layer in &mut *full.lock() {
778                            *layer = None
779                        }
780                    }
781                    EitherCache::Normal(normal) => {
782                        for layer in &mut *normal.lock().unwrap().0 {
783                            layer.reset();
784                        }
785                    }
786                }
787
788                let end = Instant::now();
789                info!(
790                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
791                    i + 1,
792                    end.duration_since(start).as_secs_f32()
793                );
794            }
795            load_device.synchronize()?;
796            let end = Instant::now();
797            info!(
798                "Finished collecting imatrix in {:.2}s",
799                end.duration_since(start).as_secs_f32()
800            );
801        }
802
803        // Only if loading from UQFF
804        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
805            let imatrix_source = match (
806                self.config.imatrix.as_ref(),
807                self.config.calibration_file.is_some(),
808            ) {
809                (None, false) => None,
810                (Some(file), false) => Some(ImatrixDataSource::File(file)),
811                (None, true) => Some(ImatrixDataSource::Collected),
812                (Some(_), true) => unreachable!(),
813            };
814
815            info!("Applying ISQ to all ranks.");
816
817            let multi_progress = Arc::new(MultiProgress::new());
818
819            model.quantize(
820                in_situ_quant,
821                model.device().clone(),
822                self.config.topology.as_ref(),
823                silent,
824                imatrix_source,
825                self.config.organization,
826                self.config.write_uqff.as_ref(),
827                UqffFullSer {
828                    tokenizer: &tokenizer,
829                    template_filename: paths.get_template_filename(),
830                    generation_config: paths.get_gen_conf_filename(),
831                    config: config.clone(),
832                    processor_filename: &None,
833                    preprocessor_filename: &None,
834                },
835                multi_progress.clone(),
836            )?;
837        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
838            model.load_from_artifacts(
839                device.clone(),
840                self.config.topology.as_ref(),
841                silent,
842                from_uqff,
843            )?;
844        }
845
846        let paged_attn_config = if matches!(
847            self.kind,
848            ModelKind::Adapter {
849                adapter: AdapterKind::XLora
850            }
851        ) {
852            warn!(
853                "Adapter parallel_models do not currently support PagedAttention, running without"
854            );
855            None
856        } else {
857            paged_attn_config
858        };
859
860        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
861            let cache_config = calculate_cache_config(
862                paged_attn_config.mem_gpu,
863                paged_attn_config.mem_cpu,
864                paged_attn_config.block_size,
865                dtype,
866                paged_attn_config.cache_type,
867                model.config(),
868                &device,
869                &pipeline_mapper
870                    .get_unique_devices()
871                    .into_iter()
872                    .map(Some)
873                    .collect::<Vec<_>>(),
874                silent,
875            )?;
876
877            let mut layer_devices = Vec::new();
878            for layer in 0..self.inner.num_layers(&config)? {
879                let device = model.get_layers().1.device_for(layer, false).cloned();
880                layer_devices.push(device);
881            }
882            let cache_engine = CacheEngine::new(
883                model.config(),
884                &cache_config,
885                dtype,
886                model.device(),
887                layer_devices.clone(),
888            )?;
889
890            (Some(cache_config), Some(cache_engine))
891        } else {
892            (None, None)
893        };
894
895        let max_seq_len = model.max_seq_len();
896        let llg_factory = build_llg_factory(tokenizer.clone())?;
897        let num_hidden_layers = match model.cache() {
898            EitherCache::Full(full) => full.lock().len(),
899            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
900        };
901        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
902        let sliding_window = model.config().sliding_window;
903        let model_metadata = Arc::new(model.config().clone());
904
905        Ok(Arc::new(Mutex::new(NormalPipeline {
906            model,
907            tokenizer: tokenizer.into(),
908            no_kv_cache: self.no_kv_cache,
909            chat_template: Arc::new(chat_template),
910            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
911                NonGranularState {
912                    non_granular_index: Arc::new(Mutex::new(0)),
913                    tgt_non_granular_index,
914                }
915            }),
916            model_id: self.model_id.clone(),
917            metadata: Arc::new(GeneralMetadata {
918                max_seq_len,
919                llg_factory: Some(llg_factory),
920                no_kv_cache: self.no_kv_cache,
921                no_prefix_cache: is_xlora,
922                num_hidden_layers,
923                eos_tok: eos,
924                kind: self.kind.clone(),
925                is_xlora,
926                activation_dtype: dtype,
927                sliding_window,
928                cache_config,
929                cache_engine,
930                prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
931                model_metadata: Some(model_metadata),
932                modalities: Modalities {
933                    input: vec![SupportedModality::Text],
934                    output: vec![SupportedModality::Text],
935                },
936            }),
937            topology: self.config.topology.clone(),
938            silent,
939            organization: self.config.organization,
940            template_filename: paths.get_template_filename().clone(),
941            generation_config: paths.get_gen_conf_filename().cloned(),
942            config,
943            imatrix: self.config.imatrix.clone(),
944            mapper: pipeline_mapper,
945        })))
946    }
947
948    fn get_id(&self) -> String {
949        self.model_id.clone()
950    }
951
952    fn get_kind(&self) -> ModelKind {
953        self.kind.clone()
954    }
955}
956
957impl PreProcessingMixin for NormalPipeline {
958    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
959        Some(self.chat_template.clone())
960    }
961    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
962        None
963    }
964}
965
966impl IsqPipelineMixin for NormalPipeline {
967    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
968        let device = self.device().clone();
969        let multi_progress = Arc::new(MultiProgress::new());
970        self.model.quantize(
971            Some(dtype),
972            device.clone(),
973            self.topology.as_ref(),
974            self.silent,
975            self.imatrix.as_ref().map(ImatrixDataSource::File),
976            self.organization,
977            None,
978            UqffFullSer {
979                tokenizer: &self.tokenizer,
980                template_filename: &self.template_filename,
981                generation_config: self.generation_config.as_ref(),
982                config: self.config.clone(),
983                processor_filename: &None,
984                preprocessor_filename: &None,
985            },
986            multi_progress.clone(),
987        )?;
988        Ok(())
989    }
990}
991
992impl CacheManagerMixin for NormalPipeline {
993    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
994        if matches!(self.model.cache(), EitherCache::Full(_)) {
995            FullCacheManager.clone_in_cache(self, seqs, false)
996        } else {
997            NormalCacheManager.clone_in_cache(self, seqs, false)
998        }
999    }
1000    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1001        if matches!(self.model.cache(), EitherCache::Full(_)) {
1002            FullCacheManager.clone_out_cache(self, seqs, false)
1003        } else {
1004            NormalCacheManager.clone_out_cache(self, seqs, false)
1005        }
1006    }
1007    fn set_none_cache(
1008        &self,
1009        seqs: &mut [&mut Sequence],
1010        reset_non_granular: bool,
1011        modify_draft_cache: bool,
1012        load_preallocated_cache: bool,
1013    ) {
1014        if matches!(self.model.cache(), EitherCache::Full(_)) {
1015            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
1016        } else {
1017            NormalCacheManager.set_none_cache(
1018                self,
1019                seqs,
1020                modify_draft_cache,
1021                load_preallocated_cache,
1022            );
1023        }
1024        if reset_non_granular {
1025            self.reset_non_granular_state()
1026        }
1027    }
1028    fn cache(&self) -> &EitherCache {
1029        self.model.cache()
1030    }
1031}
1032
1033impl MetadataMixin for NormalPipeline {
1034    fn device(&self) -> Device {
1035        self.model.device().clone()
1036    }
1037    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1038        Some(self.tokenizer.clone())
1039    }
1040    fn name(&self) -> String {
1041        self.model_id.clone()
1042    }
1043    fn reset_non_granular_state(&self) {
1044        if let Some(s) = self.non_granular_state.as_ref() {
1045            *self.cache().full().get_scalings_cache() = None;
1046            *get_mut_arcmutex!(s.non_granular_index) = 0;
1047        }
1048    }
1049    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1050        self.metadata.clone()
1051    }
1052    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1053        Some(&*self.mapper)
1054    }
1055}
1056
1057#[async_trait::async_trait]
1058impl Pipeline for NormalPipeline {
1059    fn forward_inputs(
1060        &mut self,
1061        inputs: Box<dyn Any>,
1062        return_raw_logits: bool,
1063    ) -> Result<ForwardInputsResult, candle_core::Error> {
1064        let ModelInputs {
1065            input_ids,
1066            input_ids_full,
1067            seqlen_offsets,
1068            seqlen_offsets_full,
1069            context_lens,
1070            position_ids,
1071            paged_attn_meta,
1072            flash_meta,
1073            flash_meta_full,
1074        } = *inputs.downcast().expect("Downcast failed.");
1075        let metadata = self.get_metadata();
1076        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1077            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1078            (Some(_), None) => {
1079                // This can happen if Rust-side user code is wrong
1080                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1081            }
1082            (None, Some(_)) => {
1083                // This should never happen but we handle it anyway
1084                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1085            }
1086            (None, None) => None,
1087        };
1088        let logits = match self.model.is_xlora() {
1089            false => {
1090                let paged_attn_meta = paged_attn_meta
1091                    .as_ref()
1092                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1093
1094                self.model.forward(
1095                    &input_ids,
1096                    &seqlen_offsets,
1097                    context_lens,
1098                    position_ids,
1099                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1100                    &flash_meta,
1101                )?
1102            }
1103            true => self.model.xlora_forward(
1104                &input_ids,
1105                input_ids_full.as_ref().unwrap_or(&input_ids),
1106                &seqlen_offsets,
1107                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1108                self.no_kv_cache,
1109                &self.non_granular_state,
1110                context_lens,
1111                position_ids,
1112                &flash_meta,
1113                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1114            )?,
1115        };
1116        if return_raw_logits {
1117            Ok(ForwardInputsResult::RawLogits { logits })
1118        } else {
1119            Ok(ForwardInputsResult::CausalGeneration { logits })
1120        }
1121    }
1122    async fn sample_causal_gen(
1123        &self,
1124        seqs: &mut [&mut Sequence],
1125        logits: Vec<Tensor>,
1126        prefix_cacher: &mut PrefixCacheManagerV2,
1127        disable_eos_stop: bool,
1128        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1129    ) -> Result<(), candle_core::Error> {
1130        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1131    }
1132    fn category(&self) -> ModelCategory {
1133        ModelCategory::Text
1134    }
1135}
1136
1137impl AnyMoePipelineMixin for NormalPipeline {
1138    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1139        self.model.finish_training(gate_model_id)
1140    }
1141    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1142        self.model.get_vars()
1143    }
1144    fn amoe_base_model_trainable_params(&self) -> usize {
1145        self.model.trainable_params()
1146    }
1147    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1148        self.model.take_cached_gating_outputs()
1149    }
1150    fn amoe_create_layers(
1151        &mut self,
1152        model_ids: Vec<String>,
1153        token: &TokenSource,
1154        revision: Option<String>,
1155        match_regex: &str,
1156        config: crate::amoe::AnyMoeConfig,
1157        dtype: candle_core::DType,
1158        dev: &Device,
1159        (prefix, mlp): (String, String),
1160        layers: Vec<usize>,
1161        expert_type: AnyMoeExpertType,
1162        silent: bool,
1163        gate_model_id: Option<String>,
1164    ) -> candle_core::Result<()> {
1165        let mut vbs = Vec::new();
1166        // Precompile regex here
1167        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1168        for model_id in model_ids {
1169            let model_id_str = &model_id;
1170            let model_id = Path::new(&model_id);
1171
1172            let api = {
1173                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1174                let mut api = ApiBuilder::from_cache(cache)
1175                    .with_progress(!silent)
1176                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1177                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1178                    api = api.with_cache_dir(x.into());
1179                }
1180                api.build().map_err(candle_core::Error::msg)?
1181            };
1182            let revision = revision.clone().unwrap_or("main".to_string());
1183            let api = api.repo(Repo::with_revision(
1184                model_id_str.clone(),
1185                RepoType::Model,
1186                revision.clone(),
1187            ));
1188
1189            let mut filenames = vec![];
1190            for rfilename in
1191                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1192            {
1193                filenames.push(api_get_file!(api, &rfilename, model_id));
1194            }
1195
1196            let regex = regex.clone();
1197            let match_regex_clone = match_regex.to_string();
1198            let layers_clone = layers.clone();
1199            let vb = from_mmaped_safetensors(
1200                filenames,
1201                vec![],
1202                Some(dtype),
1203                dev,
1204                vec![None],
1205                silent,
1206                None,
1207                move |key| {
1208                    if regex.is_match(&key) {
1209                        // Idx of the last char of the layer id, +1
1210                        // Assumes N.MLP
1211                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1212                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1213                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1214                            .parse::<usize>()
1215                            .unwrap();
1216                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1217                    } else {
1218                        false
1219                    }
1220                },
1221                Arc::new(|_| DeviceForLoadTensor::Base),
1222            )?;
1223            vbs.push(vb);
1224        }
1225
1226        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1227            let model_id_str = &gate_model_id;
1228            let model_id = Path::new(&gate_model_id);
1229
1230            let api = {
1231                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1232                let mut api = ApiBuilder::from_cache(cache)
1233                    .with_progress(!silent)
1234                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1235                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1236                    api = api.with_cache_dir(x.into());
1237                }
1238                api.build().map_err(candle_core::Error::msg)?
1239            };
1240            let revision = revision.clone().unwrap_or("main".to_string());
1241            let api = api.repo(Repo::with_revision(
1242                model_id_str.clone(),
1243                RepoType::Model,
1244                revision.clone(),
1245            ));
1246
1247            let mut gate_filenames = vec![];
1248            for rfilename in
1249                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1250            {
1251                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1252            }
1253            assert_eq!(
1254                gate_filenames.len(),
1255                1,
1256                "Gate model ID must contain only one .safetensors file"
1257            );
1258
1259            let vb = from_mmaped_safetensors(
1260                gate_filenames.clone(),
1261                vec![],
1262                Some(dtype),
1263                dev,
1264                vec![None],
1265                silent,
1266                None,
1267                |_| true,
1268                Arc::new(|_| DeviceForLoadTensor::Base),
1269            )?;
1270            info!(
1271                "Loaded gating layers from `{}`",
1272                gate_filenames[0].display()
1273            );
1274            Some(vb)
1275        } else {
1276            None
1277        };
1278
1279        self.model.create_anymoe_layers(
1280            vbs.clone(),
1281            config.clone(),
1282            (prefix.clone(), mlp.clone()),
1283            layers.clone(),
1284            expert_type.clone(),
1285            gate_vb.clone(),
1286        )?;
1287
1288        Ok(())
1289    }
1290    fn amoe_supported(&self) -> bool {
1291        self.model.amoe_supported()
1292    }
1293}