mistralrs_core/pipeline/
inputs_processor.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
#![allow(clippy::cast_possible_truncation)]

use std::{any::Any, num::NonZeroUsize, sync::Arc};

use anyhow::Result;
use candle_core::Device;
use text_models_inputs_processor::PagedAttentionMeta;
use tokenizers::Tokenizer;

use crate::{device_map::DeviceMapper, sequence::Sequence};

#[derive(PartialEq)]
pub enum InputsProcessorType {
    Text,
    Vision,
}

pub struct InputProcessorOutput {
    pub inputs: Box<dyn Any>,
    pub seq_indices: Vec<usize>,
}

/// Processor: Prepare inputs for the model (potentially preparing the images if applicable)
pub trait InputsProcessor {
    /// This should also enable matmul via f16 if prompt and the sequence length is greater than 32.
    /// Otherwise, matmul via f16 is disabled.
    ///
    /// This should return a type which can be downcasted to the proper type as used in `forward_inputs`
    #[allow(clippy::too_many_arguments)]
    fn process_inputs(
        &self,
        tokenizer: Option<Arc<Tokenizer>>,
        input_seqs: &mut [&mut Sequence],
        is_prompt: bool,
        is_xlora: bool,
        device: &Device,
        no_kv_cache: bool,
        last_n_context_len: Option<(usize, usize)>,
        return_raw_logits: bool,
        other_config: Option<Arc<dyn Any>>,
        paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
        prompt_batchsize: Option<NonZeroUsize>,
        mapper: Option<&dyn DeviceMapper>,
    ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>>;

    fn get_type(&self) -> InputsProcessorType;
}

// ========================= Test models input processor

pub mod text_models_inputs_processor {
    use std::{
        any::Any, collections::HashMap, fmt::Debug, iter::repeat, num::NonZeroUsize, sync::Arc,
    };

    use anyhow::Result;
    use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
    use tokenizers::Tokenizer;

    use crate::{
        device_map::DeviceMapper,
        layers::set_use_matmul_via_f16,
        paged_attention::{BlockEngine, _PAD_SLOT_ID},
        sequence::Sequence,
    };

    use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};

    const VIA_F16_TOK_THRESHOLD: usize = 512;

    fn _make_tensor_with_pad<D: WithDType>(
        x: Vec<Vec<D>>,
        max_len: usize,
        pad: D,
        device: &Device,
    ) -> Result<Tensor> {
        let mut padded_x = Vec::new();
        for mut x_i in x {
            assert!(x_i.len() <= max_len);
            x_i.extend([pad].repeat(max_len - x_i.len()));
            let shape = (x_i.len(),);
            padded_x.push(Tensor::from_vec(x_i, shape, device)?);
        }
        Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
    }

