1use super::llg::build_llg_factory;
2use super::{
3 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, QuantizationKind, TokenSource,
5};
6use super::{
7 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
8 MetadataMixin, ModelCategory, PreProcessingMixin,
9};
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::device_map::DeviceMapper;
12use crate::kv_cache::FullCacheManager;
13use crate::lora::Ordering;
14use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
15use crate::pipeline::sampling::sample_and_add_toks;
16use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
17use crate::pipeline::{ChatTemplate, LocalModelPaths};
18use crate::prefix_cacher::PrefixCacheManagerV2;
19use crate::sequence::Sequence;
20use crate::utils::debug::DeviceRepr;
21use crate::utils::model_config as ModelConfig;
22use crate::utils::progress::ProgressScopeGuard;
23use crate::utils::tokenizer::get_tokenizer;
24use crate::xlora_models::NonGranularState;
25use crate::{
26 get_mut_arcmutex, get_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
27 TryIntoDType, DEBUG,
28};
29use crate::{
30 models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
31 xlora_models::XLoraQLlama,
32};
33use anyhow::Result;
34use candle_core::quantized::ggml_file;
35use candle_core::{Device, Tensor};
36use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
37use mistralrs_quant::IsqType;
38use rand_isaac::Isaac64Rng;
39use std::any::Any;
40use std::fs;
41use std::path::PathBuf;
42use std::str::FromStr;
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tokio::sync::Mutex;
46use tracing::{info, warn};
47
48enum Model {
49 Llama(Box<QLlama>),
50 XLoraLlama(Box<XLoraQLlama>),
51}
52
53pub struct GGMLPipeline {
54 model: Model,
55 tokenizer: Arc<Tokenizer>,
56 no_kv_cache: bool,
57 chat_template: Arc<ChatTemplate>,
58 model_id: String,
59 non_granular_state: Option<NonGranularState>,
60 metadata: Arc<GeneralMetadata>,
61}
62
63pub struct GGMLLoader {
65 model_id: String,
66 config: GGMLSpecificConfig,
67 quantized_model_id: Option<String>,
68 quantized_filename: Option<String>,
69 xlora_model_id: Option<String>,
70 xlora_order: Option<Ordering>,
71 no_kv_cache: bool,
72 chat_template: Option<String>,
73 tokenizer_json: Option<String>,
74 kind: ModelKind,
75 tgt_non_granular_index: Option<usize>,
76 jinja_explicit: Option<String>,
77 lora_adapter_ids: Option<Vec<String>>,
78}
79
80#[derive(Clone, Default)]
81pub struct GGMLSpecificConfig {
83 pub gqa: usize,
84 pub topology: Option<Topology>,
85}
86
87#[derive(Default)]
88pub struct GGMLLoaderBuilder {
90 model_id: Option<String>,
91 config: GGMLSpecificConfig,
92 quantized_model_id: String,
93 quantized_filename: String,
94 xlora_model_id: Option<String>,
95 kind: ModelKind,
96 xlora_order: Option<Ordering>,
97 no_kv_cache: bool,
98 chat_template: Option<String>,
99 tokenizer_json: Option<String>,
100 tgt_non_granular_index: Option<usize>,
101 jinja_explicit: Option<String>,
102}
103
104impl GGMLLoaderBuilder {
105 #[allow(clippy::too_many_arguments)]
106 pub fn new(
107 config: GGMLSpecificConfig,
108 chat_template: Option<String>,
109 tokenizer_json: Option<String>,
110 model_id: Option<String>,
111 quantized_model_id: String,
112 quantized_filename: String,
113 no_kv_cache: bool,
114 jinja_explicit: Option<String>,
115 ) -> Self {
116 let kind = ModelKind::GgufQuantized {
117 quant: QuantizationKind::Ggml,
118 };
119
120 Self {
121 config,
122 chat_template,
123 tokenizer_json,
124 model_id,
125 kind,
126 quantized_filename,
127 quantized_model_id,
128 no_kv_cache,
129 jinja_explicit,
130 ..Default::default()
131 }
132 }
133
134 fn with_adapter(
135 mut self,
136 xlora_model_id: String,
137 xlora_order: Ordering,
138 no_kv_cache: bool,
139 tgt_non_granular_index: Option<usize>,
140 ) -> Self {
141 self.xlora_model_id = Some(xlora_model_id);
142 self.xlora_order = Some(xlora_order);
143 self.no_kv_cache = no_kv_cache;
144 self.tgt_non_granular_index = tgt_non_granular_index;
145 self.model_id = if let Some(id) = self.model_id {
146 Some(id)
147 } else {
148 info!(
149 "Using adapter base model ID: `{}`",
150 self.xlora_order.as_ref().unwrap().base_model_id
151 );
152 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
153 };
154 self
155 }
156
157 pub fn with_xlora(
158 mut self,
159 xlora_model_id: String,
160 xlora_order: Ordering,
161 no_kv_cache: bool,
162 tgt_non_granular_index: Option<usize>,
163 ) -> Self {
164 self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into();
165
166 self.with_adapter(
167 xlora_model_id,
168 xlora_order,
169 no_kv_cache,
170 tgt_non_granular_index,
171 )
172 }
173
174 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
175 self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into();
176
177 self.with_adapter(lora_model_id, lora_order, false, None)
178 }
179
180 pub fn build(self) -> Box<dyn Loader> {
181 Box::new(GGMLLoader {
182 model_id: self.model_id.unwrap(),
183 config: self.config,
184 xlora_model_id: self.xlora_model_id,
185 kind: self.kind,
186 xlora_order: self.xlora_order,
187 no_kv_cache: self.no_kv_cache,
188 chat_template: self.chat_template,
189 tokenizer_json: self.tokenizer_json,
190 tgt_non_granular_index: self.tgt_non_granular_index,
191 quantized_filename: Some(self.quantized_filename),
192 quantized_model_id: Some(self.quantized_model_id),
193 jinja_explicit: self.jinja_explicit,
194 lora_adapter_ids: None,
195 })
196 }
197}
198
199impl GGMLLoader {
200 #[allow(clippy::too_many_arguments)]
201 pub fn new(
202 model_id: Option<String>,
203 config: GGMLSpecificConfig,
204 quantized_model_id: Option<String>,
205 quantized_filename: Option<String>,
206 xlora_model_id: Option<String>,
207 kind: ModelKind,
208 xlora_order: Option<Ordering>,
209 no_kv_cache: bool,
210 chat_template: Option<String>,
211 tokenizer_json: Option<String>,
212 tgt_non_granular_index: Option<usize>,
213 jinja_explicit: Option<String>,
214 ) -> Self {
215 let model_id = if let Some(id) = model_id {
216 id
217 } else {
218 info!(
219 "Using adapter base model ID: `{}`",
220 xlora_order.as_ref().unwrap().base_model_id
221 );
222 xlora_order.as_ref().unwrap().base_model_id.clone()
223 };
224 Self {
225 model_id,
226 config,
227 quantized_model_id,
228 quantized_filename,
229 xlora_model_id,
230 xlora_order,
231 no_kv_cache,
232 chat_template,
233 tokenizer_json,
234 kind,
235 tgt_non_granular_index,
236 jinja_explicit,
237 lora_adapter_ids: None,
238 }
239 }
240}
241
242impl Loader for GGMLLoader {
243 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
244 fn load_model_from_path(
245 &self,
246 paths: &Box<dyn ModelPaths>,
247 dtype: &dyn TryIntoDType,
248 device: &Device,
249 silent: bool,
250 mapper: DeviceMapSetting,
251 in_situ_quant: Option<IsqType>,
252 mut paged_attn_config: Option<PagedAttentionConfig>,
253 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
254 let _progress_guard = ProgressScopeGuard::new(silent);
255 if in_situ_quant.is_some() {
256 anyhow::bail!(
257 "You are trying to in-situ quantize a GGML model. This will not do anything."
258 );
259 }
260
261 if matches!(mapper, DeviceMapSetting::Map(_)) {
262 anyhow::bail!("Device mapping is not supported for diffusion models.")
263 }
264
265 if paged_attn_config.is_some() {
266 warn!("PagedAttention is not supported for GGML models, disabling it.");
267
268 paged_attn_config = None;
269 }
270
271 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
272
273 info!(
274 "Loading model `{}` on {}.",
275 self.get_id(),
276 device.device_pretty_repr()
277 );
278
279 #[cfg(feature = "cuda")]
280 if let Device::Cuda(dev) = &device {
281 unsafe { dev.disable_event_tracking() };
282 }
283
284 let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
285 let model = ggml_file::Content::read(&mut file, device)
286 .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
287
288 info!("Model config: {:?}", model.hparams);
289
290 if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
291 let mut tensors = Vec::new();
292 for (name, t) in &model.tensors {
293 tensors.push(format!(
294 "name = `{name}`, shape = {:?}, dtype = {:?}",
295 t.shape().clone(),
296 t.dtype(),
297 ));
298 }
299 fs::write(
300 "mistralrs_ggml_tensors.txt",
301 serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
302 )?;
303
304 info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_ggml_tensors.txt`.");
305 }
306
307 let _ = if paged_attn_config.is_none() {
308 warn!("GGML does not currently support PagedAttention, running without");
309 None
310 } else {
311 paged_attn_config
312 };
313
314 let has_adapter = self.kind.is_adapted();
315 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
316 let internal_dtype = dtype.try_into_dtype(&[device]).unwrap();
317
318 let model_config = {
319 let quant = ModelConfig::ParamsGGML((model, self.config.gqa, internal_dtype).into());
321
322 let mut adapter = None;
324 if has_adapter {
325 adapter.replace(ModelConfig::Adapter::try_new(
326 paths, device, silent, is_xlora,
327 )?);
328 }
329
330 ModelConfig::ModelParams::new(quant, adapter)
331 };
332
333 let model = match self.kind {
336 ModelKind::GgufQuantized { .. } => {
337 Model::Llama(Box::new(QLlama::try_from(model_config)?))
338 }
339 ModelKind::GgufAdapter { .. } => {
340 Model::XLoraLlama(Box::new(XLoraQLlama::try_from(model_config)?))
341 }
342 _ => unreachable!(),
343 };
344
345 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
346 let gen_conf: Option<GenerationConfig> = paths
347 .get_gen_conf_filename()
348 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
349 let chat_template_explicit = paths
350 .get_chat_template_explicit()
351 .as_ref()
352 .map(|x| x.to_string_lossy().to_string());
353 let chat_template = get_chat_template(
354 paths,
355 self.jinja_explicit.as_ref(),
356 chat_template_explicit.as_ref(),
357 self.chat_template.as_ref(),
358 None,
359 );
360
361 let max_seq_len = match model {
362 Model::Llama(ref l) => l.max_seq_len,
363 Model::XLoraLlama(ref xl) => xl.max_seq_len,
364 };
365 let llg_factory = build_llg_factory(tokenizer.clone())?;
366 let num_hidden_layers = match model {
367 Model::Llama(ref model) => model.cache.normal().0.len(),
368 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
369 };
370 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
371 Ok(Arc::new(Mutex::new(GGMLPipeline {
372 model,
373 tokenizer: tokenizer.into(),
374 no_kv_cache: self.no_kv_cache,
375 chat_template: Arc::new(chat_template),
376 model_id: self.model_id.clone(),
377 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
378 NonGranularState {
379 non_granular_index: Arc::new(Mutex::new(0)),
380 tgt_non_granular_index,
381 }
382 }),
383 metadata: Arc::new(GeneralMetadata {
384 max_seq_len,
385 llg_factory: Some(llg_factory),
386 no_kv_cache: self.no_kv_cache,
387 no_prefix_cache: false,
388 num_hidden_layers,
389 eos_tok: eos,
390 kind: self.kind.clone(),
391 is_xlora,
392 activation_dtype: internal_dtype,
393 sliding_window: None,
394 cache_config: None,
395 cache_engine: None,
396 model_metadata: None,
397 modalities: Modalities {
398 input: vec![SupportedModality::Text],
399 output: vec![SupportedModality::Text],
400 },
401 }),
402 })))
403 }
404
405 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
406 fn load_model_from_hf(
407 &self,
408 revision: Option<String>,
409 token_source: TokenSource,
410 dtype: &dyn TryIntoDType,
411 device: &Device,
412 silent: bool,
413 mapper: DeviceMapSetting,
414 in_situ_quant: Option<IsqType>,
415 paged_attn_config: Option<PagedAttentionConfig>,
416 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
417 let _progress_guard = ProgressScopeGuard::new(silent);
418 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
419 LocalModelPaths,
420 &token_source,
421 revision,
422 self,
423 self.quantized_model_id,
424 Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
425 silent,
426 false );
428 self.load_model_from_path(
429 &paths?,
430 dtype,
431 device,
432 silent,
433 mapper,
434 in_situ_quant,
435 paged_attn_config,
436 )
437 }
438
439 fn get_id(&self) -> String {
440 self.xlora_model_id
441 .as_deref()
442 .unwrap_or(&self.model_id)
443 .to_string()
444 }
445
446 fn get_kind(&self) -> ModelKind {
447 self.kind.clone()
448 }
449}
450
451impl PreProcessingMixin for GGMLPipeline {
452 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
453 Some(self.chat_template.clone())
454 }
455 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
456 None
457 }
458}
459
460impl IsqPipelineMixin for GGMLPipeline {
461 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
462 anyhow::bail!(
463 "You are trying to in-situ requantize a GGML model. This will not do anything."
464 )
465 }
466}
467
468impl CacheManagerMixin for GGMLPipeline {
469 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
470 FullCacheManager.clone_in_cache(self, seqs, false)
471 }
472 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
473 FullCacheManager.clone_out_cache(self, seqs, false)
474 }
475 fn set_none_cache(
476 &self,
477 seqs: &mut [&mut Sequence],
478 reset_non_granular: bool,
479 modify_draft_cache: bool,
480
481 load_preallocated_cache: bool,
482 ) {
483 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
484 if reset_non_granular {
485 self.reset_non_granular_state()
486 }
487 }
488 fn cache(&self) -> &EitherCache {
489 match self.model {
490 Model::Llama(ref model) => &model.cache,
491 Model::XLoraLlama(ref model) => &model.cache,
492 }
493 }
494}
495
496impl MetadataMixin for GGMLPipeline {
497 fn device(&self) -> Device {
498 match self.model {
499 Model::Llama(ref model) => model.device.clone(),
500 Model::XLoraLlama(ref model) => model.device.clone(),
501 }
502 }
503 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
504 Some(self.tokenizer.clone())
505 }
506 fn name(&self) -> String {
507 self.model_id.clone()
508 }
509 fn reset_non_granular_state(&self) {
510 if let Some(s) = self.non_granular_state.as_ref() {
511 *self.cache().full().get_scalings_cache() = None;
512 *get_mut_arcmutex!(s.non_granular_index) = 0;
513 }
514 }
515 fn get_metadata(&self) -> Arc<GeneralMetadata> {
516 self.metadata.clone()
517 }
518 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
519 None
520 }
521}
522
523#[async_trait::async_trait]
524impl Pipeline for GGMLPipeline {
525 fn forward_inputs(
526 &mut self,
527 inputs: Box<dyn Any>,
528 return_raw_logits: bool,
529 ) -> Result<ForwardInputsResult, candle_core::Error> {
530 let ModelInputs {
531 input_ids,
532 input_ids_full,
533 seqlen_offsets,
534 seqlen_offsets_full,
535 context_lens,
536 position_ids: _, paged_attn_meta: _, flash_meta, flash_meta_full, } = *inputs.downcast().expect("Downcast failed.");
541 let logits = match self.model {
542 Model::Llama(ref model) => {
543 model.forward(&input_ids, &seqlen_offsets, context_lens, None)?
544 }
545 Model::XLoraLlama(ref model) => model.forward(
546 &input_ids,
547 input_ids_full.as_ref().unwrap_or(&input_ids),
548 &seqlen_offsets,
549 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
550 self.no_kv_cache,
551 &self.non_granular_state,
552 context_lens,
553 &flash_meta,
554 flash_meta_full.as_ref().unwrap_or(&flash_meta),
555 )?,
556 };
557 if return_raw_logits {
558 Ok(ForwardInputsResult::RawLogits { logits })
559 } else {
560 Ok(ForwardInputsResult::CausalGeneration { logits })
561 }
562 }
563 async fn sample_causal_gen(
564 &self,
565 seqs: &mut [&mut Sequence],
566 logits: Vec<Tensor>,
567 prefix_cacher: &mut PrefixCacheManagerV2,
568 disable_eos_stop: bool,
569 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
570 ) -> Result<(), candle_core::Error> {
571 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
572 }
573 fn category(&self) -> ModelCategory {
574 ModelCategory::Text
575 }
576}
577
578impl AnyMoePipelineMixin for GGMLPipeline {}