mistralrs/
embedding_model.rs1use candle_core::Device;
2use mistralrs_core::*;
3use std::num::NonZeroUsize;
4use std::{
5    ops::{Deref, DerefMut},
6    path::PathBuf,
7};
8
9use crate::{best_device, Model};
10
11#[derive(Clone)]
12pub struct EmbeddingModelBuilder {
14    pub(crate) model_id: String,
16    pub(crate) token_source: TokenSource,
17    pub(crate) hf_revision: Option<String>,
18    pub(crate) write_uqff: Option<PathBuf>,
19    pub(crate) from_uqff: Option<Vec<PathBuf>>,
20    pub(crate) tokenizer_json: Option<String>,
21    pub(crate) device_mapping: Option<DeviceMapSetting>,
22    pub(crate) hf_cache_path: Option<PathBuf>,
23    pub(crate) device: Option<Device>,
24
25    pub(crate) topology: Option<Topology>,
27    pub(crate) loader_type: Option<EmbeddingLoaderType>,
28    pub(crate) dtype: ModelDType,
29    pub(crate) force_cpu: bool,
30    pub(crate) isq: Option<IsqType>,
31    pub(crate) throughput_logging: bool,
32
33    pub(crate) with_logging: bool,
35}
36
37impl EmbeddingModelBuilder {
38    pub fn new(model_id: impl ToString) -> Self {
42        Self {
43            model_id: model_id.to_string(),
44            topology: None,
45            write_uqff: None,
46            from_uqff: None,
47            tokenizer_json: None,
48            loader_type: None,
49            dtype: ModelDType::Auto,
50            force_cpu: false,
51            token_source: TokenSource::CacheToken,
52            hf_revision: None,
53            isq: None,
54            with_logging: false,
55            device_mapping: None,
56            throughput_logging: false,
57            hf_cache_path: None,
58            device: None,
59        }
60    }
61
62    pub fn with_throughput_logging(mut self) -> Self {
64        self.throughput_logging = true;
65        self
66    }
67
68    pub fn with_topology(mut self, topology: Topology) -> Self {
70        self.topology = Some(topology);
71        self
72    }
73
74    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
76        self.tokenizer_json = Some(tokenizer_json.to_string());
77        self
78    }
79
80    pub fn with_loader_type(mut self, loader_type: EmbeddingLoaderType) -> Self {
83        self.loader_type = Some(loader_type);
84        self
85    }
86
87    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
89        self.dtype = dtype;
90        self
91    }
92
93    pub fn with_force_cpu(mut self) -> Self {
95        self.force_cpu = true;
96        self
97    }
98
99    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
101        self.token_source = token_source;
102        self
103    }
104
105    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
107        self.hf_revision = Some(revision.to_string());
108        self
109    }
110
111    pub fn with_isq(mut self, isq: IsqType) -> Self {
113        self.isq = Some(isq);
114        self
115    }
116
117    pub fn with_logging(mut self) -> Self {
119        self.with_logging = true;
120        self
121    }
122
123    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
125        self.device_mapping = Some(device_mapping);
126        self
127    }
128
129    #[deprecated(
130        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
131    )]
132    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
140        self.from_uqff = Some(path);
141        self
142    }
143
144    pub fn write_uqff(mut self, path: PathBuf) -> Self {
155        self.write_uqff = Some(path);
156        self
157    }
158
159    pub fn from_hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
161        self.hf_cache_path = Some(hf_cache_path);
162        self
163    }
164
165    pub fn with_device(mut self, device: Device) -> Self {
167        self.device = Some(device);
168        self
169    }
170
171    pub async fn build(self) -> anyhow::Result<Model> {
172        let config = EmbeddingSpecificConfig {
173            topology: self.topology,
174            write_uqff: self.write_uqff,
175            from_uqff: self.from_uqff,
176            hf_cache_path: self.hf_cache_path,
177        };
178
179        if self.with_logging {
180            initialize_logging();
181        }
182
183        let loader = EmbeddingLoaderBuilder::new(config, self.tokenizer_json, Some(self.model_id))
184            .build(self.loader_type);
185
186        let pipeline = loader.load_model_from_hf(
188            self.hf_revision,
189            self.token_source,
190            &self.dtype,
191            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
192            !self.with_logging,
193            self.device_mapping
194                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
195            self.isq,
196            None,
197        )?;
198
199        let scheduler_method = SchedulerConfig::DefaultScheduler {
200            method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()), };
202
203        let runner =
204            MistralRsBuilder::new(pipeline, scheduler_method, self.throughput_logging, None);
205
206        Ok(Model::new(runner.build().await))
207    }
208}
209
210#[derive(Clone)]
211pub struct UqffEmbeddingModelBuilder(EmbeddingModelBuilder);
214
215impl UqffEmbeddingModelBuilder {
216    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
220        let mut inner = EmbeddingModelBuilder::new(model_id);
221        inner.from_uqff = Some(uqff_file);
222        Self(inner)
223    }
224
225    pub async fn build(self) -> anyhow::Result<Model> {
226        self.0.build().await
227    }
228
229    pub fn into_inner(self) -> EmbeddingModelBuilder {
231        self.0
232    }
233}
234
235impl Deref for UqffEmbeddingModelBuilder {
236    type Target = EmbeddingModelBuilder;
237
238    fn deref(&self) -> &Self::Target {
239        &self.0
240    }
241}
242
243impl DerefMut for UqffEmbeddingModelBuilder {
244    fn deref_mut(&mut self) -> &mut Self::Target {
245        &mut self.0
246    }
247}
248
249impl From<UqffEmbeddingModelBuilder> for EmbeddingModelBuilder {
250    fn from(value: UqffEmbeddingModelBuilder) -> Self {
251        value.0
252    }
253}