1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14 Text,
15 Vision,
16}
17
18pub struct InputProcessorOutput {
19 pub inputs: Box<dyn Any>,
20 pub seq_indices: Vec<usize>,
21}
22
23pub trait InputsProcessor {
25 #[allow(clippy::too_many_arguments)]
30 fn process_inputs(
31 &self,
32 tokenizer: Option<Arc<Tokenizer>>,
33 input_seqs: &mut [&mut Sequence],
34 is_prompt: bool,
35 is_xlora: bool,
36 device: &Device,
37 no_kv_cache: bool,
38 last_n_context_len: Option<(usize, usize)>,
39 return_raw_logits: bool,
40 other_config: Option<Arc<dyn Any>>,
41 paged_attn_metadata: Option<PagedAttentionMeta>,
42 mapper: Option<&dyn DeviceMapper>,
43 ) -> Result<InputProcessorOutput>;
44
45 fn get_type(&self) -> InputsProcessorType;
46}
47
48pub mod text_models_inputs_processor {
51 use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
52
53 use anyhow::Result;
54 use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
55 use tokenizers::Tokenizer;
56
57 use crate::{
58 device_map::DeviceMapper,
59 get_mut_arcmutex,
60 paged_attention::{BlockEngine, _PAD_SLOT_ID},
61 sequence::Sequence,
62 };
63
64 use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
65
66 fn _make_tensor_with_pad<D: WithDType>(
67 x: Vec<Vec<D>>,
68 max_len: usize,
69 pad: D,
70 device: &Device,
71 ) -> Result<Tensor> {
72 let mut padded_x = Vec::new();
73 for mut x_i in x {
74 assert!(x_i.len() <= max_len);
75 x_i.extend([pad].repeat(max_len - x_i.len()));
76 let shape = (x_i.len(),);
77 padded_x.push(Tensor::from_vec(x_i, shape, device)?);
78 }
79 Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
80 }
81
82 pub struct PagedAttentionMeta {
83 pub sliding_window: Option<usize>,
84 pub block_size: usize,
85 pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
86 }
87
88 #[derive(Clone, Debug)]
89 #[allow(dead_code)]
90 pub struct PagedAttentionInputMetadata {
91 pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
92 pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
93 pub slot_mappings: HashMap<DeviceLocation, Tensor>,
94 pub max_context_len: Option<usize>,
95 pub is_first_prompt_chunk: bool,
96 }
97
98 impl PagedAttentionInputMetadata {
99 pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
102 Ok(PagedAttentionInputMetadata {
103 block_tables: None,
104 context_lens: None,
105 max_context_len: None,
106 slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
107 is_first_prompt_chunk: true,
108 })
109 }
110 }
111
112 #[derive(Clone, Debug)]
113 pub struct FlashParams {
114 pub max_q: u32,
115 pub max_k: u32,
116 pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
117 pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
118 }
119
120 pub struct InputMetadata {
121 pub input: Tensor,
122 pub positions: Vec<usize>,
123 pub context_lens: Vec<(usize, usize)>, pub position_ids: Vec<usize>,
125 pub paged_attn_meta: Option<PagedAttentionInputMetadata>, pub flash_meta: FlashParams,
127 }
128
129 pub struct InnerInputProcessorOutput {
130 pub inputs: InputMetadata,
131 pub seq_indices: Vec<usize>,
132 }
133
134 #[allow(clippy::too_many_arguments)]
137 pub fn make_prompt_chunk<T: WithDType + Debug>(
138 chunk_offset_toks: usize,
139 toks: Vec<&[T]>,
140 seq_ids: &[usize],
141 device: &Device,
142 last_n_context_len: Option<(usize, usize)>,
143 return_raw_logits: bool,
144 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
145 mapper: Option<&dyn DeviceMapper>,
146 ) -> Result<InputMetadata> {
147 let max_len = toks
148 .iter()
149 .map(|seq| seq.len())
150 .max()
151 .expect("No sequences");
152 let padding_tok = T::zero();
153 let mut seqs_tensors = Vec::new();
155 let mut seqlen_offsets = Vec::new();
156 let mut context_lens = Vec::new();
157 let mut position_ids = Vec::new();
158 let mut slot_mappings = Vec::new();
159 let mut block_tables = Vec::new();
160 let mut paged_attn_context_lens = Vec::new();
161 let flash_attn = crate::using_flash_attn();
162 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
163 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
164 for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
165 let prompt_len = ctxt.len();
166 let offset = last_n_context_len.unwrap_or_default();
167 seqlen_offsets.push(offset.1 + chunk_offset_toks);
168
169 position_ids.push(ctxt.len() + chunk_offset_toks);
170 let mut ctxt = ctxt.to_vec();
171 ctxt.extend(std::iter::repeat_n(
172 padding_tok,
173 max_len.saturating_sub(ctxt.len()),
174 ));
175 if return_raw_logits {
177 if last_n_context_len.is_some() {
178 anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
179 }
180
181 context_lens.push((0, ctxt.len()));
182 } else {
183 context_lens.push((
184 ctxt.len()
185 .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
186 last_n_context_len.map(|(a, _)| a).unwrap_or(1),
187 ));
188 }
189
190 if flash_attn {
191 seqlens_q.push(ctxt.len() as u32);
192 seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
193 }
194
195 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
196
197 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
198 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
199 let table = block_engine.block_tables.get(seq_id);
200
201 if table.is_none() {
202 slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
204 continue;
205 }
206 let table = table
207 .unwrap()
208 .iter()
209 .map(|block| block.deref_mut().block_id)
210 .collect::<Vec<_>>();
211
212 let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
213 prompt_len.saturating_sub(sliding_window)
214 } else {
215 0
216 };
217
218 let mut slot_mapping = Vec::new();
219 let mut ctxt_len = Vec::new();
220 for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
221 if i < start_idx {
222 slot_mapping.push(_PAD_SLOT_ID);
224 }
225 ctxt_len.push(i);
226
227 let block_number = if i / paged_attn_metadata.block_size >= table.len() {
228 panic!(
229 "Block table is too small (prompt)! i={} block_size={} table_len={}",
230 i,
231 paged_attn_metadata.block_size,
232 table.len()
233 );
234 } else {
235 table.get(i / paged_attn_metadata.block_size).unwrap()
236 };
237 let block_offset = i % paged_attn_metadata.block_size;
238 let slot = block_number * paged_attn_metadata.block_size + block_offset;
239 slot_mapping.push(slot.try_into().unwrap());
240 block_tables.push(table.clone());
241 }
242 slot_mappings.push(slot_mapping);
243 paged_attn_context_lens.push(ctxt_len);
244 }
245 }
246
247 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
248 let max_q = *seqlens_q.iter().max().unwrap();
249 let max_k = *seqlens_k.iter().max().unwrap();
250 let seqlens_q = Tensor::new(seqlens_q, device)?
251 .to_dtype(DType::F32)?
252 .cumsum(0)?
253 .to_dtype(DType::U32)?;
254 let seqlens_k = Tensor::new(seqlens_k, device)?
255 .to_dtype(DType::F32)?
256 .cumsum(0)?
257 .to_dtype(DType::U32)?;
258
259 let mut seqlens_q_map = HashMap::new();
260 let mut seqlens_k_map = HashMap::new();
261
262 let devices = mapper.unwrap().get_unique_devices();
263 for device in devices {
264 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
265 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
266 }
267 (max_q, max_k, seqlens_q_map, seqlens_k_map)
268 } else {
269 (0, 0, HashMap::new(), HashMap::new())
270 };
271
272 let input = Tensor::cat(&seqs_tensors, 0).unwrap();
273
274 let paged_attn_meta = if paged_attn_metadata.is_some() {
275 let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
276 let slot_mappings =
277 _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
278
279 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
280 let block_tables = _make_tensor_with_pad(
281 block_tables
282 .iter()
283 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
284 .collect::<Vec<_>>(),
285 max_block_table_len,
286 0,
287 device,
288 )?;
289 let block_tables = block_tables.reshape(((), max_block_table_len))?;
290
291 let max_context_len = paged_attn_context_lens
292 .iter()
293 .map(|x| x.len())
294 .max()
295 .unwrap();
296
297 let context_lens = _make_tensor_with_pad(
298 paged_attn_context_lens
299 .iter()
300 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
301 .collect::<Vec<_>>(),
302 max_context_len,
303 0,
304 device,
305 )?
306 .reshape(((),))?;
307
308 let devices = mapper.unwrap().get_unique_devices();
310 let mut slot_mappings_map = HashMap::new();
311 let mut block_tables_map = HashMap::new();
312 let mut context_lens_map = HashMap::new();
313
314 for device in devices {
315 slot_mappings_map
316 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
317 block_tables_map
318 .insert(device.location(), block_tables.clone().to_device(&device)?);
319 context_lens_map
320 .insert(device.location(), context_lens.clone().to_device(&device)?);
321 }
322
323 Some(PagedAttentionInputMetadata {
324 slot_mappings: slot_mappings_map,
325 block_tables: Some(block_tables_map),
326 context_lens: Some(context_lens_map),
327 max_context_len: Some(max_context_len),
328 is_first_prompt_chunk: chunk_offset_toks == 0,
329 })
330 } else {
331 None
332 };
333
334 Ok(InputMetadata {
335 input,
336 positions: seqlen_offsets,
337 context_lens,
338 position_ids,
339 paged_attn_meta,
340 flash_meta: FlashParams {
341 max_k,
342 max_q,
343 cumulative_seqlens_k: seqlens_k_map,
344 cumulative_seqlens_q: seqlens_q_map,
345 },
346 })
347 }
348
349 fn make_completion_chunk<T: WithDType>(
350 toks: Vec<&[T]>,
351 input_seqs: &[&mut Sequence],
352 device: &Device,
353 mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
354 mapper: Option<&dyn DeviceMapper>,
355 ) -> Result<InputMetadata> {
356 let flash_attn = crate::using_flash_attn();
358 let mut seqs_tensors = Vec::new();
359 let mut seqlen_offsets = Vec::new();
360 let mut context_lens = Vec::new();
361 let mut position_ids = Vec::new();
362
363 let mut slot_mappings = Vec::new();
364 let mut block_tables = Vec::new();
365 let mut paged_attn_context_lens = Vec::new();
366 let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
367 let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
368 for (seq, ctxt) in input_seqs.iter().zip(toks) {
369 let start_pos = ctxt.len().saturating_sub(1);
370 let ctxt = ctxt[start_pos..].to_vec();
371 seqlen_offsets.push(start_pos);
372 context_lens.push((0, 1));
373 position_ids.push(seq.len());
374
375 if flash_attn {
376 seqlens_q.push(ctxt.len() as u32);
377 seqlens_k.push((ctxt.len() + start_pos) as u32);
378 }
379
380 seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
381
382 if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
383 let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
384 let table = block_engine.block_tables.get(seq.id()).unwrap();
385
386 let table = table
387 .iter()
388 .map(|block| block.deref_mut().block_id)
389 .collect::<Vec<_>>();
390
391 let block_pos = start_pos - seq.token_offset();
392 let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
393 panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
394 } else {
395 table
396 .get(block_pos / paged_attn_metadata.block_size)
397 .unwrap()
398 };
399 let block_offset = block_pos % paged_attn_metadata.block_size;
400 let slot = block_number * paged_attn_metadata.block_size + block_offset;
401 let slot = slot.try_into().unwrap();
402 slot_mappings.push(vec![slot]);
403
404 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
405 let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
406 let slide_idx = if table.len() > sliding_window_blocks {
407 table.len() - sliding_window_blocks
408 } else {
409 0
410 };
411 block_tables.push(table.get(slide_idx..).unwrap().to_vec());
412 } else {
413 block_tables.push(table);
414 }
415
416 let paged_attn_context_len =
417 if let Some(sliding_window) = paged_attn_metadata.sliding_window {
418 seq.len().min(sliding_window)
419 } else {
420 seq.len()
421 };
422 paged_attn_context_lens.push(paged_attn_context_len);
423 }
424 }
425
426 let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
427 let max_q = *seqlens_q.iter().max().unwrap();
428 let max_k = *seqlens_k.iter().max().unwrap();
429 let seqlens_q = Tensor::new(seqlens_q, device)?
430 .to_dtype(DType::F32)?
431 .cumsum(0)?
432 .to_dtype(DType::U32)?;
433 let seqlens_k = Tensor::new(seqlens_k, device)?
434 .to_dtype(DType::F32)?
435 .cumsum(0)?
436 .to_dtype(DType::U32)?;
437
438 let mut seqlens_q_map = HashMap::new();
439 let mut seqlens_k_map = HashMap::new();
440
441 let devices = mapper.unwrap().get_unique_devices();
442 for device in devices {
443 seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
444 seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
445 }
446 (max_q, max_k, seqlens_q_map, seqlens_k_map)
447 } else {
448 (0, 0, HashMap::new(), HashMap::new())
449 };
450
451 let paged_attn_meta = if paged_attn_metadata.is_some() {
452 let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
453
454 let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
455
456 let block_tables = _make_tensor_with_pad(
457 block_tables
458 .iter()
459 .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
460 .collect::<Vec<_>>(),
461 max_block_table_len,
462 0,
463 device,
464 )?;
465 let block_tables = block_tables.reshape(((), max_block_table_len))?;
466
467 let max_context_len = paged_attn_context_lens.iter().max().unwrap();
468
469 let context_lens = Tensor::from_vec(
470 paged_attn_context_lens
471 .iter()
472 .map(|x| *x as u32)
473 .collect::<Vec<_>>(),
474 (paged_attn_context_lens.len(),),
475 device,
476 )?;
477
478 let devices = mapper.unwrap().get_unique_devices();
480 let mut slot_mappings_map = HashMap::new();
481 let mut block_tables_map = HashMap::new();
482 let mut context_lens_map = HashMap::new();
483
484 for device in devices {
485 slot_mappings_map
486 .insert(device.location(), slot_mappings.clone().to_device(&device)?);
487 block_tables_map
488 .insert(device.location(), block_tables.clone().to_device(&device)?);
489 context_lens_map
490 .insert(device.location(), context_lens.clone().to_device(&device)?);
491 }
492
493 Some(PagedAttentionInputMetadata {
494 slot_mappings: slot_mappings_map,
495 block_tables: Some(block_tables_map),
496 context_lens: Some(context_lens_map),
497 max_context_len: Some(*max_context_len),
498 is_first_prompt_chunk: false,
499 })
500 } else {
501 None
502 };
503
504 Ok(InputMetadata {
505 input: Tensor::cat(&seqs_tensors, 0).unwrap(),
506 positions: seqlen_offsets,
507 context_lens,
508 position_ids,
509 paged_attn_meta,
510 flash_meta: FlashParams {
511 max_k,
512 max_q,
513 cumulative_seqlens_k: seqlens_k_map,
514 cumulative_seqlens_q: seqlens_q_map,
515 },
516 })
517 }
518
519 #[allow(clippy::too_many_arguments)]
520 pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
521 toks: Vec<&[T]>,
522 input_seqs: &[&mut Sequence],
523 device: &Device,
524 last_n_context_len: Option<(usize, usize)>,
525 return_raw_logits: bool,
526 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
527 mapper: Option<&dyn DeviceMapper>,
528 ) -> Result<InnerInputProcessorOutput> {
529 let offset = input_seqs[0].token_offset();
530 make_prompt_chunk(
531 offset,
532 toks,
533 &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
534 device,
535 last_n_context_len,
536 return_raw_logits,
537 paged_attn_metadata,
538 mapper,
539 )
540 .map(|inputs| InnerInputProcessorOutput {
541 inputs,
542 seq_indices: (0..input_seqs.len()).collect(),
543 })
544 }
545
546 #[allow(clippy::too_many_arguments)]
547 pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
548 toks: Vec<&[T]>,
549 input_seqs: &[&mut Sequence],
550 device: &Device,
551 no_kv_cache: bool,
552 last_n_context_len: Option<(usize, usize)>,
553 return_raw_logits: bool,
554 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
555 mapper: Option<&dyn DeviceMapper>,
556 ) -> Result<InnerInputProcessorOutput> {
557 if no_kv_cache {
558 return get_prompt_input(
559 toks,
560 input_seqs,
561 device,
562 last_n_context_len,
563 return_raw_logits,
564 paged_attn_metadata,
565 mapper,
566 );
567 }
568
569 make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
570 InnerInputProcessorOutput {
571 inputs,
572 seq_indices: (0..input_seqs.len()).collect(),
573 }
574 })
575 }
576
577 #[derive(Clone)]
578 pub struct ModelInputs {
579 pub input_ids: Tensor,
580 pub input_ids_full: Option<Tensor>,
581 pub seqlen_offsets: Vec<usize>,
582 pub seqlen_offsets_full: Option<Vec<usize>>,
583 pub context_lens: Vec<(usize, usize)>,
584 pub position_ids: Vec<usize>,
585 pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
586 pub flash_meta: FlashParams,
587 pub flash_meta_full: Option<FlashParams>,
588 }
589
590 pub struct TextInputsProcessor;
591
592 impl InputsProcessor for TextInputsProcessor {
593 fn process_inputs(
594 &self,
595 _: Option<Arc<Tokenizer>>,
596 input_seqs: &mut [&mut Sequence],
597 is_prompt: bool,
598 is_xlora: bool,
599 device: &Device,
600 no_kv_cache: bool,
601 last_n_context_len: Option<(usize, usize)>,
602 return_raw_logits: bool,
603 _: Option<Arc<dyn Any>>,
604 mut paged_attn_metadata: Option<PagedAttentionMeta>,
605 mapper: Option<&dyn DeviceMapper>,
606 ) -> Result<InputProcessorOutput> {
607 if is_xlora && !is_prompt {
608 let prompt = get_prompt_input(
609 input_seqs
610 .iter()
611 .map(|seq| seq.get_toks())
612 .collect::<Vec<_>>(),
613 input_seqs,
614 device,
615 last_n_context_len,
616 return_raw_logits,
617 paged_attn_metadata.as_mut(),
618 mapper,
619 )?;
620 let completion = get_completion_input(
621 input_seqs
622 .iter()
623 .map(|seq| seq.get_toks())
624 .collect::<Vec<_>>(),
625 input_seqs,
626 device,
627 no_kv_cache,
628 last_n_context_len,
629 return_raw_logits,
630 paged_attn_metadata.as_mut(),
631 mapper,
632 )?;
633 let InnerInputProcessorOutput {
634 inputs:
635 InputMetadata {
636 input: input_ids_full,
637 positions: seqlen_offsets_full,
638 context_lens: _,
639 position_ids,
640 paged_attn_meta: _,
641 flash_meta: flash_meta_full,
642 },
643 seq_indices,
644 } = prompt;
645 let InnerInputProcessorOutput {
646 inputs:
647 InputMetadata {
648 input: input_ids,
649 positions: seqlen_offsets,
650 context_lens,
651 position_ids: _,
652 paged_attn_meta,
653 flash_meta,
654 },
655 seq_indices: _,
656 } = completion;
657 let inputs: Box<dyn Any> = Box::new(ModelInputs {
658 input_ids,
659 input_ids_full: Some(input_ids_full),
660 seqlen_offsets,
661 seqlen_offsets_full: Some(seqlen_offsets_full),
662 context_lens,
663 position_ids,
664 paged_attn_meta,
665 flash_meta,
666 flash_meta_full: Some(flash_meta_full),
667 });
668 Ok(InputProcessorOutput {
669 inputs,
670 seq_indices,
671 })
672 } else if is_xlora && is_prompt {
673 let metadata = get_prompt_input(
674 input_seqs
675 .iter()
676 .map(|seq| seq.get_toks())
677 .collect::<Vec<_>>(),
678 input_seqs,
679 device,
680 last_n_context_len,
681 return_raw_logits,
682 paged_attn_metadata.as_mut(),
683 mapper,
684 )?;
685 let InnerInputProcessorOutput {
686 inputs:
687 InputMetadata {
688 input: input_ids,
689 positions: seqlen_offsets,
690 context_lens,
691 position_ids,
692 paged_attn_meta,
693 flash_meta,
694 },
695 seq_indices,
696 } = metadata;
697 let inputs: Box<dyn Any> = Box::new(ModelInputs {
698 input_ids: input_ids.clone(),
699 input_ids_full: Some(input_ids),
700 seqlen_offsets: seqlen_offsets.clone(),
701 seqlen_offsets_full: Some(seqlen_offsets),
702 context_lens,
703 position_ids,
704 paged_attn_meta,
705 flash_meta: flash_meta.clone(),
706 flash_meta_full: Some(flash_meta),
707 });
708 Ok(InputProcessorOutput {
709 inputs,
710 seq_indices,
711 })
712 } else if is_prompt {
713 let metadata = get_prompt_input(
714 input_seqs
715 .iter()
716 .map(|seq| seq.get_toks())
717 .collect::<Vec<_>>(),
718 input_seqs,
719 device,
720 last_n_context_len,
721 return_raw_logits,
722 paged_attn_metadata.as_mut(),
723 mapper,
724 )?;
725 let InnerInputProcessorOutput {
726 inputs:
727 InputMetadata {
728 input: input_ids,
729 positions: seqlen_offsets,
730 context_lens,
731 position_ids,
732 paged_attn_meta,
733 flash_meta,
734 },
735 seq_indices,
736 } = metadata;
737 let inputs: Box<dyn Any> = Box::new(ModelInputs {
738 input_ids,
739 input_ids_full: None,
740 seqlen_offsets,
741 seqlen_offsets_full: None,
742 context_lens,
743 position_ids,
744 paged_attn_meta,
745 flash_meta,
746 flash_meta_full: None,
747 });
748 Ok(InputProcessorOutput {
749 inputs,
750 seq_indices,
751 })
752 } else {
753 let metadata = get_completion_input(
754 input_seqs
755 .iter()
756 .map(|seq| seq.get_toks())
757 .collect::<Vec<_>>(),
758 input_seqs,
759 device,
760 no_kv_cache,
761 last_n_context_len,
762 return_raw_logits,
763 paged_attn_metadata.as_mut(),
764 mapper,
765 )?;
766 let InnerInputProcessorOutput {
767 inputs:
768 InputMetadata {
769 input: input_ids,
770 positions: seqlen_offsets,
771 context_lens,
772 position_ids,
773 paged_attn_meta,
774 flash_meta,
775 },
776 seq_indices,
777 } = metadata;
778 let inputs: Box<dyn Any> = Box::new(ModelInputs {
779 input_ids,
780 input_ids_full: None,
781 seqlen_offsets,
782 seqlen_offsets_full: None,
783 context_lens,
784 position_ids,
785 paged_attn_meta,
786 flash_meta,
787 flash_meta_full: None,
788 });
789 Ok(InputProcessorOutput {
790 inputs,
791 seq_indices,
792 })
793 }
794 }
795
796 fn get_type(&self) -> InputsProcessorType {
797 InputsProcessorType::Text
798 }
799 }
800}