1use super::llg::build_llg_factory;
2use super::{
3 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
4 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
5 TokenSource,
6};
7use super::{
8 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
9 MetadataMixin, ModelCategory, PreProcessingMixin,
10};
11use crate::device_map::{self, DeviceMapper};
12use crate::gguf::{
13 get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
14};
15use crate::gguf::{Content, GGUFArchitecture};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::lora::Ordering;
18use crate::paged_attention::{
19 calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
20};
21use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
22use crate::pipeline::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
23use crate::pipeline::loaders::DeviceMappedModelLoader;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::ChatTemplate;
26use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::xlora_models::NonGranularState;
33use crate::{
34 get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
35 Pipeline, Topology, TryIntoDType,
36};
37use crate::{
38 models::quantized_llama::ModelWeights as QLlama,
39 models::quantized_phi2::ModelWeights as QPhi,
40 models::quantized_phi3::ModelWeights as QPhi3,
41 models::quantized_qwen::ModelWeights as QQwen,
42 models::quantized_qwen3::ModelWeights as QQwen3,
43 models::quantized_starcoder2::ModelWeights as QStarcoder2,
44 utils::tokens::get_token,
45 xlora_models::{XLoraQLlama, XLoraQPhi3},
46};
47use anyhow::{bail, Result};
48use candle_core::{Device, Tensor};
49use either::Either;
50use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
51use mistralrs_quant::IsqType;
52use rand_isaac::Isaac64Rng;
53use std::any::Any;
54use std::fs;
55use std::num::{NonZero, NonZeroUsize};
56use std::path::PathBuf;
57use std::str::FromStr;
58use std::sync::Arc;
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63enum Model {
64 Llama(QLlama),
65 Phi2(QPhi),
66 XLoraLlama(XLoraQLlama),
67 XLoraPhi3(XLoraQPhi3),
68 Phi3(QPhi3),
69 Starcoder2(QStarcoder2),
70 Qwen(QQwen),
71 Qwen3(QQwen3),
72}
73
74pub struct GGUFPipeline {
75 model: Model,
76 tokenizer: Arc<Tokenizer>,
77 no_kv_cache: bool,
78 chat_template: Arc<ChatTemplate>,
79 model_id: String,
80 non_granular_state: Option<NonGranularState>,
81 metadata: Arc<GeneralMetadata>,
82 mapper: Box<dyn DeviceMapper + Send + Sync>,
83}
84
85pub struct GGUFLoader {
87 model_id: Option<String>,
88 quantized_model_id: String,
89 quantized_filenames: Vec<String>,
90 xlora_model_id: Option<String>,
91 xlora_order: Option<Ordering>,
92 no_kv_cache: bool,
93 chat_template: Option<String>,
94 kind: ModelKind,
95 tgt_non_granular_index: Option<usize>,
96 config: GGUFSpecificConfig,
97 jinja_explicit: Option<String>,
98 lora_adapter_ids: Option<Vec<String>>,
99}
100
101#[derive(Clone, Default)]
102pub struct GGUFSpecificConfig {
104 pub prompt_chunksize: Option<NonZeroUsize>,
105 pub topology: Option<Topology>,
106}
107
108#[derive(Default)]
109pub struct GGUFLoaderBuilder {
111 model_id: Option<String>,
112 quantized_model_id: String,
113 quantized_filenames: Vec<String>,
114 xlora_model_id: Option<String>,
115 kind: ModelKind,
116 xlora_order: Option<Ordering>,
117 no_kv_cache: bool,
118 chat_template: Option<String>,
119 tgt_non_granular_index: Option<usize>,
120 config: GGUFSpecificConfig,
121 jinja_explicit: Option<String>,
122}
123
124impl GGUFLoaderBuilder {
125 pub fn new(
129 chat_template: Option<String>,
130 tok_model_id: Option<String>,
131 quantized_model_id: String,
132 quantized_filenames: Vec<String>,
133 config: GGUFSpecificConfig,
134 no_kv_cache: bool,
135 jinja_explicit: Option<String>,
136 ) -> Self {
137 let kind = ModelKind::GgufQuantized {
138 quant: QuantizationKind::Gguf,
139 };
140
141 Self {
142 chat_template,
143 model_id: tok_model_id,
144 kind,
145 quantized_filenames,
146 quantized_model_id,
147 config,
148 jinja_explicit,
149 no_kv_cache,
150 ..Default::default()
151 }
152 }
153
154 fn with_adapter(
155 mut self,
156 xlora_model_id: String,
157 xlora_order: Ordering,
158 no_kv_cache: bool,
159 tgt_non_granular_index: Option<usize>,
160 ) -> Self {
161 self.xlora_model_id = Some(xlora_model_id);
162 self.xlora_order = Some(xlora_order);
163 self.no_kv_cache = no_kv_cache;
164 self.tgt_non_granular_index = tgt_non_granular_index;
165 self.model_id = if let Some(id) = self.model_id {
166 Some(id)
167 } else {
168 info!(
169 "Using adapter base model ID: `{}`",
170 self.xlora_order.as_ref().unwrap().base_model_id
171 );
172 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173 };
174 self
175 }
176
177 pub fn with_xlora(
178 mut self,
179 xlora_model_id: String,
180 xlora_order: Ordering,
181 no_kv_cache: bool,
182 tgt_non_granular_index: Option<usize>,
183 ) -> Self {
184 self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
185
186 self.with_adapter(
187 xlora_model_id,
188 xlora_order,
189 no_kv_cache,
190 tgt_non_granular_index,
191 )
192 }
193
194 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
195 self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
196
197 self.with_adapter(lora_model_id, lora_order, false, None)
198 }
199
200 pub fn build(self) -> Box<dyn Loader> {
201 Box::new(GGUFLoader {
202 model_id: self.model_id,
203 xlora_model_id: self.xlora_model_id,
204 kind: self.kind,
205 xlora_order: self.xlora_order,
206 no_kv_cache: self.no_kv_cache,
207 chat_template: self.chat_template,
208 tgt_non_granular_index: self.tgt_non_granular_index,
209 quantized_filenames: self.quantized_filenames,
210 quantized_model_id: self.quantized_model_id,
211 config: self.config,
212 jinja_explicit: self.jinja_explicit,
213 lora_adapter_ids: None,
214 })
215 }
216}
217
218impl GGUFLoader {
219 #[allow(clippy::too_many_arguments)]
220 pub fn new(
221 model_id: Option<String>,
222 quantized_model_id: String,
223 quantized_filenames: Vec<String>,
224 xlora_model_id: Option<String>,
225 kind: ModelKind,
226 xlora_order: Option<Ordering>,
227 no_kv_cache: bool,
228 chat_template: Option<String>,
229 tgt_non_granular_index: Option<usize>,
230 config: GGUFSpecificConfig,
231 jinja_explicit: Option<String>,
232 ) -> Self {
233 let model_id = if let Some(id) = model_id {
234 Some(id)
235 } else if let Some(xlora_order) = xlora_order.clone() {
236 info!(
237 "Using adapter base model ID: `{}`",
238 xlora_order.base_model_id
239 );
240 Some(xlora_order.base_model_id.clone())
241 } else {
242 None
243 };
244 Self {
245 model_id,
246 quantized_model_id,
247 quantized_filenames,
248 xlora_model_id,
249 xlora_order,
250 no_kv_cache,
251 chat_template,
252 kind,
253 tgt_non_granular_index,
254 config,
255 jinja_explicit,
256 lora_adapter_ids: None,
257 }
258 }
259}
260
261impl Loader for GGUFLoader {
262 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
263 fn load_model_from_hf(
264 &self,
265 revision: Option<String>,
266 token_source: TokenSource,
267 dtype: &dyn TryIntoDType,
268 device: &Device,
269 silent: bool,
270 mapper: DeviceMapSetting,
271 in_situ_quant: Option<IsqType>,
272 paged_attn_config: Option<PagedAttentionConfig>,
273 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
274 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
275 LocalModelPaths,
276 &token_source,
277 revision,
278 self,
279 self.quantized_model_id.clone(),
280 self.quantized_filenames.clone(),
281 silent
282 );
283 self.load_model_from_path(
284 &paths?,
285 dtype,
286 device,
287 silent,
288 mapper,
289 in_situ_quant,
290 paged_attn_config,
291 )
292 }
293
294 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
295 fn load_model_from_path(
296 &self,
297 paths: &Box<dyn ModelPaths>,
298 dtype: &dyn TryIntoDType,
299 device: &Device,
300 silent: bool,
301 mut mapper: DeviceMapSetting,
302 in_situ_quant: Option<IsqType>,
303 mut paged_attn_config: Option<PagedAttentionConfig>,
304 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
305 if in_situ_quant.is_some() {
306 anyhow::bail!(
307 "You are trying to in-situ quantize a GGUF model. This will not do anything."
308 );
309 }
310
311 let prompt_chunksize = self
313 .config
314 .prompt_chunksize
315 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
316 .get();
317
318 info!("Prompt chunk size is {prompt_chunksize}.",);
319
320 let mut readers = Vec::new();
321 for filename in paths.get_weight_filenames() {
322 readers.push(std::fs::File::open(filename)?);
323 }
324 let mut readers = readers.iter_mut().collect::<Vec<_>>();
325
326 let model = Content::from_readers(&mut readers)?;
327 if !silent {
328 model.print_metadata()?;
329 }
330 let arch = model.arch();
331
332 let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
334 if let DeviceMapSetting::Auto(params) = mapper.clone() {
335 let devices = device_map::get_all_similar_devices(device)?;
336 let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
338
339 let model = GgufDeviceMapLoaderInner {
340 model: &model,
341 arch,
342 };
343
344 let layer_sizes_in_bytes =
345 model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1)?;
346 let non_mapped_size_in_bytes =
347 model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1)?;
348 let total_model_size_in_bytes =
349 layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
350
351 let new = model.get_device_layers(
352 "this is a dummy config!",
353 num_layers,
354 layer_sizes_in_bytes,
355 non_mapped_size_in_bytes,
356 total_model_size_in_bytes,
357 &devices,
358 dtype,
359 ¶ms,
360 prompt_chunksize,
361 paged_attn_config.as_ref(),
362 )?;
363 mapper = DeviceMapSetting::Map(new);
364 }
365
366 let pipeline_mapper =
367 mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
368 let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
369 let mut layer_devices = Vec::new();
370 for layer in 0..num_layers {
371 let device = mapper.device_for(layer, false).cloned();
372 layer_devices.push(device);
373 }
374
375 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
378 if mapping_uses_cpu {
379 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
380 paged_attn_config = None;
381 }
382
383 let GgufTokenizerConversion {
384 tokenizer,
385 bos,
386 eos,
387 unk,
388 } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
389 convert_gguf_to_hf_tokenizer(&model)?
390 } else {
391 GgufTokenizerConversion {
392 tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
393 bos: None,
394 eos: None,
395 unk: None,
396 }
397 };
398
399 let gguf_chat_template =
401 if paths.get_template_filename().is_none() && self.chat_template.is_none() {
402 get_gguf_chat_template(&model)?
403 } else {
404 None
405 };
406
407 let has_adapter = self.kind.is_adapted();
408 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
409
410 let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
411 warn!("Adapter models do not currently support PagedAttention, running without");
412 None
413 } else {
414 paged_attn_config
415 };
416
417 let model_config_metadata: ContentConfig = (&model).into();
418 let internal_dtype = mapper.get_min_dtype(dtype)?;
419
420 let model_config = {
421 let quant = ModelConfig::ParamsGGUF(
423 model,
424 (device, mapper).into(),
425 if paged_attn_config.is_some() {
426 AttentionImplementation::PagedAttention
427 } else {
428 AttentionImplementation::Eager
429 },
430 internal_dtype,
431 );
432
433 let mut adapter = None;
435 if has_adapter {
436 adapter.replace(ModelConfig::Adapter::try_new(
437 paths, device, silent, is_xlora,
438 )?);
439 }
440
441 ModelConfig::ModelParams::new(quant, adapter)
442 };
443
444 let model = match self.kind {
446 ModelKind::GgufQuantized { .. } => match arch {
447 GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
448 GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
449 GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
450 GGUFArchitecture::Starcoder2 => {
451 Model::Starcoder2(QStarcoder2::try_from(model_config)?)
452 }
453 GGUFArchitecture::Qwen2 => Model::Qwen(QQwen::try_from(model_config)?),
454 GGUFArchitecture::Qwen3 => Model::Qwen3(QQwen3::try_from(model_config)?),
455 a => bail!("Unsupported architecture `{a:?}` for GGUF"),
456 },
457 ModelKind::GgufAdapter { adapter, .. } => match arch {
458 GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
459 GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
460 a => bail!(
461 "Unsupported architecture `{a:?}` for GGUF {kind}",
462 kind = adapter.pretty_name()
463 ),
464 },
465 _ => unreachable!(),
466 };
467
468 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
469 let model_config: &dyn ModelConfigLike = &model_config_metadata;
470 let cache_config = calculate_cache_config(
471 paged_attn_config.mem_gpu,
472 paged_attn_config.mem_cpu,
473 paged_attn_config.block_size,
474 internal_dtype,
475 paged_attn_config.cache_type,
476 model_config,
477 device,
478 &layer_devices,
479 silent,
480 )?;
481 let cache_engine = CacheEngine::new(
482 model_config,
483 &cache_config,
484 internal_dtype,
485 device,
486 layer_devices,
487 )?;
488 (Some(cache_config), Some(cache_engine))
489 } else {
490 (None, None)
491 };
492
493 let gen_conf: Option<GenerationConfig> = paths
494 .get_gen_conf_filename()
495 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
496 let chat_template_explicit = paths
497 .get_chat_template_explicit()
498 .as_ref()
499 .map(|x| x.to_string_lossy().to_string());
500 let mut chat_template = get_chat_template(
501 paths,
502 self.jinja_explicit.as_ref(),
503 chat_template_explicit.as_ref(),
504 self.chat_template.as_ref(),
505 gguf_chat_template,
506 );
507
508 let max_seq_len = match model {
509 Model::Llama(ref l) => l.max_seq_len,
510 Model::Phi2(ref p) => p.max_seq_len,
511 Model::XLoraLlama(ref xl) => xl.max_seq_len,
512 Model::Phi3(ref p) => p.max_seq_len,
513 Model::XLoraPhi3(ref p) => p.max_seq_len,
514 Model::Starcoder2(ref p) => p.max_seq_len,
515 Model::Qwen(ref p) => p.max_seq_len,
516 Model::Qwen3(ref p) => p.max_seq_len,
517 };
518 let llg_factory = build_llg_factory(tokenizer.clone())?;
519 let num_hidden_layers = match model {
520 Model::Llama(ref model) => model.cache.normal().0.len(),
521 Model::Phi2(ref model) => model.cache.normal().0.len(),
522 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
523 Model::Phi3(ref model) => model.cache.normal().0.len(),
524 Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
525 Model::Starcoder2(ref model) => model.cache.normal().0.len(),
526 Model::Qwen(ref model) => model.cache.normal().0.len(),
527 Model::Qwen3(ref model) => model.cache.normal().0.len(),
528 };
529
530 if chat_template.bos_token.is_none() && bos.is_some() {
531 chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos.unwrap())));
532 }
533 if chat_template.eos_token.is_none() && eos.is_some() {
534 chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos.unwrap())));
535 }
536 if chat_template.unk_token.is_none() && unk.is_some() {
537 chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(unk.unwrap())));
538 }
539
540 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
541 Ok(Arc::new(Mutex::new(GGUFPipeline {
542 model,
543 tokenizer: tokenizer.into(),
544 no_kv_cache: self.no_kv_cache,
545 chat_template: Arc::new(chat_template),
546 model_id: self
547 .model_id
548 .clone()
549 .unwrap_or(self.quantized_model_id.clone()),
550 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
551 NonGranularState {
552 non_granular_index: Arc::new(Mutex::new(0)),
553 tgt_non_granular_index,
554 }
555 }),
556 metadata: Arc::new(GeneralMetadata {
557 max_seq_len,
558 llg_factory: Some(llg_factory),
559 no_kv_cache: self.no_kv_cache,
560 no_prefix_cache: false,
561 num_hidden_layers,
562 eos_tok: eos,
563 kind: self.kind.clone(),
564 is_xlora,
565 activation_dtype: internal_dtype,
566 sliding_window: None,
567 cache_config,
568 cache_engine,
569 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
570 model_metadata: Some(Arc::new(model_config_metadata)),
571 modalities: Modalities {
572 input: vec![SupportedModality::Text],
573 output: vec![SupportedModality::Text],
574 },
575 }),
576 mapper: pipeline_mapper,
577 })))
578 }
579
580 fn get_id(&self) -> String {
581 self.xlora_model_id
582 .as_deref()
583 .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
584 .to_string()
585 }
586
587 fn get_kind(&self) -> ModelKind {
588 self.kind.clone()
589 }
590}
591
592impl PreProcessingMixin for GGUFPipeline {
593 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
594 Some(self.chat_template.clone())
595 }
596 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
597 None
598 }
599}
600
601impl IsqPipelineMixin for GGUFPipeline {
602 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
603 anyhow::bail!(
604 "You are trying to in-situ requantize a GGML model. This will not do anything."
605 )
606 }
607}
608
609impl CacheManagerMixin for GGUFPipeline {
610 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
611 if matches!(self.cache(), EitherCache::Full(_)) {
612 FullCacheManager.clone_in_cache(self, seqs, false)
613 } else {
614 NormalCacheManager.clone_in_cache(self, seqs, false)
615 }
616 }
617 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
618 if matches!(self.cache(), EitherCache::Full(_)) {
619 FullCacheManager.clone_out_cache(self, seqs, false)
620 } else {
621 NormalCacheManager.clone_out_cache(self, seqs, false)
622 }
623 }
624 fn set_none_cache(
625 &self,
626 seqs: &mut [&mut Sequence],
627 reset_non_granular: bool,
628 modify_draft_cache: bool,
629 load_preallocated_cache: bool,
630 ) {
631 if matches!(self.cache(), EitherCache::Full(_)) {
632 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
633 } else {
634 NormalCacheManager.set_none_cache(
635 self,
636 seqs,
637 modify_draft_cache,
638 load_preallocated_cache,
639 );
640 }
641 if reset_non_granular {
642 self.reset_non_granular_state()
643 }
644 }
645 fn cache(&self) -> &EitherCache {
646 match self.model {
647 Model::Llama(ref model) => &model.cache,
648 Model::Phi2(ref model) => &model.cache,
649 Model::XLoraLlama(ref model) => &model.cache,
650 Model::Phi3(ref model) => &model.cache,
651 Model::XLoraPhi3(ref model) => &model.cache,
652 Model::Starcoder2(ref model) => &model.cache,
653 Model::Qwen(ref model) => &model.cache,
654 Model::Qwen3(ref model) => &model.cache,
655 }
656 }
657}
658
659impl MetadataMixin for GGUFPipeline {
660 fn device(&self) -> Device {
661 match self.model {
662 Model::Llama(ref model) => model.device.clone(),
663 Model::Phi2(ref model) => model.device.clone(),
664 Model::XLoraLlama(ref model) => model.device.clone(),
665 Model::Phi3(ref model) => model.device.clone(),
666 Model::XLoraPhi3(ref model) => model.device.clone(),
667 Model::Starcoder2(ref model) => model.device.clone(),
668 Model::Qwen(ref model) => model.device.clone(),
669 Model::Qwen3(ref model) => model.device.clone(),
670 }
671 }
672 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
673 Some(self.tokenizer.clone())
674 }
675 fn name(&self) -> String {
676 self.model_id.clone()
677 }
678 fn reset_non_granular_state(&self) {
679 if let Some(s) = self.non_granular_state.as_ref() {
680 *self.cache().full().get_scalings_cache() = None;
681 *get_mut_arcmutex!(s.non_granular_index) = 0;
682 }
683 }
684 fn get_metadata(&self) -> Arc<GeneralMetadata> {
685 self.metadata.clone()
686 }
687 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
688 Some(&*self.mapper)
689 }
690}
691
692#[async_trait::async_trait]
693impl Pipeline for GGUFPipeline {
694 fn forward_inputs(
695 &mut self,
696 inputs: Box<dyn Any>,
697 return_raw_logits: bool,
698 ) -> Result<ForwardInputsResult, candle_core::Error> {
699 let ModelInputs {
700 input_ids,
701 input_ids_full,
702 seqlen_offsets,
703 seqlen_offsets_full,
704 context_lens,
705 position_ids: _, paged_attn_meta,
707 flash_meta,
708 flash_meta_full,
709 } = *inputs.downcast().expect("Downcast failed.");
710 let metadata = self.get_metadata();
711 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
712 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
713 (Some(_), None) => {
714 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
716 }
717 (None, Some(_)) => {
718 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
720 }
721 (None, None) => None,
722 };
723 let logits = match self.model {
724 Model::Llama(ref model) => {
725 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
726 }
727 Model::Phi2(ref model) => {
728 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
729 }
730 Model::XLoraLlama(ref model) => model.forward(
731 &input_ids,
732 input_ids_full.as_ref().unwrap_or(&input_ids),
733 &seqlen_offsets,
734 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
735 self.no_kv_cache,
736 &self.non_granular_state,
737 context_lens,
738 &flash_meta,
739 flash_meta_full.as_ref().unwrap_or(&flash_meta),
740 )?,
741 Model::Phi3(ref model) => {
742 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
743 }
744 Model::XLoraPhi3(ref model) => model.forward(
745 &input_ids,
746 input_ids_full.as_ref().unwrap_or(&input_ids),
747 &seqlen_offsets,
748 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
749 self.no_kv_cache,
750 &self.non_granular_state,
751 context_lens,
752 &flash_meta,
753 flash_meta_full.as_ref().unwrap_or(&flash_meta),
754 )?,
755 Model::Starcoder2(ref model) => {
756 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
757 }
758 Model::Qwen(ref model) => {
759 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
760 }
761 Model::Qwen3(ref model) => {
762 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
763 }
764 };
765 if return_raw_logits {
766 Ok(ForwardInputsResult::RawLogits { logits })
767 } else {
768 Ok(ForwardInputsResult::CausalGeneration { logits })
769 }
770 }
771 async fn sample_causal_gen(
772 &self,
773 seqs: &mut [&mut Sequence],
774 logits: Vec<Tensor>,
775 prefix_cacher: &mut PrefixCacheManagerV2,
776 disable_eos_stop: bool,
777 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
778 ) -> Result<(), candle_core::Error> {
779 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
780 }
781 fn category(&self) -> ModelCategory {
782 ModelCategory::Text
783 }
784}
785
786impl AnyMoePipelineMixin for GGUFPipeline {}