mistralrs_core/diffusion_models/
processor.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use std::{any::Any, num::NonZeroUsize, sync::Arc};

use anyhow::{Context, Result};
use candle_core::Device;
use indexmap::IndexMap;
use tokenizers::Tokenizer;

use crate::{
    device_map::DeviceMapper,
    pipeline::{
        text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
        InputsProcessorType, MessagesAction, Processor,
    },
    sequence::Sequence,
    MessageContent, Pipeline,
};

use super::DiffusionGenerationParams;

pub struct DiffusionProcessor;

impl Processor for DiffusionProcessor {
    fn process(
        &self,
        _pipeline: &dyn Pipeline,
        _messages: Vec<IndexMap<String, MessageContent>>,
        _add_generation_prompt: bool,
        _add_special_tokens: bool,
        _tools: Vec<crate::Tool>,
    ) -> Result<(Vec<u32>, String)> {
        anyhow::bail!(
            "DiffusionProcessor::process should not be used. It does not expect chat messages."
        )
    }
    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
        Arc::new(DiffusionInputsProcessor)
    }
    fn get_special_tokens(&self) -> &[&'static str] {
        &[]
    }
    fn template_action(&self) -> MessagesAction {
        // Just a default
        MessagesAction::FlattenOnlyText
    }
}

pub struct DiffusionInputsProcessor;

#[derive(Clone)]
pub struct ModelInputs {
    pub(crate) prompts: Vec<String>,
    pub(crate) params: DiffusionGenerationParams,
}

impl InputsProcessor for DiffusionInputsProcessor {
    fn get_type(&self) -> InputsProcessorType {
        InputsProcessorType::Text
    }

    fn process_inputs(
        &self,
        _tokenizer: Option<Arc<Tokenizer>>,
        input_seqs: &mut [&mut Sequence],
        _is_prompt: bool,
        _is_xlora: bool,
        _device: &Device,
        _no_kv_cache: bool,
        _last_n_context_len: Option<(usize, usize)>,
        _return_raw_logits: bool,
        _other_config: Option<Arc<dyn Any>>,
        _paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
        prompt_batchsize: Option<NonZeroUsize>,
        _mapper: Option<&dyn DeviceMapper>,
    ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
        let mut make_value = if prompt_batchsize.is_some() {
            return Box::new(std::iter::once(Err(anyhow::Error::msg(
                "Prompt batching is unsupported for diffusion models",
            ))));
        } else {
            || {
                let inputs = ModelInputs {
                    prompts: input_seqs
                        .iter_mut()
                        .map(|seq| seq.get_initial_prompt().to_string())
                        .collect::<Vec<_>>(),
                    params: input_seqs[0]
                        .get_diffusion_diffusion_params()
                        .context("Diffusion model params must be present")?,
                };
                Ok(InputProcessorOutput {
                    inputs: Box::new(inputs),
                    seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
                })
            }
        };
        Box::new(std::iter::once(make_value()))
    }
}