1use super::llg::build_llg_factory;
2use super::{
3 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, QuantizationKind, TokenSource,
5};
6use super::{
7 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
8 MetadataMixin, ModelCategory, PreProcessingMixin,
9};
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::device_map::DeviceMapper;
12use crate::kv_cache::FullCacheManager;
13use crate::lora::Ordering;
14use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
15use crate::pipeline::sampling::sample_and_add_toks;
16use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
17use crate::pipeline::{ChatTemplate, LocalModelPaths};
18use crate::prefix_cacher::PrefixCacheManagerV2;
19use crate::sequence::Sequence;
20use crate::utils::debug::DeviceRepr;
21use crate::utils::model_config as ModelConfig;
22use crate::utils::tokenizer::get_tokenizer;
23use crate::xlora_models::NonGranularState;
24use crate::{
25 get_mut_arcmutex, get_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
26 TryIntoDType, DEBUG,
27};
28use crate::{
29 models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
30 xlora_models::XLoraQLlama,
31};
32use anyhow::Result;
33use candle_core::quantized::ggml_file;
34use candle_core::{Device, Tensor};
35use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
36use mistralrs_quant::IsqType;
37use rand_isaac::Isaac64Rng;
38use std::any::Any;
39use std::fs;
40use std::path::PathBuf;
41use std::str::FromStr;
42use std::sync::Arc;
43use tokenizers::Tokenizer;
44use tokio::sync::Mutex;
45use tracing::{info, warn};
46
47enum Model {
48 Llama(QLlama),
49 XLoraLlama(Box<XLoraQLlama>),
50}
51
52pub struct GGMLPipeline {
53 model: Model,
54 tokenizer: Arc<Tokenizer>,
55 no_kv_cache: bool,
56 chat_template: Arc<ChatTemplate>,
57 model_id: String,
58 non_granular_state: Option<NonGranularState>,
59 metadata: Arc<GeneralMetadata>,
60}
61
62pub struct GGMLLoader {
64 model_id: String,
65 config: GGMLSpecificConfig,
66 quantized_model_id: Option<String>,
67 quantized_filename: Option<String>,
68 xlora_model_id: Option<String>,
69 xlora_order: Option<Ordering>,
70 no_kv_cache: bool,
71 chat_template: Option<String>,
72 tokenizer_json: Option<String>,
73 kind: ModelKind,
74 tgt_non_granular_index: Option<usize>,
75 jinja_explicit: Option<String>,
76 lora_adapter_ids: Option<Vec<String>>,
77}
78
79#[derive(Clone, Default)]
80pub struct GGMLSpecificConfig {
82 pub gqa: usize,
83 pub topology: Option<Topology>,
84}
85
86#[derive(Default)]
87pub struct GGMLLoaderBuilder {
89 model_id: Option<String>,
90 config: GGMLSpecificConfig,
91 quantized_model_id: String,
92 quantized_filename: String,
93 xlora_model_id: Option<String>,
94 kind: ModelKind,
95 xlora_order: Option<Ordering>,
96 no_kv_cache: bool,
97 chat_template: Option<String>,
98 tokenizer_json: Option<String>,
99 tgt_non_granular_index: Option<usize>,
100 jinja_explicit: Option<String>,
101}
102
103impl GGMLLoaderBuilder {
104 #[allow(clippy::too_many_arguments)]
105 pub fn new(
106 config: GGMLSpecificConfig,
107 chat_template: Option<String>,
108 tokenizer_json: Option<String>,
109 model_id: Option<String>,
110 quantized_model_id: String,
111 quantized_filename: String,
112 no_kv_cache: bool,
113 jinja_explicit: Option<String>,
114 ) -> Self {
115 let kind = ModelKind::GgufQuantized {
116 quant: QuantizationKind::Ggml,
117 };
118
119 Self {
120 config,
121 chat_template,
122 tokenizer_json,
123 model_id,
124 kind,
125 quantized_filename,
126 quantized_model_id,
127 no_kv_cache,
128 jinja_explicit,
129 ..Default::default()
130 }
131 }
132
133 fn with_adapter(
134 mut self,
135 xlora_model_id: String,
136 xlora_order: Ordering,
137 no_kv_cache: bool,
138 tgt_non_granular_index: Option<usize>,
139 ) -> Self {
140 self.xlora_model_id = Some(xlora_model_id);
141 self.xlora_order = Some(xlora_order);
142 self.no_kv_cache = no_kv_cache;
143 self.tgt_non_granular_index = tgt_non_granular_index;
144 self.model_id = if let Some(id) = self.model_id {
145 Some(id)
146 } else {
147 info!(
148 "Using adapter base model ID: `{}`",
149 self.xlora_order.as_ref().unwrap().base_model_id
150 );
151 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
152 };
153 self
154 }
155
156 pub fn with_xlora(
157 mut self,
158 xlora_model_id: String,
159 xlora_order: Ordering,
160 no_kv_cache: bool,
161 tgt_non_granular_index: Option<usize>,
162 ) -> Self {
163 self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into();
164
165 self.with_adapter(
166 xlora_model_id,
167 xlora_order,
168 no_kv_cache,
169 tgt_non_granular_index,
170 )
171 }
172
173 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
174 self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into();
175
176 self.with_adapter(lora_model_id, lora_order, false, None)
177 }
178
179 pub fn build(self) -> Box<dyn Loader> {
180 Box::new(GGMLLoader {
181 model_id: self.model_id.unwrap(),
182 config: self.config,
183 xlora_model_id: self.xlora_model_id,
184 kind: self.kind,
185 xlora_order: self.xlora_order,
186 no_kv_cache: self.no_kv_cache,
187 chat_template: self.chat_template,
188 tokenizer_json: self.tokenizer_json,
189 tgt_non_granular_index: self.tgt_non_granular_index,
190 quantized_filename: Some(self.quantized_filename),
191 quantized_model_id: Some(self.quantized_model_id),
192 jinja_explicit: self.jinja_explicit,
193 lora_adapter_ids: None,
194 })
195 }
196}
197
198impl GGMLLoader {
199 #[allow(clippy::too_many_arguments)]
200 pub fn new(
201 model_id: Option<String>,
202 config: GGMLSpecificConfig,
203 quantized_model_id: Option<String>,
204 quantized_filename: Option<String>,
205 xlora_model_id: Option<String>,
206 kind: ModelKind,
207 xlora_order: Option<Ordering>,
208 no_kv_cache: bool,
209 chat_template: Option<String>,
210 tokenizer_json: Option<String>,
211 tgt_non_granular_index: Option<usize>,
212 jinja_explicit: Option<String>,
213 ) -> Self {
214 let model_id = if let Some(id) = model_id {
215 id
216 } else {
217 info!(
218 "Using adapter base model ID: `{}`",
219 xlora_order.as_ref().unwrap().base_model_id
220 );
221 xlora_order.as_ref().unwrap().base_model_id.clone()
222 };
223 Self {
224 model_id,
225 config,
226 quantized_model_id,
227 quantized_filename,
228 xlora_model_id,
229 xlora_order,
230 no_kv_cache,
231 chat_template,
232 tokenizer_json,
233 kind,
234 tgt_non_granular_index,
235 jinja_explicit,
236 lora_adapter_ids: None,
237 }
238 }
239}
240
241impl Loader for GGMLLoader {
242 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
243 fn load_model_from_path(
244 &self,
245 paths: &Box<dyn ModelPaths>,
246 dtype: &dyn TryIntoDType,
247 device: &Device,
248 silent: bool,
249 mapper: DeviceMapSetting,
250 in_situ_quant: Option<IsqType>,
251 mut paged_attn_config: Option<PagedAttentionConfig>,
252 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
253 if in_situ_quant.is_some() {
254 anyhow::bail!(
255 "You are trying to in-situ quantize a GGML model. This will not do anything."
256 );
257 }
258
259 if matches!(mapper, DeviceMapSetting::Map(_)) {
260 anyhow::bail!("Device mapping is not supported for diffusion models.")
261 }
262
263 if paged_attn_config.is_some() {
264 warn!("PagedAttention is not supported for GGML models, disabling it.");
265
266 paged_attn_config = None;
267 }
268
269 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
270
271 info!(
272 "Loading model `{}` on {}.",
273 self.get_id(),
274 device.device_pretty_repr()
275 );
276
277 #[cfg(feature = "cuda")]
278 if let Device::Cuda(dev) = &device {
279 unsafe { dev.disable_event_tracking() };
280 }
281
282 let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
283 let model = ggml_file::Content::read(&mut file, device)
284 .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
285
286 info!("Model config: {:?}", model.hparams);
287
288 if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
289 let mut tensors = Vec::new();
290 for (name, t) in &model.tensors {
291 tensors.push(format!(
292 "name = `{name}`, shape = {:?}, dtype = {:?}",
293 t.shape().clone(),
294 t.dtype(),
295 ));
296 }
297 fs::write(
298 "mistralrs_ggml_tensors.txt",
299 serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
300 )?;
301
302 info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_ggml_tensors.txt`.");
303 }
304
305 let _ = if paged_attn_config.is_none() {
306 warn!("GGML does not currently support PagedAttention, running without");
307 None
308 } else {
309 paged_attn_config
310 };
311
312 let has_adapter = self.kind.is_adapted();
313 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
314 let internal_dtype = dtype.try_into_dtype(&[device]).unwrap();
315
316 let model_config = {
317 let quant = ModelConfig::ParamsGGML((model, self.config.gqa, internal_dtype).into());
319
320 let mut adapter = None;
322 if has_adapter {
323 adapter.replace(ModelConfig::Adapter::try_new(
324 paths, device, silent, is_xlora,
325 )?);
326 }
327
328 ModelConfig::ModelParams::new(quant, adapter)
329 };
330
331 let model = match self.kind {
334 ModelKind::GgufQuantized { .. } => Model::Llama(QLlama::try_from(model_config)?),
335 ModelKind::GgufAdapter { .. } => {
336 Model::XLoraLlama(Box::new(XLoraQLlama::try_from(model_config)?))
337 }
338 _ => unreachable!(),
339 };
340
341 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
342 let gen_conf: Option<GenerationConfig> = paths
343 .get_gen_conf_filename()
344 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
345 let chat_template_explicit = paths
346 .get_chat_template_explicit()
347 .as_ref()
348 .map(|x| x.to_string_lossy().to_string());
349 let chat_template = get_chat_template(
350 paths,
351 self.jinja_explicit.as_ref(),
352 chat_template_explicit.as_ref(),
353 self.chat_template.as_ref(),
354 None,
355 );
356
357 let max_seq_len = match model {
358 Model::Llama(ref l) => l.max_seq_len,
359 Model::XLoraLlama(ref xl) => xl.max_seq_len,
360 };
361 let llg_factory = build_llg_factory(tokenizer.clone())?;
362 let num_hidden_layers = match model {
363 Model::Llama(ref model) => model.cache.normal().0.len(),
364 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
365 };
366 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
367 Ok(Arc::new(Mutex::new(GGMLPipeline {
368 model,
369 tokenizer: tokenizer.into(),
370 no_kv_cache: self.no_kv_cache,
371 chat_template: Arc::new(chat_template),
372 model_id: self.model_id.clone(),
373 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
374 NonGranularState {
375 non_granular_index: Arc::new(Mutex::new(0)),
376 tgt_non_granular_index,
377 }
378 }),
379 metadata: Arc::new(GeneralMetadata {
380 max_seq_len,
381 llg_factory: Some(llg_factory),
382 no_kv_cache: self.no_kv_cache,
383 no_prefix_cache: false,
384 num_hidden_layers,
385 eos_tok: eos,
386 kind: self.kind.clone(),
387 is_xlora,
388 activation_dtype: internal_dtype,
389 sliding_window: None,
390 cache_config: None,
391 cache_engine: None,
392 model_metadata: None,
393 modalities: Modalities {
394 input: vec![SupportedModality::Text],
395 output: vec![SupportedModality::Text],
396 },
397 }),
398 })))
399 }
400
401 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
402 fn load_model_from_hf(
403 &self,
404 revision: Option<String>,
405 token_source: TokenSource,
406 dtype: &dyn TryIntoDType,
407 device: &Device,
408 silent: bool,
409 mapper: DeviceMapSetting,
410 in_situ_quant: Option<IsqType>,
411 paged_attn_config: Option<PagedAttentionConfig>,
412 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
413 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
414 LocalModelPaths,
415 &token_source,
416 revision,
417 self,
418 self.quantized_model_id,
419 Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
420 silent,
421 false );
423 self.load_model_from_path(
424 &paths?,
425 dtype,
426 device,
427 silent,
428 mapper,
429 in_situ_quant,
430 paged_attn_config,
431 )
432 }
433
434 fn get_id(&self) -> String {
435 self.xlora_model_id
436 .as_deref()
437 .unwrap_or(&self.model_id)
438 .to_string()
439 }
440
441 fn get_kind(&self) -> ModelKind {
442 self.kind.clone()
443 }
444}
445
446impl PreProcessingMixin for GGMLPipeline {
447 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
448 Some(self.chat_template.clone())
449 }
450 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
451 None
452 }
453}
454
455impl IsqPipelineMixin for GGMLPipeline {
456 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
457 anyhow::bail!(
458 "You are trying to in-situ requantize a GGML model. This will not do anything."
459 )
460 }
461}
462
463impl CacheManagerMixin for GGMLPipeline {
464 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
465 FullCacheManager.clone_in_cache(self, seqs, false)
466 }
467 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
468 FullCacheManager.clone_out_cache(self, seqs, false)
469 }
470 fn set_none_cache(
471 &self,
472 seqs: &mut [&mut Sequence],
473 reset_non_granular: bool,
474 modify_draft_cache: bool,
475
476 load_preallocated_cache: bool,
477 ) {
478 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
479 if reset_non_granular {
480 self.reset_non_granular_state()
481 }
482 }
483 fn cache(&self) -> &EitherCache {
484 match self.model {
485 Model::Llama(ref model) => &model.cache,
486 Model::XLoraLlama(ref model) => &model.cache,
487 }
488 }
489}
490
491impl MetadataMixin for GGMLPipeline {
492 fn device(&self) -> Device {
493 match self.model {
494 Model::Llama(ref model) => model.device.clone(),
495 Model::XLoraLlama(ref model) => model.device.clone(),
496 }
497 }
498 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
499 Some(self.tokenizer.clone())
500 }
501 fn name(&self) -> String {
502 self.model_id.clone()
503 }
504 fn reset_non_granular_state(&self) {
505 if let Some(s) = self.non_granular_state.as_ref() {
506 *self.cache().full().get_scalings_cache() = None;
507 *get_mut_arcmutex!(s.non_granular_index) = 0;
508 }
509 }
510 fn get_metadata(&self) -> Arc<GeneralMetadata> {
511 self.metadata.clone()
512 }
513 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
514 None
515 }
516}
517
518#[async_trait::async_trait]
519impl Pipeline for GGMLPipeline {
520 fn forward_inputs(
521 &mut self,
522 inputs: Box<dyn Any>,
523 return_raw_logits: bool,
524 ) -> Result<ForwardInputsResult, candle_core::Error> {
525 let ModelInputs {
526 input_ids,
527 input_ids_full,
528 seqlen_offsets,
529 seqlen_offsets_full,
530 context_lens,
531 position_ids: _, paged_attn_meta: _, flash_meta, flash_meta_full, } = *inputs.downcast().expect("Downcast failed.");
536 let logits = match self.model {
537 Model::Llama(ref model) => {
538 model.forward(&input_ids, &seqlen_offsets, context_lens, None)?
539 }
540 Model::XLoraLlama(ref model) => model.forward(
541 &input_ids,
542 input_ids_full.as_ref().unwrap_or(&input_ids),
543 &seqlen_offsets,
544 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
545 self.no_kv_cache,
546 &self.non_granular_state,
547 context_lens,
548 &flash_meta,
549 flash_meta_full.as_ref().unwrap_or(&flash_meta),
550 )?,
551 };
552 if return_raw_logits {
553 Ok(ForwardInputsResult::RawLogits { logits })
554 } else {
555 Ok(ForwardInputsResult::CausalGeneration { logits })
556 }
557 }
558 async fn sample_causal_gen(
559 &self,
560 seqs: &mut [&mut Sequence],
561 logits: Vec<Tensor>,
562 prefix_cacher: &mut PrefixCacheManagerV2,
563 disable_eos_stop: bool,
564 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
565 ) -> Result<(), candle_core::Error> {
566 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
567 }
568 fn category(&self) -> ModelCategory {
569 ModelCategory::Text
570 }
571}
572
573impl AnyMoePipelineMixin for GGMLPipeline {}