1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12pub const DEFAULT_PROMPT_CHUNK_SIZE: usize = 1024;
13
14#[derive(PartialEq)]
15pub enum InputsProcessorType {
16 Text,
17 Vision,
18}
19
20pub struct InputProcessorOutput {
21 pub inputs: Box<dyn Any>,
22 pub seq_indices: Vec<usize>,
23}
24
25pub trait InputsProcessor {
27 #[allow(clippy::too_many_arguments)]
32 fn process_inputs(
33 &self,
34 tokenizer: Option<Arc<Tokenizer>>,
35 input_seqs: &mut [&mut Sequence],
36 is_prompt: bool,
37 is_xlora: bool,
38 device: &Device,
39 no_kv_cache: bool,
40 last_n_context_len: Option<(usize, usize)>,
41 return_raw_logits: bool,
42 other_config: Option<Arc<dyn Any>>,
43 paged_attn_metadata: Option<PagedAttentionMeta>,
44 prompt_chunksize: Option<NonZeroUsize>,
45 mapper: Option<&dyn DeviceMapper>,
46 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>>;
47
48 fn get_type(&self) -> InputsProcessorType;
49}
50
51pub mod text_models_inputs_processor {
54 use std::{any::Any, collections::HashMap, fmt::Debug, num::NonZeroUsize, sync::Arc};
55
56 use anyhow::Result;
57 use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
58 use tokenizers::Tokenizer;
59
60 use crate::{
61 device_map::DeviceMapper,
62 get_mut_arcmutex,
63 paged_attention::{BlockEngine, _PAD_SLOT_ID},
64 sequence::Sequence,
65 };
66
67 use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
68
69 fn _make_tensor_with_pad<D: WithDType>(
70 x: Vec<Vec<D>>,
71 max_len: usize,
72 pad: D,
73 device: &Device,
74 ) -> Result<Tensor> {
75 let mut padded_x = Vec::new();
76 for mut x_i in x {
77 assert!(x_i.len() <= max_len);
78 x_i.extend([pad].repeat(max_len - x_i.len()));
79 let shape = (x_i.len(),);
80 padded_x.push(Tensor::from_vec(x_i, shape, device)?);
81 }
82 Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
83 }
84
85 pub struct PagedAttentionMeta {
86 pub sliding_window: Option<usize>,
87 pub block_size: usize,
88 pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
89 }
90
91 #[derive(Clone, Debug)]
92 #[allow(dead_code)]
93 pub struct PagedAttentionInputMetadata {
94 pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
95 pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
96 pub slot_mappings: HashMap<DeviceLocation, Tensor>,
97 pub max_context_len: Option<usize>,
98 pub is_first_prompt_chunk: bool,
99 }
100
101 impl PagedAttentionInputMetadata {
102 pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
105 Ok(PagedAttentionInputMetadata {
106 block_tables: None,
107 context_lens: None,
108 max_context_len: None,
109 slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
110 is_first_prompt_chunk: true,
111 })
112 }
113 }
114
115 #[derive(Clone, Debug)]
116 pub struct FlashParams {
117 pub max_q: u32,
118 pub max_k: u32,
119 pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
120 pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
121 }
122
123 pub struct InputMetadata {
124 pub input: Tensor,
125 pub positions: Vec<usize>,
126 pub context_lens: Vec<(usize, usize)>, pub position_ids: Vec<usize>,
128 pub paged_attn_meta: Option<PagedAttentionInputMetadata>, pub flash_meta: FlashParams,
130 }
131
132 pub struct InnerInputProcessorOutput {
133 pub inputs: InputMetadata,
134 pub seq_indices: Vec<usize>,
135 }
136
137 #[allow(clippy::too_many_arguments)]
140 pub fn make_prompt_chunk<T: WithDType + Debug>(
141 chunk_offset_toks: usize,
142 toks: Vec<&[T]>,
143 seq_ids: &[usize],
144 device: &Device,
145 last_n_context_len: Option<(usize, usize)>,
146 return_raw_logits: bool,
147 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
148 mapper: Option<&dyn DeviceMapper>,
149 ) -> Result<InputMetadata> {
150 let max_len = toks
151 .iter()
152 .map(|seq| seq.len())
153 .max()
154 .expect("No sequences");
155 let padding_tok = T::zero();
156 let mut seqs_tensors = Vec::new();
158 let mut seqlen_offsets = Vec::new();
159 let mut context_lens = Vec::new();
160 let mut position_ids = Vec::new();
161 let mut slot_mappings = Vec::new();
162 let mut block_tables = Vec::new();
163 let mut paged_attn_context_lens = Vec::new();
164 let flash_attn = crate::using_flash_attn();
165 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
166 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
167 for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
168 let prompt_len = ctxt.len();
169 let offset = last_n_context_len.unwrap_or_default();
170 seqlen_offsets.push(offset.1 + chunk_offset_toks);
171
172 position_ids.push(ctxt.len() + chunk_offset_toks);
173 let mut ctxt = ctxt.to_vec();
174 ctxt.extend(std::iter::repeat_n(
175 padding_tok,
176 max_len.saturating_sub(ctxt.len()),
177 ));
178 if return_raw_logits {
180 if last_n_context_len.is_some() {
181 anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
182 }
183
184 context_lens.push((0, ctxt.len()));
185 } else {
186 context_lens.push((
187 ctxt.len()
188 .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
189 last_n_context_len.map(|(a, _)| a).unwrap_or(1),
190 ));
191 }
192
193 if flash_attn {
194 seqlens_q.push(ctxt.len() as u32);
195 seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
196 }
197
198 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
199
200 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
201 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
202 let table = block_engine.block_tables.get(seq_id);
203
204 if table.is_none() {
205 slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
207 continue;
208 }
209 let table = table
210 .unwrap()
211 .iter()
212 .map(|block| block.deref_mut().block_id)
213 .collect::<Vec<_>>();
214
215 let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
216 prompt_len.saturating_sub(sliding_window)
217 } else {
218 0
219 };
220
221 let mut slot_mapping = Vec::new();
222 let mut ctxt_len = Vec::new();
223 for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
224 if i < start_idx {
225 slot_mapping.push(_PAD_SLOT_ID);
227 }
228 ctxt_len.push(i);
229
230 let block_number = if i / paged_attn_metadata.block_size >= table.len() {
231 panic!(
232 "Block table is too small (prompt)! i={} block_size={} table_len={}",
233 i,
234 paged_attn_metadata.block_size,
235 table.len()
236 );
237 } else {
238 table.get(i / paged_attn_metadata.block_size).unwrap()
239 };
240 let block_offset = i % paged_attn_metadata.block_size;
241 let slot = block_number * paged_attn_metadata.block_size + block_offset;
242 slot_mapping.push(slot.try_into().unwrap());
243 block_tables.push(table.clone());
244 }
245 slot_mappings.push(slot_mapping);
246 paged_attn_context_lens.push(ctxt_len);
247 }
248 }
249
250 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
251 let max_q = *seqlens_q.iter().max().unwrap();
252 let max_k = *seqlens_k.iter().max().unwrap();
253 let seqlens_q = Tensor::new(seqlens_q, device)?
254 .to_dtype(DType::F32)?
255 .cumsum(0)?
256 .to_dtype(DType::U32)?;
257 let seqlens_k = Tensor::new(seqlens_k, device)?
258 .to_dtype(DType::F32)?
259 .cumsum(0)?
260 .to_dtype(DType::U32)?;
261
262 let mut seqlens_q_map = HashMap::new();
263 let mut seqlens_k_map = HashMap::new();
264
265 let devices = mapper.unwrap().get_unique_devices();
266 for device in devices {
267 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
268 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
269 }
270 (max_q, max_k, seqlens_q_map, seqlens_k_map)
271 } else {
272 (0, 0, HashMap::new(), HashMap::new())
273 };
274
275 let input = Tensor::cat(&seqs_tensors, 0).unwrap();
276
277 let paged_attn_meta = if paged_attn_metadata.is_some() {
278 let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
279 let slot_mappings =
280 _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
281
282 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
283 let block_tables = _make_tensor_with_pad(
284 block_tables
285 .iter()
286 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
287 .collect::<Vec<_>>(),
288 max_block_table_len,
289 0,
290 device,
291 )?;
292 let block_tables = block_tables.reshape(((), max_block_table_len))?;
293
294 let max_context_len = paged_attn_context_lens
295 .iter()
296 .map(|x| x.len())
297 .max()
298 .unwrap();
299
300 let context_lens = _make_tensor_with_pad(
301 paged_attn_context_lens
302 .iter()
303 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
304 .collect::<Vec<_>>(),
305 max_context_len,
306 0,
307 device,
308 )?
309 .reshape(((),))?;
310
311 let devices = mapper.unwrap().get_unique_devices();
313 let mut slot_mappings_map = HashMap::new();
314 let mut block_tables_map = HashMap::new();
315 let mut context_lens_map = HashMap::new();
316
317 for device in devices {
318 slot_mappings_map
319 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
320 block_tables_map
321 .insert(device.location(), block_tables.clone().to_device(&device)?);
322 context_lens_map
323 .insert(device.location(), context_lens.clone().to_device(&device)?);
324 }
325
326 Some(PagedAttentionInputMetadata {
327 slot_mappings: slot_mappings_map,
328 block_tables: Some(block_tables_map),
329 context_lens: Some(context_lens_map),
330 max_context_len: Some(max_context_len),
331 is_first_prompt_chunk: chunk_offset_toks == 0,
332 })
333 } else {
334 None
335 };
336
337 Ok(InputMetadata {
338 input,
339 positions: seqlen_offsets,
340 context_lens,
341 position_ids,
342 paged_attn_meta,
343 flash_meta: FlashParams {
344 max_k,
345 max_q,
346 cumulative_seqlens_k: seqlens_k_map,
347 cumulative_seqlens_q: seqlens_q_map,
348 },
349 })
350 }
351
352 fn make_completion_chunk<T: WithDType>(
353 toks: Vec<&[T]>,
354 input_seqs: &[&mut Sequence],
355 device: &Device,
356 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
357 mapper: Option<&dyn DeviceMapper>,
358 ) -> Result<InputMetadata> {
359 let flash_attn = crate::using_flash_attn();
361 let mut seqs_tensors = Vec::new();
362 let mut seqlen_offsets = Vec::new();
363 let mut context_lens = Vec::new();
364 let mut position_ids = Vec::new();
365
366 let mut slot_mappings = Vec::new();
367 let mut block_tables = Vec::new();
368 let mut paged_attn_context_lens = Vec::new();
369 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
370 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
371 for (seq, ctxt) in input_seqs.iter().zip(toks) {
372 let start_pos = ctxt.len().saturating_sub(1);
373 let ctxt = ctxt[start_pos..].to_vec();
374 seqlen_offsets.push(start_pos);
375 context_lens.push((0, 1));
376 position_ids.push(seq.len());
377
378 if flash_attn {
379 seqlens_q.push(ctxt.len() as u32);
380 seqlens_k.push((ctxt.len() + start_pos) as u32);
381 }
382
383 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
384
385 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
386 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
387 let table = block_engine.block_tables.get(seq.id()).unwrap();
388
389 let table = table
390 .iter()
391 .map(|block| block.deref_mut().block_id)
392 .collect::<Vec<_>>();
393
394 let block_pos = start_pos - seq.token_offset();
395 let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
396 panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
397 } else {
398 table
399 .get(block_pos / paged_attn_metadata.block_size)
400 .unwrap()
401 };
402 let block_offset = block_pos % paged_attn_metadata.block_size;
403 let slot = block_number * paged_attn_metadata.block_size + block_offset;
404 let slot = slot.try_into().unwrap();
405 slot_mappings.push(vec![slot]);
406
407 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
408 let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
409 let slide_idx = if table.len() > sliding_window_blocks {
410 table.len() - sliding_window_blocks
411 } else {
412 0
413 };
414 block_tables.push(table.get(slide_idx..).unwrap().to_vec());
415 } else {
416 block_tables.push(table);
417 }
418
419 let paged_attn_context_len =
420 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
421 seq.len().min(sliding_window)
422 } else {
423 seq.len()
424 };
425 paged_attn_context_lens.push(paged_attn_context_len);
426 }
427 }
428
429 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
430 let max_q = *seqlens_q.iter().max().unwrap();
431 let max_k = *seqlens_k.iter().max().unwrap();
432 let seqlens_q = Tensor::new(seqlens_q, device)?
433 .to_dtype(DType::F32)?
434 .cumsum(0)?
435 .to_dtype(DType::U32)?;
436 let seqlens_k = Tensor::new(seqlens_k, device)?
437 .to_dtype(DType::F32)?
438 .cumsum(0)?
439 .to_dtype(DType::U32)?;
440
441 let mut seqlens_q_map = HashMap::new();
442 let mut seqlens_k_map = HashMap::new();
443
444 let devices = mapper.unwrap().get_unique_devices();
445 for device in devices {
446 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
447 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
448 }
449 (max_q, max_k, seqlens_q_map, seqlens_k_map)
450 } else {
451 (0, 0, HashMap::new(), HashMap::new())
452 };
453
454 let paged_attn_meta = if paged_attn_metadata.is_some() {
455 let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
456
457 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
458
459 let block_tables = _make_tensor_with_pad(
460 block_tables
461 .iter()
462 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
463 .collect::<Vec<_>>(),
464 max_block_table_len,
465 0,
466 device,
467 )?;
468 let block_tables = block_tables.reshape(((), max_block_table_len))?;
469
470 let max_context_len = paged_attn_context_lens.iter().max().unwrap();
471
472 let context_lens = Tensor::from_vec(
473 paged_attn_context_lens
474 .iter()
475 .map(|x| *x as u32)
476 .collect::<Vec<_>>(),
477 (paged_attn_context_lens.len(),),
478 device,
479 )?;
480
481 let devices = mapper.unwrap().get_unique_devices();
483 let mut slot_mappings_map = HashMap::new();
484 let mut block_tables_map = HashMap::new();
485 let mut context_lens_map = HashMap::new();
486
487 for device in devices {
488 slot_mappings_map
489 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
490 block_tables_map
491 .insert(device.location(), block_tables.clone().to_device(&device)?);
492 context_lens_map
493 .insert(device.location(), context_lens.clone().to_device(&device)?);
494 }
495
496 Some(PagedAttentionInputMetadata {
497 slot_mappings: slot_mappings_map,
498 block_tables: Some(block_tables_map),
499 context_lens: Some(context_lens_map),
500 max_context_len: Some(*max_context_len),
501 is_first_prompt_chunk: false,
502 })
503 } else {
504 None
505 };
506
507 Ok(InputMetadata {
508 input: Tensor::cat(&seqs_tensors, 0).unwrap(),
509 positions: seqlen_offsets,
510 context_lens,
511 position_ids,
512 paged_attn_meta,
513 flash_meta: FlashParams {
514 max_k,
515 max_q,
516 cumulative_seqlens_k: seqlens_k_map,
517 cumulative_seqlens_q: seqlens_q_map,
518 },
519 })
520 }
521
522 #[allow(clippy::too_many_arguments)]
523 pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
524 toks: Vec<&[T]>,
525 input_seqs: &[&mut Sequence],
526 device: &Device,
527 last_n_context_len: Option<(usize, usize)>,
528 return_raw_logits: bool,
529 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
530 prompt_chunksize: Option<NonZeroUsize>,
531 mapper: Option<&dyn DeviceMapper>,
532 ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
533 if let (Some(prompt_chunksize), true) = (prompt_chunksize, paged_attn_metadata.is_none()) {
534 let chunk_size = prompt_chunksize.get();
535 let offset = input_seqs[0].token_offset();
536 let num_chunks = toks
538 .iter()
539 .map(|ctxt| ctxt.len().div_ceil(chunk_size))
540 .max()
541 .unwrap_or(0);
542
543 let mut outputs = Vec::with_capacity(num_chunks);
544 for chunk_idx in 0..num_chunks {
545 let mut slices = Vec::new();
546 let mut seq_ids = Vec::new();
547 let mut seq_indices = Vec::new();
548 for (seq_n, ctxt) in toks.iter().enumerate() {
549 let start = chunk_idx * chunk_size;
550 if start < ctxt.len() {
551 let end = (start + chunk_size).min(ctxt.len());
552 slices.push(&ctxt[start..end]);
553 seq_indices.push(seq_n);
554 seq_ids.push(*input_seqs[seq_n].id());
555 }
556 }
557 let result = make_prompt_chunk(
558 chunk_idx * chunk_size + offset,
559 slices,
560 &seq_ids,
561 device,
562 last_n_context_len,
563 return_raw_logits,
564 None,
565 mapper,
566 )
567 .map(|inputs| InnerInputProcessorOutput {
568 inputs,
569 seq_indices,
570 });
571 outputs.push(result);
572 }
573 Box::new(outputs.into_iter())
574 } else {
575 let offset = input_seqs[0].token_offset();
576 Box::new(std::iter::once(
577 make_prompt_chunk(
578 offset,
579 toks,
580 &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
581 device,
582 last_n_context_len,
583 return_raw_logits,
584 paged_attn_metadata,
585 mapper,
586 )
587 .map(|inputs| InnerInputProcessorOutput {
588 inputs,
589 seq_indices: (0..input_seqs.len()).collect(),
590 }),
591 ))
592 }
593 }
594
595 #[allow(clippy::too_many_arguments)]
596 pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
597 toks: Vec<&[T]>,
598 input_seqs: &[&mut Sequence],
599 device: &Device,
600 no_kv_cache: bool,
601 last_n_context_len: Option<(usize, usize)>,
602 return_raw_logits: bool,
603 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
604 prompt_chunksize: Option<NonZeroUsize>,
605 mapper: Option<&dyn DeviceMapper>,
606 ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
607 if no_kv_cache {
608 return get_prompt_input(
609 toks,
610 input_seqs,
611 device,
612 last_n_context_len,
613 return_raw_logits,
614 paged_attn_metadata,
615 prompt_chunksize,
616 mapper,
617 );
618 }
619
620 Box::new(std::iter::once(
621 make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(
622 |inputs| InnerInputProcessorOutput {
623 inputs,
624 seq_indices: (0..input_seqs.len()).collect(),
625 },
626 ),
627 ))
628 }
629
630 #[derive(Clone)]
631 pub struct ModelInputs {
632 pub input_ids: Tensor,
633 pub input_ids_full: Option<Tensor>,
634 pub seqlen_offsets: Vec<usize>,
635 pub seqlen_offsets_full: Option<Vec<usize>>,
636 pub context_lens: Vec<(usize, usize)>,
637 pub position_ids: Vec<usize>,
638 pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
639 pub flash_meta: FlashParams,
640 pub flash_meta_full: Option<FlashParams>,
641 }
642
643 pub struct TextInputsProcessor;
644
645 impl InputsProcessor for TextInputsProcessor {
646 fn process_inputs(
647 &self,
648 _: Option<Arc<Tokenizer>>,
649 input_seqs: &mut [&mut Sequence],
650 is_prompt: bool,
651 is_xlora: bool,
652 device: &Device,
653 no_kv_cache: bool,
654 last_n_context_len: Option<(usize, usize)>,
655 return_raw_logits: bool,
656 _: Option<Arc<dyn Any>>,
657 mut paged_attn_metadata: Option<PagedAttentionMeta>,
658 prompt_chunksize: Option<NonZeroUsize>,
659 mapper: Option<&dyn DeviceMapper>,
660 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
661 if is_xlora && !is_prompt {
662 Box::new(
663 get_prompt_input(
664 input_seqs
665 .iter()
666 .map(|seq| seq.get_toks())
667 .collect::<Vec<_>>(),
668 input_seqs,
669 device,
670 last_n_context_len,
671 return_raw_logits,
672 paged_attn_metadata.as_mut(),
673 prompt_chunksize,
674 mapper,
675 )
676 .zip(get_completion_input(
677 input_seqs
678 .iter()
679 .map(|seq| seq.get_toks())
680 .collect::<Vec<_>>(),
681 input_seqs,
682 device,
683 no_kv_cache,
684 last_n_context_len,
685 return_raw_logits,
686 paged_attn_metadata.as_mut(),
687 prompt_chunksize,
688 mapper,
689 ))
690 .map(|(prompt, completion)| {
691 let InnerInputProcessorOutput {
692 inputs:
693 InputMetadata {
694 input: input_ids_full,
695 positions: seqlen_offsets_full,
696 context_lens: _,
697 position_ids,
698 paged_attn_meta: _,
699 flash_meta: flash_meta_full,
700 },
701 seq_indices,
702 } = prompt?;
703 let InnerInputProcessorOutput {
704 inputs:
705 InputMetadata {
706 input: input_ids,
707 positions: seqlen_offsets,
708 context_lens,
709 position_ids: _,
710 paged_attn_meta,
711 flash_meta,
712 },
713 seq_indices: _,
714 } = completion?;
715 let inputs: Box<dyn Any> = Box::new(ModelInputs {
716 input_ids,
717 input_ids_full: Some(input_ids_full),
718 seqlen_offsets,
719 seqlen_offsets_full: Some(seqlen_offsets_full),
720 context_lens,
721 position_ids,
722 paged_attn_meta,
723 flash_meta,
724 flash_meta_full: Some(flash_meta_full),
725 });
726 Ok(InputProcessorOutput {
727 inputs,
728 seq_indices,
729 })
730 }),
731 )
732 } else if is_xlora && is_prompt {
733 Box::new(
734 get_prompt_input(
735 input_seqs
736 .iter()
737 .map(|seq| seq.get_toks())
738 .collect::<Vec<_>>(),
739 input_seqs,
740 device,
741 last_n_context_len,
742 return_raw_logits,
743 paged_attn_metadata.as_mut(),
744 prompt_chunksize,
745 mapper,
746 )
747 .map(|metadata| {
748 let InnerInputProcessorOutput {
749 inputs:
750 InputMetadata {
751 input: input_ids,
752 positions: seqlen_offsets,
753 context_lens,
754 position_ids,
755 paged_attn_meta,
756 flash_meta,
757 },
758 seq_indices,
759 } = metadata?;
760 let inputs: Box<dyn Any> = Box::new(ModelInputs {
761 input_ids: input_ids.clone(),
762 input_ids_full: Some(input_ids),
763 seqlen_offsets: seqlen_offsets.clone(),
764 seqlen_offsets_full: Some(seqlen_offsets),
765 context_lens,
766 position_ids,
767 paged_attn_meta,
768 flash_meta: flash_meta.clone(),
769 flash_meta_full: Some(flash_meta),
770 });
771 Ok(InputProcessorOutput {
772 inputs,
773 seq_indices,
774 })
775 }),
776 )
777 } else if is_prompt {
778 Box::new(
779 get_prompt_input(
780 input_seqs
781 .iter()
782 .map(|seq| seq.get_toks())
783 .collect::<Vec<_>>(),
784 input_seqs,
785 device,
786 last_n_context_len,
787 return_raw_logits,
788 paged_attn_metadata.as_mut(),
789 prompt_chunksize,
790 mapper,
791 )
792 .map(|metadata| {
793 let InnerInputProcessorOutput {
794 inputs:
795 InputMetadata {
796 input: input_ids,
797 positions: seqlen_offsets,
798 context_lens,
799 position_ids,
800 paged_attn_meta,
801 flash_meta,
802 },
803 seq_indices,
804 } = metadata?;
805 let inputs: Box<dyn Any> = Box::new(ModelInputs {
806 input_ids,
807 input_ids_full: None,
808 seqlen_offsets,
809 seqlen_offsets_full: None,
810 context_lens,
811 position_ids,
812 paged_attn_meta,
813 flash_meta,
814 flash_meta_full: None,
815 });
816 Ok(InputProcessorOutput {
817 inputs,
818 seq_indices,
819 })
820 }),
821 )
822 } else {
823 Box::new(
824 get_completion_input(
825 input_seqs
826 .iter()
827 .map(|seq| seq.get_toks())
828 .collect::<Vec<_>>(),
829 input_seqs,
830 device,
831 no_kv_cache,
832 last_n_context_len,
833 return_raw_logits,
834 paged_attn_metadata.as_mut(),
835 prompt_chunksize,
836 mapper,
837 )
838 .map(|metadata| {
839 let InnerInputProcessorOutput {
840 inputs:
841 InputMetadata {
842 input: input_ids,
843 positions: seqlen_offsets,
844 context_lens,
845 position_ids,
846 paged_attn_meta,
847 flash_meta,
848 },
849 seq_indices,
850 } = metadata?;
851 let inputs: Box<dyn Any> = Box::new(ModelInputs {
852 input_ids,
853 input_ids_full: None,
854 seqlen_offsets,
855 seqlen_offsets_full: None,
856 context_lens,
857 position_ids,
858 paged_attn_meta,
859 flash_meta,
860 flash_meta_full: None,
861 });
862 Ok(InputProcessorOutput {
863 inputs,
864 seq_indices,
865 })
866 }),
867 )
868 }
869 }
870
871 fn get_type(&self) -> InputsProcessorType {
872 InputsProcessorType::Text
873 }
874 }
875}