mistralrs_core/pipeline/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14    Text,
15    Vision,
16}
17
18pub struct InputProcessorOutput {
19    pub inputs: Box<dyn Any>,
20    pub seq_indices: Vec<usize>,
21}
22
23/// Processor: Prepare inputs for the model (potentially preparing the images if applicable)
24pub trait InputsProcessor {
25    /// This should also enable matmul via f16 if prompt and the sequence length is greater than 32.
26    /// Otherwise, matmul via f16 is disabled.
27    ///
28    /// This should return a type which can be downcasted to the proper type as used in `forward_inputs`
29    #[allow(clippy::too_many_arguments)]
30    fn process_inputs(
31        &self,
32        tokenizer: Option<Arc<Tokenizer>>,
33        input_seqs: &mut [&mut Sequence],
34        is_prompt: bool,
35        is_xlora: bool,
36        device: &Device,
37        no_kv_cache: bool,
38        last_n_context_len: Option<(usize, usize)>,
39        return_raw_logits: bool,
40        other_config: Option<Arc<dyn Any>>,
41        paged_attn_metadata: Option<PagedAttentionMeta>,
42        mapper: Option<&dyn DeviceMapper>,
43    ) -> Result<InputProcessorOutput>;
44
45    fn get_type(&self) -> InputsProcessorType;
46}
47
48// ========================= Test models input processor
49
50pub mod text_models_inputs_processor {
51    use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
52
53    use anyhow::Result;
54    use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
55    use tokenizers::Tokenizer;
56
57    use crate::{
58        device_map::DeviceMapper,
59        get_mut_arcmutex,
60        paged_attention::{BlockEngine, _PAD_SLOT_ID},
61        sequence::Sequence,
62    };
63
64    use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
65
66    fn _make_tensor_with_pad<D: WithDType>(
67        x: Vec<Vec<D>>,
68        max_len: usize,
69        pad: D,
70        device: &Device,
71    ) -> Result<Tensor> {
72        let mut padded_x = Vec::new();
73        for mut x_i in x {
74            assert!(x_i.len() <= max_len);
75            x_i.extend([pad].repeat(max_len - x_i.len()));
76            let shape = (x_i.len(),);
77            padded_x.push(Tensor::from_vec(x_i, shape, device)?);
78        }
79        Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
80    }
81
82    pub struct PagedAttentionMeta {
83        pub sliding_window: Option<usize>,
84        pub block_size: usize,
85        pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
86    }
87
88    #[derive(Clone, Debug)]
89    #[allow(dead_code)]
90    pub struct PagedAttentionInputMetadata {
91        pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
92        pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
93        pub slot_mappings: HashMap<DeviceLocation, Tensor>,
94        pub max_context_len: Option<usize>,
95        pub is_first_prompt_chunk: bool,
96    }
97
98    impl PagedAttentionInputMetadata {
99        /// Create a dummy input metadata, assuming that this will NOT be used for decoding.
100        /// This is used for the case of imatrix generation.
101        pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
102            Ok(PagedAttentionInputMetadata {
103                block_tables: None,
104                context_lens: None,
105                max_context_len: None,
106                slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
107                is_first_prompt_chunk: true,
108            })
109        }
110    }
111
112    #[derive(Clone, Debug)]
113    pub struct FlashParams {
114        pub max_q: u32,
115        pub max_k: u32,
116        pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
117        pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
118    }
119
120    pub struct InputMetadata {
121        pub input: Tensor,
122        pub positions: Vec<usize>,
123        pub context_lens: Vec<(usize, usize)>, // (start index, len)
124        pub position_ids: Vec<usize>,
125        pub paged_attn_meta: Option<PagedAttentionInputMetadata>, // For paged attention
126        pub flash_meta: FlashParams,
127    }
128
129    pub struct InnerInputProcessorOutput {
130        pub inputs: InputMetadata,
131        pub seq_indices: Vec<usize>,
132    }
133
134    // chunk_offset_toks is the number of tokens by which the tokens are offset,
135    // chunk_offset_toks / prompt_chunksize = number of batches
136    #[allow(clippy::too_many_arguments)]
137    pub fn make_prompt_chunk<T: WithDType + Debug>(
138        chunk_offset_toks: usize,
139        toks: Vec<&[T]>,
140        seq_ids: &[usize],
141        device: &Device,
142        last_n_context_len: Option<(usize, usize)>,
143        return_raw_logits: bool,
144        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
145        mapper: Option<&dyn DeviceMapper>,
146    ) -> Result<InputMetadata> {
147        let max_len = toks
148            .iter()
149            .map(|seq| seq.len())
150            .max()
151            .expect("No sequences");
152        let padding_tok = T::zero();
153        // Pad each sequence by the padding token to the max len.
154        let mut seqs_tensors = Vec::new();
155        let mut seqlen_offsets = Vec::new();
156        let mut context_lens = Vec::new();
157        let mut position_ids = Vec::new();
158        let mut slot_mappings = Vec::new();
159        let mut block_tables = Vec::new();
160        let mut paged_attn_context_lens = Vec::new();
161        let flash_attn = crate::using_flash_attn();
162        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
163        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
164        for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
165            let prompt_len = ctxt.len();
166            let offset = last_n_context_len.unwrap_or_default();
167            seqlen_offsets.push(offset.1 + chunk_offset_toks);
168
169            position_ids.push(ctxt.len() + chunk_offset_toks);
170            let mut ctxt = ctxt.to_vec();
171            ctxt.extend(std::iter::repeat_n(
172                padding_tok,
173                max_len.saturating_sub(ctxt.len()),
174            ));
175            // If we are returning raw logits, we want to not trim the logits at all.
176            if return_raw_logits {
177                if last_n_context_len.is_some() {
178                    anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
179                }
180
181                context_lens.push((0, ctxt.len()));
182            } else {
183                context_lens.push((
184                    ctxt.len()
185                        .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
186                    last_n_context_len.map(|(a, _)| a).unwrap_or(1),
187                ));
188            }
189
190            if flash_attn {
191                seqlens_q.push(ctxt.len() as u32);
192                seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
193            }
194
195            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
196
197            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
198                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
199                let table = block_engine.block_tables.get(seq_id);
200
201                if table.is_none() {
202                    // Will be None during profiling.
203                    slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
204                    continue;
205                }
206                let table = table
207                    .unwrap()
208                    .iter()
209                    .map(|block| block.deref_mut().block_id)
210                    .collect::<Vec<_>>();
211
212                let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
213                    prompt_len.saturating_sub(sliding_window)
214                } else {
215                    0
216                };
217
218                let mut slot_mapping = Vec::new();
219                let mut ctxt_len = Vec::new();
220                for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
221                    if i < start_idx {
222                        // Pad [0,start_idx) with _PAD_TOKEN_ID
223                        slot_mapping.push(_PAD_SLOT_ID);
224                    }
225                    ctxt_len.push(i);
226
227                    let block_number = if i / paged_attn_metadata.block_size >= table.len() {
228                        panic!(
229                            "Block table is too small (prompt)! i={} block_size={} table_len={}",
230                            i,
231                            paged_attn_metadata.block_size,
232                            table.len()
233                        );
234                    } else {
235                        table.get(i / paged_attn_metadata.block_size).unwrap()
236                    };
237                    let block_offset = i % paged_attn_metadata.block_size;
238                    let slot = block_number * paged_attn_metadata.block_size + block_offset;
239                    slot_mapping.push(slot.try_into().unwrap());
240                    block_tables.push(table.clone());
241                }
242                slot_mappings.push(slot_mapping);
243                paged_attn_context_lens.push(ctxt_len);
244            }
245        }
246
247        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
248            let max_q = *seqlens_q.iter().max().unwrap();
249            let max_k = *seqlens_k.iter().max().unwrap();
250            let seqlens_q = Tensor::new(seqlens_q, device)?
251                .to_dtype(DType::F32)?
252                .cumsum(0)?
253                .to_dtype(DType::U32)?;
254            let seqlens_k = Tensor::new(seqlens_k, device)?
255                .to_dtype(DType::F32)?
256                .cumsum(0)?
257                .to_dtype(DType::U32)?;
258
259            let mut seqlens_q_map = HashMap::new();
260            let mut seqlens_k_map = HashMap::new();
261
262            let devices = mapper.unwrap().get_unique_devices();
263            for device in devices {
264                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
265                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
266            }
267            (max_q, max_k, seqlens_q_map, seqlens_k_map)
268        } else {
269            (0, 0, HashMap::new(), HashMap::new())
270        };
271
272        let input = Tensor::cat(&seqs_tensors, 0).unwrap();
273
274        let paged_attn_meta = if paged_attn_metadata.is_some() {
275            let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
276            let slot_mappings =
277                _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
278
279            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
280            let block_tables = _make_tensor_with_pad(
281                block_tables
282                    .iter()
283                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
284                    .collect::<Vec<_>>(),
285                max_block_table_len,
286                0,
287                device,
288            )?;
289            let block_tables = block_tables.reshape(((), max_block_table_len))?;
290
291            let max_context_len = paged_attn_context_lens
292                .iter()
293                .map(|x| x.len())
294                .max()
295                .unwrap();
296
297            let context_lens = _make_tensor_with_pad(
298                paged_attn_context_lens
299                    .iter()
300                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
301                    .collect::<Vec<_>>(),
302                max_context_len,
303                0,
304                device,
305            )?
306            .reshape(((),))?;
307
308            // For device mapping, make a copy of each tensor for each device
309            let devices = mapper.unwrap().get_unique_devices();
310            let mut slot_mappings_map = HashMap::new();
311            let mut block_tables_map = HashMap::new();
312            let mut context_lens_map = HashMap::new();
313
314            for device in devices {
315                slot_mappings_map
316                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
317                block_tables_map
318                    .insert(device.location(), block_tables.clone().to_device(&device)?);
319                context_lens_map
320                    .insert(device.location(), context_lens.clone().to_device(&device)?);
321            }
322
323            Some(PagedAttentionInputMetadata {
324                slot_mappings: slot_mappings_map,
325                block_tables: Some(block_tables_map),
326                context_lens: Some(context_lens_map),
327                max_context_len: Some(max_context_len),
328                is_first_prompt_chunk: chunk_offset_toks == 0,
329            })
330        } else {
331            None
332        };
333
334        Ok(InputMetadata {
335            input,
336            positions: seqlen_offsets,
337            context_lens,
338            position_ids,
339            paged_attn_meta,
340            flash_meta: FlashParams {
341                max_k,
342                max_q,
343                cumulative_seqlens_k: seqlens_k_map,
344                cumulative_seqlens_q: seqlens_q_map,
345            },
346        })
347    }
348
349    fn make_completion_chunk<T: WithDType>(
350        toks: Vec<&[T]>,
351        input_seqs: &[&mut Sequence],
352        device: &Device,
353        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
354        mapper: Option<&dyn DeviceMapper>,
355    ) -> Result<InputMetadata> {
356        // Pad each sequence by the padding token to the max len.
357        let flash_attn = crate::using_flash_attn();
358        let mut seqs_tensors = Vec::new();
359        let mut seqlen_offsets = Vec::new();
360        let mut context_lens = Vec::new();
361        let mut position_ids = Vec::new();
362
363        let mut slot_mappings = Vec::new();
364        let mut block_tables = Vec::new();
365        let mut paged_attn_context_lens = Vec::new();
366        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
367        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
368        for (seq, ctxt) in input_seqs.iter().zip(toks) {
369            let start_pos = ctxt.len().saturating_sub(1);
370            let ctxt = ctxt[start_pos..].to_vec();
371            seqlen_offsets.push(start_pos);
372            context_lens.push((0, 1));
373            position_ids.push(seq.len());
374
375            if flash_attn {
376                seqlens_q.push(ctxt.len() as u32);
377                seqlens_k.push((ctxt.len() + start_pos) as u32);
378            }
379
380            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
381
382            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
383                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
384                let table = block_engine.block_tables.get(seq.id()).unwrap();
385
386                let table = table
387                    .iter()
388                    .map(|block| block.deref_mut().block_id)
389                    .collect::<Vec<_>>();
390
391                let block_pos = start_pos - seq.token_offset();
392                let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
393                    panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
394                } else {
395                    table
396                        .get(block_pos / paged_attn_metadata.block_size)
397                        .unwrap()
398                };
399                let block_offset = block_pos % paged_attn_metadata.block_size;
400                let slot = block_number * paged_attn_metadata.block_size + block_offset;
401                let slot = slot.try_into().unwrap();
402                slot_mappings.push(vec![slot]);
403
404                if let Some(sliding_window) = paged_attn_metadata.sliding_window {
405                    let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
406                    let slide_idx = if table.len() > sliding_window_blocks {
407                        table.len() - sliding_window_blocks
408                    } else {
409                        0
410                    };
411                    block_tables.push(table.get(slide_idx..).unwrap().to_vec());
412                } else {
413                    block_tables.push(table);
414                }
415
416                let paged_attn_context_len =
417                    if let Some(sliding_window) = paged_attn_metadata.sliding_window {
418                        seq.len().min(sliding_window)
419                    } else {
420                        seq.len()
421                    };
422                paged_attn_context_lens.push(paged_attn_context_len);
423            }
424        }
425
426        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
427            let max_q = *seqlens_q.iter().max().unwrap();
428            let max_k = *seqlens_k.iter().max().unwrap();
429            let seqlens_q = Tensor::new(seqlens_q, device)?
430                .to_dtype(DType::F32)?
431                .cumsum(0)?
432                .to_dtype(DType::U32)?;
433            let seqlens_k = Tensor::new(seqlens_k, device)?
434                .to_dtype(DType::F32)?
435                .cumsum(0)?
436                .to_dtype(DType::U32)?;
437
438            let mut seqlens_q_map = HashMap::new();
439            let mut seqlens_k_map = HashMap::new();
440
441            let devices = mapper.unwrap().get_unique_devices();
442            for device in devices {
443                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
444                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
445            }
446            (max_q, max_k, seqlens_q_map, seqlens_k_map)
447        } else {
448            (0, 0, HashMap::new(), HashMap::new())
449        };
450
451        let paged_attn_meta = if paged_attn_metadata.is_some() {
452            let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
453
454            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
455
456            let block_tables = _make_tensor_with_pad(
457                block_tables
458                    .iter()
459                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
460                    .collect::<Vec<_>>(),
461                max_block_table_len,
462                0,
463                device,
464            )?;
465            let block_tables = block_tables.reshape(((), max_block_table_len))?;
466
467            let max_context_len = paged_attn_context_lens.iter().max().unwrap();
468
469            let context_lens = Tensor::from_vec(
470                paged_attn_context_lens
471                    .iter()
472                    .map(|x| *x as u32)
473                    .collect::<Vec<_>>(),
474                (paged_attn_context_lens.len(),),
475                device,
476            )?;
477
478            // For device mapping, make a copy of each tensor for each device
479            let devices = mapper.unwrap().get_unique_devices();
480            let mut slot_mappings_map = HashMap::new();
481            let mut block_tables_map = HashMap::new();
482            let mut context_lens_map = HashMap::new();
483
484            for device in devices {
485                slot_mappings_map
486                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
487                block_tables_map
488                    .insert(device.location(), block_tables.clone().to_device(&device)?);
489                context_lens_map
490                    .insert(device.location(), context_lens.clone().to_device(&device)?);
491            }
492
493            Some(PagedAttentionInputMetadata {
494                slot_mappings: slot_mappings_map,
495                block_tables: Some(block_tables_map),
496                context_lens: Some(context_lens_map),
497                max_context_len: Some(*max_context_len),
498                is_first_prompt_chunk: false,
499            })
500        } else {
501            None
502        };
503
504        Ok(InputMetadata {
505            input: Tensor::cat(&seqs_tensors, 0).unwrap(),
506            positions: seqlen_offsets,
507            context_lens,
508            position_ids,
509            paged_attn_meta,
510            flash_meta: FlashParams {
511                max_k,
512                max_q,
513                cumulative_seqlens_k: seqlens_k_map,
514                cumulative_seqlens_q: seqlens_q_map,
515            },
516        })
517    }
518
519    #[allow(clippy::too_many_arguments)]
520    pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
521        toks: Vec<&[T]>,
522        input_seqs: &[&mut Sequence],
523        device: &Device,
524        last_n_context_len: Option<(usize, usize)>,
525        return_raw_logits: bool,
526        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
527        mapper: Option<&dyn DeviceMapper>,
528    ) -> Result<InnerInputProcessorOutput> {
529        let offset = input_seqs[0].token_offset();
530        make_prompt_chunk(
531            offset,
532            toks,
533            &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
534            device,
535            last_n_context_len,
536            return_raw_logits,
537            paged_attn_metadata,
538            mapper,
539        )
540        .map(|inputs| InnerInputProcessorOutput {
541            inputs,
542            seq_indices: (0..input_seqs.len()).collect(),
543        })
544    }
545
546    #[allow(clippy::too_many_arguments)]
547    pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
548        toks: Vec<&[T]>,
549        input_seqs: &[&mut Sequence],
550        device: &Device,
551        no_kv_cache: bool,
552        last_n_context_len: Option<(usize, usize)>,
553        return_raw_logits: bool,
554        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
555        mapper: Option<&dyn DeviceMapper>,
556    ) -> Result<InnerInputProcessorOutput> {
557        if no_kv_cache {
558            return get_prompt_input(
559                toks,
560                input_seqs,
561                device,
562                last_n_context_len,
563                return_raw_logits,
564                paged_attn_metadata,
565                mapper,
566            );
567        }
568
569        make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
570            InnerInputProcessorOutput {
571                inputs,
572                seq_indices: (0..input_seqs.len()).collect(),
573            }
574        })
575    }
576
577    #[derive(Clone)]
578    pub struct ModelInputs {
579        pub input_ids: Tensor,
580        pub input_ids_full: Option<Tensor>,
581        pub seqlen_offsets: Vec<usize>,
582        pub seqlen_offsets_full: Option<Vec<usize>>,
583        pub context_lens: Vec<(usize, usize)>,
584        pub position_ids: Vec<usize>,
585        pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
586        pub flash_meta: FlashParams,
587        pub flash_meta_full: Option<FlashParams>,
588    }
589
590    pub struct TextInputsProcessor;
591
592    impl InputsProcessor for TextInputsProcessor {
593        fn process_inputs(
594            &self,
595            _: Option<Arc<Tokenizer>>,
596            input_seqs: &mut [&mut Sequence],
597            is_prompt: bool,
598            is_xlora: bool,
599            device: &Device,
600            no_kv_cache: bool,
601            last_n_context_len: Option<(usize, usize)>,
602            return_raw_logits: bool,
603            _: Option<Arc<dyn Any>>,
604            mut paged_attn_metadata: Option<PagedAttentionMeta>,
605            mapper: Option<&dyn DeviceMapper>,
606        ) -> Result<InputProcessorOutput> {
607            if is_xlora && !is_prompt {
608                let prompt = get_prompt_input(
609                    input_seqs
610                        .iter()
611                        .map(|seq| seq.get_toks())
612                        .collect::<Vec<_>>(),
613                    input_seqs,
614                    device,
615                    last_n_context_len,
616                    return_raw_logits,
617                    paged_attn_metadata.as_mut(),
618                    mapper,
619                )?;
620                let completion = get_completion_input(
621                    input_seqs
622                        .iter()
623                        .map(|seq| seq.get_toks())
624                        .collect::<Vec<_>>(),
625                    input_seqs,
626                    device,
627                    no_kv_cache,
628                    last_n_context_len,
629                    return_raw_logits,
630                    paged_attn_metadata.as_mut(),
631                    mapper,
632                )?;
633                let InnerInputProcessorOutput {
634                    inputs:
635                        InputMetadata {
636                            input: input_ids_full,
637                            positions: seqlen_offsets_full,
638                            context_lens: _,
639                            position_ids,
640                            paged_attn_meta: _,
641                            flash_meta: flash_meta_full,
642                        },
643                    seq_indices,
644                } = prompt;
645                let InnerInputProcessorOutput {
646                    inputs:
647                        InputMetadata {
648                            input: input_ids,
649                            positions: seqlen_offsets,
650                            context_lens,
651                            position_ids: _,
652                            paged_attn_meta,
653                            flash_meta,
654                        },
655                    seq_indices: _,
656                } = completion;
657                let inputs: Box<dyn Any> = Box::new(ModelInputs {
658                    input_ids,
659                    input_ids_full: Some(input_ids_full),
660                    seqlen_offsets,
661                    seqlen_offsets_full: Some(seqlen_offsets_full),
662                    context_lens,
663                    position_ids,
664                    paged_attn_meta,
665                    flash_meta,
666                    flash_meta_full: Some(flash_meta_full),
667                });
668                Ok(InputProcessorOutput {
669                    inputs,
670                    seq_indices,
671                })
672            } else if is_xlora && is_prompt {
673                let metadata = get_prompt_input(
674                    input_seqs
675                        .iter()
676                        .map(|seq| seq.get_toks())
677                        .collect::<Vec<_>>(),
678                    input_seqs,
679                    device,
680                    last_n_context_len,
681                    return_raw_logits,
682                    paged_attn_metadata.as_mut(),
683                    mapper,
684                )?;
685                let InnerInputProcessorOutput {
686                    inputs:
687                        InputMetadata {
688                            input: input_ids,
689                            positions: seqlen_offsets,
690                            context_lens,
691                            position_ids,
692                            paged_attn_meta,
693                            flash_meta,
694                        },
695                    seq_indices,
696                } = metadata;
697                let inputs: Box<dyn Any> = Box::new(ModelInputs {
698                    input_ids: input_ids.clone(),
699                    input_ids_full: Some(input_ids),
700                    seqlen_offsets: seqlen_offsets.clone(),
701                    seqlen_offsets_full: Some(seqlen_offsets),
702                    context_lens,
703                    position_ids,
704                    paged_attn_meta,
705                    flash_meta: flash_meta.clone(),
706                    flash_meta_full: Some(flash_meta),
707                });
708                Ok(InputProcessorOutput {
709                    inputs,
710                    seq_indices,
711                })
712            } else if is_prompt {
713                let metadata = get_prompt_input(
714                    input_seqs
715                        .iter()
716                        .map(|seq| seq.get_toks())
717                        .collect::<Vec<_>>(),
718                    input_seqs,
719                    device,
720                    last_n_context_len,
721                    return_raw_logits,
722                    paged_attn_metadata.as_mut(),
723                    mapper,
724                )?;
725                let InnerInputProcessorOutput {
726                    inputs:
727                        InputMetadata {
728                            input: input_ids,
729                            positions: seqlen_offsets,
730                            context_lens,
731                            position_ids,
732                            paged_attn_meta,
733                            flash_meta,
734                        },
735                    seq_indices,
736                } = metadata;
737                let inputs: Box<dyn Any> = Box::new(ModelInputs {
738                    input_ids,
739                    input_ids_full: None,
740                    seqlen_offsets,
741                    seqlen_offsets_full: None,
742                    context_lens,
743                    position_ids,
744                    paged_attn_meta,
745                    flash_meta,
746                    flash_meta_full: None,
747                });
748                Ok(InputProcessorOutput {
749                    inputs,
750                    seq_indices,
751                })
752            } else {
753                let metadata = get_completion_input(
754                    input_seqs
755                        .iter()
756                        .map(|seq| seq.get_toks())
757                        .collect::<Vec<_>>(),
758                    input_seqs,
759                    device,
760                    no_kv_cache,
761                    last_n_context_len,
762                    return_raw_logits,
763                    paged_attn_metadata.as_mut(),
764                    mapper,
765                )?;
766                let InnerInputProcessorOutput {
767                    inputs:
768                        InputMetadata {
769                            input: input_ids,
770                            positions: seqlen_offsets,
771                            context_lens,
772                            position_ids,
773                            paged_attn_meta,
774                            flash_meta,
775                        },
776                    seq_indices,
777                } = metadata;
778                let inputs: Box<dyn Any> = Box::new(ModelInputs {
779                    input_ids,
780                    input_ids_full: None,
781                    seqlen_offsets,
782                    seqlen_offsets_full: None,
783                    context_lens,
784                    position_ids,
785                    paged_attn_meta,
786                    flash_meta,
787                    flash_meta_full: None,
788                });
789                Ok(InputProcessorOutput {
790                    inputs,
791                    seq_indices,
792                })
793            }
794        }
795
796        fn get_type(&self) -> InputsProcessorType {
797            InputsProcessorType::Text
798        }
799    }
800}