mistralrs_core/pipeline/
gguf.rs

1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5    TokenSource,
6};
7use super::{
8    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9    MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::device_map::{self, DeviceMapper};
12use crate::gguf::{
13    get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
14};
15use crate::gguf::{Content, GGUFArchitecture};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::lora::Ordering;
18use crate::paged_attention::{
19    calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
20};
21use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
22use crate::pipeline::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
23use crate::pipeline::loaders::DeviceMappedModelLoader;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::ChatTemplate;
26use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::xlora_models::NonGranularState;
33use crate::{
34    get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
35    Pipeline, Topology, TryIntoDType,
36};
37use crate::{
38    models::quantized_llama::ModelWeights as QLlama,
39    models::quantized_phi2::ModelWeights as QPhi,
40    models::quantized_phi3::ModelWeights as QPhi3,
41    models::quantized_qwen::ModelWeights as QQwen,
42    models::quantized_qwen3::ModelWeights as QQwen3,
43    models::quantized_starcoder2::ModelWeights as QStarcoder2,
44    utils::tokens::get_token,
45    xlora_models::{XLoraQLlama, XLoraQPhi3},
46};
47use anyhow::{bail, Result};
48use candle_core::{Device, Tensor};
49use either::Either;
50use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
51use mistralrs_quant::IsqType;
52use rand_isaac::Isaac64Rng;
53use std::any::Any;
54use std::fs;
55use std::num::{NonZero, NonZeroUsize};
56use std::path::PathBuf;
57use std::str::FromStr;
58use std::sync::Arc;
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63enum Model {
64    Llama(QLlama),
65    Phi2(QPhi),
66    XLoraLlama(XLoraQLlama),
67    XLoraPhi3(XLoraQPhi3),
68    Phi3(QPhi3),
69    Starcoder2(QStarcoder2),
70    Qwen(QQwen),
71    Qwen3(QQwen3),
72}
73
74pub struct GGUFPipeline {
75    model: Model,
76    tokenizer: Arc<Tokenizer>,
77    no_kv_cache: bool,
78    chat_template: Arc<ChatTemplate>,
79    model_id: String,
80    non_granular_state: Option<NonGranularState>,
81    metadata: Arc<GeneralMetadata>,
82    mapper: Box<dyn DeviceMapper + Send + Sync>,
83}
84
85/// Loader for a GGUF model.
86pub struct GGUFLoader {
87    model_id: Option<String>,
88    quantized_model_id: String,
89    quantized_filenames: Vec<String>,
90    xlora_model_id: Option<String>,
91    xlora_order: Option<Ordering>,
92    no_kv_cache: bool,
93    chat_template: Option<String>,
94    kind: ModelKind,
95    tgt_non_granular_index: Option<usize>,
96    config: GGUFSpecificConfig,
97    jinja_explicit: Option<String>,
98    lora_adapter_ids: Option<Vec<String>>,
99}
100
101#[derive(Clone, Default)]
102/// Config for a GGUF loader.
103pub struct GGUFSpecificConfig {
104    pub prompt_chunksize: Option<NonZeroUsize>,
105    pub topology: Option<Topology>,
106}
107
108#[derive(Default)]
109/// A builder for a GGUF loader.
110pub struct GGUFLoaderBuilder {
111    model_id: Option<String>,
112    quantized_model_id: String,
113    quantized_filenames: Vec<String>,
114    xlora_model_id: Option<String>,
115    kind: ModelKind,
116    xlora_order: Option<Ordering>,
117    no_kv_cache: bool,
118    chat_template: Option<String>,
119    tgt_non_granular_index: Option<usize>,
120    config: GGUFSpecificConfig,
121    jinja_explicit: Option<String>,
122}
123
124impl GGUFLoaderBuilder {
125    /// Create a loader builder for a GGUF model. `tok_model_id` is the model ID where you can find a
126    /// `tokenizer_config.json` file. If the `chat_template` is specified, then it will be treated as a
127    /// path and used over remote files, removing all remote accesses.
128    pub fn new(
129        chat_template: Option<String>,
130        tok_model_id: Option<String>,
131        quantized_model_id: String,
132        quantized_filenames: Vec<String>,
133        config: GGUFSpecificConfig,
134        no_kv_cache: bool,
135        jinja_explicit: Option<String>,
136    ) -> Self {
137        let kind = ModelKind::GgufQuantized {
138            quant: QuantizationKind::Gguf,
139        };
140
141        Self {
142            chat_template,
143            model_id: tok_model_id,
144            kind,
145            quantized_filenames,
146            quantized_model_id,
147            config,
148            jinja_explicit,
149            no_kv_cache,
150            ..Default::default()
151        }
152    }
153
154    fn with_adapter(
155        mut self,
156        xlora_model_id: String,
157        xlora_order: Ordering,
158        no_kv_cache: bool,
159        tgt_non_granular_index: Option<usize>,
160    ) -> Self {
161        self.xlora_model_id = Some(xlora_model_id);
162        self.xlora_order = Some(xlora_order);
163        self.no_kv_cache = no_kv_cache;
164        self.tgt_non_granular_index = tgt_non_granular_index;
165        self.model_id = if let Some(id) = self.model_id {
166            Some(id)
167        } else {
168            info!(
169                "Using adapter base model ID: `{}`",
170                self.xlora_order.as_ref().unwrap().base_model_id
171            );
172            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173        };
174        self
175    }
176
177    pub fn with_xlora(
178        mut self,
179        xlora_model_id: String,
180        xlora_order: Ordering,
181        no_kv_cache: bool,
182        tgt_non_granular_index: Option<usize>,
183    ) -> Self {
184        self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
185
186        self.with_adapter(
187            xlora_model_id,
188            xlora_order,
189            no_kv_cache,
190            tgt_non_granular_index,
191        )
192    }
193
194    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
195        self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
196
197        self.with_adapter(lora_model_id, lora_order, false, None)
198    }
199
200    pub fn build(self) -> Box<dyn Loader> {
201        Box::new(GGUFLoader {
202            model_id: self.model_id,
203            xlora_model_id: self.xlora_model_id,
204            kind: self.kind,
205            xlora_order: self.xlora_order,
206            no_kv_cache: self.no_kv_cache,
207            chat_template: self.chat_template,
208            tgt_non_granular_index: self.tgt_non_granular_index,
209            quantized_filenames: self.quantized_filenames,
210            quantized_model_id: self.quantized_model_id,
211            config: self.config,
212            jinja_explicit: self.jinja_explicit,
213            lora_adapter_ids: None,
214        })
215    }
216}
217
218impl GGUFLoader {
219    #[allow(clippy::too_many_arguments)]
220    pub fn new(
221        model_id: Option<String>,
222        quantized_model_id: String,
223        quantized_filenames: Vec<String>,
224        xlora_model_id: Option<String>,
225        kind: ModelKind,
226        xlora_order: Option<Ordering>,
227        no_kv_cache: bool,
228        chat_template: Option<String>,
229        tgt_non_granular_index: Option<usize>,
230        config: GGUFSpecificConfig,
231        jinja_explicit: Option<String>,
232    ) -> Self {
233        let model_id = if let Some(id) = model_id {
234            Some(id)
235        } else if let Some(xlora_order) = xlora_order.clone() {
236            info!(
237                "Using adapter base model ID: `{}`",
238                xlora_order.base_model_id
239            );
240            Some(xlora_order.base_model_id.clone())
241        } else {
242            None
243        };
244        Self {
245            model_id,
246            quantized_model_id,
247            quantized_filenames,
248            xlora_model_id,
249            xlora_order,
250            no_kv_cache,
251            chat_template,
252            kind,
253            tgt_non_granular_index,
254            config,
255            jinja_explicit,
256            lora_adapter_ids: None,
257        }
258    }
259}
260
261impl Loader for GGUFLoader {
262    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
263    fn load_model_from_hf(
264        &self,
265        revision: Option<String>,
266        token_source: TokenSource,
267        dtype: &dyn TryIntoDType,
268        device: &Device,
269        silent: bool,
270        mapper: DeviceMapSetting,
271        in_situ_quant: Option<IsqType>,
272        paged_attn_config: Option<PagedAttentionConfig>,
273    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
274        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
275            LocalModelPaths,
276            &token_source,
277            revision,
278            self,
279            self.quantized_model_id.clone(),
280            self.quantized_filenames.clone(),
281            silent
282        );
283        self.load_model_from_path(
284            &paths?,
285            dtype,
286            device,
287            silent,
288            mapper,
289            in_situ_quant,
290            paged_attn_config,
291        )
292    }
293
294    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
295    fn load_model_from_path(
296        &self,
297        paths: &Box<dyn ModelPaths>,
298        dtype: &dyn TryIntoDType,
299        device: &Device,
300        silent: bool,
301        mut mapper: DeviceMapSetting,
302        in_situ_quant: Option<IsqType>,
303        mut paged_attn_config: Option<PagedAttentionConfig>,
304    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
305        if in_situ_quant.is_some() {
306            anyhow::bail!(
307                "You are trying to in-situ quantize a GGUF model. This will not do anything."
308            );
309        }
310
311        // Apply default prompt size here
312        let prompt_chunksize = self
313            .config
314            .prompt_chunksize
315            .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
316            .get();
317
318        info!("Prompt chunk size is {prompt_chunksize}.",);
319
320        let mut readers = Vec::new();
321        for filename in paths.get_weight_filenames() {
322            readers.push(std::fs::File::open(filename)?);
323        }
324        let mut readers = readers.iter_mut().collect::<Vec<_>>();
325
326        let model = Content::from_readers(&mut readers)?;
327        if !silent {
328            model.print_metadata()?;
329        }
330        let arch = model.arch();
331
332        // If auto, convert to Map
333        let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
334        if let DeviceMapSetting::Auto(params) = mapper.clone() {
335            let devices = device_map::get_all_similar_devices(device)?;
336            // Initial dtype
337            let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
338
339            let model = GgufDeviceMapLoaderInner {
340                model: &model,
341                arch,
342            };
343
344            let layer_sizes_in_bytes =
345                model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1)?;
346            let non_mapped_size_in_bytes =
347                model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1)?;
348            let total_model_size_in_bytes =
349                layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
350
351            let new = model.get_device_layers(
352                "this is a dummy config!",
353                num_layers,
354                layer_sizes_in_bytes,
355                non_mapped_size_in_bytes,
356                total_model_size_in_bytes,
357                &devices,
358                dtype,
359                &params,
360                prompt_chunksize,
361                paged_attn_config.as_ref(),
362            )?;
363            mapper = DeviceMapSetting::Map(new);
364        }
365
366        let pipeline_mapper =
367            mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
368        let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
369        let mut layer_devices = Vec::new();
370        for layer in 0..num_layers {
371            let device = mapper.device_for(layer, false).cloned();
372            layer_devices.push(device);
373        }
374
375        // TODO: PagedAttention is not supported with CPU for now.
376        // This check is not really necessary because `get_device_layers` should prevent it.
377        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
378        if mapping_uses_cpu {
379            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
380            paged_attn_config = None;
381        }
382
383        let GgufTokenizerConversion {
384            tokenizer,
385            bos,
386            eos,
387            unk,
388        } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
389            convert_gguf_to_hf_tokenizer(&model)?
390        } else {
391            GgufTokenizerConversion {
392                tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
393                bos: None,
394                eos: None,
395                unk: None,
396            }
397        };
398
399        // Only load gguf chat template if there is nothing else
400        let gguf_chat_template =
401            if paths.get_template_filename().is_none() && self.chat_template.is_none() {
402                get_gguf_chat_template(&model)?
403            } else {
404                None
405            };
406
407        let has_adapter = self.kind.is_adapted();
408        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
409
410        let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
411            warn!("Adapter models do not currently support PagedAttention, running without");
412            None
413        } else {
414            paged_attn_config
415        };
416
417        let model_config_metadata: ContentConfig = (&model).into();
418        let internal_dtype = mapper.get_min_dtype(dtype)?;
419
420        let model_config = {
421            // Base config (quantization only):
422            let quant = ModelConfig::ParamsGGUF(
423                model,
424                (device, mapper).into(),
425                if paged_attn_config.is_some() {
426                    AttentionImplementation::PagedAttention
427                } else {
428                    AttentionImplementation::Eager
429                },
430                internal_dtype,
431            );
432
433            // With optional adapter config:
434            let mut adapter = None;
435            if has_adapter {
436                adapter.replace(ModelConfig::Adapter::try_new(
437                    paths, device, silent, is_xlora,
438                )?);
439            }
440
441            ModelConfig::ModelParams::new(quant, adapter)
442        };
443
444        // Config into model:
445        let model = match self.kind {
446            ModelKind::GgufQuantized { .. } => match arch {
447                GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
448                GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
449                GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
450                GGUFArchitecture::Starcoder2 => {
451                    Model::Starcoder2(QStarcoder2::try_from(model_config)?)
452                }
453                GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
454                GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
455                a => bail!("Unsupported architecture `{a:?}` for GGUF"),
456            },
457            ModelKind::GgufAdapter { adapter, .. } => match arch {
458                GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
459                GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
460                a => bail!(
461                    "Unsupported architecture `{a:?}` for GGUF {kind}",
462                    kind = adapter.pretty_name()
463                ),
464            },
465            _ => unreachable!(),
466        };
467
468        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
469            let model_config: &dyn ModelConfigLike = &model_config_metadata;
470            let cache_config = calculate_cache_config(
471                paged_attn_config.mem_gpu,
472                paged_attn_config.mem_cpu,
473                paged_attn_config.block_size,
474                internal_dtype,
475                paged_attn_config.cache_type,
476                model_config,
477                device,
478                &layer_devices,
479                silent,
480            )?;
481            let cache_engine = CacheEngine::new(
482                model_config,
483                &cache_config,
484                internal_dtype,
485                device,
486                layer_devices,
487            )?;
488            (Some(cache_config), Some(cache_engine))
489        } else {
490            (None, None)
491        };
492
493        let gen_conf: Option<GenerationConfig> = paths
494            .get_gen_conf_filename()
495            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
496        let chat_template_explicit = paths
497            .get_chat_template_explicit()
498            .as_ref()
499            .map(|x| x.to_string_lossy().to_string());
500        let mut chat_template = get_chat_template(
501            paths,
502            self.jinja_explicit.as_ref(),
503            chat_template_explicit.as_ref(),
504            self.chat_template.as_ref(),
505            gguf_chat_template,
506        );
507
508        let max_seq_len = match model {
509            Model::Llama(ref l) => l.max_seq_len,
510            Model::Phi2(ref p) => p.max_seq_len,
511            Model::XLoraLlama(ref xl) => xl.max_seq_len,
512            Model::Phi3(ref p) => p.max_seq_len,
513            Model::XLoraPhi3(ref p) => p.max_seq_len,
514            Model::Starcoder2(ref p) => p.max_seq_len,
515            Model::Qwen(ref p) => p.max_seq_len,
516            Model::Qwen3(ref p) => p.max_seq_len,
517        };
518        let llg_factory = build_llg_factory(tokenizer.clone())?;
519        let num_hidden_layers = match model {
520            Model::Llama(ref model) => model.cache.normal().0.len(),
521            Model::Phi2(ref model) => model.cache.normal().0.len(),
522            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
523            Model::Phi3(ref model) => model.cache.normal().0.len(),
524            Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
525            Model::Starcoder2(ref model) => model.cache.normal().0.len(),
526            Model::Qwen(ref model) => model.cache.normal().0.len(),
527            Model::Qwen3(ref model) => model.cache.normal().0.len(),
528        };
529
530        if chat_template.bos_token.is_none() && bos.is_some() {
531            chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos.unwrap())));
532        }
533        if chat_template.eos_token.is_none() && eos.is_some() {
534            chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos.unwrap())));
535        }
536        if chat_template.unk_token.is_none() && unk.is_some() {
537            chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(unk.unwrap())));
538        }
539
540        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
541        Ok(Arc::new(Mutex::new(GGUFPipeline {
542            model,
543            tokenizer: tokenizer.into(),
544            no_kv_cache: self.no_kv_cache,
545            chat_template: Arc::new(chat_template),
546            model_id: self
547                .model_id
548                .clone()
549                .unwrap_or(self.quantized_model_id.clone()),
550            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
551                NonGranularState {
552                    non_granular_index: Arc::new(Mutex::new(0)),
553                    tgt_non_granular_index,
554                }
555            }),
556            metadata: Arc::new(GeneralMetadata {
557                max_seq_len,
558                llg_factory: Some(llg_factory),
559                no_kv_cache: self.no_kv_cache,
560                no_prefix_cache: false,
561                num_hidden_layers,
562                eos_tok: eos,
563                kind: self.kind.clone(),
564                is_xlora,
565                activation_dtype: internal_dtype,
566                sliding_window: None,
567                cache_config,
568                cache_engine,
569                prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
570                model_metadata: Some(Arc::new(model_config_metadata)),
571                modalities: Modalities {
572                    input: vec![SupportedModality::Text],
573                    output: vec![SupportedModality::Text],
574                },
575            }),
576            mapper: pipeline_mapper,
577        })))
578    }
579
580    fn get_id(&self) -> String {
581        self.xlora_model_id
582            .as_deref()
583            .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
584            .to_string()
585    }
586
587    fn get_kind(&self) -> ModelKind {
588        self.kind.clone()
589    }
590}
591
592impl PreProcessingMixin for GGUFPipeline {
593    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
594        Some(self.chat_template.clone())
595    }
596    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
597        None
598    }
599}
600
601impl IsqPipelineMixin for GGUFPipeline {
602    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
603        anyhow::bail!(
604            "You are trying to in-situ requantize a GGML model. This will not do anything."
605        )
606    }
607}
608
609impl CacheManagerMixin for GGUFPipeline {
610    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
611        if matches!(self.cache(), EitherCache::Full(_)) {
612            FullCacheManager.clone_in_cache(self, seqs, false)
613        } else {
614            NormalCacheManager.clone_in_cache(self, seqs, false)
615        }
616    }
617    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
618        if matches!(self.cache(), EitherCache::Full(_)) {
619            FullCacheManager.clone_out_cache(self, seqs, false)
620        } else {
621            NormalCacheManager.clone_out_cache(self, seqs, false)
622        }
623    }
624    fn set_none_cache(
625        &self,
626        seqs: &mut [&mut Sequence],
627        reset_non_granular: bool,
628        modify_draft_cache: bool,
629        load_preallocated_cache: bool,
630    ) {
631        if matches!(self.cache(), EitherCache::Full(_)) {
632            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
633        } else {
634            NormalCacheManager.set_none_cache(
635                self,
636                seqs,
637                modify_draft_cache,
638                load_preallocated_cache,
639            );
640        }
641        if reset_non_granular {
642            self.reset_non_granular_state()
643        }
644    }
645    fn cache(&self) -> &EitherCache {
646        match self.model {
647            Model::Llama(ref model) => &model.cache,
648            Model::Phi2(ref model) => &model.cache,
649            Model::XLoraLlama(ref model) => &model.cache,
650            Model::Phi3(ref model) => &model.cache,
651            Model::XLoraPhi3(ref model) => &model.cache,
652            Model::Starcoder2(ref model) => &model.cache,
653            Model::Qwen(ref model) => &model.cache,
654            Model::Qwen3(ref model) => &model.cache,
655        }
656    }
657}
658
659impl MetadataMixin for GGUFPipeline {
660    fn device(&self) -> Device {
661        match self.model {
662            Model::Llama(ref model) => model.device.clone(),
663            Model::Phi2(ref model) => model.device.clone(),
664            Model::XLoraLlama(ref model) => model.device.clone(),
665            Model::Phi3(ref model) => model.device.clone(),
666            Model::XLoraPhi3(ref model) => model.device.clone(),
667            Model::Starcoder2(ref model) => model.device.clone(),
668            Model::Qwen(ref model) => model.device.clone(),
669            Model::Qwen3(ref model) => model.device.clone(),
670        }
671    }
672    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
673        Some(self.tokenizer.clone())
674    }
675    fn name(&self) -> String {
676        self.model_id.clone()
677    }
678    fn reset_non_granular_state(&self) {
679        if let Some(s) = self.non_granular_state.as_ref() {
680            *self.cache().full().get_scalings_cache() = None;
681            *get_mut_arcmutex!(s.non_granular_index) = 0;
682        }
683    }
684    fn get_metadata(&self) -> Arc<GeneralMetadata> {
685        self.metadata.clone()
686    }
687    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
688        Some(&*self.mapper)
689    }
690}
691
692#[async_trait::async_trait]
693impl Pipeline for GGUFPipeline {
694    fn forward_inputs(
695        &mut self,
696        inputs: Box<dyn Any>,
697        return_raw_logits: bool,
698    ) -> Result<ForwardInputsResult, candle_core::Error> {
699        let ModelInputs {
700            input_ids,
701            input_ids_full,
702            seqlen_offsets,
703            seqlen_offsets_full,
704            context_lens,
705            position_ids: _, // NOTE(EricLBuehler): ignore, it is for phi3
706            paged_attn_meta,
707            flash_meta,
708            flash_meta_full,
709        } = *inputs.downcast().expect("Downcast failed.");
710        let metadata = self.get_metadata();
711        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
712            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
713            (Some(_), None) => {
714                // This can happen if Rust-side user code is wrong
715                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
716            }
717            (None, Some(_)) => {
718                // This should never happen but we handle it anyway
719                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
720            }
721            (None, None) => None,
722        };
723        let logits = match self.model {
724            Model::Llama(ref model) => {
725                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
726            }
727            Model::Phi2(ref model) => {
728                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
729            }
730            Model::XLoraLlama(ref model) => model.forward(
731                &input_ids,
732                input_ids_full.as_ref().unwrap_or(&input_ids),
733                &seqlen_offsets,
734                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
735                self.no_kv_cache,
736                &self.non_granular_state,
737                context_lens,
738                &flash_meta,
739                flash_meta_full.as_ref().unwrap_or(&flash_meta),
740            )?,
741            Model::Phi3(ref model) => {
742                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
743            }
744            Model::XLoraPhi3(ref model) => model.forward(
745                &input_ids,
746                input_ids_full.as_ref().unwrap_or(&input_ids),
747                &seqlen_offsets,
748                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
749                self.no_kv_cache,
750                &self.non_granular_state,
751                context_lens,
752                &flash_meta,
753                flash_meta_full.as_ref().unwrap_or(&flash_meta),
754            )?,
755            Model::Starcoder2(ref model) => {
756                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
757            }
758            Model::Qwen(ref model) => {
759                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
760            }
761            Model::Qwen3(ref model) => {
762                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
763            }
764        };
765        if return_raw_logits {
766            Ok(ForwardInputsResult::RawLogits { logits })
767        } else {
768            Ok(ForwardInputsResult::CausalGeneration { logits })
769        }
770    }
771    async fn sample_causal_gen(
772        &self,
773        seqs: &mut [&mut Sequence],
774        logits: Vec<Tensor>,
775        prefix_cacher: &mut PrefixCacheManagerV2,
776        disable_eos_stop: bool,
777        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
778    ) -> Result<(), candle_core::Error> {
779        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
780    }
781    fn category(&self) -> ModelCategory {
782        ModelCategory::Text
783    }
784}
785
786// TODO
787impl AnyMoePipelineMixin for GGUFPipeline {}