mistralrs_core/pipeline/
gguf.rs

1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::llg::build_tok_env;
3use super::{
4    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
6    TokenSource,
7};
8use super::{
9    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
10    MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use crate::device_map::{self, DeviceMapper};
13use crate::gguf::{
14    get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
15};
16use crate::gguf::{Content, GGUFArchitecture};
17use crate::lora::Ordering;
18use crate::paged_attention::{
19    calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
20};
21use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
22use crate::pipeline::get_chat_template;
23use crate::pipeline::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
24use crate::pipeline::loaders::DeviceMappedModelLoader;
25use crate::pipeline::sampling::sample_and_add_toks;
26use crate::pipeline::ChatTemplate;
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::xlora_models::NonGranularState;
33use crate::{
34    get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
35    Pipeline, Topology, TryIntoDType,
36};
37use crate::{
38    models::quantized_llama::ModelWeights as QLlama,
39    models::quantized_phi2::ModelWeights as QPhi,
40    models::quantized_phi3::ModelWeights as QPhi3,
41    models::quantized_qwen2::ModelWeights as QQwen2,
42    models::quantized_starcoder2::ModelWeights as QStarcoder2,
43    utils::tokens::get_token,
44    xlora_models::{XLoraQLlama, XLoraQPhi3},
45};
46use anyhow::{bail, Result};
47use candle_core::{Device, Tensor};
48use either::Either;
49use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
50use mistralrs_quant::IsqType;
51use rand_isaac::Isaac64Rng;
52use std::any::Any;
53use std::fs;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::PathBuf;
56use std::str::FromStr;
57use std::sync::Arc;
58use tokenizers::Tokenizer;
59use tokio::sync::Mutex;
60use tracing::{info, warn};
61
62enum Model {
63    Llama(QLlama),
64    Phi2(QPhi),
65    XLoraLlama(XLoraQLlama),
66    XLoraPhi3(XLoraQPhi3),
67    Phi3(QPhi3),
68    Starcoder2(QStarcoder2),
69    Qwen2(QQwen2),
70}
71
72pub struct GGUFPipeline {
73    model: Model,
74    tokenizer: Arc<Tokenizer>,
75    no_kv_cache: bool,
76    chat_template: Arc<ChatTemplate>,
77    model_id: String,
78    non_granular_state: Option<NonGranularState>,
79    metadata: Arc<GeneralMetadata>,
80    mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83/// Loader for a GGUF model.
84pub struct GGUFLoader {
85    model_id: Option<String>,
86    quantized_model_id: String,
87    quantized_filenames: Vec<String>,
88    xlora_model_id: Option<String>,
89    xlora_order: Option<Ordering>,
90    no_kv_cache: bool,
91    chat_template: Option<String>,
92    kind: ModelKind,
93    tgt_non_granular_index: Option<usize>,
94    config: GGUFSpecificConfig,
95    jinja_explicit: Option<String>,
96    lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Clone, Default)]
100/// Config for a GGUF loader.
101pub struct GGUFSpecificConfig {
102    pub prompt_chunksize: Option<NonZeroUsize>,
103    pub topology: Option<Topology>,
104}
105
106#[derive(Default)]
107/// A builder for a GGUF loader.
108pub struct GGUFLoaderBuilder {
109    model_id: Option<String>,
110    quantized_model_id: String,
111    quantized_filenames: Vec<String>,
112    xlora_model_id: Option<String>,
113    kind: ModelKind,
114    xlora_order: Option<Ordering>,
115    no_kv_cache: bool,
116    chat_template: Option<String>,
117    tgt_non_granular_index: Option<usize>,
118    config: GGUFSpecificConfig,
119    jinja_explicit: Option<String>,
120}
121
122impl GGUFLoaderBuilder {
123    /// Create a loader builder for a GGUF model. `tok_model_id` is the model ID where you can find a
124    /// `tokenizer_config.json` file. If the `chat_template` is specified, then it will be treated as a
125    /// path and used over remote files, removing all remote accesses.
126    pub fn new(
127        chat_template: Option<String>,
128        tok_model_id: Option<String>,
129        quantized_model_id: String,
130        quantized_filenames: Vec<String>,
131        config: GGUFSpecificConfig,
132        no_kv_cache: bool,
133        jinja_explicit: Option<String>,
134    ) -> Self {
135        let kind = ModelKind::GgufQuantized {
136            quant: QuantizationKind::Gguf,
137        };
138
139        Self {
140            chat_template,
141            model_id: tok_model_id,
142            kind,
143            quantized_filenames,
144            quantized_model_id,
145            config,
146            jinja_explicit,
147            no_kv_cache,
148            ..Default::default()
149        }
150    }
151
152    fn with_adapter(
153        mut self,
154        xlora_model_id: String,
155        xlora_order: Ordering,
156        no_kv_cache: bool,
157        tgt_non_granular_index: Option<usize>,
158    ) -> Self {
159        self.xlora_model_id = Some(xlora_model_id);
160        self.xlora_order = Some(xlora_order);
161        self.no_kv_cache = no_kv_cache;
162        self.tgt_non_granular_index = tgt_non_granular_index;
163        self.model_id = if let Some(id) = self.model_id {
164            Some(id)
165        } else {
166            info!(
167                "Using adapter base model ID: `{}`",
168                self.xlora_order.as_ref().unwrap().base_model_id
169            );
170            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
171        };
172        self
173    }
174
175    pub fn with_xlora(
176        mut self,
177        xlora_model_id: String,
178        xlora_order: Ordering,
179        no_kv_cache: bool,
180        tgt_non_granular_index: Option<usize>,
181    ) -> Self {
182        self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
183
184        self.with_adapter(
185            xlora_model_id,
186            xlora_order,
187            no_kv_cache,
188            tgt_non_granular_index,
189        )
190    }
191
192    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
193        self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
194
195        self.with_adapter(lora_model_id, lora_order, false, None)
196    }
197
198    pub fn build(self) -> Box<dyn Loader> {
199        Box::new(GGUFLoader {
200            model_id: self.model_id,
201            xlora_model_id: self.xlora_model_id,
202            kind: self.kind,
203            xlora_order: self.xlora_order,
204            no_kv_cache: self.no_kv_cache,
205            chat_template: self.chat_template,
206            tgt_non_granular_index: self.tgt_non_granular_index,
207            quantized_filenames: self.quantized_filenames,
208            quantized_model_id: self.quantized_model_id,
209            config: self.config,
210            jinja_explicit: self.jinja_explicit,
211            lora_adapter_ids: None,
212        })
213    }
214}
215
216impl GGUFLoader {
217    #[allow(clippy::too_many_arguments)]
218    pub fn new(
219        model_id: Option<String>,
220        quantized_model_id: String,
221        quantized_filenames: Vec<String>,
222        xlora_model_id: Option<String>,
223        kind: ModelKind,
224        xlora_order: Option<Ordering>,
225        no_kv_cache: bool,
226        chat_template: Option<String>,
227        tgt_non_granular_index: Option<usize>,
228        config: GGUFSpecificConfig,
229        jinja_explicit: Option<String>,
230    ) -> Self {
231        let model_id = if let Some(id) = model_id {
232            Some(id)
233        } else if let Some(xlora_order) = xlora_order.clone() {
234            info!(
235                "Using adapter base model ID: `{}`",
236                xlora_order.base_model_id
237            );
238            Some(xlora_order.base_model_id.clone())
239        } else {
240            None
241        };
242        Self {
243            model_id,
244            quantized_model_id,
245            quantized_filenames,
246            xlora_model_id,
247            xlora_order,
248            no_kv_cache,
249            chat_template,
250            kind,
251            tgt_non_granular_index,
252            config,
253            jinja_explicit,
254            lora_adapter_ids: None,
255        }
256    }
257}
258
259impl Loader for GGUFLoader {
260    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
261    fn load_model_from_hf(
262        &self,
263        revision: Option<String>,
264        token_source: TokenSource,
265        dtype: &dyn TryIntoDType,
266        device: &Device,
267        silent: bool,
268        mapper: DeviceMapSetting,
269        in_situ_quant: Option<IsqType>,
270        paged_attn_config: Option<PagedAttentionConfig>,
271    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
272        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
273            LocalModelPaths,
274            &token_source,
275            revision,
276            self,
277            self.quantized_model_id.clone(),
278            self.quantized_filenames.clone(),
279            silent
280        );
281        self.load_model_from_path(
282            &paths?,
283            dtype,
284            device,
285            silent,
286            mapper,
287            in_situ_quant,
288            paged_attn_config,
289        )
290    }
291
292    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
293    fn load_model_from_path(
294        &self,
295        paths: &Box<dyn ModelPaths>,
296        dtype: &dyn TryIntoDType,
297        device: &Device,
298        silent: bool,
299        mut mapper: DeviceMapSetting,
300        in_situ_quant: Option<IsqType>,
301        mut paged_attn_config: Option<PagedAttentionConfig>,
302    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
303        if in_situ_quant.is_some() {
304            anyhow::bail!(
305                "You are trying to in-situ quantize a GGUF model. This will not do anything."
306            );
307        }
308
309        // Apply default prompt size here
310        let prompt_chunksize = self
311            .config
312            .prompt_chunksize
313            .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
314            .get();
315
316        info!("Prompt chunk size is {prompt_chunksize}.",);
317
318        let mut readers = Vec::new();
319        for filename in paths.get_weight_filenames() {
320            readers.push(std::fs::File::open(filename)?);
321        }
322        let mut readers = readers.iter_mut().collect::<Vec<_>>();
323
324        let model = Content::from_readers(&mut readers)?;
325        if !silent {
326            model.print_metadata()?;
327        }
328        let arch = model.arch();
329
330        // If auto, convert to Map
331        let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
332        if let DeviceMapSetting::Auto(params) = mapper.clone() {
333            let devices = device_map::get_all_similar_devices(device)?;
334            // Initial dtype
335            let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
336
337            let model = GgufDeviceMapLoaderInner {
338                model: &model,
339                arch,
340            };
341
342            let layer_sizes_in_bytes =
343                model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1)?;
344            let non_mapped_size_in_bytes =
345                model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1)?;
346            let total_model_size_in_bytes =
347                layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
348
349            let new = model.get_device_layers(
350                "this is a dummy config!",
351                num_layers,
352                layer_sizes_in_bytes,
353                non_mapped_size_in_bytes,
354                total_model_size_in_bytes,
355                &devices,
356                dtype,
357                &params,
358                prompt_chunksize,
359                paged_attn_config.as_ref(),
360            )?;
361            mapper = DeviceMapSetting::Map(new);
362        }
363
364        let pipeline_mapper =
365            mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
366        let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
367        let mut layer_devices = Vec::new();
368        for layer in 0..num_layers {
369            let device = mapper.device_for(layer, false).cloned();
370            layer_devices.push(device);
371        }
372
373        // TODO: PagedAttention is not supported with CPU for now.
374        // This check is not really necessary because `get_device_layers` should prevent it.
375        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
376        if mapping_uses_cpu {
377            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
378            paged_attn_config = None;
379        }
380
381        let GgufTokenizerConversion {
382            tokenizer,
383            bos,
384            eos,
385            unk,
386        } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
387            convert_gguf_to_hf_tokenizer(&model)?
388        } else {
389            GgufTokenizerConversion {
390                tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
391                bos: None,
392                eos: None,
393                unk: None,
394            }
395        };
396
397        // Only load gguf chat template if there is nothing else
398        let gguf_chat_template =
399            if paths.get_template_filename().is_none() && self.chat_template.is_none() {
400                get_gguf_chat_template(&model)?
401            } else {
402                None
403            };
404
405        let has_adapter = self.kind.is_adapted();
406        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
407
408        let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
409            warn!("Adapter models do not currently support PagedAttention, running without");
410            None
411        } else {
412            paged_attn_config
413        };
414
415        let model_config_metadata: ContentConfig = (&model).into();
416        let internal_dtype = mapper.get_min_dtype(dtype)?;
417
418        let model_config = {
419            // Base config (quantization only):
420            let quant = ModelConfig::ParamsGGUF(
421                model,
422                (device, mapper).into(),
423                if paged_attn_config.is_some() {
424                    AttentionImplementation::PagedAttention
425                } else {
426                    AttentionImplementation::Eager
427                },
428                internal_dtype,
429            );
430
431            // With optional adapter config:
432            let mut adapter = None;
433            if has_adapter {
434                adapter.replace(ModelConfig::Adapter::try_new(
435                    paths, device, silent, is_xlora,
436                )?);
437            }
438
439            ModelConfig::ModelParams::new(quant, adapter)
440        };
441
442        // Config into model:
443        let model = match self.kind {
444            ModelKind::GgufQuantized { .. } => match arch {
445                GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
446                GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
447                GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
448                GGUFArchitecture::Starcoder2 => {
449                    Model::Starcoder2(QStarcoder2::try_from(model_config)?)
450                }
451                GGUFArchitecture::Qwen2 => Model::Qwen2(QQwen2::try_from(model_config)?),
452                a => bail!("Unsupported architecture `{a:?}` for GGUF"),
453            },
454            ModelKind::GgufAdapter { adapter, .. } => match arch {
455                GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
456                GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
457                a => bail!(
458                    "Unsupported architecture `{a:?}` for GGUF {kind}",
459                    kind = adapter.pretty_name()
460                ),
461            },
462            _ => unreachable!(),
463        };
464
465        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
466            let model_config: &dyn ModelConfigLike = &model_config_metadata;
467            let cache_config = calculate_cache_config(
468                paged_attn_config.mem_gpu,
469                paged_attn_config.mem_cpu,
470                paged_attn_config.block_size,
471                internal_dtype,
472                model_config,
473                device,
474                &layer_devices,
475                silent,
476            )?;
477            let cache_engine = CacheEngine::new(
478                model_config,
479                &cache_config,
480                internal_dtype,
481                device,
482                layer_devices,
483            )?;
484            (Some(cache_config), Some(cache_engine))
485        } else {
486            (None, None)
487        };
488
489        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
490            serde_json::from_str(&fs::read_to_string(f).unwrap())
491                .expect("bos_token_id/eos_token_id missing in generation_config.json")
492        });
493        let mut chat_template = get_chat_template(
494            paths,
495            &self.jinja_explicit,
496            &paths
497                .get_chat_template_explicit()
498                .as_ref()
499                .map(|x| x.to_string_lossy().to_string())
500                .clone(),
501            &self.chat_template,
502            gguf_chat_template,
503        );
504
505        let max_seq_len = match model {
506            Model::Llama(ref l) => l.max_seq_len,
507            Model::Phi2(ref p) => p.max_seq_len,
508            Model::XLoraLlama(ref xl) => xl.max_seq_len,
509            Model::Phi3(ref p) => p.max_seq_len,
510            Model::XLoraPhi3(ref p) => p.max_seq_len,
511            Model::Starcoder2(ref p) => p.max_seq_len,
512            Model::Qwen2(ref p) => p.max_seq_len,
513        };
514        let tok_env = build_tok_env(tokenizer.clone());
515        let num_hidden_layers = match model {
516            Model::Llama(ref model) => model.cache.normal().0.len(),
517            Model::Phi2(ref model) => model.cache.normal().0.len(),
518            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
519            Model::Phi3(ref model) => model.cache.normal().0.len(),
520            Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
521            Model::Starcoder2(ref model) => model.cache.normal().0.len(),
522            Model::Qwen2(ref model) => model.cache.normal().0.len(),
523        };
524
525        if chat_template.bos_token.is_none() && bos.is_some() {
526            chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos.unwrap())));
527        }
528        if chat_template.eos_token.is_none() && eos.is_some() {
529            chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos.unwrap())));
530        }
531        if chat_template.unk_token.is_none() && unk.is_some() {
532            chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(unk.unwrap())));
533        }
534
535        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
536        Ok(Arc::new(Mutex::new(GGUFPipeline {
537            model,
538            tokenizer: tokenizer.into(),
539            no_kv_cache: self.no_kv_cache,
540            chat_template: Arc::new(chat_template),
541            model_id: self
542                .model_id
543                .clone()
544                .unwrap_or(self.quantized_model_id.clone()),
545            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
546                NonGranularState {
547                    non_granular_index: Arc::new(Mutex::new(0)),
548                    tgt_non_granular_index,
549                }
550            }),
551            metadata: Arc::new(GeneralMetadata {
552                max_seq_len,
553                tok_env: Some(tok_env),
554                no_kv_cache: self.no_kv_cache,
555                no_prefix_cache: false,
556                num_hidden_layers,
557                eos_tok: eos,
558                kind: self.kind.clone(),
559                is_xlora,
560                activation_dtype: internal_dtype,
561                sliding_window: None,
562                cache_config,
563                cache_engine,
564                prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
565                model_metadata: Some(Arc::new(model_config_metadata)),
566            }),
567            mapper: pipeline_mapper,
568        })))
569    }
570
571    fn get_id(&self) -> String {
572        self.xlora_model_id
573            .as_deref()
574            .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
575            .to_string()
576    }
577
578    fn get_kind(&self) -> ModelKind {
579        self.kind.clone()
580    }
581}
582
583impl PreProcessingMixin for GGUFPipeline {
584    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
585        Some(self.chat_template.clone())
586    }
587    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
588        None
589    }
590}
591
592impl IsqPipelineMixin for GGUFPipeline {
593    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
594        anyhow::bail!(
595            "You are trying to in-situ requantize a GGML model. This will not do anything."
596        )
597    }
598}
599
600impl CacheManagerMixin for GGUFPipeline {
601    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
602        if matches!(self.cache(), EitherCache::Full(_)) {
603            FullCacheManager.clone_in_cache(self, seqs, false)
604        } else {
605            NormalCacheManager.clone_in_cache(self, seqs, false)
606        }
607    }
608    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
609        if matches!(self.cache(), EitherCache::Full(_)) {
610            FullCacheManager.clone_out_cache(self, seqs, false)
611        } else {
612            NormalCacheManager.clone_out_cache(self, seqs, false)
613        }
614    }
615    fn set_none_cache(
616        &self,
617        seqs: &mut [&mut Sequence],
618        reset_non_granular: bool,
619        modify_draft_cache: bool,
620        load_preallocated_cache: bool,
621    ) {
622        if matches!(self.cache(), EitherCache::Full(_)) {
623            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
624        } else {
625            NormalCacheManager.set_none_cache(
626                self,
627                seqs,
628                modify_draft_cache,
629                load_preallocated_cache,
630            );
631        }
632        if reset_non_granular {
633            self.reset_non_granular_state()
634        }
635    }
636    fn cache(&self) -> &EitherCache {
637        match self.model {
638            Model::Llama(ref model) => &model.cache,
639            Model::Phi2(ref model) => &model.cache,
640            Model::XLoraLlama(ref model) => &model.cache,
641            Model::Phi3(ref model) => &model.cache,
642            Model::XLoraPhi3(ref model) => &model.cache,
643            Model::Starcoder2(ref model) => &model.cache,
644            Model::Qwen2(ref model) => &model.cache,
645        }
646    }
647}
648
649impl MetadataMixin for GGUFPipeline {
650    fn device(&self) -> Device {
651        match self.model {
652            Model::Llama(ref model) => model.device.clone(),
653            Model::Phi2(ref model) => model.device.clone(),
654            Model::XLoraLlama(ref model) => model.device.clone(),
655            Model::Phi3(ref model) => model.device.clone(),
656            Model::XLoraPhi3(ref model) => model.device.clone(),
657            Model::Starcoder2(ref model) => model.device.clone(),
658            Model::Qwen2(ref model) => model.device.clone(),
659        }
660    }
661    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
662        Some(self.tokenizer.clone())
663    }
664    fn name(&self) -> String {
665        self.model_id.clone()
666    }
667    fn reset_non_granular_state(&self) {
668        if let Some(s) = self.non_granular_state.as_ref() {
669            *self.cache().full().get_scalings_cache() = None;
670            *get_mut_arcmutex!(s.non_granular_index) = 0;
671        }
672    }
673    fn get_metadata(&self) -> Arc<GeneralMetadata> {
674        self.metadata.clone()
675    }
676    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
677        Some(&*self.mapper)
678    }
679}
680
681#[async_trait::async_trait]
682impl Pipeline for GGUFPipeline {
683    fn forward_inputs(
684        &mut self,
685        inputs: Box<dyn Any>,
686        return_raw_logits: bool,
687    ) -> Result<ForwardInputsResult, candle_core::Error> {
688        let ModelInputs {
689            input_ids,
690            input_ids_full,
691            seqlen_offsets,
692            seqlen_offsets_full,
693            context_lens,
694            position_ids: _, // NOTE(EricLBuehler): ignore, it is for phi3
695            paged_attn_meta,
696            flash_meta,
697            flash_meta_full,
698        } = *inputs.downcast().expect("Downcast failed.");
699        let metadata = self.get_metadata();
700        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
701            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
702            (Some(_), None) => {
703                // This can happen if Rust-side user code is wrong
704                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
705            }
706            (None, Some(_)) => {
707                // This should never happen but we handle it anyway
708                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
709            }
710            (None, None) => None,
711        };
712        let logits = match self.model {
713            Model::Llama(ref model) => {
714                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
715            }
716            Model::Phi2(ref model) => {
717                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
718            }
719            Model::XLoraLlama(ref model) => model.forward(
720                &input_ids,
721                input_ids_full.as_ref().unwrap_or(&input_ids),
722                &seqlen_offsets,
723                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
724                self.no_kv_cache,
725                &self.non_granular_state,
726                context_lens,
727                &flash_meta,
728                flash_meta_full.as_ref().unwrap_or(&flash_meta),
729            )?,
730            Model::Phi3(ref model) => {
731                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
732            }
733            Model::XLoraPhi3(ref model) => model.forward(
734                &input_ids,
735                input_ids_full.as_ref().unwrap_or(&input_ids),
736                &seqlen_offsets,
737                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
738                self.no_kv_cache,
739                &self.non_granular_state,
740                context_lens,
741                &flash_meta,
742                flash_meta_full.as_ref().unwrap_or(&flash_meta),
743            )?,
744            Model::Starcoder2(ref model) => {
745                model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
746            }
747            Model::Qwen2(ref model) => {
748                model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
749            }
750        };
751        if return_raw_logits {
752            Ok(ForwardInputsResult::RawLogits { logits })
753        } else {
754            Ok(ForwardInputsResult::CausalGeneration { logits })
755        }
756    }
757    async fn sample_causal_gen(
758        &self,
759        seqs: &mut [&mut Sequence],
760        logits: Vec<Tensor>,
761        prefix_cacher: &mut PrefixCacheManagerV2,
762        disable_eos_stop: bool,
763        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
764    ) -> Result<(), candle_core::Error> {
765        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
766    }
767    fn category(&self) -> ModelCategory {
768        ModelCategory::Text
769    }
770}
771
772// TODO
773impl AnyMoePipelineMixin for GGUFPipeline {}