1use super::llg::build_llg_factory;
2use super::{
3 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5 TokenSource,
6};
7use super::{
8 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9 MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::attention::ATTENTION_CHUNK_SIZE;
12use crate::device_map::{self, DeviceMapper};
13use crate::gguf::{
14 get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
15};
16use crate::gguf::{Content, GGUFArchitecture};
17use crate::kv_cache::{FullCacheManager, NormalCacheManager};
18use crate::lora::Ordering;
19use crate::paged_attention::{
20 calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
21};
22use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
23use crate::pipeline::loaders::DeviceMappedModelLoader;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::ChatTemplate;
26use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::progress::ProgressScopeGuard;
32use crate::utils::tokenizer::get_tokenizer;
33use crate::xlora_models::NonGranularState;
34use crate::{
35 get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
36 Pipeline, Topology, TryIntoDType,
37};
38use crate::{
39 models::quantized_llama::ModelWeights as QLlama,
40 models::quantized_phi2::ModelWeights as QPhi,
41 models::quantized_phi3::ModelWeights as QPhi3,
42 models::quantized_qwen::ModelWeights as QQwen,
43 models::quantized_qwen3::ModelWeights as QQwen3,
44 models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
45 models::quantized_starcoder2::ModelWeights as QStarcoder2,
46 utils::tokens::get_token,
47 xlora_models::{XLoraQLlama, XLoraQPhi3},
48};
49use anyhow::{bail, Result};
50use candle_core::{Device, Tensor};
51use either::Either;
52use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
53use mistralrs_quant::IsqType;
54use rand_isaac::Isaac64Rng;
55use std::any::Any;
56use std::fs;
57use std::path::PathBuf;
58use std::str::FromStr;
59use std::sync::Arc;
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64enum Model {
65 Llama(QLlama),
66 Phi2(QPhi),
67 XLoraLlama(XLoraQLlama),
68 XLoraPhi3(XLoraQPhi3),
69 Phi3(QPhi3),
70 Starcoder2(QStarcoder2),
71 Qwen(QQwen),
72 Qwen3(QQwen3),
73 Qwen3MoE(QQwen3MoE),
74}
75
76pub struct GGUFPipeline {
77 model: Model,
78 tokenizer: Arc<Tokenizer>,
79 no_kv_cache: bool,
80 chat_template: Arc<ChatTemplate>,
81 model_id: String,
82 non_granular_state: Option<NonGranularState>,
83 metadata: Arc<GeneralMetadata>,
84 mapper: Box<dyn DeviceMapper + Send + Sync>,
85}
86
87pub struct GGUFLoader {
89 model_id: Option<String>,
90 quantized_model_id: String,
91 quantized_filenames: Vec<String>,
92 xlora_model_id: Option<String>,
93 xlora_order: Option<Ordering>,
94 no_kv_cache: bool,
95 chat_template: Option<String>,
96 kind: ModelKind,
97 tgt_non_granular_index: Option<usize>,
98 config: GGUFSpecificConfig,
99 jinja_explicit: Option<String>,
100 lora_adapter_ids: Option<Vec<String>>,
101}
102
103#[derive(Clone, Default)]
104pub struct GGUFSpecificConfig {
106 pub topology: Option<Topology>,
107}
108
109#[derive(Default)]
110pub struct GGUFLoaderBuilder {
112 model_id: Option<String>,
113 quantized_model_id: String,
114 quantized_filenames: Vec<String>,
115 xlora_model_id: Option<String>,
116 kind: ModelKind,
117 xlora_order: Option<Ordering>,
118 no_kv_cache: bool,
119 chat_template: Option<String>,
120 tgt_non_granular_index: Option<usize>,
121 config: GGUFSpecificConfig,
122 jinja_explicit: Option<String>,
123}
124
125impl GGUFLoaderBuilder {
126 pub fn new(
130 chat_template: Option<String>,
131 tok_model_id: Option<String>,
132 quantized_model_id: String,
133 quantized_filenames: Vec<String>,
134 config: GGUFSpecificConfig,
135 no_kv_cache: bool,
136 jinja_explicit: Option<String>,
137 ) -> Self {
138 let kind = ModelKind::GgufQuantized {
139 quant: QuantizationKind::Gguf,
140 };
141
142 Self {
143 chat_template,
144 model_id: tok_model_id,
145 kind,
146 quantized_filenames,
147 quantized_model_id,
148 config,
149 jinja_explicit,
150 no_kv_cache,
151 ..Default::default()
152 }
153 }
154
155 fn with_adapter(
156 mut self,
157 xlora_model_id: String,
158 xlora_order: Ordering,
159 no_kv_cache: bool,
160 tgt_non_granular_index: Option<usize>,
161 ) -> Self {
162 self.xlora_model_id = Some(xlora_model_id);
163 self.xlora_order = Some(xlora_order);
164 self.no_kv_cache = no_kv_cache;
165 self.tgt_non_granular_index = tgt_non_granular_index;
166 self.model_id = if let Some(id) = self.model_id {
167 Some(id)
168 } else {
169 info!(
170 "Using adapter base model ID: `{}`",
171 self.xlora_order.as_ref().unwrap().base_model_id
172 );
173 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
174 };
175 self
176 }
177
178 pub fn with_xlora(
179 mut self,
180 xlora_model_id: String,
181 xlora_order: Ordering,
182 no_kv_cache: bool,
183 tgt_non_granular_index: Option<usize>,
184 ) -> Self {
185 self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
186
187 self.with_adapter(
188 xlora_model_id,
189 xlora_order,
190 no_kv_cache,
191 tgt_non_granular_index,
192 )
193 }
194
195 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
196 self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
197
198 self.with_adapter(lora_model_id, lora_order, false, None)
199 }
200
201 pub fn build(self) -> Box<dyn Loader> {
202 Box::new(GGUFLoader {
203 model_id: self.model_id,
204 xlora_model_id: self.xlora_model_id,
205 kind: self.kind,
206 xlora_order: self.xlora_order,
207 no_kv_cache: self.no_kv_cache,
208 chat_template: self.chat_template,
209 tgt_non_granular_index: self.tgt_non_granular_index,
210 quantized_filenames: self.quantized_filenames,
211 quantized_model_id: self.quantized_model_id,
212 config: self.config,
213 jinja_explicit: self.jinja_explicit,
214 lora_adapter_ids: None,
215 })
216 }
217}
218
219impl GGUFLoader {
220 #[allow(clippy::too_many_arguments)]
221 pub fn new(
222 model_id: Option<String>,
223 quantized_model_id: String,
224 quantized_filenames: Vec<String>,
225 xlora_model_id: Option<String>,
226 kind: ModelKind,
227 xlora_order: Option<Ordering>,
228 no_kv_cache: bool,
229 chat_template: Option<String>,
230 tgt_non_granular_index: Option<usize>,
231 config: GGUFSpecificConfig,
232 jinja_explicit: Option<String>,
233 ) -> Self {
234 let model_id = if let Some(id) = model_id {
235 Some(id)
236 } else if let Some(xlora_order) = xlora_order.clone() {
237 info!(
238 "Using adapter base model ID: `{}`",
239 xlora_order.base_model_id
240 );
241 Some(xlora_order.base_model_id.clone())
242 } else {
243 None
244 };
245 Self {
246 model_id,
247 quantized_model_id,
248 quantized_filenames,
249 xlora_model_id,
250 xlora_order,
251 no_kv_cache,
252 chat_template,
253 kind,
254 tgt_non_granular_index,
255 config,
256 jinja_explicit,
257 lora_adapter_ids: None,
258 }
259 }
260}
261
262impl Loader for GGUFLoader {
263 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
264 fn load_model_from_hf(
265 &self,
266 revision: Option<String>,
267 token_source: TokenSource,
268 dtype: &dyn TryIntoDType,
269 device: &Device,
270 silent: bool,
271 mapper: DeviceMapSetting,
272 in_situ_quant: Option<IsqType>,
273 paged_attn_config: Option<PagedAttentionConfig>,
274 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
275 let _progress_guard = ProgressScopeGuard::new(silent);
276 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
277 LocalModelPaths,
278 &token_source,
279 revision,
280 self,
281 self.quantized_model_id.clone(),
282 self.quantized_filenames.clone(),
283 silent
284 );
285 self.load_model_from_path(
286 &paths?,
287 dtype,
288 device,
289 silent,
290 mapper,
291 in_situ_quant,
292 paged_attn_config,
293 )
294 }
295
296 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
297 fn load_model_from_path(
298 &self,
299 paths: &Box<dyn ModelPaths>,
300 dtype: &dyn TryIntoDType,
301 device: &Device,
302 silent: bool,
303 mut mapper: DeviceMapSetting,
304 in_situ_quant: Option<IsqType>,
305 mut paged_attn_config: Option<PagedAttentionConfig>,
306 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
307 let _progress_guard = ProgressScopeGuard::new(silent);
308 if in_situ_quant.is_some() {
309 anyhow::bail!(
310 "You are trying to in-situ quantize a GGUF model. This will not do anything."
311 );
312 }
313
314 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
315
316 let mut readers = Vec::new();
317 for filename in paths.get_weight_filenames() {
318 readers.push(std::fs::File::open(filename)?);
319 }
320 let mut readers = readers.iter_mut().collect::<Vec<_>>();
321
322 let model = Content::from_readers(&mut readers)?;
323 if !silent {
324 model.print_metadata()?;
325 }
326 let arch = model.arch();
327
328 let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
330 if let DeviceMapSetting::Auto(params) = mapper.clone() {
331 let devices = device_map::get_all_similar_devices(device)?;
332 let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
334
335 let model = GgufDeviceMapLoaderInner {
336 model: &model,
337 arch,
338 };
339
340 let layer_sizes_in_bytes =
341 model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1, None)?;
342 let non_mapped_size_in_bytes =
343 model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1, None)?;
344 let total_model_size_in_bytes =
345 layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
346
347 let new = model.get_device_layers(
348 "this is a dummy config!",
349 num_layers,
350 layer_sizes_in_bytes,
351 non_mapped_size_in_bytes,
352 total_model_size_in_bytes,
353 &devices,
354 dtype,
355 ¶ms,
356 paged_attn_config.as_ref(),
357 )?;
358 mapper = DeviceMapSetting::Map(new);
359 }
360
361 #[cfg(feature = "cuda")]
362 if let Device::Cuda(dev) = &device {
363 unsafe { dev.disable_event_tracking() };
364 }
365
366 let pipeline_mapper =
367 mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
368 let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
369 let mut layer_devices = Vec::new();
370 for layer in 0..num_layers {
371 let device = mapper.device_for(layer, false).cloned();
372 layer_devices.push(device);
373 }
374
375 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
378 if mapping_uses_cpu {
379 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
380 paged_attn_config = None;
381 }
382
383 let GgufTokenizerConversion {
384 tokenizer,
385 bos,
386 eos,
387 unk,
388 } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
389 convert_gguf_to_hf_tokenizer(&model)?
390 } else {
391 GgufTokenizerConversion {
392 tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
393 bos: None,
394 eos: None,
395 unk: None,
396 }
397 };
398
399 let gguf_chat_template =
401 if paths.get_template_filename().is_none() && self.chat_template.is_none() {
402 get_gguf_chat_template(&model)?
403 } else {
404 None
405 };
406
407 let has_adapter = self.kind.is_adapted();
408 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
409
410 let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
411 warn!("Adapter models do not currently support PagedAttention, running without");
412 None
413 } else {
414 paged_attn_config
415 };
416
417 let model_config_metadata: ContentConfig = (&model).into();
418 let internal_dtype = mapper.get_min_dtype(dtype)?;
419
420 let model_config = {
421 let quant = ModelConfig::ParamsGGUF(
423 model,
424 (device, mapper).into(),
425 if paged_attn_config.is_some() {
426 AttentionImplementation::PagedAttention
427 } else {
428 AttentionImplementation::Eager
429 },
430 internal_dtype,
431 );
432
433 let mut adapter = None;
435 if has_adapter {
436 adapter.replace(ModelConfig::Adapter::try_new(
437 paths, device, silent, is_xlora,
438 )?);
439 }
440
441 ModelConfig::ModelParams::new(quant, adapter)
442 };
443
444 let model = match self.kind {
446 ModelKind::GgufQuantized { .. } => match arch {
447 GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
448 GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
449 GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
450 GGUFArchitecture::Starcoder2 => {
451 Model::Starcoder2(QStarcoder2::try_from(model_config)?)
452 }
453 GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
454 GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
455 GGUFArchitecture::Qwen3MoE => Model::Qwen3MoE(QQwen3MoE::try_from(model_config)?),
456 a => bail!("Unsupported architecture `{a:?}` for GGUF"),
457 },
458 ModelKind::GgufAdapter { adapter, .. } => match arch {
459 GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
460 GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
461 a => bail!(
462 "Unsupported architecture `{a:?}` for GGUF {kind}",
463 kind = adapter.pretty_name()
464 ),
465 },
466 _ => unreachable!(),
467 };
468
469 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
470 let model_config: &dyn ModelConfigLike = &model_config_metadata;
471 let cache_config = calculate_cache_config(
472 paged_attn_config.mem_gpu,
473 paged_attn_config.block_size,
474 internal_dtype,
475 paged_attn_config.cache_type,
476 model_config,
477 device,
478 &layer_devices,
479 silent,
480 )?;
481 let cache_engine = CacheEngine::new(
482 model_config,
483 &cache_config,
484 internal_dtype,
485 device,
486 layer_devices,
487 )?;
488 (Some(cache_config), Some(cache_engine))
489 } else {
490 (None, None)
491 };
492
493 let gen_conf: Option<GenerationConfig> = paths
494 .get_gen_conf_filename()
495 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
496 let chat_template_explicit = paths
497 .get_chat_template_explicit()
498 .as_ref()
499 .map(|x| x.to_string_lossy().to_string());
500 let mut chat_template = get_chat_template(
501 paths,
502 self.jinja_explicit.as_ref(),
503 chat_template_explicit.as_ref(),
504 self.chat_template.as_ref(),
505 gguf_chat_template,
506 );
507
508 let max_seq_len = match model {
509 Model::Llama(ref l) => l.max_seq_len,
510 Model::Phi2(ref p) => p.max_seq_len,
511 Model::XLoraLlama(ref xl) => xl.max_seq_len,
512 Model::Phi3(ref p) => p.max_seq_len,
513 Model::XLoraPhi3(ref p) => p.max_seq_len,
514 Model::Starcoder2(ref p) => p.max_seq_len,
515 Model::Qwen(ref p) => p.max_seq_len,
516 Model::Qwen3(ref p) => p.max_seq_len,
517 Model::Qwen3MoE(ref p) => p.max_seq_len,
518 };
519 let llg_factory = build_llg_factory(tokenizer.clone())?;
520 let num_hidden_layers = match model {
521 Model::Llama(ref model) => model.cache.normal().0.len(),
522 Model::Phi2(ref model) => model.cache.normal().0.len(),
523 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
524 Model::Phi3(ref model) => model.cache.normal().0.len(),
525 Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
526 Model::Starcoder2(ref model) => model.cache.normal().0.len(),
527 Model::Qwen(ref model) => model.cache.normal().0.len(),
528 Model::Qwen3(ref model) => model.cache.normal().0.len(),
529 Model::Qwen3MoE(ref model) => model.cache.normal().0.len(),
530 };
531
532 if chat_template.bos_token.is_none() {
533 if let Some(v) = bos {
534 chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
535 }
536 }
537 if chat_template.eos_token.is_none() {
538 if let Some(v) = eos {
539 chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(v)));
540 }
541 }
542 if chat_template.unk_token.is_none() {
543 if let Some(v) = unk {
544 chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(v)));
545 }
546 }
547
548 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
549 Ok(Arc::new(Mutex::new(GGUFPipeline {
550 model,
551 tokenizer: tokenizer.into(),
552 no_kv_cache: self.no_kv_cache,
553 chat_template: Arc::new(chat_template),
554 model_id: self
555 .model_id
556 .clone()
557 .unwrap_or(self.quantized_model_id.clone()),
558 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
559 NonGranularState {
560 non_granular_index: Arc::new(Mutex::new(0)),
561 tgt_non_granular_index,
562 }
563 }),
564 metadata: Arc::new(GeneralMetadata {
565 max_seq_len,
566 llg_factory: Some(llg_factory),
567 no_kv_cache: self.no_kv_cache,
568 no_prefix_cache: false,
569 num_hidden_layers,
570 eos_tok: eos,
571 kind: self.kind.clone(),
572 is_xlora,
573 activation_dtype: internal_dtype,
574 sliding_window: None,
575 cache_config,
576 cache_engine,
577 model_metadata: Some(Arc::new(model_config_metadata)),
578 modalities: Modalities {
579 input: vec![SupportedModality::Text],
580 output: vec![SupportedModality::Text],
581 },
582 }),
583 mapper: pipeline_mapper,
584 })))
585 }
586
587 fn get_id(&self) -> String {
588 self.xlora_model_id
589 .as_deref()
590 .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
591 .to_string()
592 }
593
594 fn get_kind(&self) -> ModelKind {
595 self.kind.clone()
596 }
597}
598
599impl PreProcessingMixin for GGUFPipeline {
600 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
601 Some(self.chat_template.clone())
602 }
603 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
604 None
605 }
606}
607
608impl IsqPipelineMixin for GGUFPipeline {
609 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
610 anyhow::bail!(
611 "You are trying to in-situ requantize a GGML model. This will not do anything."
612 )
613 }
614}
615
616impl CacheManagerMixin for GGUFPipeline {
617 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
618 if matches!(self.cache(), EitherCache::Full(_)) {
619 FullCacheManager.clone_in_cache(self, seqs, false)
620 } else {
621 NormalCacheManager.clone_in_cache(self, seqs, false)
622 }
623 }
624 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
625 if matches!(self.cache(), EitherCache::Full(_)) {
626 FullCacheManager.clone_out_cache(self, seqs, false)
627 } else {
628 NormalCacheManager.clone_out_cache(self, seqs, false)
629 }
630 }
631 fn set_none_cache(
632 &self,
633 seqs: &mut [&mut Sequence],
634 reset_non_granular: bool,
635 modify_draft_cache: bool,
636 load_preallocated_cache: bool,
637 ) {
638 if matches!(self.cache(), EitherCache::Full(_)) {
639 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
640 } else {
641 NormalCacheManager.set_none_cache(
642 self,
643 seqs,
644 modify_draft_cache,
645 load_preallocated_cache,
646 );
647 }
648 if reset_non_granular {
649 self.reset_non_granular_state()
650 }
651 }
652 fn cache(&self) -> &EitherCache {
653 match self.model {
654 Model::Llama(ref model) => &model.cache,
655 Model::Phi2(ref model) => &model.cache,
656 Model::XLoraLlama(ref model) => &model.cache,
657 Model::Phi3(ref model) => &model.cache,
658 Model::XLoraPhi3(ref model) => &model.cache,
659 Model::Starcoder2(ref model) => &model.cache,
660 Model::Qwen(ref model) => &model.cache,
661 Model::Qwen3(ref model) => &model.cache,
662 Model::Qwen3MoE(ref model) => &model.cache,
663 }
664 }
665}
666
667impl MetadataMixin for GGUFPipeline {
668 fn device(&self) -> Device {
669 match self.model {
670 Model::Llama(ref model) => model.device.clone(),
671 Model::Phi2(ref model) => model.device.clone(),
672 Model::XLoraLlama(ref model) => model.device.clone(),
673 Model::Phi3(ref model) => model.device.clone(),
674 Model::XLoraPhi3(ref model) => model.device.clone(),
675 Model::Starcoder2(ref model) => model.device.clone(),
676 Model::Qwen(ref model) => model.device.clone(),
677 Model::Qwen3(ref model) => model.device.clone(),
678 Model::Qwen3MoE(ref model) => model.device.clone(),
679 }
680 }
681 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
682 Some(self.tokenizer.clone())
683 }
684 fn name(&self) -> String {
685 self.model_id.clone()
686 }
687 fn reset_non_granular_state(&self) {
688 if let Some(s) = self.non_granular_state.as_ref() {
689 *self.cache().full().get_scalings_cache() = None;
690 *get_mut_arcmutex!(s.non_granular_index) = 0;
691 }
692 }
693 fn get_metadata(&self) -> Arc<GeneralMetadata> {
694 self.metadata.clone()
695 }
696 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
697 Some(&*self.mapper)
698 }
699}
700
701#[async_trait::async_trait]
702impl Pipeline for GGUFPipeline {
703 fn forward_inputs(
704 &mut self,
705 inputs: Box<dyn Any>,
706 return_raw_logits: bool,
707 ) -> Result<ForwardInputsResult, candle_core::Error> {
708 let ModelInputs {
709 input_ids,
710 input_ids_full,
711 seqlen_offsets,
712 seqlen_offsets_full,
713 context_lens,
714 position_ids: _, paged_attn_meta,
716 flash_meta,
717 flash_meta_full,
718 } = *inputs.downcast().expect("Downcast failed.");
719 let metadata = self.get_metadata();
720 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
721 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
722 (Some(_), None) => {
723 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
725 }
726 (None, Some(_)) => {
727 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
729 }
730 (None, None) => None,
731 };
732 let logits = match self.model {
733 Model::Llama(ref model) => {
734 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
735 }
736 Model::Phi2(ref model) => {
737 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
738 }
739 Model::XLoraLlama(ref model) => model.forward(
740 &input_ids,
741 input_ids_full.as_ref().unwrap_or(&input_ids),
742 &seqlen_offsets,
743 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
744 self.no_kv_cache,
745 &self.non_granular_state,
746 context_lens,
747 &flash_meta,
748 flash_meta_full.as_ref().unwrap_or(&flash_meta),
749 )?,
750 Model::Phi3(ref model) => {
751 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
752 }
753 Model::XLoraPhi3(ref model) => model.forward(
754 &input_ids,
755 input_ids_full.as_ref().unwrap_or(&input_ids),
756 &seqlen_offsets,
757 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
758 self.no_kv_cache,
759 &self.non_granular_state,
760 context_lens,
761 &flash_meta,
762 flash_meta_full.as_ref().unwrap_or(&flash_meta),
763 )?,
764 Model::Starcoder2(ref model) => {
765 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
766 }
767 Model::Qwen(ref model) => {
768 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
769 }
770 Model::Qwen3(ref model) => {
771 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
772 }
773 Model::Qwen3MoE(ref model) => {
774 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
775 }
776 };
777 if return_raw_logits {
778 Ok(ForwardInputsResult::RawLogits { logits })
779 } else {
780 Ok(ForwardInputsResult::CausalGeneration { logits })
781 }
782 }
783 async fn sample_causal_gen(
784 &self,
785 seqs: &mut [&mut Sequence],
786 logits: Vec<Tensor>,
787 prefix_cacher: &mut PrefixCacheManagerV2,
788 disable_eos_stop: bool,
789 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
790 ) -> Result<(), candle_core::Error> {
791 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
792 }
793 fn category(&self) -> ModelCategory {
794 ModelCategory::Text
795 }
796}
797
798impl AnyMoePipelineMixin for GGUFPipeline {}