1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14 Text,
15 Vision,
16 Embedding,
17}
18
19pub struct InputProcessorOutput {
20 pub inputs: Box<dyn Any>,
21 pub seq_indices: Vec<usize>,
22}
23
24pub trait InputsProcessor {
26 #[allow(clippy::too_many_arguments)]
31 fn process_inputs(
32 &self,
33 tokenizer: Option<Arc<Tokenizer>>,
34 input_seqs: &mut [&mut Sequence],
35 is_prompt: bool,
36 is_xlora: bool,
37 device: &Device,
38 no_kv_cache: bool,
39 last_n_context_len: Option<(usize, usize)>,
40 return_raw_logits: bool,
41 other_config: Option<Arc<dyn Any>>,
42 paged_attn_metadata: Option<PagedAttentionMeta>,
43 mapper: Option<&dyn DeviceMapper>,
44 ) -> Result<InputProcessorOutput>;
45
46 fn get_type(&self) -> InputsProcessorType;
47}
48
49pub mod text_models_inputs_processor {
52 use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
53
54 use anyhow::Result;
55 use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
56 use tokenizers::Tokenizer;
57
58 use crate::{
59 device_map::DeviceMapper,
60 get_mut_arcmutex,
61 paged_attention::{BlockEngine, _PAD_SLOT_ID},
62 sequence::Sequence,
63 };
64
65 use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
66
67 fn _make_tensor_with_pad<D: WithDType>(
68 x: Vec<Vec<D>>,
69 max_len: usize,
70 pad: D,
71 device: &Device,
72 ) -> Result<Tensor> {
73 let mut padded_x = Vec::new();
74 for mut x_i in x {
75 assert!(x_i.len() <= max_len);
76 x_i.extend([pad].repeat(max_len - x_i.len()));
77 let shape = (x_i.len(),);
78 padded_x.push(Tensor::from_vec(x_i, shape, device)?);
79 }
80 Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
81 }
82
83 pub struct PagedAttentionMeta {
84 pub sliding_window: Option<usize>,
85 pub block_size: usize,
86 pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
87 }
88
89 #[derive(Clone, Debug)]
90 #[allow(dead_code)]
91 pub struct PagedAttentionInputMetadata {
92 pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
93 pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
94 pub slot_mappings: HashMap<DeviceLocation, Tensor>,
95 pub max_context_len: Option<usize>,
96 pub is_first_prompt_chunk: bool,
97 }
98
99 impl PagedAttentionInputMetadata {
100 pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
103 Ok(PagedAttentionInputMetadata {
104 block_tables: None,
105 context_lens: None,
106 max_context_len: None,
107 slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
108 is_first_prompt_chunk: true,
109 })
110 }
111 }
112
113 #[derive(Clone, Debug)]
114 pub struct FlashParams {
115 pub max_q: u32,
116 pub max_k: u32,
117 pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
118 pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
119 pub causal: bool,
120 }
121
122 pub struct InputMetadata {
123 pub input: Tensor,
124 pub positions: Vec<usize>,
125 pub context_lens: Vec<(usize, usize)>, pub position_ids: Vec<usize>,
127 pub paged_attn_meta: Option<PagedAttentionInputMetadata>, pub flash_meta: FlashParams,
129 }
130
131 pub struct InnerInputProcessorOutput {
132 pub inputs: InputMetadata,
133 pub seq_indices: Vec<usize>,
134 }
135
136 #[allow(clippy::too_many_arguments)]
139 pub fn make_prompt_chunk<T: WithDType + Debug>(
140 chunk_offset_toks: usize,
141 toks: Vec<&[T]>,
142 seq_ids: &[usize],
143 device: &Device,
144 last_n_context_len: Option<(usize, usize)>,
145 return_raw_logits: bool,
146 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
147 mapper: Option<&dyn DeviceMapper>,
148 ) -> Result<InputMetadata> {
149 let max_len = toks
150 .iter()
151 .map(|seq| seq.len())
152 .max()
153 .expect("No sequences");
154 let padding_tok = T::zero();
155 let mut seqs_tensors = Vec::new();
157 let mut seqlen_offsets = Vec::new();
158 let mut context_lens = Vec::new();
159 let mut position_ids = Vec::new();
160 let mut slot_mappings = Vec::new();
161 let mut block_tables = Vec::new();
162 let mut paged_attn_context_lens = Vec::new();
163 let flash_attn = crate::using_flash_attn();
164 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
165 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
166 for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
167 let prompt_len = ctxt.len();
168 let offset = last_n_context_len.unwrap_or_default();
169 seqlen_offsets.push(offset.1 + chunk_offset_toks);
170
171 position_ids.push(ctxt.len() + chunk_offset_toks);
172 let mut ctxt = ctxt.to_vec();
173 ctxt.extend(std::iter::repeat_n(
174 padding_tok,
175 max_len.saturating_sub(ctxt.len()),
176 ));
177 if return_raw_logits {
179 if last_n_context_len.is_some() {
180 anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181 }
182
183 context_lens.push((0, ctxt.len()));
184 } else {
185 context_lens.push((
186 ctxt.len()
187 .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
188 last_n_context_len.map(|(a, _)| a).unwrap_or(1),
189 ));
190 }
191
192 if flash_attn {
193 seqlens_q.push(ctxt.len() as u32);
194 seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
195 }
196
197 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
198
199 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
200 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
201 let table = block_engine.block_tables.get(seq_id);
202
203 if table.is_none() {
204 slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
206 continue;
207 }
208 let table = table
209 .unwrap()
210 .iter()
211 .map(|block| block.deref_mut().block_id)
212 .collect::<Vec<_>>();
213
214 let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
215 prompt_len.saturating_sub(sliding_window)
216 } else {
217 0
218 };
219
220 let mut slot_mapping = Vec::new();
221 let mut ctxt_len = Vec::new();
222 for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223 if i < start_idx {
224 slot_mapping.push(_PAD_SLOT_ID);
226 }
227 ctxt_len.push(i);
228
229 let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230 panic!(
231 "Block table is too small (prompt)! i={} block_size={} table_len={}",
232 i,
233 paged_attn_metadata.block_size,
234 table.len()
235 );
236 } else {
237 table.get(i / paged_attn_metadata.block_size).unwrap()
238 };
239 let block_offset = i % paged_attn_metadata.block_size;
240 let slot = block_number * paged_attn_metadata.block_size + block_offset;
241 slot_mapping.push(slot.try_into().unwrap());
242 block_tables.push(table.clone());
243 }
244 slot_mappings.push(slot_mapping);
245 paged_attn_context_lens.push(ctxt_len);
246 }
247 }
248
249 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
250 let max_q = *seqlens_q.iter().max().unwrap();
251 let max_k = *seqlens_k.iter().max().unwrap();
252 let seqlens_q = Tensor::new(seqlens_q, device)?
253 .to_dtype(DType::F32)?
254 .cumsum(0)?
255 .to_dtype(DType::U32)?;
256 let seqlens_k = Tensor::new(seqlens_k, device)?
257 .to_dtype(DType::F32)?
258 .cumsum(0)?
259 .to_dtype(DType::U32)?;
260
261 let mut seqlens_q_map = HashMap::new();
262 let mut seqlens_k_map = HashMap::new();
263
264 let devices = mapper.unwrap().get_unique_devices();
265 for device in devices {
266 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
267 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
268 }
269 (max_q, max_k, seqlens_q_map, seqlens_k_map)
270 } else {
271 (0, 0, HashMap::new(), HashMap::new())
272 };
273
274 let input = Tensor::cat(&seqs_tensors, 0).unwrap();
275
276 let paged_attn_meta = if paged_attn_metadata.is_some() {
277 let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
278 let slot_mappings =
279 _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
280
281 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
282 let block_tables = _make_tensor_with_pad(
283 block_tables
284 .iter()
285 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
286 .collect::<Vec<_>>(),
287 max_block_table_len,
288 0,
289 device,
290 )?;
291 let block_tables = block_tables.reshape(((), max_block_table_len))?;
292
293 let max_context_len = paged_attn_context_lens
294 .iter()
295 .map(|x| x.len())
296 .max()
297 .unwrap();
298
299 let context_lens = _make_tensor_with_pad(
300 paged_attn_context_lens
301 .iter()
302 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
303 .collect::<Vec<_>>(),
304 max_context_len,
305 0,
306 device,
307 )?
308 .reshape(((),))?;
309
310 let devices = mapper.unwrap().get_unique_devices();
312 let mut slot_mappings_map = HashMap::new();
313 let mut block_tables_map = HashMap::new();
314 let mut context_lens_map = HashMap::new();
315
316 for device in devices {
317 slot_mappings_map
318 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
319 block_tables_map
320 .insert(device.location(), block_tables.clone().to_device(&device)?);
321 context_lens_map
322 .insert(device.location(), context_lens.clone().to_device(&device)?);
323 }
324
325 Some(PagedAttentionInputMetadata {
326 slot_mappings: slot_mappings_map,
327 block_tables: Some(block_tables_map),
328 context_lens: Some(context_lens_map),
329 max_context_len: Some(max_context_len),
330 is_first_prompt_chunk: chunk_offset_toks == 0,
331 })
332 } else {
333 None
334 };
335
336 Ok(InputMetadata {
337 input,
338 positions: seqlen_offsets,
339 context_lens,
340 position_ids,
341 paged_attn_meta,
342 flash_meta: FlashParams {
343 max_k,
344 max_q,
345 cumulative_seqlens_k: seqlens_k_map,
346 cumulative_seqlens_q: seqlens_q_map,
347 causal: true,
348 },
349 })
350 }
351
352 fn make_completion_chunk<T: WithDType>(
353 toks: Vec<&[T]>,
354 input_seqs: &[&mut Sequence],
355 device: &Device,
356 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
357 mapper: Option<&dyn DeviceMapper>,
358 ) -> Result<InputMetadata> {
359 let flash_attn = crate::using_flash_attn();
361 let mut seqs_tensors = Vec::new();
362 let mut seqlen_offsets = Vec::new();
363 let mut context_lens = Vec::new();
364 let mut position_ids = Vec::new();
365
366 let mut slot_mappings = Vec::new();
367 let mut block_tables = Vec::new();
368 let mut paged_attn_context_lens = Vec::new();
369 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
370 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
371 for (seq, ctxt) in input_seqs.iter().zip(toks) {
372 let start_pos = ctxt.len().saturating_sub(1);
373 let ctxt = ctxt[start_pos..].to_vec();
374 seqlen_offsets.push(start_pos);
375 context_lens.push((0, 1));
376 position_ids.push(seq.len());
377
378 if flash_attn {
379 seqlens_q.push(ctxt.len() as u32);
380 seqlens_k.push((ctxt.len() + start_pos) as u32);
381 }
382
383 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
384
385 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
386 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
387 let table = block_engine.block_tables.get(seq.id()).unwrap();
388
389 let table = table
390 .iter()
391 .map(|block| block.deref_mut().block_id)
392 .collect::<Vec<_>>();
393
394 let block_pos = start_pos - seq.token_offset();
395 let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
396 panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
397 } else {
398 table
399 .get(block_pos / paged_attn_metadata.block_size)
400 .unwrap()
401 };
402 let block_offset = block_pos % paged_attn_metadata.block_size;
403 let slot = block_number * paged_attn_metadata.block_size + block_offset;
404 let slot = slot.try_into().unwrap();
405 slot_mappings.push(vec![slot]);
406
407 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
408 let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
409 let slide_idx = if table.len() > sliding_window_blocks {
410 table.len() - sliding_window_blocks
411 } else {
412 0
413 };
414 block_tables.push(table.get(slide_idx..).unwrap().to_vec());
415 } else {
416 block_tables.push(table);
417 }
418
419 let paged_attn_context_len =
420 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
421 seq.len().min(sliding_window)
422 } else {
423 seq.len()
424 };
425 paged_attn_context_lens.push(paged_attn_context_len);
426 }
427 }
428
429 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
430 let max_q = *seqlens_q.iter().max().unwrap();
431 let max_k = *seqlens_k.iter().max().unwrap();
432 let seqlens_q = Tensor::new(seqlens_q, device)?
433 .to_dtype(DType::F32)?
434 .cumsum(0)?
435 .to_dtype(DType::U32)?;
436 let seqlens_k = Tensor::new(seqlens_k, device)?
437 .to_dtype(DType::F32)?
438 .cumsum(0)?
439 .to_dtype(DType::U32)?;
440
441 let mut seqlens_q_map = HashMap::new();
442 let mut seqlens_k_map = HashMap::new();
443
444 let devices = mapper.unwrap().get_unique_devices();
445 for device in devices {
446 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
447 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
448 }
449 (max_q, max_k, seqlens_q_map, seqlens_k_map)
450 } else {
451 (0, 0, HashMap::new(), HashMap::new())
452 };
453
454 let paged_attn_meta = if paged_attn_metadata.is_some() {
455 let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
456
457 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
458
459 let block_tables = _make_tensor_with_pad(
460 block_tables
461 .iter()
462 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
463 .collect::<Vec<_>>(),
464 max_block_table_len,
465 0,
466 device,
467 )?;
468 let block_tables = block_tables.reshape(((), max_block_table_len))?;
469
470 let max_context_len = paged_attn_context_lens.iter().max().unwrap();
471
472 let context_lens = Tensor::from_vec(
473 paged_attn_context_lens
474 .iter()
475 .map(|x| *x as u32)
476 .collect::<Vec<_>>(),
477 (paged_attn_context_lens.len(),),
478 device,
479 )?;
480
481 let devices = mapper.unwrap().get_unique_devices();
483 let mut slot_mappings_map = HashMap::new();
484 let mut block_tables_map = HashMap::new();
485 let mut context_lens_map = HashMap::new();
486
487 for device in devices {
488 slot_mappings_map
489 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
490 block_tables_map
491 .insert(device.location(), block_tables.clone().to_device(&device)?);
492 context_lens_map
493 .insert(device.location(), context_lens.clone().to_device(&device)?);
494 }
495
496 Some(PagedAttentionInputMetadata {
497 slot_mappings: slot_mappings_map,
498 block_tables: Some(block_tables_map),
499 context_lens: Some(context_lens_map),
500 max_context_len: Some(*max_context_len),
501 is_first_prompt_chunk: false,
502 })
503 } else {
504 None
505 };
506
507 Ok(InputMetadata {
508 input: Tensor::cat(&seqs_tensors, 0).unwrap(),
509 positions: seqlen_offsets,
510 context_lens,
511 position_ids,
512 paged_attn_meta,
513 flash_meta: FlashParams {
514 max_k,
515 max_q,
516 cumulative_seqlens_k: seqlens_k_map,
517 cumulative_seqlens_q: seqlens_q_map,
518 causal: true,
519 },
520 })
521 }
522
523 #[allow(clippy::too_many_arguments)]
524 pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
525 toks: Vec<&[T]>,
526 input_seqs: &[&mut Sequence],
527 device: &Device,
528 last_n_context_len: Option<(usize, usize)>,
529 return_raw_logits: bool,
530 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
531 mapper: Option<&dyn DeviceMapper>,
532 ) -> Result<InnerInputProcessorOutput> {
533 let offset = input_seqs[0].token_offset();
534 make_prompt_chunk(
535 offset,
536 toks,
537 &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
538 device,
539 last_n_context_len,
540 return_raw_logits,
541 paged_attn_metadata,
542 mapper,
543 )
544 .map(|inputs| InnerInputProcessorOutput {
545 inputs,
546 seq_indices: (0..input_seqs.len()).collect(),
547 })
548 }
549
550 #[allow(clippy::too_many_arguments)]
551 pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
552 toks: Vec<&[T]>,
553 input_seqs: &[&mut Sequence],
554 device: &Device,
555 no_kv_cache: bool,
556 last_n_context_len: Option<(usize, usize)>,
557 return_raw_logits: bool,
558 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
559 mapper: Option<&dyn DeviceMapper>,
560 ) -> Result<InnerInputProcessorOutput> {
561 if no_kv_cache {
562 return get_prompt_input(
563 toks,
564 input_seqs,
565 device,
566 last_n_context_len,
567 return_raw_logits,
568 paged_attn_metadata,
569 mapper,
570 );
571 }
572
573 make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
574 InnerInputProcessorOutput {
575 inputs,
576 seq_indices: (0..input_seqs.len()).collect(),
577 }
578 })
579 }
580
581 #[derive(Clone)]
582 pub struct ModelInputs {
583 pub input_ids: Tensor,
584 pub input_ids_full: Option<Tensor>,
585 pub seqlen_offsets: Vec<usize>,
586 pub seqlen_offsets_full: Option<Vec<usize>>,
587 pub context_lens: Vec<(usize, usize)>,
588 pub position_ids: Vec<usize>,
589 pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
590 pub flash_meta: FlashParams,
591 pub flash_meta_full: Option<FlashParams>,
592 }
593
594 pub struct TextInputsProcessor;
595
596 impl InputsProcessor for TextInputsProcessor {
597 fn process_inputs(
598 &self,
599 _: Option<Arc<Tokenizer>>,
600 input_seqs: &mut [&mut Sequence],
601 is_prompt: bool,
602 is_xlora: bool,
603 device: &Device,
604 no_kv_cache: bool,
605 last_n_context_len: Option<(usize, usize)>,
606 return_raw_logits: bool,
607 _: Option<Arc<dyn Any>>,
608 mut paged_attn_metadata: Option<PagedAttentionMeta>,
609 mapper: Option<&dyn DeviceMapper>,
610 ) -> Result<InputProcessorOutput> {
611 if is_xlora && !is_prompt {
612 let prompt = get_prompt_input(
613 input_seqs
614 .iter()
615 .map(|seq| seq.get_toks())
616 .collect::<Vec<_>>(),
617 input_seqs,
618 device,
619 last_n_context_len,
620 return_raw_logits,
621 paged_attn_metadata.as_mut(),
622 mapper,
623 )?;
624 let completion = get_completion_input(
625 input_seqs
626 .iter()
627 .map(|seq| seq.get_toks())
628 .collect::<Vec<_>>(),
629 input_seqs,
630 device,
631 no_kv_cache,
632 last_n_context_len,
633 return_raw_logits,
634 paged_attn_metadata.as_mut(),
635 mapper,
636 )?;
637 let InnerInputProcessorOutput {
638 inputs:
639 InputMetadata {
640 input: input_ids_full,
641 positions: seqlen_offsets_full,
642 context_lens: _,
643 position_ids,
644 paged_attn_meta: _,
645 flash_meta: flash_meta_full,
646 },
647 seq_indices,
648 } = prompt;
649 let InnerInputProcessorOutput {
650 inputs:
651 InputMetadata {
652 input: input_ids,
653 positions: seqlen_offsets,
654 context_lens,
655 position_ids: _,
656 paged_attn_meta,
657 flash_meta,
658 },
659 seq_indices: _,
660 } = completion;
661 let inputs: Box<dyn Any> = Box::new(ModelInputs {
662 input_ids,
663 input_ids_full: Some(input_ids_full),
664 seqlen_offsets,
665 seqlen_offsets_full: Some(seqlen_offsets_full),
666 context_lens,
667 position_ids,
668 paged_attn_meta,
669 flash_meta,
670 flash_meta_full: Some(flash_meta_full),
671 });
672 Ok(InputProcessorOutput {
673 inputs,
674 seq_indices,
675 })
676 } else if is_xlora && is_prompt {
677 let metadata = get_prompt_input(
678 input_seqs
679 .iter()
680 .map(|seq| seq.get_toks())
681 .collect::<Vec<_>>(),
682 input_seqs,
683 device,
684 last_n_context_len,
685 return_raw_logits,
686 paged_attn_metadata.as_mut(),
687 mapper,
688 )?;
689 let InnerInputProcessorOutput {
690 inputs:
691 InputMetadata {
692 input: input_ids,
693 positions: seqlen_offsets,
694 context_lens,
695 position_ids,
696 paged_attn_meta,
697 flash_meta,
698 },
699 seq_indices,
700 } = metadata;
701 let inputs: Box<dyn Any> = Box::new(ModelInputs {
702 input_ids: input_ids.clone(),
703 input_ids_full: Some(input_ids),
704 seqlen_offsets: seqlen_offsets.clone(),
705 seqlen_offsets_full: Some(seqlen_offsets),
706 context_lens,
707 position_ids,
708 paged_attn_meta,
709 flash_meta: flash_meta.clone(),
710 flash_meta_full: Some(flash_meta),
711 });
712 Ok(InputProcessorOutput {
713 inputs,
714 seq_indices,
715 })
716 } else if is_prompt {
717 let metadata = get_prompt_input(
718 input_seqs
719 .iter()
720 .map(|seq| seq.get_toks())
721 .collect::<Vec<_>>(),
722 input_seqs,
723 device,
724 last_n_context_len,
725 return_raw_logits,
726 paged_attn_metadata.as_mut(),
727 mapper,
728 )?;
729 let InnerInputProcessorOutput {
730 inputs:
731 InputMetadata {
732 input: input_ids,
733 positions: seqlen_offsets,
734 context_lens,
735 position_ids,
736 paged_attn_meta,
737 flash_meta,
738 },
739 seq_indices,
740 } = metadata;
741 let inputs: Box<dyn Any> = Box::new(ModelInputs {
742 input_ids,
743 input_ids_full: None,
744 seqlen_offsets,
745 seqlen_offsets_full: None,
746 context_lens,
747 position_ids,
748 paged_attn_meta,
749 flash_meta,
750 flash_meta_full: None,
751 });
752 Ok(InputProcessorOutput {
753 inputs,
754 seq_indices,
755 })
756 } else {
757 let metadata = get_completion_input(
758 input_seqs
759 .iter()
760 .map(|seq| seq.get_toks())
761 .collect::<Vec<_>>(),
762 input_seqs,
763 device,
764 no_kv_cache,
765 last_n_context_len,
766 return_raw_logits,
767 paged_attn_metadata.as_mut(),
768 mapper,
769 )?;
770 let InnerInputProcessorOutput {
771 inputs:
772 InputMetadata {
773 input: input_ids,
774 positions: seqlen_offsets,
775 context_lens,
776 position_ids,
777 paged_attn_meta,
778 flash_meta,
779 },
780 seq_indices,
781 } = metadata;
782 let inputs: Box<dyn Any> = Box::new(ModelInputs {
783 input_ids,
784 input_ids_full: None,
785 seqlen_offsets,
786 seqlen_offsets_full: None,
787 context_lens,
788 position_ids,
789 paged_attn_meta,
790 flash_meta,
791 flash_meta_full: None,
792 });
793 Ok(InputProcessorOutput {
794 inputs,
795 seq_indices,
796 })
797 }
798 }
799
800 fn get_type(&self) -> InputsProcessorType {
801 InputsProcessorType::Text
802 }
803 }
804}