mistralrs_core/pipeline/
normal.rs

1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6    TokenSource,
7};
8use super::{
9    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14    GptOssLoader, GraniteMoeHybridLoader, LlamaLoader, MistralLoader, MixtralLoader,
15    NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
16    Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::attention::ATTENTION_CHUNK_SIZE;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
23use crate::lora::Ordering;
24use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
25use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::loaders::auto_device_map;
28use crate::pipeline::loaders::QuantizationConfigShim;
29use crate::pipeline::sampling::sample_and_add_toks;
30use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
31use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
32use crate::pipeline::{ChatTemplate, LocalModelPaths};
33use crate::prefix_cacher::PrefixCacheManagerV2;
34use crate::sequence::Sequence;
35use crate::utils::tokenizer::get_tokenizer;
36use crate::utils::varbuilder_utils::DeviceForLoadTensor;
37use crate::utils::{
38    progress::{new_multi_progress, ProgressScopeGuard},
39    tokens::get_token,
40    varbuilder_utils::from_mmaped_safetensors,
41};
42use crate::xlora_models::NonGranularState;
43use crate::{
44    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
45    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
46    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
47};
48use anyhow::Result;
49use candle_core::{Device, Tensor, Var};
50use hf_hub::Cache;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::log::once_log_info;
53use mistralrs_quant::{
54    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
55};
56use rand_isaac::Isaac64Rng;
57use regex_automata::meta::Regex;
58use std::any::Any;
59use std::borrow::Cow;
60use std::path::{Path, PathBuf};
61use std::str::FromStr;
62use std::sync::{Arc, RwLock};
63use std::time::Instant;
64use std::{env, fs};
65use tokenizers::Tokenizer;
66use tokio::sync::Mutex;
67use tracing::{info, warn};
68
69pub struct NormalPipeline {
70    model: Box<dyn NormalModel + Send + Sync>,
71    tokenizer: Arc<Tokenizer>,
72    no_kv_cache: bool,
73    chat_template: Arc<ChatTemplate>,
74    non_granular_state: Option<NonGranularState>,
75    model_id: String,
76    metadata: Arc<GeneralMetadata>,
77    topology: Option<Topology>,
78    silent: bool,
79    organization: IsqOrganization,
80    // For full UQFF serialization
81    template_filename: Option<PathBuf>,
82    generation_config: Option<PathBuf>,
83    config: String,
84    imatrix: Option<PathBuf>,
85    mapper: Box<dyn DeviceMapper + Send + Sync>,
86}
87
88/// A loader for a "normal" (non-quantized) model.
89pub struct NormalLoader {
90    inner: Box<dyn NormalModelLoader>,
91    model_id: String,
92    config: NormalSpecificConfig,
93    xlora_model_id: Option<String>,
94    lora_adapter_ids: Option<Vec<String>>,
95    kind: ModelKind,
96    xlora_order: Option<Ordering>,
97    no_kv_cache: bool,
98    chat_template: Option<String>,
99    tokenizer_json: Option<String>,
100    tgt_non_granular_index: Option<usize>,
101    token_source: RwLock<Option<TokenSource>>,
102    revision: RwLock<Option<String>>,
103    from_uqff: RwLock<Option<Vec<PathBuf>>>,
104    jinja_explicit: Option<String>,
105    hf_cache_path: Option<PathBuf>,
106}
107
108#[derive(Default)]
109/// A builder for a loader for a "normal" (non-quantized) model.
110pub struct NormalLoaderBuilder {
111    model_id: Option<String>,
112    config: NormalSpecificConfig,
113    xlora_model_id: Option<String>,
114    lora_adapter_ids: Option<Vec<String>>,
115    kind: ModelKind,
116    xlora_order: Option<Ordering>,
117    no_kv_cache: bool,
118    chat_template: Option<String>,
119    tokenizer_json: Option<String>,
120    tgt_non_granular_index: Option<usize>,
121    jinja_explicit: Option<String>,
122    hf_cache_path: Option<PathBuf>,
123}
124
125#[derive(Clone, Default)]
126/// Config specific to loading a normal model.
127pub struct NormalSpecificConfig {
128    pub topology: Option<Topology>,
129    pub organization: IsqOrganization,
130    pub write_uqff: Option<PathBuf>,
131    pub from_uqff: Option<Vec<PathBuf>>,
132    pub imatrix: Option<PathBuf>,
133    pub calibration_file: Option<PathBuf>,
134    pub hf_cache_path: Option<PathBuf>,
135    pub matformer_config_path: Option<PathBuf>,
136    pub matformer_slice_name: Option<String>,
137}
138
139impl NormalLoaderBuilder {
140    pub fn new(
141        config: NormalSpecificConfig,
142        chat_template: Option<String>,
143        tokenizer_json: Option<String>,
144        model_id: Option<String>,
145        no_kv_cache: bool,
146        jinja_explicit: Option<String>,
147    ) -> Self {
148        Self {
149            config,
150            chat_template,
151            tokenizer_json,
152            model_id,
153            kind: ModelKind::Normal,
154            jinja_explicit,
155            no_kv_cache,
156            ..Default::default()
157        }
158    }
159
160    fn with_adapter(
161        mut self,
162        xlora_model_id: String,
163        xlora_order: Ordering,
164        no_kv_cache: bool,
165        tgt_non_granular_index: Option<usize>,
166    ) -> Self {
167        self.xlora_model_id = Some(xlora_model_id);
168        self.xlora_order = Some(xlora_order);
169        self.no_kv_cache = no_kv_cache;
170        self.tgt_non_granular_index = tgt_non_granular_index;
171        self.model_id = if let Some(id) = self.model_id {
172            Some(id)
173        } else {
174            info!(
175                "Using adapter base model ID: `{}`",
176                self.xlora_order.as_ref().unwrap().base_model_id
177            );
178            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
179        };
180        self
181    }
182
183    pub fn with_xlora(
184        mut self,
185        xlora_model_id: String,
186        xlora_order: Ordering,
187        no_kv_cache: bool,
188        tgt_non_granular_index: Option<usize>,
189    ) -> Self {
190        self.kind = ModelKind::Adapter {
191            adapter: AdapterKind::XLora,
192        };
193        self.with_adapter(
194            xlora_model_id,
195            xlora_order,
196            no_kv_cache,
197            tgt_non_granular_index,
198        )
199    }
200
201    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
202        self.kind = ModelKind::Adapter {
203            adapter: AdapterKind::Lora,
204        };
205        self.lora_adapter_ids = Some(lora_adapter_ids);
206        self
207    }
208
209    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
210        self.hf_cache_path = Some(hf_cache_path);
211        self
212    }
213
214    /// If the loader type is not specified, loader type is automatically determined from the
215    /// `architectures` array in the config.
216    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
217        let loader: Box<dyn NormalModelLoader> = match loader_tp {
218            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
219            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
220            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
221            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
222            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
223            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
224            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
225            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
226            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
227            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
228            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
229            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
230            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
231            Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
232            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
233            Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
234            Some(NormalLoaderType::GraniteMoeHybrid) => Box::new(GraniteMoeHybridLoader),
235            Some(NormalLoaderType::GptOss) => Box::new(GptOssLoader),
236            None => Box::new(AutoNormalLoader),
237        };
238        Ok(Box::new(NormalLoader {
239            inner: loader,
240            model_id: self.model_id.unwrap(),
241            config: self.config,
242            xlora_model_id: self.xlora_model_id,
243            lora_adapter_ids: self.lora_adapter_ids,
244            kind: self.kind,
245            xlora_order: self.xlora_order,
246            no_kv_cache: self.no_kv_cache,
247            chat_template: self.chat_template,
248            tokenizer_json: self.tokenizer_json,
249            tgt_non_granular_index: self.tgt_non_granular_index,
250            jinja_explicit: self.jinja_explicit,
251            token_source: RwLock::new(None),
252            revision: RwLock::new(None),
253            from_uqff: RwLock::new(None),
254            hf_cache_path: self.hf_cache_path,
255        }))
256    }
257}
258
259impl Loader for NormalLoader {
260    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
261    fn load_model_from_hf(
262        &self,
263        revision: Option<String>,
264        token_source: TokenSource,
265        dtype: &dyn TryIntoDType,
266        device: &Device,
267        silent: bool,
268        mapper: DeviceMapSetting,
269        in_situ_quant: Option<IsqType>,
270        paged_attn_config: Option<PagedAttentionConfig>,
271    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
272        let _progress_guard = ProgressScopeGuard::new(silent);
273        let cache = self
274            .hf_cache_path
275            .clone()
276            .map(Cache::new)
277            .unwrap_or_default();
278        GLOBAL_HF_CACHE.get_or_init(|| cache);
279
280        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
281            LocalModelPaths,
282            &token_source,
283            revision.clone(),
284            self,
285            None,
286            None,
287            silent,
288            self.config.from_uqff.is_some()
289        );
290        if let Some(from_uqff) = self.config.from_uqff.clone() {
291            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
292        }
293        *self
294            .token_source
295            .write()
296            .expect("Failed to write to token source") = Some(token_source);
297        *self.revision.write().expect("Failed to write to revision") = revision;
298        self.load_model_from_path(
299            &paths?,
300            dtype,
301            device,
302            silent,
303            mapper,
304            in_situ_quant,
305            paged_attn_config,
306        )
307    }
308
309    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
310    fn load_model_from_path(
311        &self,
312        paths: &Box<dyn ModelPaths>,
313        dtype: &dyn TryIntoDType,
314        device: &Device,
315        silent: bool,
316        mut mapper: DeviceMapSetting,
317        mut in_situ_quant: Option<IsqType>,
318        mut paged_attn_config: Option<PagedAttentionConfig>,
319    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
320        let _progress_guard = ProgressScopeGuard::new(silent);
321        let config = std::fs::read_to_string(paths.get_config_filename())?;
322
323        if !self.inner.supports_paged_attention(&config)? {
324            paged_attn_config = None;
325        }
326
327        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
328
329        let use_nccl = mistralrs_quant::distributed::use_nccl();
330
331        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
332            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
333            let WorkerTransferData::Init { id: _, worker_rank } = payload;
334            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
335        } else if use_nccl {
336            vec![candle_core::Device::new_cuda(0)?]
337        } else {
338            device_map::get_all_similar_devices(device)?
339        };
340        #[cfg(feature = "cuda")]
341        for device in &available_devices {
342            if let Device::Cuda(dev) = device {
343                unsafe { dev.disable_event_tracking() };
344            }
345        }
346        let device = if use_nccl || cfg!(feature = "ring") {
347            available_devices[0].clone()
348        } else {
349            device.clone()
350        };
351
352        // If auto, convert to Map if not using nccl
353        if use_nccl || cfg!(feature = "ring") {
354            mapper = DeviceMapSetting::DummyNccl {
355                nm_device: available_devices[0].clone(),
356            };
357        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
358            // Initial dtype
359            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
360
361            // Disable ISQ if we are loading a prequantized model.
362            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
363                in_situ_quant = None;
364            }
365
366            // ISQ or UQFF: quantized path
367            // Match logic below where UQFF has priority
368            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
369                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
370                    let weight_pack_factor = {
371                        let ser_artifacts = unsafe {
372                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
373                        };
374                        let mut total_pack_factors = 0;
375                        let total_tensors = ser_artifacts.tensors().len();
376                        for (_, artifact) in ser_artifacts.tensors() {
377                            let artifact = artifact.data();
378                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
379                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
380                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
381                            {
382                                QuantizedSerdeType::Hqq => {
383                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
384                                        .pack_factor(dtype)
385                                }
386                                QuantizedSerdeType::Gguf => {
387                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
388                                        .pack_factor(dtype)
389                                }
390                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
391                                QuantizedSerdeType::Unquant => 1,
392                                QuantizedSerdeType::Afq => {
393                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
394                                        .pack_factor(dtype)
395                                }
396                            };
397                            total_pack_factors += pack_factor;
398                        }
399
400                        total_pack_factors / total_tensors
401                    };
402
403                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
404                        &config,
405                        dtype,
406                        weight_pack_factor,
407                        None,
408                    )?;
409                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
410                        &config,
411                        dtype,
412                        weight_pack_factor,
413                        None,
414                    )?;
415                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
416                    (
417                        layer_sizes_in_bytes,
418                        non_mapped_size_in_bytes,
419                        layer_sizes_sum + non_mapped_size_in_bytes,
420                    )
421                } else if let Some(isq) = in_situ_quant {
422                    let weight_pack_factor = isq.pack_factor(dtype);
423                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
424                        &config,
425                        dtype,
426                        weight_pack_factor,
427                        None,
428                    )?;
429                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
430                        &config,
431                        dtype,
432                        weight_pack_factor,
433                        None,
434                    )?;
435                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
436                    (
437                        layer_sizes_in_bytes,
438                        non_mapped_size_in_bytes,
439                        layer_sizes_sum + non_mapped_size_in_bytes,
440                    )
441                } else {
442                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
443                    let weight_pack_factor =
444                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
445                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
446                        &config,
447                        dtype,
448                        weight_pack_factor,
449                        None,
450                    )?;
451                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
452                        &config,
453                        dtype,
454                        weight_pack_factor,
455                        None,
456                    )?;
457                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
458                    (
459                        layer_sizes_in_bytes,
460                        non_mapped_size_in_bytes,
461                        layer_sizes_sum + non_mapped_size_in_bytes,
462                    )
463                };
464
465            let new = auto_device_map::get_device_layers(
466                &*self.inner,
467                &config,
468                self.inner.num_layers(&config)?,
469                layer_sizes_in_bytes,
470                non_mapped_size_in_bytes,
471                total_model_size_in_bytes,
472                &available_devices,
473                dtype,
474                &params,
475                paged_attn_config.as_ref(),
476            )?;
477            mapper = DeviceMapSetting::Map(new);
478        }
479
480        let pipeline_mapper = mapper.into_mapper(
481            self.inner.num_layers(&config)?,
482            &device,
483            self.config.topology.as_ref(),
484        )?;
485        let mapper = mapper.into_mapper(
486            self.inner.num_layers(&config)?,
487            &device,
488            self.config.topology.as_ref(),
489        )?;
490        let mut layer_devices = Vec::new();
491        for layer in 0..self.inner.num_layers(&config)? {
492            let device = mapper.device_for(layer, false).cloned();
493            layer_devices.push(device);
494        }
495        let dtype = mapper.get_min_dtype(dtype)?;
496
497        // TODO: PagedAttention is not supported with CPU for now.
498        // This check is not really necessary because `get_device_layers` should prevent it.
499        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
500        if mapping_uses_cpu && paged_attn_config.is_some() {
501            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
502            paged_attn_config = None;
503        }
504
505        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
506        if crate::using_flash_attn() {
507            once_log_info("FlashAttention is enabled.");
508        }
509
510        let topology_overrides = self
511            .config
512            .topology
513            .as_ref()
514            .map(|topology| {
515                topology
516                    .pattern_overrides()
517                    .into_iter()
518                    .map(|(regex, layer)| ImmediateIsqOverride {
519                        predicate: regex,
520                        ty: layer.isq,
521                        device: layer.device.clone(),
522                    })
523                    .collect::<Vec<_>>()
524            })
525            .unwrap_or_default();
526        let has_override_isq = topology_overrides
527            .iter()
528            .any(|override_entry| override_entry.ty.is_some());
529        let topology_requires_post_quant = self
530            .config
531            .topology
532            .as_ref()
533            .is_some_and(|topology| topology.requires_post_quantization());
534
535        let allow_immediate_cli = self.config.imatrix.is_none()
536            && self.config.calibration_file.is_none()
537            && !device.is_cuda()
538            && in_situ_quant.is_some();
539
540        let mut immediate_ty = None;
541        let mut immediate_predicates = Vec::new();
542        if allow_immediate_cli {
543            immediate_ty = in_situ_quant;
544            immediate_predicates =
545                if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
546                    self.inner.immediate_isq_predicates_moqe(&config)?
547                } else {
548                    self.inner.immediate_isq_predicates(&config)?
549                };
550            info!("Applying ISQ to {in_situ_quant:?}");
551            if immediate_predicates.is_empty() {
552                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
553            }
554        }
555
556        let use_immediate = allow_immediate_cli || has_override_isq;
557        if use_immediate {
558            mistralrs_quant::set_immediate_isq_with_overrides(
559                immediate_ty,
560                immediate_predicates.clone(),
561                topology_overrides.clone(),
562            );
563        }
564
565        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
566        let mut loading_isq = if use_immediate {
567            false
568        } else {
569            in_situ_quant.is_some()
570        };
571        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
572            loading_isq = true;
573        }
574        loading_isq |= topology_requires_post_quant;
575
576        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
577            anyhow::bail!(
578                "`imatrix` and `calibration_file` were both specified, this is not allowed."
579            );
580        }
581
582        // Load onto the regular device if not using isq or if the calibration file is specified
583        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
584            loading_isq = false;
585            device.clone()
586        } else {
587            Device::Cpu
588        };
589
590        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
591
592        let attention_mechanism = if paged_attn_config.is_some() {
593            AttentionImplementation::PagedAttention
594        } else {
595            AttentionImplementation::Eager
596        };
597
598        let multi_progress = Arc::new(new_multi_progress());
599
600        // Load matformer slicing config if provided
601        let matformer_slicing_config = if let Some(matformer_path) =
602            &self.config.matformer_config_path
603        {
604            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
605            info!("Loading Matformer config from {:?}", matformer_path);
606            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
607
608            if let Some(slice_name) = &self.config.matformer_slice_name {
609                info!("Using Matformer slice: {}", slice_name);
610                Some(MatformerSliceConfig::new(slice_name.clone(), config))
611            } else {
612                // If no slice name is provided but config exists, we'll need to handle this
613                // For now, return None and let the model handle the default slice selection
614                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
615                None
616            }
617        } else {
618            None
619        };
620
621        let mut model = if use_nccl || cfg!(feature = "ring") {
622            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
623                dtype,
624                &device,
625                &available_devices,
626                silent,
627                &config,
628                loading_isq,
629                self.config.from_uqff.is_some(),
630                self.config.organization,
631                &*self.inner,
632                paths.as_ref(),
633            )?;
634
635            // Special case for where things can be more optimially loaded.
636            match self.kind {
637                ModelKind::Normal => normal_model_loader_sharded!(
638                    sharded_vb,
639                    config,
640                    self.inner,
641                    mapper,
642                    loading_isq,
643                    device.clone(),
644                    attention_mechanism,
645                    multi_progress.clone(),
646                    matformer_slicing_config.clone(),
647                ),
648                ModelKind::Adapter {
649                    adapter: AdapterKind::XLora,
650                } => xlora_model_loader!(
651                    paths,
652                    Some(dtype),
653                    &load_device,
654                    layer_devices.clone(),
655                    config,
656                    self.inner,
657                    silent,
658                    mapper,
659                    loading_isq,
660                    device.clone(),
661                    multi_progress.clone(),
662                    matformer_slicing_config.clone(),
663                ),
664                ModelKind::Adapter {
665                    adapter: AdapterKind::Lora,
666                } => lora_model_loader!(
667                    paths,
668                    Some(dtype),
669                    &load_device,
670                    layer_devices.clone(),
671                    config,
672                    self.inner,
673                    silent,
674                    mapper,
675                    loading_isq,
676                    self.config.from_uqff.is_some(),
677                    device.clone(),
678                    attention_mechanism,
679                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
680                    multi_progress.clone(),
681                    matformer_slicing_config.clone(),
682                ),
683                _ => unreachable!(),
684            }
685        } else {
686            match self.kind {
687                ModelKind::Normal => normal_model_loader!(
688                    paths,
689                    Some(dtype),
690                    &load_device,
691                    layer_devices.clone(),
692                    config,
693                    self.inner,
694                    silent,
695                    mapper,
696                    loading_isq,
697                    self.config.from_uqff.is_some(),
698                    device.clone(),
699                    attention_mechanism,
700                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
701                    multi_progress.clone(),
702                    matformer_slicing_config.clone(),
703                ),
704                ModelKind::Adapter {
705                    adapter: AdapterKind::XLora,
706                } => xlora_model_loader!(
707                    paths,
708                    Some(dtype),
709                    &load_device,
710                    layer_devices.clone(),
711                    config,
712                    self.inner,
713                    silent,
714                    mapper,
715                    loading_isq,
716                    device.clone(),
717                    multi_progress.clone(),
718                    matformer_slicing_config.clone(),
719                ),
720                ModelKind::Adapter {
721                    adapter: AdapterKind::Lora,
722                } => lora_model_loader!(
723                    paths,
724                    Some(dtype),
725                    &load_device,
726                    layer_devices.clone(),
727                    config,
728                    self.inner,
729                    silent,
730                    mapper,
731                    loading_isq,
732                    self.config.from_uqff.is_some(),
733                    device.clone(),
734                    attention_mechanism,
735                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
736                    multi_progress.clone(),
737                    matformer_slicing_config.clone(),
738                ),
739                _ => unreachable!(),
740            }
741        };
742
743        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
744        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
745            match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
746                Ok(conf) => Some(conf),
747                Err(e) => {
748                    warn!("Failed to parse generation_config.json: {}", e);
749                    None
750                }
751            }
752        });
753
754        let chat_template_explicit = paths
755            .get_chat_template_explicit()
756            .as_ref()
757            .map(|x| x.to_string_lossy().to_string());
758        let chat_template = get_chat_template(
759            paths,
760            self.jinja_explicit.as_ref(),
761            chat_template_explicit.as_ref(),
762            self.chat_template.as_ref(),
763            None,
764        );
765
766        if let Some(calibration_file) = &self.config.calibration_file {
767            let calibration_data = std::fs::read_to_string(calibration_file)?;
768            // Tokenize, don't add bos yet
769            let tokens = tokenizer
770                .encode_fast(calibration_data, false)
771                .map_err(anyhow::Error::msg)?
772                .get_ids()
773                .to_vec();
774            info!(
775                "Collecting imatrix from calibration file `{}` of {} tokens.",
776                calibration_file.display(),
777                tokens.len()
778            );
779            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
780            let bos_tok_id = tokenizer
781                .token_to_id(&bos_toks[0])
782                .expect("Somehow the bos token is not present.");
783
784            match self.config.organization {
785                IsqOrganization::Default => model.begin_track_stats()?,
786                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
787            }
788
789            const CHUNK_SIZE: usize = 1024;
790            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
791            let start = Instant::now();
792            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
793                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
794                let chunk_len = chunk.len();
795
796                let start = Instant::now();
797                let inputs = make_prompt_chunk(
798                    0,
799                    vec![&chunk],
800                    &[0],
801                    &load_device,
802                    None,
803                    false,
804                    None,
805                    Some(pipeline_mapper.as_ref()),
806                )?;
807
808                model.forward(
809                    &inputs.input.to_device(model.device())?,
810                    &inputs.positions,
811                    inputs.context_lens.clone(),
812                    inputs.position_ids.clone(),
813                    None,
814                    &inputs.flash_meta.clone(),
815                )?;
816
817                match model.cache_mut() {
818                    EitherCache::Full(full) => {
819                        for layer in &mut *full.lock() {
820                            *layer = None
821                        }
822                    }
823                    EitherCache::Normal(normal) => {
824                        for layer in &mut *normal.lock().unwrap().0 {
825                            layer.reset();
826                        }
827                    }
828                    EitherCache::Hybrid(hybrid) => {
829                        hybrid.lock().unwrap().reset();
830                    }
831                }
832
833                let end = Instant::now();
834                info!(
835                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
836                    i + 1,
837                    end.duration_since(start).as_secs_f32()
838                );
839            }
840            load_device.synchronize()?;
841            let end = Instant::now();
842            info!(
843                "Finished collecting imatrix in {:.2}s",
844                end.duration_since(start).as_secs_f32()
845            );
846        }
847
848        // Only if loading from UQFF
849        let should_serialize = self.config.write_uqff.is_some();
850        let should_quantize_pass = loading_isq;
851
852        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
853            let imatrix_source = if should_quantize_pass {
854                match (
855                    self.config.imatrix.as_ref(),
856                    self.config.calibration_file.is_some(),
857                ) {
858                    (None, false) => None,
859                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
860                    (None, true) => Some(ImatrixDataSource::Collected),
861                    (Some(_), true) => unreachable!(),
862                }
863            } else {
864                None
865            };
866
867            if should_quantize_pass {
868                info!("Applying ISQ to all ranks.");
869            } else {
870                info!("Serializing existing ISQ tensors without additional quantization.");
871            }
872
873            let multi_progress = Arc::new(new_multi_progress());
874
875            model.quantize(
876                in_situ_quant,
877                model.device().clone(),
878                self.config.topology.as_ref(),
879                silent,
880                imatrix_source,
881                self.config.organization,
882                should_quantize_pass,
883                self.config.write_uqff.as_ref(),
884                UqffFullSer {
885                    tokenizer: &tokenizer,
886                    template_filename: paths.get_template_filename(),
887                    generation_config: paths.get_gen_conf_filename(),
888                    config: config.clone(),
889                    processor_filename: &None,
890                    preprocessor_filename: &None,
891                    modules: None,
892                    module_paths: None,
893                },
894                multi_progress.clone(),
895            )?;
896        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
897            model.load_from_artifacts(
898                device.clone(),
899                self.config.topology.as_ref(),
900                silent,
901                from_uqff,
902            )?;
903        }
904
905        let paged_attn_config = if matches!(
906            self.kind,
907            ModelKind::Adapter {
908                adapter: AdapterKind::XLora
909            }
910        ) {
911            warn!(
912                "Adapter parallel_models do not currently support PagedAttention, running without"
913            );
914            None
915        } else {
916            paged_attn_config
917        };
918
919        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
920            let cache_config = calculate_cache_config(
921                paged_attn_config.mem_gpu,
922                paged_attn_config.block_size,
923                dtype,
924                paged_attn_config.cache_type,
925                model.config(),
926                &device,
927                &pipeline_mapper
928                    .get_unique_devices()
929                    .into_iter()
930                    .map(Some)
931                    .collect::<Vec<_>>(),
932                silent,
933            )?;
934
935            let mut layer_devices = Vec::new();
936            for layer in 0..self.inner.num_layers(&config)? {
937                let device = model.get_layers().1.device_for(layer, false).cloned();
938                layer_devices.push(device);
939            }
940            let cache_engine = CacheEngine::new(
941                model.config(),
942                &cache_config,
943                dtype,
944                model.device(),
945                layer_devices.clone(),
946            )?;
947
948            (Some(cache_config), Some(cache_engine))
949        } else {
950            (None, None)
951        };
952
953        let max_seq_len = model.max_seq_len();
954        let llg_factory = build_llg_factory(tokenizer.clone())?;
955        let num_hidden_layers = match model.cache() {
956            EitherCache::Full(full) => full.lock().len(),
957            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
958            EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
959        };
960        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
961        let sliding_window = model.config().sliding_window;
962        let model_metadata = Arc::new(model.config().clone());
963
964        Ok(Arc::new(Mutex::new(NormalPipeline {
965            model,
966            tokenizer: tokenizer.into(),
967            no_kv_cache: self.no_kv_cache,
968            chat_template: Arc::new(chat_template),
969            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
970                NonGranularState {
971                    non_granular_index: Arc::new(Mutex::new(0)),
972                    tgt_non_granular_index,
973                }
974            }),
975            model_id: self.model_id.clone(),
976            metadata: Arc::new(GeneralMetadata {
977                max_seq_len,
978                llg_factory: Some(llg_factory),
979                no_kv_cache: self.no_kv_cache,
980                no_prefix_cache: is_xlora,
981                num_hidden_layers,
982                eos_tok: eos,
983                kind: self.kind.clone(),
984                is_xlora,
985                activation_dtype: dtype,
986                sliding_window,
987                cache_config,
988                cache_engine,
989                model_metadata: Some(model_metadata),
990                modalities: Modalities {
991                    input: vec![SupportedModality::Text],
992                    output: vec![SupportedModality::Text],
993                },
994            }),
995            topology: self.config.topology.clone(),
996            silent,
997            organization: self.config.organization,
998            template_filename: paths.get_template_filename().clone(),
999            generation_config: paths.get_gen_conf_filename().cloned(),
1000            config,
1001            imatrix: self.config.imatrix.clone(),
1002            mapper: pipeline_mapper,
1003        })))
1004    }
1005
1006    fn get_id(&self) -> String {
1007        self.model_id.clone()
1008    }
1009
1010    fn get_kind(&self) -> ModelKind {
1011        self.kind.clone()
1012    }
1013}
1014
1015impl PreProcessingMixin for NormalPipeline {
1016    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1017        Some(self.chat_template.clone())
1018    }
1019    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1020        None
1021    }
1022}
1023
1024impl IsqPipelineMixin for NormalPipeline {
1025    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1026        let device = self.device().clone();
1027        let multi_progress = Arc::new(new_multi_progress());
1028        self.model.quantize(
1029            Some(dtype),
1030            device.clone(),
1031            self.topology.as_ref(),
1032            self.silent,
1033            self.imatrix.as_ref().map(ImatrixDataSource::File),
1034            self.organization,
1035            true,
1036            None,
1037            UqffFullSer {
1038                tokenizer: &self.tokenizer,
1039                template_filename: &self.template_filename,
1040                generation_config: self.generation_config.as_ref(),
1041                config: self.config.clone(),
1042                processor_filename: &None,
1043                preprocessor_filename: &None,
1044                modules: None,
1045                module_paths: None,
1046            },
1047            multi_progress.clone(),
1048        )?;
1049        Ok(())
1050    }
1051}
1052
1053impl CacheManagerMixin for NormalPipeline {
1054    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1055        match self.model.cache() {
1056            EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1057            EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1058            EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1059        }
1060    }
1061    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1062        match self.model.cache() {
1063            EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1064            EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1065            EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1066        }
1067    }
1068    fn set_none_cache(
1069        &self,
1070        seqs: &mut [&mut Sequence],
1071        reset_non_granular: bool,
1072        modify_draft_cache: bool,
1073        load_preallocated_cache: bool,
1074    ) {
1075        match self.model.cache() {
1076            EitherCache::Full(_) => {
1077                FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1078            }
1079            EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1080                self,
1081                seqs,
1082                modify_draft_cache,
1083                load_preallocated_cache,
1084            ),
1085            EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1086                self,
1087                seqs,
1088                modify_draft_cache,
1089                load_preallocated_cache,
1090            ),
1091        }
1092        if reset_non_granular {
1093            self.reset_non_granular_state()
1094        }
1095    }
1096    fn cache(&self) -> &EitherCache {
1097        self.model.cache()
1098    }
1099}
1100
1101impl MetadataMixin for NormalPipeline {
1102    fn device(&self) -> Device {
1103        self.model.device().clone()
1104    }
1105    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1106        Some(self.tokenizer.clone())
1107    }
1108    fn name(&self) -> String {
1109        self.model_id.clone()
1110    }
1111    fn reset_non_granular_state(&self) {
1112        if let Some(s) = self.non_granular_state.as_ref() {
1113            *self.cache().full().get_scalings_cache() = None;
1114            *get_mut_arcmutex!(s.non_granular_index) = 0;
1115        }
1116    }
1117    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1118        self.metadata.clone()
1119    }
1120    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1121        Some(&*self.mapper)
1122    }
1123}
1124
1125#[async_trait::async_trait]
1126impl Pipeline for NormalPipeline {
1127    fn forward_inputs(
1128        &mut self,
1129        inputs: Box<dyn Any>,
1130        return_raw_logits: bool,
1131    ) -> Result<ForwardInputsResult, candle_core::Error> {
1132        let ModelInputs {
1133            input_ids,
1134            input_ids_full,
1135            seqlen_offsets,
1136            seqlen_offsets_full,
1137            context_lens,
1138            position_ids,
1139            paged_attn_meta,
1140            flash_meta,
1141            flash_meta_full,
1142        } = *inputs.downcast().expect("Downcast failed.");
1143        let metadata = self.get_metadata();
1144        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1145            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1146            (Some(_), None) => {
1147                // This can happen if Rust-side user code is wrong
1148                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1149            }
1150            (None, Some(_)) => {
1151                // This should never happen but we handle it anyway
1152                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1153            }
1154            (None, None) => None,
1155        };
1156        let logits = match self.model.is_xlora() {
1157            false => {
1158                let paged_attn_meta = paged_attn_meta
1159                    .as_ref()
1160                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1161
1162                self.model.forward(
1163                    &input_ids,
1164                    &seqlen_offsets,
1165                    context_lens,
1166                    position_ids,
1167                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1168                    &flash_meta,
1169                )?
1170            }
1171            true => self.model.xlora_forward(
1172                &input_ids,
1173                input_ids_full.as_ref().unwrap_or(&input_ids),
1174                &seqlen_offsets,
1175                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1176                self.no_kv_cache,
1177                &self.non_granular_state,
1178                context_lens,
1179                position_ids,
1180                &flash_meta,
1181                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1182            )?,
1183        };
1184        if return_raw_logits {
1185            Ok(ForwardInputsResult::RawLogits { logits })
1186        } else {
1187            Ok(ForwardInputsResult::CausalGeneration { logits })
1188        }
1189    }
1190    async fn sample_causal_gen(
1191        &self,
1192        seqs: &mut [&mut Sequence],
1193        logits: Vec<Tensor>,
1194        prefix_cacher: &mut PrefixCacheManagerV2,
1195        disable_eos_stop: bool,
1196        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1197    ) -> Result<(), candle_core::Error> {
1198        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1199    }
1200    fn category(&self) -> ModelCategory {
1201        ModelCategory::Text
1202    }
1203}
1204
1205impl AnyMoePipelineMixin for NormalPipeline {
1206    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1207        self.model.finish_training(gate_model_id)
1208    }
1209    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1210        self.model.get_vars()
1211    }
1212    fn amoe_base_model_trainable_params(&self) -> usize {
1213        self.model.trainable_params()
1214    }
1215    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1216        self.model.take_cached_gating_outputs()
1217    }
1218    fn amoe_create_layers(
1219        &mut self,
1220        model_ids: Vec<String>,
1221        token: &TokenSource,
1222        revision: Option<String>,
1223        match_regex: &str,
1224        config: crate::amoe::AnyMoeConfig,
1225        dtype: candle_core::DType,
1226        dev: &Device,
1227        (prefix, mlp): (String, String),
1228        layers: Vec<usize>,
1229        expert_type: AnyMoeExpertType,
1230        silent: bool,
1231        gate_model_id: Option<String>,
1232    ) -> candle_core::Result<()> {
1233        let mut vbs = Vec::new();
1234        // Precompile regex here
1235        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1236        for model_id in model_ids {
1237            let model_id_str = &model_id;
1238            let model_id = Path::new(&model_id);
1239
1240            let api = {
1241                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1242                let mut api = ApiBuilder::from_cache(cache)
1243                    .with_progress(!silent)
1244                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1245                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1246                    api = api.with_cache_dir(x.into());
1247                }
1248                api.build().map_err(candle_core::Error::msg)?
1249            };
1250            let revision = revision.clone().unwrap_or("main".to_string());
1251            let api = api.repo(Repo::with_revision(
1252                model_id_str.clone(),
1253                RepoType::Model,
1254                revision.clone(),
1255            ));
1256
1257            let mut filenames = vec![];
1258            for rfilename in
1259                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1260            {
1261                filenames.push(api_get_file!(api, &rfilename, model_id));
1262            }
1263
1264            let regex = regex.clone();
1265            let match_regex_clone = match_regex.to_string();
1266            let layers_clone = layers.clone();
1267            let vb = from_mmaped_safetensors(
1268                filenames,
1269                vec![],
1270                Some(dtype),
1271                dev,
1272                vec![None],
1273                silent,
1274                None,
1275                move |key| {
1276                    if regex.is_match(&key) {
1277                        // Idx of the last char of the layer id, +1
1278                        // Assumes N.MLP
1279                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1280                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1281                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1282                            .parse::<usize>()
1283                            .unwrap();
1284                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1285                    } else {
1286                        false
1287                    }
1288                },
1289                Arc::new(|_| DeviceForLoadTensor::Base),
1290            )?;
1291            vbs.push(vb);
1292        }
1293
1294        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1295            let model_id_str = &gate_model_id;
1296            let model_id = Path::new(&gate_model_id);
1297
1298            let api = {
1299                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1300                let mut api = ApiBuilder::from_cache(cache)
1301                    .with_progress(!silent)
1302                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1303                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1304                    api = api.with_cache_dir(x.into());
1305                }
1306                api.build().map_err(candle_core::Error::msg)?
1307            };
1308            let revision = revision.clone().unwrap_or("main".to_string());
1309            let api = api.repo(Repo::with_revision(
1310                model_id_str.clone(),
1311                RepoType::Model,
1312                revision.clone(),
1313            ));
1314
1315            let mut gate_filenames = vec![];
1316            for rfilename in
1317                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1318            {
1319                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1320            }
1321            assert_eq!(
1322                gate_filenames.len(),
1323                1,
1324                "Gate model ID must contain only one .safetensors file"
1325            );
1326
1327            let vb = from_mmaped_safetensors(
1328                gate_filenames.clone(),
1329                vec![],
1330                Some(dtype),
1331                dev,
1332                vec![None],
1333                silent,
1334                None,
1335                |_| true,
1336                Arc::new(|_| DeviceForLoadTensor::Base),
1337            )?;
1338            info!(
1339                "Loaded gating layers from `{}`",
1340                gate_filenames[0].display()
1341            );
1342            Some(vb)
1343        } else {
1344            None
1345        };
1346
1347        self.model.create_anymoe_layers(
1348            vbs.clone(),
1349            config.clone(),
1350            (prefix.clone(), mlp.clone()),
1351            layers.clone(),
1352            expert_type.clone(),
1353            gate_vb.clone(),
1354        )?;
1355
1356        Ok(())
1357    }
1358    fn amoe_supported(&self) -> bool {
1359        self.model.amoe_supported()
1360    }
1361}