mistralrs_core/pipeline/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use anyhow::Result;
6use candle_core::Device;
7use text_models_inputs_processor::PagedAttentionMeta;
8use tokenizers::Tokenizer;
9
10use crate::{device_map::DeviceMapper, sequence::Sequence};
11
12pub const DEFAULT_PROMPT_CHUNK_SIZE: usize = 1024;
13
14#[derive(PartialEq)]
15pub enum InputsProcessorType {
16    Text,
17    Vision,
18}
19
20pub struct InputProcessorOutput {
21    pub inputs: Box<dyn Any>,
22    pub seq_indices: Vec<usize>,
23}
24
25/// Processor: Prepare inputs for the model (potentially preparing the images if applicable)
26pub trait InputsProcessor {
27    /// This should also enable matmul via f16 if prompt and the sequence length is greater than 32.
28    /// Otherwise, matmul via f16 is disabled.
29    ///
30    /// This should return a type which can be downcasted to the proper type as used in `forward_inputs`
31    #[allow(clippy::too_many_arguments)]
32    fn process_inputs(
33        &self,
34        tokenizer: Option<Arc<Tokenizer>>,
35        input_seqs: &mut [&mut Sequence],
36        is_prompt: bool,
37        is_xlora: bool,
38        device: &Device,
39        no_kv_cache: bool,
40        last_n_context_len: Option<(usize, usize)>,
41        return_raw_logits: bool,
42        other_config: Option<Arc<dyn Any>>,
43        paged_attn_metadata: Option<PagedAttentionMeta>,
44        prompt_chunksize: Option<NonZeroUsize>,
45        mapper: Option<&dyn DeviceMapper>,
46    ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>>;
47
48    fn get_type(&self) -> InputsProcessorType;
49}
50
51// ========================= Test models input processor
52
53pub mod text_models_inputs_processor {
54    use std::{any::Any, collections::HashMap, fmt::Debug, num::NonZeroUsize, sync::Arc};
55
56    use anyhow::Result;
57    use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
58    use tokenizers::Tokenizer;
59
60    use crate::{
61        device_map::DeviceMapper,
62        get_mut_arcmutex,
63        paged_attention::{BlockEngine, _PAD_SLOT_ID},
64        sequence::Sequence,
65    };
66
67    use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};
68
69    fn _make_tensor_with_pad<D: WithDType>(
70        x: Vec<Vec<D>>,
71        max_len: usize,
72        pad: D,
73        device: &Device,
74    ) -> Result<Tensor> {
75        let mut padded_x = Vec::new();
76        for mut x_i in x {
77            assert!(x_i.len() <= max_len);
78            x_i.extend([pad].repeat(max_len - x_i.len()));
79            let shape = (x_i.len(),);
80            padded_x.push(Tensor::from_vec(x_i, shape, device)?);
81        }
82        Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
83    }
84
85    pub struct PagedAttentionMeta {
86        pub sliding_window: Option<usize>,
87        pub block_size: usize,
88        pub block_engine: Arc<tokio::sync::Mutex<BlockEngine>>,
89    }
90
91    #[derive(Clone, Debug)]
92    #[allow(dead_code)]
93    pub struct PagedAttentionInputMetadata {
94        pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
95        pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
96        pub slot_mappings: HashMap<DeviceLocation, Tensor>,
97        pub max_context_len: Option<usize>,
98        pub is_first_prompt_chunk: bool,
99    }
100
101    impl PagedAttentionInputMetadata {
102        /// Create a dummy input metadata, assuming that this will NOT be used for decoding.
103        /// This is used for the case of imatrix generation.
104        pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
105            Ok(PagedAttentionInputMetadata {
106                block_tables: None,
107                context_lens: None,
108                max_context_len: None,
109                slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
110                is_first_prompt_chunk: true,
111            })
112        }
113    }
114
115    #[derive(Clone, Debug)]
116    pub struct FlashParams {
117        pub max_q: u32,
118        pub max_k: u32,
119        pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
120        pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
121    }
122
123    pub struct InputMetadata {
124        pub input: Tensor,
125        pub positions: Vec<usize>,
126        pub context_lens: Vec<(usize, usize)>, // (start index, len)
127        pub position_ids: Vec<usize>,
128        pub paged_attn_meta: Option<PagedAttentionInputMetadata>, // For paged attention
129        pub flash_meta: FlashParams,
130    }
131
132    pub struct InnerInputProcessorOutput {
133        pub inputs: InputMetadata,
134        pub seq_indices: Vec<usize>,
135    }
136
137    // chunk_offset_toks is the number of tokens by which the tokens are offset,
138    // chunk_offset_toks / prompt_chunksize = number of batches
139    #[allow(clippy::too_many_arguments)]
140    pub fn make_prompt_chunk<T: WithDType + Debug>(
141        chunk_offset_toks: usize,
142        toks: Vec<&[T]>,
143        seq_ids: &[usize],
144        device: &Device,
145        last_n_context_len: Option<(usize, usize)>,
146        return_raw_logits: bool,
147        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
148        mapper: Option<&dyn DeviceMapper>,
149    ) -> Result<InputMetadata> {
150        let max_len = toks
151            .iter()
152            .map(|seq| seq.len())
153            .max()
154            .expect("No sequences");
155        let padding_tok = T::zero();
156        // Pad each sequence by the padding token to the max len.
157        let mut seqs_tensors = Vec::new();
158        let mut seqlen_offsets = Vec::new();
159        let mut context_lens = Vec::new();
160        let mut position_ids = Vec::new();
161        let mut slot_mappings = Vec::new();
162        let mut block_tables = Vec::new();
163        let mut paged_attn_context_lens = Vec::new();
164        let flash_attn = crate::using_flash_attn();
165        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
166        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
167        for (seq_id, ctxt) in seq_ids.iter().zip(toks) {
168            let prompt_len = ctxt.len();
169            let offset = last_n_context_len.unwrap_or_default();
170            seqlen_offsets.push(offset.1 + chunk_offset_toks);
171
172            position_ids.push(ctxt.len() + chunk_offset_toks);
173            let mut ctxt = ctxt.to_vec();
174            ctxt.extend(std::iter::repeat_n(
175                padding_tok,
176                max_len.saturating_sub(ctxt.len()),
177            ));
178            // If we are returning raw logits, we want to not trim the logits at all.
179            if return_raw_logits {
180                if last_n_context_len.is_some() {
181                    anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
182                }
183
184                context_lens.push((0, ctxt.len()));
185            } else {
186                context_lens.push((
187                    ctxt.len()
188                        .saturating_sub(last_n_context_len.map(|(a, _)| a).unwrap_or(1)),
189                    last_n_context_len.map(|(a, _)| a).unwrap_or(1),
190                ));
191            }
192
193            if flash_attn {
194                seqlens_q.push(ctxt.len() as u32);
195                seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
196            }
197
198            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
199
200            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
201                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
202                let table = block_engine.block_tables.get(seq_id);
203
204                if table.is_none() {
205                    // Will be None during profiling.
206                    slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
207                    continue;
208                }
209                let table = table
210                    .unwrap()
211                    .iter()
212                    .map(|block| block.deref_mut().block_id)
213                    .collect::<Vec<_>>();
214
215                let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
216                    prompt_len.saturating_sub(sliding_window)
217                } else {
218                    0
219                };
220
221                let mut slot_mapping = Vec::new();
222                let mut ctxt_len = Vec::new();
223                for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
224                    if i < start_idx {
225                        // Pad [0,start_idx) with _PAD_TOKEN_ID
226                        slot_mapping.push(_PAD_SLOT_ID);
227                    }
228                    ctxt_len.push(i);
229
230                    let block_number = if i / paged_attn_metadata.block_size >= table.len() {
231                        panic!(
232                            "Block table is too small (prompt)! i={} block_size={} table_len={}",
233                            i,
234                            paged_attn_metadata.block_size,
235                            table.len()
236                        );
237                    } else {
238                        table.get(i / paged_attn_metadata.block_size).unwrap()
239                    };
240                    let block_offset = i % paged_attn_metadata.block_size;
241                    let slot = block_number * paged_attn_metadata.block_size + block_offset;
242                    slot_mapping.push(slot.try_into().unwrap());
243                    block_tables.push(table.clone());
244                }
245                slot_mappings.push(slot_mapping);
246                paged_attn_context_lens.push(ctxt_len);
247            }
248        }
249
250        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
251            let max_q = *seqlens_q.iter().max().unwrap();
252            let max_k = *seqlens_k.iter().max().unwrap();
253            let seqlens_q = Tensor::new(seqlens_q, device)?
254                .to_dtype(DType::F32)?
255                .cumsum(0)?
256                .to_dtype(DType::U32)?;
257            let seqlens_k = Tensor::new(seqlens_k, device)?
258                .to_dtype(DType::F32)?
259                .cumsum(0)?
260                .to_dtype(DType::U32)?;
261
262            let mut seqlens_q_map = HashMap::new();
263            let mut seqlens_k_map = HashMap::new();
264
265            let devices = mapper.unwrap().get_unique_devices();
266            for device in devices {
267                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
268                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
269            }
270            (max_q, max_k, seqlens_q_map, seqlens_k_map)
271        } else {
272            (0, 0, HashMap::new(), HashMap::new())
273        };
274
275        let input = Tensor::cat(&seqs_tensors, 0).unwrap();
276
277        let paged_attn_meta = if paged_attn_metadata.is_some() {
278            let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
279            let slot_mappings =
280                _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;
281
282            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
283            let block_tables = _make_tensor_with_pad(
284                block_tables
285                    .iter()
286                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
287                    .collect::<Vec<_>>(),
288                max_block_table_len,
289                0,
290                device,
291            )?;
292            let block_tables = block_tables.reshape(((), max_block_table_len))?;
293
294            let max_context_len = paged_attn_context_lens
295                .iter()
296                .map(|x| x.len())
297                .max()
298                .unwrap();
299
300            let context_lens = _make_tensor_with_pad(
301                paged_attn_context_lens
302                    .iter()
303                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
304                    .collect::<Vec<_>>(),
305                max_context_len,
306                0,
307                device,
308            )?
309            .reshape(((),))?;
310
311            // For device mapping, make a copy of each tensor for each device
312            let devices = mapper.unwrap().get_unique_devices();
313            let mut slot_mappings_map = HashMap::new();
314            let mut block_tables_map = HashMap::new();
315            let mut context_lens_map = HashMap::new();
316
317            for device in devices {
318                slot_mappings_map
319                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
320                block_tables_map
321                    .insert(device.location(), block_tables.clone().to_device(&device)?);
322                context_lens_map
323                    .insert(device.location(), context_lens.clone().to_device(&device)?);
324            }
325
326            Some(PagedAttentionInputMetadata {
327                slot_mappings: slot_mappings_map,
328                block_tables: Some(block_tables_map),
329                context_lens: Some(context_lens_map),
330                max_context_len: Some(max_context_len),
331                is_first_prompt_chunk: chunk_offset_toks == 0,
332            })
333        } else {
334            None
335        };
336
337        Ok(InputMetadata {
338            input,
339            positions: seqlen_offsets,
340            context_lens,
341            position_ids,
342            paged_attn_meta,
343            flash_meta: FlashParams {
344                max_k,
345                max_q,
346                cumulative_seqlens_k: seqlens_k_map,
347                cumulative_seqlens_q: seqlens_q_map,
348            },
349        })
350    }
351
352    fn make_completion_chunk<T: WithDType>(
353        toks: Vec<&[T]>,
354        input_seqs: &[&mut Sequence],
355        device: &Device,
356        mut paged_attn_metadata: Option<&mut PagedAttentionMeta>,
357        mapper: Option<&dyn DeviceMapper>,
358    ) -> Result<InputMetadata> {
359        // Pad each sequence by the padding token to the max len.
360        let flash_attn = crate::using_flash_attn();
361        let mut seqs_tensors = Vec::new();
362        let mut seqlen_offsets = Vec::new();
363        let mut context_lens = Vec::new();
364        let mut position_ids = Vec::new();
365
366        let mut slot_mappings = Vec::new();
367        let mut block_tables = Vec::new();
368        let mut paged_attn_context_lens = Vec::new();
369        let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
370        let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
371        for (seq, ctxt) in input_seqs.iter().zip(toks) {
372            let start_pos = ctxt.len().saturating_sub(1);
373            let ctxt = ctxt[start_pos..].to_vec();
374            seqlen_offsets.push(start_pos);
375            context_lens.push((0, 1));
376            position_ids.push(seq.len());
377
378            if flash_attn {
379                seqlens_q.push(ctxt.len() as u32);
380                seqlens_k.push((ctxt.len() + start_pos) as u32);
381            }
382
383            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
384
385            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
386                let block_engine = get_mut_arcmutex!(paged_attn_metadata.block_engine);
387                let table = block_engine.block_tables.get(seq.id()).unwrap();
388
389                let table = table
390                    .iter()
391                    .map(|block| block.deref_mut().block_id)
392                    .collect::<Vec<_>>();
393
394                let block_pos = start_pos - seq.token_offset();
395                let block_number = if block_pos / paged_attn_metadata.block_size >= table.len() {
396                    panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", block_pos, paged_attn_metadata.block_size, table.len());
397                } else {
398                    table
399                        .get(block_pos / paged_attn_metadata.block_size)
400                        .unwrap()
401                };
402                let block_offset = block_pos % paged_attn_metadata.block_size;
403                let slot = block_number * paged_attn_metadata.block_size + block_offset;
404                let slot = slot.try_into().unwrap();
405                slot_mappings.push(vec![slot]);
406
407                if let Some(sliding_window) = paged_attn_metadata.sliding_window {
408                    let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
409                    let slide_idx = if table.len() > sliding_window_blocks {
410                        table.len() - sliding_window_blocks
411                    } else {
412                        0
413                    };
414                    block_tables.push(table.get(slide_idx..).unwrap().to_vec());
415                } else {
416                    block_tables.push(table);
417                }
418
419                let paged_attn_context_len =
420                    if let Some(sliding_window) = paged_attn_metadata.sliding_window {
421                        seq.len().min(sliding_window)
422                    } else {
423                        seq.len()
424                    };
425                paged_attn_context_lens.push(paged_attn_context_len);
426            }
427        }
428
429        let (max_q, max_k, seqlens_q_map, seqlens_k_map) = if flash_attn {
430            let max_q = *seqlens_q.iter().max().unwrap();
431            let max_k = *seqlens_k.iter().max().unwrap();
432            let seqlens_q = Tensor::new(seqlens_q, device)?
433                .to_dtype(DType::F32)?
434                .cumsum(0)?
435                .to_dtype(DType::U32)?;
436            let seqlens_k = Tensor::new(seqlens_k, device)?
437                .to_dtype(DType::F32)?
438                .cumsum(0)?
439                .to_dtype(DType::U32)?;
440
441            let mut seqlens_q_map = HashMap::new();
442            let mut seqlens_k_map = HashMap::new();
443
444            let devices = mapper.unwrap().get_unique_devices();
445            for device in devices {
446                seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
447                seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
448            }
449            (max_q, max_k, seqlens_q_map, seqlens_k_map)
450        } else {
451            (0, 0, HashMap::new(), HashMap::new())
452        };
453
454        let paged_attn_meta = if paged_attn_metadata.is_some() {
455            let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;
456
457            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
458
459            let block_tables = _make_tensor_with_pad(
460                block_tables
461                    .iter()
462                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
463                    .collect::<Vec<_>>(),
464                max_block_table_len,
465                0,
466                device,
467            )?;
468            let block_tables = block_tables.reshape(((), max_block_table_len))?;
469
470            let max_context_len = paged_attn_context_lens.iter().max().unwrap();
471
472            let context_lens = Tensor::from_vec(
473                paged_attn_context_lens
474                    .iter()
475                    .map(|x| *x as u32)
476                    .collect::<Vec<_>>(),
477                (paged_attn_context_lens.len(),),
478                device,
479            )?;
480
481            // For device mapping, make a copy of each tensor for each device
482            let devices = mapper.unwrap().get_unique_devices();
483            let mut slot_mappings_map = HashMap::new();
484            let mut block_tables_map = HashMap::new();
485            let mut context_lens_map = HashMap::new();
486
487            for device in devices {
488                slot_mappings_map
489                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
490                block_tables_map
491                    .insert(device.location(), block_tables.clone().to_device(&device)?);
492                context_lens_map
493                    .insert(device.location(), context_lens.clone().to_device(&device)?);
494            }
495
496            Some(PagedAttentionInputMetadata {
497                slot_mappings: slot_mappings_map,
498                block_tables: Some(block_tables_map),
499                context_lens: Some(context_lens_map),
500                max_context_len: Some(*max_context_len),
501                is_first_prompt_chunk: false,
502            })
503        } else {
504            None
505        };
506
507        Ok(InputMetadata {
508            input: Tensor::cat(&seqs_tensors, 0).unwrap(),
509            positions: seqlen_offsets,
510            context_lens,
511            position_ids,
512            paged_attn_meta,
513            flash_meta: FlashParams {
514                max_k,
515                max_q,
516                cumulative_seqlens_k: seqlens_k_map,
517                cumulative_seqlens_q: seqlens_q_map,
518            },
519        })
520    }
521
522    #[allow(clippy::too_many_arguments)]
523    pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
524        toks: Vec<&[T]>,
525        input_seqs: &[&mut Sequence],
526        device: &Device,
527        last_n_context_len: Option<(usize, usize)>,
528        return_raw_logits: bool,
529        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
530        prompt_chunksize: Option<NonZeroUsize>,
531        mapper: Option<&dyn DeviceMapper>,
532    ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
533        if let (Some(prompt_chunksize), true) = (prompt_chunksize, paged_attn_metadata.is_none()) {
534            let chunk_size = prompt_chunksize.get();
535            let offset = input_seqs[0].token_offset();
536            // Determine the maximum number of chunks across all sequences
537            let num_chunks = toks
538                .iter()
539                .map(|ctxt| ctxt.len().div_ceil(chunk_size))
540                .max()
541                .unwrap_or(0);
542
543            let mut outputs = Vec::with_capacity(num_chunks);
544            for chunk_idx in 0..num_chunks {
545                let mut slices = Vec::new();
546                let mut seq_ids = Vec::new();
547                let mut seq_indices = Vec::new();
548                for (seq_n, ctxt) in toks.iter().enumerate() {
549                    let start = chunk_idx * chunk_size;
550                    if start < ctxt.len() {
551                        let end = (start + chunk_size).min(ctxt.len());
552                        slices.push(&ctxt[start..end]);
553                        seq_indices.push(seq_n);
554                        seq_ids.push(*input_seqs[seq_n].id());
555                    }
556                }
557                let result = make_prompt_chunk(
558                    chunk_idx * chunk_size + offset,
559                    slices,
560                    &seq_ids,
561                    device,
562                    last_n_context_len,
563                    return_raw_logits,
564                    None,
565                    mapper,
566                )
567                .map(|inputs| InnerInputProcessorOutput {
568                    inputs,
569                    seq_indices,
570                });
571                outputs.push(result);
572            }
573            Box::new(outputs.into_iter())
574        } else {
575            let offset = input_seqs[0].token_offset();
576            Box::new(std::iter::once(
577                make_prompt_chunk(
578                    offset,
579                    toks,
580                    &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
581                    device,
582                    last_n_context_len,
583                    return_raw_logits,
584                    paged_attn_metadata,
585                    mapper,
586                )
587                .map(|inputs| InnerInputProcessorOutput {
588                    inputs,
589                    seq_indices: (0..input_seqs.len()).collect(),
590                }),
591            ))
592        }
593    }
594
595    #[allow(clippy::too_many_arguments)]
596    pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
597        toks: Vec<&[T]>,
598        input_seqs: &[&mut Sequence],
599        device: &Device,
600        no_kv_cache: bool,
601        last_n_context_len: Option<(usize, usize)>,
602        return_raw_logits: bool,
603        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
604        prompt_chunksize: Option<NonZeroUsize>,
605        mapper: Option<&dyn DeviceMapper>,
606    ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
607        if no_kv_cache {
608            return get_prompt_input(
609                toks,
610                input_seqs,
611                device,
612                last_n_context_len,
613                return_raw_logits,
614                paged_attn_metadata,
615                prompt_chunksize,
616                mapper,
617            );
618        }
619
620        Box::new(std::iter::once(
621            make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(
622                |inputs| InnerInputProcessorOutput {
623                    inputs,
624                    seq_indices: (0..input_seqs.len()).collect(),
625                },
626            ),
627        ))
628    }
629
630    #[derive(Clone)]
631    pub struct ModelInputs {
632        pub input_ids: Tensor,
633        pub input_ids_full: Option<Tensor>,
634        pub seqlen_offsets: Vec<usize>,
635        pub seqlen_offsets_full: Option<Vec<usize>>,
636        pub context_lens: Vec<(usize, usize)>,
637        pub position_ids: Vec<usize>,
638        pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
639        pub flash_meta: FlashParams,
640        pub flash_meta_full: Option<FlashParams>,
641    }
642
643    pub struct TextInputsProcessor;
644
645    impl InputsProcessor for TextInputsProcessor {
646        fn process_inputs(
647            &self,
648            _: Option<Arc<Tokenizer>>,
649            input_seqs: &mut [&mut Sequence],
650            is_prompt: bool,
651            is_xlora: bool,
652            device: &Device,
653            no_kv_cache: bool,
654            last_n_context_len: Option<(usize, usize)>,
655            return_raw_logits: bool,
656            _: Option<Arc<dyn Any>>,
657            mut paged_attn_metadata: Option<PagedAttentionMeta>,
658            prompt_chunksize: Option<NonZeroUsize>,
659            mapper: Option<&dyn DeviceMapper>,
660        ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
661            if is_xlora && !is_prompt {
662                Box::new(
663                    get_prompt_input(
664                        input_seqs
665                            .iter()
666                            .map(|seq| seq.get_toks())
667                            .collect::<Vec<_>>(),
668                        input_seqs,
669                        device,
670                        last_n_context_len,
671                        return_raw_logits,
672                        paged_attn_metadata.as_mut(),
673                        prompt_chunksize,
674                        mapper,
675                    )
676                    .zip(get_completion_input(
677                        input_seqs
678                            .iter()
679                            .map(|seq| seq.get_toks())
680                            .collect::<Vec<_>>(),
681                        input_seqs,
682                        device,
683                        no_kv_cache,
684                        last_n_context_len,
685                        return_raw_logits,
686                        paged_attn_metadata.as_mut(),
687                        prompt_chunksize,
688                        mapper,
689                    ))
690                    .map(|(prompt, completion)| {
691                        let InnerInputProcessorOutput {
692                            inputs:
693                                InputMetadata {
694                                    input: input_ids_full,
695                                    positions: seqlen_offsets_full,
696                                    context_lens: _,
697                                    position_ids,
698                                    paged_attn_meta: _,
699                                    flash_meta: flash_meta_full,
700                                },
701                            seq_indices,
702                        } = prompt?;
703                        let InnerInputProcessorOutput {
704                            inputs:
705                                InputMetadata {
706                                    input: input_ids,
707                                    positions: seqlen_offsets,
708                                    context_lens,
709                                    position_ids: _,
710                                    paged_attn_meta,
711                                    flash_meta,
712                                },
713                            seq_indices: _,
714                        } = completion?;
715                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
716                            input_ids,
717                            input_ids_full: Some(input_ids_full),
718                            seqlen_offsets,
719                            seqlen_offsets_full: Some(seqlen_offsets_full),
720                            context_lens,
721                            position_ids,
722                            paged_attn_meta,
723                            flash_meta,
724                            flash_meta_full: Some(flash_meta_full),
725                        });
726                        Ok(InputProcessorOutput {
727                            inputs,
728                            seq_indices,
729                        })
730                    }),
731                )
732            } else if is_xlora && is_prompt {
733                Box::new(
734                    get_prompt_input(
735                        input_seqs
736                            .iter()
737                            .map(|seq| seq.get_toks())
738                            .collect::<Vec<_>>(),
739                        input_seqs,
740                        device,
741                        last_n_context_len,
742                        return_raw_logits,
743                        paged_attn_metadata.as_mut(),
744                        prompt_chunksize,
745                        mapper,
746                    )
747                    .map(|metadata| {
748                        let InnerInputProcessorOutput {
749                            inputs:
750                                InputMetadata {
751                                    input: input_ids,
752                                    positions: seqlen_offsets,
753                                    context_lens,
754                                    position_ids,
755                                    paged_attn_meta,
756                                    flash_meta,
757                                },
758                            seq_indices,
759                        } = metadata?;
760                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
761                            input_ids: input_ids.clone(),
762                            input_ids_full: Some(input_ids),
763                            seqlen_offsets: seqlen_offsets.clone(),
764                            seqlen_offsets_full: Some(seqlen_offsets),
765                            context_lens,
766                            position_ids,
767                            paged_attn_meta,
768                            flash_meta: flash_meta.clone(),
769                            flash_meta_full: Some(flash_meta),
770                        });
771                        Ok(InputProcessorOutput {
772                            inputs,
773                            seq_indices,
774                        })
775                    }),
776                )
777            } else if is_prompt {
778                Box::new(
779                    get_prompt_input(
780                        input_seqs
781                            .iter()
782                            .map(|seq| seq.get_toks())
783                            .collect::<Vec<_>>(),
784                        input_seqs,
785                        device,
786                        last_n_context_len,
787                        return_raw_logits,
788                        paged_attn_metadata.as_mut(),
789                        prompt_chunksize,
790                        mapper,
791                    )
792                    .map(|metadata| {
793                        let InnerInputProcessorOutput {
794                            inputs:
795                                InputMetadata {
796                                    input: input_ids,
797                                    positions: seqlen_offsets,
798                                    context_lens,
799                                    position_ids,
800                                    paged_attn_meta,
801                                    flash_meta,
802                                },
803                            seq_indices,
804                        } = metadata?;
805                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
806                            input_ids,
807                            input_ids_full: None,
808                            seqlen_offsets,
809                            seqlen_offsets_full: None,
810                            context_lens,
811                            position_ids,
812                            paged_attn_meta,
813                            flash_meta,
814                            flash_meta_full: None,
815                        });
816                        Ok(InputProcessorOutput {
817                            inputs,
818                            seq_indices,
819                        })
820                    }),
821                )
822            } else {
823                Box::new(
824                    get_completion_input(
825                        input_seqs
826                            .iter()
827                            .map(|seq| seq.get_toks())
828                            .collect::<Vec<_>>(),
829                        input_seqs,
830                        device,
831                        no_kv_cache,
832                        last_n_context_len,
833                        return_raw_logits,
834                        paged_attn_metadata.as_mut(),
835                        prompt_chunksize,
836                        mapper,
837                    )
838                    .map(|metadata| {
839                        let InnerInputProcessorOutput {
840                            inputs:
841                                InputMetadata {
842                                    input: input_ids,
843                                    positions: seqlen_offsets,
844                                    context_lens,
845                                    position_ids,
846                                    paged_attn_meta,
847                                    flash_meta,
848                                },
849                            seq_indices,
850                        } = metadata?;
851                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
852                            input_ids,
853                            input_ids_full: None,
854                            seqlen_offsets,
855                            seqlen_offsets_full: None,
856                            context_lens,
857                            position_ids,
858                            paged_attn_meta,
859                            flash_meta,
860                            flash_meta_full: None,
861                        });
862                        Ok(InputProcessorOutput {
863                            inputs,
864                            seq_indices,
865                        })
866                    }),
867                )
868            }
869        }
870
871        fn get_type(&self) -> InputsProcessorType {
872            InputsProcessorType::Text
873        }
874    }
875}