mistralrs_core/pipeline/
gguf.rs

1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5    TokenSource,
6};
7use super::{
8    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9    MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::attention::ATTENTION_CHUNK_SIZE;
12use crate::device_map::{self, DeviceMapper};
13use crate::gguf::{
14    get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
15};
16use crate::gguf::{Content, GGUFArchitecture};
17use crate::kv_cache::{FullCacheManager, NormalCacheManager};
18use crate::lora::Ordering;
19use crate::paged_attention::{
20    calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
21};
22use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
23use crate::pipeline::loaders::DeviceMappedModelLoader;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::ChatTemplate;
26use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::xlora_models::NonGranularState;
33use crate::{
34    get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
35    Pipeline, Topology, TryIntoDType,
36};
37use crate::{
38    models::quantized_llama::ModelWeights as QLlama,
39    models::quantized_phi2::ModelWeights as QPhi,
40    models::quantized_phi3::ModelWeights as QPhi3,
41    models::quantized_qwen::ModelWeights as QQwen,
42    models::quantized_qwen3::ModelWeights as QQwen3,
43    models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
44    models::quantized_starcoder2::ModelWeights as QStarcoder2,
45    utils::tokens::get_token,
46    xlora_models::{XLoraQLlama, XLoraQPhi3},
47};
48use anyhow::{bail, Result};
49use candle_core::{Device, Tensor};
50use either::Either;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::IsqType;
53use rand_isaac::Isaac64Rng;
54use std::any::Any;
55use std::fs;
56use std::path::PathBuf;
57use std::str::FromStr;
58use std::sync::Arc;
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63enum Model {
64    Llama(QLlama),
65    Phi2(QPhi),
66    XLoraLlama(XLoraQLlama),
67    XLoraPhi3(XLoraQPhi3),
68    Phi3(QPhi3),
69    Starcoder2(QStarcoder2),
70    Qwen(QQwen),
71    Qwen3(QQwen3),
72    Qwen3MoE(QQwen3MoE),
73}
74
75pub struct GGUFPipeline {
76    model: Model,
77    tokenizer: Arc<Tokenizer>,
78    no_kv_cache: bool,
79    chat_template: Arc<ChatTemplate>,
80    model_id: String,
81    non_granular_state: Option<NonGranularState>,
82    metadata: Arc<GeneralMetadata>,
83    mapper: Box<dyn DeviceMapper + Send + Sync>,
84}
85
86/// Loader for a GGUF model.
87pub struct GGUFLoader {
88    model_id: Option<String>,
89    quantized_model_id: String,
90    quantized_filenames: Vec<String>,
91    xlora_model_id: Option<String>,
92    xlora_order: Option<Ordering>,
93    no_kv_cache: bool,
94    chat_template: Option<String>,
95    kind: ModelKind,
96    tgt_non_granular_index: Option<usize>,
97    config: GGUFSpecificConfig,
98    jinja_explicit: Option<String>,
99    lora_adapter_ids: Option<Vec<String>>,
100}
101
102#[derive(Clone, Default)]
103/// Config for a GGUF loader.
104pub struct GGUFSpecificConfig {
105    pub topology: Option<Topology>,
106}
107
108#[derive(Default)]
109/// A builder for a GGUF loader.
110pub struct GGUFLoaderBuilder {
111    model_id: Option<String>,
112    quantized_model_id: String,
113    quantized_filenames: Vec<String>,
114    xlora_model_id: Option<String>,
115    kind: ModelKind,
116    xlora_order: Option<Ordering>,
117    no_kv_cache: bool,
118    chat_template: Option<String>,
119    tgt_non_granular_index: Option<usize>,
120    config: GGUFSpecificConfig,
121    jinja_explicit: Option<String>,
122}
123
124impl GGUFLoaderBuilder {
125    /// Create a loader builder for a GGUF model. `tok_model_id` is the model ID where you can find a
126    /// `tokenizer_config.json` file. If the `chat_template` is specified, then it will be treated as a
127    /// path and used over remote files, removing all remote accesses.
128    pub fn new(
129        chat_template: Option<String>,
130        tok_model_id: Option<String>,
131        quantized_model_id: String,
132        quantized_filenames: Vec<String>,
133        config: GGUFSpecificConfig,
134        no_kv_cache: bool,
135        jinja_explicit: Option<String>,
136    ) -> Self {
137        let kind = ModelKind::GgufQuantized {
138            quant: QuantizationKind::Gguf,
139        };
140
141        Self {
142            chat_template,
143            model_id: tok_model_id,
144            kind,
145            quantized_filenames,
146            quantized_model_id,
147            config,
148            jinja_explicit,
149            no_kv_cache,
150            ..Default::default()
151        }
152    }
153
154    fn with_adapter(
155        mut self,
156        xlora_model_id: String,
157        xlora_order: Ordering,
158        no_kv_cache: bool,
159        tgt_non_granular_index: Option<usize>,
160    ) -> Self {
161        self.xlora_model_id = Some(xlora_model_id);
162        self.xlora_order = Some(xlora_order);
163        self.no_kv_cache = no_kv_cache;
164        self.tgt_non_granular_index = tgt_non_granular_index;
165        self.model_id = if let Some(id) = self.model_id {
166            Some(id)
167        } else {
168            info!(
169                "Using adapter base model ID: `{}`",
170                self.xlora_order.as_ref().unwrap().base_model_id
171            );
172            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173        };
174        self
175    }
176
177    pub fn with_xlora(
178        mut self,
179        xlora_model_id: String,
180        xlora_order: Ordering,
181        no_kv_cache: bool,
182        tgt_non_granular_index: Option<usize>,
183    ) -> Self {
184        self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
185
186        self.with_adapter(
187            xlora_model_id,
188            xlora_order,
189            no_kv_cache,
190            tgt_non_granular_index,
191        )
192    }
193
194    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
195        self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
196
197        self.with_adapter(lora_model_id, lora_order, false, None)
198    }
199
200    pub fn build(self) -> Box<dyn Loader> {
201        Box::new(GGUFLoader {
202            model_id: self.model_id,
203            xlora_model_id: self.xlora_model_id,
204            kind: self.kind,
205            xlora_order: self.xlora_order,
206            no_kv_cache: self.no_kv_cache,
207            chat_template: self.chat_template,
208            tgt_non_granular_index: self.tgt_non_granular_index,
209            quantized_filenames: self.quantized_filenames,
210            quantized_model_id: self.quantized_model_id,
211            config: self.config,
212            jinja_explicit: self.jinja_explicit,
213            lora_adapter_ids: None,
214        })
215    }
216}
217
218impl GGUFLoader {
219    #[allow(clippy::too_many_arguments)]
220    pub fn new(
221        model_id: Option<String>,
222        quantized_model_id: String,
223        quantized_filenames: Vec<String>,
224        xlora_model_id: Option<String>,
225        kind: ModelKind,
226        xlora_order: Option<Ordering>,
227        no_kv_cache: bool,
228        chat_template: Option<String>,
229        tgt_non_granular_index: Option<usize>,
230        config: GGUFSpecificConfig,
231        jinja_explicit: Option<String>,
232    ) -> Self {
233        let model_id = if let Some(id) = model_id {
234            Some(id)
235        } else if let Some(xlora_order) = xlora_order.clone() {
236            info!(
237                "Using adapter base model ID: `{}`",
238                xlora_order.base_model_id
239            );
240            Some(xlora_order.base_model_id.clone())
241        } else {
242            None
243        };
244        Self {
245            model_id,
246            quantized_model_id,
247            quantized_filenames,
248            xlora_model_id,
249            xlora_order,
250            no_kv_cache,
251            chat_template,
252            kind,
253            tgt_non_granular_index,
254            config,
255            jinja_explicit,
256            lora_adapter_ids: None,
257        }
258    }
259}
260
261impl Loader for GGUFLoader {
262    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
263    fn load_model_from_hf(
264        &self,
265        revision: Option<String>,
266        token_source: TokenSource,
267        dtype: &dyn TryIntoDType,
268        device: &Device,
269        silent: bool,
270        mapper: DeviceMapSetting,
271        in_situ_quant: Option<IsqType>,
272        paged_attn_config: Option<PagedAttentionConfig>,
273    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
274        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
275            LocalModelPaths,
276            &token_source,
277            revision,
278            self,
279            self.quantized_model_id.clone(),
280            self.quantized_filenames.clone(),
281            silent
282        );
283        self.load_model_from_path(
284            &paths?,
285            dtype,
286            device,
287            silent,
288            mapper,
289            in_situ_quant,
290            paged_attn_config,
291        )
292    }
293
294    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
295    fn load_model_from_path(
296        &self,
297        paths: &Box<dyn ModelPaths>,
298        dtype: &dyn TryIntoDType,
299        device: &Device,
300        silent: bool,
301        mut mapper: DeviceMapSetting,
302        in_situ_quant: Option<IsqType>,
303        mut paged_attn_config: Option<PagedAttentionConfig>,
304    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
305        if in_situ_quant.is_some() {
306            anyhow::bail!(
307                "You are trying to in-situ quantize a GGUF model. This will not do anything."
308            );
309        }
310
311        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
312
313        let mut readers = Vec::new();
314        for filename in paths.get_weight_filenames() {
315            readers.push(std::fs::File::open(filename)?);
316        }
317        let mut readers = readers.iter_mut().collect::<Vec<_>>();
318
319        let model = Content::from_readers(&mut readers)?;
320        if !silent {
321            model.print_metadata()?;
322        }
323        let arch = model.arch();
324
325        // If auto, convert to Map
326        let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
327        if let DeviceMapSetting::Auto(params) = mapper.clone() {
328            let devices = device_map::get_all_similar_devices(device)?;
329            // Initial dtype
330            let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
331
332            let model = GgufDeviceMapLoaderInner {
333                model: &model,
334                arch,
335            };
336
337            let layer_sizes_in_bytes =
338                model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1, None)?;
339            let non_mapped_size_in_bytes =
340                model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1, None)?;
341            let total_model_size_in_bytes =
342                layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
343
344            let new = model.get_device_layers(
345                "this is a dummy config!",
346                num_layers,
347                layer_sizes_in_bytes,
348                non_mapped_size_in_bytes,
349                total_model_size_in_bytes,
350                &devices,
351                dtype,
352                &params,
353                paged_attn_config.as_ref(),
354            )?;
355            mapper = DeviceMapSetting::Map(new);
356        }
357
358        #[cfg(feature = "cuda")]
359        if let Device::Cuda(dev) = &device {
360            unsafe { dev.disable_event_tracking() };
361        }
362
363        let pipeline_mapper =
364            mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
365        let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
366        let mut layer_devices = Vec::new();
367        for layer in 0..num_layers {
368            let device = mapper.device_for(layer, false).cloned();
369            layer_devices.push(device);
370        }
371
372        // TODO: PagedAttention is not supported with CPU for now.
373        // This check is not really necessary because `get_device_layers` should prevent it.
374        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
375        if mapping_uses_cpu {
376            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
377            paged_attn_config = None;
378        }
379
380        let GgufTokenizerConversion {
381            tokenizer,
382            bos,
383            eos,
384            unk,
385        } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
386            convert_gguf_to_hf_tokenizer(&model)?
387        } else {
388            GgufTokenizerConversion {
389                tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
390                bos: None,
391                eos: None,
392                unk: None,
393            }
394        };
395
396        // Only load gguf chat template if there is nothing else
397        let gguf_chat_template =
398            if paths.get_template_filename().is_none() && self.chat_template.is_none() {
399                get_gguf_chat_template(&model)?
400            } else {
401                None
402            };
403
404        let has_adapter = self.kind.is_adapted();
405        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
406
407        let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
408            warn!("Adapter models do not currently support PagedAttention, running without");
409            None
410        } else {
411            paged_attn_config
412        };
413
414        let model_config_metadata: ContentConfig = (&model).into();
415        let internal_dtype = mapper.get_min_dtype(dtype)?;
416
417        let model_config = {
418            // Base config (quantization only):
419            let quant = ModelConfig::ParamsGGUF(
420                model,
421                (device, mapper).into(),
422                if paged_attn_config.is_some() {
423                    AttentionImplementation::PagedAttention
424                } else {
425                    AttentionImplementation::Eager
426                },
427                internal_dtype,
428            );
429
430            // With optional adapter config:
431            let mut adapter = None;
432            if has_adapter {
433                adapter.replace(ModelConfig::Adapter::try_new(
434                    paths, device, silent, is_xlora,
435                )?);
436            }
437
438            ModelConfig::ModelParams::new(quant, adapter)
439        };
440
441        // Config into model:
442        let model = match self.kind {
443            ModelKind::GgufQuantized { .. } => match arch {
444                GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
445                GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
446                GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
447                GGUFArchitecture::Starcoder2 => {
448                    Model::Starcoder2(QStarcoder2::try_from(model_config)?)
449                }
450                GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
451                GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
452                GGUFArchitecture::Qwen3MoE => Model::Qwen3MoE(QQwen3MoE::try_from(model_config)?),
453                a => bail!("Unsupported architecture `{a:?}` for GGUF"),
454            },
455            ModelKind::GgufAdapter { adapter, .. } => match arch {
456                GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
457                GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
458                a => bail!(
459                    "Unsupported architecture `{a:?}` for GGUF {kind}",
460                    kind = adapter.pretty_name()
461                ),
462            },
463            _ => unreachable!(),
464        };
465
466        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
467            let model_config: &dyn ModelConfigLike = &model_config_metadata;
468            let cache_config = calculate_cache_config(
469                paged_attn_config.mem_gpu,
470                paged_attn_config.mem_cpu,
471                paged_attn_config.block_size,
472                internal_dtype,
473                paged_attn_config.cache_type,
474                model_config,
475                device,
476                &layer_devices,
477                silent,
478            )?;
479            let cache_engine = CacheEngine::new(
480                model_config,
481                &cache_config,
482                internal_dtype,
483                device,
484                layer_devices,
485            )?;
486            (Some(cache_config), Some(cache_engine))
487        } else {
488            (None, None)
489        };
490
491        let gen_conf: Option<GenerationConfig> = paths
492            .get_gen_conf_filename()
493            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
494        let chat_template_explicit = paths
495            .get_chat_template_explicit()
496            .as_ref()
497            .map(|x| x.to_string_lossy().to_string());
498        let mut chat_template = get_chat_template(
499            paths,
500            self.jinja_explicit.as_ref(),
501            chat_template_explicit.as_ref(),
502            self.chat_template.as_ref(),
503            gguf_chat_template,
504        );
505
506        let max_seq_len = match model {
507            Model::Llama(ref l) => l.max_seq_len,
508            Model::Phi2(ref p) => p.max_seq_len,
509            Model::XLoraLlama(ref xl) => xl.max_seq_len,
510            Model::Phi3(ref p) => p.max_seq_len,
511            Model::XLoraPhi3(ref p) => p.max_seq_len,
512            Model::Starcoder2(ref p) => p.max_seq_len,
513            Model::Qwen(ref p) => p.max_seq_len,
514            Model::Qwen3(ref p) => p.max_seq_len,
515            Model::Qwen3MoE(ref p) => p.max_seq_len,
516        };
517        let llg_factory = build_llg_factory(tokenizer.clone())?;
518        let num_hidden_layers = match model {
519            Model::Llama(ref model) => model.cache.normal().0.len(),
520            Model::Phi2(ref model) => model.cache.normal().0.len(),
521            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
522            Model::Phi3(ref model) => model.cache.normal().0.len(),
523            Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
524            Model::Starcoder2(ref model) => model.cache.normal().0.len(),
525            Model::Qwen(ref model) => model.cache.normal().0.len(),
526            Model::Qwen3(ref model) => model.cache.normal().0.len(),
527            Model::Qwen3MoE(ref model) => model.cache.normal().0.len(),
528        };
529
530        if chat_template.bos_token.is_none() {
531            if let Some(v) = bos {
532                chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
533            }
534        }
535        if chat_template.eos_token.is_none() {
536            if let Some(v) = eos {
537                chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
538            }
539        }
540        if chat_template.unk_token.is_none() {
541            if let Some(v) = unk {
542                chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(v)));
543            }
544        }
545
546        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
547        Ok(Arc::new(Mutex::new(GGUFPipeline {
548            model,
549            tokenizer: tokenizer.into(),
550            no_kv_cache: self.no_kv_cache,
551            chat_template: Arc::new(chat_template),
552            model_id: self
553                .model_id
554                .clone()
555                .unwrap_or(self.quantized_model_id.clone()),
556            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
557                NonGranularState {
558                    non_granular_index: Arc::new(Mutex::new(0)),
559                    tgt_non_granular_index,
560                }
561            }),
562            metadata: Arc::new(GeneralMetadata {
563                max_seq_len,
564                llg_factory: Some(llg_factory),
565                no_kv_cache: self.no_kv_cache,
566                no_prefix_cache: false,
567                num_hidden_layers,
568                eos_tok: eos,
569                kind: self.kind.clone(),
570                is_xlora,
571                activation_dtype: internal_dtype,
572                sliding_window: None,
573                cache_config,
574                cache_engine,
575                model_metadata: Some(Arc::new(model_config_metadata)),
576                modalities: Modalities {
577                    input: vec![SupportedModality::Text],
578                    output: vec![SupportedModality::Text],
579                },
580            }),
581            mapper: pipeline_mapper,
582        })))
583    }
584
585    fn get_id(&self) -> String {
586        self.xlora_model_id
587            .as_deref()
588            .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
589            .to_string()
590    }
591
592    fn get_kind(&self) -> ModelKind {
593        self.kind.clone()
594    }
595}
596
597impl PreProcessingMixin for GGUFPipeline {
598    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
599        Some(self.chat_template.clone())
600    }
601    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
602        None
603    }
604}
605
606impl IsqPipelineMixin for GGUFPipeline {
607    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
608        anyhow::bail!(
609            "You are trying to in-situ requantize a GGML model. This will not do anything."
610        )
611    }
612}
613
614impl CacheManagerMixin for GGUFPipeline {
615    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
616        if matches!(self.cache(), EitherCache::Full(_)) {
617            FullCacheManager.clone_in_cache(self, seqs, false)
618        } else {
619            NormalCacheManager.clone_in_cache(self, seqs, false)
620        }
621    }
622    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
623        if matches!(self.cache(), EitherCache::Full(_)) {
624            FullCacheManager.clone_out_cache(self, seqs, false)
625        } else {
626            NormalCacheManager.clone_out_cache(self, seqs, false)
627        }
628    }
629    fn set_none_cache(
630        &self,
631        seqs: &mut [&mut Sequence],
632        reset_non_granular: bool,
633        modify_draft_cache: bool,
634        load_preallocated_cache: bool,
635    ) {
636        if matches!(self.cache(), EitherCache::Full(_)) {
637            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
638        } else {
639            NormalCacheManager.set_none_cache(
640                self,
641                seqs,
642                modify_draft_cache,
643                load_preallocated_cache,
644            );
645        }
646        if reset_non_granular {
647            self.reset_non_granular_state()
648        }
649    }
650    fn cache(&self) -> &EitherCache {
651        match self.model {
652            Model::Llama(ref model) => &model.cache,
653            Model::Phi2(ref model) => &model.cache,
654            Model::XLoraLlama(ref model) => &model.cache,
655            Model::Phi3(ref model) => &model.cache,
656            Model::XLoraPhi3(ref model) => &model.cache,
657            Model::Starcoder2(ref model) => &model.cache,
658            Model::Qwen(ref model) => &model.cache,
659            Model::Qwen3(ref model) => &model.cache,
660            Model::Qwen3MoE(ref model) => &model.cache,
661        }
662    }
663}
664
665impl MetadataMixin for GGUFPipeline {
666    fn device(&self) -> Device {
667        match self.model {
668            Model::Llama(ref model) => model.device.clone(),
669            Model::Phi2(ref model) => model.device.clone(),
670            Model::XLoraLlama(ref model) => model.device.clone(),
671            Model::Phi3(ref model) => model.device.clone(),
672            Model::XLoraPhi3(ref model) => model.device.clone(),
673            Model::Starcoder2(ref model) => model.device.clone(),
674            Model::Qwen(ref model) => model.device.clone(),
675            Model::Qwen3(ref model) => model.device.clone(),
676            Model::Qwen3MoE(ref model) => model.device.clone(),
677        }
678    }
679    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
680        Some(self.tokenizer.clone())
681    }
682    fn name(&self) -> String {
683        self.model_id.clone()
684    }
685    fn reset_non_granular_state(&self) {
686        if let Some(s) = self.non_granular_state.as_ref() {
687            *self.cache().full().get_scalings_cache() = None;
688            *get_mut_arcmutex!(s.non_granular_index) = 0;
689        }
690    }
691    fn get_metadata(&self) -> Arc<GeneralMetadata> {
692        self.metadata.clone()
693    }
694    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
695        Some(&*self.mapper)
696    }
697}
698
699#[async_trait::async_trait]
700impl Pipeline for GGUFPipeline {
701    fn forward_inputs(
702        &mut self,
703        inputs: Box<dyn Any>,
704        return_raw_logits: bool,
705    ) -> Result<ForwardInputsResult, candle_core::Error> {
706        let ModelInputs {
707            input_ids,
708            input_ids_full,
709            seqlen_offsets,
710            seqlen_offsets_full,
711            context_lens,
712            position_ids: _, // NOTE(EricLBuehler): ignore, it is for phi3
713            paged_attn_meta,
714            flash_meta,
715            flash_meta_full,
716        } = *inputs.downcast().expect("Downcast failed.");
717        let metadata = self.get_metadata();
718        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
719            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
720            (Some(_), None) => {
721                // This can happen if Rust-side user code is wrong
722                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
723            }
724            (None, Some(_)) => {
725                // This should never happen but we handle it anyway
726                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
727            }
728            (None, None) => None,
729        };
730        let logits = match self.model {
731            Model::Llama(ref model) => {
732                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
733            }
734            Model::Phi2(ref model) => {
735                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
736            }
737            Model::XLoraLlama(ref model) => model.forward(
738                &input_ids,
739                input_ids_full.as_ref().unwrap_or(&input_ids),
740                &seqlen_offsets,
741                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
742                self.no_kv_cache,
743                &self.non_granular_state,
744                context_lens,
745                &flash_meta,
746                flash_meta_full.as_ref().unwrap_or(&flash_meta),
747            )?,
748            Model::Phi3(ref model) => {
749                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
750            }
751            Model::XLoraPhi3(ref model) => model.forward(
752                &input_ids,
753                input_ids_full.as_ref().unwrap_or(&input_ids),
754                &seqlen_offsets,
755                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
756                self.no_kv_cache,
757                &self.non_granular_state,
758                context_lens,
759                &flash_meta,
760                flash_meta_full.as_ref().unwrap_or(&flash_meta),
761            )?,
762            Model::Starcoder2(ref model) => {
763                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
764            }
765            Model::Qwen(ref model) => {
766                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
767            }
768            Model::Qwen3(ref model) => {
769                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
770            }
771            Model::Qwen3MoE(ref model) => {
772                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
773            }
774        };
775        if return_raw_logits {
776            Ok(ForwardInputsResult::RawLogits { logits })
777        } else {
778            Ok(ForwardInputsResult::CausalGeneration { logits })
779        }
780    }
781    async fn sample_causal_gen(
782        &self,
783        seqs: &mut [&mut Sequence],
784        logits: Vec<Tensor>,
785        prefix_cacher: &mut PrefixCacheManagerV2,
786        disable_eos_stop: bool,
787        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
788    ) -> Result<(), candle_core::Error> {
789        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
790    }
791    fn category(&self) -> ModelCategory {
792        ModelCategory::Text
793    }
794}
795
796// TODO
797impl AnyMoePipelineMixin for GGUFPipeline {}