1use super::llg::build_llg_factory;
2use super::{
3 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5 TokenSource,
6};
7use super::{
8 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9 MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::attention::ATTENTION_CHUNK_SIZE;
12use crate::device_map::{self, DeviceMapper};
13use crate::gguf::{
14 get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
15};
16use crate::gguf::{Content, GGUFArchitecture};
17use crate::kv_cache::{FullCacheManager, NormalCacheManager};
18use crate::lora::Ordering;
19use crate::paged_attention::{
20 calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
21};
22use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
23use crate::pipeline::loaders::DeviceMappedModelLoader;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::ChatTemplate;
26use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::xlora_models::NonGranularState;
33use crate::{
34 get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
35 Pipeline, Topology, TryIntoDType,
36};
37use crate::{
38 models::quantized_llama::ModelWeights as QLlama,
39 models::quantized_phi2::ModelWeights as QPhi,
40 models::quantized_phi3::ModelWeights as QPhi3,
41 models::quantized_qwen::ModelWeights as QQwen,
42 models::quantized_qwen3::ModelWeights as QQwen3,
43 models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
44 models::quantized_starcoder2::ModelWeights as QStarcoder2,
45 utils::tokens::get_token,
46 xlora_models::{XLoraQLlama, XLoraQPhi3},
47};
48use anyhow::{bail, Result};
49use candle_core::{Device, Tensor};
50use either::Either;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::IsqType;
53use rand_isaac::Isaac64Rng;
54use std::any::Any;
55use std::fs;
56use std::path::PathBuf;
57use std::str::FromStr;
58use std::sync::Arc;
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63enum Model {
64 Llama(QLlama),
65 Phi2(QPhi),
66 XLoraLlama(XLoraQLlama),
67 XLoraPhi3(XLoraQPhi3),
68 Phi3(QPhi3),
69 Starcoder2(QStarcoder2),
70 Qwen(QQwen),
71 Qwen3(QQwen3),
72 Qwen3MoE(QQwen3MoE),
73}
74
75pub struct GGUFPipeline {
76 model: Model,
77 tokenizer: Arc<Tokenizer>,
78 no_kv_cache: bool,
79 chat_template: Arc<ChatTemplate>,
80 model_id: String,
81 non_granular_state: Option<NonGranularState>,
82 metadata: Arc<GeneralMetadata>,
83 mapper: Box<dyn DeviceMapper + Send + Sync>,
84}
85
86pub struct GGUFLoader {
88 model_id: Option<String>,
89 quantized_model_id: String,
90 quantized_filenames: Vec<String>,
91 xlora_model_id: Option<String>,
92 xlora_order: Option<Ordering>,
93 no_kv_cache: bool,
94 chat_template: Option<String>,
95 kind: ModelKind,
96 tgt_non_granular_index: Option<usize>,
97 config: GGUFSpecificConfig,
98 jinja_explicit: Option<String>,
99 lora_adapter_ids: Option<Vec<String>>,
100}
101
102#[derive(Clone, Default)]
103pub struct GGUFSpecificConfig {
105 pub topology: Option<Topology>,
106}
107
108#[derive(Default)]
109pub struct GGUFLoaderBuilder {
111 model_id: Option<String>,
112 quantized_model_id: String,
113 quantized_filenames: Vec<String>,
114 xlora_model_id: Option<String>,
115 kind: ModelKind,
116 xlora_order: Option<Ordering>,
117 no_kv_cache: bool,
118 chat_template: Option<String>,
119 tgt_non_granular_index: Option<usize>,
120 config: GGUFSpecificConfig,
121 jinja_explicit: Option<String>,
122}
123
124impl GGUFLoaderBuilder {
125 pub fn new(
129 chat_template: Option<String>,
130 tok_model_id: Option<String>,
131 quantized_model_id: String,
132 quantized_filenames: Vec<String>,
133 config: GGUFSpecificConfig,
134 no_kv_cache: bool,
135 jinja_explicit: Option<String>,
136 ) -> Self {
137 let kind = ModelKind::GgufQuantized {
138 quant: QuantizationKind::Gguf,
139 };
140
141 Self {
142 chat_template,
143 model_id: tok_model_id,
144 kind,
145 quantized_filenames,
146 quantized_model_id,
147 config,
148 jinja_explicit,
149 no_kv_cache,
150 ..Default::default()
151 }
152 }
153
154 fn with_adapter(
155 mut self,
156 xlora_model_id: String,
157 xlora_order: Ordering,
158 no_kv_cache: bool,
159 tgt_non_granular_index: Option<usize>,
160 ) -> Self {
161 self.xlora_model_id = Some(xlora_model_id);
162 self.xlora_order = Some(xlora_order);
163 self.no_kv_cache = no_kv_cache;
164 self.tgt_non_granular_index = tgt_non_granular_index;
165 self.model_id = if let Some(id) = self.model_id {
166 Some(id)
167 } else {
168 info!(
169 "Using adapter base model ID: `{}`",
170 self.xlora_order.as_ref().unwrap().base_model_id
171 );
172 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173 };
174 self
175 }
176
177 pub fn with_xlora(
178 mut self,
179 xlora_model_id: String,
180 xlora_order: Ordering,
181 no_kv_cache: bool,
182 tgt_non_granular_index: Option<usize>,
183 ) -> Self {
184 self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
185
186 self.with_adapter(
187 xlora_model_id,
188 xlora_order,
189 no_kv_cache,
190 tgt_non_granular_index,
191 )
192 }
193
194 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
195 self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
196
197 self.with_adapter(lora_model_id, lora_order, false, None)
198 }
199
200 pub fn build(self) -> Box<dyn Loader> {
201 Box::new(GGUFLoader {
202 model_id: self.model_id,
203 xlora_model_id: self.xlora_model_id,
204 kind: self.kind,
205 xlora_order: self.xlora_order,
206 no_kv_cache: self.no_kv_cache,
207 chat_template: self.chat_template,
208 tgt_non_granular_index: self.tgt_non_granular_index,
209 quantized_filenames: self.quantized_filenames,
210 quantized_model_id: self.quantized_model_id,
211 config: self.config,
212 jinja_explicit: self.jinja_explicit,
213 lora_adapter_ids: None,
214 })
215 }
216}
217
218impl GGUFLoader {
219 #[allow(clippy::too_many_arguments)]
220 pub fn new(
221 model_id: Option<String>,
222 quantized_model_id: String,
223 quantized_filenames: Vec<String>,
224 xlora_model_id: Option<String>,
225 kind: ModelKind,
226 xlora_order: Option<Ordering>,
227 no_kv_cache: bool,
228 chat_template: Option<String>,
229 tgt_non_granular_index: Option<usize>,
230 config: GGUFSpecificConfig,
231 jinja_explicit: Option<String>,
232 ) -> Self {
233 let model_id = if let Some(id) = model_id {
234 Some(id)
235 } else if let Some(xlora_order) = xlora_order.clone() {
236 info!(
237 "Using adapter base model ID: `{}`",
238 xlora_order.base_model_id
239 );
240 Some(xlora_order.base_model_id.clone())
241 } else {
242 None
243 };
244 Self {
245 model_id,
246 quantized_model_id,
247 quantized_filenames,
248 xlora_model_id,
249 xlora_order,
250 no_kv_cache,
251 chat_template,
252 kind,
253 tgt_non_granular_index,
254 config,
255 jinja_explicit,
256 lora_adapter_ids: None,
257 }
258 }
259}
260
261impl Loader for GGUFLoader {
262 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
263 fn load_model_from_hf(
264 &self,
265 revision: Option<String>,
266 token_source: TokenSource,
267 dtype: &dyn TryIntoDType,
268 device: &Device,
269 silent: bool,
270 mapper: DeviceMapSetting,
271 in_situ_quant: Option<IsqType>,
272 paged_attn_config: Option<PagedAttentionConfig>,
273 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
274 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
275 LocalModelPaths,
276 &token_source,
277 revision,
278 self,
279 self.quantized_model_id.clone(),
280 self.quantized_filenames.clone(),
281 silent
282 );
283 self.load_model_from_path(
284 &paths?,
285 dtype,
286 device,
287 silent,
288 mapper,
289 in_situ_quant,
290 paged_attn_config,
291 )
292 }
293
294 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
295 fn load_model_from_path(
296 &self,
297 paths: &Box<dyn ModelPaths>,
298 dtype: &dyn TryIntoDType,
299 device: &Device,
300 silent: bool,
301 mut mapper: DeviceMapSetting,
302 in_situ_quant: Option<IsqType>,
303 mut paged_attn_config: Option<PagedAttentionConfig>,
304 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
305 if in_situ_quant.is_some() {
306 anyhow::bail!(
307 "You are trying to in-situ quantize a GGUF model. This will not do anything."
308 );
309 }
310
311 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
312
313 let mut readers = Vec::new();
314 for filename in paths.get_weight_filenames() {
315 readers.push(std::fs::File::open(filename)?);
316 }
317 let mut readers = readers.iter_mut().collect::<Vec<_>>();
318
319 let model = Content::from_readers(&mut readers)?;
320 if !silent {
321 model.print_metadata()?;
322 }
323 let arch = model.arch();
324
325 let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
327 if let DeviceMapSetting::Auto(params) = mapper.clone() {
328 let devices = device_map::get_all_similar_devices(device)?;
329 let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
331
332 let model = GgufDeviceMapLoaderInner {
333 model: &model,
334 arch,
335 };
336
337 let layer_sizes_in_bytes =
338 model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1, None)?;
339 let non_mapped_size_in_bytes =
340 model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1, None)?;
341 let total_model_size_in_bytes =
342 layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
343
344 let new = model.get_device_layers(
345 "this is a dummy config!",
346 num_layers,
347 layer_sizes_in_bytes,
348 non_mapped_size_in_bytes,
349 total_model_size_in_bytes,
350 &devices,
351 dtype,
352 ¶ms,
353 paged_attn_config.as_ref(),
354 )?;
355 mapper = DeviceMapSetting::Map(new);
356 }
357
358 #[cfg(feature = "cuda")]
359 if let Device::Cuda(dev) = &device {
360 unsafe { dev.disable_event_tracking() };
361 }
362
363 let pipeline_mapper =
364 mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
365 let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
366 let mut layer_devices = Vec::new();
367 for layer in 0..num_layers {
368 let device = mapper.device_for(layer, false).cloned();
369 layer_devices.push(device);
370 }
371
372 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
375 if mapping_uses_cpu {
376 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
377 paged_attn_config = None;
378 }
379
380 let GgufTokenizerConversion {
381 tokenizer,
382 bos,
383 eos,
384 unk,
385 } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
386 convert_gguf_to_hf_tokenizer(&model)?
387 } else {
388 GgufTokenizerConversion {
389 tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
390 bos: None,
391 eos: None,
392 unk: None,
393 }
394 };
395
396 let gguf_chat_template =
398 if paths.get_template_filename().is_none() && self.chat_template.is_none() {
399 get_gguf_chat_template(&model)?
400 } else {
401 None
402 };
403
404 let has_adapter = self.kind.is_adapted();
405 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
406
407 let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
408 warn!("Adapter models do not currently support PagedAttention, running without");
409 None
410 } else {
411 paged_attn_config
412 };
413
414 let model_config_metadata: ContentConfig = (&model).into();
415 let internal_dtype = mapper.get_min_dtype(dtype)?;
416
417 let model_config = {
418 let quant = ModelConfig::ParamsGGUF(
420 model,
421 (device, mapper).into(),
422 if paged_attn_config.is_some() {
423 AttentionImplementation::PagedAttention
424 } else {
425 AttentionImplementation::Eager
426 },
427 internal_dtype,
428 );
429
430 let mut adapter = None;
432 if has_adapter {
433 adapter.replace(ModelConfig::Adapter::try_new(
434 paths, device, silent, is_xlora,
435 )?);
436 }
437
438 ModelConfig::ModelParams::new(quant, adapter)
439 };
440
441 let model = match self.kind {
443 ModelKind::GgufQuantized { .. } => match arch {
444 GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
445 GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
446 GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
447 GGUFArchitecture::Starcoder2 => {
448 Model::Starcoder2(QStarcoder2::try_from(model_config)?)
449 }
450 GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
451 GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
452 GGUFArchitecture::Qwen3MoE => Model::Qwen3MoE(QQwen3MoE::try_from(model_config)?),
453 a => bail!("Unsupported architecture `{a:?}` for GGUF"),
454 },
455 ModelKind::GgufAdapter { adapter, .. } => match arch {
456 GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
457 GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
458 a => bail!(
459 "Unsupported architecture `{a:?}` for GGUF {kind}",
460 kind = adapter.pretty_name()
461 ),
462 },
463 _ => unreachable!(),
464 };
465
466 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
467 let model_config: &dyn ModelConfigLike = &model_config_metadata;
468 let cache_config = calculate_cache_config(
469 paged_attn_config.mem_gpu,
470 paged_attn_config.mem_cpu,
471 paged_attn_config.block_size,
472 internal_dtype,
473 paged_attn_config.cache_type,
474 model_config,
475 device,
476 &layer_devices,
477 silent,
478 )?;
479 let cache_engine = CacheEngine::new(
480 model_config,
481 &cache_config,
482 internal_dtype,
483 device,
484 layer_devices,
485 )?;
486 (Some(cache_config), Some(cache_engine))
487 } else {
488 (None, None)
489 };
490
491 let gen_conf: Option<GenerationConfig> = paths
492 .get_gen_conf_filename()
493 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
494 let chat_template_explicit = paths
495 .get_chat_template_explicit()
496 .as_ref()
497 .map(|x| x.to_string_lossy().to_string());
498 let mut chat_template = get_chat_template(
499 paths,
500 self.jinja_explicit.as_ref(),
501 chat_template_explicit.as_ref(),
502 self.chat_template.as_ref(),
503 gguf_chat_template,
504 );
505
506 let max_seq_len = match model {
507 Model::Llama(ref l) => l.max_seq_len,
508 Model::Phi2(ref p) => p.max_seq_len,
509 Model::XLoraLlama(ref xl) => xl.max_seq_len,
510 Model::Phi3(ref p) => p.max_seq_len,
511 Model::XLoraPhi3(ref p) => p.max_seq_len,
512 Model::Starcoder2(ref p) => p.max_seq_len,
513 Model::Qwen(ref p) => p.max_seq_len,
514 Model::Qwen3(ref p) => p.max_seq_len,
515 Model::Qwen3MoE(ref p) => p.max_seq_len,
516 };
517 let llg_factory = build_llg_factory(tokenizer.clone())?;
518 let num_hidden_layers = match model {
519 Model::Llama(ref model) => model.cache.normal().0.len(),
520 Model::Phi2(ref model) => model.cache.normal().0.len(),
521 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
522 Model::Phi3(ref model) => model.cache.normal().0.len(),
523 Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
524 Model::Starcoder2(ref model) => model.cache.normal().0.len(),
525 Model::Qwen(ref model) => model.cache.normal().0.len(),
526 Model::Qwen3(ref model) => model.cache.normal().0.len(),
527 Model::Qwen3MoE(ref model) => model.cache.normal().0.len(),
528 };
529
530 if chat_template.bos_token.is_none() {
531 if let Some(v) = bos {
532 chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
533 }
534 }
535 if chat_template.eos_token.is_none() {
536 if let Some(v) = eos {
537 chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
538 }
539 }
540 if chat_template.unk_token.is_none() {
541 if let Some(v) = unk {
542 chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(v)));
543 }
544 }
545
546 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
547 Ok(Arc::new(Mutex::new(GGUFPipeline {
548 model,
549 tokenizer: tokenizer.into(),
550 no_kv_cache: self.no_kv_cache,
551 chat_template: Arc::new(chat_template),
552 model_id: self
553 .model_id
554 .clone()
555 .unwrap_or(self.quantized_model_id.clone()),
556 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
557 NonGranularState {
558 non_granular_index: Arc::new(Mutex::new(0)),
559 tgt_non_granular_index,
560 }
561 }),
562 metadata: Arc::new(GeneralMetadata {
563 max_seq_len,
564 llg_factory: Some(llg_factory),
565 no_kv_cache: self.no_kv_cache,
566 no_prefix_cache: false,
567 num_hidden_layers,
568 eos_tok: eos,
569 kind: self.kind.clone(),
570 is_xlora,
571 activation_dtype: internal_dtype,
572 sliding_window: None,
573 cache_config,
574 cache_engine,
575 model_metadata: Some(Arc::new(model_config_metadata)),
576 modalities: Modalities {
577 input: vec![SupportedModality::Text],
578 output: vec![SupportedModality::Text],
579 },
580 }),
581 mapper: pipeline_mapper,
582 })))
583 }
584
585 fn get_id(&self) -> String {
586 self.xlora_model_id
587 .as_deref()
588 .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
589 .to_string()
590 }
591
592 fn get_kind(&self) -> ModelKind {
593 self.kind.clone()
594 }
595}
596
597impl PreProcessingMixin for GGUFPipeline {
598 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
599 Some(self.chat_template.clone())
600 }
601 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
602 None
603 }
604}
605
606impl IsqPipelineMixin for GGUFPipeline {
607 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
608 anyhow::bail!(
609 "You are trying to in-situ requantize a GGML model. This will not do anything."
610 )
611 }
612}
613
614impl CacheManagerMixin for GGUFPipeline {
615 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
616 if matches!(self.cache(), EitherCache::Full(_)) {
617 FullCacheManager.clone_in_cache(self, seqs, false)
618 } else {
619 NormalCacheManager.clone_in_cache(self, seqs, false)
620 }
621 }
622 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
623 if matches!(self.cache(), EitherCache::Full(_)) {
624 FullCacheManager.clone_out_cache(self, seqs, false)
625 } else {
626 NormalCacheManager.clone_out_cache(self, seqs, false)
627 }
628 }
629 fn set_none_cache(
630 &self,
631 seqs: &mut [&mut Sequence],
632 reset_non_granular: bool,
633 modify_draft_cache: bool,
634 load_preallocated_cache: bool,
635 ) {
636 if matches!(self.cache(), EitherCache::Full(_)) {
637 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
638 } else {
639 NormalCacheManager.set_none_cache(
640 self,
641 seqs,
642 modify_draft_cache,
643 load_preallocated_cache,
644 );
645 }
646 if reset_non_granular {
647 self.reset_non_granular_state()
648 }
649 }
650 fn cache(&self) -> &EitherCache {
651 match self.model {
652 Model::Llama(ref model) => &model.cache,
653 Model::Phi2(ref model) => &model.cache,
654 Model::XLoraLlama(ref model) => &model.cache,
655 Model::Phi3(ref model) => &model.cache,
656 Model::XLoraPhi3(ref model) => &model.cache,
657 Model::Starcoder2(ref model) => &model.cache,
658 Model::Qwen(ref model) => &model.cache,
659 Model::Qwen3(ref model) => &model.cache,
660 Model::Qwen3MoE(ref model) => &model.cache,
661 }
662 }
663}
664
665impl MetadataMixin for GGUFPipeline {
666 fn device(&self) -> Device {
667 match self.model {
668 Model::Llama(ref model) => model.device.clone(),
669 Model::Phi2(ref model) => model.device.clone(),
670 Model::XLoraLlama(ref model) => model.device.clone(),
671 Model::Phi3(ref model) => model.device.clone(),
672 Model::XLoraPhi3(ref model) => model.device.clone(),
673 Model::Starcoder2(ref model) => model.device.clone(),
674 Model::Qwen(ref model) => model.device.clone(),
675 Model::Qwen3(ref model) => model.device.clone(),
676 Model::Qwen3MoE(ref model) => model.device.clone(),
677 }
678 }
679 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
680 Some(self.tokenizer.clone())
681 }
682 fn name(&self) -> String {
683 self.model_id.clone()
684 }
685 fn reset_non_granular_state(&self) {
686 if let Some(s) = self.non_granular_state.as_ref() {
687 *self.cache().full().get_scalings_cache() = None;
688 *get_mut_arcmutex!(s.non_granular_index) = 0;
689 }
690 }
691 fn get_metadata(&self) -> Arc<GeneralMetadata> {
692 self.metadata.clone()
693 }
694 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
695 Some(&*self.mapper)
696 }
697}
698
699#[async_trait::async_trait]
700impl Pipeline for GGUFPipeline {
701 fn forward_inputs(
702 &mut self,
703 inputs: Box<dyn Any>,
704 return_raw_logits: bool,
705 ) -> Result<ForwardInputsResult, candle_core::Error> {
706 let ModelInputs {
707 input_ids,
708 input_ids_full,
709 seqlen_offsets,
710 seqlen_offsets_full,
711 context_lens,
712 position_ids: _, paged_attn_meta,
714 flash_meta,
715 flash_meta_full,
716 } = *inputs.downcast().expect("Downcast failed.");
717 let metadata = self.get_metadata();
718 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
719 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
720 (Some(_), None) => {
721 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
723 }
724 (None, Some(_)) => {
725 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
727 }
728 (None, None) => None,
729 };
730 let logits = match self.model {
731 Model::Llama(ref model) => {
732 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
733 }
734 Model::Phi2(ref model) => {
735 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
736 }
737 Model::XLoraLlama(ref model) => model.forward(
738 &input_ids,
739 input_ids_full.as_ref().unwrap_or(&input_ids),
740 &seqlen_offsets,
741 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
742 self.no_kv_cache,
743 &self.non_granular_state,
744 context_lens,
745 &flash_meta,
746 flash_meta_full.as_ref().unwrap_or(&flash_meta),
747 )?,
748 Model::Phi3(ref model) => {
749 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
750 }
751 Model::XLoraPhi3(ref model) => model.forward(
752 &input_ids,
753 input_ids_full.as_ref().unwrap_or(&input_ids),
754 &seqlen_offsets,
755 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
756 self.no_kv_cache,
757 &self.non_granular_state,
758 context_lens,
759 &flash_meta,
760 flash_meta_full.as_ref().unwrap_or(&flash_meta),
761 )?,
762 Model::Starcoder2(ref model) => {
763 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
764 }
765 Model::Qwen(ref model) => {
766 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
767 }
768 Model::Qwen3(ref model) => {
769 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
770 }
771 Model::Qwen3MoE(ref model) => {
772 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
773 }
774 };
775 if return_raw_logits {
776 Ok(ForwardInputsResult::RawLogits { logits })
777 } else {
778 Ok(ForwardInputsResult::CausalGeneration { logits })
779 }
780 }
781 async fn sample_causal_gen(
782 &self,
783 seqs: &mut [&mut Sequence],
784 logits: Vec<Tensor>,
785 prefix_cacher: &mut PrefixCacheManagerV2,
786 disable_eos_stop: bool,
787 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
788 ) -> Result<(), candle_core::Error> {
789 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
790 }
791 fn category(&self) -> ModelCategory {
792 ModelCategory::Text
793 }
794}
795
796impl AnyMoePipelineMixin for GGUFPipeline {}