mistralrs_core/pipeline/
gguf.rs

1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5    TokenSource,
6};
7use super::{
8    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9    MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::attention::ATTENTION_CHUNK_SIZE;
12use crate::device_map::{self, DeviceMapper};
13use crate::gguf::{
14    get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
15};
16use crate::gguf::{Content, GGUFArchitecture};
17use crate::kv_cache::{FullCacheManager, NormalCacheManager};
18use crate::lora::Ordering;
19use crate::paged_attention::{
20    calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
21};
22use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
23use crate::pipeline::loaders::DeviceMappedModelLoader;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::ChatTemplate;
26use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::progress::ProgressScopeGuard;
32use crate::utils::tokenizer::get_tokenizer;
33use crate::xlora_models::NonGranularState;
34use crate::{
35    get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
36    Pipeline, Topology, TryIntoDType,
37};
38use crate::{
39    models::quantized_llama::ModelWeights as QLlama,
40    models::quantized_phi2::ModelWeights as QPhi,
41    models::quantized_phi3::ModelWeights as QPhi3,
42    models::quantized_qwen::ModelWeights as QQwen,
43    models::quantized_qwen3::ModelWeights as QQwen3,
44    models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
45    models::quantized_starcoder2::ModelWeights as QStarcoder2,
46    utils::tokens::get_token,
47    xlora_models::{XLoraQLlama, XLoraQPhi3},
48};
49use anyhow::{bail, Result};
50use candle_core::{Device, Tensor};
51use either::Either;
52use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
53use mistralrs_quant::IsqType;
54use rand_isaac::Isaac64Rng;
55use std::any::Any;
56use std::fs;
57use std::path::PathBuf;
58use std::str::FromStr;
59use std::sync::Arc;
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64enum Model {
65    Llama(QLlama),
66    Phi2(QPhi),
67    XLoraLlama(XLoraQLlama),
68    XLoraPhi3(XLoraQPhi3),
69    Phi3(QPhi3),
70    Starcoder2(QStarcoder2),
71    Qwen(QQwen),
72    Qwen3(QQwen3),
73    Qwen3MoE(QQwen3MoE),
74}
75
76pub struct GGUFPipeline {
77    model: Model,
78    tokenizer: Arc<Tokenizer>,
79    no_kv_cache: bool,
80    chat_template: Arc<ChatTemplate>,
81    model_id: String,
82    non_granular_state: Option<NonGranularState>,
83    metadata: Arc<GeneralMetadata>,
84    mapper: Box<dyn DeviceMapper + Send + Sync>,
85}
86
87/// Loader for a GGUF model.
88pub struct GGUFLoader {
89    model_id: Option<String>,
90    quantized_model_id: String,
91    quantized_filenames: Vec<String>,
92    xlora_model_id: Option<String>,
93    xlora_order: Option<Ordering>,
94    no_kv_cache: bool,
95    chat_template: Option<String>,
96    kind: ModelKind,
97    tgt_non_granular_index: Option<usize>,
98    config: GGUFSpecificConfig,
99    jinja_explicit: Option<String>,
100    lora_adapter_ids: Option<Vec<String>>,
101}
102
103#[derive(Clone, Default)]
104/// Config for a GGUF loader.
105pub struct GGUFSpecificConfig {
106    pub topology: Option<Topology>,
107}
108
109#[derive(Default)]
110/// A builder for a GGUF loader.
111pub struct GGUFLoaderBuilder {
112    model_id: Option<String>,
113    quantized_model_id: String,
114    quantized_filenames: Vec<String>,
115    xlora_model_id: Option<String>,
116    kind: ModelKind,
117    xlora_order: Option<Ordering>,
118    no_kv_cache: bool,
119    chat_template: Option<String>,
120    tgt_non_granular_index: Option<usize>,
121    config: GGUFSpecificConfig,
122    jinja_explicit: Option<String>,
123}
124
125impl GGUFLoaderBuilder {
126    /// Create a loader builder for a GGUF model. `tok_model_id` is the model ID where you can find a
127    /// `tokenizer_config.json` file. If the `chat_template` is specified, then it will be treated as a
128    /// path and used over remote files, removing all remote accesses.
129    pub fn new(
130        chat_template: Option<String>,
131        tok_model_id: Option<String>,
132        quantized_model_id: String,
133        quantized_filenames: Vec<String>,
134        config: GGUFSpecificConfig,
135        no_kv_cache: bool,
136        jinja_explicit: Option<String>,
137    ) -> Self {
138        let kind = ModelKind::GgufQuantized {
139            quant: QuantizationKind::Gguf,
140        };
141
142        Self {
143            chat_template,
144            model_id: tok_model_id,
145            kind,
146            quantized_filenames,
147            quantized_model_id,
148            config,
149            jinja_explicit,
150            no_kv_cache,
151            ..Default::default()
152        }
153    }
154
155    fn with_adapter(
156        mut self,
157        xlora_model_id: String,
158        xlora_order: Ordering,
159        no_kv_cache: bool,
160        tgt_non_granular_index: Option<usize>,
161    ) -> Self {
162        self.xlora_model_id = Some(xlora_model_id);
163        self.xlora_order = Some(xlora_order);
164        self.no_kv_cache = no_kv_cache;
165        self.tgt_non_granular_index = tgt_non_granular_index;
166        self.model_id = if let Some(id) = self.model_id {
167            Some(id)
168        } else {
169            info!(
170                "Using adapter base model ID: `{}`",
171                self.xlora_order.as_ref().unwrap().base_model_id
172            );
173            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
174        };
175        self
176    }
177
178    pub fn with_xlora(
179        mut self,
180        xlora_model_id: String,
181        xlora_order: Ordering,
182        no_kv_cache: bool,
183        tgt_non_granular_index: Option<usize>,
184    ) -> Self {
185        self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
186
187        self.with_adapter(
188            xlora_model_id,
189            xlora_order,
190            no_kv_cache,
191            tgt_non_granular_index,
192        )
193    }
194
195    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
196        self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
197
198        self.with_adapter(lora_model_id, lora_order, false, None)
199    }
200
201    pub fn build(self) -> Box<dyn Loader> {
202        Box::new(GGUFLoader {
203            model_id: self.model_id,
204            xlora_model_id: self.xlora_model_id,
205            kind: self.kind,
206            xlora_order: self.xlora_order,
207            no_kv_cache: self.no_kv_cache,
208            chat_template: self.chat_template,
209            tgt_non_granular_index: self.tgt_non_granular_index,
210            quantized_filenames: self.quantized_filenames,
211            quantized_model_id: self.quantized_model_id,
212            config: self.config,
213            jinja_explicit: self.jinja_explicit,
214            lora_adapter_ids: None,
215        })
216    }
217}
218
219impl GGUFLoader {
220    #[allow(clippy::too_many_arguments)]
221    pub fn new(
222        model_id: Option<String>,
223        quantized_model_id: String,
224        quantized_filenames: Vec<String>,
225        xlora_model_id: Option<String>,
226        kind: ModelKind,
227        xlora_order: Option<Ordering>,
228        no_kv_cache: bool,
229        chat_template: Option<String>,
230        tgt_non_granular_index: Option<usize>,
231        config: GGUFSpecificConfig,
232        jinja_explicit: Option<String>,
233    ) -> Self {
234        let model_id = if let Some(id) = model_id {
235            Some(id)
236        } else if let Some(xlora_order) = xlora_order.clone() {
237            info!(
238                "Using adapter base model ID: `{}`",
239                xlora_order.base_model_id
240            );
241            Some(xlora_order.base_model_id.clone())
242        } else {
243            None
244        };
245        Self {
246            model_id,
247            quantized_model_id,
248            quantized_filenames,
249            xlora_model_id,
250            xlora_order,
251            no_kv_cache,
252            chat_template,
253            kind,
254            tgt_non_granular_index,
255            config,
256            jinja_explicit,
257            lora_adapter_ids: None,
258        }
259    }
260}
261
262impl Loader for GGUFLoader {
263    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
264    fn load_model_from_hf(
265        &self,
266        revision: Option<String>,
267        token_source: TokenSource,
268        dtype: &dyn TryIntoDType,
269        device: &Device,
270        silent: bool,
271        mapper: DeviceMapSetting,
272        in_situ_quant: Option<IsqType>,
273        paged_attn_config: Option<PagedAttentionConfig>,
274    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
275        let _progress_guard = ProgressScopeGuard::new(silent);
276        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
277            LocalModelPaths,
278            &token_source,
279            revision,
280            self,
281            self.quantized_model_id.clone(),
282            self.quantized_filenames.clone(),
283            silent
284        );
285        self.load_model_from_path(
286            &paths?,
287            dtype,
288            device,
289            silent,
290            mapper,
291            in_situ_quant,
292            paged_attn_config,
293        )
294    }
295
296    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
297    fn load_model_from_path(
298        &self,
299        paths: &Box<dyn ModelPaths>,
300        dtype: &dyn TryIntoDType,
301        device: &Device,
302        silent: bool,
303        mut mapper: DeviceMapSetting,
304        in_situ_quant: Option<IsqType>,
305        mut paged_attn_config: Option<PagedAttentionConfig>,
306    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
307        let _progress_guard = ProgressScopeGuard::new(silent);
308        if in_situ_quant.is_some() {
309            anyhow::bail!(
310                "You are trying to in-situ quantize a GGUF model. This will not do anything."
311            );
312        }
313
314        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
315
316        let mut readers = Vec::new();
317        for filename in paths.get_weight_filenames() {
318            readers.push(std::fs::File::open(filename)?);
319        }
320        let mut readers = readers.iter_mut().collect::<Vec<_>>();
321
322        let model = Content::from_readers(&mut readers)?;
323        if !silent {
324            model.print_metadata()?;
325        }
326        let arch = model.arch();
327
328        // If auto, convert to Map
329        let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
330        if let DeviceMapSetting::Auto(params) = mapper.clone() {
331            let devices = device_map::get_all_similar_devices(device)?;
332            // Initial dtype
333            let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
334
335            let model = GgufDeviceMapLoaderInner {
336                model: &model,
337                arch,
338            };
339
340            let layer_sizes_in_bytes =
341                model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1, None)?;
342            let non_mapped_size_in_bytes =
343                model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1, None)?;
344            let total_model_size_in_bytes =
345                layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
346
347            let new = model.get_device_layers(
348                "this is a dummy config!",
349                num_layers,
350                layer_sizes_in_bytes,
351                non_mapped_size_in_bytes,
352                total_model_size_in_bytes,
353                &devices,
354                dtype,
355                &params,
356                paged_attn_config.as_ref(),
357            )?;
358            mapper = DeviceMapSetting::Map(new);
359        }
360
361        #[cfg(feature = "cuda")]
362        if let Device::Cuda(dev) = &device {
363            unsafe { dev.disable_event_tracking() };
364        }
365
366        let pipeline_mapper =
367            mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
368        let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
369        let mut layer_devices = Vec::new();
370        for layer in 0..num_layers {
371            let device = mapper.device_for(layer, false).cloned();
372            layer_devices.push(device);
373        }
374
375        // TODO: PagedAttention is not supported with CPU for now.
376        // This check is not really necessary because `get_device_layers` should prevent it.
377        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
378        if mapping_uses_cpu {
379            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
380            paged_attn_config = None;
381        }
382
383        let GgufTokenizerConversion {
384            tokenizer,
385            bos,
386            eos,
387            unk,
388        } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
389            convert_gguf_to_hf_tokenizer(&model)?
390        } else {
391            GgufTokenizerConversion {
392                tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
393                bos: None,
394                eos: None,
395                unk: None,
396            }
397        };
398
399        // Only load gguf chat template if there is nothing else
400        let gguf_chat_template =
401            if paths.get_template_filename().is_none() && self.chat_template.is_none() {
402                get_gguf_chat_template(&model)?
403            } else {
404                None
405            };
406
407        let has_adapter = self.kind.is_adapted();
408        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
409
410        let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
411            warn!("Adapter models do not currently support PagedAttention, running without");
412            None
413        } else {
414            paged_attn_config
415        };
416
417        let model_config_metadata: ContentConfig = (&model).into();
418        let internal_dtype = mapper.get_min_dtype(dtype)?;
419
420        let model_config = {
421            // Base config (quantization only):
422            let quant = ModelConfig::ParamsGGUF(
423                model,
424                (device, mapper).into(),
425                if paged_attn_config.is_some() {
426                    AttentionImplementation::PagedAttention
427                } else {
428                    AttentionImplementation::Eager
429                },
430                internal_dtype,
431            );
432
433            // With optional adapter config:
434            let mut adapter = None;
435            if has_adapter {
436                adapter.replace(ModelConfig::Adapter::try_new(
437                    paths, device, silent, is_xlora,
438                )?);
439            }
440
441            ModelConfig::ModelParams::new(quant, adapter)
442        };
443
444        // Config into model:
445        let model = match self.kind {
446            ModelKind::GgufQuantized { .. } => match arch {
447                GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
448                GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
449                GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
450                GGUFArchitecture::Starcoder2 => {
451                    Model::Starcoder2(QStarcoder2::try_from(model_config)?)
452                }
453                GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
454                GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
455                GGUFArchitecture::Qwen3MoE => Model::Qwen3MoE(QQwen3MoE::try_from(model_config)?),
456                a => bail!("Unsupported architecture `{a:?}` for GGUF"),
457            },
458            ModelKind::GgufAdapter { adapter, .. } => match arch {
459                GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
460                GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
461                a => bail!(
462                    "Unsupported architecture `{a:?}` for GGUF {kind}",
463                    kind = adapter.pretty_name()
464                ),
465            },
466            _ => unreachable!(),
467        };
468
469        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
470            let model_config: &dyn ModelConfigLike = &model_config_metadata;
471            let cache_config = calculate_cache_config(
472                paged_attn_config.mem_gpu,
473                paged_attn_config.block_size,
474                internal_dtype,
475                paged_attn_config.cache_type,
476                model_config,
477                device,
478                &layer_devices,
479                silent,
480            )?;
481            let cache_engine = CacheEngine::new(
482                model_config,
483                &cache_config,
484                internal_dtype,
485                device,
486                layer_devices,
487            )?;
488            (Some(cache_config), Some(cache_engine))
489        } else {
490            (None, None)
491        };
492
493        let gen_conf: Option<GenerationConfig> = paths
494            .get_gen_conf_filename()
495            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
496        let chat_template_explicit = paths
497            .get_chat_template_explicit()
498            .as_ref()
499            .map(|x| x.to_string_lossy().to_string());
500        let mut chat_template = get_chat_template(
501            paths,
502            self.jinja_explicit.as_ref(),
503            chat_template_explicit.as_ref(),
504            self.chat_template.as_ref(),
505            gguf_chat_template,
506        );
507
508        let max_seq_len = match model {
509            Model::Llama(ref l) => l.max_seq_len,
510            Model::Phi2(ref p) => p.max_seq_len,
511            Model::XLoraLlama(ref xl) => xl.max_seq_len,
512            Model::Phi3(ref p) => p.max_seq_len,
513            Model::XLoraPhi3(ref p) => p.max_seq_len,
514            Model::Starcoder2(ref p) => p.max_seq_len,
515            Model::Qwen(ref p) => p.max_seq_len,
516            Model::Qwen3(ref p) => p.max_seq_len,
517            Model::Qwen3MoE(ref p) => p.max_seq_len,
518        };
519        let llg_factory = build_llg_factory(tokenizer.clone())?;
520        let num_hidden_layers = match model {
521            Model::Llama(ref model) => model.cache.normal().0.len(),
522            Model::Phi2(ref model) => model.cache.normal().0.len(),
523            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
524            Model::Phi3(ref model) => model.cache.normal().0.len(),
525            Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
526            Model::Starcoder2(ref model) => model.cache.normal().0.len(),
527            Model::Qwen(ref model) => model.cache.normal().0.len(),
528            Model::Qwen3(ref model) => model.cache.normal().0.len(),
529            Model::Qwen3MoE(ref model) => model.cache.normal().0.len(),
530        };
531
532        if chat_template.bos_token.is_none() {
533            if let Some(v) = bos {
534                chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
535            }
536        }
537        if chat_template.eos_token.is_none() {
538            if let Some(v) = eos {
539                chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
540            }
541        }
542        if chat_template.unk_token.is_none() {
543            if let Some(v) = unk {
544                chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(v)));
545            }
546        }
547
548        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
549        Ok(Arc::new(Mutex::new(GGUFPipeline {
550            model,
551            tokenizer: tokenizer.into(),
552            no_kv_cache: self.no_kv_cache,
553            chat_template: Arc::new(chat_template),
554            model_id: self
555                .model_id
556                .clone()
557                .unwrap_or(self.quantized_model_id.clone()),
558            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
559                NonGranularState {
560                    non_granular_index: Arc::new(Mutex::new(0)),
561                    tgt_non_granular_index,
562                }
563            }),
564            metadata: Arc::new(GeneralMetadata {
565                max_seq_len,
566                llg_factory: Some(llg_factory),
567                no_kv_cache: self.no_kv_cache,
568                no_prefix_cache: false,
569                num_hidden_layers,
570                eos_tok: eos,
571                kind: self.kind.clone(),
572                is_xlora,
573                activation_dtype: internal_dtype,
574                sliding_window: None,
575                cache_config,
576                cache_engine,
577                model_metadata: Some(Arc::new(model_config_metadata)),
578                modalities: Modalities {
579                    input: vec![SupportedModality::Text],
580                    output: vec![SupportedModality::Text],
581                },
582            }),
583            mapper: pipeline_mapper,
584        })))
585    }
586
587    fn get_id(&self) -> String {
588        self.xlora_model_id
589            .as_deref()
590            .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
591            .to_string()
592    }
593
594    fn get_kind(&self) -> ModelKind {
595        self.kind.clone()
596    }
597}
598
599impl PreProcessingMixin for GGUFPipeline {
600    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
601        Some(self.chat_template.clone())
602    }
603    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
604        None
605    }
606}
607
608impl IsqPipelineMixin for GGUFPipeline {
609    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
610        anyhow::bail!(
611            "You are trying to in-situ requantize a GGML model. This will not do anything."
612        )
613    }
614}
615
616impl CacheManagerMixin for GGUFPipeline {
617    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
618        if matches!(self.cache(), EitherCache::Full(_)) {
619            FullCacheManager.clone_in_cache(self, seqs, false)
620        } else {
621            NormalCacheManager.clone_in_cache(self, seqs, false)
622        }
623    }
624    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
625        if matches!(self.cache(), EitherCache::Full(_)) {
626            FullCacheManager.clone_out_cache(self, seqs, false)
627        } else {
628            NormalCacheManager.clone_out_cache(self, seqs, false)
629        }
630    }
631    fn set_none_cache(
632        &self,
633        seqs: &mut [&mut Sequence],
634        reset_non_granular: bool,
635        modify_draft_cache: bool,
636        load_preallocated_cache: bool,
637    ) {
638        if matches!(self.cache(), EitherCache::Full(_)) {
639            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
640        } else {
641            NormalCacheManager.set_none_cache(
642                self,
643                seqs,
644                modify_draft_cache,
645                load_preallocated_cache,
646            );
647        }
648        if reset_non_granular {
649            self.reset_non_granular_state()
650        }
651    }
652    fn cache(&self) -> &EitherCache {
653        match self.model {
654            Model::Llama(ref model) => &model.cache,
655            Model::Phi2(ref model) => &model.cache,
656            Model::XLoraLlama(ref model) => &model.cache,
657            Model::Phi3(ref model) => &model.cache,
658            Model::XLoraPhi3(ref model) => &model.cache,
659            Model::Starcoder2(ref model) => &model.cache,
660            Model::Qwen(ref model) => &model.cache,
661            Model::Qwen3(ref model) => &model.cache,
662            Model::Qwen3MoE(ref model) => &model.cache,
663        }
664    }
665}
666
667impl MetadataMixin for GGUFPipeline {
668    fn device(&self) -> Device {
669        match self.model {
670            Model::Llama(ref model) => model.device.clone(),
671            Model::Phi2(ref model) => model.device.clone(),
672            Model::XLoraLlama(ref model) => model.device.clone(),
673            Model::Phi3(ref model) => model.device.clone(),
674            Model::XLoraPhi3(ref model) => model.device.clone(),
675            Model::Starcoder2(ref model) => model.device.clone(),
676            Model::Qwen(ref model) => model.device.clone(),
677            Model::Qwen3(ref model) => model.device.clone(),
678            Model::Qwen3MoE(ref model) => model.device.clone(),
679        }
680    }
681    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
682        Some(self.tokenizer.clone())
683    }
684    fn name(&self) -> String {
685        self.model_id.clone()
686    }
687    fn reset_non_granular_state(&self) {
688        if let Some(s) = self.non_granular_state.as_ref() {
689            *self.cache().full().get_scalings_cache() = None;
690            *get_mut_arcmutex!(s.non_granular_index) = 0;
691        }
692    }
693    fn get_metadata(&self) -> Arc<GeneralMetadata> {
694        self.metadata.clone()
695    }
696    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
697        Some(&*self.mapper)
698    }
699}
700
701#[async_trait::async_trait]
702impl Pipeline for GGUFPipeline {
703    fn forward_inputs(
704        &mut self,
705        inputs: Box<dyn Any>,
706        return_raw_logits: bool,
707    ) -> Result<ForwardInputsResult, candle_core::Error> {
708        let ModelInputs {
709            input_ids,
710            input_ids_full,
711            seqlen_offsets,
712            seqlen_offsets_full,
713            context_lens,
714            position_ids: _, // NOTE(EricLBuehler): ignore, it is for phi3
715            paged_attn_meta,
716            flash_meta,
717            flash_meta_full,
718        } = *inputs.downcast().expect("Downcast failed.");
719        let metadata = self.get_metadata();
720        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
721            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
722            (Some(_), None) => {
723                // This can happen if Rust-side user code is wrong
724                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
725            }
726            (None, Some(_)) => {
727                // This should never happen but we handle it anyway
728                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
729            }
730            (None, None) => None,
731        };
732        let logits = match self.model {
733            Model::Llama(ref model) => {
734                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
735            }
736            Model::Phi2(ref model) => {
737                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
738            }
739            Model::XLoraLlama(ref model) => model.forward(
740                &input_ids,
741                input_ids_full.as_ref().unwrap_or(&input_ids),
742                &seqlen_offsets,
743                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
744                self.no_kv_cache,
745                &self.non_granular_state,
746                context_lens,
747                &flash_meta,
748                flash_meta_full.as_ref().unwrap_or(&flash_meta),
749            )?,
750            Model::Phi3(ref model) => {
751                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
752            }
753            Model::XLoraPhi3(ref model) => model.forward(
754                &input_ids,
755                input_ids_full.as_ref().unwrap_or(&input_ids),
756                &seqlen_offsets,
757                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
758                self.no_kv_cache,
759                &self.non_granular_state,
760                context_lens,
761                &flash_meta,
762                flash_meta_full.as_ref().unwrap_or(&flash_meta),
763            )?,
764            Model::Starcoder2(ref model) => {
765                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
766            }
767            Model::Qwen(ref model) => {
768                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
769            }
770            Model::Qwen3(ref model) => {
771                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
772            }
773            Model::Qwen3MoE(ref model) => {
774                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
775            }
776        };
777        if return_raw_logits {
778            Ok(ForwardInputsResult::RawLogits { logits })
779        } else {
780            Ok(ForwardInputsResult::CausalGeneration { logits })
781        }
782    }
783    async fn sample_causal_gen(
784        &self,
785        seqs: &mut [&mut Sequence],
786        logits: Vec<Tensor>,
787        prefix_cacher: &mut PrefixCacheManagerV2,
788        disable_eos_stop: bool,
789        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
790    ) -> Result<(), candle_core::Error> {
791        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
792    }
793    fn category(&self) -> ModelCategory {
794        ModelCategory::Text
795    }
796}
797
798// TODO
799impl AnyMoePipelineMixin for GGUFPipeline {}