1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12pub const DEFAULT_PROMPT_CHUNK_SIZE: usize = 1024;
13
14#[derive(PartialEq)]
15pub enum InputsProcessorType {
16 Text,
17 Vision,
18}
19
20pub struct InputProcessorOutput {
21 pub inputs: Box<dyn Any>,
22 pub seq_indices: Vec<usize>,
23}
24
25pub trait InputsProcessor {
27 #[allow(clippy::too_many_arguments)]
32 fn process_inputs(
33 &self,
34 tokenizer: Option<Arc<Tokenizer>>,
35 input_seqs: &mut [&mut Sequence],
36 is_prompt: bool,
37 is_xlora: bool,
38 device: &Device,
39 no_kv_cache: bool,
40 last_n_context_len: Option<(usize, usize)>,
41 return_raw_logits: bool,
42 other_config: Option<Arc<dyn Any>>,
43 paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
44 prompt_chunksize: Option<NonZeroUsize>,
45 mapper: Option<&dyn DeviceMapper>,
46 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>>;
47
48 fn get_type(&self) -> InputsProcessorType;
49}
50
51pub mod text_models_inputs_processor {
54 use std::{any::Any, collections::HashMap, fmt::Debug, num::NonZeroUsize, sync::Arc};
55
56 use anyhow::Result;
57 use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
58 use tokenizers::Tokenizer;
59
60 use crate::{
61 device_map::DeviceMapper,
62 paged_attention::{BlockEngine, _PAD_SLOT_ID},
63 sequence::Sequence,
64 };
65
66 use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
67
68 fn _make_tensor_with_pad<D: WithDType>(
69 x: Vec<Vec<D>>,
70 max_len: usize,
71 pad: D,
72 device: &Device,
73 ) -> Result<Tensor> {
74 let mut padded_x = Vec::new();
75 for mut x_i in x {
76 assert!(x_i.len() <= max_len);
77 x_i.extend([pad].repeat(max_len - x_i.len()));
78 let shape = (x_i.len(),);
79 padded_x.push(Tensor::from_vec(x_i, shape, device)?);
80 }
81 Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
82 }
83
84 pub struct PagedAttentionMeta<'a> {
85 pub sliding_window: Option<usize>,
86 pub block_size: usize,
87 pub block_engine: &'a mut BlockEngine,
88 }
89
90 #[derive(Clone, Debug)]
91 #[allow(dead_code)]
92 pub struct PagedAttentionInputMetadata {
93 pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
94 pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
95 pub slot_mappings: HashMap<DeviceLocation, Tensor>,
96 pub max_context_len: Option<usize>,
97 pub is_first_prompt_chunk: bool,
98 }
99
100 impl PagedAttentionInputMetadata {
101 pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
104 Ok(PagedAttentionInputMetadata {
105 block_tables: None,
106 context_lens: None,
107 max_context_len: None,
108 slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
109 is_first_prompt_chunk: true,
110 })
111 }
112 }
113
114 #[derive(Clone, Debug)]
115 pub struct FlashParams {
116 pub max_q: u32,
117 pub max_k: u32,
118 pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
119 pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
120 }
121
122 pub struct InputMetadata {
123 pub input: Tensor,
124 pub positions: Vec<usize>,
125 pub context_lens: Vec<(usize, usize)>, pub position_ids: Vec<usize>,
127 pub paged_attn_meta: Option<PagedAttentionInputMetadata>, pub flash_meta: FlashParams,
129 }
130
131 pub struct InnerInputProcessorOutput {
132 pub inputs: InputMetadata,
133 pub seq_indices: Vec<usize>,
134 }
135
136 #[allow(clippy::too_many_arguments)]
139 pub fn make_prompt_chunk<T: WithDType + Debug>(
140 chunk_offset_toks: usize,
141 toks: Vec<Vec<T>>,
142 seq_ids: &[usize],
143 device: &Device,
144 last_n_context_len: Option<(usize, usize)>,
145 return_raw_logits: bool,
146 mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
147 mapper: Option<&dyn DeviceMapper>,
148 ) -> Result<InputMetadata> {
149 let max_len = toks
150 .iter()
151 .map(|seq| seq.len())
152 .max()
153 .expect("No sequences");
154 let padding_tok = T::zero();
155 let mut seqs_tensors = Vec::new();
157 let mut seqlen_offsets = Vec::new();
158 let mut context_lens = Vec::new();
159 let mut position_ids = Vec::new();
160 let mut slot_mappings = Vec::new();
161 let mut block_tables = Vec::new();
162 let mut paged_attn_context_lens = Vec::new();
163 let mut seqlens_q = vec![0];
164 let mut seqlens_k = vec![0];
165 for (seq_id, mut ctxt) in seq_ids.iter().zip(toks) {
166 let prompt_len = ctxt.len();
167 let offset = last_n_context_len.unwrap_or_default();
168 seqlen_offsets.push(offset.1 + chunk_offset_toks);
169
170 position_ids.push(ctxt.len() + chunk_offset_toks);
171 ctxt.extend(std::iter::repeat_n(
172 padding_tok,
173 max_len.saturating_sub(ctxt.len()),
174 ));
175 if return_raw_logits {
177 if last_n_context_len.is_some() {
178 anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
179 }
180
181 context_lens.push((0, ctxt.len()));
182 } else {
183 context_lens.push((
184 ctxt.len() - last_n_context_len.map(|(a, _)| a).unwrap_or(1),
185 last_n_context_len.map(|(a, _)| a).unwrap_or(1),
186 ));
187 }
188
189 seqlens_q.push(ctxt.len() as u32);
190 seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
191
192 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
193
194 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
195 let table = paged_attn_metadata.block_engine.block_tables.get(seq_id);
196
197 if table.is_none() {
198 slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
200 continue;
201 }
202 let table = table
203 .unwrap()
204 .iter()
205 .map(|block| block.deref_mut().block_id)
206 .collect::<Vec<_>>();
207
208 let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
209 if prompt_len > sliding_window {
210 chunk_offset_toks.min(prompt_len - sliding_window)
211 } else {
212 chunk_offset_toks
213 }
214 } else {
215 chunk_offset_toks
216 };
217
218 let mut slot_mapping = Vec::new();
219 let mut ctxt_len = Vec::new();
220 for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
221 if i < start_idx {
222 slot_mapping.push(_PAD_SLOT_ID);
224 }
225 ctxt_len.push(i);
226
227 let block_number = if i / paged_attn_metadata.block_size >= table.len() {
228 panic!(
229 "Block table is too small (prompt)! i={} block_size={} table_len={}",
230 i,
231 paged_attn_metadata.block_size,
232 table.len()
233 );
234 } else {
235 table.get(i / paged_attn_metadata.block_size).unwrap()
236 };
237 let block_offset = i % paged_attn_metadata.block_size;
238 let slot = block_number * paged_attn_metadata.block_size + block_offset;
239 slot_mapping.push(slot.try_into().unwrap());
240 block_tables.push(table.clone());
241 }
242 slot_mappings.push(slot_mapping);
243 paged_attn_context_lens.push(ctxt_len);
244 }
245 }
246
247 let max_q = *seqlens_q.iter().max().unwrap();
248 let max_k = *seqlens_k.iter().max().unwrap();
249 let seqlens_q = Tensor::new(seqlens_q, device)?
250 .to_dtype(DType::F32)?
251 .cumsum(0)?
252 .to_dtype(DType::U32)?;
253 let seqlens_k = Tensor::new(seqlens_k, device)?
254 .to_dtype(DType::F32)?
255 .cumsum(0)?
256 .to_dtype(DType::U32)?;
257
258 let mut seqlens_q_map = HashMap::new();
259 let mut seqlens_k_map = HashMap::new();
260
261 let devices = mapper.unwrap().get_unique_devices();
262 for device in devices {
263 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
264 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
265 }
266
267 let input = Tensor::cat(&seqs_tensors, 0).unwrap();
268
269 let paged_attn_meta = if paged_attn_metadata.is_some() {
270 let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
271 let slot_mappings =
272 _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
273
274 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
275 let block_tables = _make_tensor_with_pad(
276 block_tables
277 .iter()
278 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
279 .collect::<Vec<_>>(),
280 max_block_table_len,
281 0,
282 device,
283 )?;
284 let block_tables = block_tables.reshape(((), max_block_table_len))?;
285
286 let max_context_len = paged_attn_context_lens
287 .iter()
288 .map(|x| x.len())
289 .max()
290 .unwrap();
291
292 let context_lens = _make_tensor_with_pad(
293 paged_attn_context_lens
294 .iter()
295 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
296 .collect::<Vec<_>>(),
297 max_context_len,
298 0,
299 device,
300 )?
301 .reshape(((),))?;
302
303 let devices = mapper.unwrap().get_unique_devices();
305 let mut slot_mappings_map = HashMap::new();
306 let mut block_tables_map = HashMap::new();
307 let mut context_lens_map = HashMap::new();
308
309 for device in devices {
310 slot_mappings_map
311 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
312 block_tables_map
313 .insert(device.location(), block_tables.clone().to_device(&device)?);
314 context_lens_map
315 .insert(device.location(), context_lens.clone().to_device(&device)?);
316 }
317
318 Some(PagedAttentionInputMetadata {
319 slot_mappings: slot_mappings_map,
320 block_tables: Some(block_tables_map),
321 context_lens: Some(context_lens_map),
322 max_context_len: Some(max_context_len),
323 is_first_prompt_chunk: chunk_offset_toks == 0,
324 })
325 } else {
326 None
327 };
328
329 Ok(InputMetadata {
330 input,
331 positions: seqlen_offsets,
332 context_lens,
333 position_ids,
334 paged_attn_meta,
335 flash_meta: FlashParams {
336 max_k,
337 max_q,
338 cumulative_seqlens_k: seqlens_k_map,
339 cumulative_seqlens_q: seqlens_q_map,
340 },
341 })
342 }
343
344 fn make_completion_chunk<T: WithDType>(
345 toks: Vec<Vec<T>>,
346 input_seqs: &[&mut Sequence],
347 device: &Device,
348 mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
349 mapper: Option<&dyn DeviceMapper>,
350 ) -> Result<InputMetadata> {
351 let mut seqs_tensors = Vec::new();
353 let mut seqlen_offsets = Vec::new();
354 let mut context_lens = Vec::new();
355 let mut position_ids = Vec::new();
356
357 let mut slot_mappings = Vec::new();
358 let mut block_tables = Vec::new();
359 let mut paged_attn_context_lens = Vec::new();
360 let mut seqlens_q = vec![0];
361 let mut seqlens_k = vec![0];
362 for (seq, ctxt) in input_seqs.iter().zip(toks) {
363 let start_pos = ctxt.len().saturating_sub(1);
364 let ctxt = ctxt[start_pos..].to_vec();
365 seqlen_offsets.push(start_pos);
366 context_lens.push((0, 1));
367 position_ids.push(seq.len());
368
369 seqlens_q.push(ctxt.len() as u32);
370 seqlens_k.push((ctxt.len() + start_pos) as u32);
371
372 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
373
374 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
375 let table = paged_attn_metadata
376 .block_engine
377 .block_tables
378 .get(seq.id())
379 .unwrap();
380
381 let table = table
382 .iter()
383 .map(|block| block.deref_mut().block_id)
384 .collect::<Vec<_>>();
385
386 let block_number = if start_pos / paged_attn_metadata.block_size >= table.len() {
387 panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", start_pos, paged_attn_metadata.block_size, table.len());
388 } else {
389 table
390 .get(start_pos / paged_attn_metadata.block_size)
391 .unwrap()
392 };
393 let block_offset = start_pos % paged_attn_metadata.block_size;
394 let slot = block_number * paged_attn_metadata.block_size + block_offset;
395 let slot = slot.try_into().unwrap();
396 slot_mappings.push(vec![slot]);
397
398 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
399 let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
400 let slide_idx = if table.len() > sliding_window_blocks {
401 table.len() - sliding_window_blocks
402 } else {
403 0
404 };
405 block_tables.push(table.get(slide_idx..).unwrap().to_vec());
406 } else {
407 block_tables.push(table);
408 }
409
410 let paged_attn_context_len =
411 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
412 seq.len().min(sliding_window)
413 } else {
414 seq.len()
415 };
416 paged_attn_context_lens.push(paged_attn_context_len);
417 }
418 }
419
420 let max_q = *seqlens_q.iter().max().unwrap();
421 let max_k = *seqlens_k.iter().max().unwrap();
422 let seqlens_q = Tensor::new(seqlens_q, device)?
423 .to_dtype(DType::F32)?
424 .cumsum(0)?
425 .to_dtype(DType::U32)?;
426 let seqlens_k = Tensor::new(seqlens_k, device)?
427 .to_dtype(DType::F32)?
428 .cumsum(0)?
429 .to_dtype(DType::U32)?;
430
431 let mut seqlens_q_map = HashMap::new();
432 let mut seqlens_k_map = HashMap::new();
433
434 let devices = mapper.unwrap().get_unique_devices();
435 for device in devices {
436 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
437 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
438 }
439
440 let paged_attn_meta = if paged_attn_metadata.is_some() {
441 let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
442
443 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
444
445 let block_tables = _make_tensor_with_pad(
446 block_tables
447 .iter()
448 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
449 .collect::<Vec<_>>(),
450 max_block_table_len,
451 0,
452 device,
453 )?;
454 let block_tables = block_tables.reshape(((), max_block_table_len))?;
455
456 let max_context_len = paged_attn_context_lens.iter().max().unwrap();
457
458 let context_lens = Tensor::from_vec(
459 paged_attn_context_lens
460 .iter()
461 .map(|x| *x as u32)
462 .collect::<Vec<_>>(),
463 (paged_attn_context_lens.len(),),
464 device,
465 )?;
466
467 let devices = mapper.unwrap().get_unique_devices();
469 let mut slot_mappings_map = HashMap::new();
470 let mut block_tables_map = HashMap::new();
471 let mut context_lens_map = HashMap::new();
472
473 for device in devices {
474 slot_mappings_map
475 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
476 block_tables_map
477 .insert(device.location(), block_tables.clone().to_device(&device)?);
478 context_lens_map
479 .insert(device.location(), context_lens.clone().to_device(&device)?);
480 }
481
482 Some(PagedAttentionInputMetadata {
483 slot_mappings: slot_mappings_map,
484 block_tables: Some(block_tables_map),
485 context_lens: Some(context_lens_map),
486 max_context_len: Some(*max_context_len),
487 is_first_prompt_chunk: false,
488 })
489 } else {
490 None
491 };
492
493 Ok(InputMetadata {
494 input: Tensor::cat(&seqs_tensors, 0).unwrap(),
495 positions: seqlen_offsets,
496 context_lens,
497 position_ids,
498 paged_attn_meta,
499 flash_meta: FlashParams {
500 max_k,
501 max_q,
502 cumulative_seqlens_k: seqlens_k_map,
503 cumulative_seqlens_q: seqlens_q_map,
504 },
505 })
506 }
507
508 #[allow(clippy::too_many_arguments)]
509 pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
510 toks: Vec<Vec<T>>,
511 input_seqs: &[&mut Sequence],
512 device: &Device,
513 last_n_context_len: Option<(usize, usize)>,
514 return_raw_logits: bool,
515 mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
516 prompt_chunksize: Option<NonZeroUsize>,
517 mapper: Option<&dyn DeviceMapper>,
518 ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
519 if let (Some(prompt_chunksize), true) = (prompt_chunksize, paged_attn_metadata.is_none()) {
520 let mut seq_chunks = Vec::new();
521 let mut n_chunks = Vec::new();
522 let prompt_chunksize: usize = prompt_chunksize.into();
523
524 let offset = input_seqs[0].token_offset();
527
528 for ctxt in toks.iter() {
530 let chunks = ctxt.chunks(prompt_chunksize).collect::<Vec<_>>();
531 n_chunks.push(chunks.len());
532 seq_chunks.push(chunks);
533 }
534 let mut chunks_transposed: Vec<Vec<(Vec<T>, usize)>> = Vec::new();
536 for (seq_n, seq) in seq_chunks.into_iter().enumerate() {
537 for (i, chunk) in seq.into_iter().enumerate() {
538 match chunks_transposed.get_mut(i) {
539 Some(part) => part.push((chunk.to_vec(), seq_n)),
540 None => chunks_transposed.push(vec![(chunk.to_vec(), seq_n)]),
541 }
542 }
543 }
544 let chunks = chunks_transposed
545 .into_iter()
546 .enumerate()
547 .map(|(i, chunk)| {
548 let (toks, seq_ns): (Vec<Vec<T>>, Vec<usize>) = chunk.into_iter().unzip();
549 make_prompt_chunk(
550 i * prompt_chunksize + offset,
551 toks,
552 &seq_ns
553 .iter()
554 .map(|i| *input_seqs[*i].id())
555 .collect::<Vec<_>>(),
556 device,
557 last_n_context_len,
558 return_raw_logits,
559 paged_attn_metadata.as_deref_mut(),
560 mapper,
561 )
562 .map(|inputs| InnerInputProcessorOutput {
563 inputs,
564 seq_indices: seq_ns,
565 })
566 })
567 .collect::<Vec<_>>();
568 Box::new(chunks.into_iter())
569 } else {
570 let offset = input_seqs[0].token_offset();
571 if offset != 0 && paged_attn_metadata.is_some() {
572 return Box::new(std::iter::once(Err(anyhow::Error::msg(
573 "PagedAttention does not yet support sequences with an offset != 0.",
574 ))));
575 }
576 Box::new(std::iter::once(
577 make_prompt_chunk(
578 offset,
579 toks,
580 &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
581 device,
582 last_n_context_len,
583 return_raw_logits,
584 paged_attn_metadata,
585 mapper,
586 )
587 .map(|inputs| InnerInputProcessorOutput {
588 inputs,
589 seq_indices: (0..input_seqs.len()).collect(),
590 }),
591 ))
592 }
593 }
594
595 #[allow(clippy::too_many_arguments)]
596 pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
597 toks: Vec<Vec<T>>,
598 input_seqs: &[&mut Sequence],
599 device: &Device,
600 no_kv_cache: bool,
601 last_n_context_len: Option<(usize, usize)>,
602 return_raw_logits: bool,
603 paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
604 prompt_chunksize: Option<NonZeroUsize>,
605 mapper: Option<&dyn DeviceMapper>,
606 ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
607 if no_kv_cache {
608 return get_prompt_input(
609 toks,
610 input_seqs,
611 device,
612 last_n_context_len,
613 return_raw_logits,
614 paged_attn_metadata,
615 prompt_chunksize,
616 mapper,
617 );
618 }
619
620 Box::new(std::iter::once(
621 make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(
622 |inputs| InnerInputProcessorOutput {
623 inputs,
624 seq_indices: (0..input_seqs.len()).collect(),
625 },
626 ),
627 ))
628 }
629
630 #[derive(Clone)]
631 pub struct ModelInputs {
632 pub input_ids: Tensor,
633 pub input_ids_full: Option<Tensor>,
634 pub seqlen_offsets: Vec<usize>,
635 pub seqlen_offsets_full: Option<Vec<usize>>,
636 pub context_lens: Vec<(usize, usize)>,
637 pub position_ids: Vec<usize>,
638 pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
639 pub flash_meta: FlashParams,
640 pub flash_meta_full: Option<FlashParams>,
641 }
642
643 pub struct TextInputsProcessor;
644
645 impl InputsProcessor for TextInputsProcessor {
646 fn process_inputs(
647 &self,
648 _: Option<Arc<Tokenizer>>,
649 input_seqs: &mut [&mut Sequence],
650 is_prompt: bool,
651 is_xlora: bool,
652 device: &Device,
653 no_kv_cache: bool,
654 last_n_context_len: Option<(usize, usize)>,
655 return_raw_logits: bool,
656 _: Option<Arc<dyn Any>>,
657 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
658 prompt_chunksize: Option<NonZeroUsize>,
659 mapper: Option<&dyn DeviceMapper>,
660 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
661 if is_xlora && !is_prompt {
662 Box::new(
663 get_prompt_input(
664 input_seqs
665 .iter()
666 .map(|seq| seq.get_toks().to_vec())
667 .collect::<Vec<_>>(),
668 input_seqs,
669 device,
670 last_n_context_len,
671 return_raw_logits,
672 paged_attn_metadata.as_mut(),
673 prompt_chunksize,
674 mapper,
675 )
676 .zip(get_completion_input(
677 input_seqs
678 .iter()
679 .map(|seq| seq.get_toks().to_vec())
680 .collect::<Vec<_>>(),
681 input_seqs,
682 device,
683 no_kv_cache,
684 last_n_context_len,
685 return_raw_logits,
686 paged_attn_metadata.as_mut(),
687 prompt_chunksize,
688 mapper,
689 ))
690 .map(|(prompt, completion)| {
691 let InnerInputProcessorOutput {
692 inputs:
693 InputMetadata {
694 input: input_ids_full,
695 positions: seqlen_offsets_full,
696 context_lens: _,
697 position_ids,
698 paged_attn_meta: _,
699 flash_meta: flash_meta_full,
700 },
701 seq_indices,
702 } = prompt?;
703 let InnerInputProcessorOutput {
704 inputs:
705 InputMetadata {
706 input: input_ids,
707 positions: seqlen_offsets,
708 context_lens,
709 position_ids: _,
710 paged_attn_meta,
711 flash_meta,
712 },
713 seq_indices: _,
714 } = completion?;
715 let inputs: Box<dyn Any> = Box::new(ModelInputs {
716 input_ids,
717 input_ids_full: Some(input_ids_full),
718 seqlen_offsets,
719 seqlen_offsets_full: Some(seqlen_offsets_full),
720 context_lens,
721 position_ids,
722 paged_attn_meta,
723 flash_meta,
724 flash_meta_full: Some(flash_meta_full),
725 });
726 Ok(InputProcessorOutput {
727 inputs,
728 seq_indices,
729 })
730 }),
731 )
732 } else if is_xlora && is_prompt {
733 Box::new(
734 get_prompt_input(
735 input_seqs
736 .iter()
737 .map(|seq| seq.get_toks().to_vec())
738 .collect::<Vec<_>>(),
739 input_seqs,
740 device,
741 last_n_context_len,
742 return_raw_logits,
743 paged_attn_metadata.as_mut(),
744 prompt_chunksize,
745 mapper,
746 )
747 .map(|metadata| {
748 let InnerInputProcessorOutput {
749 inputs:
750 InputMetadata {
751 input: input_ids,
752 positions: seqlen_offsets,
753 context_lens,
754 position_ids,
755 paged_attn_meta,
756 flash_meta,
757 },
758 seq_indices,
759 } = metadata?;
760 let inputs: Box<dyn Any> = Box::new(ModelInputs {
761 input_ids: input_ids.clone(),
762 input_ids_full: Some(input_ids),
763 seqlen_offsets: seqlen_offsets.clone(),
764 seqlen_offsets_full: Some(seqlen_offsets),
765 context_lens,
766 position_ids,
767 paged_attn_meta,
768 flash_meta: flash_meta.clone(),
769 flash_meta_full: Some(flash_meta),
770 });
771 Ok(InputProcessorOutput {
772 inputs,
773 seq_indices,
774 })
775 }),
776 )
777 } else if is_prompt {
778 Box::new(
779 get_prompt_input(
780 input_seqs
781 .iter()
782 .map(|seq| seq.get_toks().to_vec())
783 .collect::<Vec<_>>(),
784 input_seqs,
785 device,
786 last_n_context_len,
787 return_raw_logits,
788 paged_attn_metadata.as_mut(),
789 prompt_chunksize,
790 mapper,
791 )
792 .map(|metadata| {
793 let InnerInputProcessorOutput {
794 inputs:
795 InputMetadata {
796 input: input_ids,
797 positions: seqlen_offsets,
798 context_lens,
799 position_ids,
800 paged_attn_meta,
801 flash_meta,
802 },
803 seq_indices,
804 } = metadata?;
805 let inputs: Box<dyn Any> = Box::new(ModelInputs {
806 input_ids,
807 input_ids_full: None,
808 seqlen_offsets,
809 seqlen_offsets_full: None,
810 context_lens,
811 position_ids,
812 paged_attn_meta,
813 flash_meta,
814 flash_meta_full: None,
815 });
816 Ok(InputProcessorOutput {
817 inputs,
818 seq_indices,
819 })
820 }),
821 )
822 } else {
823 Box::new(
824 get_completion_input(
825 input_seqs
826 .iter()
827 .map(|seq| seq.get_toks().to_vec())
828 .collect::<Vec<_>>(),
829 input_seqs,
830 device,
831 no_kv_cache,
832 last_n_context_len,
833 return_raw_logits,
834 paged_attn_metadata.as_mut(),
835 prompt_chunksize,
836 mapper,
837 )
838 .map(|metadata| {
839 let InnerInputProcessorOutput {
840 inputs:
841 InputMetadata {
842 input: input_ids,
843 positions: seqlen_offsets,
844 context_lens,
845 position_ids,
846 paged_attn_meta,
847 flash_meta,
848 },
849 seq_indices,
850 } = metadata?;
851 let inputs: Box<dyn Any> = Box::new(ModelInputs {
852 input_ids,
853 input_ids_full: None,
854 seqlen_offsets,
855 seqlen_offsets_full: None,
856 context_lens,
857 position_ids,
858 paged_attn_meta,
859 flash_meta,
860 flash_meta_full: None,
861 });
862 Ok(InputProcessorOutput {
863 inputs,
864 seq_indices,
865 })
866 }),
867 )
868 }
869 }
870
871 fn get_type(&self) -> InputsProcessorType {
872 InputsProcessorType::Text
873 }
874 }
875}