mistralrs/
embedding_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4    ops::{Deref, DerefMut},
5    path::PathBuf,
6};
7
8use crate::model_builder_trait::{build_embedding_pipeline, build_model_from_pipeline};
9use crate::Model;
10
11#[derive(Clone)]
12/// Configure an embedding model with the various parameters for loading, running, and other inference behaviors.
13pub struct EmbeddingModelBuilder {
14    // Loading model
15    pub(crate) model_id: String,
16    pub(crate) token_source: TokenSource,
17    pub(crate) hf_revision: Option<String>,
18    pub(crate) write_uqff: Option<PathBuf>,
19    pub(crate) from_uqff: Option<Vec<PathBuf>>,
20    pub(crate) tokenizer_json: Option<String>,
21    pub(crate) device_mapping: Option<DeviceMapSetting>,
22    pub(crate) hf_cache_path: Option<PathBuf>,
23    pub(crate) device: Option<Device>,
24
25    // Model running
26    pub(crate) topology: Option<Topology>,
27    pub(crate) topology_path: Option<String>,
28    pub(crate) loader_type: Option<EmbeddingLoaderType>,
29    pub(crate) dtype: ModelDType,
30    pub(crate) force_cpu: bool,
31    pub(crate) isq: Option<IsqType>,
32    pub(crate) throughput_logging: bool,
33
34    // Other things
35    pub(crate) max_num_seqs: usize,
36    pub(crate) with_logging: bool,
37}
38
39impl EmbeddingModelBuilder {
40    /// A few defaults are applied here:
41    /// - Maximum number of sequences running is 32
42    /// - Token source is from the cache (.cache/huggingface/token)
43    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
44    pub fn new(model_id: impl ToString) -> Self {
45        Self {
46            model_id: model_id.to_string(),
47            topology: None,
48            topology_path: None,
49            write_uqff: None,
50            from_uqff: None,
51            tokenizer_json: None,
52            loader_type: None,
53            dtype: ModelDType::Auto,
54            force_cpu: false,
55            token_source: TokenSource::CacheToken,
56            hf_revision: None,
57            isq: None,
58            max_num_seqs: 32,
59            with_logging: false,
60            device_mapping: None,
61            throughput_logging: false,
62            hf_cache_path: None,
63            device: None,
64        }
65    }
66
67    /// Enable runner throughput logging.
68    pub fn with_throughput_logging(mut self) -> Self {
69        self.throughput_logging = true;
70        self
71    }
72
73    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
74    pub fn with_topology(mut self, topology: Topology) -> Self {
75        self.topology = Some(topology);
76        self
77    }
78
79    /// Set the model topology from a path. This preserves the path for unload/reload support.
80    /// If there is an overlap, the topology type is used over the ISQ type.
81    pub fn with_topology_from_path<P: AsRef<std::path::Path>>(
82        mut self,
83        path: P,
84    ) -> anyhow::Result<Self> {
85        let path_str = path.as_ref().to_string_lossy().to_string();
86        self.topology = Some(Topology::from_path(&path)?);
87        self.topology_path = Some(path_str);
88        Ok(self)
89    }
90
91    /// Path to a discrete `tokenizer.json` file.
92    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
93        self.tokenizer_json = Some(tokenizer_json.to_string());
94        self
95    }
96
97    /// Manually set the model loader type. Otherwise, it will attempt to automatically
98    /// determine the loader type.
99    pub fn with_loader_type(mut self, loader_type: EmbeddingLoaderType) -> Self {
100        self.loader_type = Some(loader_type);
101        self
102    }
103
104    /// Load the model in a certain dtype.
105    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
106        self.dtype = dtype;
107        self
108    }
109
110    /// Force usage of the CPU device. Do not use PagedAttention with this.
111    pub fn with_force_cpu(mut self) -> Self {
112        self.force_cpu = true;
113        self
114    }
115
116    /// Source of the Hugging Face token.
117    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
118        self.token_source = token_source;
119        self
120    }
121
122    /// Set the revision to use for a Hugging Face remote model.
123    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
124        self.hf_revision = Some(revision.to_string());
125        self
126    }
127
128    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
129    pub fn with_isq(mut self, isq: IsqType) -> Self {
130        self.isq = Some(isq);
131        self
132    }
133
134    /// Set the maximum number of sequences which can be run at once.
135    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
136        self.max_num_seqs = max_num_seqs;
137        self
138    }
139
140    /// Enable logging.
141    pub fn with_logging(mut self) -> Self {
142        self.with_logging = true;
143        self
144    }
145
146    /// Provide metadata to initialize the device mapper.
147    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
148        self.device_mapping = Some(device_mapping);
149        self
150    }
151
152    #[deprecated(
153        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
154    )]
155    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
156    ///
157    /// For example, these include:
158    /// - `residual.safetensors`
159    /// - `tokenizer.json`
160    /// - `config.json`
161    /// - More depending on the model
162    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
163        self.from_uqff = Some(path);
164        self
165    }
166
167    /// Path to write a `.uqff` file to and serialize the other necessary files.
168    ///
169    /// The parent (part of the path excluding the filename) will determine where any other files
170    /// serialized are written to.
171    ///
172    /// For example, these include:
173    /// - `residual.safetensors`
174    /// - `tokenizer.json`
175    /// - `config.json`
176    /// - More depending on the model
177    pub fn write_uqff(mut self, path: PathBuf) -> Self {
178        self.write_uqff = Some(path);
179        self
180    }
181
182    /// Cache path for Hugging Face models downloaded locally
183    pub fn from_hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
184        self.hf_cache_path = Some(hf_cache_path);
185        self
186    }
187
188    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
189    pub fn with_device(mut self, device: Device) -> Self {
190        self.device = Some(device);
191        self
192    }
193
194    pub async fn build(self) -> anyhow::Result<Model> {
195        let (pipeline, scheduler_config, add_model_config) = build_embedding_pipeline(self).await?;
196        Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
197    }
198}
199
200#[derive(Clone)]
201/// Configure a UQFF embedding model with the various parameters for loading, running, and other inference behaviors.
202/// This wraps and implements `DerefMut` for the UqffEmbeddingModelBuilder, so users should take care to not call UQFF-related methods.
203pub struct UqffEmbeddingModelBuilder(EmbeddingModelBuilder);
204
205impl UqffEmbeddingModelBuilder {
206    /// A few defaults are applied here:
207    /// - Token source is from the cache (.cache/huggingface/token)
208    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
209    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
210        let mut inner = EmbeddingModelBuilder::new(model_id);
211        inner.from_uqff = Some(uqff_file);
212        Self(inner)
213    }
214
215    pub async fn build(self) -> anyhow::Result<Model> {
216        self.0.build().await
217    }
218
219    /// This wraps the EmbeddingModelBuilder, so users should take care to not call UQFF-related methods.
220    pub fn into_inner(self) -> EmbeddingModelBuilder {
221        self.0
222    }
223}
224
225impl Deref for UqffEmbeddingModelBuilder {
226    type Target = EmbeddingModelBuilder;
227
228    fn deref(&self) -> &Self::Target {
229        &self.0
230    }
231}
232
233impl DerefMut for UqffEmbeddingModelBuilder {
234    fn deref_mut(&mut self) -> &mut Self::Target {
235        &mut self.0
236    }
237}
238
239impl From<UqffEmbeddingModelBuilder> for EmbeddingModelBuilder {
240    fn from(value: UqffEmbeddingModelBuilder) -> Self {
241        value.0
242    }
243}