    pub struct PagedAttentionMeta<'a> {
        pub sliding_window: Option<usize>,
        pub block_size: usize,
        pub block_engine: &'a mut BlockEngine,
    }

    #[derive(Clone, Debug)]
    #[allow(dead_code)]
    pub struct PagedAttentionInputMetadata {
        pub block_tables: Option<HashMap<DeviceLocation, Tensor>>,
        pub context_lens: Option<HashMap<DeviceLocation, Tensor>>,
        pub slot_mappings: HashMap<DeviceLocation, Tensor>,
        pub max_context_len: Option<usize>,
    }

    impl PagedAttentionInputMetadata {
        /// Create a dummy input metadata, assuming that this will NOT be used for decoding.
        /// This is used for the case of imatrix generation.
        pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
            Ok(PagedAttentionInputMetadata {
                block_tables: None,
                context_lens: None,
                max_context_len: None,
                slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]),
            })
        }
    }

    #[derive(Clone, Debug)]
    pub struct FlashParams {
        pub max_q: u32,
        pub max_k: u32,
        pub cumulative_seqlens_q: Tensor,
        pub cumulative_seqlens_k: Tensor,
    }

    pub struct InputMetadata {
        pub input: Tensor,
        pub positions: Vec<usize>,
        pub positions_kernel: Tensor,          // [bs, seq len]
        pub context_lens: Vec<(usize, usize)>, // (start index, len)
        pub position_ids: Vec<usize>,
        pub paged_attn_meta: Option<PagedAttentionInputMetadata>, // For paged attention
        pub flash_meta: FlashParams,
    }

    pub struct InnerInputProcessorOutput {
        pub inputs: InputMetadata,
        pub seq_indices: Vec<usize>,
    }

    // chunk_offset_toks is the number of tokens by which the tokens are offset,
    // chunk_offset_toks / prompt_batchsize = number of batches
    #[allow(clippy::too_many_arguments)]
    pub fn make_prompt_chunk<T: WithDType + Debug>(
        chunk_offset_toks: usize,
        toks: Vec<Vec<T>>,
        seq_ids: &[usize],
        device: &Device,
        last_n_context_len: Option<(usize, usize)>,
        return_raw_logits: bool,
        mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
        mapper: Option<&dyn DeviceMapper>,
    ) -> Result<InputMetadata> {
        let max_len = toks
            .iter()
            .map(|seq| seq.len())
            .max()
            .expect("No sequences");
        let padding_tok = T::zero();
        // Pad each sequence by the padding token to the max len.
        let mut seqs_tensors = Vec::new();
        let mut seqlen_offsets = Vec::new();
        let mut context_lens = Vec::new();
        let mut position_ids = Vec::new();
        let mut slot_mappings = Vec::new();
        let mut block_tables = Vec::new();
        let mut paged_attn_context_lens = Vec::new();
        let mut seqlens_q = vec![0];
        let mut seqlens_k = vec![0];
        for (seq_id, mut ctxt) in seq_ids.iter().zip(toks) {
            let prompt_len = ctxt.len();
            let offset = last_n_context_len.unwrap_or_default();
            seqlen_offsets.push(offset.1 + chunk_offset_toks);

            position_ids.push(ctxt.len() + chunk_offset_toks);
            ctxt.extend(repeat(padding_tok).take(max_len.saturating_sub(ctxt.len())));
            // If we are returning raw logits, we want to not trim the logits at all.
            if return_raw_logits {
                if last_n_context_len.is_some() {
                    anyhow::bail!("`return_raw_logits` is incompatible with `last_n_context_len`");
                }

                context_lens.push((0, ctxt.len()));
            } else {
                context_lens.push((
                    ctxt.len() - last_n_context_len.map(|(a, _)| a).unwrap_or(1),
                    last_n_context_len.map(|(a, _)| a).unwrap_or(1),
                ));
            }

            seqlens_q.push(ctxt.len() as u32);
            seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);

            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());

            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
                let table = paged_attn_metadata.block_engine.block_tables.get(seq_id);

                if table.is_none() {
                    // Will be None during profiling.
                    slot_mappings.push([_PAD_SLOT_ID].repeat(prompt_len));
                    continue;
                }
                let table = table
                    .unwrap()
                    .iter()
                    .map(|block| block.deref_mut().block_id)
                    .collect::<Vec<_>>();

                let start_idx = if let Some(sliding_window) = paged_attn_metadata.sliding_window {
                    if prompt_len > sliding_window {
                        chunk_offset_toks.min(prompt_len - sliding_window)
                    } else {
                        chunk_offset_toks
                    }
                } else {
                    chunk_offset_toks
                };

                let mut slot_mapping = Vec::new();
                let mut ctxt_len = Vec::new();
                for i in chunk_offset_toks..prompt_len + chunk_offset_toks {
                    if i < start_idx {
                        // Pad [0,start_idx) with _PAD_TOKEN_ID
                        slot_mapping.push(_PAD_SLOT_ID);
                    }
                    ctxt_len.push(i);

                    let block_number = if i / paged_attn_metadata.block_size >= table.len() {
                        panic!(
                            "Block table is too small (prompt)! i={} block_size={} table_len={}",
                            i,
                            paged_attn_metadata.block_size,
                            table.len()
                        );
                    } else {
                        table.get(i / paged_attn_metadata.block_size).unwrap()
                    };
                    let block_offset = i % paged_attn_metadata.block_size;
                    let slot = block_number * paged_attn_metadata.block_size + block_offset;
                    slot_mapping.push(slot.try_into().unwrap());
                    block_tables.push(table.clone());
                }
                slot_mappings.push(slot_mapping);
                paged_attn_context_lens.push(ctxt_len);
            }
        }

        let mut tmp = Vec::new();
        if last_n_context_len.is_some() {
            for pos in (0..seqs_tensors.len())
                .map(|i| {
                    (*seqlen_offsets.get(i).unwrap() as i64
                        ..*seqlen_offsets.get(i).unwrap() as i64 + max_len as i64)
                        .collect::<Vec<_>>()
                })
                .collect::<Vec<_>>()
            {
                tmp.push(Tensor::from_slice(&pos, pos.len(), device)?.unsqueeze(0)?);
            }
        } else {
            for pos in (0..seqs_tensors.len())
                .map(|_| (0..max_len).map(|x| x as i64).collect::<Vec<_>>())
                .collect::<Vec<_>>()
            {
                tmp.push(Tensor::from_slice(&pos, pos.len(), device)?.unsqueeze(0)?);
            }
        }
        let max_q = *seqlens_q.iter().max().unwrap();
        let max_k = *seqlens_k.iter().max().unwrap();
        let seqlens_q = Tensor::new(seqlens_q, device)?
            .to_dtype(DType::F32)?
            .cumsum(0)?
            .to_dtype(DType::U32)?;
        let seqlens_k = Tensor::new(seqlens_k, device)?
            .to_dtype(DType::F32)?
            .cumsum(0)?
            .to_dtype(DType::U32)?;
        //dbg!(&seqlens_q, &seqlens_k, &seqlen_offsets, &position_ids);
        let positions_kernel = Tensor::cat(&tmp, 0)?;
        let input = Tensor::cat(&seqs_tensors, 0).unwrap();
        // Only use matmul via f16 if prompt and seqlen > 512
        if input.dim(1)? > VIA_F16_TOK_THRESHOLD {
            set_use_matmul_via_f16(true);
        } else {
            set_use_matmul_via_f16(false);
        }

        let paged_attn_meta = if paged_attn_metadata.is_some() {
            let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
            let slot_mappings =
                _make_tensor_with_pad(slot_mappings, max_slot_mapping_len, _PAD_SLOT_ID, device)?;

            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();
            let block_tables = _make_tensor_with_pad(
                block_tables
                    .iter()
                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
                    .collect::<Vec<_>>(),
                max_block_table_len,
                0,
                device,
            )?;
            let block_tables = block_tables.reshape(((), max_block_table_len))?;

            let max_context_len = paged_attn_context_lens
                .iter()
                .map(|x| x.len())
                .max()
                .unwrap();

            let context_lens = _make_tensor_with_pad(
                paged_attn_context_lens
                    .iter()
                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
                    .collect::<Vec<_>>(),
                max_context_len,
                0,
                device,
            )?
            .reshape(((),))?;

            // For device mapping, make a copy of each tensor for each device
            let devices = mapper.unwrap().get_unique_devices();
            let mut slot_mappings_map = HashMap::new();
            let mut block_tables_map = HashMap::new();
            let mut context_lens_map = HashMap::new();

            for device in devices {
                slot_mappings_map
                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
                block_tables_map
                    .insert(device.location(), block_tables.clone().to_device(&device)?);
                context_lens_map
                    .insert(device.location(), context_lens.clone().to_device(&device)?);
            }

            Some(PagedAttentionInputMetadata {
                slot_mappings: slot_mappings_map,
                block_tables: Some(block_tables_map),
                context_lens: Some(context_lens_map),
                max_context_len: Some(max_context_len),
            })
        } else {
            None
        };

        Ok(InputMetadata {
            input,
            positions: seqlen_offsets,
            positions_kernel,
            context_lens,
            position_ids,
            paged_attn_meta,
            flash_meta: FlashParams {
                max_k,
                max_q,
                cumulative_seqlens_k: seqlens_k,
                cumulative_seqlens_q: seqlens_q,
            },
        })
    }

    fn make_completion_chunk<T: WithDType>(
        toks: Vec<Vec<T>>,
        input_seqs: &[&mut Sequence],
        device: &Device,
        mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
        mapper: Option<&dyn DeviceMapper>,
    ) -> Result<InputMetadata> {
        // Pad each sequence by the padding token to the max len.
        let mut seqs_tensors = Vec::new();
        let mut seqlen_offsets = Vec::new();
        let mut context_lens = Vec::new();
        let mut position_ids = Vec::new();

        let mut slot_mappings = Vec::new();
        let mut block_tables = Vec::new();
        let mut paged_attn_context_lens = Vec::new();
        let mut seqlens_q = vec![0];
        let mut seqlens_k = vec![0];
        for (seq, ctxt) in input_seqs.iter().zip(toks) {
            let start_pos = ctxt.len().saturating_sub(1);
            let ctxt = ctxt[start_pos..].to_vec();
            seqlen_offsets.push(start_pos);
            context_lens.push((0, 1));
            position_ids.push(seq.len());

            seqlens_q.push(ctxt.len() as u32);
            seqlens_k.push((ctxt.len() + start_pos) as u32);

            seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());

            if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
                let table = paged_attn_metadata
                    .block_engine
                    .block_tables
                    .get(seq.id())
                    .unwrap();

                let table = table
                    .iter()
                    .map(|block| block.deref_mut().block_id)
                    .collect::<Vec<_>>();

                let block_number = if start_pos / paged_attn_metadata.block_size >= table.len() {
                    panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", start_pos, paged_attn_metadata.block_size, table.len());
                } else {
                    table
                        .get(start_pos / paged_attn_metadata.block_size)
                        .unwrap()
                };
                let block_offset = start_pos % paged_attn_metadata.block_size;
                let slot = block_number * paged_attn_metadata.block_size + block_offset;
                let slot = slot.try_into().unwrap();
                slot_mappings.push(vec![slot]);

                if let Some(sliding_window) = paged_attn_metadata.sliding_window {
                    let sliding_window_blocks = sliding_window / paged_attn_metadata.block_size;
                    let slide_idx = if table.len() > sliding_window_blocks {
                        table.len() - sliding_window_blocks
                    } else {
                        0
                    };
                    block_tables.push(table.get(slide_idx..).unwrap().to_vec());
                } else {
                    block_tables.push(table);
                }

                let paged_attn_context_len =
                    if let Some(sliding_window) = paged_attn_metadata.sliding_window {
                        seq.len().min(sliding_window)
                    } else {
                        seq.len()
                    };
                paged_attn_context_lens.push(paged_attn_context_len);
            }
        }
        let mut tmp = Vec::new();
        for pos in (0..seqs_tensors.len())
            .map(|i| vec![*seqlen_offsets.get(i).unwrap() as i64])
            .collect::<Vec<_>>()
        {
            tmp.push(Tensor::from_slice(&pos, pos.len(), device)?.unsqueeze(0)?);
        }
        let max_q = *seqlens_q.iter().max().unwrap();
        let max_k = *seqlens_k.iter().max().unwrap();
        let seqlens_q = Tensor::new(seqlens_q, device)?
            .to_dtype(DType::F32)?
            .cumsum(0)?
            .to_dtype(DType::U32)?;
        let seqlens_k = Tensor::new(seqlens_k, device)?
            .to_dtype(DType::F32)?
            .cumsum(0)?
            .to_dtype(DType::U32)?;
        //dbg!(&seqlens_q, &seqlens_k, &seqlen_offsets, &position_ids);
        let positions_kernel = Tensor::cat(&tmp, 0)?;
        set_use_matmul_via_f16(false);

        let paged_attn_meta = if paged_attn_metadata.is_some() {
            let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;

            let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap();

            let block_tables = _make_tensor_with_pad(
                block_tables
                    .iter()
                    .map(|x| x.iter().map(|x| *x as u32).collect::<Vec<_>>())
                    .collect::<Vec<_>>(),
                max_block_table_len,
                0,
                device,
            )?;
            let block_tables = block_tables.reshape(((), max_block_table_len))?;

            let max_context_len = paged_attn_context_lens.iter().max().unwrap();

            let context_lens = Tensor::from_vec(
                paged_attn_context_lens
                    .iter()
                    .map(|x| *x as u32)
                    .collect::<Vec<_>>(),
                (paged_attn_context_lens.len(),),
                device,
            )?;

            // For device mapping, make a copy of each tensor for each device
            let devices = mapper.unwrap().get_unique_devices();
            let mut slot_mappings_map = HashMap::new();
            let mut block_tables_map = HashMap::new();
            let mut context_lens_map = HashMap::new();

            for device in devices {
                slot_mappings_map
                    .insert(device.location(), slot_mappings.clone().to_device(&device)?);
                block_tables_map
                    .insert(device.location(), block_tables.clone().to_device(&device)?);
                context_lens_map
                    .insert(device.location(), context_lens.clone().to_device(&device)?);
            }

            Some(PagedAttentionInputMetadata {
                slot_mappings: slot_mappings_map,
                block_tables: Some(block_tables_map),
                context_lens: Some(context_lens_map),
                max_context_len: Some(*max_context_len),
            })
        } else {
            None
        };

        Ok(InputMetadata {
            input: Tensor::cat(&seqs_tensors, 0).unwrap(),
            positions: seqlen_offsets,
            positions_kernel,
            context_lens,
            position_ids,
            paged_attn_meta,
            flash_meta: FlashParams {
                max_k,
                max_q,
                cumulative_seqlens_k: seqlens_k,
                cumulative_seqlens_q: seqlens_q,
            },
        })
    }

    #[allow(clippy::too_many_arguments)]
    pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
        toks: Vec<Vec<T>>,
        input_seqs: &[&mut Sequence],
        device: &Device,
        last_n_context_len: Option<(usize, usize)>,
        return_raw_logits: bool,
        mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
        prompt_batchsize: Option<NonZeroUsize>,
        mapper: Option<&dyn DeviceMapper>,
    ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
        if let (Some(prompt_batchsize), true) = (prompt_batchsize, paged_attn_metadata.is_none()) {
            let mut seq_chunks = Vec::new();
            let mut n_chunks = Vec::new();
            let prompt_batchsize: usize = prompt_batchsize.into();

            // Pad each sequence by the padding token to the max len.
            for ctxt in toks.iter() {
                let chunks = ctxt.chunks(prompt_batchsize).collect::<Vec<_>>();
                n_chunks.push(chunks.len());
                seq_chunks.push(chunks);
            }
            // Basically convert the sequences and tok chunks into chunks of seqs and the corresp toks
            let mut chunks_transposed: Vec<Vec<(Vec<T>, usize)>> = Vec::new();
            for (seq_n, seq) in seq_chunks.into_iter().enumerate() {
                for (i, chunk) in seq.into_iter().enumerate() {
                    match chunks_transposed.get_mut(i) {
                        Some(part) => part.push((chunk.to_vec(), seq_n)),
                        None => chunks_transposed.push(vec![(chunk.to_vec(), seq_n)]),
                    }
                }
            }
            let chunks = chunks_transposed
                .into_iter()
                .enumerate()
                .map(|(i, chunk)| {
                    let (toks, seq_ns): (Vec<Vec<T>>, Vec<usize>) = chunk.into_iter().unzip();
                    make_prompt_chunk(
                        i * prompt_batchsize,
                        toks,
                        &seq_ns
                            .iter()
                            .map(|i| *input_seqs[*i].id())
                            .collect::<Vec<_>>(),
                        device,
                        last_n_context_len,
                        return_raw_logits,
                        paged_attn_metadata.as_deref_mut(),
                        mapper,
                    )
                    .map(|inputs| InnerInputProcessorOutput {
                        inputs,
                        seq_indices: seq_ns,
                    })
                })
                .collect::<Vec<_>>();
            Box::new(chunks.into_iter())
        } else {
            if prompt_batchsize.is_some() {
                // TODO(EricLBuehler)
                return Box::new(std::iter::once(Err(anyhow::Error::msg(
                    "PagedAttention does not yet support prompt batching.",
                ))));
            }
            let offset = input_seqs[0].token_offset();
            if offset != 0 && paged_attn_metadata.is_some() {
                return Box::new(std::iter::once(Err(anyhow::Error::msg(
                    "PagedAttention does not yet support sequences with an offset != 0.",
                ))));
            }
            Box::new(std::iter::once(
                make_prompt_chunk(
                    offset,
                    toks,
                    &input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
                    device,
                    last_n_context_len,
                    return_raw_logits,
                    paged_attn_metadata,
                    mapper,
                )
                .map(|inputs| InnerInputProcessorOutput {
                    inputs,
                    seq_indices: (0..input_seqs.len()).collect(),
                }),
            ))
        }
    }

    #[allow(clippy::too_many_arguments)]
    pub(crate) fn get_completion_input<T: WithDType + std::fmt::Debug>(
        toks: Vec<Vec<T>>,
        input_seqs: &[&mut Sequence],
        device: &Device,
        no_kv_cache: bool,
        last_n_context_len: Option<(usize, usize)>,
        return_raw_logits: bool,
        paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
        prompt_batchsize: Option<NonZeroUsize>,
        mapper: Option<&dyn DeviceMapper>,
    ) -> Box<dyn Iterator<Item = Result<InnerInputProcessorOutput>>> {
        if no_kv_cache {
            return get_prompt_input(
                toks,
                input_seqs,
                device,
                last_n_context_len,
                return_raw_logits,
                paged_attn_metadata,
                prompt_batchsize,
                mapper,
            );
        }

        Box::new(std::iter::once(
            make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map(
                |inputs| InnerInputProcessorOutput {
                    inputs,
                    seq_indices: (0..input_seqs.len()).collect(),
                },
            ),
        ))
    }

    #[derive(Clone)]
    pub struct ModelInputs {
        pub input_ids: Tensor,
        pub input_ids_full: Option<Tensor>,
        pub seqlen_offsets: Vec<usize>,
        pub seqlen_offsets_full: Option<Vec<usize>>,
        pub seqlen_offsets_kernel: Tensor,
        pub seqlen_offsets_kernel_full: Option<Tensor>,
        pub context_lens: Vec<(usize, usize)>,
        pub position_ids: Vec<usize>,
        pub paged_attn_meta: Option<PagedAttentionInputMetadata>,
        pub flash_meta: FlashParams,
        pub flash_meta_full: Option<FlashParams>,
    }

    pub struct TextInputsProcessor;

    impl InputsProcessor for TextInputsProcessor {
        fn process_inputs(
            &self,
            _: Option<Arc<Tokenizer>>,
            input_seqs: &mut [&mut Sequence],
            is_prompt: bool,
            is_xlora: bool,
            device: &Device,
            no_kv_cache: bool,
            last_n_context_len: Option<(usize, usize)>,
            return_raw_logits: bool,
            _: Option<Arc<dyn Any>>,
            mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
            prompt_batchsize: Option<NonZeroUsize>,
            mapper: Option<&dyn DeviceMapper>,
        ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
            if is_xlora && !is_prompt {
                Box::new(
                    get_prompt_input(
                        input_seqs
                            .iter()
                            .map(|seq| seq.get_toks().to_vec())
                            .collect::<Vec<_>>(),
                        input_seqs,
                        device,
                        last_n_context_len,
                        return_raw_logits,
                        paged_attn_metadata.as_mut(),
                        prompt_batchsize,
                        mapper,
                    )
                    .zip(get_completion_input(
                        input_seqs
                            .iter()
                            .map(|seq| seq.get_toks().to_vec())
                            .collect::<Vec<_>>(),
                        input_seqs,
                        device,
                        no_kv_cache,
                        last_n_context_len,
                        return_raw_logits,
                        paged_attn_metadata.as_mut(),
                        prompt_batchsize,
                        mapper,
                    ))
                    .map(|(prompt, completion)| {
                        let InnerInputProcessorOutput {
                            inputs:
                                InputMetadata {
                                    input: input_ids_full,
                                    positions: seqlen_offsets_full,
                                    positions_kernel: seqlen_offsets_kernel_full,
                                    context_lens: _,
                                    position_ids,
                                    paged_attn_meta: _,
                                    flash_meta: flash_meta_full,
                                },
                            seq_indices,
                        } = prompt?;
                        let InnerInputProcessorOutput {
                            inputs:
                                InputMetadata {
                                    input: input_ids,
                                    positions: seqlen_offsets,
                                    positions_kernel: seqlen_offsets_kernel,
                                    context_lens,
                                    position_ids: _,
                                    paged_attn_meta,
                                    flash_meta,
                                },
                            seq_indices: _,
                        } = completion?;
                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
                            input_ids,
                            input_ids_full: Some(input_ids_full),
                            seqlen_offsets,
                            seqlen_offsets_full: Some(seqlen_offsets_full),
                            seqlen_offsets_kernel,
                            seqlen_offsets_kernel_full: Some(seqlen_offsets_kernel_full),
                            context_lens,
                            position_ids,
                            paged_attn_meta,
                            flash_meta,
                            flash_meta_full: Some(flash_meta_full),
                        });
                        Ok(InputProcessorOutput {
                            inputs,
                            seq_indices,
                        })
                    }),
                )
            } else if is_xlora && is_prompt {
                Box::new(
                    get_prompt_input(
                        input_seqs
                            .iter()
                            .map(|seq| seq.get_toks().to_vec())
                            .collect::<Vec<_>>(),
                        input_seqs,
                        device,
                        last_n_context_len,
                        return_raw_logits,
                        paged_attn_metadata.as_mut(),
                        prompt_batchsize,
                        mapper,
                    )
                    .map(|metadata| {
                        let InnerInputProcessorOutput {
                            inputs:
                                InputMetadata {
                                    input: input_ids,
                                    positions: seqlen_offsets,
                                    positions_kernel: seqlen_offsets_kernel,
                                    context_lens,
                                    position_ids,
                                    paged_attn_meta,
                                    flash_meta,
                                },
                            seq_indices,
                        } = metadata?;
                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
                            input_ids: input_ids.clone(),
                            input_ids_full: Some(input_ids),
                            seqlen_offsets: seqlen_offsets.clone(),
                            seqlen_offsets_full: Some(seqlen_offsets),
                            seqlen_offsets_kernel: seqlen_offsets_kernel.clone(),
                            seqlen_offsets_kernel_full: Some(seqlen_offsets_kernel),
                            context_lens,
                            position_ids,
                            paged_attn_meta,
                            flash_meta: flash_meta.clone(),
                            flash_meta_full: Some(flash_meta),
                        });
                        Ok(InputProcessorOutput {
                            inputs,
                            seq_indices,
                        })
                    }),
                )
            } else if is_prompt {
                Box::new(
                    get_prompt_input(
                        input_seqs
                            .iter()
                            .map(|seq| seq.get_toks().to_vec())
                            .collect::<Vec<_>>(),
                        input_seqs,
                        device,
                        last_n_context_len,
                        return_raw_logits,
                        paged_attn_metadata.as_mut(),
                        prompt_batchsize,
                        mapper,
                    )
                    .map(|metadata| {
                        let InnerInputProcessorOutput {
                            inputs:
                                InputMetadata {
                                    input: input_ids,
                                    positions: seqlen_offsets,
                                    positions_kernel: seqlen_offsets_kernel,
                                    context_lens,
                                    position_ids,
                                    paged_attn_meta,
                                    flash_meta,
                                },
                            seq_indices,
                        } = metadata?;
                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
                            input_ids,
                            input_ids_full: None,
                            seqlen_offsets,
                            seqlen_offsets_full: None,
                            seqlen_offsets_kernel,
                            seqlen_offsets_kernel_full: None,
                            context_lens,
                            position_ids,
                            paged_attn_meta,
                            flash_meta,
                            flash_meta_full: None,
                        });
                        Ok(InputProcessorOutput {
                            inputs,
                            seq_indices,
                        })
                    }),
                )
            } else {
                Box::new(
                    get_completion_input(
                        input_seqs
                            .iter()
                            .map(|seq| seq.get_toks().to_vec())
                            .collect::<Vec<_>>(),
                        input_seqs,
                        device,
                        no_kv_cache,
                        last_n_context_len,
                        return_raw_logits,
                        paged_attn_metadata.as_mut(),
                        prompt_batchsize,
                        mapper,
                    )
                    .map(|metadata| {
                        let InnerInputProcessorOutput {
                            inputs:
                                InputMetadata {
                                    input: input_ids,
                                    positions: seqlen_offsets,
                                    positions_kernel: seqlen_offsets_kernel,
                                    context_lens,
                                    position_ids,
                                    paged_attn_meta,
                                    flash_meta,
                                },
                            seq_indices,
                        } = metadata?;
                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
                            input_ids,
                            input_ids_full: None,
                            seqlen_offsets,
                            seqlen_offsets_full: None,
                            seqlen_offsets_kernel,
                            seqlen_offsets_kernel_full: None,
                            context_lens,
                            position_ids,
                            paged_attn_meta,
                            flash_meta,
                            flash_meta_full: None,
                        });
                        Ok(InputProcessorOutput {
                            inputs,
                            seq_indices,
                        })
                    }),
                )
            }
        }

        fn get_type(&self) -> InputsProcessorType {
            InputsProcessorType::Text
        }
    }
}