mistralrs_core/pipeline/
ggml.rs

1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, QuantizationKind, TokenSource,
5};
6use super::{
7    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
8    MetadataMixin, ModelCategory, PreProcessingMixin,
9};
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::device_map::DeviceMapper;
12use crate::kv_cache::FullCacheManager;
13use crate::lora::Ordering;
14use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
15use crate::pipeline::sampling::sample_and_add_toks;
16use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
17use crate::pipeline::{ChatTemplate, LocalModelPaths};
18use crate::prefix_cacher::PrefixCacheManagerV2;
19use crate::sequence::Sequence;
20use crate::utils::debug::DeviceRepr;
21use crate::utils::model_config as ModelConfig;
22use crate::utils::tokenizer::get_tokenizer;
23use crate::xlora_models::NonGranularState;
24use crate::{
25    get_mut_arcmutex, get_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
26    TryIntoDType, DEBUG,
27};
28use crate::{
29    models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
30    xlora_models::XLoraQLlama,
31};
32use anyhow::Result;
33use candle_core::quantized::ggml_file;
34use candle_core::{Device, Tensor};
35use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
36use mistralrs_quant::IsqType;
37use rand_isaac::Isaac64Rng;
38use std::any::Any;
39use std::fs;
40use std::path::PathBuf;
41use std::str::FromStr;
42use std::sync::Arc;
43use tokenizers::Tokenizer;
44use tokio::sync::Mutex;
45use tracing::{info, warn};
46
47enum Model {
48    Llama(QLlama),
49    XLoraLlama(Box<XLoraQLlama>),
50}
51
52pub struct GGMLPipeline {
53    model: Model,
54    tokenizer: Arc<Tokenizer>,
55    no_kv_cache: bool,
56    chat_template: Arc<ChatTemplate>,
57    model_id: String,
58    non_granular_state: Option<NonGranularState>,
59    metadata: Arc<GeneralMetadata>,
60}
61
62/// A loader for a GGML model.
63pub struct GGMLLoader {
64    model_id: String,
65    config: GGMLSpecificConfig,
66    quantized_model_id: Option<String>,
67    quantized_filename: Option<String>,
68    xlora_model_id: Option<String>,
69    xlora_order: Option<Ordering>,
70    no_kv_cache: bool,
71    chat_template: Option<String>,
72    tokenizer_json: Option<String>,
73    kind: ModelKind,
74    tgt_non_granular_index: Option<usize>,
75    jinja_explicit: Option<String>,
76    lora_adapter_ids: Option<Vec<String>>,
77}
78
79#[derive(Clone, Default)]
80/// Config for a GGML loader.
81pub struct GGMLSpecificConfig {
82    pub gqa: usize,
83    pub topology: Option<Topology>,
84}
85
86#[derive(Default)]
87/// A builder for a GGML loader.
88pub struct GGMLLoaderBuilder {
89    model_id: Option<String>,
90    config: GGMLSpecificConfig,
91    quantized_model_id: String,
92    quantized_filename: String,
93    xlora_model_id: Option<String>,
94    kind: ModelKind,
95    xlora_order: Option<Ordering>,
96    no_kv_cache: bool,
97    chat_template: Option<String>,
98    tokenizer_json: Option<String>,
99    tgt_non_granular_index: Option<usize>,
100    jinja_explicit: Option<String>,
101}
102
103impl GGMLLoaderBuilder {
104    #[allow(clippy::too_many_arguments)]
105    pub fn new(
106        config: GGMLSpecificConfig,
107        chat_template: Option<String>,
108        tokenizer_json: Option<String>,
109        model_id: Option<String>,
110        quantized_model_id: String,
111        quantized_filename: String,
112        no_kv_cache: bool,
113        jinja_explicit: Option<String>,
114    ) -> Self {
115        let kind = ModelKind::GgufQuantized {
116            quant: QuantizationKind::Ggml,
117        };
118
119        Self {
120            config,
121            chat_template,
122            tokenizer_json,
123            model_id,
124            kind,
125            quantized_filename,
126            quantized_model_id,
127            no_kv_cache,
128            jinja_explicit,
129            ..Default::default()
130        }
131    }
132
133    fn with_adapter(
134        mut self,
135        xlora_model_id: String,
136        xlora_order: Ordering,
137        no_kv_cache: bool,
138        tgt_non_granular_index: Option<usize>,
139    ) -> Self {
140        self.xlora_model_id = Some(xlora_model_id);
141        self.xlora_order = Some(xlora_order);
142        self.no_kv_cache = no_kv_cache;
143        self.tgt_non_granular_index = tgt_non_granular_index;
144        self.model_id = if let Some(id) = self.model_id {
145            Some(id)
146        } else {
147            info!(
148                "Using adapter base model ID: `{}`",
149                self.xlora_order.as_ref().unwrap().base_model_id
150            );
151            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
152        };
153        self
154    }
155
156    pub fn with_xlora(
157        mut self,
158        xlora_model_id: String,
159        xlora_order: Ordering,
160        no_kv_cache: bool,
161        tgt_non_granular_index: Option<usize>,
162    ) -> Self {
163        self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into();
164
165        self.with_adapter(
166            xlora_model_id,
167            xlora_order,
168            no_kv_cache,
169            tgt_non_granular_index,
170        )
171    }
172
173    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
174        self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into();
175
176        self.with_adapter(lora_model_id, lora_order, false, None)
177    }
178
179    pub fn build(self) -> Box<dyn Loader> {
180        Box::new(GGMLLoader {
181            model_id: self.model_id.unwrap(),
182            config: self.config,
183            xlora_model_id: self.xlora_model_id,
184            kind: self.kind,
185            xlora_order: self.xlora_order,
186            no_kv_cache: self.no_kv_cache,
187            chat_template: self.chat_template,
188            tokenizer_json: self.tokenizer_json,
189            tgt_non_granular_index: self.tgt_non_granular_index,
190            quantized_filename: Some(self.quantized_filename),
191            quantized_model_id: Some(self.quantized_model_id),
192            jinja_explicit: self.jinja_explicit,
193            lora_adapter_ids: None,
194        })
195    }
196}
197
198impl GGMLLoader {
199    #[allow(clippy::too_many_arguments)]
200    pub fn new(
201        model_id: Option<String>,
202        config: GGMLSpecificConfig,
203        quantized_model_id: Option<String>,
204        quantized_filename: Option<String>,
205        xlora_model_id: Option<String>,
206        kind: ModelKind,
207        xlora_order: Option<Ordering>,
208        no_kv_cache: bool,
209        chat_template: Option<String>,
210        tokenizer_json: Option<String>,
211        tgt_non_granular_index: Option<usize>,
212        jinja_explicit: Option<String>,
213    ) -> Self {
214        let model_id = if let Some(id) = model_id {
215            id
216        } else {
217            info!(
218                "Using adapter base model ID: `{}`",
219                xlora_order.as_ref().unwrap().base_model_id
220            );
221            xlora_order.as_ref().unwrap().base_model_id.clone()
222        };
223        Self {
224            model_id,
225            config,
226            quantized_model_id,
227            quantized_filename,
228            xlora_model_id,
229            xlora_order,
230            no_kv_cache,
231            chat_template,
232            tokenizer_json,
233            kind,
234            tgt_non_granular_index,
235            jinja_explicit,
236            lora_adapter_ids: None,
237        }
238    }
239}
240
241impl Loader for GGMLLoader {
242    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
243    fn load_model_from_path(
244        &self,
245        paths: &Box<dyn ModelPaths>,
246        dtype: &dyn TryIntoDType,
247        device: &Device,
248        silent: bool,
249        mapper: DeviceMapSetting,
250        in_situ_quant: Option<IsqType>,
251        mut paged_attn_config: Option<PagedAttentionConfig>,
252    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
253        if in_situ_quant.is_some() {
254            anyhow::bail!(
255                "You are trying to in-situ quantize a GGML model. This will not do anything."
256            );
257        }
258
259        if matches!(mapper, DeviceMapSetting::Map(_)) {
260            anyhow::bail!("Device mapping is not supported for diffusion models.")
261        }
262
263        if paged_attn_config.is_some() {
264            warn!("PagedAttention is not supported for GGML models, disabling it.");
265
266            paged_attn_config = None;
267        }
268
269        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
270
271        info!(
272            "Loading model `{}` on {}.",
273            self.get_id(),
274            device.device_pretty_repr()
275        );
276
277        #[cfg(feature = "cuda")]
278        if let Device::Cuda(dev) = &device {
279            unsafe { dev.disable_event_tracking() };
280        }
281
282        let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
283        let model = ggml_file::Content::read(&mut file, device)
284            .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
285
286        info!("Model config: {:?}", model.hparams);
287
288        if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
289            let mut tensors = Vec::new();
290            for (name, t) in &model.tensors {
291                tensors.push(format!(
292                    "name = `{name}`, shape = {:?}, dtype = {:?}",
293                    t.shape().clone(),
294                    t.dtype(),
295                ));
296            }
297            fs::write(
298                "mistralrs_ggml_tensors.txt",
299                serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
300            )?;
301
302            info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_ggml_tensors.txt`.");
303        }
304
305        let _ = if paged_attn_config.is_none() {
306            warn!("GGML does not currently support PagedAttention, running without");
307            None
308        } else {
309            paged_attn_config
310        };
311
312        let has_adapter = self.kind.is_adapted();
313        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
314        let internal_dtype = dtype.try_into_dtype(&[device]).unwrap();
315
316        let model_config = {
317            // Base config (quantization only):
318            let quant = ModelConfig::ParamsGGML((model, self.config.gqa, internal_dtype).into());
319
320            // With optional adapter config:
321            let mut adapter = None;
322            if has_adapter {
323                adapter.replace(ModelConfig::Adapter::try_new(
324                    paths, device, silent, is_xlora,
325                )?);
326            }
327
328            ModelConfig::ModelParams::new(quant, adapter)
329        };
330
331        // Config into model:
332        // NOTE: No architecture to infer like GGUF, Llama model is implicitly matched
333        let model = match self.kind {
334            ModelKind::GgufQuantized { .. } => Model::Llama(QLlama::try_from(model_config)?),
335            ModelKind::GgufAdapter { .. } => {
336                Model::XLoraLlama(Box::new(XLoraQLlama::try_from(model_config)?))
337            }
338            _ => unreachable!(),
339        };
340
341        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
342        let gen_conf: Option<GenerationConfig> = paths
343            .get_gen_conf_filename()
344            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
345        let chat_template_explicit = paths
346            .get_chat_template_explicit()
347            .as_ref()
348            .map(|x| x.to_string_lossy().to_string());
349        let chat_template = get_chat_template(
350            paths,
351            self.jinja_explicit.as_ref(),
352            chat_template_explicit.as_ref(),
353            self.chat_template.as_ref(),
354            None,
355        );
356
357        let max_seq_len = match model {
358            Model::Llama(ref l) => l.max_seq_len,
359            Model::XLoraLlama(ref xl) => xl.max_seq_len,
360        };
361        let llg_factory = build_llg_factory(tokenizer.clone())?;
362        let num_hidden_layers = match model {
363            Model::Llama(ref model) => model.cache.normal().0.len(),
364            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
365        };
366        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
367        Ok(Arc::new(Mutex::new(GGMLPipeline {
368            model,
369            tokenizer: tokenizer.into(),
370            no_kv_cache: self.no_kv_cache,
371            chat_template: Arc::new(chat_template),
372            model_id: self.model_id.clone(),
373            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
374                NonGranularState {
375                    non_granular_index: Arc::new(Mutex::new(0)),
376                    tgt_non_granular_index,
377                }
378            }),
379            metadata: Arc::new(GeneralMetadata {
380                max_seq_len,
381                llg_factory: Some(llg_factory),
382                no_kv_cache: self.no_kv_cache,
383                no_prefix_cache: false,
384                num_hidden_layers,
385                eos_tok: eos,
386                kind: self.kind.clone(),
387                is_xlora,
388                activation_dtype: internal_dtype,
389                sliding_window: None,
390                cache_config: None,
391                cache_engine: None,
392                model_metadata: None,
393                modalities: Modalities {
394                    input: vec![SupportedModality::Text],
395                    output: vec![SupportedModality::Text],
396                },
397            }),
398        })))
399    }
400
401    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
402    fn load_model_from_hf(
403        &self,
404        revision: Option<String>,
405        token_source: TokenSource,
406        dtype: &dyn TryIntoDType,
407        device: &Device,
408        silent: bool,
409        mapper: DeviceMapSetting,
410        in_situ_quant: Option<IsqType>,
411        paged_attn_config: Option<PagedAttentionConfig>,
412    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
413        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
414            LocalModelPaths,
415            &token_source,
416            revision,
417            self,
418            self.quantized_model_id,
419            Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
420            silent,
421            false // Never loading UQFF
422        );
423        self.load_model_from_path(
424            &paths?,
425            dtype,
426            device,
427            silent,
428            mapper,
429            in_situ_quant,
430            paged_attn_config,
431        )
432    }
433
434    fn get_id(&self) -> String {
435        self.xlora_model_id
436            .as_deref()
437            .unwrap_or(&self.model_id)
438            .to_string()
439    }
440
441    fn get_kind(&self) -> ModelKind {
442        self.kind.clone()
443    }
444}
445
446impl PreProcessingMixin for GGMLPipeline {
447    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
448        Some(self.chat_template.clone())
449    }
450    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
451        None
452    }
453}
454
455impl IsqPipelineMixin for GGMLPipeline {
456    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
457        anyhow::bail!(
458            "You are trying to in-situ requantize a GGML model. This will not do anything."
459        )
460    }
461}
462
463impl CacheManagerMixin for GGMLPipeline {
464    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
465        FullCacheManager.clone_in_cache(self, seqs, false)
466    }
467    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
468        FullCacheManager.clone_out_cache(self, seqs, false)
469    }
470    fn set_none_cache(
471        &self,
472        seqs: &mut [&mut Sequence],
473        reset_non_granular: bool,
474        modify_draft_cache: bool,
475
476        load_preallocated_cache: bool,
477    ) {
478        FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
479        if reset_non_granular {
480            self.reset_non_granular_state()
481        }
482    }
483    fn cache(&self) -> &EitherCache {
484        match self.model {
485            Model::Llama(ref model) => &model.cache,
486            Model::XLoraLlama(ref model) => &model.cache,
487        }
488    }
489}
490
491impl MetadataMixin for GGMLPipeline {
492    fn device(&self) -> Device {
493        match self.model {
494            Model::Llama(ref model) => model.device.clone(),
495            Model::XLoraLlama(ref model) => model.device.clone(),
496        }
497    }
498    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
499        Some(self.tokenizer.clone())
500    }
501    fn name(&self) -> String {
502        self.model_id.clone()
503    }
504    fn reset_non_granular_state(&self) {
505        if let Some(s) = self.non_granular_state.as_ref() {
506            *self.cache().full().get_scalings_cache() = None;
507            *get_mut_arcmutex!(s.non_granular_index) = 0;
508        }
509    }
510    fn get_metadata(&self) -> Arc<GeneralMetadata> {
511        self.metadata.clone()
512    }
513    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
514        None
515    }
516}
517
518#[async_trait::async_trait]
519impl Pipeline for GGMLPipeline {
520    fn forward_inputs(
521        &mut self,
522        inputs: Box<dyn Any>,
523        return_raw_logits: bool,
524    ) -> Result<ForwardInputsResult, candle_core::Error> {
525        let ModelInputs {
526            input_ids,
527            input_ids_full,
528            seqlen_offsets,
529            seqlen_offsets_full,
530            context_lens,
531            position_ids: _,    // NOTE(EricLBuehler): ignore, it is for phi3
532            paged_attn_meta: _, // NOTE(EricLBuehler): ignore it for ggml
533            flash_meta,         // NOTE(EricLBuehler): ignore it for ggml dequant into f32
534            flash_meta_full,    // NOTE(EricLBuehler): ignore it for ggml dequant into f32
535        } = *inputs.downcast().expect("Downcast failed.");
536        let logits = match self.model {
537            Model::Llama(ref model) => {
538                model.forward(&input_ids, &seqlen_offsets, context_lens, None)?
539            }
540            Model::XLoraLlama(ref model) => model.forward(
541                &input_ids,
542                input_ids_full.as_ref().unwrap_or(&input_ids),
543                &seqlen_offsets,
544                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
545                self.no_kv_cache,
546                &self.non_granular_state,
547                context_lens,
548                &flash_meta,
549                flash_meta_full.as_ref().unwrap_or(&flash_meta),
550            )?,
551        };
552        if return_raw_logits {
553            Ok(ForwardInputsResult::RawLogits { logits })
554        } else {
555            Ok(ForwardInputsResult::CausalGeneration { logits })
556        }
557    }
558    async fn sample_causal_gen(
559        &self,
560        seqs: &mut [&mut Sequence],
561        logits: Vec<Tensor>,
562        prefix_cacher: &mut PrefixCacheManagerV2,
563        disable_eos_stop: bool,
564        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
565    ) -> Result<(), candle_core::Error> {
566        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
567    }
568    fn category(&self) -> ModelCategory {
569        ModelCategory::Text
570    }
571}
572
573// TODO
574impl AnyMoePipelineMixin for GGMLPipeline {}