mistralrs_core/pipeline/
ggml.rs

1use super::llg::build_llg_factory;
2use super::{
3    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, QuantizationKind, TokenSource,
5};
6use super::{
7    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
8    MetadataMixin, ModelCategory, PreProcessingMixin,
9};
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::device_map::DeviceMapper;
12use crate::kv_cache::FullCacheManager;
13use crate::lora::Ordering;
14use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
15use crate::pipeline::sampling::sample_and_add_toks;
16use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
17use crate::pipeline::{ChatTemplate, LocalModelPaths};
18use crate::prefix_cacher::PrefixCacheManagerV2;
19use crate::sequence::Sequence;
20use crate::utils::debug::DeviceRepr;
21use crate::utils::model_config as ModelConfig;
22use crate::utils::progress::ProgressScopeGuard;
23use crate::utils::tokenizer::get_tokenizer;
24use crate::xlora_models::NonGranularState;
25use crate::{
26    get_mut_arcmutex, get_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
27    TryIntoDType, DEBUG,
28};
29use crate::{
30    models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
31    xlora_models::XLoraQLlama,
32};
33use anyhow::Result;
34use candle_core::quantized::ggml_file;
35use candle_core::{Device, Tensor};
36use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
37use mistralrs_quant::IsqType;
38use rand_isaac::Isaac64Rng;
39use std::any::Any;
40use std::fs;
41use std::path::PathBuf;
42use std::str::FromStr;
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tokio::sync::Mutex;
46use tracing::{info, warn};
47
48enum Model {
49    Llama(Box<QLlama>),
50    XLoraLlama(Box<XLoraQLlama>),
51}
52
53pub struct GGMLPipeline {
54    model: Model,
55    tokenizer: Arc<Tokenizer>,
56    no_kv_cache: bool,
57    chat_template: Arc<ChatTemplate>,
58    model_id: String,
59    non_granular_state: Option<NonGranularState>,
60    metadata: Arc<GeneralMetadata>,
61}
62
63/// A loader for a GGML model.
64pub struct GGMLLoader {
65    model_id: String,
66    config: GGMLSpecificConfig,
67    quantized_model_id: Option<String>,
68    quantized_filename: Option<String>,
69    xlora_model_id: Option<String>,
70    xlora_order: Option<Ordering>,
71    no_kv_cache: bool,
72    chat_template: Option<String>,
73    tokenizer_json: Option<String>,
74    kind: ModelKind,
75    tgt_non_granular_index: Option<usize>,
76    jinja_explicit: Option<String>,
77    lora_adapter_ids: Option<Vec<String>>,
78}
79
80#[derive(Clone, Default)]
81/// Config for a GGML loader.
82pub struct GGMLSpecificConfig {
83    pub gqa: usize,
84    pub topology: Option<Topology>,
85}
86
87#[derive(Default)]
88/// A builder for a GGML loader.
89pub struct GGMLLoaderBuilder {
90    model_id: Option<String>,
91    config: GGMLSpecificConfig,
92    quantized_model_id: String,
93    quantized_filename: String,
94    xlora_model_id: Option<String>,
95    kind: ModelKind,
96    xlora_order: Option<Ordering>,
97    no_kv_cache: bool,
98    chat_template: Option<String>,
99    tokenizer_json: Option<String>,
100    tgt_non_granular_index: Option<usize>,
101    jinja_explicit: Option<String>,
102}
103
104impl GGMLLoaderBuilder {
105    #[allow(clippy::too_many_arguments)]
106    pub fn new(
107        config: GGMLSpecificConfig,
108        chat_template: Option<String>,
109        tokenizer_json: Option<String>,
110        model_id: Option<String>,
111        quantized_model_id: String,
112        quantized_filename: String,
113        no_kv_cache: bool,
114        jinja_explicit: Option<String>,
115    ) -> Self {
116        let kind = ModelKind::GgufQuantized {
117            quant: QuantizationKind::Ggml,
118        };
119
120        Self {
121            config,
122            chat_template,
123            tokenizer_json,
124            model_id,
125            kind,
126            quantized_filename,
127            quantized_model_id,
128            no_kv_cache,
129            jinja_explicit,
130            ..Default::default()
131        }
132    }
133
134    fn with_adapter(
135        mut self,
136        xlora_model_id: String,
137        xlora_order: Ordering,
138        no_kv_cache: bool,
139        tgt_non_granular_index: Option<usize>,
140    ) -> Self {
141        self.xlora_model_id = Some(xlora_model_id);
142        self.xlora_order = Some(xlora_order);
143        self.no_kv_cache = no_kv_cache;
144        self.tgt_non_granular_index = tgt_non_granular_index;
145        self.model_id = if let Some(id) = self.model_id {
146            Some(id)
147        } else {
148            info!(
149                "Using adapter base model ID: `{}`",
150                self.xlora_order.as_ref().unwrap().base_model_id
151            );
152            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
153        };
154        self
155    }
156
157    pub fn with_xlora(
158        mut self,
159        xlora_model_id: String,
160        xlora_order: Ordering,
161        no_kv_cache: bool,
162        tgt_non_granular_index: Option<usize>,
163    ) -> Self {
164        self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into();
165
166        self.with_adapter(
167            xlora_model_id,
168            xlora_order,
169            no_kv_cache,
170            tgt_non_granular_index,
171        )
172    }
173
174    pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
175        self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into();
176
177        self.with_adapter(lora_model_id, lora_order, false, None)
178    }
179
180    pub fn build(self) -> Box<dyn Loader> {
181        Box::new(GGMLLoader {
182            model_id: self.model_id.unwrap(),
183            config: self.config,
184            xlora_model_id: self.xlora_model_id,
185            kind: self.kind,
186            xlora_order: self.xlora_order,
187            no_kv_cache: self.no_kv_cache,
188            chat_template: self.chat_template,
189            tokenizer_json: self.tokenizer_json,
190            tgt_non_granular_index: self.tgt_non_granular_index,
191            quantized_filename: Some(self.quantized_filename),
192            quantized_model_id: Some(self.quantized_model_id),
193            jinja_explicit: self.jinja_explicit,
194            lora_adapter_ids: None,
195        })
196    }
197}
198
199impl GGMLLoader {
200    #[allow(clippy::too_many_arguments)]
201    pub fn new(
202        model_id: Option<String>,
203        config: GGMLSpecificConfig,
204        quantized_model_id: Option<String>,
205        quantized_filename: Option<String>,
206        xlora_model_id: Option<String>,
207        kind: ModelKind,
208        xlora_order: Option<Ordering>,
209        no_kv_cache: bool,
210        chat_template: Option<String>,
211        tokenizer_json: Option<String>,
212        tgt_non_granular_index: Option<usize>,
213        jinja_explicit: Option<String>,
214    ) -> Self {
215        let model_id = if let Some(id) = model_id {
216            id
217        } else {
218            info!(
219                "Using adapter base model ID: `{}`",
220                xlora_order.as_ref().unwrap().base_model_id
221            );
222            xlora_order.as_ref().unwrap().base_model_id.clone()
223        };
224        Self {
225            model_id,
226            config,
227            quantized_model_id,
228            quantized_filename,
229            xlora_model_id,
230            xlora_order,
231            no_kv_cache,
232            chat_template,
233            tokenizer_json,
234            kind,
235            tgt_non_granular_index,
236            jinja_explicit,
237            lora_adapter_ids: None,
238        }
239    }
240}
241
242impl Loader for GGMLLoader {
243    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
244    fn load_model_from_path(
245        &self,
246        paths: &Box<dyn ModelPaths>,
247        dtype: &dyn TryIntoDType,
248        device: &Device,
249        silent: bool,
250        mapper: DeviceMapSetting,
251        in_situ_quant: Option<IsqType>,
252        mut paged_attn_config: Option<PagedAttentionConfig>,
253    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
254        let _progress_guard = ProgressScopeGuard::new(silent);
255        if in_situ_quant.is_some() {
256            anyhow::bail!(
257                "You are trying to in-situ quantize a GGML model. This will not do anything."
258            );
259        }
260
261        if matches!(mapper, DeviceMapSetting::Map(_)) {
262            anyhow::bail!("Device mapping is not supported for diffusion models.")
263        }
264
265        if paged_attn_config.is_some() {
266            warn!("PagedAttention is not supported for GGML models, disabling it.");
267
268            paged_attn_config = None;
269        }
270
271        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
272
273        info!(
274            "Loading model `{}` on {}.",
275            self.get_id(),
276            device.device_pretty_repr()
277        );
278
279        #[cfg(feature = "cuda")]
280        if let Device::Cuda(dev) = &device {
281            unsafe { dev.disable_event_tracking() };
282        }
283
284        let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
285        let model = ggml_file::Content::read(&mut file, device)
286            .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
287
288        info!("Model config: {:?}", model.hparams);
289
290        if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
291            let mut tensors = Vec::new();
292            for (name, t) in &model.tensors {
293                tensors.push(format!(
294                    "name = `{name}`, shape = {:?}, dtype = {:?}",
295                    t.shape().clone(),
296                    t.dtype(),
297                ));
298            }
299            fs::write(
300                "mistralrs_ggml_tensors.txt",
301                serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
302            )?;
303
304            info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_ggml_tensors.txt`.");
305        }
306
307        let _ = if paged_attn_config.is_none() {
308            warn!("GGML does not currently support PagedAttention, running without");
309            None
310        } else {
311            paged_attn_config
312        };
313
314        let has_adapter = self.kind.is_adapted();
315        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
316        let internal_dtype = dtype.try_into_dtype(&[device]).unwrap();
317
318        let model_config = {
319            // Base config (quantization only):
320            let quant = ModelConfig::ParamsGGML((model, self.config.gqa, internal_dtype).into());
321
322            // With optional adapter config:
323            let mut adapter = None;
324            if has_adapter {
325                adapter.replace(ModelConfig::Adapter::try_new(
326                    paths, device, silent, is_xlora,
327                )?);
328            }
329
330            ModelConfig::ModelParams::new(quant, adapter)
331        };
332
333        // Config into model:
334        // NOTE: No architecture to infer like GGUF, Llama model is implicitly matched
335        let model = match self.kind {
336            ModelKind::GgufQuantized { .. } => {
337                Model::Llama(Box::new(QLlama::try_from(model_config)?))
338            }
339            ModelKind::GgufAdapter { .. } => {
340                Model::XLoraLlama(Box::new(XLoraQLlama::try_from(model_config)?))
341            }
342            _ => unreachable!(),
343        };
344
345        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
346        let gen_conf: Option<GenerationConfig> = paths
347            .get_gen_conf_filename()
348            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
349        let chat_template_explicit = paths
350            .get_chat_template_explicit()
351            .as_ref()
352            .map(|x| x.to_string_lossy().to_string());
353        let chat_template = get_chat_template(
354            paths,
355            self.jinja_explicit.as_ref(),
356            chat_template_explicit.as_ref(),
357            self.chat_template.as_ref(),
358            None,
359        );
360
361        let max_seq_len = match model {
362            Model::Llama(ref l) => l.max_seq_len,
363            Model::XLoraLlama(ref xl) => xl.max_seq_len,
364        };
365        let llg_factory = build_llg_factory(tokenizer.clone())?;
366        let num_hidden_layers = match model {
367            Model::Llama(ref model) => model.cache.normal().0.len(),
368            Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
369        };
370        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
371        Ok(Arc::new(Mutex::new(GGMLPipeline {
372            model,
373            tokenizer: tokenizer.into(),
374            no_kv_cache: self.no_kv_cache,
375            chat_template: Arc::new(chat_template),
376            model_id: self.model_id.clone(),
377            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
378                NonGranularState {
379                    non_granular_index: Arc::new(Mutex::new(0)),
380                    tgt_non_granular_index,
381                }
382            }),
383            metadata: Arc::new(GeneralMetadata {
384                max_seq_len,
385                llg_factory: Some(llg_factory),
386                no_kv_cache: self.no_kv_cache,
387                no_prefix_cache: false,
388                num_hidden_layers,
389                eos_tok: eos,
390                kind: self.kind.clone(),
391                is_xlora,
392                activation_dtype: internal_dtype,
393                sliding_window: None,
394                cache_config: None,
395                cache_engine: None,
396                model_metadata: None,
397                modalities: Modalities {
398                    input: vec![SupportedModality::Text],
399                    output: vec![SupportedModality::Text],
400                },
401            }),
402        })))
403    }
404
405    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
406    fn load_model_from_hf(
407        &self,
408        revision: Option<String>,
409        token_source: TokenSource,
410        dtype: &dyn TryIntoDType,
411        device: &Device,
412        silent: bool,
413        mapper: DeviceMapSetting,
414        in_situ_quant: Option<IsqType>,
415        paged_attn_config: Option<PagedAttentionConfig>,
416    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
417        let _progress_guard = ProgressScopeGuard::new(silent);
418        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
419            LocalModelPaths,
420            &token_source,
421            revision,
422            self,
423            self.quantized_model_id,
424            Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
425            silent,
426            false // Never loading UQFF
427        );
428        self.load_model_from_path(
429            &paths?,
430            dtype,
431            device,
432            silent,
433            mapper,
434            in_situ_quant,
435            paged_attn_config,
436        )
437    }
438
439    fn get_id(&self) -> String {
440        self.xlora_model_id
441            .as_deref()
442            .unwrap_or(&self.model_id)
443            .to_string()
444    }
445
446    fn get_kind(&self) -> ModelKind {
447        self.kind.clone()
448    }
449}
450
451impl PreProcessingMixin for GGMLPipeline {
452    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
453        Some(self.chat_template.clone())
454    }
455    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
456        None
457    }
458}
459
460impl IsqPipelineMixin for GGMLPipeline {
461    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
462        anyhow::bail!(
463            "You are trying to in-situ requantize a GGML model. This will not do anything."
464        )
465    }
466}
467
468impl CacheManagerMixin for GGMLPipeline {
469    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
470        FullCacheManager.clone_in_cache(self, seqs, false)
471    }
472    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
473        FullCacheManager.clone_out_cache(self, seqs, false)
474    }
475    fn set_none_cache(
476        &self,
477        seqs: &mut [&mut Sequence],
478        reset_non_granular: bool,
479        modify_draft_cache: bool,
480
481        load_preallocated_cache: bool,
482    ) {
483        FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
484        if reset_non_granular {
485            self.reset_non_granular_state()
486        }
487    }
488    fn cache(&self) -> &EitherCache {
489        match self.model {
490            Model::Llama(ref model) => &model.cache,
491            Model::XLoraLlama(ref model) => &model.cache,
492        }
493    }
494}
495
496impl MetadataMixin for GGMLPipeline {
497    fn device(&self) -> Device {
498        match self.model {
499            Model::Llama(ref model) => model.device.clone(),
500            Model::XLoraLlama(ref model) => model.device.clone(),
501        }
502    }
503    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
504        Some(self.tokenizer.clone())
505    }
506    fn name(&self) -> String {
507        self.model_id.clone()
508    }
509    fn reset_non_granular_state(&self) {
510        if let Some(s) = self.non_granular_state.as_ref() {
511            *self.cache().full().get_scalings_cache() = None;
512            *get_mut_arcmutex!(s.non_granular_index) = 0;
513        }
514    }
515    fn get_metadata(&self) -> Arc<GeneralMetadata> {
516        self.metadata.clone()
517    }
518    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
519        None
520    }
521}
522
523#[async_trait::async_trait]
524impl Pipeline for GGMLPipeline {
525    fn forward_inputs(
526        &mut self,
527        inputs: Box<dyn Any>,
528        return_raw_logits: bool,
529    ) -> Result<ForwardInputsResult, candle_core::Error> {
530        let ModelInputs {
531            input_ids,
532            input_ids_full,
533            seqlen_offsets,
534            seqlen_offsets_full,
535            context_lens,
536            position_ids: _,    // NOTE(EricLBuehler): ignore, it is for phi3
537            paged_attn_meta: _, // NOTE(EricLBuehler): ignore it for ggml
538            flash_meta,         // NOTE(EricLBuehler): ignore it for ggml dequant into f32
539            flash_meta_full,    // NOTE(EricLBuehler): ignore it for ggml dequant into f32
540        } = *inputs.downcast().expect("Downcast failed.");
541        let logits = match self.model {
542            Model::Llama(ref model) => {
543                model.forward(&input_ids, &seqlen_offsets, context_lens, None)?
544            }
545            Model::XLoraLlama(ref model) => model.forward(
546                &input_ids,
547                input_ids_full.as_ref().unwrap_or(&input_ids),
548                &seqlen_offsets,
549                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
550                self.no_kv_cache,
551                &self.non_granular_state,
552                context_lens,
553                &flash_meta,
554                flash_meta_full.as_ref().unwrap_or(&flash_meta),
555            )?,
556        };
557        if return_raw_logits {
558            Ok(ForwardInputsResult::RawLogits { logits })
559        } else {
560            Ok(ForwardInputsResult::CausalGeneration { logits })
561        }
562    }
563    async fn sample_causal_gen(
564        &self,
565        seqs: &mut [&mut Sequence],
566        logits: Vec<Tensor>,
567        prefix_cacher: &mut PrefixCacheManagerV2,
568        disable_eos_stop: bool,
569        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
570    ) -> Result<(), candle_core::Error> {
571        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
572    }
573    fn category(&self) -> ModelCategory {
574        ModelCategory::Text
575    }
576}
577
578// TODO
579impl AnyMoePipelineMixin for GGMLPipeline {}