mistralrs/
embedding_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use std::num::NonZeroUsize;
4use std::{
5    ops::{Deref, DerefMut},
6    path::PathBuf,
7};
8
9use crate::{best_device, Model};
10
11#[derive(Clone)]
12/// Configure an embedding model with the various parameters for loading, running, and other inference behaviors.
13pub struct EmbeddingModelBuilder {
14    // Loading model
15    pub(crate) model_id: String,
16    pub(crate) token_source: TokenSource,
17    pub(crate) hf_revision: Option<String>,
18    pub(crate) write_uqff: Option<PathBuf>,
19    pub(crate) from_uqff: Option<Vec<PathBuf>>,
20    pub(crate) tokenizer_json: Option<String>,
21    pub(crate) device_mapping: Option<DeviceMapSetting>,
22    pub(crate) hf_cache_path: Option<PathBuf>,
23    pub(crate) device: Option<Device>,
24
25    // Model running
26    pub(crate) topology: Option<Topology>,
27    pub(crate) loader_type: Option<EmbeddingLoaderType>,
28    pub(crate) dtype: ModelDType,
29    pub(crate) force_cpu: bool,
30    pub(crate) isq: Option<IsqType>,
31    pub(crate) throughput_logging: bool,
32
33    // Other things
34    pub(crate) with_logging: bool,
35}
36
37impl EmbeddingModelBuilder {
38    /// A few defaults are applied here:
39    /// - Token source is from the cache (.cache/huggingface/token)
40    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
41    pub fn new(model_id: impl ToString) -> Self {
42        Self {
43            model_id: model_id.to_string(),
44            topology: None,
45            write_uqff: None,
46            from_uqff: None,
47            tokenizer_json: None,
48            loader_type: None,
49            dtype: ModelDType::Auto,
50            force_cpu: false,
51            token_source: TokenSource::CacheToken,
52            hf_revision: None,
53            isq: None,
54            with_logging: false,
55            device_mapping: None,
56            throughput_logging: false,
57            hf_cache_path: None,
58            device: None,
59        }
60    }
61
62    /// Enable runner throughput logging.
63    pub fn with_throughput_logging(mut self) -> Self {
64        self.throughput_logging = true;
65        self
66    }
67
68    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
69    pub fn with_topology(mut self, topology: Topology) -> Self {
70        self.topology = Some(topology);
71        self
72    }
73
74    /// Path to a discrete `tokenizer.json` file.
75    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
76        self.tokenizer_json = Some(tokenizer_json.to_string());
77        self
78    }
79
80    /// Manually set the model loader type. Otherwise, it will attempt to automatically
81    /// determine the loader type.
82    pub fn with_loader_type(mut self, loader_type: EmbeddingLoaderType) -> Self {
83        self.loader_type = Some(loader_type);
84        self
85    }
86
87    /// Load the model in a certain dtype.
88    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
89        self.dtype = dtype;
90        self
91    }
92
93    /// Force usage of the CPU device. Do not use PagedAttention with this.
94    pub fn with_force_cpu(mut self) -> Self {
95        self.force_cpu = true;
96        self
97    }
98
99    /// Source of the Hugging Face token.
100    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
101        self.token_source = token_source;
102        self
103    }
104
105    /// Set the revision to use for a Hugging Face remote model.
106    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
107        self.hf_revision = Some(revision.to_string());
108        self
109    }
110
111    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
112    pub fn with_isq(mut self, isq: IsqType) -> Self {
113        self.isq = Some(isq);
114        self
115    }
116
117    /// Enable logging.
118    pub fn with_logging(mut self) -> Self {
119        self.with_logging = true;
120        self
121    }
122
123    /// Provide metadata to initialize the device mapper.
124    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
125        self.device_mapping = Some(device_mapping);
126        self
127    }
128
129    #[deprecated(
130        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
131    )]
132    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
133    ///
134    /// For example, these include:
135    /// - `residual.safetensors`
136    /// - `tokenizer.json`
137    /// - `config.json`
138    /// - More depending on the model
139    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
140        self.from_uqff = Some(path);
141        self
142    }
143
144    /// Path to write a `.uqff` file to and serialize the other necessary files.
145    ///
146    /// The parent (part of the path excluding the filename) will determine where any other files
147    /// serialized are written to.
148    ///
149    /// For example, these include:
150    /// - `residual.safetensors`
151    /// - `tokenizer.json`
152    /// - `config.json`
153    /// - More depending on the model
154    pub fn write_uqff(mut self, path: PathBuf) -> Self {
155        self.write_uqff = Some(path);
156        self
157    }
158
159    /// Cache path for Hugging Face models downloaded locally
160    pub fn from_hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
161        self.hf_cache_path = Some(hf_cache_path);
162        self
163    }
164
165    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
166    pub fn with_device(mut self, device: Device) -> Self {
167        self.device = Some(device);
168        self
169    }
170
171    pub async fn build(self) -> anyhow::Result<Model> {
172        let config = EmbeddingSpecificConfig {
173            topology: self.topology,
174            write_uqff: self.write_uqff,
175            from_uqff: self.from_uqff,
176            hf_cache_path: self.hf_cache_path,
177        };
178
179        if self.with_logging {
180            initialize_logging();
181        }
182
183        let loader = EmbeddingLoaderBuilder::new(config, self.tokenizer_json, Some(self.model_id))
184            .build(self.loader_type);
185
186        // Load, into a Pipeline
187        let pipeline = loader.load_model_from_hf(
188            self.hf_revision,
189            self.token_source,
190            &self.dtype,
191            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
192            !self.with_logging,
193            self.device_mapping
194                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
195            self.isq,
196            None,
197        )?;
198
199        let scheduler_method = SchedulerConfig::DefaultScheduler {
200            method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()), // just a dummy
201        };
202
203        let runner =
204            MistralRsBuilder::new(pipeline, scheduler_method, self.throughput_logging, None);
205
206        Ok(Model::new(runner.build().await))
207    }
208}
209
210#[derive(Clone)]
211/// Configure a UQFF embedding model with the various parameters for loading, running, and other inference behaviors.
212/// This wraps and implements `DerefMut` for the UqffEmbeddingModelBuilder, so users should take care to not call UQFF-related methods.
213pub struct UqffEmbeddingModelBuilder(EmbeddingModelBuilder);
214
215impl UqffEmbeddingModelBuilder {
216    /// A few defaults are applied here:
217    /// - Token source is from the cache (.cache/huggingface/token)
218    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
219    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
220        let mut inner = EmbeddingModelBuilder::new(model_id);
221        inner.from_uqff = Some(uqff_file);
222        Self(inner)
223    }
224
225    pub async fn build(self) -> anyhow::Result<Model> {
226        self.0.build().await
227    }
228
229    /// This wraps the EmbeddingModelBuilder, so users should take care to not call UQFF-related methods.
230    pub fn into_inner(self) -> EmbeddingModelBuilder {
231        self.0
232    }
233}
234
235impl Deref for UqffEmbeddingModelBuilder {
236    type Target = EmbeddingModelBuilder;
237
238    fn deref(&self) -> &Self::Target {
239        &self.0
240    }
241}
242
243impl DerefMut for UqffEmbeddingModelBuilder {
244    fn deref_mut(&mut self) -> &mut Self::Target {
245        &mut self.0
246    }
247}
248
249impl From<UqffEmbeddingModelBuilder> for EmbeddingModelBuilder {
250    fn from(value: UqffEmbeddingModelBuilder) -> Self {
251        value.0
252    }
253}