mistralrs_core/pipeline/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12pub const DEFAULT_PROMPT_CHUNK_SIZE: usize = 1024;
13
14#[derive(PartialEq)]
15pub enum InputsProcessorType {
16    Text,
17    Vision,
18}
19
20pub struct InputProcessorOutput {
21    pub inputs: Box<dyn Any>,
22    pub seq_indices: Vec<usize>,
23}
24
25/// Processor: Prepare inputs for the model (potentially preparing the images if applicable)
26pub trait InputsProcessor {
27    /// This should also enable matmul via f16 if prompt and the sequence length is greater than 32.
28    /// Otherwise, matmul via f16 is disabled.
29    ///
30    /// This should return a type which can be downcasted to the proper type as used in `forward_inputs`
31    #[allow(clippy::too_many_arguments)]
32    fn process_inputs(
33        &self,
34        tokenizer: Option<Arc<Tokenizer>>,
35        input_seqs: &mut [&mut Sequence],
36        is_prompt: bool,
37        is_xlora: bool,
38        device: &Device,
39        no_kv_cache: bool,
40        last_n_context_len: Option<(usize, usize)>,
41        return_raw_logits: bool,
42        other_config: Option<Arc<dyn Any>>,
43        paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
44        prompt_chunksize: Option<NonZeroUsize>,
45        mapper: Option<&dyn DeviceMapper>,
46    ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>>;
47
48    fn get_type(&self) -> InputsProcessorType;
49}
50
51// ========================= Test models input processor
52
53pub mod text_models_inputs_processor {
54    use std::{any::Any, collections::HashMap, fmt::Debug, num::NonZeroUsize, sync::Arc};
55
56    use anyhow::Result;
57    use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
58    use tokenizers::Tokenizer;
59
60    use crate::{
61        device_map::DeviceMapper,
62        paged_attention::{BlockEngine, _PAD_SLOT_ID},
63        sequence::Sequence,
64        using_flash_attn,
65    };
66
67    use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
68
69    #[inline(always)]
70    fn _make_tensor_with_pad<D: WithDType>(
71        x: Vec<Vec<D>>,
72        max_len: usize,
73        pad: D,
74        device: &Device,
75    ) -> Result<Tensor> {
76        let bs = x.len();
77        let mut padded_x = Vec::with_capacity(x.len() * max_len);
78        for mut x_i in x {
79            assert!(x_i.len() <= max_len);
80            x_i.extend(std::iter::repeat_n(pad, max_len.saturating_sub(x_i.len())));
81            padded_x.extend(x_i);
82        }
83        Tensor::from_vec(padded_x, (bs, max_len), device).map_err(anyhow::Error::msg)
84    }
85
86    pub struct PagedAttentionMeta<'a> {
87        pub sliding_window: Option<usize>,
88        pub block_size: usize,
89        pub block_engine: &'a mut BlockEngine,
90    }
91
92    #[derive(Clone, Debug)]
93    #[allow(dead_code)]
94    pub struct PagedAttentionInputMetadata {
95        pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
96        pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
97        pub slot_mappings: HashMap<DeviceLocation, Tensor>,
98        pub max_context_len: Option<usize>,
99        pub is_first_prompt_chunk: bool,
100    }
101
102    impl PagedAttentionInputMetadata {
103        /// Create a dummy input metadata, assuming that this will NOT be used for decoding.
104        /// This is used for the case of imatrix generation.
105        pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
106            Ok(PagedAttentionInputMetadata {
107                block_tables: None,
108                context_lens: None,
109                max_context_len: None,
110                slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
111                is_first_prompt_chunk: true,
112            })
113        }
114    }
115
116    #[derive(Clone, Debug)]
117    pub struct FlashParams {
118        pub max_q: u32,
119        pub max_k: u32,
120        pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
121        pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
122    }
123
124    pub struct InputMetadata {
125        pub input: Tensor,
126        pub positions: Vec<usize>,
127        pub context_lens: Vec<(usize, usize)>, // (start index, len)
128        pub position_ids: Vec<usize>,
129        pub paged_attn_meta: Option<PagedAttentionInputMetadata>, // For paged attention
130        pub flash_meta: FlashParams,
131    }
132
133    pub struct InnerInputProcessorOutput {
134        pub inputs: InputMetadata,
135        pub seq_indices: Vec<usize>,
136    }
137
138    // chunk_offset_toks is the number of tokens by which the tokens are offset,
139    // chunk_offset_toks / prompt_chunksize = number of batches
140    #[allow(clippy::too_many_arguments)]
141    pub fn make_prompt_chunk<T: WithDType + Debug>(
142        chunk_offset_toks: usize,
143        toks: Vec<Vec<T>>,
144        seq_ids: &[usize],
145        device: &Device,
146        last_n_context_len: Option<(usize, usize)>,
147        return_raw_logits: bool,
148        mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
149        mapper: Option<&dyn DeviceMapper>,
150    ) -> Result<InputMetadata> {
151        let max_len = toks
152            .iter()
153            .map(|seq| seq.len())
154            .max()
155            .expect("No sequences");
156        let padding_tok = T::zero();
157        // Pad each sequence by the padding token to the max len.
158        let mut seqs_tensors = Vec::new();
159        let mut seqlen_offsets = Vec::new();
160        let mut context_lens = Vec::new();
161        let mut position_ids = Vec::new();
162        let mut slot_mappings = Vec::new();
163        let mut block_tables = Vec::new();
164        let mut paged_attn_context_lens = Vec::new();
165        let mut seqlens_q = vec![0];
166        let mut seqlens_k = vec![0];
167        for (seq_id, mut ctxt) in seq_ids.iter().zip(toks) {
168            let prompt_len = ctxt.len();
169            let offset = last_n_context_len.unwrap_or_default();
170            seqlen_offsets.push(offset.1 + chunk_offset_toks);
171
172            position_ids.push(ctxt.len() + chunk_offset_toks);
173            ctxt.extend(std::iter::repeat_n(
174                padding_tok,
175                max_len.saturating_sub(ctxt.len()),
176            ));
177            // If we are returning raw logits, we want to not trim the logits at all.
178            if return_raw_logits {
179                if last_n_context_len.is_some() {
180                    anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
181                }
182
183                context_lens.push((0, ctxt.len()));
184            } else {
185                context_lens.push((
186                    ctxt.len() - last_n_context_len.map(|(a, _)| a).unwrap_or(1),
187                    last_n_context_len.map(|(a, _)| a).unwrap_or(1),
188                ));
189            }
190
191            seqlens_q.push(ctxt.len() as u32);
192            seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
193
194            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
195
196            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
197                let table = paged_attn_metadata.block_engine.block_tables.get(seq_id);
198
199                if table.is_none() {
200                    // Will be None during profiling.
201                    slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
202                    continue;
203                }
204                let table = table
205                    .unwrap()
206                    .iter()
207                    .map(|block| block.deref_mut().block_id)
208                    .collect::<Vec<_>>();
209
210                let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
211                    if prompt_len > sliding_window {
212                        chunk_offset_toks.min(prompt_len - sliding_window)
213                    } else {
214                        chunk_offset_toks
215                    }
216                } else {
217                    chunk_offset_toks
218                };
219
220                let mut slot_mapping = Vec::new();
221                let mut ctxt_len = Vec::new();
222                for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
223                    if i < start_idx {
224                        // Pad [0,start_idx) with _PAD_TOKEN_ID
225                        slot_mapping.push(_PAD_SLOT_ID);
226                    }
227                    ctxt_len.push(i);
228
229                    let block_number = if i / paged_attn_metadata.block_size >= table.len() {
230                        panic!(
231                            "Block table is too small (prompt)! i={} block_size={} table_len={}",
232                            i,
233                            paged_attn_metadata.block_size,
234                            table.len()
235                        );
236                    } else {
237                        table.get(i / paged_attn_metadata.block_size).unwrap()
238                    };
239                    let block_offset = i % paged_attn_metadata.block_size;
240                    let slot = block_number * paged_attn_metadata.block_size + block_offset;
241                    slot_mapping.push(slot.try_into().unwrap());
242                    block_tables.push(table.clone());
243                }
244                slot_mappings.push(slot_mapping);
245                paged_attn_context_lens.push(ctxt_len);
246            }
247        }
248
249        let max_q = *seqlens_q.iter().max().unwrap();
250        let max_k = *seqlens_k.iter().max().unwrap();
251
252        let mut seqlens_q_map = HashMap::new();
253        let mut seqlens_k_map = HashMap::new();
254
255        if using_flash_attn() {
256            let seqlens_q = Tensor::new(seqlens_q, device)?
257                .to_dtype(DType::F32)?
258                .cumsum(0)?
259                .to_dtype(DType::U32)?;
260            let seqlens_k = Tensor::new(seqlens_k, device)?
261                .to_dtype(DType::F32)?
262                .cumsum(0)?
263                .to_dtype(DType::U32)?;
264
265            let devices = mapper.unwrap().get_unique_devices();
266            for device in devices {
267                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
268                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
269            }
270        }
271
272        let input = Tensor::cat(&seqs_tensors, 0).unwrap();
273
274        let paged_attn_meta = if paged_attn_metadata.is_some() {
275            let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
276            let slot_mappings =
277                _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
278
279            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
280            let block_tables = _make_tensor_with_pad(
281                block_tables
282                    .iter()
283                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
284                    .collect::<Vec<_>>(),
285                max_block_table_len,
286                0,
287                device,
288            )?;
289            let block_tables = block_tables.reshape(((), max_block_table_len))?;
290
291            let max_context_len = paged_attn_context_lens
292                .iter()
293                .map(|x| x.len())
294                .max()
295                .unwrap();
296
297            let context_lens = _make_tensor_with_pad(
298                paged_attn_context_lens
299                    .iter()
300                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
301                    .collect::<Vec<_>>(),
302                max_context_len,
303                0,
304                device,
305            )?
306            .reshape(((),))?;
307
308            // For device mapping, make a copy of each tensor for each device
309            let devices = mapper.unwrap().get_unique_devices();
310            let mut slot_mappings_map = HashMap::new();
311            let mut block_tables_map = HashMap::new();
312            let mut context_lens_map = HashMap::new();
313
314            for device in devices {
315                slot_mappings_map
316                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
317                block_tables_map
318                    .insert(device.location(), block_tables.clone().to_device(&device)?);
319                context_lens_map
320                    .insert(device.location(), context_lens.clone().to_device(&device)?);
321            }
322
323            Some(PagedAttentionInputMetadata {
324                slot_mappings: slot_mappings_map,
325                block_tables: Some(block_tables_map),
326                context_lens: Some(context_lens_map),
327                max_context_len: Some(max_context_len),
328                is_first_prompt_chunk: chunk_offset_toks == 0,
329            })
330        } else {
331            None
332        };
333
334        Ok(InputMetadata {
335            input,
336            positions: seqlen_offsets,
337            context_lens,
338            position_ids,
339            paged_attn_meta,
340            flash_meta: FlashParams {
341                max_k,
342                max_q,
343                cumulative_seqlens_k: seqlens_k_map,
344                cumulative_seqlens_q: seqlens_q_map,
345            },
346        })
347    }
348
349    fn make_completion_chunk<T: WithDType>(
350        toks: Vec<Vec<T>>,
351        input_seqs: &[&mut Sequence],
352        device: &Device,
353        mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
354        mapper: Option<&dyn DeviceMapper>,
355    ) -> Result<InputMetadata> {
356        // Pad each sequence by the padding token to the max len.
357        let mut seqs_tensors = Vec::new();
358        let mut seqlen_offsets = Vec::new();
359        let mut context_lens = Vec::new();
360        let mut position_ids = Vec::new();
361
362        let mut slot_mappings = Vec::new();
363        let mut block_tables = Vec::new();
364        let mut paged_attn_context_lens = Vec::new();
365        let mut seqlens_q = vec![0];
366        let mut seqlens_k = vec![0];
367        for (seq, ctxt) in input_seqs.iter().zip(toks) {
368            let start_pos = ctxt.len().saturating_sub(1);
369            let ctxt = ctxt[start_pos..].to_vec();
370            seqlen_offsets.push(start_pos);
371            context_lens.push((0, 1));
372            position_ids.push(seq.len());
373
374            seqlens_q.push(ctxt.len() as u32);
375            seqlens_k.push((ctxt.len() + start_pos) as u32);
376
377            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
378
379            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
380                let table = paged_attn_metadata
381                    .block_engine
382                    .block_tables
383                    .get(seq.id())
384                    .unwrap();
385
386                let table = table
387                    .iter()
388                    .map(|block| block.deref_mut().block_id)
389                    .collect::<Vec<_>>();
390
391                let block_number = if start_pos / paged_attn_metadata.block_size >= table.len() {
392                    panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", start_pos, paged_attn_metadata.block_size, table.len());
393                } else {
394                    table
395                        .get(start_pos / paged_attn_metadata.block_size)
396                        .unwrap()
397                };
398                let block_offset = start_pos % paged_attn_metadata.block_size;
399                let slot = block_number * paged_attn_metadata.block_size + block_offset;
400                let slot = slot.try_into().unwrap();
401                slot_mappings.push(vec![slot]);
402
403                if let Some(sliding_window) = paged_attn_metadata.sliding_window {
404                    let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
405                    let slide_idx = if table.len() > sliding_window_blocks {
406                        table.len() - sliding_window_blocks
407                    } else {
408                        0
409                    };
410                    block_tables.push(table.get(slide_idx..).unwrap().to_vec());
411                } else {
412                    block_tables.push(table);
413                }
414
415                let paged_attn_context_len =
416                    if let Some(sliding_window) = paged_attn_metadata.sliding_window {
417                        seq.len().min(sliding_window)
418                    } else {
419                        seq.len()
420                    };
421                paged_attn_context_lens.push(paged_attn_context_len);
422            }
423        }
424
425        let max_q = *seqlens_q.iter().max().unwrap();
426        let max_k = *seqlens_k.iter().max().unwrap();
427
428        let mut seqlens_q_map = HashMap::new();
429        let mut seqlens_k_map = HashMap::new();
430
431        if using_flash_attn() {
432            let seqlens_q = Tensor::new(seqlens_q, device)?
433                .to_dtype(DType::F32)?
434                .cumsum(0)?
435                .to_dtype(DType::U32)?;
436            let seqlens_k = Tensor::new(seqlens_k, device)?
437                .to_dtype(DType::F32)?
438                .cumsum(0)?
439                .to_dtype(DType::U32)?;
440
441            let devices = mapper.unwrap().get_unique_devices();
442            for device in devices {
443                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
444                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
445            }
446        }
447
448        let paged_attn_meta = if paged_attn_metadata.is_some() {
449            let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
450
451            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
452
453            let block_tables = _make_tensor_with_pad(
454                block_tables
455                    .iter()
456                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
457                    .collect::<Vec<_>>(),
458                max_block_table_len,
459                0,
460                device,
461            )?;
462            let block_tables = block_tables.reshape(((), max_block_table_len))?;
463
464            let max_context_len = paged_attn_context_lens.iter().max().unwrap();
465
466            let context_lens = Tensor::from_vec(
467                paged_attn_context_lens
468                    .iter()
469                    .map(|x| *x as u32)
470                    .collect::<Vec<_>>(),
471                (paged_attn_context_lens.len(),),
472                device,
473            )?;
474
475            // For device mapping, make a copy of each tensor for each device
476            let devices = mapper.unwrap().get_unique_devices();
477            let mut slot_mappings_map = HashMap::new();
478            let mut block_tables_map = HashMap::new();
479            let mut context_lens_map = HashMap::new();
480
481            for device in devices {
482                slot_mappings_map
483                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
484                block_tables_map
485                    .insert(device.location(), block_tables.clone().to_device(&device)?);
486                context_lens_map
487                    .insert(device.location(), context_lens.clone().to_device(&device)?);
488            }
489
490            Some(PagedAttentionInputMetadata {
491                slot_mappings: slot_mappings_map,
492                block_tables: Some(block_tables_map),
493                context_lens: Some(context_lens_map),
494                max_context_len: Some(*max_context_len),
495                is_first_prompt_chunk: false,
496            })
497        } else {
498            None
499        };
500
501        Ok(InputMetadata {
502            input: Tensor::cat(&seqs_tensors, 0).unwrap(),
503            positions: seqlen_offsets,
504            context_lens,
505            position_ids,
506            paged_attn_meta,
507            flash_meta: FlashParams {
508                max_k,
509                max_q,
510                cumulative_seqlens_k: seqlens_k_map,
511                cumulative_seqlens_q: seqlens_q_map,
512            },
513        })
514    }
515
516    #[allow(clippy::too_many_arguments)]
517    pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
518        toks: Vec<Vec<T>>,
519        input_seqs: &[&mut Sequence],
520        device: &Device,
521        last_n_context_len: Option<(usize, usize)>,
522        return_raw_logits: bool,
523        mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
524        prompt_chunksize: Option<NonZeroUsize>,
525        mapper: Option<&dyn DeviceMapper>,
526    ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
527        if let (Some(prompt_chunksize), true) = (prompt_chunksize, paged_attn_metadata.is_none()) {
528            let mut seq_chunks = Vec::new();
529            let mut n_chunks = Vec::new();
530            let prompt_chunksize: usize = prompt_chunksize.into();
531
532            // This comes from prefix caching
533            // The invariant where all token offsets are the same is handled by the scheduler
534            let offset = input_seqs[0].token_offset();
535
536            // Pad each sequence by the padding token to the max len.
537            for ctxt in toks.iter() {
538                let chunks = ctxt.chunks(prompt_chunksize).collect::<Vec<_>>();
539                n_chunks.push(chunks.len());
540                seq_chunks.push(chunks);
541            }
542            // Basically convert the sequences and tok chunks into chunks of seqs and the corresp toks
543            let mut chunks_transposed: Vec<Vec<(Vec<T>, usize)>> = Vec::new();
544            for (seq_n, seq) in seq_chunks.into_iter().enumerate() {
545                for (i, chunk) in seq.into_iter().enumerate() {
546                    match chunks_transposed.get_mut(i) {
547                        Some(part) => part.push((chunk.to_vec(), seq_n)),
548                        None => chunks_transposed.push(vec![(chunk.to_vec(), seq_n)]),
549                    }
550                }
551            }
552            let chunks = chunks_transposed
553                .into_iter()
554                .enumerate()
555                .map(|(i, chunk)| {
556                    let (toks, seq_ns): (Vec<Vec<T>>, Vec<usize>) = chunk.into_iter().unzip();
557                    make_prompt_chunk(
558                        i * prompt_chunksize + offset,
559                        toks,
560                        &seq_ns
561                            .iter()
562                            .map(|i| *input_seqs[*i].id())
563                            .collect::<Vec<_>>(),
564                        device,
565                        last_n_context_len,
566                        return_raw_logits,
567                        paged_attn_metadata.as_deref_mut(),
568                        mapper,
569                    )
570                    .map(|inputs| InnerInputProcessorOutput {
571                        inputs,
572                        seq_indices: seq_ns,
573                    })
574                })
575                .collect::<Vec<_>>();
576            Box::new(chunks.into_iter())
577        } else {
578            let offset = input_seqs[0].token_offset();
579            if offset != 0 && paged_attn_metadata.is_some() {
580                return Box::new(std::iter::once(Err(anyhow::Error::msg(
581                    "PagedAttention does not yet support sequences with an offset != 0.",
582                ))));
583            }
584            Box::new(std::iter::once(
585                make_prompt_chunk(
586                    offset,
587                    toks,
588                    &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
589                    device,
590                    last_n_context_len,
591                    return_raw_logits,
592                    paged_attn_metadata,
593                    mapper,
594                )
595                .map(|inputs| InnerInputProcessorOutput {
596                    inputs,
597                    seq_indices: (0..input_seqs.len()).collect(),
598                }),
599            ))
600        }
601    }
602
603    #[allow(clippy::too_many_arguments)]
604    pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
605        toks: Vec<Vec<T>>,
606        input_seqs: &[&mut Sequence],
607        device: &Device,
608        no_kv_cache: bool,
609        last_n_context_len: Option<(usize, usize)>,
610        return_raw_logits: bool,
611        paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
612        prompt_chunksize: Option<NonZeroUsize>,
613        mapper: Option<&dyn DeviceMapper>,
614    ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
615        if no_kv_cache {
616            return get_prompt_input(
617                toks,
618                input_seqs,
619                device,
620                last_n_context_len,
621                return_raw_logits,
622                paged_attn_metadata,
623                prompt_chunksize,
624                mapper,
625            );
626        }
627
628        Box::new(std::iter::once(
629            make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(
630                |inputs| InnerInputProcessorOutput {
631                    inputs,
632                    seq_indices: (0..input_seqs.len()).collect(),
633                },
634            ),
635        ))
636    }
637
638    #[derive(Clone)]
639    pub struct ModelInputs {
640        pub input_ids: Tensor,
641        pub input_ids_full: Option<Tensor>,
642        pub seqlen_offsets: Vec<usize>,
643        pub seqlen_offsets_full: Option<Vec<usize>>,
644        pub context_lens: Vec<(usize, usize)>,
645        pub position_ids: Vec<usize>,
646        pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
647        pub flash_meta: FlashParams,
648        pub flash_meta_full: Option<FlashParams>,
649    }
650
651    pub struct TextInputsProcessor;
652
653    impl InputsProcessor for TextInputsProcessor {
654        fn process_inputs(
655            &self,
656            _: Option<Arc<Tokenizer>>,
657            input_seqs: &mut [&mut Sequence],
658            is_prompt: bool,
659            is_xlora: bool,
660            device: &Device,
661            no_kv_cache: bool,
662            last_n_context_len: Option<(usize, usize)>,
663            return_raw_logits: bool,
664            _: Option<Arc<dyn Any>>,
665            mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
666            prompt_chunksize: Option<NonZeroUsize>,
667            mapper: Option<&dyn DeviceMapper>,
668        ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
669            if is_xlora && !is_prompt {
670                Box::new(
671                    get_prompt_input(
672                        input_seqs
673                            .iter()
674                            .map(|seq| seq.get_toks().to_vec())
675                            .collect::<Vec<_>>(),
676                        input_seqs,
677                        device,
678                        last_n_context_len,
679                        return_raw_logits,
680                        paged_attn_metadata.as_mut(),
681                        prompt_chunksize,
682                        mapper,
683                    )
684                    .zip(get_completion_input(
685                        input_seqs
686                            .iter()
687                            .map(|seq| seq.get_toks().to_vec())
688                            .collect::<Vec<_>>(),
689                        input_seqs,
690                        device,
691                        no_kv_cache,
692                        last_n_context_len,
693                        return_raw_logits,
694                        paged_attn_metadata.as_mut(),
695                        prompt_chunksize,
696                        mapper,
697                    ))
698                    .map(|(prompt, completion)| {
699                        let InnerInputProcessorOutput {
700                            inputs:
701                                InputMetadata {
702                                    input: input_ids_full,
703                                    positions: seqlen_offsets_full,
704                                    context_lens: _,
705                                    position_ids,
706                                    paged_attn_meta: _,
707                                    flash_meta: flash_meta_full,
708                                },
709                            seq_indices,
710                        } = prompt?;
711                        let InnerInputProcessorOutput {
712                            inputs:
713                                InputMetadata {
714                                    input: input_ids,
715                                    positions: seqlen_offsets,
716                                    context_lens,
717                                    position_ids: _,
718                                    paged_attn_meta,
719                                    flash_meta,
720                                },
721                            seq_indices: _,
722                        } = completion?;
723                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
724                            input_ids,
725                            input_ids_full: Some(input_ids_full),
726                            seqlen_offsets,
727                            seqlen_offsets_full: Some(seqlen_offsets_full),
728                            context_lens,
729                            position_ids,
730                            paged_attn_meta,
731                            flash_meta,
732                            flash_meta_full: Some(flash_meta_full),
733                        });
734                        Ok(InputProcessorOutput {
735                            inputs,
736                            seq_indices,
737                        })
738                    }),
739                )
740            } else if is_xlora && is_prompt {
741                Box::new(
742                    get_prompt_input(
743                        input_seqs
744                            .iter()
745                            .map(|seq| seq.get_toks().to_vec())
746                            .collect::<Vec<_>>(),
747                        input_seqs,
748                        device,
749                        last_n_context_len,
750                        return_raw_logits,
751                        paged_attn_metadata.as_mut(),
752                        prompt_chunksize,
753                        mapper,
754                    )
755                    .map(|metadata| {
756                        let InnerInputProcessorOutput {
757                            inputs:
758                                InputMetadata {
759                                    input: input_ids,
760                                    positions: seqlen_offsets,
761                                    context_lens,
762                                    position_ids,
763                                    paged_attn_meta,
764                                    flash_meta,
765                                },
766                            seq_indices,
767                        } = metadata?;
768                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
769                            input_ids: input_ids.clone(),
770                            input_ids_full: Some(input_ids),
771                            seqlen_offsets: seqlen_offsets.clone(),
772                            seqlen_offsets_full: Some(seqlen_offsets),
773                            context_lens,
774                            position_ids,
775                            paged_attn_meta,
776                            flash_meta: flash_meta.clone(),
777                            flash_meta_full: Some(flash_meta),
778                        });
779                        Ok(InputProcessorOutput {
780                            inputs,
781                            seq_indices,
782                        })
783                    }),
784                )
785            } else if is_prompt {
786                Box::new(
787                    get_prompt_input(
788                        input_seqs
789                            .iter()
790                            .map(|seq| seq.get_toks().to_vec())
791                            .collect::<Vec<_>>(),
792                        input_seqs,
793                        device,
794                        last_n_context_len,
795                        return_raw_logits,
796                        paged_attn_metadata.as_mut(),
797                        prompt_chunksize,
798                        mapper,
799                    )
800                    .map(|metadata| {
801                        let InnerInputProcessorOutput {
802                            inputs:
803                                InputMetadata {
804                                    input: input_ids,
805                                    positions: seqlen_offsets,
806                                    context_lens,
807                                    position_ids,
808                                    paged_attn_meta,
809                                    flash_meta,
810                                },
811                            seq_indices,
812                        } = metadata?;
813                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
814                            input_ids,
815                            input_ids_full: None,
816                            seqlen_offsets,
817                            seqlen_offsets_full: None,
818                            context_lens,
819                            position_ids,
820                            paged_attn_meta,
821                            flash_meta,
822                            flash_meta_full: None,
823                        });
824                        Ok(InputProcessorOutput {
825                            inputs,
826                            seq_indices,
827                        })
828                    }),
829                )
830            } else {
831                Box::new(
832                    get_completion_input(
833                        input_seqs
834                            .iter()
835                            .map(|seq| seq.get_toks().to_vec())
836                            .collect::<Vec<_>>(),
837                        input_seqs,
838                        device,
839                        no_kv_cache,
840                        last_n_context_len,
841                        return_raw_logits,
842                        paged_attn_metadata.as_mut(),
843                        prompt_chunksize,
844                        mapper,
845                    )
846                    .map(|metadata| {
847                        let InnerInputProcessorOutput {
848                            inputs:
849                                InputMetadata {
850                                    input: input_ids,
851                                    positions: seqlen_offsets,
852                                    context_lens,
853                                    position_ids,
854                                    paged_attn_meta,
855                                    flash_meta,
856                                },
857                            seq_indices,
858                        } = metadata?;
859                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
860                            input_ids,
861                            input_ids_full: None,
862                            seqlen_offsets,
863                            seqlen_offsets_full: None,
864                            context_lens,
865                            position_ids,
866                            paged_attn_meta,
867                            flash_meta,
868                            flash_meta_full: None,
869                        });
870                        Ok(InputProcessorOutput {
871                            inputs,
872                            seq_indices,
873                        })
874                    }),
875                )
876            }
877        }
878
879        fn get_type(&self) -> InputsProcessorType {
880            InputsProcessorType::Text
881        }
882    }
883}