1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5    TokenSource,
6};
7use super::{
8    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9    MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::attention::ATTENTION_CHUNK_SIZE;
12use crate::device_map::{self, DeviceMapper};
13use crate::gguf::{
14    get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
15};
16use crate::gguf::{Content, GGUFArchitecture};
17use crate::kv_cache::{FullCacheManager, NormalCacheManager};
18use crate::lora::Ordering;
19use crate::paged_attention::{
20    calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
21};
22use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
23use crate::pipeline::loaders::DeviceMappedModelLoader;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::ChatTemplate;
26use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::xlora_models::NonGranularState;
33use crate::{
34    get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
35    Pipeline, Topology, TryIntoDType,
36};
37use crate::{
38    models::quantized_llama::ModelWeights as QLlama,
39    models::quantized_phi2::ModelWeights as QPhi,
40    models::quantized_phi3::ModelWeights as QPhi3,
41    models::quantized_qwen::ModelWeights as QQwen,
42    models::quantized_qwen3::ModelWeights as QQwen3,
43    models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
44    models::quantized_starcoder2::ModelWeights as QStarcoder2,
45    utils::tokens::get_token,
46    xlora_models::{XLoraQLlama, XLoraQPhi3},
47};
48use anyhow::{bail, Result};
49use candle_core::{Device, Tensor};
50use either::Either;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::IsqType;
53use rand_isaac::Isaac64Rng;
54use std::any::Any;
55use std::fs;
56use std::path::PathBuf;
57use std::str::FromStr;
58use std::sync::Arc;
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63enum Model {
64    Llama(QLlama),
65    Phi2(QPhi),
66    XLoraLlama(XLoraQLlama),
67    XLoraPhi3(XLoraQPhi3),
68    Phi3(QPhi3),
69    Starcoder2(QStarcoder2),
70    Qwen(QQwen),
71    Qwen3(QQwen3),
72    Qwen3MoE(QQwen3MoE),
73}
74
75pub struct GGUFPipeline {
76    model: Model,
77    tokenizer: Arc<Tokenizer>,
78    no_kv_cache: bool,
79    chat_template: Arc<ChatTemplate>,
80    model_id: String,
81    non_granular_state: Option<NonGranularState>,
82    metadata: Arc<GeneralMetadata>,
83    mapper: Box<dyn DeviceMapper + Send + Sync>,
84}
85
86pub struct GGUFLoader {
88    model_id: Option<String>,
89    quantized_model_id: String,
90    quantized_filenames: Vec<String>,
91    xlora_model_id: Option<String>,
92    xlora_order: Option<Ordering>,
93    no_kv_cache: bool,
94    chat_template: Option<String>,
95    kind: ModelKind,
96    tgt_non_granular_index: Option<usize>,
97    config: GGUFSpecificConfig,
98    jinja_explicit: Option<String>,
99    lora_adapter_ids: Option<Vec<String>>,
100}
101
102#[derive(Clone, Default)]
103pub struct GGUFSpecificConfig {
105    pub topology: Option<Topology>,
106}
107
108#[derive(Default)]
109pub struct GGUFLoaderBuilder {
111    model_id: Option<String>,
112    quantized_model_id: String,
113    quantized_filenames: Vec<String>,
114    xlora_model_id: Option<String>,
115    kind: ModelKind,
116    xlora_order: Option<Ordering>,
117    no_kv_cache: bool,
118    chat_template: Option<String>,
119    tgt_non_granular_index: Option<usize>,
120    config: GGUFSpecificConfig,
121    jinja_explicit: Option<String>,
122}
123
124impl GGUFLoaderBuilder {
125    pub fn new(
129        chat_template: Option<String>,
130        tok_model_id: Option<String>,
131        quantized_model_id: String,
132        quantized_filenames: Vec<String>,
133        config: GGUFSpecificConfig,
134        no_kv_cache: bool,
135        jinja_explicit: Option<String>,
136    ) -> Self {
137        let kind = ModelKind::GgufQuantized {
138            quant: QuantizationKind::Gguf,
139        };
140
141        Self {
142            chat_template,
143            model_id: tok_model_id,
144            kind,
145            quantized_filenames,
146            quantized_model_id,
147            config,
148            jinja_explicit,
149            no_kv_cache,
150            ..Default::default()
151        }
152    }
153
154    fn with_adapter(
155        mut self,
156        xlora_model_id: String,
157        xlora_order: Ordering,
158        no_kv_cache: bool,
159        tgt_non_granular_index: Option<usize>,
160    ) -> Self {
161        self.xlora_model_id = Some(xlora_model_id);
162        self.xlora_order = Some(xlora_order);
163        self.no_kv_cache = no_kv_cache;
164        self.tgt_non_granular_index = tgt_non_granular_index;
165        self.model_id = if let Some(id) = self.model_id {
166            Some(id)
167        } else {
168            info!(
169                "Using adapter base model ID: `{}`",
170                self.xlora_order.as_ref().unwrap().base_model_id
171            );
172            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173        };
174        self
175    }
176
177    pub fn with_xlora(
178        mut self,
179        xlora_model_id: String,
180        xlora_order: Ordering,
181        no_kv_cache: bool,
182        tgt_non_granular_index: Option<usize>,
183    ) -> Self {
184        self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
185
186        self.with_adapter(
187            xlora_model_id,
188            xlora_order,
189            no_kv_cache,
190            tgt_non_granular_index,
191        )
192    }
193
194    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
195        self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
196
197        self.with_adapter(lora_model_id, lora_order, false, None)
198    }
199
200    pub fn build(self) -> Box<dyn Loader> {
201        Box::new(GGUFLoader {
202            model_id: self.model_id,
203            xlora_model_id: self.xlora_model_id,
204            kind: self.kind,
205            xlora_order: self.xlora_order,
206            no_kv_cache: self.no_kv_cache,
207            chat_template: self.chat_template,
208            tgt_non_granular_index: self.tgt_non_granular_index,
209            quantized_filenames: self.quantized_filenames,
210            quantized_model_id: self.quantized_model_id,
211            config: self.config,
212            jinja_explicit: self.jinja_explicit,
213            lora_adapter_ids: None,
214        })
215    }
216}
217
218impl GGUFLoader {
219    #[allow(clippy::too_many_arguments)]
220    pub fn new(
221        model_id: Option<String>,
222        quantized_model_id: String,
223        quantized_filenames: Vec<String>,
224        xlora_model_id: Option<String>,
225        kind: ModelKind,
226        xlora_order: Option<Ordering>,
227        no_kv_cache: bool,
228        chat_template: Option<String>,
229        tgt_non_granular_index: Option<usize>,
230        config: GGUFSpecificConfig,
231        jinja_explicit: Option<String>,
232    ) -> Self {
233        let model_id = if let Some(id) = model_id {
234            Some(id)
235        } else if let Some(xlora_order) = xlora_order.clone() {
236            info!(
237                "Using adapter base model ID: `{}`",
238                xlora_order.base_model_id
239            );
240            Some(xlora_order.base_model_id.clone())
241        } else {
242            None
243        };
244        Self {
245            model_id,
246            quantized_model_id,
247            quantized_filenames,
248            xlora_model_id,
249            xlora_order,
250            no_kv_cache,
251            chat_template,
252            kind,
253            tgt_non_granular_index,
254            config,
255            jinja_explicit,
256            lora_adapter_ids: None,
257        }
258    }
259}
260
261impl Loader for GGUFLoader {
262    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
263    fn load_model_from_hf(
264        &self,
265        revision: Option<String>,
266        token_source: TokenSource,
267        dtype: &dyn TryIntoDType,
268        device: &Device,
269        silent: bool,
270        mapper: DeviceMapSetting,
271        in_situ_quant: Option<IsqType>,
272        paged_attn_config: Option<PagedAttentionConfig>,
273    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
274        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
275            LocalModelPaths,
276            &token_source,
277            revision,
278            self,
279            self.quantized_model_id.clone(),
280            self.quantized_filenames.clone(),
281            silent
282        );
283        self.load_model_from_path(
284            &paths?,
285            dtype,
286            device,
287            silent,
288            mapper,
289            in_situ_quant,
290            paged_attn_config,
291        )
292    }
293
294    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
295    fn load_model_from_path(
296        &self,
297        paths: &Box<dyn ModelPaths>,
298        dtype: &dyn TryIntoDType,
299        device: &Device,
300        silent: bool,
301        mut mapper: DeviceMapSetting,
302        in_situ_quant: Option<IsqType>,
303        mut paged_attn_config: Option<PagedAttentionConfig>,
304    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
305        if in_situ_quant.is_some() {
306            anyhow::bail!(
307                "You are trying to in-situ quantize a GGUF model. This will not do anything."
308            );
309        }
310
311        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
312
313        let mut readers = Vec::new();
314        for filename in paths.get_weight_filenames() {
315            readers.push(std::fs::File::open(filename)?);
316        }
317        let mut readers = readers.iter_mut().collect::<Vec<_>>();
318
319        let model = Content::from_readers(&mut readers)?;
320        if !silent {
321            model.print_metadata()?;
322        }
323        let arch = model.arch();
324
325        let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
327        if let DeviceMapSetting::Auto(params) = mapper.clone() {
328            let devices = device_map::get_all_similar_devices(device)?;
329            let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
331
332            let model = GgufDeviceMapLoaderInner {
333                model: &model,
334                arch,
335            };
336
337            let layer_sizes_in_bytes =
338                model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1, None)?;
339            let non_mapped_size_in_bytes =
340                model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1, None)?;
341            let total_model_size_in_bytes =
342                layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
343
344            let new = model.get_device_layers(
345                "this is a dummy config!",
346                num_layers,
347                layer_sizes_in_bytes,
348                non_mapped_size_in_bytes,
349                total_model_size_in_bytes,
350                &devices,
351                dtype,
352                ¶ms,
353                paged_attn_config.as_ref(),
354            )?;
355            mapper = DeviceMapSetting::Map(new);
356        }
357
358        #[cfg(feature = "cuda")]
359        if let Device::Cuda(dev) = &device {
360            unsafe { dev.disable_event_tracking() };
361        }
362
363        let pipeline_mapper =
364            mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
365        let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
366        let mut layer_devices = Vec::new();
367        for layer in 0..num_layers {
368            let device = mapper.device_for(layer, false).cloned();
369            layer_devices.push(device);
370        }
371
372        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
375        if mapping_uses_cpu {
376            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
377            paged_attn_config = None;
378        }
379
380        let GgufTokenizerConversion {
381            tokenizer,
382            bos,
383            eos,
384            unk,
385        } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
386            convert_gguf_to_hf_tokenizer(&model)?
387        } else {
388            GgufTokenizerConversion {
389                tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
390                bos: None,
391                eos: None,
392                unk: None,
393            }
394        };
395
396        let gguf_chat_template =
398            if paths.get_template_filename().is_none() && self.chat_template.is_none() {
399                get_gguf_chat_template(&model)?
400            } else {
401                None
402            };
403
404        let has_adapter = self.kind.is_adapted();
405        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
406
407        let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
408            warn!("Adapter models do not currently support PagedAttention, running without");
409            None
410        } else {
411            paged_attn_config
412        };
413
414        let model_config_metadata: ContentConfig = (&model).into();
415        let internal_dtype = mapper.get_min_dtype(dtype)?;
416
417        let model_config = {
418            let quant = ModelConfig::ParamsGGUF(
420                model,
421                (device, mapper).into(),
422                if paged_attn_config.is_some() {
423                    AttentionImplementation::PagedAttention
424                } else {
425                    AttentionImplementation::Eager
426                },
427                internal_dtype,
428            );
429
430            let mut adapter = None;
432            if has_adapter {
433                adapter.replace(ModelConfig::Adapter::try_new(
434                    paths, device, silent, is_xlora,
435                )?);
436            }
437
438            ModelConfig::ModelParams::new(quant, adapter)
439        };
440
441        let model = match self.kind {
443            ModelKind::GgufQuantized { .. } => match arch {
444                GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
445                GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
446                GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
447                GGUFArchitecture::Starcoder2 => {
448                    Model::Starcoder2(QStarcoder2::try_from(model_config)?)
449                }
450                GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
451                GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
452                GGUFArchitecture::Qwen3MoE => Model::Qwen3MoE(QQwen3MoE::try_from(model_config)?),
453                a => bail!("Unsupported architecture `{a:?}` for GGUF"),
454            },
455            ModelKind::GgufAdapter { adapter, .. } => match arch {
456                GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
457                GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
458                a => bail!(
459                    "Unsupported architecture `{a:?}` for GGUF {kind}",
460                    kind = adapter.pretty_name()
461                ),
462            },
463            _ => unreachable!(),
464        };
465
466        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
467            let model_config: &dyn ModelConfigLike = &model_config_metadata;
468            let cache_config = calculate_cache_config(
469                paged_attn_config.mem_gpu,
470                paged_attn_config.block_size,
471                internal_dtype,
472                paged_attn_config.cache_type,
473                model_config,
474                device,
475                &layer_devices,
476                silent,
477            )?;
478            let cache_engine = CacheEngine::new(
479                model_config,
480                &cache_config,
481                internal_dtype,
482                device,
483                layer_devices,
484            )?;
485            (Some(cache_config), Some(cache_engine))
486        } else {
487            (None, None)
488        };
489
490        let gen_conf: Option<GenerationConfig> = paths
491            .get_gen_conf_filename()
492            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
493        let chat_template_explicit = paths
494            .get_chat_template_explicit()
495            .as_ref()
496            .map(|x| x.to_string_lossy().to_string());
497        let mut chat_template = get_chat_template(
498            paths,
499            self.jinja_explicit.as_ref(),
500            chat_template_explicit.as_ref(),
501            self.chat_template.as_ref(),
502            gguf_chat_template,
503        );
504
505        let max_seq_len = match model {
506            Model::Llama(ref l) => l.max_seq_len,
507            Model::Phi2(ref p) => p.max_seq_len,
508            Model::XLoraLlama(ref xl) => xl.max_seq_len,
509            Model::Phi3(ref p) => p.max_seq_len,
510            Model::XLoraPhi3(ref p) => p.max_seq_len,
511            Model::Starcoder2(ref p) => p.max_seq_len,
512            Model::Qwen(ref p) => p.max_seq_len,
513            Model::Qwen3(ref p) => p.max_seq_len,
514            Model::Qwen3MoE(ref p) => p.max_seq_len,
515        };
516        let llg_factory = build_llg_factory(tokenizer.clone())?;
517        let num_hidden_layers = match model {
518            Model::Llama(ref model) => model.cache.normal().0.len(),
519            Model::Phi2(ref model) => model.cache.normal().0.len(),
520            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
521            Model::Phi3(ref model) => model.cache.normal().0.len(),
522            Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
523            Model::Starcoder2(ref model) => model.cache.normal().0.len(),
524            Model::Qwen(ref model) => model.cache.normal().0.len(),
525            Model::Qwen3(ref model) => model.cache.normal().0.len(),
526            Model::Qwen3MoE(ref model) => model.cache.normal().0.len(),
527        };
528
529        if chat_template.bos_token.is_none() {
530            if let Some(v) = bos {
531                chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
532            }
533        }
534        if chat_template.eos_token.is_none() {
535            if let Some(v) = eos {
536                chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
537            }
538        }
539        if chat_template.unk_token.is_none() {
540            if let Some(v) = unk {
541                chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(v)));
542            }
543        }
544
545        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
546        Ok(Arc::new(Mutex::new(GGUFPipeline {
547            model,
548            tokenizer: tokenizer.into(),
549            no_kv_cache: self.no_kv_cache,
550            chat_template: Arc::new(chat_template),
551            model_id: self
552                .model_id
553                .clone()
554                .unwrap_or(self.quantized_model_id.clone()),
555            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
556                NonGranularState {
557                    non_granular_index: Arc::new(Mutex::new(0)),
558                    tgt_non_granular_index,
559                }
560            }),
561            metadata: Arc::new(GeneralMetadata {
562                max_seq_len,
563                llg_factory: Some(llg_factory),
564                no_kv_cache: self.no_kv_cache,
565                no_prefix_cache: false,
566                num_hidden_layers,
567                eos_tok: eos,
568                kind: self.kind.clone(),
569                is_xlora,
570                activation_dtype: internal_dtype,
571                sliding_window: None,
572                cache_config,
573                cache_engine,
574                model_metadata: Some(Arc::new(model_config_metadata)),
575                modalities: Modalities {
576                    input: vec![SupportedModality::Text],
577                    output: vec![SupportedModality::Text],
578                },
579            }),
580            mapper: pipeline_mapper,
581        })))
582    }
583
584    fn get_id(&self) -> String {
585        self.xlora_model_id
586            .as_deref()
587            .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
588            .to_string()
589    }
590
591    fn get_kind(&self) -> ModelKind {
592        self.kind.clone()
593    }
594}
595
596impl PreProcessingMixin for GGUFPipeline {
597    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
598        Some(self.chat_template.clone())
599    }
600    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
601        None
602    }
603}
604
605impl IsqPipelineMixin for GGUFPipeline {
606    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
607        anyhow::bail!(
608            "You are trying to in-situ requantize a GGML model. This will not do anything."
609        )
610    }
611}
612
613impl CacheManagerMixin for GGUFPipeline {
614    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
615        if matches!(self.cache(), EitherCache::Full(_)) {
616            FullCacheManager.clone_in_cache(self, seqs, false)
617        } else {
618            NormalCacheManager.clone_in_cache(self, seqs, false)
619        }
620    }
621    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
622        if matches!(self.cache(), EitherCache::Full(_)) {
623            FullCacheManager.clone_out_cache(self, seqs, false)
624        } else {
625            NormalCacheManager.clone_out_cache(self, seqs, false)
626        }
627    }
628    fn set_none_cache(
629        &self,
630        seqs: &mut [&mut Sequence],
631        reset_non_granular: bool,
632        modify_draft_cache: bool,
633        load_preallocated_cache: bool,
634    ) {
635        if matches!(self.cache(), EitherCache::Full(_)) {
636            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
637        } else {
638            NormalCacheManager.set_none_cache(
639                self,
640                seqs,
641                modify_draft_cache,
642                load_preallocated_cache,
643            );
644        }
645        if reset_non_granular {
646            self.reset_non_granular_state()
647        }
648    }
649    fn cache(&self) -> &EitherCache {
650        match self.model {
651            Model::Llama(ref model) => &model.cache,
652            Model::Phi2(ref model) => &model.cache,
653            Model::XLoraLlama(ref model) => &model.cache,
654            Model::Phi3(ref model) => &model.cache,
655            Model::XLoraPhi3(ref model) => &model.cache,
656            Model::Starcoder2(ref model) => &model.cache,
657            Model::Qwen(ref model) => &model.cache,
658            Model::Qwen3(ref model) => &model.cache,
659            Model::Qwen3MoE(ref model) => &model.cache,
660        }
661    }
662}
663
664impl MetadataMixin for GGUFPipeline {
665    fn device(&self) -> Device {
666        match self.model {
667            Model::Llama(ref model) => model.device.clone(),
668            Model::Phi2(ref model) => model.device.clone(),
669            Model::XLoraLlama(ref model) => model.device.clone(),
670            Model::Phi3(ref model) => model.device.clone(),
671            Model::XLoraPhi3(ref model) => model.device.clone(),
672            Model::Starcoder2(ref model) => model.device.clone(),
673            Model::Qwen(ref model) => model.device.clone(),
674            Model::Qwen3(ref model) => model.device.clone(),
675            Model::Qwen3MoE(ref model) => model.device.clone(),
676        }
677    }
678    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
679        Some(self.tokenizer.clone())
680    }
681    fn name(&self) -> String {
682        self.model_id.clone()
683    }
684    fn reset_non_granular_state(&self) {
685        if let Some(s) = self.non_granular_state.as_ref() {
686            *self.cache().full().get_scalings_cache() = None;
687            *get_mut_arcmutex!(s.non_granular_index) = 0;
688        }
689    }
690    fn get_metadata(&self) -> Arc<GeneralMetadata> {
691        self.metadata.clone()
692    }
693    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
694        Some(&*self.mapper)
695    }
696}
697
698#[async_trait::async_trait]
699impl Pipeline for GGUFPipeline {
700    fn forward_inputs(
701        &mut self,
702        inputs: Box<dyn Any>,
703        return_raw_logits: bool,
704    ) -> Result<ForwardInputsResult, candle_core::Error> {
705        let ModelInputs {
706            input_ids,
707            input_ids_full,
708            seqlen_offsets,
709            seqlen_offsets_full,
710            context_lens,
711            position_ids: _, paged_attn_meta,
713            flash_meta,
714            flash_meta_full,
715        } = *inputs.downcast().expect("Downcast failed.");
716        let metadata = self.get_metadata();
717        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
718            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
719            (Some(_), None) => {
720                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
722            }
723            (None, Some(_)) => {
724                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
726            }
727            (None, None) => None,
728        };
729        let logits = match self.model {
730            Model::Llama(ref model) => {
731                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
732            }
733            Model::Phi2(ref model) => {
734                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
735            }
736            Model::XLoraLlama(ref model) => model.forward(
737                &input_ids,
738                input_ids_full.as_ref().unwrap_or(&input_ids),
739                &seqlen_offsets,
740                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
741                self.no_kv_cache,
742                &self.non_granular_state,
743                context_lens,
744                &flash_meta,
745                flash_meta_full.as_ref().unwrap_or(&flash_meta),
746            )?,
747            Model::Phi3(ref model) => {
748                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
749            }
750            Model::XLoraPhi3(ref model) => model.forward(
751                &input_ids,
752                input_ids_full.as_ref().unwrap_or(&input_ids),
753                &seqlen_offsets,
754                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
755                self.no_kv_cache,
756                &self.non_granular_state,
757                context_lens,
758                &flash_meta,
759                flash_meta_full.as_ref().unwrap_or(&flash_meta),
760            )?,
761            Model::Starcoder2(ref model) => {
762                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
763            }
764            Model::Qwen(ref model) => {
765                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
766            }
767            Model::Qwen3(ref model) => {
768                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
769            }
770            Model::Qwen3MoE(ref model) => {
771                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
772            }
773        };
774        if return_raw_logits {
775            Ok(ForwardInputsResult::RawLogits { logits })
776        } else {
777            Ok(ForwardInputsResult::CausalGeneration { logits })
778        }
779    }
780    async fn sample_causal_gen(
781        &self,
782        seqs: &mut [&mut Sequence],
783        logits: Vec<Tensor>,
784        prefix_cacher: &mut PrefixCacheManagerV2,
785        disable_eos_stop: bool,
786        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
787    ) -> Result<(), candle_core::Error> {
788        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
789    }
790    fn category(&self) -> ModelCategory {
791        ModelCategory::Text
792    }
793}
794
795impl AnyMoePipelineMixin for GGUFPipeline {}