mistralrs_core/pipeline/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14    Text,
15    Vision,
16    Embedding,
17}
18
19pub struct InputProcessorOutput {
20    pub inputs: Box<dyn Any>,
21    pub seq_indices: Vec<usize>,
22}
23
24/// Processor: Prepare inputs for the model (potentially preparing the images if applicable)
25pub trait InputsProcessor {
26    /// This should also enable matmul via f16 if prompt and the sequence length is greater than 32.
27    /// Otherwise, matmul via f16 is disabled.
28    ///
29    /// This should return a type which can be downcasted to the proper type as used in `forward_inputs`
30    #[allow(clippy::too_many_arguments)]
31    fn process_inputs(
32        &self,
33        tokenizer: Option<Arc<Tokenizer>>,
34        input_seqs: &mut [&mut Sequence],
35        is_prompt: bool,
36        is_xlora: bool,
37        device: &Device,
38        no_kv_cache: bool,
39        last_n_context_len: Option<(usize, usize)>,
40        return_raw_logits: bool,
41        other_config: Option<Arc<dyn Any>>,
42        paged_attn_metadata: Option<PagedAttentionMeta>,
43        mapper: Option<&dyn DeviceMapper>,
44    ) -> Result<InputProcessorOutput>;
45
46    fn get_type(&self) -> InputsProcessorType;
47}
48
49// ========================= Test models input processor
50
51pub mod text_models_inputs_processor {
52    use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
53
54    use anyhow::Result;
55    use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
56    use tokenizers::Tokenizer;
57
58    use crate::{
59        device_map::DeviceMapper,
60        get_mut_arcmutex,
61        paged_attention::{BlockEngine, _PAD_SLOT_ID},
62        sequence::Sequence,
63    };
64
65    use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
66
67    fn _make_tensor_with_pad<D: WithDType>(
68        x: Vec<Vec<D>>,
69        max_len: usize,
70        pad: D,
71        device: &Device,
72    ) -> Result<Tensor> {
73        let mut padded_x = Vec::new();
74        for mut x_i in x {
75            assert!(x_i.len() <= max_len);
76            x_i.extend([pad].repeat(max_len - x_i.len()));
77            let shape = (x_i.len(),);
78            padded_x.push(Tensor::from_vec(x_i, shape, device)?);
79        }
80        Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
81    }
82
83    pub struct PagedAttentionMeta {
84        pub sliding_window: Option<usize>,
85        pub block_size: usize,
86        pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
87    }
88
89    #[derive(Clone, Debug)]
90    #[allow(dead_code)]
91    pub struct PagedAttentionInputMetadata {
92        pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
93        pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
94        pub slot_mappings: HashMap<DeviceLocation, Tensor>,
95        pub max_context_len: Option<usize>,
96        pub is_first_prompt_chunk: bool,
97    }
98
99    impl PagedAttentionInputMetadata {
100        /// Create a dummy input metadata, assuming that this will NOT be used for decoding.
101        /// This is used for the case of imatrix generation.
102        pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
103            Ok(PagedAttentionInputMetadata {
104                block_tables: None,
105                context_lens: None,
106                max_context_len: None,
107                slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
108                is_first_prompt_chunk: true,
109            })
110        }
111    }
112
113    #[derive(Clone, Debug)]
114    pub struct FlashParams {
115        pub max_q: u32,
116        pub max_k: u32,
117        pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
118        pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
119        pub causal: bool,
120    }
121
122    pub struct InputMetadata {
123        pub input: Tensor,
124        pub positions: Vec<usize>,
125        pub context_lens: Vec<(usize, usize)>, // (start index, len)
126        pub position_ids: Vec<usize>,
127        pub paged_attn_meta: Option<PagedAttentionInputMetadata>, // For paged attention
128        pub flash_meta: FlashParams,
129    }
130
131    pub struct InnerInputProcessorOutput {
132        pub inputs: InputMetadata,
133        pub seq_indices: Vec<usize>,
134    }
135
136    // chunk_offset_toks is the number of tokens by which the tokens are offset,
137    // chunk_offset_toks / prompt_chunksize = number of batches
138    #[allow(clippy::too_many_arguments)]
139    pub fn make_prompt_chunk<T: WithDType + Debug>(
140        chunk_offset_toks: usize,
141        toks: Vec<&[T]>,
142        seq_ids: &[usize],
143        device: &Device,
144        last_n_context_len: Option<(usize, usize)>,
145        return_raw_logits: bool,
146        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
147        mapper: Option<&dyn DeviceMapper>,
148    ) -> Result<InputMetadata> {
149        let max_len = toks
150            .iter()
151            .map(|seq| seq.len())
152            .max()
153            .expect("No sequences");
154        let padding_tok = T::zero();
155        // Pad each sequence by the padding token to the max len.
156        let mut seqs_tensors = Vec::new();
157        let mut seqlen_offsets = Vec::new();
158        let mut context_lens = Vec::new();
159        let mut position_ids = Vec::new();
160        let mut slot_mappings = Vec::new();
161        let mut block_tables = Vec::new();
162        let mut paged_attn_context_lens = Vec::new();
163        let flash_attn = crate::using_flash_attn();
164        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
165        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
166        for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
167            let prompt_len = ctxt.len();
168            let offset = last_n_context_len.unwrap_or_default();
169            seqlen_offsets.push(offset.1 + chunk_offset_toks);
170
171            position_ids.push(ctxt.len() + chunk_offset_toks);
172            let mut ctxt = ctxt.to_vec();
173            ctxt.extend(std::iter::repeat_n(
174                padding_tok,
175                max_len.saturating_sub(ctxt.len()),
176            ));
177            // If we are returning raw logits, we want to not trim the logits at all.
178            if return_raw_logits {
179                if last_n_context_len.is_some() {
180                    anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181                }
182
183                context_lens.push((0, ctxt.len()));
184            } else {
185                context_lens.push((
186                    ctxt.len()
187                        .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
188                    last_n_context_len.map(|(a, _)| a).unwrap_or(1),
189                ));
190            }
191
192            if flash_attn {
193                seqlens_q.push(ctxt.len() as u32);
194                seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
195            }
196
197            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
198
199            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
200                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
201                let table = block_engine.block_tables.get(seq_id);
202
203                if table.is_none() {
204                    // Will be None during profiling.
205                    slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
206                    continue;
207                }
208                let table = table
209                    .unwrap()
210                    .iter()
211                    .map(|block| block.deref_mut().block_id)
212                    .collect::<Vec<_>>();
213
214                let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
215                    prompt_len.saturating_sub(sliding_window)
216                } else {
217                    0
218                };
219
220                let mut slot_mapping = Vec::new();
221                let mut ctxt_len = Vec::new();
222                for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223                    if i < start_idx {
224                        // Pad [0,start_idx) with _PAD_TOKEN_ID
225                        slot_mapping.push(_PAD_SLOT_ID);
226                    }
227                    ctxt_len.push(i);
228
229                    let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230                        panic!(
231                            "Block table is too small (prompt)! i={} block_size={} table_len={}",
232                            i,
233                            paged_attn_metadata.block_size,
234                            table.len()
235                        );
236                    } else {
237                        table.get(i / paged_attn_metadata.block_size).unwrap()
238                    };
239                    let block_offset = i % paged_attn_metadata.block_size;
240                    // Use checked arithmetic to prevent overflow
241                    let slot = block_number
242                        .checked_mul(paged_attn_metadata.block_size)
243                        .and_then(|v| v.checked_add(block_offset))
244                        .expect("Slot calculation overflowed");
245                    slot_mapping.push(
246                        slot.try_into()
247                            .expect("Slot value too large for target integer type"),
248                    );
249                    block_tables.push(table.clone());
250                }
251                slot_mappings.push(slot_mapping);
252                paged_attn_context_lens.push(ctxt_len);
253            }
254        }
255
256        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
257            // SAFETY: seqlens_q/k are initialized with vec![0] when flash_attn is true,
258            // so they are guaranteed to be non-empty here.
259            let max_q = *seqlens_q
260                .iter()
261                .max()
262                .expect("seqlens_q should not be empty when flash_attn is enabled");
263            let max_k = *seqlens_k
264                .iter()
265                .max()
266                .expect("seqlens_k should not be empty when flash_attn is enabled");
267            let seqlens_q = Tensor::new(seqlens_q, device)?
268                .to_dtype(DType::F32)?
269                .cumsum(0)?
270                .to_dtype(DType::U32)?;
271            let seqlens_k = Tensor::new(seqlens_k, device)?
272                .to_dtype(DType::F32)?
273                .cumsum(0)?
274                .to_dtype(DType::U32)?;
275
276            let mut seqlens_q_map = HashMap::new();
277            let mut seqlens_k_map = HashMap::new();
278
279            let devices = mapper.unwrap().get_unique_devices();
280            for device in devices {
281                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
282                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
283            }
284            (max_q, max_k, seqlens_q_map, seqlens_k_map)
285        } else {
286            (0, 0, HashMap::new(), HashMap::new())
287        };
288
289        let input = Tensor::cat(&seqs_tensors, 0).unwrap();
290
291        let paged_attn_meta = if paged_attn_metadata.is_some() {
292            let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
293            let slot_mappings =
294                _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
295
296            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
297            let block_tables = _make_tensor_with_pad(
298                block_tables
299                    .iter()
300                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
301                    .collect::<Vec<_>>(),
302                max_block_table_len,
303                0,
304                device,
305            )?;
306            let block_tables = block_tables.reshape(((), max_block_table_len))?;
307
308            let max_context_len = paged_attn_context_lens
309                .iter()
310                .map(|x| x.len())
311                .max()
312                .unwrap();
313
314            let context_lens = _make_tensor_with_pad(
315                paged_attn_context_lens
316                    .iter()
317                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
318                    .collect::<Vec<_>>(),
319                max_context_len,
320                0,
321                device,
322            )?
323            .reshape(((),))?;
324
325            // For device mapping, make a copy of each tensor for each device
326            let devices = mapper.unwrap().get_unique_devices();
327            let mut slot_mappings_map = HashMap::new();
328            let mut block_tables_map = HashMap::new();
329            let mut context_lens_map = HashMap::new();
330
331            for device in devices {
332                slot_mappings_map
333                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
334                block_tables_map
335                    .insert(device.location(), block_tables.clone().to_device(&device)?);
336                context_lens_map
337                    .insert(device.location(), context_lens.clone().to_device(&device)?);
338            }
339
340            Some(PagedAttentionInputMetadata {
341                slot_mappings: slot_mappings_map,
342                block_tables: Some(block_tables_map),
343                context_lens: Some(context_lens_map),
344                max_context_len: Some(max_context_len),
345                is_first_prompt_chunk: chunk_offset_toks == 0,
346            })
347        } else {
348            None
349        };
350
351        Ok(InputMetadata {
352            input,
353            positions: seqlen_offsets,
354            context_lens,
355            position_ids,
356            paged_attn_meta,
357            flash_meta: FlashParams {
358                max_k,
359                max_q,
360                cumulative_seqlens_k: seqlens_k_map,
361                cumulative_seqlens_q: seqlens_q_map,
362                causal: true,
363            },
364        })
365    }
366
367    fn make_completion_chunk<T: WithDType>(
368        toks: Vec<&[T]>,
369        input_seqs: &[&mut Sequence],
370        device: &Device,
371        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
372        mapper: Option<&dyn DeviceMapper>,
373    ) -> Result<InputMetadata> {
374        // Pad each sequence by the padding token to the max len.
375        let flash_attn = crate::using_flash_attn();
376        let mut seqs_tensors = Vec::new();
377        let mut seqlen_offsets = Vec::new();
378        let mut context_lens = Vec::new();
379        let mut position_ids = Vec::new();
380
381        let mut slot_mappings = Vec::new();
382        let mut block_tables = Vec::new();
383        let mut paged_attn_context_lens = Vec::new();
384        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
385        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
386        for (seq, ctxt) in input_seqs.iter().zip(toks) {
387            let start_pos = ctxt.len().saturating_sub(1);
388            let ctxt = ctxt[start_pos..].to_vec();
389            seqlen_offsets.push(start_pos);
390            context_lens.push((0, 1));
391            position_ids.push(seq.len());
392
393            if flash_attn {
394                seqlens_q.push(ctxt.len() as u32);
395                seqlens_k.push((ctxt.len() + start_pos) as u32);
396            }
397
398            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
399
400            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
401                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
402                let table = block_engine.block_tables.get(seq.id()).unwrap();
403
404                let table = table
405                    .iter()
406                    .map(|block| block.deref_mut().block_id)
407                    .collect::<Vec<_>>();
408
409                let block_pos = start_pos - seq.token_offset();
410                let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
411                    panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
412                } else {
413                    table
414                        .get(block_pos / paged_attn_metadata.block_size)
415                        .unwrap()
416                };
417                let block_offset = block_pos % paged_attn_metadata.block_size;
418                // Use checked arithmetic to prevent overflow
419                let slot = block_number
420                    .checked_mul(paged_attn_metadata.block_size)
421                    .and_then(|v| v.checked_add(block_offset))
422                    .expect("Slot calculation overflowed");
423                let slot = slot
424                    .try_into()
425                    .expect("Slot value too large for target integer type");
426                slot_mappings.push(vec![slot]);
427
428                if let Some(sliding_window) = paged_attn_metadata.sliding_window {
429                    let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
430                    let slide_idx = if table.len() > sliding_window_blocks {
431                        table.len() - sliding_window_blocks
432                    } else {
433                        0
434                    };
435                    block_tables.push(table.get(slide_idx..).unwrap().to_vec());
436                } else {
437                    block_tables.push(table);
438                }
439
440                let paged_attn_context_len =
441                    if let Some(sliding_window) = paged_attn_metadata.sliding_window {
442                        seq.len().min(sliding_window)
443                    } else {
444                        seq.len()
445                    };
446                paged_attn_context_lens.push(paged_attn_context_len);
447            }
448        }
449
450        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
451            // SAFETY: seqlens_q/k are initialized with vec![0] when flash_attn is true,
452            // so they are guaranteed to be non-empty here.
453            let max_q = *seqlens_q
454                .iter()
455                .max()
456                .expect("seqlens_q should not be empty when flash_attn is enabled");
457            let max_k = *seqlens_k
458                .iter()
459                .max()
460                .expect("seqlens_k should not be empty when flash_attn is enabled");
461            let seqlens_q = Tensor::new(seqlens_q, device)?
462                .to_dtype(DType::F32)?
463                .cumsum(0)?
464                .to_dtype(DType::U32)?;
465            let seqlens_k = Tensor::new(seqlens_k, device)?
466                .to_dtype(DType::F32)?
467                .cumsum(0)?
468                .to_dtype(DType::U32)?;
469
470            let mut seqlens_q_map = HashMap::new();
471            let mut seqlens_k_map = HashMap::new();
472
473            let devices = mapper.unwrap().get_unique_devices();
474            for device in devices {
475                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
476                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
477            }
478            (max_q, max_k, seqlens_q_map, seqlens_k_map)
479        } else {
480            (0, 0, HashMap::new(), HashMap::new())
481        };
482
483        let paged_attn_meta = if paged_attn_metadata.is_some() {
484            let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
485
486            let max_block_table_len = block_tables
487                .iter()
488                .map(|x| x.len())
489                .max()
490                .expect("block_tables should not be empty when paged attention is enabled");
491
492            let block_tables = _make_tensor_with_pad(
493                block_tables
494                    .iter()
495                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
496                    .collect::<Vec<_>>(),
497                max_block_table_len,
498                0,
499                device,
500            )?;
501            let block_tables = block_tables.reshape(((), max_block_table_len))?;
502
503            let max_context_len = paged_attn_context_lens.iter().max().unwrap();
504
505            let context_lens = Tensor::from_vec(
506                paged_attn_context_lens
507                    .iter()
508                    .map(|x| *x as u32)
509                    .collect::<Vec<_>>(),
510                (paged_attn_context_lens.len(),),
511                device,
512            )?;
513
514            // For device mapping, make a copy of each tensor for each device
515            let devices = mapper.unwrap().get_unique_devices();
516            let mut slot_mappings_map = HashMap::new();
517            let mut block_tables_map = HashMap::new();
518            let mut context_lens_map = HashMap::new();
519
520            for device in devices {
521                slot_mappings_map
522                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
523                block_tables_map
524                    .insert(device.location(), block_tables.clone().to_device(&device)?);
525                context_lens_map
526                    .insert(device.location(), context_lens.clone().to_device(&device)?);
527            }
528
529            Some(PagedAttentionInputMetadata {
530                slot_mappings: slot_mappings_map,
531                block_tables: Some(block_tables_map),
532                context_lens: Some(context_lens_map),
533                max_context_len: Some(*max_context_len),
534                is_first_prompt_chunk: false,
535            })
536        } else {
537            None
538        };
539
540        Ok(InputMetadata {
541            input: Tensor::cat(&seqs_tensors, 0).unwrap(),
542            positions: seqlen_offsets,
543            context_lens,
544            position_ids,
545            paged_attn_meta,
546            flash_meta: FlashParams {
547                max_k,
548                max_q,
549                cumulative_seqlens_k: seqlens_k_map,
550                cumulative_seqlens_q: seqlens_q_map,
551                causal: true,
552            },
553        })
554    }
555
556    #[allow(clippy::too_many_arguments)]
557    pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
558        toks: Vec<&[T]>,
559        input_seqs: &[&mut Sequence],
560        device: &Device,
561        last_n_context_len: Option<(usize, usize)>,
562        return_raw_logits: bool,
563        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
564        mapper: Option<&dyn DeviceMapper>,
565    ) -> Result<InnerInputProcessorOutput> {
566        let offset = input_seqs[0].token_offset();
567        make_prompt_chunk(
568            offset,
569            toks,
570            &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
571            device,
572            last_n_context_len,
573            return_raw_logits,
574            paged_attn_metadata,
575            mapper,
576        )
577        .map(|inputs| InnerInputProcessorOutput {
578            inputs,
579            seq_indices: (0..input_seqs.len()).collect(),
580        })
581    }
582
583    #[allow(clippy::too_many_arguments)]
584    pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
585        toks: Vec<&[T]>,
586        input_seqs: &[&mut Sequence],
587        device: &Device,
588        no_kv_cache: bool,
589        last_n_context_len: Option<(usize, usize)>,
590        return_raw_logits: bool,
591        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
592        mapper: Option<&dyn DeviceMapper>,
593    ) -> Result<InnerInputProcessorOutput> {
594        if no_kv_cache {
595            return get_prompt_input(
596                toks,
597                input_seqs,
598                device,
599                last_n_context_len,
600                return_raw_logits,
601                paged_attn_metadata,
602                mapper,
603            );
604        }
605
606        make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
607            InnerInputProcessorOutput {
608                inputs,
609                seq_indices: (0..input_seqs.len()).collect(),
610            }
611        })
612    }
613
614    #[derive(Clone)]
615    pub struct ModelInputs {
616        pub input_ids: Tensor,
617        pub input_ids_full: Option<Tensor>,
618        pub seqlen_offsets: Vec<usize>,
619        pub seqlen_offsets_full: Option<Vec<usize>>,
620        pub context_lens: Vec<(usize, usize)>,
621        pub position_ids: Vec<usize>,
622        pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
623        pub flash_meta: FlashParams,
624        pub flash_meta_full: Option<FlashParams>,
625    }
626
627    pub struct TextInputsProcessor;
628
629    impl InputsProcessor for TextInputsProcessor {
630        fn process_inputs(
631            &self,
632            _: Option<Arc<Tokenizer>>,
633            input_seqs: &mut [&mut Sequence],
634            is_prompt: bool,
635            is_xlora: bool,
636            device: &Device,
637            no_kv_cache: bool,
638            last_n_context_len: Option<(usize, usize)>,
639            return_raw_logits: bool,
640            _: Option<Arc<dyn Any>>,
641            mut paged_attn_metadata: Option<PagedAttentionMeta>,
642            mapper: Option<&dyn DeviceMapper>,
643        ) -> Result<InputProcessorOutput> {
644            if is_xlora && !is_prompt {
645                let prompt = get_prompt_input(
646                    input_seqs
647                        .iter()
648                        .map(|seq| seq.get_toks())
649                        .collect::<Vec<_>>(),
650                    input_seqs,
651                    device,
652                    last_n_context_len,
653                    return_raw_logits,
654                    paged_attn_metadata.as_mut(),
655                    mapper,
656                )?;
657                let completion = get_completion_input(
658                    input_seqs
659                        .iter()
660                        .map(|seq| seq.get_toks())
661                        .collect::<Vec<_>>(),
662                    input_seqs,
663                    device,
664                    no_kv_cache,
665                    last_n_context_len,
666                    return_raw_logits,
667                    paged_attn_metadata.as_mut(),
668                    mapper,
669                )?;
670                let InnerInputProcessorOutput {
671                    inputs:
672                        InputMetadata {
673                            input: input_ids_full,
674                            positions: seqlen_offsets_full,
675                            context_lens: _,
676                            position_ids,
677                            paged_attn_meta: _,
678                            flash_meta: flash_meta_full,
679                        },
680                    seq_indices,
681                } = prompt;
682                let InnerInputProcessorOutput {
683                    inputs:
684                        InputMetadata {
685                            input: input_ids,
686                            positions: seqlen_offsets,
687                            context_lens,
688                            position_ids: _,
689                            paged_attn_meta,
690                            flash_meta,
691                        },
692                    seq_indices: _,
693                } = completion;
694                let inputs: Box<dyn Any> = Box::new(ModelInputs {
695                    input_ids,
696                    input_ids_full: Some(input_ids_full),
697                    seqlen_offsets,
698                    seqlen_offsets_full: Some(seqlen_offsets_full),
699                    context_lens,
700                    position_ids,
701                    paged_attn_meta,
702                    flash_meta,
703                    flash_meta_full: Some(flash_meta_full),
704                });
705                Ok(InputProcessorOutput {
706                    inputs,
707                    seq_indices,
708                })
709            } else if is_xlora && is_prompt {
710                let metadata = get_prompt_input(
711                    input_seqs
712                        .iter()
713                        .map(|seq| seq.get_toks())
714                        .collect::<Vec<_>>(),
715                    input_seqs,
716                    device,
717                    last_n_context_len,
718                    return_raw_logits,
719                    paged_attn_metadata.as_mut(),
720                    mapper,
721                )?;
722                let InnerInputProcessorOutput {
723                    inputs:
724                        InputMetadata {
725                            input: input_ids,
726                            positions: seqlen_offsets,
727                            context_lens,
728                            position_ids,
729                            paged_attn_meta,
730                            flash_meta,
731                        },
732                    seq_indices,
733                } = metadata;
734                let inputs: Box<dyn Any> = Box::new(ModelInputs {
735                    input_ids: input_ids.clone(),
736                    input_ids_full: Some(input_ids),
737                    seqlen_offsets: seqlen_offsets.clone(),
738                    seqlen_offsets_full: Some(seqlen_offsets),
739                    context_lens,
740                    position_ids,
741                    paged_attn_meta,
742                    flash_meta: flash_meta.clone(),
743                    flash_meta_full: Some(flash_meta),
744                });
745                Ok(InputProcessorOutput {
746                    inputs,
747                    seq_indices,
748                })
749            } else if is_prompt {
750                let metadata = get_prompt_input(
751                    input_seqs
752                        .iter()
753                        .map(|seq| seq.get_toks())
754                        .collect::<Vec<_>>(),
755                    input_seqs,
756                    device,
757                    last_n_context_len,
758                    return_raw_logits,
759                    paged_attn_metadata.as_mut(),
760                    mapper,
761                )?;
762                let InnerInputProcessorOutput {
763                    inputs:
764                        InputMetadata {
765                            input: input_ids,
766                            positions: seqlen_offsets,
767                            context_lens,
768                            position_ids,
769                            paged_attn_meta,
770                            flash_meta,
771                        },
772                    seq_indices,
773                } = metadata;
774                let inputs: Box<dyn Any> = Box::new(ModelInputs {
775                    input_ids,
776                    input_ids_full: None,
777                    seqlen_offsets,
778                    seqlen_offsets_full: None,
779                    context_lens,
780                    position_ids,
781                    paged_attn_meta,
782                    flash_meta,
783                    flash_meta_full: None,
784                });
785                Ok(InputProcessorOutput {
786                    inputs,
787                    seq_indices,
788                })
789            } else {
790                let metadata = get_completion_input(
791                    input_seqs
792                        .iter()
793                        .map(|seq| seq.get_toks())
794                        .collect::<Vec<_>>(),
795                    input_seqs,
796                    device,
797                    no_kv_cache,
798                    last_n_context_len,
799                    return_raw_logits,
800                    paged_attn_metadata.as_mut(),
801                    mapper,
802                )?;
803                let InnerInputProcessorOutput {
804                    inputs:
805                        InputMetadata {
806                            input: input_ids,
807                            positions: seqlen_offsets,
808                            context_lens,
809                            position_ids,
810                            paged_attn_meta,
811                            flash_meta,
812                        },
813                    seq_indices,
814                } = metadata;
815                let inputs: Box<dyn Any> = Box::new(ModelInputs {
816                    input_ids,
817                    input_ids_full: None,
818                    seqlen_offsets,
819                    seqlen_offsets_full: None,
820                    context_lens,
821                    position_ids,
822                    paged_attn_meta,
823                    flash_meta,
824                    flash_meta_full: None,
825                });
826                Ok(InputProcessorOutput {
827                    inputs,
828                    seq_indices,
829                })
830            }
831        }
832
833        fn get_type(&self) -> InputsProcessorType {
834            InputsProcessorType::Text
835        }
836    }
837}