mistralrs_core/pipeline/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12#[derive(PartialEq)]
13pub enum InputsProcessorType {
14    Text,
15    Vision,
16    Embedding,
17}
18
19pub struct InputProcessorOutput {
20    pub inputs: Box<dyn Any>,
21    pub seq_indices: Vec<usize>,
22}
23
24/// Processor: Prepare inputs for the model (potentially preparing the images if applicable)
25pub trait InputsProcessor {
26    /// This should also enable matmul via f16 if prompt and the sequence length is greater than 32.
27    /// Otherwise, matmul via f16 is disabled.
28    ///
29    /// This should return a type which can be downcasted to the proper type as used in `forward_inputs`
30    #[allow(clippy::too_many_arguments)]
31    fn process_inputs(
32        &self,
33        tokenizer: Option<Arc<Tokenizer>>,
34        input_seqs: &mut [&mut Sequence],
35        is_prompt: bool,
36        is_xlora: bool,
37        device: &Device,
38        no_kv_cache: bool,
39        last_n_context_len: Option<(usize, usize)>,
40        return_raw_logits: bool,
41        other_config: Option<Arc<dyn Any>>,
42        paged_attn_metadata: Option<PagedAttentionMeta>,
43        mapper: Option<&dyn DeviceMapper>,
44    ) -> Result<InputProcessorOutput>;
45
46    fn get_type(&self) -> InputsProcessorType;
47}
48
49// ========================= Test models input processor
50
51pub mod text_models_inputs_processor {
52    use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
53
54    use anyhow::Result;
55    use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
56    use tokenizers::Tokenizer;
57
58    use crate::{
59        device_map::DeviceMapper,
60        get_mut_arcmutex,
61        paged_attention::{BlockEngine, _PAD_SLOT_ID},
62        sequence::Sequence,
63    };
64
65    use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
66
67    fn _make_tensor_with_pad<D: WithDType>(
68        x: Vec<Vec<D>>,
69        max_len: usize,
70        pad: D,
71        device: &Device,
72    ) -> Result<Tensor> {
73        let mut padded_x = Vec::new();
74        for mut x_i in x {
75            assert!(x_i.len() <= max_len);
76            x_i.extend([pad].repeat(max_len - x_i.len()));
77            let shape = (x_i.len(),);
78            padded_x.push(Tensor::from_vec(x_i, shape, device)?);
79        }
80        Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
81    }
82
83    pub struct PagedAttentionMeta {
84        pub sliding_window: Option<usize>,
85        pub block_size: usize,
86        pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
87    }
88
89    #[derive(Clone, Debug)]
90    #[allow(dead_code)]
91    pub struct PagedAttentionInputMetadata {
92        pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
93        pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
94        pub slot_mappings: HashMap<DeviceLocation, Tensor>,
95        pub max_context_len: Option<usize>,
96        pub is_first_prompt_chunk: bool,
97    }
98
99    impl PagedAttentionInputMetadata {
100        /// Create a dummy input metadata, assuming that this will NOT be used for decoding.
101        /// This is used for the case of imatrix generation.
102        pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
103            Ok(PagedAttentionInputMetadata {
104                block_tables: None,
105                context_lens: None,
106                max_context_len: None,
107                slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
108                is_first_prompt_chunk: true,
109            })
110        }
111    }
112
113    #[derive(Clone, Debug)]
114    pub struct FlashParams {
115        pub max_q: u32,
116        pub max_k: u32,
117        pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
118        pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
119        pub causal: bool,
120    }
121
122    pub struct InputMetadata {
123        pub input: Tensor,
124        pub positions: Vec<usize>,
125        pub context_lens: Vec<(usize, usize)>, // (start index, len)
126        pub position_ids: Vec<usize>,
127        pub paged_attn_meta: Option<PagedAttentionInputMetadata>, // For paged attention
128        pub flash_meta: FlashParams,
129    }
130
131    pub struct InnerInputProcessorOutput {
132        pub inputs: InputMetadata,
133        pub seq_indices: Vec<usize>,
134    }
135
136    // chunk_offset_toks is the number of tokens by which the tokens are offset,
137    // chunk_offset_toks / prompt_chunksize = number of batches
138    #[allow(clippy::too_many_arguments)]
139    pub fn make_prompt_chunk<T: WithDType + Debug>(
140        chunk_offset_toks: usize,
141        toks: Vec<&[T]>,
142        seq_ids: &[usize],
143        device: &Device,
144        last_n_context_len: Option<(usize, usize)>,
145        return_raw_logits: bool,
146        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
147        mapper: Option<&dyn DeviceMapper>,
148    ) -> Result<InputMetadata> {
149        let max_len = toks
150            .iter()
151            .map(|seq| seq.len())
152            .max()
153            .expect("No sequences");
154        let padding_tok = T::zero();
155        // Pad each sequence by the padding token to the max len.
156        let mut seqs_tensors = Vec::new();
157        let mut seqlen_offsets = Vec::new();
158        let mut context_lens = Vec::new();
159        let mut position_ids = Vec::new();
160        let mut slot_mappings = Vec::new();
161        let mut block_tables = Vec::new();
162        let mut paged_attn_context_lens = Vec::new();
163        let flash_attn = crate::using_flash_attn();
164        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
165        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
166        for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
167            let prompt_len = ctxt.len();
168            let offset = last_n_context_len.unwrap_or_default();
169            seqlen_offsets.push(offset.1 + chunk_offset_toks);
170
171            position_ids.push(ctxt.len() + chunk_offset_toks);
172            let mut ctxt = ctxt.to_vec();
173            ctxt.extend(std::iter::repeat_n(
174                padding_tok,
175                max_len.saturating_sub(ctxt.len()),
176            ));
177            // If we are returning raw logits, we want to not trim the logits at all.
178            if return_raw_logits {
179                if last_n_context_len.is_some() {
180                    anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181                }
182
183                context_lens.push((0, ctxt.len()));
184            } else {
185                context_lens.push((
186                    ctxt.len()
187                        .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
188                    last_n_context_len.map(|(a, _)| a).unwrap_or(1),
189                ));
190            }
191
192            if flash_attn {
193                seqlens_q.push(ctxt.len() as u32);
194                seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
195            }
196
197            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
198
199            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
200                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
201                let table = block_engine.block_tables.get(seq_id);
202
203                if table.is_none() {
204                    // Will be None during profiling.
205                    slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
206                    continue;
207                }
208                let table = table
209                    .unwrap()
210                    .iter()
211                    .map(|block| block.block_id())
212                    .collect::<Vec<_>>();
213
214                let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
215                    prompt_len.saturating_sub(sliding_window)
216                } else {
217                    0
218                };
219
220                let mut slot_mapping = Vec::new();
221                let mut ctxt_len = Vec::new();
222                for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223                    if i < start_idx {
224                        // Pad [0,start_idx) with _PAD_TOKEN_ID
225                        slot_mapping.push(_PAD_SLOT_ID);
226                    }
227                    ctxt_len.push(i);
228
229                    let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230                        panic!(
231                            "Block table is too small (prompt)! i={} block_size={} table_len={}",
232                            i,
233                            paged_attn_metadata.block_size,
234                            table.len()
235                        );
236                    } else {
237                        table.get(i / paged_attn_metadata.block_size).unwrap()
238                    };
239                    let block_offset = i % paged_attn_metadata.block_size;
240                    // Use checked arithmetic to prevent overflow
241                    let slot = block_number
242                        .checked_mul(paged_attn_metadata.block_size)
243                        .and_then(|v| v.checked_add(block_offset))
244                        .expect("Slot calculation overflowed");
245                    slot_mapping.push(
246                        slot.try_into()
247                            .expect("Slot value too large for target integer type"),
248                    );
249                    block_tables.push(table.clone());
250                }
251                slot_mappings.push(slot_mapping);
252                paged_attn_context_lens.push(ctxt_len);
253            }
254        }
255
256        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
257            // SAFETY: seqlens_q/k are initialized with vec![0] when flash_attn is true,
258            // so they are guaranteed to be non-empty here.
259            let max_q = *seqlens_q
260                .iter()
261                .max()
262                .expect("seqlens_q should not be empty when flash_attn is enabled");
263            let max_k = *seqlens_k
264                .iter()
265                .max()
266                .expect("seqlens_k should not be empty when flash_attn is enabled");
267            // Create tensors on CPU first to avoid CUDA context issues when copying
268            // between different GPU devices. Each GPU has its own CUDA context, and
269            // candle/cudarc doesn't properly switch contexts when doing GPU-to-GPU
270            // transfers (which go through CPU). By creating on CPU first, we avoid
271            // the cross-context memory access that causes CUDA_ERROR_INVALID_VALUE.
272            let seqlens_q = Tensor::new(seqlens_q, &Device::Cpu)?
273                .to_dtype(DType::F32)?
274                .cumsum(0)?
275                .to_dtype(DType::U32)?;
276            let seqlens_k = Tensor::new(seqlens_k, &Device::Cpu)?
277                .to_dtype(DType::F32)?
278                .cumsum(0)?
279                .to_dtype(DType::U32)?;
280
281            let mut seqlens_q_map = HashMap::new();
282            let mut seqlens_k_map = HashMap::new();
283
284            let devices = mapper.unwrap().get_unique_devices();
285            for device in devices {
286                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
287                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
288            }
289            (max_q, max_k, seqlens_q_map, seqlens_k_map)
290        } else {
291            (0, 0, HashMap::new(), HashMap::new())
292        };
293
294        let input = Tensor::cat(&seqs_tensors, 0).unwrap();
295
296        let paged_attn_meta = if paged_attn_metadata.is_some() {
297            // Create paged attention tensors on CPU first (see comment above about CUDA contexts)
298            let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
299            let slot_mappings = _make_tensor_with_pad(
300                slot_mappings,
301                max_slot_mapping_len,
302                _PAD_SLOT_ID,
303                &Device::Cpu,
304            )?;
305
306            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
307            let block_tables = _make_tensor_with_pad(
308                block_tables
309                    .iter()
310                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
311                    .collect::<Vec<_>>(),
312                max_block_table_len,
313                0,
314                &Device::Cpu,
315            )?;
316            let block_tables = block_tables.reshape(((), max_block_table_len))?;
317
318            let max_context_len = paged_attn_context_lens
319                .iter()
320                .map(|x| x.len())
321                .max()
322                .unwrap();
323
324            let context_lens = _make_tensor_with_pad(
325                paged_attn_context_lens
326                    .iter()
327                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
328                    .collect::<Vec<_>>(),
329                max_context_len,
330                0,
331                &Device::Cpu,
332            )?
333            .reshape(((),))?;
334
335            // For device mapping, make a copy of each tensor for each device
336            let devices = mapper.unwrap().get_unique_devices();
337            let mut slot_mappings_map = HashMap::new();
338            let mut block_tables_map = HashMap::new();
339            let mut context_lens_map = HashMap::new();
340
341            for device in devices {
342                slot_mappings_map
343                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
344                block_tables_map
345                    .insert(device.location(), block_tables.clone().to_device(&device)?);
346                context_lens_map
347                    .insert(device.location(), context_lens.clone().to_device(&device)?);
348            }
349
350            Some(PagedAttentionInputMetadata {
351                slot_mappings: slot_mappings_map,
352                block_tables: Some(block_tables_map),
353                context_lens: Some(context_lens_map),
354                max_context_len: Some(max_context_len),
355                is_first_prompt_chunk: chunk_offset_toks == 0,
356            })
357        } else {
358            None
359        };
360
361        Ok(InputMetadata {
362            input,
363            positions: seqlen_offsets,
364            context_lens,
365            position_ids,
366            paged_attn_meta,
367            flash_meta: FlashParams {
368                max_k,
369                max_q,
370                cumulative_seqlens_k: seqlens_k_map,
371                cumulative_seqlens_q: seqlens_q_map,
372                causal: true,
373            },
374        })
375    }
376
377    fn make_completion_chunk<T: WithDType>(
378        toks: Vec<&[T]>,
379        input_seqs: &[&mut Sequence],
380        device: &Device,
381        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
382        mapper: Option<&dyn DeviceMapper>,
383    ) -> Result<InputMetadata> {
384        // Pad each sequence by the padding token to the max len.
385        let flash_attn = crate::using_flash_attn();
386        let mut seqs_tensors = Vec::new();
387        let mut seqlen_offsets = Vec::new();
388        let mut context_lens = Vec::new();
389        let mut position_ids = Vec::new();
390
391        let mut slot_mappings = Vec::new();
392        let mut block_tables = Vec::new();
393        let mut paged_attn_context_lens = Vec::new();
394        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
395        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
396        for (seq, ctxt) in input_seqs.iter().zip(toks) {
397            let start_pos = ctxt.len().saturating_sub(1);
398            let ctxt = ctxt[start_pos..].to_vec();
399            seqlen_offsets.push(start_pos);
400            context_lens.push((0, 1));
401            position_ids.push(seq.len());
402
403            if flash_attn {
404                seqlens_q.push(ctxt.len() as u32);
405                seqlens_k.push((ctxt.len() + start_pos) as u32);
406            }
407
408            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
409
410            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
411                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
412                let table = block_engine.block_tables.get(seq.id()).unwrap();
413
414                let table = table
415                    .iter()
416                    .map(|block| block.block_id())
417                    .collect::<Vec<_>>();
418
419                let block_pos = start_pos - seq.token_offset();
420                let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
421                    panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
422                } else {
423                    table
424                        .get(block_pos / paged_attn_metadata.block_size)
425                        .unwrap()
426                };
427                let block_offset = block_pos % paged_attn_metadata.block_size;
428                // Use checked arithmetic to prevent overflow
429                let slot = block_number
430                    .checked_mul(paged_attn_metadata.block_size)
431                    .and_then(|v| v.checked_add(block_offset))
432                    .expect("Slot calculation overflowed");
433                let slot = slot
434                    .try_into()
435                    .expect("Slot value too large for target integer type");
436                slot_mappings.push(vec![slot]);
437
438                if let Some(sliding_window) = paged_attn_metadata.sliding_window {
439                    let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
440                    let slide_idx = if table.len() > sliding_window_blocks {
441                        table.len() - sliding_window_blocks
442                    } else {
443                        0
444                    };
445                    block_tables.push(table.get(slide_idx..).unwrap().to_vec());
446                } else {
447                    block_tables.push(table);
448                }
449
450                let paged_attn_context_len =
451                    if let Some(sliding_window) = paged_attn_metadata.sliding_window {
452                        seq.len().min(sliding_window)
453                    } else {
454                        seq.len()
455                    };
456                paged_attn_context_lens.push(paged_attn_context_len);
457            }
458        }
459
460        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
461            // SAFETY: seqlens_q/k are initialized with vec![0] when flash_attn is true,
462            // so they are guaranteed to be non-empty here.
463            let max_q = *seqlens_q
464                .iter()
465                .max()
466                .expect("seqlens_q should not be empty when flash_attn is enabled");
467            let max_k = *seqlens_k
468                .iter()
469                .max()
470                .expect("seqlens_k should not be empty when flash_attn is enabled");
471            // Create tensors on CPU first to avoid CUDA context issues (see make_prompt_chunk)
472            let seqlens_q = Tensor::new(seqlens_q, &Device::Cpu)?
473                .to_dtype(DType::F32)?
474                .cumsum(0)?
475                .to_dtype(DType::U32)?;
476            let seqlens_k = Tensor::new(seqlens_k, &Device::Cpu)?
477                .to_dtype(DType::F32)?
478                .cumsum(0)?
479                .to_dtype(DType::U32)?;
480
481            let mut seqlens_q_map = HashMap::new();
482            let mut seqlens_k_map = HashMap::new();
483
484            let devices = mapper.unwrap().get_unique_devices();
485            for device in devices {
486                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
487                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
488            }
489            (max_q, max_k, seqlens_q_map, seqlens_k_map)
490        } else {
491            (0, 0, HashMap::new(), HashMap::new())
492        };
493
494        let paged_attn_meta = if paged_attn_metadata.is_some() {
495            // Create paged attention tensors on CPU first (see make_prompt_chunk for explanation)
496            let slot_mappings =
497                _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, &Device::Cpu)?;
498
499            let max_block_table_len = block_tables
500                .iter()
501                .map(|x| x.len())
502                .max()
503                .expect("block_tables should not be empty when paged attention is enabled");
504
505            let block_tables = _make_tensor_with_pad(
506                block_tables
507                    .iter()
508                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
509                    .collect::<Vec<_>>(),
510                max_block_table_len,
511                0,
512                &Device::Cpu,
513            )?;
514            let block_tables = block_tables.reshape(((), max_block_table_len))?;
515
516            let max_context_len = paged_attn_context_lens.iter().max().unwrap();
517
518            let context_lens = Tensor::from_vec(
519                paged_attn_context_lens
520                    .iter()
521                    .map(|x| *x as u32)
522                    .collect::<Vec<_>>(),
523                (paged_attn_context_lens.len(),),
524                &Device::Cpu,
525            )?;
526
527            // For device mapping, make a copy of each tensor for each device
528            let devices = mapper.unwrap().get_unique_devices();
529            let mut slot_mappings_map = HashMap::new();
530            let mut block_tables_map = HashMap::new();
531            let mut context_lens_map = HashMap::new();
532
533            for device in devices {
534                slot_mappings_map
535                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
536                block_tables_map
537                    .insert(device.location(), block_tables.clone().to_device(&device)?);
538                context_lens_map
539                    .insert(device.location(), context_lens.clone().to_device(&device)?);
540            }
541
542            Some(PagedAttentionInputMetadata {
543                slot_mappings: slot_mappings_map,
544                block_tables: Some(block_tables_map),
545                context_lens: Some(context_lens_map),
546                max_context_len: Some(*max_context_len),
547                is_first_prompt_chunk: false,
548            })
549        } else {
550            None
551        };
552
553        Ok(InputMetadata {
554            input: Tensor::cat(&seqs_tensors, 0).unwrap(),
555            positions: seqlen_offsets,
556            context_lens,
557            position_ids,
558            paged_attn_meta,
559            flash_meta: FlashParams {
560                max_k,
561                max_q,
562                cumulative_seqlens_k: seqlens_k_map,
563                cumulative_seqlens_q: seqlens_q_map,
564                causal: true,
565            },
566        })
567    }
568
569    #[allow(clippy::too_many_arguments)]
570    pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
571        toks: Vec<&[T]>,
572        input_seqs: &[&mut Sequence],
573        device: &Device,
574        last_n_context_len: Option<(usize, usize)>,
575        return_raw_logits: bool,
576        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
577        mapper: Option<&dyn DeviceMapper>,
578    ) -> Result<InnerInputProcessorOutput> {
579        let offset = input_seqs[0].token_offset();
580        make_prompt_chunk(
581            offset,
582            toks,
583            &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
584            device,
585            last_n_context_len,
586            return_raw_logits,
587            paged_attn_metadata,
588            mapper,
589        )
590        .map(|inputs| InnerInputProcessorOutput {
591            inputs,
592            seq_indices: (0..input_seqs.len()).collect(),
593        })
594    }
595
596    #[allow(clippy::too_many_arguments)]
597    pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
598        toks: Vec<&[T]>,
599        input_seqs: &[&mut Sequence],
600        device: &Device,
601        no_kv_cache: bool,
602        last_n_context_len: Option<(usize, usize)>,
603        return_raw_logits: bool,
604        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
605        mapper: Option<&dyn DeviceMapper>,
606    ) -> Result<InnerInputProcessorOutput> {
607        if no_kv_cache {
608            return get_prompt_input(
609                toks,
610                input_seqs,
611                device,
612                last_n_context_len,
613                return_raw_logits,
614                paged_attn_metadata,
615                mapper,
616            );
617        }
618
619        make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(|inputs| {
620            InnerInputProcessorOutput {
621                inputs,
622                seq_indices: (0..input_seqs.len()).collect(),
623            }
624        })
625    }
626
627    #[derive(Clone)]
628    pub struct ModelInputs {
629        pub input_ids: Tensor,
630        pub input_ids_full: Option<Tensor>,
631        pub seqlen_offsets: Vec<usize>,
632        pub seqlen_offsets_full: Option<Vec<usize>>,
633        pub context_lens: Vec<(usize, usize)>,
634        pub position_ids: Vec<usize>,
635        pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
636        pub flash_meta: FlashParams,
637        pub flash_meta_full: Option<FlashParams>,
638    }
639
640    pub struct TextInputsProcessor;
641
642    impl InputsProcessor for TextInputsProcessor {
643        fn process_inputs(
644            &self,
645            _: Option<Arc<Tokenizer>>,
646            input_seqs: &mut [&mut Sequence],
647            is_prompt: bool,
648            is_xlora: bool,
649            device: &Device,
650            no_kv_cache: bool,
651            last_n_context_len: Option<(usize, usize)>,
652            return_raw_logits: bool,
653            _: Option<Arc<dyn Any>>,
654            mut paged_attn_metadata: Option<PagedAttentionMeta>,
655            mapper: Option<&dyn DeviceMapper>,
656        ) -> Result<InputProcessorOutput> {
657            if is_xlora && !is_prompt {
658                let prompt = get_prompt_input(
659                    input_seqs
660                        .iter()
661                        .map(|seq| seq.get_toks())
662                        .collect::<Vec<_>>(),
663                    input_seqs,
664                    device,
665                    last_n_context_len,
666                    return_raw_logits,
667                    paged_attn_metadata.as_mut(),
668                    mapper,
669                )?;
670                let completion = get_completion_input(
671                    input_seqs
672                        .iter()
673                        .map(|seq| seq.get_toks())
674                        .collect::<Vec<_>>(),
675                    input_seqs,
676                    device,
677                    no_kv_cache,
678                    last_n_context_len,
679                    return_raw_logits,
680                    paged_attn_metadata.as_mut(),
681                    mapper,
682                )?;
683                let InnerInputProcessorOutput {
684                    inputs:
685                        InputMetadata {
686                            input: input_ids_full,
687                            positions: seqlen_offsets_full,
688                            context_lens: _,
689                            position_ids,
690                            paged_attn_meta: _,
691                            flash_meta: flash_meta_full,
692                        },
693                    seq_indices,
694                } = prompt;
695                let InnerInputProcessorOutput {
696                    inputs:
697                        InputMetadata {
698                            input: input_ids,
699                            positions: seqlen_offsets,
700                            context_lens,
701                            position_ids: _,
702                            paged_attn_meta,
703                            flash_meta,
704                        },
705                    seq_indices: _,
706                } = completion;
707                let inputs: Box<dyn Any> = Box::new(ModelInputs {
708                    input_ids,
709                    input_ids_full: Some(input_ids_full),
710                    seqlen_offsets,
711                    seqlen_offsets_full: Some(seqlen_offsets_full),
712                    context_lens,
713                    position_ids,
714                    paged_attn_meta,
715                    flash_meta,
716                    flash_meta_full: Some(flash_meta_full),
717                });
718                Ok(InputProcessorOutput {
719                    inputs,
720                    seq_indices,
721                })
722            } else if is_xlora && is_prompt {
723                let metadata = get_prompt_input(
724                    input_seqs
725                        .iter()
726                        .map(|seq| seq.get_toks())
727                        .collect::<Vec<_>>(),
728                    input_seqs,
729                    device,
730                    last_n_context_len,
731                    return_raw_logits,
732                    paged_attn_metadata.as_mut(),
733                    mapper,
734                )?;
735                let InnerInputProcessorOutput {
736                    inputs:
737                        InputMetadata {
738                            input: input_ids,
739                            positions: seqlen_offsets,
740                            context_lens,
741                            position_ids,
742                            paged_attn_meta,
743                            flash_meta,
744                        },
745                    seq_indices,
746                } = metadata;
747                let inputs: Box<dyn Any> = Box::new(ModelInputs {
748                    input_ids: input_ids.clone(),
749                    input_ids_full: Some(input_ids),
750                    seqlen_offsets: seqlen_offsets.clone(),
751                    seqlen_offsets_full: Some(seqlen_offsets),
752                    context_lens,
753                    position_ids,
754                    paged_attn_meta,
755                    flash_meta: flash_meta.clone(),
756                    flash_meta_full: Some(flash_meta),
757                });
758                Ok(InputProcessorOutput {
759                    inputs,
760                    seq_indices,
761                })
762            } else if is_prompt {
763                let metadata = get_prompt_input(
764                    input_seqs
765                        .iter()
766                        .map(|seq| seq.get_toks())
767                        .collect::<Vec<_>>(),
768                    input_seqs,
769                    device,
770                    last_n_context_len,
771                    return_raw_logits,
772                    paged_attn_metadata.as_mut(),
773                    mapper,
774                )?;
775                let InnerInputProcessorOutput {
776                    inputs:
777                        InputMetadata {
778                            input: input_ids,
779                            positions: seqlen_offsets,
780                            context_lens,
781                            position_ids,
782                            paged_attn_meta,
783                            flash_meta,
784                        },
785                    seq_indices,
786                } = metadata;
787                let inputs: Box<dyn Any> = Box::new(ModelInputs {
788                    input_ids,
789                    input_ids_full: None,
790                    seqlen_offsets,
791                    seqlen_offsets_full: None,
792                    context_lens,
793                    position_ids,
794                    paged_attn_meta,
795                    flash_meta,
796                    flash_meta_full: None,
797                });
798                Ok(InputProcessorOutput {
799                    inputs,
800                    seq_indices,
801                })
802            } else {
803                let metadata = get_completion_input(
804                    input_seqs
805                        .iter()
806                        .map(|seq| seq.get_toks())
807                        .collect::<Vec<_>>(),
808                    input_seqs,
809                    device,
810                    no_kv_cache,
811                    last_n_context_len,
812                    return_raw_logits,
813                    paged_attn_metadata.as_mut(),
814                    mapper,
815                )?;
816                let InnerInputProcessorOutput {
817                    inputs:
818                        InputMetadata {
819                            input: input_ids,
820                            positions: seqlen_offsets,
821                            context_lens,
822                            position_ids,
823                            paged_attn_meta,
824                            flash_meta,
825                        },
826                    seq_indices,
827                } = metadata;
828                let inputs: Box<dyn Any> = Box::new(ModelInputs {
829                    input_ids,
830                    input_ids_full: None,
831                    seqlen_offsets,
832                    seqlen_offsets_full: None,
833                    context_lens,
834                    position_ids,
835                    paged_attn_meta,
836                    flash_meta,
837                    flash_meta_full: None,
838                });
839                Ok(InputProcessorOutput {
840                    inputs,
841                    seq_indices,
842                })
843            }
844        }
845
846        fn get_type(&self) -> InputsProcessorType {
847            InputsProcessorType::Text
848        }
849    }
850}