1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::llg::build_tok_env;
3use super::{
4 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
6 TokenSource,
7};
8use super::{
9 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
10 MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use crate::device_map::{self, DeviceMapper};
13use crate::gguf::{
14 get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
15};
16use crate::gguf::{Content, GGUFArchitecture};
17use crate::lora::Ordering;
18use crate::paged_attention::{
19 calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
20};
21use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkPadTok, GenerationConfig};
22use crate::pipeline::get_chat_template;
23use crate::pipeline::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
24use crate::pipeline::loaders::DeviceMappedModelLoader;
25use crate::pipeline::sampling::sample_and_add_toks;
26use crate::pipeline::ChatTemplate;
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::gguf_metadata::{ContentConfig, GgufDeviceMapLoaderInner};
30use crate::utils::model_config as ModelConfig;
31use crate::utils::tokenizer::get_tokenizer;
32use crate::xlora_models::NonGranularState;
33use crate::{
34 get_mut_arcmutex, get_paths_gguf, DeviceMapSetting, LocalModelPaths, PagedAttentionConfig,
35 Pipeline, Topology, TryIntoDType,
36};
37use crate::{
38 models::quantized_llama::ModelWeights as QLlama,
39 models::quantized_phi2::ModelWeights as QPhi,
40 models::quantized_phi3::ModelWeights as QPhi3,
41 models::quantized_qwen2::ModelWeights as QQwen2,
42 models::quantized_starcoder2::ModelWeights as QStarcoder2,
43 utils::tokens::get_token,
44 xlora_models::{XLoraQLlama, XLoraQPhi3},
45};
46use anyhow::{bail, Result};
47use candle_core::{Device, Tensor};
48use either::Either;
49use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
50use mistralrs_quant::IsqType;
51use rand_isaac::Isaac64Rng;
52use std::any::Any;
53use std::fs;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::PathBuf;
56use std::str::FromStr;
57use std::sync::Arc;
58use tokenizers::Tokenizer;
59use tokio::sync::Mutex;
60use tracing::{info, warn};
61
62enum Model {
63 Llama(QLlama),
64 Phi2(QPhi),
65 XLoraLlama(XLoraQLlama),
66 XLoraPhi3(XLoraQPhi3),
67 Phi3(QPhi3),
68 Starcoder2(QStarcoder2),
69 Qwen2(QQwen2),
70}
71
72pub struct GGUFPipeline {
73 model: Model,
74 tokenizer: Arc<Tokenizer>,
75 no_kv_cache: bool,
76 chat_template: Arc<ChatTemplate>,
77 model_id: String,
78 non_granular_state: Option<NonGranularState>,
79 metadata: Arc<GeneralMetadata>,
80 mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83pub struct GGUFLoader {
85 model_id: Option<String>,
86 quantized_model_id: String,
87 quantized_filenames: Vec<String>,
88 xlora_model_id: Option<String>,
89 xlora_order: Option<Ordering>,
90 no_kv_cache: bool,
91 chat_template: Option<String>,
92 kind: ModelKind,
93 tgt_non_granular_index: Option<usize>,
94 config: GGUFSpecificConfig,
95 jinja_explicit: Option<String>,
96 lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Clone, Default)]
100pub struct GGUFSpecificConfig {
102 pub prompt_chunksize: Option<NonZeroUsize>,
103 pub topology: Option<Topology>,
104}
105
106#[derive(Default)]
107pub struct GGUFLoaderBuilder {
109 model_id: Option<String>,
110 quantized_model_id: String,
111 quantized_filenames: Vec<String>,
112 xlora_model_id: Option<String>,
113 kind: ModelKind,
114 xlora_order: Option<Ordering>,
115 no_kv_cache: bool,
116 chat_template: Option<String>,
117 tgt_non_granular_index: Option<usize>,
118 config: GGUFSpecificConfig,
119 jinja_explicit: Option<String>,
120}
121
122impl GGUFLoaderBuilder {
123 pub fn new(
127 chat_template: Option<String>,
128 tok_model_id: Option<String>,
129 quantized_model_id: String,
130 quantized_filenames: Vec<String>,
131 config: GGUFSpecificConfig,
132 no_kv_cache: bool,
133 jinja_explicit: Option<String>,
134 ) -> Self {
135 let kind = ModelKind::GgufQuantized {
136 quant: QuantizationKind::Gguf,
137 };
138
139 Self {
140 chat_template,
141 model_id: tok_model_id,
142 kind,
143 quantized_filenames,
144 quantized_model_id,
145 config,
146 jinja_explicit,
147 no_kv_cache,
148 ..Default::default()
149 }
150 }
151
152 fn with_adapter(
153 mut self,
154 xlora_model_id: String,
155 xlora_order: Ordering,
156 no_kv_cache: bool,
157 tgt_non_granular_index: Option<usize>,
158 ) -> Self {
159 self.xlora_model_id = Some(xlora_model_id);
160 self.xlora_order = Some(xlora_order);
161 self.no_kv_cache = no_kv_cache;
162 self.tgt_non_granular_index = tgt_non_granular_index;
163 self.model_id = if let Some(id) = self.model_id {
164 Some(id)
165 } else {
166 info!(
167 "Using adapter base model ID: `{}`",
168 self.xlora_order.as_ref().unwrap().base_model_id
169 );
170 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
171 };
172 self
173 }
174
175 pub fn with_xlora(
176 mut self,
177 xlora_model_id: String,
178 xlora_order: Ordering,
179 no_kv_cache: bool,
180 tgt_non_granular_index: Option<usize>,
181 ) -> Self {
182 self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into();
183
184 self.with_adapter(
185 xlora_model_id,
186 xlora_order,
187 no_kv_cache,
188 tgt_non_granular_index,
189 )
190 }
191
192 pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
193 self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into();
194
195 self.with_adapter(lora_model_id, lora_order, false, None)
196 }
197
198 pub fn build(self) -> Box<dyn Loader> {
199 Box::new(GGUFLoader {
200 model_id: self.model_id,
201 xlora_model_id: self.xlora_model_id,
202 kind: self.kind,
203 xlora_order: self.xlora_order,
204 no_kv_cache: self.no_kv_cache,
205 chat_template: self.chat_template,
206 tgt_non_granular_index: self.tgt_non_granular_index,
207 quantized_filenames: self.quantized_filenames,
208 quantized_model_id: self.quantized_model_id,
209 config: self.config,
210 jinja_explicit: self.jinja_explicit,
211 lora_adapter_ids: None,
212 })
213 }
214}
215
216impl GGUFLoader {
217 #[allow(clippy::too_many_arguments)]
218 pub fn new(
219 model_id: Option<String>,
220 quantized_model_id: String,
221 quantized_filenames: Vec<String>,
222 xlora_model_id: Option<String>,
223 kind: ModelKind,
224 xlora_order: Option<Ordering>,
225 no_kv_cache: bool,
226 chat_template: Option<String>,
227 tgt_non_granular_index: Option<usize>,
228 config: GGUFSpecificConfig,
229 jinja_explicit: Option<String>,
230 ) -> Self {
231 let model_id = if let Some(id) = model_id {
232 Some(id)
233 } else if let Some(xlora_order) = xlora_order.clone() {
234 info!(
235 "Using adapter base model ID: `{}`",
236 xlora_order.base_model_id
237 );
238 Some(xlora_order.base_model_id.clone())
239 } else {
240 None
241 };
242 Self {
243 model_id,
244 quantized_model_id,
245 quantized_filenames,
246 xlora_model_id,
247 xlora_order,
248 no_kv_cache,
249 chat_template,
250 kind,
251 tgt_non_granular_index,
252 config,
253 jinja_explicit,
254 lora_adapter_ids: None,
255 }
256 }
257}
258
259impl Loader for GGUFLoader {
260 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
261 fn load_model_from_hf(
262 &self,
263 revision: Option<String>,
264 token_source: TokenSource,
265 dtype: &dyn TryIntoDType,
266 device: &Device,
267 silent: bool,
268 mapper: DeviceMapSetting,
269 in_situ_quant: Option<IsqType>,
270 paged_attn_config: Option<PagedAttentionConfig>,
271 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
272 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths_gguf!(
273 LocalModelPaths,
274 &token_source,
275 revision,
276 self,
277 self.quantized_model_id.clone(),
278 self.quantized_filenames.clone(),
279 silent
280 );
281 self.load_model_from_path(
282 &paths?,
283 dtype,
284 device,
285 silent,
286 mapper,
287 in_situ_quant,
288 paged_attn_config,
289 )
290 }
291
292 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
293 fn load_model_from_path(
294 &self,
295 paths: &Box<dyn ModelPaths>,
296 dtype: &dyn TryIntoDType,
297 device: &Device,
298 silent: bool,
299 mut mapper: DeviceMapSetting,
300 in_situ_quant: Option<IsqType>,
301 mut paged_attn_config: Option<PagedAttentionConfig>,
302 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
303 if in_situ_quant.is_some() {
304 anyhow::bail!(
305 "You are trying to in-situ quantize a GGUF model. This will not do anything."
306 );
307 }
308
309 let prompt_chunksize = self
311 .config
312 .prompt_chunksize
313 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
314 .get();
315
316 info!("Prompt chunk size is {prompt_chunksize}.",);
317
318 let mut readers = Vec::new();
319 for filename in paths.get_weight_filenames() {
320 readers.push(std::fs::File::open(filename)?);
321 }
322 let mut readers = readers.iter_mut().collect::<Vec<_>>();
323
324 let model = Content::from_readers(&mut readers)?;
325 if !silent {
326 model.print_metadata()?;
327 }
328 let arch = model.arch();
329
330 let num_layers = model.get_metadata()[&format!("{arch}.block_count")].to_u32()? as usize;
332 if let DeviceMapSetting::Auto(params) = mapper.clone() {
333 let devices = device_map::get_all_similar_devices(device)?;
334 let dtype = dtype.try_into_dtype(&devices.iter().collect::<Vec<_>>())?;
336
337 let model = GgufDeviceMapLoaderInner {
338 model: &model,
339 arch,
340 };
341
342 let layer_sizes_in_bytes =
343 model.layer_sizes_in_bytes("this is a dummy config!", dtype, 1)?;
344 let non_mapped_size_in_bytes =
345 model.non_mapped_size_in_bytes("this is a dummy config!", dtype, 1)?;
346 let total_model_size_in_bytes =
347 layer_sizes_in_bytes.iter().sum::<usize>() + non_mapped_size_in_bytes;
348
349 let new = model.get_device_layers(
350 "this is a dummy config!",
351 num_layers,
352 layer_sizes_in_bytes,
353 non_mapped_size_in_bytes,
354 total_model_size_in_bytes,
355 &devices,
356 dtype,
357 ¶ms,
358 prompt_chunksize,
359 paged_attn_config.as_ref(),
360 )?;
361 mapper = DeviceMapSetting::Map(new);
362 }
363
364 let pipeline_mapper =
365 mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
366 let mapper = mapper.into_mapper(num_layers, device, self.config.topology.as_ref())?;
367 let mut layer_devices = Vec::new();
368 for layer in 0..num_layers {
369 let device = mapper.device_for(layer, false).cloned();
370 layer_devices.push(device);
371 }
372
373 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
376 if mapping_uses_cpu {
377 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
378 paged_attn_config = None;
379 }
380
381 let GgufTokenizerConversion {
382 tokenizer,
383 bos,
384 eos,
385 unk,
386 } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
387 convert_gguf_to_hf_tokenizer(&model)?
388 } else {
389 GgufTokenizerConversion {
390 tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?,
391 bos: None,
392 eos: None,
393 unk: None,
394 }
395 };
396
397 let gguf_chat_template =
399 if paths.get_template_filename().is_none() && self.chat_template.is_none() {
400 get_gguf_chat_template(&model)?
401 } else {
402 None
403 };
404
405 let has_adapter = self.kind.is_adapted();
406 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
407
408 let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
409 warn!("Adapter models do not currently support PagedAttention, running without");
410 None
411 } else {
412 paged_attn_config
413 };
414
415 let model_config_metadata: ContentConfig = (&model).into();
416 let internal_dtype = mapper.get_min_dtype(dtype)?;
417
418 let model_config = {
419 let quant = ModelConfig::ParamsGGUF(
421 model,
422 (device, mapper).into(),
423 if paged_attn_config.is_some() {
424 AttentionImplementation::PagedAttention
425 } else {
426 AttentionImplementation::Eager
427 },
428 internal_dtype,
429 );
430
431 let mut adapter = None;
433 if has_adapter {
434 adapter.replace(ModelConfig::Adapter::try_new(
435 paths, device, silent, is_xlora,
436 )?);
437 }
438
439 ModelConfig::ModelParams::new(quant, adapter)
440 };
441
442 let model = match self.kind {
444 ModelKind::GgufQuantized { .. } => match arch {
445 GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
446 GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
447 GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
448 GGUFArchitecture::Starcoder2 => {
449 Model::Starcoder2(QStarcoder2::try_from(model_config)?)
450 }
451 GGUFArchitecture::Qwen2 => Model::Qwen2(QQwen2::try_from(model_config)?),
452 a => bail!("Unsupported architecture `{a:?}` for GGUF"),
453 },
454 ModelKind::GgufAdapter { adapter, .. } => match arch {
455 GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
456 GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
457 a => bail!(
458 "Unsupported architecture `{a:?}` for GGUF {kind}",
459 kind = adapter.pretty_name()
460 ),
461 },
462 _ => unreachable!(),
463 };
464
465 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
466 let model_config: &dyn ModelConfigLike = &model_config_metadata;
467 let cache_config = calculate_cache_config(
468 paged_attn_config.mem_gpu,
469 paged_attn_config.mem_cpu,
470 paged_attn_config.block_size,
471 internal_dtype,
472 model_config,
473 device,
474 &layer_devices,
475 silent,
476 )?;
477 let cache_engine = CacheEngine::new(
478 model_config,
479 &cache_config,
480 internal_dtype,
481 device,
482 layer_devices,
483 )?;
484 (Some(cache_config), Some(cache_engine))
485 } else {
486 (None, None)
487 };
488
489 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
490 serde_json::from_str(&fs::read_to_string(f).unwrap())
491 .expect("bos_token_id/eos_token_id missing in generation_config.json")
492 });
493 let mut chat_template = get_chat_template(
494 paths,
495 &self.jinja_explicit,
496 &paths
497 .get_chat_template_explicit()
498 .as_ref()
499 .map(|x| x.to_string_lossy().to_string())
500 .clone(),
501 &self.chat_template,
502 gguf_chat_template,
503 );
504
505 let max_seq_len = match model {
506 Model::Llama(ref l) => l.max_seq_len,
507 Model::Phi2(ref p) => p.max_seq_len,
508 Model::XLoraLlama(ref xl) => xl.max_seq_len,
509 Model::Phi3(ref p) => p.max_seq_len,
510 Model::XLoraPhi3(ref p) => p.max_seq_len,
511 Model::Starcoder2(ref p) => p.max_seq_len,
512 Model::Qwen2(ref p) => p.max_seq_len,
513 };
514 let tok_env = build_tok_env(tokenizer.clone());
515 let num_hidden_layers = match model {
516 Model::Llama(ref model) => model.cache.normal().0.len(),
517 Model::Phi2(ref model) => model.cache.normal().0.len(),
518 Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
519 Model::Phi3(ref model) => model.cache.normal().0.len(),
520 Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
521 Model::Starcoder2(ref model) => model.cache.normal().0.len(),
522 Model::Qwen2(ref model) => model.cache.normal().0.len(),
523 };
524
525 if chat_template.bos_token.is_none() && bos.is_some() {
526 chat_template.bos_token = Some(BeginEndUnkPadTok(Either::Left(bos.unwrap())));
527 }
528 if chat_template.eos_token.is_none() && eos.is_some() {
529 chat_template.eos_token = Some(BeginEndUnkPadTok(Either::Left(eos.unwrap())));
530 }
531 if chat_template.unk_token.is_none() && unk.is_some() {
532 chat_template.unk_token = Some(BeginEndUnkPadTok(Either::Left(unk.unwrap())));
533 }
534
535 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
536 Ok(Arc::new(Mutex::new(GGUFPipeline {
537 model,
538 tokenizer: tokenizer.into(),
539 no_kv_cache: self.no_kv_cache,
540 chat_template: Arc::new(chat_template),
541 model_id: self
542 .model_id
543 .clone()
544 .unwrap_or(self.quantized_model_id.clone()),
545 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
546 NonGranularState {
547 non_granular_index: Arc::new(Mutex::new(0)),
548 tgt_non_granular_index,
549 }
550 }),
551 metadata: Arc::new(GeneralMetadata {
552 max_seq_len,
553 tok_env: Some(tok_env),
554 no_kv_cache: self.no_kv_cache,
555 no_prefix_cache: false,
556 num_hidden_layers,
557 eos_tok: eos,
558 kind: self.kind.clone(),
559 is_xlora,
560 activation_dtype: internal_dtype,
561 sliding_window: None,
562 cache_config,
563 cache_engine,
564 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
565 model_metadata: Some(Arc::new(model_config_metadata)),
566 }),
567 mapper: pipeline_mapper,
568 })))
569 }
570
571 fn get_id(&self) -> String {
572 self.xlora_model_id
573 .as_deref()
574 .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
575 .to_string()
576 }
577
578 fn get_kind(&self) -> ModelKind {
579 self.kind.clone()
580 }
581}
582
583impl PreProcessingMixin for GGUFPipeline {
584 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
585 Some(self.chat_template.clone())
586 }
587 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
588 None
589 }
590}
591
592impl IsqPipelineMixin for GGUFPipeline {
593 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
594 anyhow::bail!(
595 "You are trying to in-situ requantize a GGML model. This will not do anything."
596 )
597 }
598}
599
600impl CacheManagerMixin for GGUFPipeline {
601 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
602 if matches!(self.cache(), EitherCache::Full(_)) {
603 FullCacheManager.clone_in_cache(self, seqs, false)
604 } else {
605 NormalCacheManager.clone_in_cache(self, seqs, false)
606 }
607 }
608 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
609 if matches!(self.cache(), EitherCache::Full(_)) {
610 FullCacheManager.clone_out_cache(self, seqs, false)
611 } else {
612 NormalCacheManager.clone_out_cache(self, seqs, false)
613 }
614 }
615 fn set_none_cache(
616 &self,
617 seqs: &mut [&mut Sequence],
618 reset_non_granular: bool,
619 modify_draft_cache: bool,
620 load_preallocated_cache: bool,
621 ) {
622 if matches!(self.cache(), EitherCache::Full(_)) {
623 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
624 } else {
625 NormalCacheManager.set_none_cache(
626 self,
627 seqs,
628 modify_draft_cache,
629 load_preallocated_cache,
630 );
631 }
632 if reset_non_granular {
633 self.reset_non_granular_state()
634 }
635 }
636 fn cache(&self) -> &EitherCache {
637 match self.model {
638 Model::Llama(ref model) => &model.cache,
639 Model::Phi2(ref model) => &model.cache,
640 Model::XLoraLlama(ref model) => &model.cache,
641 Model::Phi3(ref model) => &model.cache,
642 Model::XLoraPhi3(ref model) => &model.cache,
643 Model::Starcoder2(ref model) => &model.cache,
644 Model::Qwen2(ref model) => &model.cache,
645 }
646 }
647}
648
649impl MetadataMixin for GGUFPipeline {
650 fn device(&self) -> Device {
651 match self.model {
652 Model::Llama(ref model) => model.device.clone(),
653 Model::Phi2(ref model) => model.device.clone(),
654 Model::XLoraLlama(ref model) => model.device.clone(),
655 Model::Phi3(ref model) => model.device.clone(),
656 Model::XLoraPhi3(ref model) => model.device.clone(),
657 Model::Starcoder2(ref model) => model.device.clone(),
658 Model::Qwen2(ref model) => model.device.clone(),
659 }
660 }
661 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
662 Some(self.tokenizer.clone())
663 }
664 fn name(&self) -> String {
665 self.model_id.clone()
666 }
667 fn reset_non_granular_state(&self) {
668 if let Some(s) = self.non_granular_state.as_ref() {
669 *self.cache().full().get_scalings_cache() = None;
670 *get_mut_arcmutex!(s.non_granular_index) = 0;
671 }
672 }
673 fn get_metadata(&self) -> Arc<GeneralMetadata> {
674 self.metadata.clone()
675 }
676 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
677 Some(&*self.mapper)
678 }
679}
680
681#[async_trait::async_trait]
682impl Pipeline for GGUFPipeline {
683 fn forward_inputs(
684 &mut self,
685 inputs: Box<dyn Any>,
686 return_raw_logits: bool,
687 ) -> Result<ForwardInputsResult, candle_core::Error> {
688 let ModelInputs {
689 input_ids,
690 input_ids_full,
691 seqlen_offsets,
692 seqlen_offsets_full,
693 context_lens,
694 position_ids: _, paged_attn_meta,
696 flash_meta,
697 flash_meta_full,
698 } = *inputs.downcast().expect("Downcast failed.");
699 let metadata = self.get_metadata();
700 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
701 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
702 (Some(_), None) => {
703 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
705 }
706 (None, Some(_)) => {
707 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
709 }
710 (None, None) => None,
711 };
712 let logits = match self.model {
713 Model::Llama(ref model) => {
714 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
715 }
716 Model::Phi2(ref model) => {
717 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
718 }
719 Model::XLoraLlama(ref model) => model.forward(
720 &input_ids,
721 input_ids_full.as_ref().unwrap_or(&input_ids),
722 &seqlen_offsets,
723 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
724 self.no_kv_cache,
725 &self.non_granular_state,
726 context_lens,
727 &flash_meta,
728 flash_meta_full.as_ref().unwrap_or(&flash_meta),
729 )?,
730 Model::Phi3(ref model) => {
731 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
732 }
733 Model::XLoraPhi3(ref model) => model.forward(
734 &input_ids,
735 input_ids_full.as_ref().unwrap_or(&input_ids),
736 &seqlen_offsets,
737 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
738 self.no_kv_cache,
739 &self.non_granular_state,
740 context_lens,
741 &flash_meta,
742 flash_meta_full.as_ref().unwrap_or(&flash_meta),
743 )?,
744 Model::Starcoder2(ref model) => {
745 model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
746 }
747 Model::Qwen2(ref model) => {
748 model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
749 }
750 };
751 if return_raw_logits {
752 Ok(ForwardInputsResult::RawLogits { logits })
753 } else {
754 Ok(ForwardInputsResult::CausalGeneration { logits })
755 }
756 }
757 async fn sample_causal_gen(
758 &self,
759 seqs: &mut [&mut Sequence],
760 logits: Vec<Tensor>,
761 prefix_cacher: &mut PrefixCacheManagerV2,
762 disable_eos_stop: bool,
763 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
764 ) -> Result<(), candle_core::Error> {
765 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
766 }
767 fn category(&self) -> ModelCategory {
768 ModelCategory::Text
769 }
770}
771
772impl AnyMoePipelineMixin for GGUFPipeline {}