mistralrs_core/pipeline/
normal.rs

1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
3use super::isq::ImatrixDataSource;
4use super::llg::build_tok_env;
5use super::{
6    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
7    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
8    TokenSource,
9};
10use super::{
11    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
12    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
13};
14use super::{
15    AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
16    MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
17    Qwen2Loader, Starcoder2Loader,
18};
19use crate::amoe::AnyMoeExpertType;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::get_chat_template;
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::sampling::sample_and_add_toks;
28use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
29use crate::pipeline::{ChatTemplate, LocalModelPaths};
30use crate::prefix_cacher::PrefixCacheManagerV2;
31use crate::sequence::Sequence;
32use crate::utils::tokenizer::get_tokenizer;
33use crate::utils::varbuilder_utils::DeviceForLoadTensor;
34use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
35use crate::xlora_models::NonGranularState;
36use crate::{
37    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
38    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
39    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
40};
41use anyhow::Result;
42use candle_core::{Device, Tensor, Var};
43use hf_hub::Cache;
44use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
45use indicatif::MultiProgress;
46use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
47use rand_isaac::Isaac64Rng;
48use regex_automata::meta::Regex;
49use std::any::Any;
50use std::borrow::Cow;
51use std::num::{NonZero, NonZeroUsize};
52use std::path::{Path, PathBuf};
53use std::str::FromStr;
54use std::sync::{Arc, RwLock};
55use std::time::Instant;
56use std::{env, fs};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{info, warn};
60
61pub struct NormalPipeline {
62    model: Box<dyn NormalModel + Send + Sync>,
63    tokenizer: Arc<Tokenizer>,
64    no_kv_cache: bool,
65    chat_template: Arc<ChatTemplate>,
66    non_granular_state: Option<NonGranularState>,
67    model_id: String,
68    metadata: Arc<GeneralMetadata>,
69    topology: Option<Topology>,
70    silent: bool,
71    organization: IsqOrganization,
72    // For full UQFF serialization
73    template_filename: Option<PathBuf>,
74    generation_config: Option<PathBuf>,
75    config: String,
76    imatrix: Option<PathBuf>,
77    mapper: Box<dyn DeviceMapper + Send + Sync>,
78}
79
80/// A loader for a "normal" (non-quantized) model.
81pub struct NormalLoader {
82    inner: Box<dyn NormalModelLoader>,
83    model_id: String,
84    config: NormalSpecificConfig,
85    xlora_model_id: Option<String>,
86    lora_adapter_ids: Option<Vec<String>>,
87    kind: ModelKind,
88    xlora_order: Option<Ordering>,
89    no_kv_cache: bool,
90    chat_template: Option<String>,
91    tokenizer_json: Option<String>,
92    tgt_non_granular_index: Option<usize>,
93    token_source: RwLock<Option<TokenSource>>,
94    revision: RwLock<Option<String>>,
95    from_uqff: RwLock<Option<PathBuf>>,
96    jinja_explicit: Option<String>,
97    hf_cache_path: Option<PathBuf>,
98}
99
100#[derive(Default)]
101/// A builder for a loader for a "normal" (non-quantized) model.
102pub struct NormalLoaderBuilder {
103    model_id: Option<String>,
104    config: NormalSpecificConfig,
105    xlora_model_id: Option<String>,
106    lora_adapter_ids: Option<Vec<String>>,
107    kind: ModelKind,
108    xlora_order: Option<Ordering>,
109    no_kv_cache: bool,
110    chat_template: Option<String>,
111    tokenizer_json: Option<String>,
112    tgt_non_granular_index: Option<usize>,
113    jinja_explicit: Option<String>,
114    hf_cache_path: Option<PathBuf>,
115}
116
117#[derive(Clone, Default)]
118/// Config specific to loading a normal model.
119pub struct NormalSpecificConfig {
120    pub use_flash_attn: bool,
121    pub prompt_chunksize: Option<NonZeroUsize>,
122    pub topology: Option<Topology>,
123    pub organization: IsqOrganization,
124    pub write_uqff: Option<PathBuf>,
125    pub from_uqff: Option<PathBuf>,
126    pub imatrix: Option<PathBuf>,
127    pub calibration_file: Option<PathBuf>,
128    pub hf_cache_path: Option<PathBuf>,
129}
130
131impl NormalLoaderBuilder {
132    pub fn new(
133        config: NormalSpecificConfig,
134        chat_template: Option<String>,
135        tokenizer_json: Option<String>,
136        model_id: Option<String>,
137        no_kv_cache: bool,
138        jinja_explicit: Option<String>,
139    ) -> Self {
140        Self {
141            config,
142            chat_template,
143            tokenizer_json,
144            model_id,
145            kind: ModelKind::Normal,
146            jinja_explicit,
147            no_kv_cache,
148            ..Default::default()
149        }
150    }
151
152    fn with_adapter(
153        mut self,
154        xlora_model_id: String,
155        xlora_order: Ordering,
156        no_kv_cache: bool,
157        tgt_non_granular_index: Option<usize>,
158    ) -> Self {
159        self.xlora_model_id = Some(xlora_model_id);
160        self.xlora_order = Some(xlora_order);
161        self.no_kv_cache = no_kv_cache;
162        self.tgt_non_granular_index = tgt_non_granular_index;
163        self.model_id = if let Some(id) = self.model_id {
164            Some(id)
165        } else {
166            info!(
167                "Using adapter base model ID: `{}`",
168                self.xlora_order.as_ref().unwrap().base_model_id
169            );
170            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
171        };
172        self
173    }
174
175    pub fn with_xlora(
176        mut self,
177        xlora_model_id: String,
178        xlora_order: Ordering,
179        no_kv_cache: bool,
180        tgt_non_granular_index: Option<usize>,
181    ) -> Self {
182        self.kind = ModelKind::Adapter {
183            adapter: AdapterKind::XLora,
184        };
185        self.with_adapter(
186            xlora_model_id,
187            xlora_order,
188            no_kv_cache,
189            tgt_non_granular_index,
190        )
191    }
192
193    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
194        self.kind = ModelKind::Adapter {
195            adapter: AdapterKind::Lora,
196        };
197        self.lora_adapter_ids = Some(lora_adapter_ids);
198        self
199    }
200
201    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
202        self.hf_cache_path = Some(hf_cache_path);
203        self
204    }
205
206    /// If the loader type is not specified, loader type is automatically determined from the
207    /// `architectures` array in the config.
208    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
209        let loader: Box<dyn NormalModelLoader> = match loader_tp {
210            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
211            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
212            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
213            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
214            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
215            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
216            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
217            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
218            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
219            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
220            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
221            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
222            None => Box::new(AutoLoader),
223        };
224        Ok(Box::new(NormalLoader {
225            inner: loader,
226            model_id: self.model_id.unwrap(),
227            config: self.config,
228            xlora_model_id: self.xlora_model_id,
229            lora_adapter_ids: self.lora_adapter_ids,
230            kind: self.kind,
231            xlora_order: self.xlora_order,
232            no_kv_cache: self.no_kv_cache,
233            chat_template: self.chat_template,
234            tokenizer_json: self.tokenizer_json,
235            tgt_non_granular_index: self.tgt_non_granular_index,
236            jinja_explicit: self.jinja_explicit,
237            token_source: RwLock::new(None),
238            revision: RwLock::new(None),
239            from_uqff: RwLock::new(None),
240            hf_cache_path: self.hf_cache_path,
241        }))
242    }
243}
244
245impl Loader for NormalLoader {
246    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
247    fn load_model_from_hf(
248        &self,
249        revision: Option<String>,
250        token_source: TokenSource,
251        dtype: &dyn TryIntoDType,
252        device: &Device,
253        silent: bool,
254        mapper: DeviceMapSetting,
255        in_situ_quant: Option<IsqType>,
256        paged_attn_config: Option<PagedAttentionConfig>,
257    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
258        let cache = self
259            .hf_cache_path
260            .clone()
261            .map(Cache::new)
262            .unwrap_or_default();
263        GLOBAL_HF_CACHE.get_or_init(|| cache);
264
265        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
266            LocalModelPaths,
267            &token_source,
268            revision.clone(),
269            self,
270            None,
271            None,
272            silent,
273            self.config.from_uqff.is_some()
274        );
275        if let Some(from_uqff) = self.config.from_uqff.clone() {
276            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
277        }
278        *self
279            .token_source
280            .write()
281            .expect("Failed to write to token source") = Some(token_source);
282        *self.revision.write().expect("Failed to write to revision") = revision;
283        self.load_model_from_path(
284            &paths?,
285            dtype,
286            device,
287            silent,
288            mapper,
289            in_situ_quant,
290            paged_attn_config,
291        )
292    }
293
294    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
295    fn load_model_from_path(
296        &self,
297        paths: &Box<dyn ModelPaths>,
298        dtype: &dyn TryIntoDType,
299        device: &Device,
300        silent: bool,
301        mut mapper: DeviceMapSetting,
302        in_situ_quant: Option<IsqType>,
303        mut paged_attn_config: Option<PagedAttentionConfig>,
304    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
305        let config = std::fs::read_to_string(paths.get_config_filename())?;
306
307        // Apply default prompt size here
308        let prompt_chunksize = self
309            .config
310            .prompt_chunksize
311            .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
312            .get();
313
314        info!("Prompt chunk size is {prompt_chunksize}.",);
315
316        let use_nccl = mistralrs_quant::distributed::use_nccl();
317
318        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
319            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
320            let WorkerTransferData::Init { id: _, worker_rank } = payload;
321            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
322        } else if use_nccl {
323            vec![candle_core::Device::new_cuda(0)?]
324        } else {
325            device_map::get_all_similar_devices(device)?
326        };
327        let device = if use_nccl {
328            available_devices[0].clone()
329        } else {
330            device.clone()
331        };
332
333        // If auto, convert to Map if not using nccl
334        if use_nccl {
335            mapper = DeviceMapSetting::DummyNccl {
336                nm_device: available_devices[0].clone(),
337            };
338        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
339            // Initial dtype
340            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
341
342            // ISQ or UQFF: quantized path
343            // Match logic below where UQFF has priority
344            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
345                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
346                    let weight_pack_factor = {
347                        let ser_artifacts = unsafe {
348                            candle_core::safetensors::MmapedSafetensors::new(serialized)?
349                        };
350                        let mut total_pack_factors = 0;
351                        let total_tensors = ser_artifacts.tensors().len();
352                        for (_, artifact) in ser_artifacts.tensors() {
353                            let artifact = artifact.data();
354                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
355                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
356                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
357                            {
358                                QuantizedSerdeType::Hqq => {
359                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
360                                        .pack_factor(dtype)
361                                }
362                                QuantizedSerdeType::Gguf => {
363                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
364                                        .pack_factor(dtype)
365                                }
366                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
367                                QuantizedSerdeType::Unquant => 1,
368                                QuantizedSerdeType::Afq => {
369                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
370                                        .pack_factor(dtype)
371                                }
372                            };
373                            total_pack_factors += pack_factor;
374                        }
375
376                        total_pack_factors / total_tensors
377                    };
378
379                    let layer_sizes_in_bytes =
380                        self.inner
381                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
382                    let non_mapped_size_in_bytes =
383                        self.inner
384                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
385                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
386                    (
387                        layer_sizes_in_bytes,
388                        non_mapped_size_in_bytes,
389                        layer_sizes_sum + non_mapped_size_in_bytes,
390                    )
391                } else if let Some(isq) = in_situ_quant {
392                    let weight_pack_factor = isq.pack_factor(dtype);
393                    let layer_sizes_in_bytes =
394                        self.inner
395                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
396                    let non_mapped_size_in_bytes =
397                        self.inner
398                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
399                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
400                    (
401                        layer_sizes_in_bytes,
402                        non_mapped_size_in_bytes,
403                        layer_sizes_sum + non_mapped_size_in_bytes,
404                    )
405                } else {
406                    let layer_sizes_in_bytes =
407                        self.inner.layer_sizes_in_bytes(&config, dtype, 1)?;
408                    let non_mapped_size_in_bytes =
409                        self.inner.non_mapped_size_in_bytes(&config, dtype, 1)?;
410                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
411                    (
412                        layer_sizes_in_bytes,
413                        non_mapped_size_in_bytes,
414                        layer_sizes_sum + non_mapped_size_in_bytes,
415                    )
416                };
417
418            let new = self.inner.get_device_layers(
419                &config,
420                self.inner.num_layers(&config)?,
421                layer_sizes_in_bytes,
422                non_mapped_size_in_bytes,
423                total_model_size_in_bytes,
424                &available_devices,
425                dtype,
426                &params,
427                prompt_chunksize,
428                paged_attn_config.as_ref(),
429            )?;
430            mapper = DeviceMapSetting::Map(new);
431        }
432
433        let pipeline_mapper = mapper.into_mapper(
434            self.inner.num_layers(&config)?,
435            &device,
436            self.config.topology.as_ref(),
437        )?;
438        let mapper = mapper.into_mapper(
439            self.inner.num_layers(&config)?,
440            &device,
441            self.config.topology.as_ref(),
442        )?;
443        let mut layer_devices = Vec::new();
444        for layer in 0..self.inner.num_layers(&config)? {
445            let device = mapper.device_for(layer, false).cloned();
446            layer_devices.push(device);
447        }
448        let dtype = mapper.get_min_dtype(dtype)?;
449
450        // TODO: PagedAttention is not supported with CPU for now.
451        // This check is not really necessary because `get_device_layers` should prevent it.
452        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
453        if mapping_uses_cpu {
454            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
455            paged_attn_config = None;
456        }
457
458        info!(
459            "Model config: {:?}",
460            self.inner
461                .get_config_repr(&config, self.config.use_flash_attn)?
462        );
463
464        let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
465        if let Some(ref topology) = self.config.topology {
466            loading_isq |= topology
467                .0
468                .iter()
469                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
470        }
471
472        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
473            anyhow::bail!(
474                "`imatrix` and `calibration_file` were both specified, this is not allowed."
475            );
476        }
477
478        // Load onto the regular device if not using isq or if the calibration file is specified
479        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
480            loading_isq = false;
481            device.clone()
482        } else {
483            Device::Cpu
484        };
485
486        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
487
488        let attention_mechanism = if paged_attn_config.is_some() {
489            AttentionImplementation::PagedAttention
490        } else {
491            AttentionImplementation::Eager
492        };
493
494        let multi_progress = Arc::new(MultiProgress::new());
495
496        let mut model = if use_nccl {
497            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
498                dtype,
499                &device,
500                &load_device,
501                &available_devices,
502                &config,
503                loading_isq,
504                self.config.from_uqff.is_some(),
505                self.config.organization,
506                &*self.inner,
507                paths.as_ref(),
508            )?;
509
510            // Special case for where things can be more optimially loaded.
511            match self.kind {
512                ModelKind::Normal => normal_model_loader_sharded!(
513                    sharded_vb,
514                    config,
515                    self.inner,
516                    self.config.use_flash_attn,
517                    mapper,
518                    loading_isq,
519                    device.clone(),
520                    attention_mechanism,
521                    multi_progress.clone(),
522                ),
523                ModelKind::Adapter {
524                    adapter: AdapterKind::XLora,
525                } => xlora_model_loader!(
526                    paths,
527                    Some(dtype),
528                    &load_device,
529                    layer_devices.clone(),
530                    config,
531                    self.inner,
532                    self.config.use_flash_attn,
533                    silent,
534                    mapper,
535                    loading_isq,
536                    device.clone(),
537                    multi_progress.clone(),
538                ),
539                ModelKind::Adapter {
540                    adapter: AdapterKind::Lora,
541                } => lora_model_loader!(
542                    paths,
543                    Some(dtype),
544                    &load_device,
545                    layer_devices.clone(),
546                    config,
547                    self.inner,
548                    self.config.use_flash_attn,
549                    silent,
550                    mapper,
551                    loading_isq,
552                    self.config.from_uqff.is_some(),
553                    device.clone(),
554                    attention_mechanism,
555                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
556                    multi_progress.clone(),
557                ),
558                _ => unreachable!(),
559            }
560        } else {
561            match self.kind {
562                ModelKind::Normal => normal_model_loader!(
563                    paths,
564                    Some(dtype),
565                    &load_device,
566                    layer_devices.clone(),
567                    config,
568                    self.inner,
569                    self.config.use_flash_attn,
570                    silent,
571                    mapper,
572                    loading_isq,
573                    self.config.from_uqff.is_some(),
574                    device.clone(),
575                    attention_mechanism,
576                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
577                    multi_progress.clone(),
578                ),
579                ModelKind::Adapter {
580                    adapter: AdapterKind::XLora,
581                } => xlora_model_loader!(
582                    paths,
583                    Some(dtype),
584                    &load_device,
585                    layer_devices.clone(),
586                    config,
587                    self.inner,
588                    self.config.use_flash_attn,
589                    silent,
590                    mapper,
591                    loading_isq,
592                    device.clone(),
593                    multi_progress.clone(),
594                ),
595                ModelKind::Adapter {
596                    adapter: AdapterKind::Lora,
597                } => lora_model_loader!(
598                    paths,
599                    Some(dtype),
600                    &load_device,
601                    layer_devices.clone(),
602                    config,
603                    self.inner,
604                    self.config.use_flash_attn,
605                    silent,
606                    mapper,
607                    loading_isq,
608                    self.config.from_uqff.is_some(),
609                    device.clone(),
610                    attention_mechanism,
611                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
612                    multi_progress.clone(),
613                ),
614                _ => unreachable!(),
615            }
616        };
617
618        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
619        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
620            serde_json::from_str(&fs::read_to_string(f).unwrap())
621                .expect("bos_token_id/eos_token_id missing in generation_config.json")
622        });
623
624        let chat_template = get_chat_template(
625            paths,
626            &self.jinja_explicit,
627            &paths
628                .get_chat_template_explicit()
629                .as_ref()
630                .map(|x| x.to_string_lossy().to_string())
631                .clone(),
632            &self.chat_template,
633            None,
634        );
635
636        if let Some(calibration_file) = &self.config.calibration_file {
637            let calibration_data = std::fs::read_to_string(calibration_file)?;
638            // Tokenize, don't add bos yet
639            let tokens = tokenizer
640                .encode_fast(calibration_data, false)
641                .map_err(anyhow::Error::msg)?
642                .get_ids()
643                .to_vec();
644            info!(
645                "Collecting imatrix from calibration file `{}` of {} tokens.",
646                calibration_file.display(),
647                tokens.len()
648            );
649            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
650            let bos_tok_id = tokenizer
651                .token_to_id(&bos_toks[0])
652                .expect("Somehow the bos token is not present.");
653
654            match self.config.organization {
655                IsqOrganization::Default => model.begin_track_stats()?,
656                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
657            }
658
659            const CHUNK_SIZE: usize = 1024;
660            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
661            let start = Instant::now();
662            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
663                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
664                let chunk_len = chunk.len();
665
666                let start = Instant::now();
667                let inputs = make_prompt_chunk(
668                    0,
669                    vec![chunk],
670                    &[0],
671                    &load_device,
672                    None,
673                    false,
674                    None,
675                    Some(pipeline_mapper.as_ref()),
676                )?;
677
678                model.forward(
679                    &inputs.input.to_device(model.device())?,
680                    &inputs.positions,
681                    inputs.context_lens.clone(),
682                    inputs.position_ids.clone(),
683                    None,
684                    &inputs.flash_meta.clone(),
685                )?;
686
687                match model.cache_mut() {
688                    EitherCache::Full(full) => {
689                        for layer in &mut *full.lock() {
690                            *layer = None
691                        }
692                    }
693                    EitherCache::Normal(normal) => {
694                        for layer in &mut *normal.lock().unwrap().0 {
695                            layer.reset();
696                        }
697                    }
698                }
699
700                let end = Instant::now();
701                info!(
702                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
703                    i + 1,
704                    end.duration_since(start).as_secs_f32()
705                );
706            }
707            load_device.synchronize()?;
708            let end = Instant::now();
709            info!(
710                "Finished collecting imatrix in {:.2}s",
711                end.duration_since(start).as_secs_f32()
712            );
713        }
714
715        if (in_situ_quant.is_some() || self.config.topology.is_some())
716            && self.config.from_uqff.is_none()
717        {
718            let imatrix_source = match (
719                self.config.imatrix.as_ref(),
720                self.config.calibration_file.is_some(),
721            ) {
722                (None, false) => None,
723                (Some(file), false) => Some(ImatrixDataSource::File(file)),
724                (None, true) => Some(ImatrixDataSource::Collected),
725                (Some(_), true) => unreachable!(),
726            };
727
728            info!("Applying ISQ to all ranks.");
729
730            let multi_progress = Arc::new(MultiProgress::new());
731
732            model.quantize(
733                in_situ_quant,
734                model.device().clone(),
735                self.config.topology.as_ref(),
736                silent,
737                imatrix_source,
738                self.config.organization,
739                self.config.write_uqff.as_ref(),
740                UqffFullSer {
741                    tokenizer: &tokenizer,
742                    template_filename: paths.get_template_filename(),
743                    generation_config: paths.get_gen_conf_filename(),
744                    config: config.clone(),
745                    processor_filename: &None,
746                    preprocessor_filename: &None,
747                },
748                multi_progress.clone(),
749            )?;
750        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
751            model.load_from_artifacts(
752                device.clone(),
753                self.config.topology.as_ref(),
754                silent,
755                from_uqff,
756            )?;
757        }
758
759        let paged_attn_config = if matches!(
760            self.kind,
761            ModelKind::Adapter {
762                adapter: AdapterKind::XLora
763            }
764        ) {
765            warn!(
766                "Adapter parallel_models do not currently support PagedAttention, running without"
767            );
768            None
769        } else {
770            paged_attn_config
771        };
772
773        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
774            let cache_config = calculate_cache_config(
775                paged_attn_config.mem_gpu,
776                paged_attn_config.mem_cpu,
777                paged_attn_config.block_size,
778                dtype,
779                model.config(),
780                &device,
781                &pipeline_mapper
782                    .get_unique_devices()
783                    .into_iter()
784                    .map(Some)
785                    .collect::<Vec<_>>(),
786                silent,
787            )?;
788
789            let mut layer_devices = Vec::new();
790            for layer in 0..self.inner.num_layers(&config)? {
791                let device = model.get_layers().1.device_for(layer, false).cloned();
792                layer_devices.push(device);
793            }
794            let cache_engine = CacheEngine::new(
795                model.config(),
796                &cache_config,
797                dtype,
798                model.device(),
799                layer_devices.clone(),
800            )?;
801
802            (Some(cache_config), Some(cache_engine))
803        } else {
804            (None, None)
805        };
806
807        let max_seq_len = model.max_seq_len();
808        let tok_env = build_tok_env(tokenizer.clone());
809        let num_hidden_layers = match model.cache() {
810            EitherCache::Full(full) => full.lock().len(),
811            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
812        };
813        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
814        let sliding_window = model.config().sliding_window;
815        let model_metadata = Arc::new(model.config().clone());
816
817        Ok(Arc::new(Mutex::new(NormalPipeline {
818            model,
819            tokenizer: tokenizer.into(),
820            no_kv_cache: self.no_kv_cache,
821            chat_template: Arc::new(chat_template),
822            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
823                NonGranularState {
824                    non_granular_index: Arc::new(Mutex::new(0)),
825                    tgt_non_granular_index,
826                }
827            }),
828            model_id: self.model_id.clone(),
829            metadata: Arc::new(GeneralMetadata {
830                max_seq_len,
831                tok_env: Some(tok_env),
832                no_kv_cache: self.no_kv_cache,
833                no_prefix_cache: is_xlora,
834                num_hidden_layers,
835                eos_tok: eos,
836                kind: self.kind.clone(),
837                is_xlora,
838                activation_dtype: dtype,
839                sliding_window,
840                cache_config,
841                cache_engine,
842                prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
843                model_metadata: Some(model_metadata),
844            }),
845            topology: self.config.topology.clone(),
846            silent,
847            organization: self.config.organization,
848            template_filename: paths.get_template_filename().clone(),
849            generation_config: paths.get_gen_conf_filename().cloned(),
850            config,
851            imatrix: self.config.imatrix.clone(),
852            mapper: pipeline_mapper,
853        })))
854    }
855
856    fn get_id(&self) -> String {
857        self.model_id.clone()
858    }
859
860    fn get_kind(&self) -> ModelKind {
861        self.kind.clone()
862    }
863}
864
865impl PreProcessingMixin for NormalPipeline {
866    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
867        Some(self.chat_template.clone())
868    }
869    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
870        None
871    }
872}
873
874impl IsqPipelineMixin for NormalPipeline {
875    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
876        let device = self.device().clone();
877        let multi_progress = Arc::new(MultiProgress::new());
878        self.model.quantize(
879            Some(dtype),
880            device.clone(),
881            self.topology.as_ref(),
882            self.silent,
883            self.imatrix.as_ref().map(ImatrixDataSource::File),
884            self.organization,
885            None,
886            UqffFullSer {
887                tokenizer: &self.tokenizer,
888                template_filename: &self.template_filename,
889                generation_config: self.generation_config.as_ref(),
890                config: self.config.clone(),
891                processor_filename: &None,
892                preprocessor_filename: &None,
893            },
894            multi_progress.clone(),
895        )?;
896        Ok(())
897    }
898}
899
900impl CacheManagerMixin for NormalPipeline {
901    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
902        if matches!(self.model.cache(), EitherCache::Full(_)) {
903            FullCacheManager.clone_in_cache(self, seqs, false)
904        } else {
905            NormalCacheManager.clone_in_cache(self, seqs, false)
906        }
907    }
908    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
909        if matches!(self.model.cache(), EitherCache::Full(_)) {
910            FullCacheManager.clone_out_cache(self, seqs, false)
911        } else {
912            NormalCacheManager.clone_out_cache(self, seqs, false)
913        }
914    }
915    fn set_none_cache(
916        &self,
917        seqs: &mut [&mut Sequence],
918        reset_non_granular: bool,
919        modify_draft_cache: bool,
920        load_preallocated_cache: bool,
921    ) {
922        if matches!(self.model.cache(), EitherCache::Full(_)) {
923            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
924        } else {
925            NormalCacheManager.set_none_cache(
926                self,
927                seqs,
928                modify_draft_cache,
929                load_preallocated_cache,
930            );
931        }
932        if reset_non_granular {
933            self.reset_non_granular_state()
934        }
935    }
936    fn cache(&self) -> &EitherCache {
937        self.model.cache()
938    }
939}
940
941impl MetadataMixin for NormalPipeline {
942    fn device(&self) -> Device {
943        self.model.device().clone()
944    }
945    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
946        Some(self.tokenizer.clone())
947    }
948    fn name(&self) -> String {
949        self.model_id.clone()
950    }
951    fn reset_non_granular_state(&self) {
952        if let Some(s) = self.non_granular_state.as_ref() {
953            *self.cache().full().get_scalings_cache() = None;
954            *get_mut_arcmutex!(s.non_granular_index) = 0;
955        }
956    }
957    fn get_metadata(&self) -> Arc<GeneralMetadata> {
958        self.metadata.clone()
959    }
960    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
961        Some(&*self.mapper)
962    }
963}
964
965#[async_trait::async_trait]
966impl Pipeline for NormalPipeline {
967    fn forward_inputs(
968        &mut self,
969        inputs: Box<dyn Any>,
970        return_raw_logits: bool,
971    ) -> Result<ForwardInputsResult, candle_core::Error> {
972        let ModelInputs {
973            input_ids,
974            input_ids_full,
975            seqlen_offsets,
976            seqlen_offsets_full,
977            context_lens,
978            position_ids,
979            paged_attn_meta,
980            flash_meta,
981            flash_meta_full,
982        } = *inputs.downcast().expect("Downcast failed.");
983        let metadata = self.get_metadata();
984        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
985            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
986            (Some(_), None) => {
987                // This can happen if Rust-side user code is wrong
988                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
989            }
990            (None, Some(_)) => {
991                // This should never happen but we handle it anyway
992                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
993            }
994            (None, None) => None,
995        };
996        #[cfg(feature = "metal")]
997        let logits = objc::rc::autoreleasepool(|| -> candle_core::Result<Tensor> {
998            match self.model.is_xlora() {
999                false => {
1000                    let paged_attn_meta = paged_attn_meta
1001                        .as_ref()
1002                        .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1003
1004                    self.model.forward(
1005                        &input_ids,
1006                        &seqlen_offsets,
1007                        context_lens,
1008                        position_ids,
1009                        paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1010                        &flash_meta,
1011                    )
1012                }
1013                true => self.model.xlora_forward(
1014                    &input_ids,
1015                    input_ids_full.as_ref().unwrap_or(&input_ids),
1016                    &seqlen_offsets,
1017                    seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1018                    self.no_kv_cache,
1019                    &self.non_granular_state,
1020                    context_lens,
1021                    position_ids,
1022                    &flash_meta,
1023                    flash_meta_full.as_ref().unwrap_or(&flash_meta),
1024                ),
1025            }
1026        })?;
1027        #[cfg(not(feature = "metal"))]
1028        let logits = match self.model.is_xlora() {
1029            false => {
1030                let paged_attn_meta = paged_attn_meta
1031                    .as_ref()
1032                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1033
1034                self.model.forward(
1035                    &input_ids,
1036                    &seqlen_offsets,
1037                    context_lens,
1038                    position_ids,
1039                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1040                    &flash_meta,
1041                )?
1042            }
1043            true => self.model.xlora_forward(
1044                &input_ids,
1045                input_ids_full.as_ref().unwrap_or(&input_ids),
1046                &seqlen_offsets,
1047                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1048                self.no_kv_cache,
1049                &self.non_granular_state,
1050                context_lens,
1051                position_ids,
1052                &flash_meta,
1053                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1054            )?,
1055        };
1056        if return_raw_logits {
1057            Ok(ForwardInputsResult::RawLogits { logits })
1058        } else {
1059            Ok(ForwardInputsResult::CausalGeneration { logits })
1060        }
1061    }
1062    async fn sample_causal_gen(
1063        &self,
1064        seqs: &mut [&mut Sequence],
1065        logits: Vec<Tensor>,
1066        prefix_cacher: &mut PrefixCacheManagerV2,
1067        disable_eos_stop: bool,
1068        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1069    ) -> Result<(), candle_core::Error> {
1070        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1071    }
1072    fn category(&self) -> ModelCategory {
1073        ModelCategory::Text
1074    }
1075}
1076
1077impl AnyMoePipelineMixin for NormalPipeline {
1078    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1079        self.model.finish_training(gate_model_id)
1080    }
1081    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1082        self.model.get_vars()
1083    }
1084    fn amoe_base_model_trainable_params(&self) -> usize {
1085        self.model.trainable_params()
1086    }
1087    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1088        self.model.take_cached_gating_outputs()
1089    }
1090    fn amoe_create_layers(
1091        &mut self,
1092        model_ids: Vec<String>,
1093        token: &TokenSource,
1094        revision: Option<String>,
1095        match_regex: &str,
1096        config: crate::amoe::AnyMoeConfig,
1097        dtype: candle_core::DType,
1098        dev: &Device,
1099        (prefix, mlp): (String, String),
1100        layers: Vec<usize>,
1101        expert_type: AnyMoeExpertType,
1102        silent: bool,
1103        gate_model_id: Option<String>,
1104    ) -> candle_core::Result<()> {
1105        let mut vbs = Vec::new();
1106        // Precompile regex here
1107        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1108        for model_id in model_ids {
1109            let model_id_str = &model_id;
1110            let model_id = Path::new(&model_id);
1111
1112            let api = {
1113                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1114                let mut api = ApiBuilder::from_cache(cache)
1115                    .with_progress(!silent)
1116                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1117                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1118                    api = api.with_cache_dir(x.into());
1119                }
1120                api.build().map_err(candle_core::Error::msg)?
1121            };
1122            let revision = revision.clone().unwrap_or("main".to_string());
1123            let api = api.repo(Repo::with_revision(
1124                model_id_str.clone(),
1125                RepoType::Model,
1126                revision.clone(),
1127            ));
1128
1129            let mut filenames = vec![];
1130            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1131                filenames.push(api_get_file!(api, &rfilename, model_id));
1132            }
1133
1134            let regex = regex.clone();
1135            let match_regex_clone = match_regex.to_string();
1136            let layers_clone = layers.clone();
1137            let vb = from_mmaped_safetensors(
1138                filenames,
1139                vec![],
1140                Some(dtype),
1141                dev,
1142                vec![None],
1143                silent,
1144                None,
1145                move |key| {
1146                    if regex.is_match(&key) {
1147                        // Idx of the last char of the layer id, +1
1148                        // Assumes N.MLP
1149                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1150                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1151                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1152                            .parse::<usize>()
1153                            .unwrap();
1154                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1155                    } else {
1156                        false
1157                    }
1158                },
1159                Arc::new(|_| DeviceForLoadTensor::Base),
1160            )?;
1161            vbs.push(vb);
1162        }
1163
1164        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1165            let model_id_str = &gate_model_id;
1166            let model_id = Path::new(&gate_model_id);
1167
1168            let api = {
1169                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1170                let mut api = ApiBuilder::from_cache(cache)
1171                    .with_progress(!silent)
1172                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1173                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1174                    api = api.with_cache_dir(x.into());
1175                }
1176                api.build().map_err(candle_core::Error::msg)?
1177            };
1178            let revision = revision.clone().unwrap_or("main".to_string());
1179            let api = api.repo(Repo::with_revision(
1180                model_id_str.clone(),
1181                RepoType::Model,
1182                revision.clone(),
1183            ));
1184
1185            let mut gate_filenames = vec![];
1186            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1187                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1188            }
1189            assert_eq!(
1190                gate_filenames.len(),
1191                1,
1192                "Gate model ID must contain only one .safetensors file"
1193            );
1194
1195            let vb = from_mmaped_safetensors(
1196                gate_filenames.clone(),
1197                vec![],
1198                Some(dtype),
1199                dev,
1200                vec![None],
1201                silent,
1202                None,
1203                |_| true,
1204                Arc::new(|_| DeviceForLoadTensor::Base),
1205            )?;
1206            info!(
1207                "Loaded gating layers from `{}`",
1208                gate_filenames[0].display()
1209            );
1210            Some(vb)
1211        } else {
1212            None
1213        };
1214
1215        self.model.create_anymoe_layers(
1216            vbs.clone(),
1217            config.clone(),
1218            (prefix.clone(), mlp.clone()),
1219            layers.clone(),
1220            expert_type.clone(),
1221            gate_vb.clone(),
1222        )?;
1223
1224        Ok(())
1225    }
1226    fn amoe_supported(&self) -> bool {
1227        self.model.amoe_supported()
1228    }
1229}