mistralrs_core/pipeline/
normal.rs

1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7    TokenSource,
8};
9use super::{
10    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
15    LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
16    Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::{Path, PathBuf};
56use std::str::FromStr;
57use std::sync::{Arc, RwLock};
58use std::time::Instant;
59use std::{env, fs};
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64pub struct NormalPipeline {
65    model: Box<dyn NormalModel + Send + Sync>,
66    tokenizer: Arc<Tokenizer>,
67    no_kv_cache: bool,
68    chat_template: Arc<ChatTemplate>,
69    non_granular_state: Option<NonGranularState>,
70    model_id: String,
71    metadata: Arc<GeneralMetadata>,
72    topology: Option<Topology>,
73    silent: bool,
74    organization: IsqOrganization,
75    // For full UQFF serialization
76    template_filename: Option<PathBuf>,
77    generation_config: Option<PathBuf>,
78    config: String,
79    imatrix: Option<PathBuf>,
80    mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83/// A loader for a "normal" (non-quantized) model.
84pub struct NormalLoader {
85    inner: Box<dyn NormalModelLoader>,
86    model_id: String,
87    config: NormalSpecificConfig,
88    xlora_model_id: Option<String>,
89    lora_adapter_ids: Option<Vec<String>>,
90    kind: ModelKind,
91    xlora_order: Option<Ordering>,
92    no_kv_cache: bool,
93    chat_template: Option<String>,
94    tokenizer_json: Option<String>,
95    tgt_non_granular_index: Option<usize>,
96    token_source: RwLock<Option<TokenSource>>,
97    revision: RwLock<Option<String>>,
98    from_uqff: RwLock<Option<Vec<PathBuf>>>,
99    jinja_explicit: Option<String>,
100    hf_cache_path: Option<PathBuf>,
101}
102
103#[derive(Default)]
104/// A builder for a loader for a "normal" (non-quantized) model.
105pub struct NormalLoaderBuilder {
106    model_id: Option<String>,
107    config: NormalSpecificConfig,
108    xlora_model_id: Option<String>,
109    lora_adapter_ids: Option<Vec<String>>,
110    kind: ModelKind,
111    xlora_order: Option<Ordering>,
112    no_kv_cache: bool,
113    chat_template: Option<String>,
114    tokenizer_json: Option<String>,
115    tgt_non_granular_index: Option<usize>,
116    jinja_explicit: Option<String>,
117    hf_cache_path: Option<PathBuf>,
118}
119
120#[derive(Clone, Default)]
121/// Config specific to loading a normal model.
122pub struct NormalSpecificConfig {
123    pub prompt_chunksize: Option<NonZeroUsize>,
124    pub topology: Option<Topology>,
125    pub organization: IsqOrganization,
126    pub write_uqff: Option<PathBuf>,
127    pub from_uqff: Option<Vec<PathBuf>>,
128    pub imatrix: Option<PathBuf>,
129    pub calibration_file: Option<PathBuf>,
130    pub hf_cache_path: Option<PathBuf>,
131}
132
133impl NormalLoaderBuilder {
134    pub fn new(
135        config: NormalSpecificConfig,
136        chat_template: Option<String>,
137        tokenizer_json: Option<String>,
138        model_id: Option<String>,
139        no_kv_cache: bool,
140        jinja_explicit: Option<String>,
141    ) -> Self {
142        Self {
143            config,
144            chat_template,
145            tokenizer_json,
146            model_id,
147            kind: ModelKind::Normal,
148            jinja_explicit,
149            no_kv_cache,
150            ..Default::default()
151        }
152    }
153
154    fn with_adapter(
155        mut self,
156        xlora_model_id: String,
157        xlora_order: Ordering,
158        no_kv_cache: bool,
159        tgt_non_granular_index: Option<usize>,
160    ) -> Self {
161        self.xlora_model_id = Some(xlora_model_id);
162        self.xlora_order = Some(xlora_order);
163        self.no_kv_cache = no_kv_cache;
164        self.tgt_non_granular_index = tgt_non_granular_index;
165        self.model_id = if let Some(id) = self.model_id {
166            Some(id)
167        } else {
168            info!(
169                "Using adapter base model ID: `{}`",
170                self.xlora_order.as_ref().unwrap().base_model_id
171            );
172            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173        };
174        self
175    }
176
177    pub fn with_xlora(
178        mut self,
179        xlora_model_id: String,
180        xlora_order: Ordering,
181        no_kv_cache: bool,
182        tgt_non_granular_index: Option<usize>,
183    ) -> Self {
184        self.kind = ModelKind::Adapter {
185            adapter: AdapterKind::XLora,
186        };
187        self.with_adapter(
188            xlora_model_id,
189            xlora_order,
190            no_kv_cache,
191            tgt_non_granular_index,
192        )
193    }
194
195    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
196        self.kind = ModelKind::Adapter {
197            adapter: AdapterKind::Lora,
198        };
199        self.lora_adapter_ids = Some(lora_adapter_ids);
200        self
201    }
202
203    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
204        self.hf_cache_path = Some(hf_cache_path);
205        self
206    }
207
208    /// If the loader type is not specified, loader type is automatically determined from the
209    /// `architectures` array in the config.
210    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
211        let loader: Box<dyn NormalModelLoader> = match loader_tp {
212            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
213            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
214            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
215            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
216            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
217            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
218            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
219            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
220            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
221            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
222            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
223            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
224            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225            Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
226            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
227            Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
228            None => Box::new(AutoNormalLoader),
229        };
230        Ok(Box::new(NormalLoader {
231            inner: loader,
232            model_id: self.model_id.unwrap(),
233            config: self.config,
234            xlora_model_id: self.xlora_model_id,
235            lora_adapter_ids: self.lora_adapter_ids,
236            kind: self.kind,
237            xlora_order: self.xlora_order,
238            no_kv_cache: self.no_kv_cache,
239            chat_template: self.chat_template,
240            tokenizer_json: self.tokenizer_json,
241            tgt_non_granular_index: self.tgt_non_granular_index,
242            jinja_explicit: self.jinja_explicit,
243            token_source: RwLock::new(None),
244            revision: RwLock::new(None),
245            from_uqff: RwLock::new(None),
246            hf_cache_path: self.hf_cache_path,
247        }))
248    }
249}
250
251impl Loader for NormalLoader {
252    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
253    fn load_model_from_hf(
254        &self,
255        revision: Option<String>,
256        token_source: TokenSource,
257        dtype: &dyn TryIntoDType,
258        device: &Device,
259        silent: bool,
260        mapper: DeviceMapSetting,
261        in_situ_quant: Option<IsqType>,
262        paged_attn_config: Option<PagedAttentionConfig>,
263    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
264        let cache = self
265            .hf_cache_path
266            .clone()
267            .map(Cache::new)
268            .unwrap_or_default();
269        GLOBAL_HF_CACHE.get_or_init(|| cache);
270
271        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
272            LocalModelPaths,
273            &token_source,
274            revision.clone(),
275            self,
276            None,
277            None,
278            silent,
279            self.config.from_uqff.is_some()
280        );
281        if let Some(from_uqff) = self.config.from_uqff.clone() {
282            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
283        }
284        *self
285            .token_source
286            .write()
287            .expect("Failed to write to token source") = Some(token_source);
288        *self.revision.write().expect("Failed to write to revision") = revision;
289        self.load_model_from_path(
290            &paths?,
291            dtype,
292            device,
293            silent,
294            mapper,
295            in_situ_quant,
296            paged_attn_config,
297        )
298    }
299
300    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
301    fn load_model_from_path(
302        &self,
303        paths: &Box<dyn ModelPaths>,
304        dtype: &dyn TryIntoDType,
305        device: &Device,
306        silent: bool,
307        mut mapper: DeviceMapSetting,
308        mut in_situ_quant: Option<IsqType>,
309        mut paged_attn_config: Option<PagedAttentionConfig>,
310    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
311        let config = std::fs::read_to_string(paths.get_config_filename())?;
312
313        if !self.inner.supports_paged_attention(&config)? {
314            paged_attn_config = None;
315        }
316
317        // Apply default prompt size here
318        let prompt_chunksize = self
319            .config
320            .prompt_chunksize
321            .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
322            .get();
323
324        info!("Prompt chunk size is {prompt_chunksize}.",);
325
326        let use_nccl = mistralrs_quant::distributed::use_nccl();
327
328        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
329            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
330            let WorkerTransferData::Init { id: _, worker_rank } = payload;
331            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
332        } else if use_nccl {
333            vec![candle_core::Device::new_cuda(0)?]
334        } else {
335            device_map::get_all_similar_devices(device)?
336        };
337        let device = if use_nccl || cfg!(feature = "ring") {
338            available_devices[0].clone()
339        } else {
340            device.clone()
341        };
342
343        // If auto, convert to Map if not using nccl
344        if use_nccl || cfg!(feature = "ring") {
345            mapper = DeviceMapSetting::DummyNccl {
346                nm_device: available_devices[0].clone(),
347            };
348        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
349            // Initial dtype
350            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
351
352            // Disable ISQ if we are loading a prequantized model.
353            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
354                in_situ_quant = None;
355            }
356
357            // ISQ or UQFF: quantized path
358            // Match logic below where UQFF has priority
359            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
360                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
361                    let weight_pack_factor = {
362                        let ser_artifacts = unsafe {
363                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
364                        };
365                        let mut total_pack_factors = 0;
366                        let total_tensors = ser_artifacts.tensors().len();
367                        for (_, artifact) in ser_artifacts.tensors() {
368                            let artifact = artifact.data();
369                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
370                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
371                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
372                            {
373                                QuantizedSerdeType::Hqq => {
374                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
375                                        .pack_factor(dtype)
376                                }
377                                QuantizedSerdeType::Gguf => {
378                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
379                                        .pack_factor(dtype)
380                                }
381                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
382                                QuantizedSerdeType::Unquant => 1,
383                                QuantizedSerdeType::Afq => {
384                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
385                                        .pack_factor(dtype)
386                                }
387                            };
388                            total_pack_factors += pack_factor;
389                        }
390
391                        total_pack_factors / total_tensors
392                    };
393
394                    let layer_sizes_in_bytes =
395                        self.inner
396                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
397                    let non_mapped_size_in_bytes =
398                        self.inner
399                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
400                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
401                    (
402                        layer_sizes_in_bytes,
403                        non_mapped_size_in_bytes,
404                        layer_sizes_sum + non_mapped_size_in_bytes,
405                    )
406                } else if let Some(isq) = in_situ_quant {
407                    let weight_pack_factor = isq.pack_factor(dtype);
408                    let layer_sizes_in_bytes =
409                        self.inner
410                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
411                    let non_mapped_size_in_bytes =
412                        self.inner
413                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
414                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
415                    (
416                        layer_sizes_in_bytes,
417                        non_mapped_size_in_bytes,
418                        layer_sizes_sum + non_mapped_size_in_bytes,
419                    )
420                } else {
421                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
422                    let weight_pack_factor =
423                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
424                    let layer_sizes_in_bytes =
425                        self.inner
426                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
427                    let non_mapped_size_in_bytes =
428                        self.inner
429                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
430                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
431                    (
432                        layer_sizes_in_bytes,
433                        non_mapped_size_in_bytes,
434                        layer_sizes_sum + non_mapped_size_in_bytes,
435                    )
436                };
437
438            let new = auto_device_map::get_device_layers(
439                &*self.inner,
440                &config,
441                self.inner.num_layers(&config)?,
442                layer_sizes_in_bytes,
443                non_mapped_size_in_bytes,
444                total_model_size_in_bytes,
445                &available_devices,
446                dtype,
447                &params,
448                prompt_chunksize,
449                paged_attn_config.as_ref(),
450            )?;
451            mapper = DeviceMapSetting::Map(new);
452        }
453
454        let pipeline_mapper = mapper.into_mapper(
455            self.inner.num_layers(&config)?,
456            &device,
457            self.config.topology.as_ref(),
458        )?;
459        let mapper = mapper.into_mapper(
460            self.inner.num_layers(&config)?,
461            &device,
462            self.config.topology.as_ref(),
463        )?;
464        let mut layer_devices = Vec::new();
465        for layer in 0..self.inner.num_layers(&config)? {
466            let device = mapper.device_for(layer, false).cloned();
467            layer_devices.push(device);
468        }
469        let dtype = mapper.get_min_dtype(dtype)?;
470
471        // TODO: PagedAttention is not supported with CPU for now.
472        // This check is not really necessary because `get_device_layers` should prevent it.
473        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
474        if mapping_uses_cpu {
475            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
476            paged_attn_config = None;
477        }
478
479        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
480        if crate::using_flash_attn() {
481            once_log_info("FlashAttention is enabled.");
482        }
483
484        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
485        let mut loading_isq = if self.config.imatrix.is_none()
486            && self.config.calibration_file.is_none()
487            && !device.is_cuda()
488            && self.config.write_uqff.is_none()
489            && in_situ_quant.is_some()
490        {
491            let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
492            {
493                self.inner.immediate_isq_predicates_moqe(&config)?
494            } else {
495                self.inner.immediate_isq_predicates(&config)?
496            };
497            info!("Applying ISQ to {in_situ_quant:?}");
498            if predicates.is_empty() {
499                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
500            }
501            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
502            false
503        } else {
504            in_situ_quant.is_some()
505        };
506
507        if let Some(ref topology) = self.config.topology {
508            loading_isq |= topology
509                .0
510                .iter()
511                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
512        }
513
514        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
515            anyhow::bail!(
516                "`imatrix` and `calibration_file` were both specified, this is not allowed."
517            );
518        }
519
520        // Load onto the regular device if not using isq or if the calibration file is specified
521        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
522            loading_isq = false;
523            device.clone()
524        } else {
525            Device::Cpu
526        };
527
528        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
529
530        let attention_mechanism = if paged_attn_config.is_some() {
531            AttentionImplementation::PagedAttention
532        } else {
533            AttentionImplementation::Eager
534        };
535
536        let multi_progress = Arc::new(MultiProgress::new());
537
538        let mut model = if use_nccl || cfg!(feature = "ring") {
539            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
540                dtype,
541                &device,
542                &available_devices,
543                silent,
544                &config,
545                loading_isq,
546                self.config.from_uqff.is_some(),
547                self.config.organization,
548                &*self.inner,
549                paths.as_ref(),
550            )?;
551
552            // Special case for where things can be more optimially loaded.
553            match self.kind {
554                ModelKind::Normal => normal_model_loader_sharded!(
555                    sharded_vb,
556                    config,
557                    self.inner,
558                    mapper,
559                    loading_isq,
560                    device.clone(),
561                    attention_mechanism,
562                    multi_progress.clone(),
563                ),
564                ModelKind::Adapter {
565                    adapter: AdapterKind::XLora,
566                } => xlora_model_loader!(
567                    paths,
568                    Some(dtype),
569                    &load_device,
570                    layer_devices.clone(),
571                    config,
572                    self.inner,
573                    silent,
574                    mapper,
575                    loading_isq,
576                    device.clone(),
577                    multi_progress.clone(),
578                ),
579                ModelKind::Adapter {
580                    adapter: AdapterKind::Lora,
581                } => lora_model_loader!(
582                    paths,
583                    Some(dtype),
584                    &load_device,
585                    layer_devices.clone(),
586                    config,
587                    self.inner,
588                    silent,
589                    mapper,
590                    loading_isq,
591                    self.config.from_uqff.is_some(),
592                    device.clone(),
593                    attention_mechanism,
594                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
595                    multi_progress.clone(),
596                ),
597                _ => unreachable!(),
598            }
599        } else {
600            match self.kind {
601                ModelKind::Normal => normal_model_loader!(
602                    paths,
603                    Some(dtype),
604                    &load_device,
605                    layer_devices.clone(),
606                    config,
607                    self.inner,
608                    silent,
609                    mapper,
610                    loading_isq,
611                    self.config.from_uqff.is_some(),
612                    device.clone(),
613                    attention_mechanism,
614                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
615                    multi_progress.clone(),
616                ),
617                ModelKind::Adapter {
618                    adapter: AdapterKind::XLora,
619                } => xlora_model_loader!(
620                    paths,
621                    Some(dtype),
622                    &load_device,
623                    layer_devices.clone(),
624                    config,
625                    self.inner,
626                    silent,
627                    mapper,
628                    loading_isq,
629                    device.clone(),
630                    multi_progress.clone(),
631                ),
632                ModelKind::Adapter {
633                    adapter: AdapterKind::Lora,
634                } => lora_model_loader!(
635                    paths,
636                    Some(dtype),
637                    &load_device,
638                    layer_devices.clone(),
639                    config,
640                    self.inner,
641                    silent,
642                    mapper,
643                    loading_isq,
644                    self.config.from_uqff.is_some(),
645                    device.clone(),
646                    attention_mechanism,
647                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
648                    multi_progress.clone(),
649                ),
650                _ => unreachable!(),
651            }
652        };
653
654        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
655        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
656            match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
657                Ok(conf) => Some(conf),
658                Err(e) => {
659                    warn!("Failed to parse generation_config.json: {}", e);
660                    None
661                }
662            }
663        });
664
665        let chat_template_explicit = paths
666            .get_chat_template_explicit()
667            .as_ref()
668            .map(|x| x.to_string_lossy().to_string());
669        let chat_template = get_chat_template(
670            paths,
671            self.jinja_explicit.as_ref(),
672            chat_template_explicit.as_ref(),
673            self.chat_template.as_ref(),
674            None,
675        );
676
677        if let Some(calibration_file) = &self.config.calibration_file {
678            let calibration_data = std::fs::read_to_string(calibration_file)?;
679            // Tokenize, don't add bos yet
680            let tokens = tokenizer
681                .encode_fast(calibration_data, false)
682                .map_err(anyhow::Error::msg)?
683                .get_ids()
684                .to_vec();
685            info!(
686                "Collecting imatrix from calibration file `{}` of {} tokens.",
687                calibration_file.display(),
688                tokens.len()
689            );
690            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
691            let bos_tok_id = tokenizer
692                .token_to_id(&bos_toks[0])
693                .expect("Somehow the bos token is not present.");
694
695            match self.config.organization {
696                IsqOrganization::Default => model.begin_track_stats()?,
697                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
698            }
699
700            const CHUNK_SIZE: usize = 1024;
701            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
702            let start = Instant::now();
703            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
704                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
705                let chunk_len = chunk.len();
706
707                let start = Instant::now();
708                let inputs = make_prompt_chunk(
709                    0,
710                    vec![&chunk],
711                    &[0],
712                    &load_device,
713                    None,
714                    false,
715                    None,
716                    Some(pipeline_mapper.as_ref()),
717                )?;
718
719                model.forward(
720                    &inputs.input.to_device(model.device())?,
721                    &inputs.positions,
722                    inputs.context_lens.clone(),
723                    inputs.position_ids.clone(),
724                    None,
725                    &inputs.flash_meta.clone(),
726                )?;
727
728                match model.cache_mut() {
729                    EitherCache::Full(full) => {
730                        for layer in &mut *full.lock() {
731                            *layer = None
732                        }
733                    }
734                    EitherCache::Normal(normal) => {
735                        for layer in &mut *normal.lock().unwrap().0 {
736                            layer.reset();
737                        }
738                    }
739                }
740
741                let end = Instant::now();
742                info!(
743                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
744                    i + 1,
745                    end.duration_since(start).as_secs_f32()
746                );
747            }
748            load_device.synchronize()?;
749            let end = Instant::now();
750            info!(
751                "Finished collecting imatrix in {:.2}s",
752                end.duration_since(start).as_secs_f32()
753            );
754        }
755
756        // Only if loading from UQFF
757        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
758            let imatrix_source = match (
759                self.config.imatrix.as_ref(),
760                self.config.calibration_file.is_some(),
761            ) {
762                (None, false) => None,
763                (Some(file), false) => Some(ImatrixDataSource::File(file)),
764                (None, true) => Some(ImatrixDataSource::Collected),
765                (Some(_), true) => unreachable!(),
766            };
767
768            info!("Applying ISQ to all ranks.");
769
770            let multi_progress = Arc::new(MultiProgress::new());
771
772            model.quantize(
773                in_situ_quant,
774                model.device().clone(),
775                self.config.topology.as_ref(),
776                silent,
777                imatrix_source,
778                self.config.organization,
779                self.config.write_uqff.as_ref(),
780                UqffFullSer {
781                    tokenizer: &tokenizer,
782                    template_filename: paths.get_template_filename(),
783                    generation_config: paths.get_gen_conf_filename(),
784                    config: config.clone(),
785                    processor_filename: &None,
786                    preprocessor_filename: &None,
787                },
788                multi_progress.clone(),
789            )?;
790        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
791            model.load_from_artifacts(
792                device.clone(),
793                self.config.topology.as_ref(),
794                silent,
795                from_uqff,
796            )?;
797        }
798
799        let paged_attn_config = if matches!(
800            self.kind,
801            ModelKind::Adapter {
802                adapter: AdapterKind::XLora
803            }
804        ) {
805            warn!(
806                "Adapter parallel_models do not currently support PagedAttention, running without"
807            );
808            None
809        } else {
810            paged_attn_config
811        };
812
813        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
814            let cache_config = calculate_cache_config(
815                paged_attn_config.mem_gpu,
816                paged_attn_config.mem_cpu,
817                paged_attn_config.block_size,
818                dtype,
819                paged_attn_config.cache_type,
820                model.config(),
821                &device,
822                &pipeline_mapper
823                    .get_unique_devices()
824                    .into_iter()
825                    .map(Some)
826                    .collect::<Vec<_>>(),
827                silent,
828            )?;
829
830            let mut layer_devices = Vec::new();
831            for layer in 0..self.inner.num_layers(&config)? {
832                let device = model.get_layers().1.device_for(layer, false).cloned();
833                layer_devices.push(device);
834            }
835            let cache_engine = CacheEngine::new(
836                model.config(),
837                &cache_config,
838                dtype,
839                model.device(),
840                layer_devices.clone(),
841            )?;
842
843            (Some(cache_config), Some(cache_engine))
844        } else {
845            (None, None)
846        };
847
848        let max_seq_len = model.max_seq_len();
849        let llg_factory = build_llg_factory(tokenizer.clone())?;
850        let num_hidden_layers = match model.cache() {
851            EitherCache::Full(full) => full.lock().len(),
852            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
853        };
854        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
855        let sliding_window = model.config().sliding_window;
856        let model_metadata = Arc::new(model.config().clone());
857
858        Ok(Arc::new(Mutex::new(NormalPipeline {
859            model,
860            tokenizer: tokenizer.into(),
861            no_kv_cache: self.no_kv_cache,
862            chat_template: Arc::new(chat_template),
863            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
864                NonGranularState {
865                    non_granular_index: Arc::new(Mutex::new(0)),
866                    tgt_non_granular_index,
867                }
868            }),
869            model_id: self.model_id.clone(),
870            metadata: Arc::new(GeneralMetadata {
871                max_seq_len,
872                llg_factory: Some(llg_factory),
873                no_kv_cache: self.no_kv_cache,
874                no_prefix_cache: is_xlora,
875                num_hidden_layers,
876                eos_tok: eos,
877                kind: self.kind.clone(),
878                is_xlora,
879                activation_dtype: dtype,
880                sliding_window,
881                cache_config,
882                cache_engine,
883                prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
884                model_metadata: Some(model_metadata),
885                modalities: Modalities {
886                    input: vec![SupportedModality::Text],
887                    output: vec![SupportedModality::Text],
888                },
889            }),
890            topology: self.config.topology.clone(),
891            silent,
892            organization: self.config.organization,
893            template_filename: paths.get_template_filename().clone(),
894            generation_config: paths.get_gen_conf_filename().cloned(),
895            config,
896            imatrix: self.config.imatrix.clone(),
897            mapper: pipeline_mapper,
898        })))
899    }
900
901    fn get_id(&self) -> String {
902        self.model_id.clone()
903    }
904
905    fn get_kind(&self) -> ModelKind {
906        self.kind.clone()
907    }
908}
909
910impl PreProcessingMixin for NormalPipeline {
911    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
912        Some(self.chat_template.clone())
913    }
914    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
915        None
916    }
917}
918
919impl IsqPipelineMixin for NormalPipeline {
920    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
921        let device = self.device().clone();
922        let multi_progress = Arc::new(MultiProgress::new());
923        self.model.quantize(
924            Some(dtype),
925            device.clone(),
926            self.topology.as_ref(),
927            self.silent,
928            self.imatrix.as_ref().map(ImatrixDataSource::File),
929            self.organization,
930            None,
931            UqffFullSer {
932                tokenizer: &self.tokenizer,
933                template_filename: &self.template_filename,
934                generation_config: self.generation_config.as_ref(),
935                config: self.config.clone(),
936                processor_filename: &None,
937                preprocessor_filename: &None,
938            },
939            multi_progress.clone(),
940        )?;
941        Ok(())
942    }
943}
944
945impl CacheManagerMixin for NormalPipeline {
946    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
947        if matches!(self.model.cache(), EitherCache::Full(_)) {
948            FullCacheManager.clone_in_cache(self, seqs, false)
949        } else {
950            NormalCacheManager.clone_in_cache(self, seqs, false)
951        }
952    }
953    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
954        if matches!(self.model.cache(), EitherCache::Full(_)) {
955            FullCacheManager.clone_out_cache(self, seqs, false)
956        } else {
957            NormalCacheManager.clone_out_cache(self, seqs, false)
958        }
959    }
960    fn set_none_cache(
961        &self,
962        seqs: &mut [&mut Sequence],
963        reset_non_granular: bool,
964        modify_draft_cache: bool,
965        load_preallocated_cache: bool,
966    ) {
967        if matches!(self.model.cache(), EitherCache::Full(_)) {
968            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
969        } else {
970            NormalCacheManager.set_none_cache(
971                self,
972                seqs,
973                modify_draft_cache,
974                load_preallocated_cache,
975            );
976        }
977        if reset_non_granular {
978            self.reset_non_granular_state()
979        }
980    }
981    fn cache(&self) -> &EitherCache {
982        self.model.cache()
983    }
984}
985
986impl MetadataMixin for NormalPipeline {
987    fn device(&self) -> Device {
988        self.model.device().clone()
989    }
990    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
991        Some(self.tokenizer.clone())
992    }
993    fn name(&self) -> String {
994        self.model_id.clone()
995    }
996    fn reset_non_granular_state(&self) {
997        if let Some(s) = self.non_granular_state.as_ref() {
998            *self.cache().full().get_scalings_cache() = None;
999            *get_mut_arcmutex!(s.non_granular_index) = 0;
1000        }
1001    }
1002    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1003        self.metadata.clone()
1004    }
1005    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1006        Some(&*self.mapper)
1007    }
1008}
1009
1010#[async_trait::async_trait]
1011impl Pipeline for NormalPipeline {
1012    fn forward_inputs(
1013        &mut self,
1014        inputs: Box<dyn Any>,
1015        return_raw_logits: bool,
1016    ) -> Result<ForwardInputsResult, candle_core::Error> {
1017        let ModelInputs {
1018            input_ids,
1019            input_ids_full,
1020            seqlen_offsets,
1021            seqlen_offsets_full,
1022            context_lens,
1023            position_ids,
1024            paged_attn_meta,
1025            flash_meta,
1026            flash_meta_full,
1027        } = *inputs.downcast().expect("Downcast failed.");
1028        let metadata = self.get_metadata();
1029        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1030            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1031            (Some(_), None) => {
1032                // This can happen if Rust-side user code is wrong
1033                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1034            }
1035            (None, Some(_)) => {
1036                // This should never happen but we handle it anyway
1037                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1038            }
1039            (None, None) => None,
1040        };
1041        let logits = match self.model.is_xlora() {
1042            false => {
1043                let paged_attn_meta = paged_attn_meta
1044                    .as_ref()
1045                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1046
1047                self.model.forward(
1048                    &input_ids,
1049                    &seqlen_offsets,
1050                    context_lens,
1051                    position_ids,
1052                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1053                    &flash_meta,
1054                )?
1055            }
1056            true => self.model.xlora_forward(
1057                &input_ids,
1058                input_ids_full.as_ref().unwrap_or(&input_ids),
1059                &seqlen_offsets,
1060                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1061                self.no_kv_cache,
1062                &self.non_granular_state,
1063                context_lens,
1064                position_ids,
1065                &flash_meta,
1066                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1067            )?,
1068        };
1069        if return_raw_logits {
1070            Ok(ForwardInputsResult::RawLogits { logits })
1071        } else {
1072            Ok(ForwardInputsResult::CausalGeneration { logits })
1073        }
1074    }
1075    async fn sample_causal_gen(
1076        &self,
1077        seqs: &mut [&mut Sequence],
1078        logits: Vec<Tensor>,
1079        prefix_cacher: &mut PrefixCacheManagerV2,
1080        disable_eos_stop: bool,
1081        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1082    ) -> Result<(), candle_core::Error> {
1083        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1084    }
1085    fn category(&self) -> ModelCategory {
1086        ModelCategory::Text
1087    }
1088}
1089
1090impl AnyMoePipelineMixin for NormalPipeline {
1091    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1092        self.model.finish_training(gate_model_id)
1093    }
1094    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1095        self.model.get_vars()
1096    }
1097    fn amoe_base_model_trainable_params(&self) -> usize {
1098        self.model.trainable_params()
1099    }
1100    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1101        self.model.take_cached_gating_outputs()
1102    }
1103    fn amoe_create_layers(
1104        &mut self,
1105        model_ids: Vec<String>,
1106        token: &TokenSource,
1107        revision: Option<String>,
1108        match_regex: &str,
1109        config: crate::amoe::AnyMoeConfig,
1110        dtype: candle_core::DType,
1111        dev: &Device,
1112        (prefix, mlp): (String, String),
1113        layers: Vec<usize>,
1114        expert_type: AnyMoeExpertType,
1115        silent: bool,
1116        gate_model_id: Option<String>,
1117    ) -> candle_core::Result<()> {
1118        let mut vbs = Vec::new();
1119        // Precompile regex here
1120        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1121        for model_id in model_ids {
1122            let model_id_str = &model_id;
1123            let model_id = Path::new(&model_id);
1124
1125            let api = {
1126                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1127                let mut api = ApiBuilder::from_cache(cache)
1128                    .with_progress(!silent)
1129                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1130                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1131                    api = api.with_cache_dir(x.into());
1132                }
1133                api.build().map_err(candle_core::Error::msg)?
1134            };
1135            let revision = revision.clone().unwrap_or("main".to_string());
1136            let api = api.repo(Repo::with_revision(
1137                model_id_str.clone(),
1138                RepoType::Model,
1139                revision.clone(),
1140            ));
1141
1142            let mut filenames = vec![];
1143            for rfilename in
1144                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1145            {
1146                filenames.push(api_get_file!(api, &rfilename, model_id));
1147            }
1148
1149            let regex = regex.clone();
1150            let match_regex_clone = match_regex.to_string();
1151            let layers_clone = layers.clone();
1152            let vb = from_mmaped_safetensors(
1153                filenames,
1154                vec![],
1155                Some(dtype),
1156                dev,
1157                vec![None],
1158                silent,
1159                None,
1160                move |key| {
1161                    if regex.is_match(&key) {
1162                        // Idx of the last char of the layer id, +1
1163                        // Assumes N.MLP
1164                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1165                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1166                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1167                            .parse::<usize>()
1168                            .unwrap();
1169                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1170                    } else {
1171                        false
1172                    }
1173                },
1174                Arc::new(|_| DeviceForLoadTensor::Base),
1175            )?;
1176            vbs.push(vb);
1177        }
1178
1179        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1180            let model_id_str = &gate_model_id;
1181            let model_id = Path::new(&gate_model_id);
1182
1183            let api = {
1184                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1185                let mut api = ApiBuilder::from_cache(cache)
1186                    .with_progress(!silent)
1187                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1188                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1189                    api = api.with_cache_dir(x.into());
1190                }
1191                api.build().map_err(candle_core::Error::msg)?
1192            };
1193            let revision = revision.clone().unwrap_or("main".to_string());
1194            let api = api.repo(Repo::with_revision(
1195                model_id_str.clone(),
1196                RepoType::Model,
1197                revision.clone(),
1198            ));
1199
1200            let mut gate_filenames = vec![];
1201            for rfilename in
1202                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1203            {
1204                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1205            }
1206            assert_eq!(
1207                gate_filenames.len(),
1208                1,
1209                "Gate model ID must contain only one .safetensors file"
1210            );
1211
1212            let vb = from_mmaped_safetensors(
1213                gate_filenames.clone(),
1214                vec![],
1215                Some(dtype),
1216                dev,
1217                vec![None],
1218                silent,
1219                None,
1220                |_| true,
1221                Arc::new(|_| DeviceForLoadTensor::Base),
1222            )?;
1223            info!(
1224                "Loaded gating layers from `{}`",
1225                gate_filenames[0].display()
1226            );
1227            Some(vb)
1228        } else {
1229            None
1230        };
1231
1232        self.model.create_anymoe_layers(
1233            vbs.clone(),
1234            config.clone(),
1235            (prefix.clone(), mlp.clone()),
1236            layers.clone(),
1237            expert_type.clone(),
1238            gate_vb.clone(),
1239        )?;
1240
1241        Ok(())
1242    }
1243    fn amoe_supported(&self) -> bool {
1244        self.model.amoe_supported()
1245    }
1246}