mistralrs_core/pipeline/
normal.rs

1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6    TokenSource,
7};
8use super::{
9    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14    GraniteMoeHybridLoader, LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType,
15    Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader,
16    SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::attention::ATTENTION_CHUNK_SIZE;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
23use crate::lora::Ordering;
24use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
25use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::loaders::auto_device_map;
28use crate::pipeline::loaders::QuantizationConfigShim;
29use crate::pipeline::sampling::sample_and_add_toks;
30use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
31use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
32use crate::pipeline::{ChatTemplate, LocalModelPaths};
33use crate::prefix_cacher::PrefixCacheManagerV2;
34use crate::sequence::Sequence;
35use crate::utils::tokenizer::get_tokenizer;
36use crate::utils::varbuilder_utils::DeviceForLoadTensor;
37use crate::utils::{
38    progress::{new_multi_progress, ProgressScopeGuard},
39    tokens::get_token,
40    varbuilder_utils::from_mmaped_safetensors,
41};
42use crate::xlora_models::NonGranularState;
43use crate::{
44    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
45    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
46    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
47};
48use anyhow::Result;
49use candle_core::{Device, Tensor, Var};
50use hf_hub::Cache;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::log::once_log_info;
53use mistralrs_quant::{
54    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
55};
56use rand_isaac::Isaac64Rng;
57use regex_automata::meta::Regex;
58use std::any::Any;
59use std::borrow::Cow;
60use std::path::{Path, PathBuf};
61use std::str::FromStr;
62use std::sync::{Arc, RwLock};
63use std::time::Instant;
64use std::{env, fs};
65use tokenizers::Tokenizer;
66use tokio::sync::Mutex;
67use tracing::{info, warn};
68
69pub struct NormalPipeline {
70    model: Box<dyn NormalModel + Send + Sync>,
71    tokenizer: Arc<Tokenizer>,
72    no_kv_cache: bool,
73    chat_template: Arc<ChatTemplate>,
74    non_granular_state: Option<NonGranularState>,
75    model_id: String,
76    metadata: Arc<GeneralMetadata>,
77    topology: Option<Topology>,
78    silent: bool,
79    organization: IsqOrganization,
80    // For full UQFF serialization
81    template_filename: Option<PathBuf>,
82    generation_config: Option<PathBuf>,
83    config: String,
84    imatrix: Option<PathBuf>,
85    mapper: Box<dyn DeviceMapper + Send + Sync>,
86}
87
88/// A loader for a "normal" (non-quantized) model.
89pub struct NormalLoader {
90    inner: Box<dyn NormalModelLoader>,
91    model_id: String,
92    config: NormalSpecificConfig,
93    xlora_model_id: Option<String>,
94    lora_adapter_ids: Option<Vec<String>>,
95    kind: ModelKind,
96    xlora_order: Option<Ordering>,
97    no_kv_cache: bool,
98    chat_template: Option<String>,
99    tokenizer_json: Option<String>,
100    tgt_non_granular_index: Option<usize>,
101    token_source: RwLock<Option<TokenSource>>,
102    revision: RwLock<Option<String>>,
103    from_uqff: RwLock<Option<Vec<PathBuf>>>,
104    jinja_explicit: Option<String>,
105    hf_cache_path: Option<PathBuf>,
106}
107
108#[derive(Default)]
109/// A builder for a loader for a "normal" (non-quantized) model.
110pub struct NormalLoaderBuilder {
111    model_id: Option<String>,
112    config: NormalSpecificConfig,
113    xlora_model_id: Option<String>,
114    lora_adapter_ids: Option<Vec<String>>,
115    kind: ModelKind,
116    xlora_order: Option<Ordering>,
117    no_kv_cache: bool,
118    chat_template: Option<String>,
119    tokenizer_json: Option<String>,
120    tgt_non_granular_index: Option<usize>,
121    jinja_explicit: Option<String>,
122    hf_cache_path: Option<PathBuf>,
123}
124
125#[derive(Clone, Default)]
126/// Config specific to loading a normal model.
127pub struct NormalSpecificConfig {
128    pub topology: Option<Topology>,
129    pub organization: IsqOrganization,
130    pub write_uqff: Option<PathBuf>,
131    pub from_uqff: Option<Vec<PathBuf>>,
132    pub imatrix: Option<PathBuf>,
133    pub calibration_file: Option<PathBuf>,
134    pub hf_cache_path: Option<PathBuf>,
135    pub matformer_config_path: Option<PathBuf>,
136    pub matformer_slice_name: Option<String>,
137}
138
139impl NormalLoaderBuilder {
140    pub fn new(
141        config: NormalSpecificConfig,
142        chat_template: Option<String>,
143        tokenizer_json: Option<String>,
144        model_id: Option<String>,
145        no_kv_cache: bool,
146        jinja_explicit: Option<String>,
147    ) -> Self {
148        Self {
149            config,
150            chat_template,
151            tokenizer_json,
152            model_id,
153            kind: ModelKind::Normal,
154            jinja_explicit,
155            no_kv_cache,
156            ..Default::default()
157        }
158    }
159
160    fn with_adapter(
161        mut self,
162        xlora_model_id: String,
163        xlora_order: Ordering,
164        no_kv_cache: bool,
165        tgt_non_granular_index: Option<usize>,
166    ) -> Self {
167        self.xlora_model_id = Some(xlora_model_id);
168        self.xlora_order = Some(xlora_order);
169        self.no_kv_cache = no_kv_cache;
170        self.tgt_non_granular_index = tgt_non_granular_index;
171        self.model_id = if let Some(id) = self.model_id {
172            Some(id)
173        } else {
174            info!(
175                "Using adapter base model ID: `{}`",
176                self.xlora_order.as_ref().unwrap().base_model_id
177            );
178            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
179        };
180        self
181    }
182
183    pub fn with_xlora(
184        mut self,
185        xlora_model_id: String,
186        xlora_order: Ordering,
187        no_kv_cache: bool,
188        tgt_non_granular_index: Option<usize>,
189    ) -> Self {
190        self.kind = ModelKind::Adapter {
191            adapter: AdapterKind::XLora,
192        };
193        self.with_adapter(
194            xlora_model_id,
195            xlora_order,
196            no_kv_cache,
197            tgt_non_granular_index,
198        )
199    }
200
201    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
202        self.kind = ModelKind::Adapter {
203            adapter: AdapterKind::Lora,
204        };
205        self.lora_adapter_ids = Some(lora_adapter_ids);
206        self
207    }
208
209    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
210        self.hf_cache_path = Some(hf_cache_path);
211        self
212    }
213
214    /// If the loader type is not specified, loader type is automatically determined from the
215    /// `architectures` array in the config.
216    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
217        let loader: Box<dyn NormalModelLoader> = match loader_tp {
218            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
219            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
220            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
221            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
222            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
223            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
224            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
225            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
226            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
227            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
228            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
229            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
230            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
231            Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
232            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
233            Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
234            Some(NormalLoaderType::GraniteMoeHybrid) => Box::new(GraniteMoeHybridLoader),
235            None => Box::new(AutoNormalLoader),
236        };
237        Ok(Box::new(NormalLoader {
238            inner: loader,
239            model_id: self.model_id.unwrap(),
240            config: self.config,
241            xlora_model_id: self.xlora_model_id,
242            lora_adapter_ids: self.lora_adapter_ids,
243            kind: self.kind,
244            xlora_order: self.xlora_order,
245            no_kv_cache: self.no_kv_cache,
246            chat_template: self.chat_template,
247            tokenizer_json: self.tokenizer_json,
248            tgt_non_granular_index: self.tgt_non_granular_index,
249            jinja_explicit: self.jinja_explicit,
250            token_source: RwLock::new(None),
251            revision: RwLock::new(None),
252            from_uqff: RwLock::new(None),
253            hf_cache_path: self.hf_cache_path,
254        }))
255    }
256}
257
258impl Loader for NormalLoader {
259    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
260    fn load_model_from_hf(
261        &self,
262        revision: Option<String>,
263        token_source: TokenSource,
264        dtype: &dyn TryIntoDType,
265        device: &Device,
266        silent: bool,
267        mapper: DeviceMapSetting,
268        in_situ_quant: Option<IsqType>,
269        paged_attn_config: Option<PagedAttentionConfig>,
270    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
271        let _progress_guard = ProgressScopeGuard::new(silent);
272        let cache = self
273            .hf_cache_path
274            .clone()
275            .map(Cache::new)
276            .unwrap_or_default();
277        GLOBAL_HF_CACHE.get_or_init(|| cache);
278
279        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
280            LocalModelPaths,
281            &token_source,
282            revision.clone(),
283            self,
284            None,
285            None,
286            silent,
287            self.config.from_uqff.is_some()
288        );
289        if let Some(from_uqff) = self.config.from_uqff.clone() {
290            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
291        }
292        *self
293            .token_source
294            .write()
295            .expect("Failed to write to token source") = Some(token_source);
296        *self.revision.write().expect("Failed to write to revision") = revision;
297        self.load_model_from_path(
298            &paths?,
299            dtype,
300            device,
301            silent,
302            mapper,
303            in_situ_quant,
304            paged_attn_config,
305        )
306    }
307
308    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
309    fn load_model_from_path(
310        &self,
311        paths: &Box<dyn ModelPaths>,
312        dtype: &dyn TryIntoDType,
313        device: &Device,
314        silent: bool,
315        mut mapper: DeviceMapSetting,
316        mut in_situ_quant: Option<IsqType>,
317        mut paged_attn_config: Option<PagedAttentionConfig>,
318    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
319        let _progress_guard = ProgressScopeGuard::new(silent);
320        let config = std::fs::read_to_string(paths.get_config_filename())?;
321
322        if !self.inner.supports_paged_attention(&config)? {
323            paged_attn_config = None;
324        }
325
326        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
327
328        let use_nccl = mistralrs_quant::distributed::use_nccl();
329
330        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
331            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
332            let WorkerTransferData::Init { id: _, worker_rank } = payload;
333            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
334        } else if use_nccl {
335            vec![candle_core::Device::new_cuda(0)?]
336        } else {
337            device_map::get_all_similar_devices(device)?
338        };
339        #[cfg(feature = "cuda")]
340        for device in &available_devices {
341            if let Device::Cuda(dev) = device {
342                unsafe { dev.disable_event_tracking() };
343            }
344        }
345        let device = if use_nccl || cfg!(feature = "ring") {
346            available_devices[0].clone()
347        } else {
348            device.clone()
349        };
350
351        // If auto, convert to Map if not using nccl
352        if use_nccl || cfg!(feature = "ring") {
353            mapper = DeviceMapSetting::DummyNccl {
354                nm_device: available_devices[0].clone(),
355            };
356        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
357            // Initial dtype
358            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
359
360            // Disable ISQ if we are loading a prequantized model.
361            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
362                in_situ_quant = None;
363            }
364
365            // ISQ or UQFF: quantized path
366            // Match logic below where UQFF has priority
367            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
368                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
369                    let weight_pack_factor = {
370                        let ser_artifacts = unsafe {
371                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
372                        };
373                        let mut total_pack_factors = 0;
374                        let total_tensors = ser_artifacts.tensors().len();
375                        for (_, artifact) in ser_artifacts.tensors() {
376                            let artifact = artifact.data();
377                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
378                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
379                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
380                            {
381                                QuantizedSerdeType::Hqq => {
382                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
383                                        .pack_factor(dtype)
384                                }
385                                QuantizedSerdeType::Gguf => {
386                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
387                                        .pack_factor(dtype)
388                                }
389                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
390                                QuantizedSerdeType::Unquant => 1,
391                                QuantizedSerdeType::Afq => {
392                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
393                                        .pack_factor(dtype)
394                                }
395                            };
396                            total_pack_factors += pack_factor;
397                        }
398
399                        total_pack_factors / total_tensors
400                    };
401
402                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
403                        &config,
404                        dtype,
405                        weight_pack_factor,
406                        None,
407                    )?;
408                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
409                        &config,
410                        dtype,
411                        weight_pack_factor,
412                        None,
413                    )?;
414                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
415                    (
416                        layer_sizes_in_bytes,
417                        non_mapped_size_in_bytes,
418                        layer_sizes_sum + non_mapped_size_in_bytes,
419                    )
420                } else if let Some(isq) = in_situ_quant {
421                    let weight_pack_factor = isq.pack_factor(dtype);
422                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
423                        &config,
424                        dtype,
425                        weight_pack_factor,
426                        None,
427                    )?;
428                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
429                        &config,
430                        dtype,
431                        weight_pack_factor,
432                        None,
433                    )?;
434                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
435                    (
436                        layer_sizes_in_bytes,
437                        non_mapped_size_in_bytes,
438                        layer_sizes_sum + non_mapped_size_in_bytes,
439                    )
440                } else {
441                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
442                    let weight_pack_factor =
443                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
444                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
445                        &config,
446                        dtype,
447                        weight_pack_factor,
448                        None,
449                    )?;
450                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
451                        &config,
452                        dtype,
453                        weight_pack_factor,
454                        None,
455                    )?;
456                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
457                    (
458                        layer_sizes_in_bytes,
459                        non_mapped_size_in_bytes,
460                        layer_sizes_sum + non_mapped_size_in_bytes,
461                    )
462                };
463
464            let new = auto_device_map::get_device_layers(
465                &*self.inner,
466                &config,
467                self.inner.num_layers(&config)?,
468                layer_sizes_in_bytes,
469                non_mapped_size_in_bytes,
470                total_model_size_in_bytes,
471                &available_devices,
472                dtype,
473                &params,
474                paged_attn_config.as_ref(),
475            )?;
476            mapper = DeviceMapSetting::Map(new);
477        }
478
479        let pipeline_mapper = mapper.into_mapper(
480            self.inner.num_layers(&config)?,
481            &device,
482            self.config.topology.as_ref(),
483        )?;
484        let mapper = mapper.into_mapper(
485            self.inner.num_layers(&config)?,
486            &device,
487            self.config.topology.as_ref(),
488        )?;
489        let mut layer_devices = Vec::new();
490        for layer in 0..self.inner.num_layers(&config)? {
491            let device = mapper.device_for(layer, false).cloned();
492            layer_devices.push(device);
493        }
494        let dtype = mapper.get_min_dtype(dtype)?;
495
496        // TODO: PagedAttention is not supported with CPU for now.
497        // This check is not really necessary because `get_device_layers` should prevent it.
498        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
499        if mapping_uses_cpu && paged_attn_config.is_some() {
500            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
501            paged_attn_config = None;
502        }
503
504        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
505        if crate::using_flash_attn() {
506            once_log_info("FlashAttention is enabled.");
507        }
508
509        let topology_overrides = self
510            .config
511            .topology
512            .as_ref()
513            .map(|topology| {
514                topology
515                    .pattern_overrides()
516                    .into_iter()
517                    .map(|(regex, layer)| ImmediateIsqOverride {
518                        predicate: regex,
519                        ty: layer.isq,
520                        device: layer.device.clone(),
521                    })
522                    .collect::<Vec<_>>()
523            })
524            .unwrap_or_default();
525        let has_override_isq = topology_overrides
526            .iter()
527            .any(|override_entry| override_entry.ty.is_some());
528        let topology_requires_post_quant = self
529            .config
530            .topology
531            .as_ref()
532            .is_some_and(|topology| topology.requires_post_quantization());
533
534        let allow_immediate_cli = self.config.imatrix.is_none()
535            && self.config.calibration_file.is_none()
536            && !device.is_cuda()
537            && in_situ_quant.is_some();
538
539        let mut immediate_ty = None;
540        let mut immediate_predicates = Vec::new();
541        if allow_immediate_cli {
542            immediate_ty = in_situ_quant;
543            immediate_predicates =
544                if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
545                    self.inner.immediate_isq_predicates_moqe(&config)?
546                } else {
547                    self.inner.immediate_isq_predicates(&config)?
548                };
549            info!("Applying ISQ to {in_situ_quant:?}");
550            if immediate_predicates.is_empty() {
551                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
552            }
553        }
554
555        let use_immediate = allow_immediate_cli || has_override_isq;
556        if use_immediate {
557            mistralrs_quant::set_immediate_isq_with_overrides(
558                immediate_ty,
559                immediate_predicates.clone(),
560                topology_overrides.clone(),
561            );
562        }
563
564        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
565        let mut loading_isq = if use_immediate {
566            false
567        } else {
568            in_situ_quant.is_some()
569        };
570        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
571            loading_isq = true;
572        }
573        loading_isq |= topology_requires_post_quant;
574
575        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
576            anyhow::bail!(
577                "`imatrix` and `calibration_file` were both specified, this is not allowed."
578            );
579        }
580
581        // Load onto the regular device if not using isq or if the calibration file is specified
582        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
583            loading_isq = false;
584            device.clone()
585        } else {
586            Device::Cpu
587        };
588
589        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
590
591        let attention_mechanism = if paged_attn_config.is_some() {
592            AttentionImplementation::PagedAttention
593        } else {
594            AttentionImplementation::Eager
595        };
596
597        let multi_progress = Arc::new(new_multi_progress());
598
599        // Load matformer slicing config if provided
600        let matformer_slicing_config = if let Some(matformer_path) =
601            &self.config.matformer_config_path
602        {
603            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
604            info!("Loading Matformer config from {:?}", matformer_path);
605            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
606
607            if let Some(slice_name) = &self.config.matformer_slice_name {
608                info!("Using Matformer slice: {}", slice_name);
609                Some(MatformerSliceConfig::new(slice_name.clone(), config))
610            } else {
611                // If no slice name is provided but config exists, we'll need to handle this
612                // For now, return None and let the model handle the default slice selection
613                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
614                None
615            }
616        } else {
617            None
618        };
619
620        let mut model = if use_nccl || cfg!(feature = "ring") {
621            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
622                dtype,
623                &device,
624                &available_devices,
625                silent,
626                &config,
627                loading_isq,
628                self.config.from_uqff.is_some(),
629                self.config.organization,
630                &*self.inner,
631                paths.as_ref(),
632            )?;
633
634            // Special case for where things can be more optimially loaded.
635            match self.kind {
636                ModelKind::Normal => normal_model_loader_sharded!(
637                    sharded_vb,
638                    config,
639                    self.inner,
640                    mapper,
641                    loading_isq,
642                    device.clone(),
643                    attention_mechanism,
644                    multi_progress.clone(),
645                    matformer_slicing_config.clone(),
646                ),
647                ModelKind::Adapter {
648                    adapter: AdapterKind::XLora,
649                } => xlora_model_loader!(
650                    paths,
651                    Some(dtype),
652                    &load_device,
653                    layer_devices.clone(),
654                    config,
655                    self.inner,
656                    silent,
657                    mapper,
658                    loading_isq,
659                    device.clone(),
660                    multi_progress.clone(),
661                    matformer_slicing_config.clone(),
662                ),
663                ModelKind::Adapter {
664                    adapter: AdapterKind::Lora,
665                } => lora_model_loader!(
666                    paths,
667                    Some(dtype),
668                    &load_device,
669                    layer_devices.clone(),
670                    config,
671                    self.inner,
672                    silent,
673                    mapper,
674                    loading_isq,
675                    self.config.from_uqff.is_some(),
676                    device.clone(),
677                    attention_mechanism,
678                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
679                    multi_progress.clone(),
680                    matformer_slicing_config.clone(),
681                ),
682                _ => unreachable!(),
683            }
684        } else {
685            match self.kind {
686                ModelKind::Normal => normal_model_loader!(
687                    paths,
688                    Some(dtype),
689                    &load_device,
690                    layer_devices.clone(),
691                    config,
692                    self.inner,
693                    silent,
694                    mapper,
695                    loading_isq,
696                    self.config.from_uqff.is_some(),
697                    device.clone(),
698                    attention_mechanism,
699                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
700                    multi_progress.clone(),
701                    matformer_slicing_config.clone(),
702                ),
703                ModelKind::Adapter {
704                    adapter: AdapterKind::XLora,
705                } => xlora_model_loader!(
706                    paths,
707                    Some(dtype),
708                    &load_device,
709                    layer_devices.clone(),
710                    config,
711                    self.inner,
712                    silent,
713                    mapper,
714                    loading_isq,
715                    device.clone(),
716                    multi_progress.clone(),
717                    matformer_slicing_config.clone(),
718                ),
719                ModelKind::Adapter {
720                    adapter: AdapterKind::Lora,
721                } => lora_model_loader!(
722                    paths,
723                    Some(dtype),
724                    &load_device,
725                    layer_devices.clone(),
726                    config,
727                    self.inner,
728                    silent,
729                    mapper,
730                    loading_isq,
731                    self.config.from_uqff.is_some(),
732                    device.clone(),
733                    attention_mechanism,
734                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
735                    multi_progress.clone(),
736                    matformer_slicing_config.clone(),
737                ),
738                _ => unreachable!(),
739            }
740        };
741
742        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
743        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
744            match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
745                Ok(conf) => Some(conf),
746                Err(e) => {
747                    warn!("Failed to parse generation_config.json: {}", e);
748                    None
749                }
750            }
751        });
752
753        let chat_template_explicit = paths
754            .get_chat_template_explicit()
755            .as_ref()
756            .map(|x| x.to_string_lossy().to_string());
757        let chat_template = get_chat_template(
758            paths,
759            self.jinja_explicit.as_ref(),
760            chat_template_explicit.as_ref(),
761            self.chat_template.as_ref(),
762            None,
763        );
764
765        if let Some(calibration_file) = &self.config.calibration_file {
766            let calibration_data = std::fs::read_to_string(calibration_file)?;
767            // Tokenize, don't add bos yet
768            let tokens = tokenizer
769                .encode_fast(calibration_data, false)
770                .map_err(anyhow::Error::msg)?
771                .get_ids()
772                .to_vec();
773            info!(
774                "Collecting imatrix from calibration file `{}` of {} tokens.",
775                calibration_file.display(),
776                tokens.len()
777            );
778            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
779            let bos_tok_id = tokenizer
780                .token_to_id(&bos_toks[0])
781                .expect("Somehow the bos token is not present.");
782
783            match self.config.organization {
784                IsqOrganization::Default => model.begin_track_stats()?,
785                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
786            }
787
788            const CHUNK_SIZE: usize = 1024;
789            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
790            let start = Instant::now();
791            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
792                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
793                let chunk_len = chunk.len();
794
795                let start = Instant::now();
796                let inputs = make_prompt_chunk(
797                    0,
798                    vec![&chunk],
799                    &[0],
800                    &load_device,
801                    None,
802                    false,
803                    None,
804                    Some(pipeline_mapper.as_ref()),
805                )?;
806
807                model.forward(
808                    &inputs.input.to_device(model.device())?,
809                    &inputs.positions,
810                    inputs.context_lens.clone(),
811                    inputs.position_ids.clone(),
812                    None,
813                    &inputs.flash_meta.clone(),
814                )?;
815
816                match model.cache_mut() {
817                    EitherCache::Full(full) => {
818                        for layer in &mut *full.lock() {
819                            *layer = None
820                        }
821                    }
822                    EitherCache::Normal(normal) => {
823                        for layer in &mut *normal.lock().unwrap().0 {
824                            layer.reset();
825                        }
826                    }
827                    EitherCache::Hybrid(hybrid) => {
828                        hybrid.lock().unwrap().reset();
829                    }
830                }
831
832                let end = Instant::now();
833                info!(
834                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
835                    i + 1,
836                    end.duration_since(start).as_secs_f32()
837                );
838            }
839            load_device.synchronize()?;
840            let end = Instant::now();
841            info!(
842                "Finished collecting imatrix in {:.2}s",
843                end.duration_since(start).as_secs_f32()
844            );
845        }
846
847        // Only if loading from UQFF
848        let should_serialize = self.config.write_uqff.is_some();
849        let should_quantize_pass = loading_isq;
850
851        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
852            let imatrix_source = if should_quantize_pass {
853                match (
854                    self.config.imatrix.as_ref(),
855                    self.config.calibration_file.is_some(),
856                ) {
857                    (None, false) => None,
858                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
859                    (None, true) => Some(ImatrixDataSource::Collected),
860                    (Some(_), true) => unreachable!(),
861                }
862            } else {
863                None
864            };
865
866            if should_quantize_pass {
867                info!("Applying ISQ to all ranks.");
868            } else {
869                info!("Serializing existing ISQ tensors without additional quantization.");
870            }
871
872            let multi_progress = Arc::new(new_multi_progress());
873
874            model.quantize(
875                in_situ_quant,
876                model.device().clone(),
877                self.config.topology.as_ref(),
878                silent,
879                imatrix_source,
880                self.config.organization,
881                should_quantize_pass,
882                self.config.write_uqff.as_ref(),
883                UqffFullSer {
884                    tokenizer: &tokenizer,
885                    template_filename: paths.get_template_filename(),
886                    generation_config: paths.get_gen_conf_filename(),
887                    config: config.clone(),
888                    processor_filename: &None,
889                    preprocessor_filename: &None,
890                    modules: None,
891                    module_paths: None,
892                },
893                multi_progress.clone(),
894            )?;
895        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
896            model.load_from_artifacts(
897                device.clone(),
898                self.config.topology.as_ref(),
899                silent,
900                from_uqff,
901            )?;
902        }
903
904        let paged_attn_config = if matches!(
905            self.kind,
906            ModelKind::Adapter {
907                adapter: AdapterKind::XLora
908            }
909        ) {
910            warn!(
911                "Adapter parallel_models do not currently support PagedAttention, running without"
912            );
913            None
914        } else {
915            paged_attn_config
916        };
917
918        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
919            let cache_config = calculate_cache_config(
920                paged_attn_config.mem_gpu,
921                paged_attn_config.block_size,
922                dtype,
923                paged_attn_config.cache_type,
924                model.config(),
925                &device,
926                &pipeline_mapper
927                    .get_unique_devices()
928                    .into_iter()
929                    .map(Some)
930                    .collect::<Vec<_>>(),
931                silent,
932            )?;
933
934            let mut layer_devices = Vec::new();
935            for layer in 0..self.inner.num_layers(&config)? {
936                let device = model.get_layers().1.device_for(layer, false).cloned();
937                layer_devices.push(device);
938            }
939            let cache_engine = CacheEngine::new(
940                model.config(),
941                &cache_config,
942                dtype,
943                model.device(),
944                layer_devices.clone(),
945            )?;
946
947            (Some(cache_config), Some(cache_engine))
948        } else {
949            (None, None)
950        };
951
952        let max_seq_len = model.max_seq_len();
953        let llg_factory = build_llg_factory(tokenizer.clone())?;
954        let num_hidden_layers = match model.cache() {
955            EitherCache::Full(full) => full.lock().len(),
956            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
957            EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
958        };
959        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
960        let sliding_window = model.config().sliding_window;
961        let model_metadata = Arc::new(model.config().clone());
962
963        Ok(Arc::new(Mutex::new(NormalPipeline {
964            model,
965            tokenizer: tokenizer.into(),
966            no_kv_cache: self.no_kv_cache,
967            chat_template: Arc::new(chat_template),
968            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
969                NonGranularState {
970                    non_granular_index: Arc::new(Mutex::new(0)),
971                    tgt_non_granular_index,
972                }
973            }),
974            model_id: self.model_id.clone(),
975            metadata: Arc::new(GeneralMetadata {
976                max_seq_len,
977                llg_factory: Some(llg_factory),
978                no_kv_cache: self.no_kv_cache,
979                no_prefix_cache: is_xlora,
980                num_hidden_layers,
981                eos_tok: eos,
982                kind: self.kind.clone(),
983                is_xlora,
984                activation_dtype: dtype,
985                sliding_window,
986                cache_config,
987                cache_engine,
988                model_metadata: Some(model_metadata),
989                modalities: Modalities {
990                    input: vec![SupportedModality::Text],
991                    output: vec![SupportedModality::Text],
992                },
993            }),
994            topology: self.config.topology.clone(),
995            silent,
996            organization: self.config.organization,
997            template_filename: paths.get_template_filename().clone(),
998            generation_config: paths.get_gen_conf_filename().cloned(),
999            config,
1000            imatrix: self.config.imatrix.clone(),
1001            mapper: pipeline_mapper,
1002        })))
1003    }
1004
1005    fn get_id(&self) -> String {
1006        self.model_id.clone()
1007    }
1008
1009    fn get_kind(&self) -> ModelKind {
1010        self.kind.clone()
1011    }
1012}
1013
1014impl PreProcessingMixin for NormalPipeline {
1015    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1016        Some(self.chat_template.clone())
1017    }
1018    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1019        None
1020    }
1021}
1022
1023impl IsqPipelineMixin for NormalPipeline {
1024    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1025        let device = self.device().clone();
1026        let multi_progress = Arc::new(new_multi_progress());
1027        self.model.quantize(
1028            Some(dtype),
1029            device.clone(),
1030            self.topology.as_ref(),
1031            self.silent,
1032            self.imatrix.as_ref().map(ImatrixDataSource::File),
1033            self.organization,
1034            true,
1035            None,
1036            UqffFullSer {
1037                tokenizer: &self.tokenizer,
1038                template_filename: &self.template_filename,
1039                generation_config: self.generation_config.as_ref(),
1040                config: self.config.clone(),
1041                processor_filename: &None,
1042                preprocessor_filename: &None,
1043                modules: None,
1044                module_paths: None,
1045            },
1046            multi_progress.clone(),
1047        )?;
1048        Ok(())
1049    }
1050}
1051
1052impl CacheManagerMixin for NormalPipeline {
1053    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1054        match self.model.cache() {
1055            EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1056            EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1057            EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1058        }
1059    }
1060    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1061        match self.model.cache() {
1062            EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1063            EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1064            EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1065        }
1066    }
1067    fn set_none_cache(
1068        &self,
1069        seqs: &mut [&mut Sequence],
1070        reset_non_granular: bool,
1071        modify_draft_cache: bool,
1072        load_preallocated_cache: bool,
1073    ) {
1074        match self.model.cache() {
1075            EitherCache::Full(_) => {
1076                FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1077            }
1078            EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1079                self,
1080                seqs,
1081                modify_draft_cache,
1082                load_preallocated_cache,
1083            ),
1084            EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1085                self,
1086                seqs,
1087                modify_draft_cache,
1088                load_preallocated_cache,
1089            ),
1090        }
1091        if reset_non_granular {
1092            self.reset_non_granular_state()
1093        }
1094    }
1095    fn cache(&self) -> &EitherCache {
1096        self.model.cache()
1097    }
1098}
1099
1100impl MetadataMixin for NormalPipeline {
1101    fn device(&self) -> Device {
1102        self.model.device().clone()
1103    }
1104    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1105        Some(self.tokenizer.clone())
1106    }
1107    fn name(&self) -> String {
1108        self.model_id.clone()
1109    }
1110    fn reset_non_granular_state(&self) {
1111        if let Some(s) = self.non_granular_state.as_ref() {
1112            *self.cache().full().get_scalings_cache() = None;
1113            *get_mut_arcmutex!(s.non_granular_index) = 0;
1114        }
1115    }
1116    fn get_metadata(&self) -> Arc<GeneralMetadata> {
1117        self.metadata.clone()
1118    }
1119    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1120        Some(&*self.mapper)
1121    }
1122}
1123
1124#[async_trait::async_trait]
1125impl Pipeline for NormalPipeline {
1126    fn forward_inputs(
1127        &mut self,
1128        inputs: Box<dyn Any>,
1129        return_raw_logits: bool,
1130    ) -> Result<ForwardInputsResult, candle_core::Error> {
1131        let ModelInputs {
1132            input_ids,
1133            input_ids_full,
1134            seqlen_offsets,
1135            seqlen_offsets_full,
1136            context_lens,
1137            position_ids,
1138            paged_attn_meta,
1139            flash_meta,
1140            flash_meta_full,
1141        } = *inputs.downcast().expect("Downcast failed.");
1142        let metadata = self.get_metadata();
1143        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1144            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1145            (Some(_), None) => {
1146                // This can happen if Rust-side user code is wrong
1147                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1148            }
1149            (None, Some(_)) => {
1150                // This should never happen but we handle it anyway
1151                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1152            }
1153            (None, None) => None,
1154        };
1155        let logits = match self.model.is_xlora() {
1156            false => {
1157                let paged_attn_meta = paged_attn_meta
1158                    .as_ref()
1159                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1160
1161                self.model.forward(
1162                    &input_ids,
1163                    &seqlen_offsets,
1164                    context_lens,
1165                    position_ids,
1166                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1167                    &flash_meta,
1168                )?
1169            }
1170            true => self.model.xlora_forward(
1171                &input_ids,
1172                input_ids_full.as_ref().unwrap_or(&input_ids),
1173                &seqlen_offsets,
1174                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1175                self.no_kv_cache,
1176                &self.non_granular_state,
1177                context_lens,
1178                position_ids,
1179                &flash_meta,
1180                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1181            )?,
1182        };
1183        if return_raw_logits {
1184            Ok(ForwardInputsResult::RawLogits { logits })
1185        } else {
1186            Ok(ForwardInputsResult::CausalGeneration { logits })
1187        }
1188    }
1189    async fn sample_causal_gen(
1190        &self,
1191        seqs: &mut [&mut Sequence],
1192        logits: Vec<Tensor>,
1193        prefix_cacher: &mut PrefixCacheManagerV2,
1194        disable_eos_stop: bool,
1195        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1196    ) -> Result<(), candle_core::Error> {
1197        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1198    }
1199    fn category(&self) -> ModelCategory {
1200        ModelCategory::Text
1201    }
1202}
1203
1204impl AnyMoePipelineMixin for NormalPipeline {
1205    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1206        self.model.finish_training(gate_model_id)
1207    }
1208    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1209        self.model.get_vars()
1210    }
1211    fn amoe_base_model_trainable_params(&self) -> usize {
1212        self.model.trainable_params()
1213    }
1214    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1215        self.model.take_cached_gating_outputs()
1216    }
1217    fn amoe_create_layers(
1218        &mut self,
1219        model_ids: Vec<String>,
1220        token: &TokenSource,
1221        revision: Option<String>,
1222        match_regex: &str,
1223        config: crate::amoe::AnyMoeConfig,
1224        dtype: candle_core::DType,
1225        dev: &Device,
1226        (prefix, mlp): (String, String),
1227        layers: Vec<usize>,
1228        expert_type: AnyMoeExpertType,
1229        silent: bool,
1230        gate_model_id: Option<String>,
1231    ) -> candle_core::Result<()> {
1232        let mut vbs = Vec::new();
1233        // Precompile regex here
1234        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1235        for model_id in model_ids {
1236            let model_id_str = &model_id;
1237            let model_id = Path::new(&model_id);
1238
1239            let api = {
1240                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1241                let mut api = ApiBuilder::from_cache(cache)
1242                    .with_progress(!silent)
1243                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1244                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1245                    api = api.with_cache_dir(x.into());
1246                }
1247                api.build().map_err(candle_core::Error::msg)?
1248            };
1249            let revision = revision.clone().unwrap_or("main".to_string());
1250            let api = api.repo(Repo::with_revision(
1251                model_id_str.clone(),
1252                RepoType::Model,
1253                revision.clone(),
1254            ));
1255
1256            let mut filenames = vec![];
1257            for rfilename in
1258                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1259            {
1260                filenames.push(api_get_file!(api, &rfilename, model_id));
1261            }
1262
1263            let regex = regex.clone();
1264            let match_regex_clone = match_regex.to_string();
1265            let layers_clone = layers.clone();
1266            let vb = from_mmaped_safetensors(
1267                filenames,
1268                vec![],
1269                Some(dtype),
1270                dev,
1271                vec![None],
1272                silent,
1273                None,
1274                move |key| {
1275                    if regex.is_match(&key) {
1276                        // Idx of the last char of the layer id, +1
1277                        // Assumes N.MLP
1278                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1279                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1280                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1281                            .parse::<usize>()
1282                            .unwrap();
1283                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1284                    } else {
1285                        false
1286                    }
1287                },
1288                Arc::new(|_| DeviceForLoadTensor::Base),
1289            )?;
1290            vbs.push(vb);
1291        }
1292
1293        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1294            let model_id_str = &gate_model_id;
1295            let model_id = Path::new(&gate_model_id);
1296
1297            let api = {
1298                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1299                let mut api = ApiBuilder::from_cache(cache)
1300                    .with_progress(!silent)
1301                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1302                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1303                    api = api.with_cache_dir(x.into());
1304                }
1305                api.build().map_err(candle_core::Error::msg)?
1306            };
1307            let revision = revision.clone().unwrap_or("main".to_string());
1308            let api = api.repo(Repo::with_revision(
1309                model_id_str.clone(),
1310                RepoType::Model,
1311                revision.clone(),
1312            ));
1313
1314            let mut gate_filenames = vec![];
1315            for rfilename in
1316                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1317            {
1318                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1319            }
1320            assert_eq!(
1321                gate_filenames.len(),
1322                1,
1323                "Gate model ID must contain only one .safetensors file"
1324            );
1325
1326            let vb = from_mmaped_safetensors(
1327                gate_filenames.clone(),
1328                vec![],
1329                Some(dtype),
1330                dev,
1331                vec![None],
1332                silent,
1333                None,
1334                |_| true,
1335                Arc::new(|_| DeviceForLoadTensor::Base),
1336            )?;
1337            info!(
1338                "Loaded gating layers from `{}`",
1339                gate_filenames[0].display()
1340            );
1341            Some(vb)
1342        } else {
1343            None
1344        };
1345
1346        self.model.create_anymoe_layers(
1347            vbs.clone(),
1348            config.clone(),
1349            (prefix.clone(), mlp.clone()),
1350            layers.clone(),
1351            expert_type.clone(),
1352            gate_vb.clone(),
1353        )?;
1354
1355        Ok(())
1356    }
1357    fn amoe_supported(&self) -> bool {
1358        self.model.amoe_supported()
1359    }
1360}