1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14 Text,
15 Vision,
16 Embedding,
17}
18
19pub struct InputProcessorOutput {
20 pub inputs: Box<dyn Any>,
21 pub seq_indices: Vec<usize>,
22}
23
24pub trait InputsProcessor {
26 #[allow(clippy::too_many_arguments)]
31 fn process_inputs(
32 &self,
33 tokenizer: Option<Arc<Tokenizer>>,
34 input_seqs: &mut [&mut Sequence],
35 is_prompt: bool,
36 is_xlora: bool,
37 device: &Device,
38 no_kv_cache: bool,
39 last_n_context_len: Option<(usize, usize)>,
40 return_raw_logits: bool,
41 other_config: Option<Arc<dyn Any>>,
42 paged_attn_metadata: Option<PagedAttentionMeta>,
43 mapper: Option<&dyn DeviceMapper>,
44 ) -> Result<InputProcessorOutput>;
45
46 fn get_type(&self) -> InputsProcessorType;
47}
48
49pub mod text_models_inputs_processor {
52 use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
53
54 use anyhow::Result;
55 use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
56 use tokenizers::Tokenizer;
57
58 use crate::{
59 device_map::DeviceMapper,
60 get_mut_arcmutex,
61 paged_attention::{BlockEngine, _PAD_SLOT_ID},
62 sequence::Sequence,
63 };
64
65 use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
66
67 fn _make_tensor_with_pad<D: WithDType>(
68 x: Vec<Vec<D>>,
69 max_len: usize,
70 pad: D,
71 device: &Device,
72 ) -> Result<Tensor> {
73 let mut padded_x = Vec::new();
74 for mut x_i in x {
75 assert!(x_i.len() <= max_len);
76 x_i.extend([pad].repeat(max_len - x_i.len()));
77 let shape = (x_i.len(),);
78 padded_x.push(Tensor::from_vec(x_i, shape, device)?);
79 }
80 Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
81 }
82
83 pub struct PagedAttentionMeta {
84 pub sliding_window: Option<usize>,
85 pub block_size: usize,
86 pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
87 }
88
89 #[derive(Clone, Debug)]
90 #[allow(dead_code)]
91 pub struct PagedAttentionInputMetadata {
92 pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
93 pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
94 pub slot_mappings: HashMap<DeviceLocation, Tensor>,
95 pub max_context_len: Option<usize>,
96 pub is_first_prompt_chunk: bool,
97 }
98
99 impl PagedAttentionInputMetadata {
100 pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
103 Ok(PagedAttentionInputMetadata {
104 block_tables: None,
105 context_lens: None,
106 max_context_len: None,
107 slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
108 is_first_prompt_chunk: true,
109 })
110 }
111 }
112
113 #[derive(Clone, Debug)]
114 pub struct FlashParams {
115 pub max_q: u32,
116 pub max_k: u32,
117 pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
118 pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
119 pub causal: bool,
120 }
121
122 pub struct InputMetadata {
123 pub input: Tensor,
124 pub positions: Vec<usize>,
125 pub context_lens: Vec<(usize, usize)>, pub position_ids: Vec<usize>,
127 pub paged_attn_meta: Option<PagedAttentionInputMetadata>, pub flash_meta: FlashParams,
129 }
130
131 pub struct InnerInputProcessorOutput {
132 pub inputs: InputMetadata,
133 pub seq_indices: Vec<usize>,
134 }
135
136 #[allow(clippy::too_many_arguments)]
139 pub fn make_prompt_chunk<T: WithDType + Debug>(
140 chunk_offset_toks: usize,
141 toks: Vec<&[T]>,
142 seq_ids: &[usize],
143 device: &Device,
144 last_n_context_len: Option<(usize, usize)>,
145 return_raw_logits: bool,
146 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
147 mapper: Option<&dyn DeviceMapper>,
148 ) -> Result<InputMetadata> {
149 let max_len = toks
150 .iter()
151 .map(|seq| seq.len())
152 .max()
153 .expect("No sequences");
154 let padding_tok = T::zero();
155 let mut seqs_tensors = Vec::new();
157 let mut seqlen_offsets = Vec::new();
158 let mut context_lens = Vec::new();
159 let mut position_ids = Vec::new();
160 let mut slot_mappings = Vec::new();
161 let mut block_tables = Vec::new();
162 let mut paged_attn_context_lens = Vec::new();
163 let flash_attn = crate::using_flash_attn();
164 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
165 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
166 for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
167 let prompt_len = ctxt.len();
168 let offset = last_n_context_len.unwrap_or_default();
169 seqlen_offsets.push(offset.1 + chunk_offset_toks);
170
171 position_ids.push(ctxt.len() + chunk_offset_toks);
172 let mut ctxt = ctxt.to_vec();
173 ctxt.extend(std::iter::repeat_n(
174 padding_tok,
175 max_len.saturating_sub(ctxt.len()),
176 ));
177 if return_raw_logits {
179 if last_n_context_len.is_some() {
180 anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181 }
182
183 context_lens.push((0, ctxt.len()));
184 } else {
185 context_lens.push((
186 ctxt.len()
187 .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
188 last_n_context_len.map(|(a, _)| a).unwrap_or(1),
189 ));
190 }
191
192 if flash_attn {
193 seqlens_q.push(ctxt.len() as u32);
194 seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
195 }
196
197 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
198
199 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
200 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
201 let table = block_engine.block_tables.get(seq_id);
202
203 if table.is_none() {
204 slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
206 continue;
207 }
208 let table = table
209 .unwrap()
210 .iter()
211 .map(|block| block.block_id())
212 .collect::<Vec<_>>();
213
214 let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
215 prompt_len.saturating_sub(sliding_window)
216 } else {
217 0
218 };
219
220 let mut slot_mapping = Vec::new();
221 let mut ctxt_len = Vec::new();
222 for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223 if i < start_idx {
224 slot_mapping.push(_PAD_SLOT_ID);
226 }
227 ctxt_len.push(i);
228
229 let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230 panic!(
231 "Block table is too small (prompt)! i={} block_size={} table_len={}",
232 i,
233 paged_attn_metadata.block_size,
234 table.len()
235 );
236 } else {
237 table.get(i / paged_attn_metadata.block_size).unwrap()
238 };
239 let block_offset = i % paged_attn_metadata.block_size;
240 let slot = block_number
242 .checked_mul(paged_attn_metadata.block_size)
243 .and_then(|v| v.checked_add(block_offset))
244 .expect("Slot calculation overflowed");
245 slot_mapping.push(
246 slot.try_into()
247 .expect("Slot value too large for target integer type"),
248 );
249 block_tables.push(table.clone());
250 }
251 slot_mappings.push(slot_mapping);
252 paged_attn_context_lens.push(ctxt_len);
253 }
254 }
255
256 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
257 let max_q = *seqlens_q
260 .iter()
261 .max()
262 .expect("seqlens_q should not be empty when flash_attn is enabled");
263 let max_k = *seqlens_k
264 .iter()
265 .max()
266 .expect("seqlens_k should not be empty when flash_attn is enabled");
267 let seqlens_q = Tensor::new(seqlens_q, &Device::Cpu)?
273 .to_dtype(DType::F32)?
274 .cumsum(0)?
275 .to_dtype(DType::U32)?;
276 let seqlens_k = Tensor::new(seqlens_k, &Device::Cpu)?
277 .to_dtype(DType::F32)?
278 .cumsum(0)?
279 .to_dtype(DType::U32)?;
280
281 let mut seqlens_q_map = HashMap::new();
282 let mut seqlens_k_map = HashMap::new();
283
284 let devices = mapper.unwrap().get_unique_devices();
285 for device in devices {
286 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
287 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
288 }
289 (max_q, max_k, seqlens_q_map, seqlens_k_map)
290 } else {
291 (0, 0, HashMap::new(), HashMap::new())
292 };
293
294 let input = Tensor::cat(&seqs_tensors, 0).unwrap();
295
296 let paged_attn_meta = if paged_attn_metadata.is_some() {
297 let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
299 let slot_mappings = _make_tensor_with_pad(
300 slot_mappings,
301 max_slot_mapping_len,
302 _PAD_SLOT_ID,
303 &Device::Cpu,
304 )?;
305
306 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
307 let block_tables = _make_tensor_with_pad(
308 block_tables
309 .iter()
310 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
311 .collect::<Vec<_>>(),
312 max_block_table_len,
313 0,
314 &Device::Cpu,
315 )?;
316 let block_tables = block_tables.reshape(((), max_block_table_len))?;
317
318 let max_context_len = paged_attn_context_lens
319 .iter()
320 .map(|x| x.len())
321 .max()
322 .unwrap();
323
324 let context_lens = _make_tensor_with_pad(
325 paged_attn_context_lens
326 .iter()
327 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
328 .collect::<Vec<_>>(),
329 max_context_len,
330 0,
331 &Device::Cpu,
332 )?
333 .reshape(((),))?;
334
335 let devices = mapper.unwrap().get_unique_devices();
337 let mut slot_mappings_map = HashMap::new();
338 let mut block_tables_map = HashMap::new();
339 let mut context_lens_map = HashMap::new();
340
341 for device in devices {
342 slot_mappings_map
343 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
344 block_tables_map
345 .insert(device.location(), block_tables.clone().to_device(&device)?);
346 context_lens_map
347 .insert(device.location(), context_lens.clone().to_device(&device)?);
348 }
349
350 Some(PagedAttentionInputMetadata {
351 slot_mappings: slot_mappings_map,
352 block_tables: Some(block_tables_map),
353 context_lens: Some(context_lens_map),
354 max_context_len: Some(max_context_len),
355 is_first_prompt_chunk: chunk_offset_toks == 0,
356 })
357 } else {
358 None
359 };
360
361 Ok(InputMetadata {
362 input,
363 positions: seqlen_offsets,
364 context_lens,
365 position_ids,
366 paged_attn_meta,
367 flash_meta: FlashParams {
368 max_k,
369 max_q,
370 cumulative_seqlens_k: seqlens_k_map,
371 cumulative_seqlens_q: seqlens_q_map,
372 causal: true,
373 },
374 })
375 }
376
377 fn make_completion_chunk<T: WithDType>(
378 toks: Vec<&[T]>,
379 input_seqs: &[&mut Sequence],
380 device: &Device,
381 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
382 mapper: Option<&dyn DeviceMapper>,
383 ) -> Result<InputMetadata> {
384 let flash_attn = crate::using_flash_attn();
386 let mut seqs_tensors = Vec::new();
387 let mut seqlen_offsets = Vec::new();
388 let mut context_lens = Vec::new();
389 let mut position_ids = Vec::new();
390
391 let mut slot_mappings = Vec::new();
392 let mut block_tables = Vec::new();
393 let mut paged_attn_context_lens = Vec::new();
394 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
395 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
396 for (seq, ctxt) in input_seqs.iter().zip(toks) {
397 let start_pos = ctxt.len().saturating_sub(1);
398 let ctxt = ctxt[start_pos..].to_vec();
399 seqlen_offsets.push(start_pos);
400 context_lens.push((0, 1));
401 position_ids.push(seq.len());
402
403 if flash_attn {
404 seqlens_q.push(ctxt.len() as u32);
405 seqlens_k.push((ctxt.len() + start_pos) as u32);
406 }
407
408 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
409
410 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
411 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
412 let table = block_engine.block_tables.get(seq.id()).unwrap();
413
414 let table = table
415 .iter()
416 .map(|block| block.block_id())
417 .collect::<Vec<_>>();
418
419 let block_pos = start_pos - seq.token_offset();
420 let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
421 panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
422 } else {
423 table
424 .get(block_pos / paged_attn_metadata.block_size)
425 .unwrap()
426 };
427 let block_offset = block_pos % paged_attn_metadata.block_size;
428 let slot = block_number
430 .checked_mul(paged_attn_metadata.block_size)
431 .and_then(|v| v.checked_add(block_offset))
432 .expect("Slot calculation overflowed");
433 let slot = slot
434 .try_into()
435 .expect("Slot value too large for target integer type");
436 slot_mappings.push(vec![slot]);
437
438 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
439 let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
440 let slide_idx = if table.len() > sliding_window_blocks {
441 table.len() - sliding_window_blocks
442 } else {
443 0
444 };
445 block_tables.push(table.get(slide_idx..).unwrap().to_vec());
446 } else {
447 block_tables.push(table);
448 }
449
450 let paged_attn_context_len =
451 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
452 seq.len().min(sliding_window)
453 } else {
454 seq.len()
455 };
456 paged_attn_context_lens.push(paged_attn_context_len);
457 }
458 }
459
460 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
461 let max_q = *seqlens_q
464 .iter()
465 .max()
466 .expect("seqlens_q should not be empty when flash_attn is enabled");
467 let max_k = *seqlens_k
468 .iter()
469 .max()
470 .expect("seqlens_k should not be empty when flash_attn is enabled");
471 let seqlens_q = Tensor::new(seqlens_q, &Device::Cpu)?
473 .to_dtype(DType::F32)?
474 .cumsum(0)?
475 .to_dtype(DType::U32)?;
476 let seqlens_k = Tensor::new(seqlens_k, &Device::Cpu)?
477 .to_dtype(DType::F32)?
478 .cumsum(0)?
479 .to_dtype(DType::U32)?;
480
481 let mut seqlens_q_map = HashMap::new();
482 let mut seqlens_k_map = HashMap::new();
483
484 let devices = mapper.unwrap().get_unique_devices();
485 for device in devices {
486 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
487 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
488 }
489 (max_q, max_k, seqlens_q_map, seqlens_k_map)
490 } else {
491 (0, 0, HashMap::new(), HashMap::new())
492 };
493
494 let paged_attn_meta = if paged_attn_metadata.is_some() {
495 let slot_mappings =
497 _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, &Device::Cpu)?;
498
499 let max_block_table_len = block_tables
500 .iter()
501 .map(|x| x.len())
502 .max()
503 .expect("block_tables should not be empty when paged attention is enabled");
504
505 let block_tables = _make_tensor_with_pad(
506 block_tables
507 .iter()
508 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
509 .collect::<Vec<_>>(),
510 max_block_table_len,
511 0,
512 &Device::Cpu,
513 )?;
514 let block_tables = block_tables.reshape(((), max_block_table_len))?;
515
516 let max_context_len = paged_attn_context_lens.iter().max().unwrap();
517
518 let context_lens = Tensor::from_vec(
519 paged_attn_context_lens
520 .iter()
521 .map(|x| *x as u32)
522 .collect::<Vec<_>>(),
523 (paged_attn_context_lens.len(),),
524 &Device::Cpu,
525 )?;
526
527 let devices = mapper.unwrap().get_unique_devices();
529 let mut slot_mappings_map = HashMap::new();
530 let mut block_tables_map = HashMap::new();
531 let mut context_lens_map = HashMap::new();
532
533 for device in devices {
534 slot_mappings_map
535 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
536 block_tables_map
537 .insert(device.location(), block_tables.clone().to_device(&device)?);
538 context_lens_map
539 .insert(device.location(), context_lens.clone().to_device(&device)?);
540 }
541
542 Some(PagedAttentionInputMetadata {
543 slot_mappings: slot_mappings_map,
544 block_tables: Some(block_tables_map),
545 context_lens: Some(context_lens_map),
546 max_context_len: Some(*max_context_len),
547 is_first_prompt_chunk: false,
548 })
549 } else {
550 None
551 };
552
553 Ok(InputMetadata {
554 input: Tensor::cat(&seqs_tensors, 0).unwrap(),
555 positions: seqlen_offsets,
556 context_lens,
557 position_ids,
558 paged_attn_meta,
559 flash_meta: FlashParams {
560 max_k,
561 max_q,
562 cumulative_seqlens_k: seqlens_k_map,
563 cumulative_seqlens_q: seqlens_q_map,
564 causal: true,
565 },
566 })
567 }
568
569 #[allow(clippy::too_many_arguments)]
570 pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
571 toks: Vec<&[T]>,
572 input_seqs: &[&mut Sequence],
573 device: &Device,
574 last_n_context_len: Option<(usize, usize)>,
575 return_raw_logits: bool,
576 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
577 mapper: Option<&dyn DeviceMapper>,
578 ) -> Result<InnerInputProcessorOutput> {
579 let offset = input_seqs[0].token_offset();
580 make_prompt_chunk(
581 offset,
582 toks,
583 &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
584 device,
585 last_n_context_len,
586 return_raw_logits,
587 paged_attn_metadata,
588 mapper,
589 )
590 .map(|inputs| InnerInputProcessorOutput {
591 inputs,
592 seq_indices: (0..input_seqs.len()).collect(),
593 })
594 }
595
596 #[allow(clippy::too_many_arguments)]
597 pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
598 toks: Vec<&[T]>,
599 input_seqs: &[&mut Sequence],
600 device: &Device,
601 no_kv_cache: bool,
602 last_n_context_len: Option<(usize, usize)>,
603 return_raw_logits: bool,
604 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
605 mapper: Option<&dyn DeviceMapper>,
606 ) -> Result<InnerInputProcessorOutput> {
607 if no_kv_cache {
608 return get_prompt_input(
609 toks,
610 input_seqs,
611 device,
612 last_n_context_len,
613 return_raw_logits,
614 paged_attn_metadata,
615 mapper,
616 );
617 }
618
619 make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
620 InnerInputProcessorOutput {
621 inputs,
622 seq_indices: (0..input_seqs.len()).collect(),
623 }
624 })
625 }
626
627 #[derive(Clone)]
628 pub struct ModelInputs {
629 pub input_ids: Tensor,
630 pub input_ids_full: Option<Tensor>,
631 pub seqlen_offsets: Vec<usize>,
632 pub seqlen_offsets_full: Option<Vec<usize>>,
633 pub context_lens: Vec<(usize, usize)>,
634 pub position_ids: Vec<usize>,
635 pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
636 pub flash_meta: FlashParams,
637 pub flash_meta_full: Option<FlashParams>,
638 }
639
640 pub struct TextInputsProcessor;
641
642 impl InputsProcessor for TextInputsProcessor {
643 fn process_inputs(
644 &self,
645 _: Option<Arc<Tokenizer>>,
646 input_seqs: &mut [&mut Sequence],
647 is_prompt: bool,
648 is_xlora: bool,
649 device: &Device,
650 no_kv_cache: bool,
651 last_n_context_len: Option<(usize, usize)>,
652 return_raw_logits: bool,
653 _: Option<Arc<dyn Any>>,
654 mut paged_attn_metadata: Option<PagedAttentionMeta>,
655 mapper: Option<&dyn DeviceMapper>,
656 ) -> Result<InputProcessorOutput> {
657 if is_xlora && !is_prompt {
658 let prompt = get_prompt_input(
659 input_seqs
660 .iter()
661 .map(|seq| seq.get_toks())
662 .collect::<Vec<_>>(),
663 input_seqs,
664 device,
665 last_n_context_len,
666 return_raw_logits,
667 paged_attn_metadata.as_mut(),
668 mapper,
669 )?;
670 let completion = get_completion_input(
671 input_seqs
672 .iter()
673 .map(|seq| seq.get_toks())
674 .collect::<Vec<_>>(),
675 input_seqs,
676 device,
677 no_kv_cache,
678 last_n_context_len,
679 return_raw_logits,
680 paged_attn_metadata.as_mut(),
681 mapper,
682 )?;
683 let InnerInputProcessorOutput {
684 inputs:
685 InputMetadata {
686 input: input_ids_full,
687 positions: seqlen_offsets_full,
688 context_lens: _,
689 position_ids,
690 paged_attn_meta: _,
691 flash_meta: flash_meta_full,
692 },
693 seq_indices,
694 } = prompt;
695 let InnerInputProcessorOutput {
696 inputs:
697 InputMetadata {
698 input: input_ids,
699 positions: seqlen_offsets,
700 context_lens,
701 position_ids: _,
702 paged_attn_meta,
703 flash_meta,
704 },
705 seq_indices: _,
706 } = completion;
707 let inputs: Box<dyn Any> = Box::new(ModelInputs {
708 input_ids,
709 input_ids_full: Some(input_ids_full),
710 seqlen_offsets,
711 seqlen_offsets_full: Some(seqlen_offsets_full),
712 context_lens,
713 position_ids,
714 paged_attn_meta,
715 flash_meta,
716 flash_meta_full: Some(flash_meta_full),
717 });
718 Ok(InputProcessorOutput {
719 inputs,
720 seq_indices,
721 })
722 } else if is_xlora && is_prompt {
723 let metadata = get_prompt_input(
724 input_seqs
725 .iter()
726 .map(|seq| seq.get_toks())
727 .collect::<Vec<_>>(),
728 input_seqs,
729 device,
730 last_n_context_len,
731 return_raw_logits,
732 paged_attn_metadata.as_mut(),
733 mapper,
734 )?;
735 let InnerInputProcessorOutput {
736 inputs:
737 InputMetadata {
738 input: input_ids,
739 positions: seqlen_offsets,
740 context_lens,
741 position_ids,
742 paged_attn_meta,
743 flash_meta,
744 },
745 seq_indices,
746 } = metadata;
747 let inputs: Box<dyn Any> = Box::new(ModelInputs {
748 input_ids: input_ids.clone(),
749 input_ids_full: Some(input_ids),
750 seqlen_offsets: seqlen_offsets.clone(),
751 seqlen_offsets_full: Some(seqlen_offsets),
752 context_lens,
753 position_ids,
754 paged_attn_meta,
755 flash_meta: flash_meta.clone(),
756 flash_meta_full: Some(flash_meta),
757 });
758 Ok(InputProcessorOutput {
759 inputs,
760 seq_indices,
761 })
762 } else if is_prompt {
763 let metadata = get_prompt_input(
764 input_seqs
765 .iter()
766 .map(|seq| seq.get_toks())
767 .collect::<Vec<_>>(),
768 input_seqs,
769 device,
770 last_n_context_len,
771 return_raw_logits,
772 paged_attn_metadata.as_mut(),
773 mapper,
774 )?;
775 let InnerInputProcessorOutput {
776 inputs:
777 InputMetadata {
778 input: input_ids,
779 positions: seqlen_offsets,
780 context_lens,
781 position_ids,
782 paged_attn_meta,
783 flash_meta,
784 },
785 seq_indices,
786 } = metadata;
787 let inputs: Box<dyn Any> = Box::new(ModelInputs {
788 input_ids,
789 input_ids_full: None,
790 seqlen_offsets,
791 seqlen_offsets_full: None,
792 context_lens,
793 position_ids,
794 paged_attn_meta,
795 flash_meta,
796 flash_meta_full: None,
797 });
798 Ok(InputProcessorOutput {
799 inputs,
800 seq_indices,
801 })
802 } else {
803 let metadata = get_completion_input(
804 input_seqs
805 .iter()
806 .map(|seq| seq.get_toks())
807 .collect::<Vec<_>>(),
808 input_seqs,
809 device,
810 no_kv_cache,
811 last_n_context_len,
812 return_raw_logits,
813 paged_attn_metadata.as_mut(),
814 mapper,
815 )?;
816 let InnerInputProcessorOutput {
817 inputs:
818 InputMetadata {
819 input: input_ids,
820 positions: seqlen_offsets,
821 context_lens,
822 position_ids,
823 paged_attn_meta,
824 flash_meta,
825 },
826 seq_indices,
827 } = metadata;
828 let inputs: Box<dyn Any> = Box::new(ModelInputs {
829 input_ids,
830 input_ids_full: None,
831 seqlen_offsets,
832 seqlen_offsets_full: None,
833 context_lens,
834 position_ids,
835 paged_attn_meta,
836 flash_meta,
837 flash_meta_full: None,
838 });
839 Ok(InputProcessorOutput {
840 inputs,
841 seq_indices,
842 })
843 }
844 }
845
846 fn get_type(&self) -> InputsProcessorType {
847 InputsProcessorType::Text
848 }
849 }
850}