1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14 Text,
15 Vision,
16 Embedding,
17}
18
19pub struct InputProcessorOutput {
20 pub inputs: Box<dyn Any>,
21 pub seq_indices: Vec<usize>,
22}
23
24pub trait InputsProcessor {
26 #[allow(clippy::too_many_arguments)]
31 fn process_inputs(
32 &self,
33 tokenizer: Option<Arc<Tokenizer>>,
34 input_seqs: &mut [&mut Sequence],
35 is_prompt: bool,
36 is_xlora: bool,
37 device: &Device,
38 no_kv_cache: bool,
39 last_n_context_len: Option<(usize, usize)>,
40 return_raw_logits: bool,
41 other_config: Option<Arc<dyn Any>>,
42 paged_attn_metadata: Option<PagedAttentionMeta>,
43 mapper: Option<&dyn DeviceMapper>,
44 ) -> Result<InputProcessorOutput>;
45
46 fn get_type(&self) -> InputsProcessorType;
47}
48
49pub mod text_models_inputs_processor {
52 use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
53
54 use anyhow::Result;
55 use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
56 use tokenizers::Tokenizer;
57
58 use crate::{
59 device_map::DeviceMapper,
60 get_mut_arcmutex,
61 paged_attention::{BlockEngine, _PAD_SLOT_ID},
62 sequence::Sequence,
63 };
64
65 use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
66
67 fn _make_tensor_with_pad<D: WithDType>(
68 x: Vec<Vec<D>>,
69 max_len: usize,
70 pad: D,
71 device: &Device,
72 ) -> Result<Tensor> {
73 let mut padded_x = Vec::new();
74 for mut x_i in x {
75 assert!(x_i.len() <= max_len);
76 x_i.extend([pad].repeat(max_len - x_i.len()));
77 let shape = (x_i.len(),);
78 padded_x.push(Tensor::from_vec(x_i, shape, device)?);
79 }
80 Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
81 }
82
83 pub struct PagedAttentionMeta {
84 pub sliding_window: Option<usize>,
85 pub block_size: usize,
86 pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
87 }
88
89 #[derive(Clone, Debug)]
90 #[allow(dead_code)]
91 pub struct PagedAttentionInputMetadata {
92 pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
93 pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
94 pub slot_mappings: HashMap<DeviceLocation, Tensor>,
95 pub max_context_len: Option<usize>,
96 pub is_first_prompt_chunk: bool,
97 }
98
99 impl PagedAttentionInputMetadata {
100 pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
103 Ok(PagedAttentionInputMetadata {
104 block_tables: None,
105 context_lens: None,
106 max_context_len: None,
107 slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
108 is_first_prompt_chunk: true,
109 })
110 }
111 }
112
113 #[derive(Clone, Debug)]
114 pub struct FlashParams {
115 pub max_q: u32,
116 pub max_k: u32,
117 pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
118 pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
119 pub causal: bool,
120 }
121
122 pub struct InputMetadata {
123 pub input: Tensor,
124 pub positions: Vec<usize>,
125 pub context_lens: Vec<(usize, usize)>, pub position_ids: Vec<usize>,
127 pub paged_attn_meta: Option<PagedAttentionInputMetadata>, pub flash_meta: FlashParams,
129 }
130
131 pub struct InnerInputProcessorOutput {
132 pub inputs: InputMetadata,
133 pub seq_indices: Vec<usize>,
134 }
135
136 #[allow(clippy::too_many_arguments)]
139 pub fn make_prompt_chunk<T: WithDType + Debug>(
140 chunk_offset_toks: usize,
141 toks: Vec<&[T]>,
142 seq_ids: &[usize],
143 device: &Device,
144 last_n_context_len: Option<(usize, usize)>,
145 return_raw_logits: bool,
146 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
147 mapper: Option<&dyn DeviceMapper>,
148 ) -> Result<InputMetadata> {
149 let max_len = toks
150 .iter()
151 .map(|seq| seq.len())
152 .max()
153 .expect("No sequences");
154 let padding_tok = T::zero();
155 let mut seqs_tensors = Vec::new();
157 let mut seqlen_offsets = Vec::new();
158 let mut context_lens = Vec::new();
159 let mut position_ids = Vec::new();
160 let mut slot_mappings = Vec::new();
161 let mut block_tables = Vec::new();
162 let mut paged_attn_context_lens = Vec::new();
163 let flash_attn = crate::using_flash_attn();
164 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
165 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
166 for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
167 let prompt_len = ctxt.len();
168 let offset = last_n_context_len.unwrap_or_default();
169 seqlen_offsets.push(offset.1 + chunk_offset_toks);
170
171 position_ids.push(ctxt.len() + chunk_offset_toks);
172 let mut ctxt = ctxt.to_vec();
173 ctxt.extend(std::iter::repeat_n(
174 padding_tok,
175 max_len.saturating_sub(ctxt.len()),
176 ));
177 if return_raw_logits {
179 if last_n_context_len.is_some() {
180 anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181 }
182
183 context_lens.push((0, ctxt.len()));
184 } else {
185 context_lens.push((
186 ctxt.len()
187 .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
188 last_n_context_len.map(|(a, _)| a).unwrap_or(1),
189 ));
190 }
191
192 if flash_attn {
193 seqlens_q.push(ctxt.len() as u32);
194 seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
195 }
196
197 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
198
199 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
200 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
201 let table = block_engine.block_tables.get(seq_id);
202
203 if table.is_none() {
204 slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
206 continue;
207 }
208 let table = table
209 .unwrap()
210 .iter()
211 .map(|block| block.deref_mut().block_id)
212 .collect::<Vec<_>>();
213
214 let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
215 prompt_len.saturating_sub(sliding_window)
216 } else {
217 0
218 };
219
220 let mut slot_mapping = Vec::new();
221 let mut ctxt_len = Vec::new();
222 for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223 if i < start_idx {
224 slot_mapping.push(_PAD_SLOT_ID);
226 }
227 ctxt_len.push(i);
228
229 let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230 panic!(
231 "Block table is too small (prompt)! i={} block_size={} table_len={}",
232 i,
233 paged_attn_metadata.block_size,
234 table.len()
235 );
236 } else {
237 table.get(i / paged_attn_metadata.block_size).unwrap()
238 };
239 let block_offset = i % paged_attn_metadata.block_size;
240 let slot = block_number
242 .checked_mul(paged_attn_metadata.block_size)
243 .and_then(|v| v.checked_add(block_offset))
244 .expect("Slot calculation overflowed");
245 slot_mapping.push(
246 slot.try_into()
247 .expect("Slot value too large for target integer type"),
248 );
249 block_tables.push(table.clone());
250 }
251 slot_mappings.push(slot_mapping);
252 paged_attn_context_lens.push(ctxt_len);
253 }
254 }
255
256 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
257 let max_q = *seqlens_q
260 .iter()
261 .max()
262 .expect("seqlens_q should not be empty when flash_attn is enabled");
263 let max_k = *seqlens_k
264 .iter()
265 .max()
266 .expect("seqlens_k should not be empty when flash_attn is enabled");
267 let seqlens_q = Tensor::new(seqlens_q, device)?
268 .to_dtype(DType::F32)?
269 .cumsum(0)?
270 .to_dtype(DType::U32)?;
271 let seqlens_k = Tensor::new(seqlens_k, device)?
272 .to_dtype(DType::F32)?
273 .cumsum(0)?
274 .to_dtype(DType::U32)?;
275
276 let mut seqlens_q_map = HashMap::new();
277 let mut seqlens_k_map = HashMap::new();
278
279 let devices = mapper.unwrap().get_unique_devices();
280 for device in devices {
281 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
282 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
283 }
284 (max_q, max_k, seqlens_q_map, seqlens_k_map)
285 } else {
286 (0, 0, HashMap::new(), HashMap::new())
287 };
288
289 let input = Tensor::cat(&seqs_tensors, 0).unwrap();
290
291 let paged_attn_meta = if paged_attn_metadata.is_some() {
292 let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
293 let slot_mappings =
294 _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
295
296 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
297 let block_tables = _make_tensor_with_pad(
298 block_tables
299 .iter()
300 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
301 .collect::<Vec<_>>(),
302 max_block_table_len,
303 0,
304 device,
305 )?;
306 let block_tables = block_tables.reshape(((), max_block_table_len))?;
307
308 let max_context_len = paged_attn_context_lens
309 .iter()
310 .map(|x| x.len())
311 .max()
312 .unwrap();
313
314 let context_lens = _make_tensor_with_pad(
315 paged_attn_context_lens
316 .iter()
317 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
318 .collect::<Vec<_>>(),
319 max_context_len,
320 0,
321 device,
322 )?
323 .reshape(((),))?;
324
325 let devices = mapper.unwrap().get_unique_devices();
327 let mut slot_mappings_map = HashMap::new();
328 let mut block_tables_map = HashMap::new();
329 let mut context_lens_map = HashMap::new();
330
331 for device in devices {
332 slot_mappings_map
333 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
334 block_tables_map
335 .insert(device.location(), block_tables.clone().to_device(&device)?);
336 context_lens_map
337 .insert(device.location(), context_lens.clone().to_device(&device)?);
338 }
339
340 Some(PagedAttentionInputMetadata {
341 slot_mappings: slot_mappings_map,
342 block_tables: Some(block_tables_map),
343 context_lens: Some(context_lens_map),
344 max_context_len: Some(max_context_len),
345 is_first_prompt_chunk: chunk_offset_toks == 0,
346 })
347 } else {
348 None
349 };
350
351 Ok(InputMetadata {
352 input,
353 positions: seqlen_offsets,
354 context_lens,
355 position_ids,
356 paged_attn_meta,
357 flash_meta: FlashParams {
358 max_k,
359 max_q,
360 cumulative_seqlens_k: seqlens_k_map,
361 cumulative_seqlens_q: seqlens_q_map,
362 causal: true,
363 },
364 })
365 }
366
367 fn make_completion_chunk<T: WithDType>(
368 toks: Vec<&[T]>,
369 input_seqs: &[&mut Sequence],
370 device: &Device,
371 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
372 mapper: Option<&dyn DeviceMapper>,
373 ) -> Result<InputMetadata> {
374 let flash_attn = crate::using_flash_attn();
376 let mut seqs_tensors = Vec::new();
377 let mut seqlen_offsets = Vec::new();
378 let mut context_lens = Vec::new();
379 let mut position_ids = Vec::new();
380
381 let mut slot_mappings = Vec::new();
382 let mut block_tables = Vec::new();
383 let mut paged_attn_context_lens = Vec::new();
384 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
385 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
386 for (seq, ctxt) in input_seqs.iter().zip(toks) {
387 let start_pos = ctxt.len().saturating_sub(1);
388 let ctxt = ctxt[start_pos..].to_vec();
389 seqlen_offsets.push(start_pos);
390 context_lens.push((0, 1));
391 position_ids.push(seq.len());
392
393 if flash_attn {
394 seqlens_q.push(ctxt.len() as u32);
395 seqlens_k.push((ctxt.len() + start_pos) as u32);
396 }
397
398 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
399
400 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
401 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
402 let table = block_engine.block_tables.get(seq.id()).unwrap();
403
404 let table = table
405 .iter()
406 .map(|block| block.deref_mut().block_id)
407 .collect::<Vec<_>>();
408
409 let block_pos = start_pos - seq.token_offset();
410 let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
411 panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
412 } else {
413 table
414 .get(block_pos / paged_attn_metadata.block_size)
415 .unwrap()
416 };
417 let block_offset = block_pos % paged_attn_metadata.block_size;
418 let slot = block_number
420 .checked_mul(paged_attn_metadata.block_size)
421 .and_then(|v| v.checked_add(block_offset))
422 .expect("Slot calculation overflowed");
423 let slot = slot
424 .try_into()
425 .expect("Slot value too large for target integer type");
426 slot_mappings.push(vec![slot]);
427
428 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
429 let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
430 let slide_idx = if table.len() > sliding_window_blocks {
431 table.len() - sliding_window_blocks
432 } else {
433 0
434 };
435 block_tables.push(table.get(slide_idx..).unwrap().to_vec());
436 } else {
437 block_tables.push(table);
438 }
439
440 let paged_attn_context_len =
441 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
442 seq.len().min(sliding_window)
443 } else {
444 seq.len()
445 };
446 paged_attn_context_lens.push(paged_attn_context_len);
447 }
448 }
449
450 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
451 let max_q = *seqlens_q
454 .iter()
455 .max()
456 .expect("seqlens_q should not be empty when flash_attn is enabled");
457 let max_k = *seqlens_k
458 .iter()
459 .max()
460 .expect("seqlens_k should not be empty when flash_attn is enabled");
461 let seqlens_q = Tensor::new(seqlens_q, device)?
462 .to_dtype(DType::F32)?
463 .cumsum(0)?
464 .to_dtype(DType::U32)?;
465 let seqlens_k = Tensor::new(seqlens_k, device)?
466 .to_dtype(DType::F32)?
467 .cumsum(0)?
468 .to_dtype(DType::U32)?;
469
470 let mut seqlens_q_map = HashMap::new();
471 let mut seqlens_k_map = HashMap::new();
472
473 let devices = mapper.unwrap().get_unique_devices();
474 for device in devices {
475 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
476 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
477 }
478 (max_q, max_k, seqlens_q_map, seqlens_k_map)
479 } else {
480 (0, 0, HashMap::new(), HashMap::new())
481 };
482
483 let paged_attn_meta = if paged_attn_metadata.is_some() {
484 let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
485
486 let max_block_table_len = block_tables
487 .iter()
488 .map(|x| x.len())
489 .max()
490 .expect("block_tables should not be empty when paged attention is enabled");
491
492 let block_tables = _make_tensor_with_pad(
493 block_tables
494 .iter()
495 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
496 .collect::<Vec<_>>(),
497 max_block_table_len,
498 0,
499 device,
500 )?;
501 let block_tables = block_tables.reshape(((), max_block_table_len))?;
502
503 let max_context_len = paged_attn_context_lens.iter().max().unwrap();
504
505 let context_lens = Tensor::from_vec(
506 paged_attn_context_lens
507 .iter()
508 .map(|x| *x as u32)
509 .collect::<Vec<_>>(),
510 (paged_attn_context_lens.len(),),
511 device,
512 )?;
513
514 let devices = mapper.unwrap().get_unique_devices();
516 let mut slot_mappings_map = HashMap::new();
517 let mut block_tables_map = HashMap::new();
518 let mut context_lens_map = HashMap::new();
519
520 for device in devices {
521 slot_mappings_map
522 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
523 block_tables_map
524 .insert(device.location(), block_tables.clone().to_device(&device)?);
525 context_lens_map
526 .insert(device.location(), context_lens.clone().to_device(&device)?);
527 }
528
529 Some(PagedAttentionInputMetadata {
530 slot_mappings: slot_mappings_map,
531 block_tables: Some(block_tables_map),
532 context_lens: Some(context_lens_map),
533 max_context_len: Some(*max_context_len),
534 is_first_prompt_chunk: false,
535 })
536 } else {
537 None
538 };
539
540 Ok(InputMetadata {
541 input: Tensor::cat(&seqs_tensors, 0).unwrap(),
542 positions: seqlen_offsets,
543 context_lens,
544 position_ids,
545 paged_attn_meta,
546 flash_meta: FlashParams {
547 max_k,
548 max_q,
549 cumulative_seqlens_k: seqlens_k_map,
550 cumulative_seqlens_q: seqlens_q_map,
551 causal: true,
552 },
553 })
554 }
555
556 #[allow(clippy::too_many_arguments)]
557 pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
558 toks: Vec<&[T]>,
559 input_seqs: &[&mut Sequence],
560 device: &Device,
561 last_n_context_len: Option<(usize, usize)>,
562 return_raw_logits: bool,
563 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
564 mapper: Option<&dyn DeviceMapper>,
565 ) -> Result<InnerInputProcessorOutput> {
566 let offset = input_seqs[0].token_offset();
567 make_prompt_chunk(
568 offset,
569 toks,
570 &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
571 device,
572 last_n_context_len,
573 return_raw_logits,
574 paged_attn_metadata,
575 mapper,
576 )
577 .map(|inputs| InnerInputProcessorOutput {
578 inputs,
579 seq_indices: (0..input_seqs.len()).collect(),
580 })
581 }
582
583 #[allow(clippy::too_many_arguments)]
584 pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
585 toks: Vec<&[T]>,
586 input_seqs: &[&mut Sequence],
587 device: &Device,
588 no_kv_cache: bool,
589 last_n_context_len: Option<(usize, usize)>,
590 return_raw_logits: bool,
591 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
592 mapper: Option<&dyn DeviceMapper>,
593 ) -> Result<InnerInputProcessorOutput> {
594 if no_kv_cache {
595 return get_prompt_input(
596 toks,
597 input_seqs,
598 device,
599 last_n_context_len,
600 return_raw_logits,
601 paged_attn_metadata,
602 mapper,
603 );
604 }
605
606 make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
607 InnerInputProcessorOutput {
608 inputs,
609 seq_indices: (0..input_seqs.len()).collect(),
610 }
611 })
612 }
613
614 #[derive(Clone)]
615 pub struct ModelInputs {
616 pub input_ids: Tensor,
617 pub input_ids_full: Option<Tensor>,
618 pub seqlen_offsets: Vec<usize>,
619 pub seqlen_offsets_full: Option<Vec<usize>>,
620 pub context_lens: Vec<(usize, usize)>,
621 pub position_ids: Vec<usize>,
622 pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
623 pub flash_meta: FlashParams,
624 pub flash_meta_full: Option<FlashParams>,
625 }
626
627 pub struct TextInputsProcessor;
628
629 impl InputsProcessor for TextInputsProcessor {
630 fn process_inputs(
631 &self,
632 _: Option<Arc<Tokenizer>>,
633 input_seqs: &mut [&mut Sequence],
634 is_prompt: bool,
635 is_xlora: bool,
636 device: &Device,
637 no_kv_cache: bool,
638 last_n_context_len: Option<(usize, usize)>,
639 return_raw_logits: bool,
640 _: Option<Arc<dyn Any>>,
641 mut paged_attn_metadata: Option<PagedAttentionMeta>,
642 mapper: Option<&dyn DeviceMapper>,
643 ) -> Result<InputProcessorOutput> {
644 if is_xlora && !is_prompt {
645 let prompt = get_prompt_input(
646 input_seqs
647 .iter()
648 .map(|seq| seq.get_toks())
649 .collect::<Vec<_>>(),
650 input_seqs,
651 device,
652 last_n_context_len,
653 return_raw_logits,
654 paged_attn_metadata.as_mut(),
655 mapper,
656 )?;
657 let completion = get_completion_input(
658 input_seqs
659 .iter()
660 .map(|seq| seq.get_toks())
661 .collect::<Vec<_>>(),
662 input_seqs,
663 device,
664 no_kv_cache,
665 last_n_context_len,
666 return_raw_logits,
667 paged_attn_metadata.as_mut(),
668 mapper,
669 )?;
670 let InnerInputProcessorOutput {
671 inputs:
672 InputMetadata {
673 input: input_ids_full,
674 positions: seqlen_offsets_full,
675 context_lens: _,
676 position_ids,
677 paged_attn_meta: _,
678 flash_meta: flash_meta_full,
679 },
680 seq_indices,
681 } = prompt;
682 let InnerInputProcessorOutput {
683 inputs:
684 InputMetadata {
685 input: input_ids,
686 positions: seqlen_offsets,
687 context_lens,
688 position_ids: _,
689 paged_attn_meta,
690 flash_meta,
691 },
692 seq_indices: _,
693 } = completion;
694 let inputs: Box<dyn Any> = Box::new(ModelInputs {
695 input_ids,
696 input_ids_full: Some(input_ids_full),
697 seqlen_offsets,
698 seqlen_offsets_full: Some(seqlen_offsets_full),
699 context_lens,
700 position_ids,
701 paged_attn_meta,
702 flash_meta,
703 flash_meta_full: Some(flash_meta_full),
704 });
705 Ok(InputProcessorOutput {
706 inputs,
707 seq_indices,
708 })
709 } else if is_xlora && is_prompt {
710 let metadata = get_prompt_input(
711 input_seqs
712 .iter()
713 .map(|seq| seq.get_toks())
714 .collect::<Vec<_>>(),
715 input_seqs,
716 device,
717 last_n_context_len,
718 return_raw_logits,
719 paged_attn_metadata.as_mut(),
720 mapper,
721 )?;
722 let InnerInputProcessorOutput {
723 inputs:
724 InputMetadata {
725 input: input_ids,
726 positions: seqlen_offsets,
727 context_lens,
728 position_ids,
729 paged_attn_meta,
730 flash_meta,
731 },
732 seq_indices,
733 } = metadata;
734 let inputs: Box<dyn Any> = Box::new(ModelInputs {
735 input_ids: input_ids.clone(),
736 input_ids_full: Some(input_ids),
737 seqlen_offsets: seqlen_offsets.clone(),
738 seqlen_offsets_full: Some(seqlen_offsets),
739 context_lens,
740 position_ids,
741 paged_attn_meta,
742 flash_meta: flash_meta.clone(),
743 flash_meta_full: Some(flash_meta),
744 });
745 Ok(InputProcessorOutput {
746 inputs,
747 seq_indices,
748 })
749 } else if is_prompt {
750 let metadata = get_prompt_input(
751 input_seqs
752 .iter()
753 .map(|seq| seq.get_toks())
754 .collect::<Vec<_>>(),
755 input_seqs,
756 device,
757 last_n_context_len,
758 return_raw_logits,
759 paged_attn_metadata.as_mut(),
760 mapper,
761 )?;
762 let InnerInputProcessorOutput {
763 inputs:
764 InputMetadata {
765 input: input_ids,
766 positions: seqlen_offsets,
767 context_lens,
768 position_ids,
769 paged_attn_meta,
770 flash_meta,
771 },
772 seq_indices,
773 } = metadata;
774 let inputs: Box<dyn Any> = Box::new(ModelInputs {
775 input_ids,
776 input_ids_full: None,
777 seqlen_offsets,
778 seqlen_offsets_full: None,
779 context_lens,
780 position_ids,
781 paged_attn_meta,
782 flash_meta,
783 flash_meta_full: None,
784 });
785 Ok(InputProcessorOutput {
786 inputs,
787 seq_indices,
788 })
789 } else {
790 let metadata = get_completion_input(
791 input_seqs
792 .iter()
793 .map(|seq| seq.get_toks())
794 .collect::<Vec<_>>(),
795 input_seqs,
796 device,
797 no_kv_cache,
798 last_n_context_len,
799 return_raw_logits,
800 paged_attn_metadata.as_mut(),
801 mapper,
802 )?;
803 let InnerInputProcessorOutput {
804 inputs:
805 InputMetadata {
806 input: input_ids,
807 positions: seqlen_offsets,
808 context_lens,
809 position_ids,
810 paged_attn_meta,
811 flash_meta,
812 },
813 seq_indices,
814 } = metadata;
815 let inputs: Box<dyn Any> = Box::new(ModelInputs {
816 input_ids,
817 input_ids_full: None,
818 seqlen_offsets,
819 seqlen_offsets_full: None,
820 context_lens,
821 position_ids,
822 paged_attn_meta,
823 flash_meta,
824 flash_meta_full: None,
825 });
826 Ok(InputProcessorOutput {
827 inputs,
828 seq_indices,
829 })
830 }
831 }
832
833 fn get_type(&self) -> InputsProcessorType {
834 InputsProcessorType::Text
835 }
836 }
837}