mistralrs_core/pipeline/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14    Text,
15    Vision,
16    Embedding,
17}
18
19pub struct InputProcessorOutput {
20    pub inputs: Box<dyn Any>,
21    pub seq_indices: Vec<usize>,
22}
23
24/// Processor: Prepare inputs for the model (potentially preparing the images if applicable)
25pub trait InputsProcessor {
26    /// This should also enable matmul via f16 if prompt and the sequence length is greater than 32.
27    /// Otherwise, matmul via f16 is disabled.
28    ///
29    /// This should return a type which can be downcasted to the proper type as used in `forward_inputs`
30    #[allow(clippy::too_many_arguments)]
31    fn process_inputs(
32        &self,
33        tokenizer: Option<Arc<Tokenizer>>,
34        input_seqs: &mut [&mut Sequence],
35        is_prompt: bool,
36        is_xlora: bool,
37        device: &Device,
38        no_kv_cache: bool,
39        last_n_context_len: Option<(usize, usize)>,
40        return_raw_logits: bool,
41        other_config: Option<Arc<dyn Any>>,
42        paged_attn_metadata: Option<PagedAttentionMeta>,
43        mapper: Option<&dyn DeviceMapper>,
44    ) -> Result<InputProcessorOutput>;
45
46    fn get_type(&self) -> InputsProcessorType;
47}
48
49// ========================= Test models input processor
50
51pub mod text_models_inputs_processor {
52    use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
53
54    use anyhow::Result;
55    use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
56    use tokenizers::Tokenizer;
57
58    use crate::{
59        device_map::DeviceMapper,
60        get_mut_arcmutex,
61        paged_attention::{BlockEngine, _PAD_SLOT_ID},
62        sequence::Sequence,
63    };
64
65    use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
66
67    fn _make_tensor_with_pad<D: WithDType>(
68        x: Vec<Vec<D>>,
69        max_len: usize,
70        pad: D,
71        device: &Device,
72    ) -> Result<Tensor> {
73        let mut padded_x = Vec::new();
74        for mut x_i in x {
75            assert!(x_i.len() <= max_len);
76            x_i.extend([pad].repeat(max_len - x_i.len()));
77            let shape = (x_i.len(),);
78            padded_x.push(Tensor::from_vec(x_i, shape, device)?);
79        }
80        Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
81    }
82
83    pub struct PagedAttentionMeta {
84        pub sliding_window: Option<usize>,
85        pub block_size: usize,
86        pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
87    }
88
89    #[derive(Clone, Debug)]
90    #[allow(dead_code)]
91    pub struct PagedAttentionInputMetadata {
92        pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
93        pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
94        pub slot_mappings: HashMap<DeviceLocation, Tensor>,
95        pub max_context_len: Option<usize>,
96        pub is_first_prompt_chunk: bool,
97    }
98
99    impl PagedAttentionInputMetadata {
100        /// Create a dummy input metadata, assuming that this will NOT be used for decoding.
101        /// This is used for the case of imatrix generation.
102        pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
103            Ok(PagedAttentionInputMetadata {
104                block_tables: None,
105                context_lens: None,
106                max_context_len: None,
107                slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
108                is_first_prompt_chunk: true,
109            })
110        }
111    }
112
113    #[derive(Clone, Debug)]
114    pub struct FlashParams {
115        pub max_q: u32,
116        pub max_k: u32,
117        pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
118        pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
119        pub causal: bool,
120    }
121
122    pub struct InputMetadata {
123        pub input: Tensor,
124        pub positions: Vec<usize>,
125        pub context_lens: Vec<(usize, usize)>, // (start index, len)
126        pub position_ids: Vec<usize>,
127        pub paged_attn_meta: Option<PagedAttentionInputMetadata>, // For paged attention
128        pub flash_meta: FlashParams,
129    }
130
131    pub struct InnerInputProcessorOutput {
132        pub inputs: InputMetadata,
133        pub seq_indices: Vec<usize>,
134    }
135
136    // chunk_offset_toks is the number of tokens by which the tokens are offset,
137    // chunk_offset_toks / prompt_chunksize = number of batches
138    #[allow(clippy::too_many_arguments)]
139    pub fn make_prompt_chunk<T: WithDType + Debug>(
140        chunk_offset_toks: usize,
141        toks: Vec<&[T]>,
142        seq_ids: &[usize],
143        device: &Device,
144        last_n_context_len: Option<(usize, usize)>,
145        return_raw_logits: bool,
146        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
147        mapper: Option<&dyn DeviceMapper>,
148    ) -> Result<InputMetadata> {
149        let max_len = toks
150            .iter()
151            .map(|seq| seq.len())
152            .max()
153            .expect("No sequences");
154        let padding_tok = T::zero();
155        // Pad each sequence by the padding token to the max len.
156        let mut seqs_tensors = Vec::new();
157        let mut seqlen_offsets = Vec::new();
158        let mut context_lens = Vec::new();
159        let mut position_ids = Vec::new();
160        let mut slot_mappings = Vec::new();
161        let mut block_tables = Vec::new();
162        let mut paged_attn_context_lens = Vec::new();
163        let flash_attn = crate::using_flash_attn();
164        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
165        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
166        for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
167            let prompt_len = ctxt.len();
168            let offset = last_n_context_len.unwrap_or_default();
169            seqlen_offsets.push(offset.1 + chunk_offset_toks);
170
171            position_ids.push(ctxt.len() + chunk_offset_toks);
172            let mut ctxt = ctxt.to_vec();
173            ctxt.extend(std::iter::repeat_n(
174                padding_tok,
175                max_len.saturating_sub(ctxt.len()),
176            ));
177            // If we are returning raw logits, we want to not trim the logits at all.
178            if return_raw_logits {
179                if last_n_context_len.is_some() {
180                    anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181                }
182
183                context_lens.push((0, ctxt.len()));
184            } else {
185                context_lens.push((
186                    ctxt.len()
187                        .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
188                    last_n_context_len.map(|(a, _)| a).unwrap_or(1),
189                ));
190            }
191
192            if flash_attn {
193                seqlens_q.push(ctxt.len() as u32);
194                seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
195            }
196
197            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
198
199            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
200                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
201                let table = block_engine.block_tables.get(seq_id);
202
203                if table.is_none() {
204                    // Will be None during profiling.
205                    slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
206                    continue;
207                }
208                let table = table
209                    .unwrap()
210                    .iter()
211                    .map(|block| block.deref_mut().block_id)
212                    .collect::<Vec<_>>();
213
214                let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
215                    prompt_len.saturating_sub(sliding_window)
216                } else {
217                    0
218                };
219
220                let mut slot_mapping = Vec::new();
221                let mut ctxt_len = Vec::new();
222                for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223                    if i < start_idx {
224                        // Pad [0,start_idx) with _PAD_TOKEN_ID
225                        slot_mapping.push(_PAD_SLOT_ID);
226                    }
227                    ctxt_len.push(i);
228
229                    let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230                        panic!(
231                            "Block table is too small (prompt)! i={} block_size={} table_len={}",
232                            i,
233                            paged_attn_metadata.block_size,
234                            table.len()
235                        );
236                    } else {
237                        table.get(i / paged_attn_metadata.block_size).unwrap()
238                    };
239                    let block_offset = i % paged_attn_metadata.block_size;
240                    let slot = block_number * paged_attn_metadata.block_size + block_offset;
241                    slot_mapping.push(slot.try_into().unwrap());
242                    block_tables.push(table.clone());
243                }
244                slot_mappings.push(slot_mapping);
245                paged_attn_context_lens.push(ctxt_len);
246            }
247        }
248
249        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
250            let max_q = *seqlens_q.iter().max().unwrap();
251            let max_k = *seqlens_k.iter().max().unwrap();
252            let seqlens_q = Tensor::new(seqlens_q, device)?
253                .to_dtype(DType::F32)?
254                .cumsum(0)?
255                .to_dtype(DType::U32)?;
256            let seqlens_k = Tensor::new(seqlens_k, device)?
257                .to_dtype(DType::F32)?
258                .cumsum(0)?
259                .to_dtype(DType::U32)?;
260
261            let mut seqlens_q_map = HashMap::new();
262            let mut seqlens_k_map = HashMap::new();
263
264            let devices = mapper.unwrap().get_unique_devices();
265            for device in devices {
266                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
267                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
268            }
269            (max_q, max_k, seqlens_q_map, seqlens_k_map)
270        } else {
271            (0, 0, HashMap::new(), HashMap::new())
272        };
273
274        let input = Tensor::cat(&seqs_tensors, 0).unwrap();
275
276        let paged_attn_meta = if paged_attn_metadata.is_some() {
277            let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
278            let slot_mappings =
279                _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
280
281            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
282            let block_tables = _make_tensor_with_pad(
283                block_tables
284                    .iter()
285                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
286                    .collect::<Vec<_>>(),
287                max_block_table_len,
288                0,
289                device,
290            )?;
291            let block_tables = block_tables.reshape(((), max_block_table_len))?;
292
293            let max_context_len = paged_attn_context_lens
294                .iter()
295                .map(|x| x.len())
296                .max()
297                .unwrap();
298
299            let context_lens = _make_tensor_with_pad(
300                paged_attn_context_lens
301                    .iter()
302                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
303                    .collect::<Vec<_>>(),
304                max_context_len,
305                0,
306                device,
307            )?
308            .reshape(((),))?;
309
310            // For device mapping, make a copy of each tensor for each device
311            let devices = mapper.unwrap().get_unique_devices();
312            let mut slot_mappings_map = HashMap::new();
313            let mut block_tables_map = HashMap::new();
314            let mut context_lens_map = HashMap::new();
315
316            for device in devices {
317                slot_mappings_map
318                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
319                block_tables_map
320                    .insert(device.location(), block_tables.clone().to_device(&device)?);
321                context_lens_map
322                    .insert(device.location(), context_lens.clone().to_device(&device)?);
323            }
324
325            Some(PagedAttentionInputMetadata {
326                slot_mappings: slot_mappings_map,
327                block_tables: Some(block_tables_map),
328                context_lens: Some(context_lens_map),
329                max_context_len: Some(max_context_len),
330                is_first_prompt_chunk: chunk_offset_toks == 0,
331            })
332        } else {
333            None
334        };
335
336        Ok(InputMetadata {
337            input,
338            positions: seqlen_offsets,
339            context_lens,
340            position_ids,
341            paged_attn_meta,
342            flash_meta: FlashParams {
343                max_k,
344                max_q,
345                cumulative_seqlens_k: seqlens_k_map,
346                cumulative_seqlens_q: seqlens_q_map,
347                causal: true,
348            },
349        })
350    }
351
352    fn make_completion_chunk<T: WithDType>(
353        toks: Vec<&[T]>,
354        input_seqs: &[&mut Sequence],
355        device: &Device,
356        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
357        mapper: Option<&dyn DeviceMapper>,
358    ) -> Result<InputMetadata> {
359        // Pad each sequence by the padding token to the max len.
360        let flash_attn = crate::using_flash_attn();
361        let mut seqs_tensors = Vec::new();
362        let mut seqlen_offsets = Vec::new();
363        let mut context_lens = Vec::new();
364        let mut position_ids = Vec::new();
365
366        let mut slot_mappings = Vec::new();
367        let mut block_tables = Vec::new();
368        let mut paged_attn_context_lens = Vec::new();
369        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
370        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
371        for (seq, ctxt) in input_seqs.iter().zip(toks) {
372            let start_pos = ctxt.len().saturating_sub(1);
373            let ctxt = ctxt[start_pos..].to_vec();
374            seqlen_offsets.push(start_pos);
375            context_lens.push((0, 1));
376            position_ids.push(seq.len());
377
378            if flash_attn {
379                seqlens_q.push(ctxt.len() as u32);
380                seqlens_k.push((ctxt.len() + start_pos) as u32);
381            }
382
383            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
384
385            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
386                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
387                let table = block_engine.block_tables.get(seq.id()).unwrap();
388
389                let table = table
390                    .iter()
391                    .map(|block| block.deref_mut().block_id)
392                    .collect::<Vec<_>>();
393
394                let block_pos = start_pos - seq.token_offset();
395                let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
396                    panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
397                } else {
398                    table
399                        .get(block_pos / paged_attn_metadata.block_size)
400                        .unwrap()
401                };
402                let block_offset = block_pos % paged_attn_metadata.block_size;
403                let slot = block_number * paged_attn_metadata.block_size + block_offset;
404                let slot = slot.try_into().unwrap();
405                slot_mappings.push(vec![slot]);
406
407                if let Some(sliding_window) = paged_attn_metadata.sliding_window {
408                    let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
409                    let slide_idx = if table.len() > sliding_window_blocks {
410                        table.len() - sliding_window_blocks
411                    } else {
412                        0
413                    };
414                    block_tables.push(table.get(slide_idx..).unwrap().to_vec());
415                } else {
416                    block_tables.push(table);
417                }
418
419                let paged_attn_context_len =
420                    if let Some(sliding_window) = paged_attn_metadata.sliding_window {
421                        seq.len().min(sliding_window)
422                    } else {
423                        seq.len()
424                    };
425                paged_attn_context_lens.push(paged_attn_context_len);
426            }
427        }
428
429        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
430            let max_q = *seqlens_q.iter().max().unwrap();
431            let max_k = *seqlens_k.iter().max().unwrap();
432            let seqlens_q = Tensor::new(seqlens_q, device)?
433                .to_dtype(DType::F32)?
434                .cumsum(0)?
435                .to_dtype(DType::U32)?;
436            let seqlens_k = Tensor::new(seqlens_k, device)?
437                .to_dtype(DType::F32)?
438                .cumsum(0)?
439                .to_dtype(DType::U32)?;
440
441            let mut seqlens_q_map = HashMap::new();
442            let mut seqlens_k_map = HashMap::new();
443
444            let devices = mapper.unwrap().get_unique_devices();
445            for device in devices {
446                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
447                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
448            }
449            (max_q, max_k, seqlens_q_map, seqlens_k_map)
450        } else {
451            (0, 0, HashMap::new(), HashMap::new())
452        };
453
454        let paged_attn_meta = if paged_attn_metadata.is_some() {
455            let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
456
457            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
458
459            let block_tables = _make_tensor_with_pad(
460                block_tables
461                    .iter()
462                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
463                    .collect::<Vec<_>>(),
464                max_block_table_len,
465                0,
466                device,
467            )?;
468            let block_tables = block_tables.reshape(((), max_block_table_len))?;
469
470            let max_context_len = paged_attn_context_lens.iter().max().unwrap();
471
472            let context_lens = Tensor::from_vec(
473                paged_attn_context_lens
474                    .iter()
475                    .map(|x| *x as u32)
476                    .collect::<Vec<_>>(),
477                (paged_attn_context_lens.len(),),
478                device,
479            )?;
480
481            // For device mapping, make a copy of each tensor for each device
482            let devices = mapper.unwrap().get_unique_devices();
483            let mut slot_mappings_map = HashMap::new();
484            let mut block_tables_map = HashMap::new();
485            let mut context_lens_map = HashMap::new();
486
487            for device in devices {
488                slot_mappings_map
489                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
490                block_tables_map
491                    .insert(device.location(), block_tables.clone().to_device(&device)?);
492                context_lens_map
493                    .insert(device.location(), context_lens.clone().to_device(&device)?);
494            }
495
496            Some(PagedAttentionInputMetadata {
497                slot_mappings: slot_mappings_map,
498                block_tables: Some(block_tables_map),
499                context_lens: Some(context_lens_map),
500                max_context_len: Some(*max_context_len),
501                is_first_prompt_chunk: false,
502            })
503        } else {
504            None
505        };
506
507        Ok(InputMetadata {
508            input: Tensor::cat(&seqs_tensors, 0).unwrap(),
509            positions: seqlen_offsets,
510            context_lens,
511            position_ids,
512            paged_attn_meta,
513            flash_meta: FlashParams {
514                max_k,
515                max_q,
516                cumulative_seqlens_k: seqlens_k_map,
517                cumulative_seqlens_q: seqlens_q_map,
518                causal: true,
519            },
520        })
521    }
522
523    #[allow(clippy::too_many_arguments)]
524    pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
525        toks: Vec<&[T]>,
526        input_seqs: &[&mut Sequence],
527        device: &Device,
528        last_n_context_len: Option<(usize, usize)>,
529        return_raw_logits: bool,
530        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
531        mapper: Option<&dyn DeviceMapper>,
532    ) -> Result<InnerInputProcessorOutput> {
533        let offset = input_seqs[0].token_offset();
534        make_prompt_chunk(
535            offset,
536            toks,
537            &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
538            device,
539            last_n_context_len,
540            return_raw_logits,
541            paged_attn_metadata,
542            mapper,
543        )
544        .map(|inputs| InnerInputProcessorOutput {
545            inputs,
546            seq_indices: (0..input_seqs.len()).collect(),
547        })
548    }
549
550    #[allow(clippy::too_many_arguments)]
551    pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
552        toks: Vec<&[T]>,
553        input_seqs: &[&mut Sequence],
554        device: &Device,
555        no_kv_cache: bool,
556        last_n_context_len: Option<(usize, usize)>,
557        return_raw_logits: bool,
558        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
559        mapper: Option<&dyn DeviceMapper>,
560    ) -> Result<InnerInputProcessorOutput> {
561        if no_kv_cache {
562            return get_prompt_input(
563                toks,
564                input_seqs,
565                device,
566                last_n_context_len,
567                return_raw_logits,
568                paged_attn_metadata,
569                mapper,
570            );
571        }
572
573        make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
574            InnerInputProcessorOutput {
575                inputs,
576                seq_indices: (0..input_seqs.len()).collect(),
577            }
578        })
579    }
580
581    #[derive(Clone)]
582    pub struct ModelInputs {
583        pub input_ids: Tensor,
584        pub input_ids_full: Option<Tensor>,
585        pub seqlen_offsets: Vec<usize>,
586        pub seqlen_offsets_full: Option<Vec<usize>>,
587        pub context_lens: Vec<(usize, usize)>,
588        pub position_ids: Vec<usize>,
589        pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
590        pub flash_meta: FlashParams,
591        pub flash_meta_full: Option<FlashParams>,
592    }
593
594    pub struct TextInputsProcessor;
595
596    impl InputsProcessor for TextInputsProcessor {
597        fn process_inputs(
598            &self,
599            _: Option<Arc<Tokenizer>>,
600            input_seqs: &mut [&mut Sequence],
601            is_prompt: bool,
602            is_xlora: bool,
603            device: &Device,
604            no_kv_cache: bool,
605            last_n_context_len: Option<(usize, usize)>,
606            return_raw_logits: bool,
607            _: Option<Arc<dyn Any>>,
608            mut paged_attn_metadata: Option<PagedAttentionMeta>,
609            mapper: Option<&dyn DeviceMapper>,
610        ) -> Result<InputProcessorOutput> {
611            if is_xlora && !is_prompt {
612                let prompt = get_prompt_input(
613                    input_seqs
614                        .iter()
615                        .map(|seq| seq.get_toks())
616                        .collect::<Vec<_>>(),
617                    input_seqs,
618                    device,
619                    last_n_context_len,
620                    return_raw_logits,
621                    paged_attn_metadata.as_mut(),
622                    mapper,
623                )?;
624                let completion = get_completion_input(
625                    input_seqs
626                        .iter()
627                        .map(|seq| seq.get_toks())
628                        .collect::<Vec<_>>(),
629                    input_seqs,
630                    device,
631                    no_kv_cache,
632                    last_n_context_len,
633                    return_raw_logits,
634                    paged_attn_metadata.as_mut(),
635                    mapper,
636                )?;
637                let InnerInputProcessorOutput {
638                    inputs:
639                        InputMetadata {
640                            input: input_ids_full,
641                            positions: seqlen_offsets_full,
642                            context_lens: _,
643                            position_ids,
644                            paged_attn_meta: _,
645                            flash_meta: flash_meta_full,
646                        },
647                    seq_indices,
648                } = prompt;
649                let InnerInputProcessorOutput {
650                    inputs:
651                        InputMetadata {
652                            input: input_ids,
653                            positions: seqlen_offsets,
654                            context_lens,
655                            position_ids: _,
656                            paged_attn_meta,
657                            flash_meta,
658                        },
659                    seq_indices: _,
660                } = completion;
661                let inputs: Box<dyn Any> = Box::new(ModelInputs {
662                    input_ids,
663                    input_ids_full: Some(input_ids_full),
664                    seqlen_offsets,
665                    seqlen_offsets_full: Some(seqlen_offsets_full),
666                    context_lens,
667                    position_ids,
668                    paged_attn_meta,
669                    flash_meta,
670                    flash_meta_full: Some(flash_meta_full),
671                });
672                Ok(InputProcessorOutput {
673                    inputs,
674                    seq_indices,
675                })
676            } else if is_xlora && is_prompt {
677                let metadata = get_prompt_input(
678                    input_seqs
679                        .iter()
680                        .map(|seq| seq.get_toks())
681                        .collect::<Vec<_>>(),
682                    input_seqs,
683                    device,
684                    last_n_context_len,
685                    return_raw_logits,
686                    paged_attn_metadata.as_mut(),
687                    mapper,
688                )?;
689                let InnerInputProcessorOutput {
690                    inputs:
691                        InputMetadata {
692                            input: input_ids,
693                            positions: seqlen_offsets,
694                            context_lens,
695                            position_ids,
696                            paged_attn_meta,
697                            flash_meta,
698                        },
699                    seq_indices,
700                } = metadata;
701                let inputs: Box<dyn Any> = Box::new(ModelInputs {
702                    input_ids: input_ids.clone(),
703                    input_ids_full: Some(input_ids),
704                    seqlen_offsets: seqlen_offsets.clone(),
705                    seqlen_offsets_full: Some(seqlen_offsets),
706                    context_lens,
707                    position_ids,
708                    paged_attn_meta,
709                    flash_meta: flash_meta.clone(),
710                    flash_meta_full: Some(flash_meta),
711                });
712                Ok(InputProcessorOutput {
713                    inputs,
714                    seq_indices,
715                })
716            } else if is_prompt {
717                let metadata = get_prompt_input(
718                    input_seqs
719                        .iter()
720                        .map(|seq| seq.get_toks())
721                        .collect::<Vec<_>>(),
722                    input_seqs,
723                    device,
724                    last_n_context_len,
725                    return_raw_logits,
726                    paged_attn_metadata.as_mut(),
727                    mapper,
728                )?;
729                let InnerInputProcessorOutput {
730                    inputs:
731                        InputMetadata {
732                            input: input_ids,
733                            positions: seqlen_offsets,
734                            context_lens,
735                            position_ids,
736                            paged_attn_meta,
737                            flash_meta,
738                        },
739                    seq_indices,
740                } = metadata;
741                let inputs: Box<dyn Any> = Box::new(ModelInputs {
742                    input_ids,
743                    input_ids_full: None,
744                    seqlen_offsets,
745                    seqlen_offsets_full: None,
746                    context_lens,
747                    position_ids,
748                    paged_attn_meta,
749                    flash_meta,
750                    flash_meta_full: None,
751                });
752                Ok(InputProcessorOutput {
753                    inputs,
754                    seq_indices,
755                })
756            } else {
757                let metadata = get_completion_input(
758                    input_seqs
759                        .iter()
760                        .map(|seq| seq.get_toks())
761                        .collect::<Vec<_>>(),
762                    input_seqs,
763                    device,
764                    no_kv_cache,
765                    last_n_context_len,
766                    return_raw_logits,
767                    paged_attn_metadata.as_mut(),
768                    mapper,
769                )?;
770                let InnerInputProcessorOutput {
771                    inputs:
772                        InputMetadata {
773                            input: input_ids,
774                            positions: seqlen_offsets,
775                            context_lens,
776                            position_ids,
777                            paged_attn_meta,
778                            flash_meta,
779                        },
780                    seq_indices,
781                } = metadata;
782                let inputs: Box<dyn Any> = Box::new(ModelInputs {
783                    input_ids,
784                    input_ids_full: None,
785                    seqlen_offsets,
786                    seqlen_offsets_full: None,
787                    context_lens,
788                    position_ids,
789                    paged_attn_meta,
790                    flash_meta,
791                    flash_meta_full: None,
792                });
793                Ok(InputProcessorOutput {
794                    inputs,
795                    seq_indices,
796                })
797            }
798        }
799
800        fn get_type(&self) -> InputsProcessorType {
801            InputsProcessorType::Text
802        }
803    }
804}