1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12pub const DEFAULT_PROMPT_CHUNK_SIZE: usize = 1024;
13
14#[derive(PartialEq)]
15pub enum InputsProcessorType {
16 Text,
17 Vision,
18}
19
20pub struct InputProcessorOutput {
21 pub inputs: Box<dyn Any>,
22 pub seq_indices: Vec<usize>,
23}
24
25pub trait InputsProcessor {
27 #[allow(clippy::too_many_arguments)]
32 fn process_inputs(
33 &self,
34 tokenizer: Option<Arc<Tokenizer>>,
35 input_seqs: &mut [&mut Sequence],
36 is_prompt: bool,
37 is_xlora: bool,
38 device: &Device,
39 no_kv_cache: bool,
40 last_n_context_len: Option<(usize, usize)>,
41 return_raw_logits: bool,
42 other_config: Option<Arc<dyn Any>>,
43 paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
44 prompt_chunksize: Option<NonZeroUsize>,
45 mapper: Option<&dyn DeviceMapper>,
46 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>>;
47
48 fn get_type(&self) -> InputsProcessorType;
49}
50
51pub mod text_models_inputs_processor {
54 use std::{any::Any, collections::HashMap, fmt::Debug, num::NonZeroUsize, sync::Arc};
55
56 use anyhow::Result;
57 use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
58 use tokenizers::Tokenizer;
59
60 use crate::{
61 device_map::DeviceMapper,
62 paged_attention::{BlockEngine, _PAD_SLOT_ID},
63 sequence::Sequence,
64 using_flash_attn,
65 };
66
67 use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
68
69 #[inline(always)]
70 fn _make_tensor_with_pad<D: WithDType>(
71 x: Vec<Vec<D>>,
72 max_len: usize,
73 pad: D,
74 device: &Device,
75 ) -> Result<Tensor> {
76 let bs = x.len();
77 let mut padded_x = Vec::with_capacity(x.len() * max_len);
78 for mut x_i in x {
79 assert!(x_i.len() <= max_len);
80 x_i.extend(std::iter::repeat_n(pad, max_len.saturating_sub(x_i.len())));
81 padded_x.extend(x_i);
82 }
83 Tensor::from_vec(padded_x, (bs, max_len), device).map_err(anyhow::Error::msg)
84 }
85
86 pub struct PagedAttentionMeta<'a> {
87 pub sliding_window: Option<usize>,
88 pub block_size: usize,
89 pub block_engine: &'a mut BlockEngine,
90 }
91
92 #[derive(Clone, Debug)]
93 #[allow(dead_code)]
94 pub struct PagedAttentionInputMetadata {
95 pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
96 pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
97 pub slot_mappings: HashMap<DeviceLocation, Tensor>,
98 pub max_context_len: Option<usize>,
99 pub is_first_prompt_chunk: bool,
100 }
101
102 impl PagedAttentionInputMetadata {
103 pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
106 Ok(PagedAttentionInputMetadata {
107 block_tables: None,
108 context_lens: None,
109 max_context_len: None,
110 slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
111 is_first_prompt_chunk: true,
112 })
113 }
114 }
115
116 #[derive(Clone, Debug)]
117 pub struct FlashParams {
118 pub max_q: u32,
119 pub max_k: u32,
120 pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
121 pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
122 }
123
124 pub struct InputMetadata {
125 pub input: Tensor,
126 pub positions: Vec<usize>,
127 pub context_lens: Vec<(usize, usize)>, pub position_ids: Vec<usize>,
129 pub paged_attn_meta: Option<PagedAttentionInputMetadata>, pub flash_meta: FlashParams,
131 }
132
133 pub struct InnerInputProcessorOutput {
134 pub inputs: InputMetadata,
135 pub seq_indices: Vec<usize>,
136 }
137
138 #[allow(clippy::too_many_arguments)]
141 pub fn make_prompt_chunk<T: WithDType + Debug>(
142 chunk_offset_toks: usize,
143 toks: Vec<Vec<T>>,
144 seq_ids: &[usize],
145 device: &Device,
146 last_n_context_len: Option<(usize, usize)>,
147 return_raw_logits: bool,
148 mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
149 mapper: Option<&dyn DeviceMapper>,
150 ) -> Result<InputMetadata> {
151 let max_len = toks
152 .iter()
153 .map(|seq| seq.len())
154 .max()
155 .expect("No sequences");
156 let padding_tok = T::zero();
157 let mut seqs_tensors = Vec::new();
159 let mut seqlen_offsets = Vec::new();
160 let mut context_lens = Vec::new();
161 let mut position_ids = Vec::new();
162 let mut slot_mappings = Vec::new();
163 let mut block_tables = Vec::new();
164 let mut paged_attn_context_lens = Vec::new();
165 let mut seqlens_q = vec![0];
166 let mut seqlens_k = vec![0];
167 for (seq_id, mut ctxt) in seq_ids.iter().zip(toks) {
168 let prompt_len = ctxt.len();
169 let offset = last_n_context_len.unwrap_or_default();
170 seqlen_offsets.push(offset.1 + chunk_offset_toks);
171
172 position_ids.push(ctxt.len() + chunk_offset_toks);
173 ctxt.extend(std::iter::repeat_n(
174 padding_tok,
175 max_len.saturating_sub(ctxt.len()),
176 ));
177 if return_raw_logits {
179 if last_n_context_len.is_some() {
180 anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181 }
182
183 context_lens.push((0, ctxt.len()));
184 } else {
185 context_lens.push((
186 ctxt.len() - last_n_context_len.map(|(a, _)| a).unwrap_or(1),
187 last_n_context_len.map(|(a, _)| a).unwrap_or(1),
188 ));
189 }
190
191 seqlens_q.push(ctxt.len() as u32);
192 seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
193
194 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
195
196 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
197 let table = paged_attn_metadata.block_engine.block_tables.get(seq_id);
198
199 if table.is_none() {
200 slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
202 continue;
203 }
204 let table = table
205 .unwrap()
206 .iter()
207 .map(|block| block.deref_mut().block_id)
208 .collect::<Vec<_>>();
209
210 let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
211 if prompt_len > sliding_window {
212 chunk_offset_toks.min(prompt_len - sliding_window)
213 } else {
214 chunk_offset_toks
215 }
216 } else {
217 chunk_offset_toks
218 };
219
220 let mut slot_mapping = Vec::new();
221 let mut ctxt_len = Vec::new();
222 for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223 if i < start_idx {
224 slot_mapping.push(_PAD_SLOT_ID);
226 }
227 ctxt_len.push(i);
228
229 let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230 panic!(
231 "Block table is too small (prompt)! i={} block_size={} table_len={}",
232 i,
233 paged_attn_metadata.block_size,
234 table.len()
235 );
236 } else {
237 table.get(i / paged_attn_metadata.block_size).unwrap()
238 };
239 let block_offset = i % paged_attn_metadata.block_size;
240 let slot = block_number * paged_attn_metadata.block_size + block_offset;
241 slot_mapping.push(slot.try_into().unwrap());
242 block_tables.push(table.clone());
243 }
244 slot_mappings.push(slot_mapping);
245 paged_attn_context_lens.push(ctxt_len);
246 }
247 }
248
249 let max_q = *seqlens_q.iter().max().unwrap();
250 let max_k = *seqlens_k.iter().max().unwrap();
251
252 let mut seqlens_q_map = HashMap::new();
253 let mut seqlens_k_map = HashMap::new();
254
255 if using_flash_attn() {
256 let seqlens_q = Tensor::new(seqlens_q, device)?
257 .to_dtype(DType::F32)?
258 .cumsum(0)?
259 .to_dtype(DType::U32)?;
260 let seqlens_k = Tensor::new(seqlens_k, device)?
261 .to_dtype(DType::F32)?
262 .cumsum(0)?
263 .to_dtype(DType::U32)?;
264
265 let devices = mapper.unwrap().get_unique_devices();
266 for device in devices {
267 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
268 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
269 }
270 }
271
272 let input = Tensor::cat(&seqs_tensors, 0).unwrap();
273
274 let paged_attn_meta = if paged_attn_metadata.is_some() {
275 let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
276 let slot_mappings =
277 _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
278
279 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
280 let block_tables = _make_tensor_with_pad(
281 block_tables
282 .iter()
283 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
284 .collect::<Vec<_>>(),
285 max_block_table_len,
286 0,
287 device,
288 )?;
289 let block_tables = block_tables.reshape(((), max_block_table_len))?;
290
291 let max_context_len = paged_attn_context_lens
292 .iter()
293 .map(|x| x.len())
294 .max()
295 .unwrap();
296
297 let context_lens = _make_tensor_with_pad(
298 paged_attn_context_lens
299 .iter()
300 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
301 .collect::<Vec<_>>(),
302 max_context_len,
303 0,
304 device,
305 )?
306 .reshape(((),))?;
307
308 let devices = mapper.unwrap().get_unique_devices();
310 let mut slot_mappings_map = HashMap::new();
311 let mut block_tables_map = HashMap::new();
312 let mut context_lens_map = HashMap::new();
313
314 for device in devices {
315 slot_mappings_map
316 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
317 block_tables_map
318 .insert(device.location(), block_tables.clone().to_device(&device)?);
319 context_lens_map
320 .insert(device.location(), context_lens.clone().to_device(&device)?);
321 }
322
323 Some(PagedAttentionInputMetadata {
324 slot_mappings: slot_mappings_map,
325 block_tables: Some(block_tables_map),
326 context_lens: Some(context_lens_map),
327 max_context_len: Some(max_context_len),
328 is_first_prompt_chunk: chunk_offset_toks == 0,
329 })
330 } else {
331 None
332 };
333
334 Ok(InputMetadata {
335 input,
336 positions: seqlen_offsets,
337 context_lens,
338 position_ids,
339 paged_attn_meta,
340 flash_meta: FlashParams {
341 max_k,
342 max_q,
343 cumulative_seqlens_k: seqlens_k_map,
344 cumulative_seqlens_q: seqlens_q_map,
345 },
346 })
347 }
348
349 fn make_completion_chunk<T: WithDType>(
350 toks: Vec<Vec<T>>,
351 input_seqs: &[&mut Sequence],
352 device: &Device,
353 mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
354 mapper: Option<&dyn DeviceMapper>,
355 ) -> Result<InputMetadata> {
356 let mut seqs_tensors = Vec::new();
358 let mut seqlen_offsets = Vec::new();
359 let mut context_lens = Vec::new();
360 let mut position_ids = Vec::new();
361
362 let mut slot_mappings = Vec::new();
363 let mut block_tables = Vec::new();
364 let mut paged_attn_context_lens = Vec::new();
365 let mut seqlens_q = vec![0];
366 let mut seqlens_k = vec![0];
367 for (seq, ctxt) in input_seqs.iter().zip(toks) {
368 let start_pos = ctxt.len().saturating_sub(1);
369 let ctxt = ctxt[start_pos..].to_vec();
370 seqlen_offsets.push(start_pos);
371 context_lens.push((0, 1));
372 position_ids.push(seq.len());
373
374 seqlens_q.push(ctxt.len() as u32);
375 seqlens_k.push((ctxt.len() + start_pos) as u32);
376
377 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
378
379 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
380 let table = paged_attn_metadata
381 .block_engine
382 .block_tables
383 .get(seq.id())
384 .unwrap();
385
386 let table = table
387 .iter()
388 .map(|block| block.deref_mut().block_id)
389 .collect::<Vec<_>>();
390
391 let block_number = if start_pos / paged_attn_metadata.block_size >= table.len() {
392 panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", start_pos, paged_attn_metadata.block_size, table.len());
393 } else {
394 table
395 .get(start_pos / paged_attn_metadata.block_size)
396 .unwrap()
397 };
398 let block_offset = start_pos % paged_attn_metadata.block_size;
399 let slot = block_number * paged_attn_metadata.block_size + block_offset;
400 let slot = slot.try_into().unwrap();
401 slot_mappings.push(vec![slot]);
402
403 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
404 let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
405 let slide_idx = if table.len() > sliding_window_blocks {
406 table.len() - sliding_window_blocks
407 } else {
408 0
409 };
410 block_tables.push(table.get(slide_idx..).unwrap().to_vec());
411 } else {
412 block_tables.push(table);
413 }
414
415 let paged_attn_context_len =
416 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
417 seq.len().min(sliding_window)
418 } else {
419 seq.len()
420 };
421 paged_attn_context_lens.push(paged_attn_context_len);
422 }
423 }
424
425 let max_q = *seqlens_q.iter().max().unwrap();
426 let max_k = *seqlens_k.iter().max().unwrap();
427
428 let mut seqlens_q_map = HashMap::new();
429 let mut seqlens_k_map = HashMap::new();
430
431 if using_flash_attn() {
432 let seqlens_q = Tensor::new(seqlens_q, device)?
433 .to_dtype(DType::F32)?
434 .cumsum(0)?
435 .to_dtype(DType::U32)?;
436 let seqlens_k = Tensor::new(seqlens_k, device)?
437 .to_dtype(DType::F32)?
438 .cumsum(0)?
439 .to_dtype(DType::U32)?;
440
441 let devices = mapper.unwrap().get_unique_devices();
442 for device in devices {
443 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
444 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
445 }
446 }
447
448 let paged_attn_meta = if paged_attn_metadata.is_some() {
449 let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
450
451 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
452
453 let block_tables = _make_tensor_with_pad(
454 block_tables
455 .iter()
456 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
457 .collect::<Vec<_>>(),
458 max_block_table_len,
459 0,
460 device,
461 )?;
462 let block_tables = block_tables.reshape(((), max_block_table_len))?;
463
464 let max_context_len = paged_attn_context_lens.iter().max().unwrap();
465
466 let context_lens = Tensor::from_vec(
467 paged_attn_context_lens
468 .iter()
469 .map(|x| *x as u32)
470 .collect::<Vec<_>>(),
471 (paged_attn_context_lens.len(),),
472 device,
473 )?;
474
475 let devices = mapper.unwrap().get_unique_devices();
477 let mut slot_mappings_map = HashMap::new();
478 let mut block_tables_map = HashMap::new();
479 let mut context_lens_map = HashMap::new();
480
481 for device in devices {
482 slot_mappings_map
483 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
484 block_tables_map
485 .insert(device.location(), block_tables.clone().to_device(&device)?);
486 context_lens_map
487 .insert(device.location(), context_lens.clone().to_device(&device)?);
488 }
489
490 Some(PagedAttentionInputMetadata {
491 slot_mappings: slot_mappings_map,
492 block_tables: Some(block_tables_map),
493 context_lens: Some(context_lens_map),
494 max_context_len: Some(*max_context_len),
495 is_first_prompt_chunk: false,
496 })
497 } else {
498 None
499 };
500
501 Ok(InputMetadata {
502 input: Tensor::cat(&seqs_tensors, 0).unwrap(),
503 positions: seqlen_offsets,
504 context_lens,
505 position_ids,
506 paged_attn_meta,
507 flash_meta: FlashParams {
508 max_k,
509 max_q,
510 cumulative_seqlens_k: seqlens_k_map,
511 cumulative_seqlens_q: seqlens_q_map,
512 },
513 })
514 }
515
516 #[allow(clippy::too_many_arguments)]
517 pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
518 toks: Vec<Vec<T>>,
519 input_seqs: &[&mut Sequence],
520 device: &Device,
521 last_n_context_len: Option<(usize, usize)>,
522 return_raw_logits: bool,
523 mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
524 prompt_chunksize: Option<NonZeroUsize>,
525 mapper: Option<&dyn DeviceMapper>,
526 ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
527 if let (Some(prompt_chunksize), true) = (prompt_chunksize, paged_attn_metadata.is_none()) {
528 let mut seq_chunks = Vec::new();
529 let mut n_chunks = Vec::new();
530 let prompt_chunksize: usize = prompt_chunksize.into();
531
532 let offset = input_seqs[0].token_offset();
535
536 for ctxt in toks.iter() {
538 let chunks = ctxt.chunks(prompt_chunksize).collect::<Vec<_>>();
539 n_chunks.push(chunks.len());
540 seq_chunks.push(chunks);
541 }
542 let mut chunks_transposed: Vec<Vec<(Vec<T>, usize)>> = Vec::new();
544 for (seq_n, seq) in seq_chunks.into_iter().enumerate() {
545 for (i, chunk) in seq.into_iter().enumerate() {
546 match chunks_transposed.get_mut(i) {
547 Some(part) => part.push((chunk.to_vec(), seq_n)),
548 None => chunks_transposed.push(vec![(chunk.to_vec(), seq_n)]),
549 }
550 }
551 }
552 let chunks = chunks_transposed
553 .into_iter()
554 .enumerate()
555 .map(|(i, chunk)| {
556 let (toks, seq_ns): (Vec<Vec<T>>, Vec<usize>) = chunk.into_iter().unzip();
557 make_prompt_chunk(
558 i * prompt_chunksize + offset,
559 toks,
560 &seq_ns
561 .iter()
562 .map(|i| *input_seqs[*i].id())
563 .collect::<Vec<_>>(),
564 device,
565 last_n_context_len,
566 return_raw_logits,
567 paged_attn_metadata.as_deref_mut(),
568 mapper,
569 )
570 .map(|inputs| InnerInputProcessorOutput {
571 inputs,
572 seq_indices: seq_ns,
573 })
574 })
575 .collect::<Vec<_>>();
576 Box::new(chunks.into_iter())
577 } else {
578 let offset = input_seqs[0].token_offset();
579 if offset != 0 && paged_attn_metadata.is_some() {
580 return Box::new(std::iter::once(Err(anyhow::Error::msg(
581 "PagedAttention does not yet support sequences with an offset != 0.",
582 ))));
583 }
584 Box::new(std::iter::once(
585 make_prompt_chunk(
586 offset,
587 toks,
588 &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
589 device,
590 last_n_context_len,
591 return_raw_logits,
592 paged_attn_metadata,
593 mapper,
594 )
595 .map(|inputs| InnerInputProcessorOutput {
596 inputs,
597 seq_indices: (0..input_seqs.len()).collect(),
598 }),
599 ))
600 }
601 }
602
603 #[allow(clippy::too_many_arguments)]
604 pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
605 toks: Vec<Vec<T>>,
606 input_seqs: &[&mut Sequence],
607 device: &Device,
608 no_kv_cache: bool,
609 last_n_context_len: Option<(usize, usize)>,
610 return_raw_logits: bool,
611 paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
612 prompt_chunksize: Option<NonZeroUsize>,
613 mapper: Option<&dyn DeviceMapper>,
614 ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
615 if no_kv_cache {
616 return get_prompt_input(
617 toks,
618 input_seqs,
619 device,
620 last_n_context_len,
621 return_raw_logits,
622 paged_attn_metadata,
623 prompt_chunksize,
624 mapper,
625 );
626 }
627
628 Box::new(std::iter::once(
629 make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(
630 |inputs| InnerInputProcessorOutput {
631 inputs,
632 seq_indices: (0..input_seqs.len()).collect(),
633 },
634 ),
635 ))
636 }
637
638 #[derive(Clone)]
639 pub struct ModelInputs {
640 pub input_ids: Tensor,
641 pub input_ids_full: Option<Tensor>,
642 pub seqlen_offsets: Vec<usize>,
643 pub seqlen_offsets_full: Option<Vec<usize>>,
644 pub context_lens: Vec<(usize, usize)>,
645 pub position_ids: Vec<usize>,
646 pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
647 pub flash_meta: FlashParams,
648 pub flash_meta_full: Option<FlashParams>,
649 }
650
651 pub struct TextInputsProcessor;
652
653 impl InputsProcessor for TextInputsProcessor {
654 fn process_inputs(
655 &self,
656 _: Option<Arc<Tokenizer>>,
657 input_seqs: &mut [&mut Sequence],
658 is_prompt: bool,
659 is_xlora: bool,
660 device: &Device,
661 no_kv_cache: bool,
662 last_n_context_len: Option<(usize, usize)>,
663 return_raw_logits: bool,
664 _: Option<Arc<dyn Any>>,
665 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
666 prompt_chunksize: Option<NonZeroUsize>,
667 mapper: Option<&dyn DeviceMapper>,
668 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
669 if is_xlora && !is_prompt {
670 Box::new(
671 get_prompt_input(
672 input_seqs
673 .iter()
674 .map(|seq| seq.get_toks().to_vec())
675 .collect::<Vec<_>>(),
676 input_seqs,
677 device,
678 last_n_context_len,
679 return_raw_logits,
680 paged_attn_metadata.as_mut(),
681 prompt_chunksize,
682 mapper,
683 )
684 .zip(get_completion_input(
685 input_seqs
686 .iter()
687 .map(|seq| seq.get_toks().to_vec())
688 .collect::<Vec<_>>(),
689 input_seqs,
690 device,
691 no_kv_cache,
692 last_n_context_len,
693 return_raw_logits,
694 paged_attn_metadata.as_mut(),
695 prompt_chunksize,
696 mapper,
697 ))
698 .map(|(prompt, completion)| {
699 let InnerInputProcessorOutput {
700 inputs:
701 InputMetadata {
702 input: input_ids_full,
703 positions: seqlen_offsets_full,
704 context_lens: _,
705 position_ids,
706 paged_attn_meta: _,
707 flash_meta: flash_meta_full,
708 },
709 seq_indices,
710 } = prompt?;
711 let InnerInputProcessorOutput {
712 inputs:
713 InputMetadata {
714 input: input_ids,
715 positions: seqlen_offsets,
716 context_lens,
717 position_ids: _,
718 paged_attn_meta,
719 flash_meta,
720 },
721 seq_indices: _,
722 } = completion?;
723 let inputs: Box<dyn Any> = Box::new(ModelInputs {
724 input_ids,
725 input_ids_full: Some(input_ids_full),
726 seqlen_offsets,
727 seqlen_offsets_full: Some(seqlen_offsets_full),
728 context_lens,
729 position_ids,
730 paged_attn_meta,
731 flash_meta,
732 flash_meta_full: Some(flash_meta_full),
733 });
734 Ok(InputProcessorOutput {
735 inputs,
736 seq_indices,
737 })
738 }),
739 )
740 } else if is_xlora && is_prompt {
741 Box::new(
742 get_prompt_input(
743 input_seqs
744 .iter()
745 .map(|seq| seq.get_toks().to_vec())
746 .collect::<Vec<_>>(),
747 input_seqs,
748 device,
749 last_n_context_len,
750 return_raw_logits,
751 paged_attn_metadata.as_mut(),
752 prompt_chunksize,
753 mapper,
754 )
755 .map(|metadata| {
756 let InnerInputProcessorOutput {
757 inputs:
758 InputMetadata {
759 input: input_ids,
760 positions: seqlen_offsets,
761 context_lens,
762 position_ids,
763 paged_attn_meta,
764 flash_meta,
765 },
766 seq_indices,
767 } = metadata?;
768 let inputs: Box<dyn Any> = Box::new(ModelInputs {
769 input_ids: input_ids.clone(),
770 input_ids_full: Some(input_ids),
771 seqlen_offsets: seqlen_offsets.clone(),
772 seqlen_offsets_full: Some(seqlen_offsets),
773 context_lens,
774 position_ids,
775 paged_attn_meta,
776 flash_meta: flash_meta.clone(),
777 flash_meta_full: Some(flash_meta),
778 });
779 Ok(InputProcessorOutput {
780 inputs,
781 seq_indices,
782 })
783 }),
784 )
785 } else if is_prompt {
786 Box::new(
787 get_prompt_input(
788 input_seqs
789 .iter()
790 .map(|seq| seq.get_toks().to_vec())
791 .collect::<Vec<_>>(),
792 input_seqs,
793 device,
794 last_n_context_len,
795 return_raw_logits,
796 paged_attn_metadata.as_mut(),
797 prompt_chunksize,
798 mapper,
799 )
800 .map(|metadata| {
801 let InnerInputProcessorOutput {
802 inputs:
803 InputMetadata {
804 input: input_ids,
805 positions: seqlen_offsets,
806 context_lens,
807 position_ids,
808 paged_attn_meta,
809 flash_meta,
810 },
811 seq_indices,
812 } = metadata?;
813 let inputs: Box<dyn Any> = Box::new(ModelInputs {
814 input_ids,
815 input_ids_full: None,
816 seqlen_offsets,
817 seqlen_offsets_full: None,
818 context_lens,
819 position_ids,
820 paged_attn_meta,
821 flash_meta,
822 flash_meta_full: None,
823 });
824 Ok(InputProcessorOutput {
825 inputs,
826 seq_indices,
827 })
828 }),
829 )
830 } else {
831 Box::new(
832 get_completion_input(
833 input_seqs
834 .iter()
835 .map(|seq| seq.get_toks().to_vec())
836 .collect::<Vec<_>>(),
837 input_seqs,
838 device,
839 no_kv_cache,
840 last_n_context_len,
841 return_raw_logits,
842 paged_attn_metadata.as_mut(),
843 prompt_chunksize,
844 mapper,
845 )
846 .map(|metadata| {
847 let InnerInputProcessorOutput {
848 inputs:
849 InputMetadata {
850 input: input_ids,
851 positions: seqlen_offsets,
852 context_lens,
853 position_ids,
854 paged_attn_meta,
855 flash_meta,
856 },
857 seq_indices,
858 } = metadata?;
859 let inputs: Box<dyn Any> = Box::new(ModelInputs {
860 input_ids,
861 input_ids_full: None,
862 seqlen_offsets,
863 seqlen_offsets_full: None,
864 context_lens,
865 position_ids,
866 paged_attn_meta,
867 flash_meta,
868 flash_meta_full: None,
869 });
870 Ok(InputProcessorOutput {
871 inputs,
872 seq_indices,
873 })
874 }),
875 )
876 }
877 }
878
879 fn get_type(&self) -> InputsProcessorType {
880 InputsProcessorType::Text
881 }
882 }
883}