mistralrs/
embedding_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4    ops::{Deref, DerefMut},
5    path::PathBuf,
6};
7
8use crate::{best_device, Model};
9
10#[derive(Clone)]
11/// Configure an embedding model with the various parameters for loading, running, and other inference behaviors.
12pub struct EmbeddingModelBuilder {
13    // Loading model
14    pub(crate) model_id: String,
15    pub(crate) token_source: TokenSource,
16    pub(crate) hf_revision: Option<String>,
17    pub(crate) write_uqff: Option<PathBuf>,
18    pub(crate) from_uqff: Option<Vec<PathBuf>>,
19    pub(crate) tokenizer_json: Option<String>,
20    pub(crate) device_mapping: Option<DeviceMapSetting>,
21    pub(crate) hf_cache_path: Option<PathBuf>,
22    pub(crate) device: Option<Device>,
23
24    // Model running
25    pub(crate) topology: Option<Topology>,
26    pub(crate) loader_type: Option<EmbeddingLoaderType>,
27    pub(crate) dtype: ModelDType,
28    pub(crate) force_cpu: bool,
29    pub(crate) isq: Option<IsqType>,
30    pub(crate) throughput_logging: bool,
31
32    // Other things
33    pub(crate) max_num_seqs: usize,
34    pub(crate) with_logging: bool,
35}
36
37impl EmbeddingModelBuilder {
38    /// A few defaults are applied here:
39    /// - Maximum number of sequences running is 32
40    /// - Token source is from the cache (.cache/huggingface/token)
41    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
42    pub fn new(model_id: impl ToString) -> Self {
43        Self {
44            model_id: model_id.to_string(),
45            topology: None,
46            write_uqff: None,
47            from_uqff: None,
48            tokenizer_json: None,
49            loader_type: None,
50            dtype: ModelDType::Auto,
51            force_cpu: false,
52            token_source: TokenSource::CacheToken,
53            hf_revision: None,
54            isq: None,
55            max_num_seqs: 32,
56            with_logging: false,
57            device_mapping: None,
58            throughput_logging: false,
59            hf_cache_path: None,
60            device: None,
61        }
62    }
63
64    /// Enable runner throughput logging.
65    pub fn with_throughput_logging(mut self) -> Self {
66        self.throughput_logging = true;
67        self
68    }
69
70    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
71    pub fn with_topology(mut self, topology: Topology) -> Self {
72        self.topology = Some(topology);
73        self
74    }
75
76    /// Path to a discrete `tokenizer.json` file.
77    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
78        self.tokenizer_json = Some(tokenizer_json.to_string());
79        self
80    }
81
82    /// Manually set the model loader type. Otherwise, it will attempt to automatically
83    /// determine the loader type.
84    pub fn with_loader_type(mut self, loader_type: EmbeddingLoaderType) -> Self {
85        self.loader_type = Some(loader_type);
86        self
87    }
88
89    /// Load the model in a certain dtype.
90    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
91        self.dtype = dtype;
92        self
93    }
94
95    /// Force usage of the CPU device. Do not use PagedAttention with this.
96    pub fn with_force_cpu(mut self) -> Self {
97        self.force_cpu = true;
98        self
99    }
100
101    /// Source of the Hugging Face token.
102    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
103        self.token_source = token_source;
104        self
105    }
106
107    /// Set the revision to use for a Hugging Face remote model.
108    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
109        self.hf_revision = Some(revision.to_string());
110        self
111    }
112
113    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
114    pub fn with_isq(mut self, isq: IsqType) -> Self {
115        self.isq = Some(isq);
116        self
117    }
118
119    /// Set the maximum number of sequences which can be run at once.
120    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
121        self.max_num_seqs = max_num_seqs;
122        self
123    }
124
125    /// Enable logging.
126    pub fn with_logging(mut self) -> Self {
127        self.with_logging = true;
128        self
129    }
130
131    /// Provide metadata to initialize the device mapper.
132    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
133        self.device_mapping = Some(device_mapping);
134        self
135    }
136
137    #[deprecated(
138        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
139    )]
140    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
141    ///
142    /// For example, these include:
143    /// - `residual.safetensors`
144    /// - `tokenizer.json`
145    /// - `config.json`
146    /// - More depending on the model
147    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
148        self.from_uqff = Some(path);
149        self
150    }
151
152    /// Path to write a `.uqff` file to and serialize the other necessary files.
153    ///
154    /// The parent (part of the path excluding the filename) will determine where any other files
155    /// serialized are written to.
156    ///
157    /// For example, these include:
158    /// - `residual.safetensors`
159    /// - `tokenizer.json`
160    /// - `config.json`
161    /// - More depending on the model
162    pub fn write_uqff(mut self, path: PathBuf) -> Self {
163        self.write_uqff = Some(path);
164        self
165    }
166
167    /// Cache path for Hugging Face models downloaded locally
168    pub fn from_hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
169        self.hf_cache_path = Some(hf_cache_path);
170        self
171    }
172
173    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
174    pub fn with_device(mut self, device: Device) -> Self {
175        self.device = Some(device);
176        self
177    }
178
179    pub async fn build(self) -> anyhow::Result<Model> {
180        let config = EmbeddingSpecificConfig {
181            topology: self.topology,
182            write_uqff: self.write_uqff,
183            from_uqff: self.from_uqff,
184            hf_cache_path: self.hf_cache_path,
185        };
186
187        if self.with_logging {
188            initialize_logging();
189        }
190
191        let loader = EmbeddingLoaderBuilder::new(config, self.tokenizer_json, Some(self.model_id))
192            .build(self.loader_type);
193
194        // Load, into a Pipeline
195        let pipeline = loader.load_model_from_hf(
196            self.hf_revision,
197            self.token_source,
198            &self.dtype,
199            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
200            !self.with_logging,
201            self.device_mapping
202                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
203            self.isq,
204            None,
205        )?;
206
207        let scheduler_method = SchedulerConfig::DefaultScheduler {
208            method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
209        };
210
211        let runner =
212            MistralRsBuilder::new(pipeline, scheduler_method, self.throughput_logging, None);
213
214        Ok(Model::new(runner.build().await))
215    }
216}
217
218#[derive(Clone)]
219/// Configure a UQFF embedding model with the various parameters for loading, running, and other inference behaviors.
220/// This wraps and implements `DerefMut` for the UqffEmbeddingModelBuilder, so users should take care to not call UQFF-related methods.
221pub struct UqffEmbeddingModelBuilder(EmbeddingModelBuilder);
222
223impl UqffEmbeddingModelBuilder {
224    /// A few defaults are applied here:
225    /// - Token source is from the cache (.cache/huggingface/token)
226    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
227    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
228        let mut inner = EmbeddingModelBuilder::new(model_id);
229        inner.from_uqff = Some(uqff_file);
230        Self(inner)
231    }
232
233    pub async fn build(self) -> anyhow::Result<Model> {
234        self.0.build().await
235    }
236
237    /// This wraps the EmbeddingModelBuilder, so users should take care to not call UQFF-related methods.
238    pub fn into_inner(self) -> EmbeddingModelBuilder {
239        self.0
240    }
241}
242
243impl Deref for UqffEmbeddingModelBuilder {
244    type Target = EmbeddingModelBuilder;
245
246    fn deref(&self) -> &Self::Target {
247        &self.0
248    }
249}
250
251impl DerefMut for UqffEmbeddingModelBuilder {
252    fn deref_mut(&mut self) -> &mut Self::Target {
253        &mut self.0
254    }
255}
256
257impl From<UqffEmbeddingModelBuilder> for EmbeddingModelBuilder {
258    fn from(value: UqffEmbeddingModelBuilder) -> Self {
259        value.0
260    }
261}