mistralrs_core/pipeline/
processing.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use either::Either;
5use indexmap::IndexMap;
6
7use crate::{
8    vision_models::{preprocessor_config::PreProcessorConfig, processor_config::ProcessorConfig},
9    MessageContent, Pipeline, Tool,
10};
11
12use super::{chat_template::apply_chat_template_to, text_models_inputs_processor, InputsProcessor};
13
14/// Trait to create processors.
15pub trait ProcessorCreator {
16    fn new_processor(
17        _: Option<ProcessorConfig>,
18        _: PreProcessorConfig,
19    ) -> Arc<dyn Processor + Send + Sync>;
20}
21
22pub enum MessagesAction {
23    // For idefics2, others which use the "new" openai format
24    Keep,
25    // For everything else
26    FlattenOnlyText,
27}
28
29/// Processor for messages.
30/// Also includes method to retrieve the input processor for processing inputs for the
31/// model.
32pub trait Processor {
33    /// Get the tokens and the untokenized prompt. `add_special_tokens` should usually be true.
34    fn process(
35        &self,
36        pipeline: &dyn Pipeline,
37        messages: Vec<IndexMap<String, MessageContent>>,
38        add_generation_prompt: bool,
39        add_special_tokens: bool,
40        tools: Vec<Tool>,
41    ) -> Result<(Vec<u32>, String)> {
42        // for message in messages.iter_mut() {
43        //     if message["role"].as_ref().left().is_some_and(|x| x == "tool") {
44        //         message["role"] = Either::Left("ipython".to_string());
45        //         message["content"] = Either::Left(format!(
46        //             "{{\"output\": \"{}\"}}",
47        //             message["content"].as_ref().unwrap_left()
48        //         ));
49        //     }
50        // }
51
52        let prompt = apply_chat_template(
53            pipeline,
54            messages,
55            add_generation_prompt,
56            self.template_action(),
57            tools,
58        )?;
59        let encoding = pipeline
60            .tokenizer()
61            .with_context(|| {
62                "Default `Processor::process` requires the model to have a tokenizer."
63            })?
64            .encode_fast(prompt.clone(), add_special_tokens)
65            .map_err(anyhow::Error::msg)?;
66        Ok((encoding.get_ids().to_vec(), prompt))
67    }
68    fn inputs_processor(&self) -> Arc<dyn InputsProcessor>;
69    fn get_special_tokens(&self) -> &[&'static str];
70    fn template_action(&self) -> MessagesAction;
71}
72
73pub(crate) fn apply_chat_template(
74    pipeline: &dyn Pipeline,
75    messages: Vec<IndexMap<String, MessageContent>>,
76    add_generation_prompt: bool,
77    action: MessagesAction,
78    tools: Vec<Tool>,
79) -> Result<String> {
80    let messages = match action {
81        MessagesAction::Keep => messages,
82        MessagesAction::FlattenOnlyText => {
83            // This is really only for image models. If they need to flatten it s.t. they only see
84            // the text, do that.
85            let mut new_messages = Vec::new();
86            for message in messages {
87                let mut new_message = IndexMap::new();
88                for (k, v) in message {
89                    if k == "content" {
90                        match v {
91                            Either::Left(lv) => {
92                                new_message.insert(k, Either::Left(lv));
93                            }
94                            Either::Right(rv) => {
95                                'outer: for content_row in rv {
96                                    for (content_k, content_v) in content_row {
97                                        if content_k == "text" {
98                                            if let Some(content_str) = content_v.as_str() {
99                                                new_message.insert(
100                                                    k,
101                                                    Either::Left(content_str.to_string()),
102                                                );
103                                                break 'outer;
104                                            }
105                                        }
106                                    }
107                                }
108                            }
109                        }
110                    } else {
111                        new_message.insert(k, Either::Left(v.left().unwrap()));
112                    }
113                }
114                new_messages.push(new_message)
115            }
116            new_messages
117        }
118    };
119    let chat_template = pipeline
120        .get_chat_template()
121        .with_context(|| "`apply_chat_template` expects the pipeline to have a chat template.")?;
122    let template = chat_template.chat_template.as_ref().unwrap();
123    let bos_tok = if let Some(ref bos) = chat_template.bos_token {
124        match bos.0 {
125            Either::Left(ref lit) => Some(lit.to_string()),
126            Either::Right(ref added) => Some(added.content.to_string()),
127        }
128    } else {
129        None
130    };
131    let eos_tok = if let Some(ref eos) = chat_template.eos_token {
132        match eos.0 {
133            Either::Left(ref lit) => Some(lit.to_string()),
134            Either::Right(ref added) => Some(added.content.to_string()),
135        }
136    } else {
137        None
138    };
139    let unk_tok = if let Some(ref unk) = chat_template.unk_token {
140        match unk.0 {
141            Either::Left(ref lit) => Some(lit.to_string()),
142            Either::Right(ref added) => Some(added.content.to_string()),
143        }
144    } else {
145        None
146    };
147    apply_chat_template_to(
148        messages,
149        add_generation_prompt,
150        template,
151        bos_tok,
152        eos_tok,
153        unk_tok,
154        tools,
155    )
156}
157
158pub struct BasicProcessor;
159
160impl Processor for BasicProcessor {
161    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
162        Arc::new(text_models_inputs_processor::TextInputsProcessor)
163    }
164    fn get_special_tokens(&self) -> &[&'static str] {
165        &[]
166    }
167    fn template_action(&self) -> MessagesAction {
168        MessagesAction::Keep
169    }
170}