mistralrs/
embedding_model.rs1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4 ops::{Deref, DerefMut},
5 path::PathBuf,
6};
7
8use crate::{best_device, Model};
9
10#[derive(Clone)]
11pub struct EmbeddingModelBuilder {
13 pub(crate) model_id: String,
15 pub(crate) token_source: TokenSource,
16 pub(crate) hf_revision: Option<String>,
17 pub(crate) write_uqff: Option<PathBuf>,
18 pub(crate) from_uqff: Option<Vec<PathBuf>>,
19 pub(crate) tokenizer_json: Option<String>,
20 pub(crate) device_mapping: Option<DeviceMapSetting>,
21 pub(crate) hf_cache_path: Option<PathBuf>,
22 pub(crate) device: Option<Device>,
23
24 pub(crate) topology: Option<Topology>,
26 pub(crate) loader_type: Option<EmbeddingLoaderType>,
27 pub(crate) dtype: ModelDType,
28 pub(crate) force_cpu: bool,
29 pub(crate) isq: Option<IsqType>,
30 pub(crate) throughput_logging: bool,
31
32 pub(crate) max_num_seqs: usize,
34 pub(crate) with_logging: bool,
35}
36
37impl EmbeddingModelBuilder {
38 pub fn new(model_id: impl ToString) -> Self {
43 Self {
44 model_id: model_id.to_string(),
45 topology: None,
46 write_uqff: None,
47 from_uqff: None,
48 tokenizer_json: None,
49 loader_type: None,
50 dtype: ModelDType::Auto,
51 force_cpu: false,
52 token_source: TokenSource::CacheToken,
53 hf_revision: None,
54 isq: None,
55 max_num_seqs: 32,
56 with_logging: false,
57 device_mapping: None,
58 throughput_logging: false,
59 hf_cache_path: None,
60 device: None,
61 }
62 }
63
64 pub fn with_throughput_logging(mut self) -> Self {
66 self.throughput_logging = true;
67 self
68 }
69
70 pub fn with_topology(mut self, topology: Topology) -> Self {
72 self.topology = Some(topology);
73 self
74 }
75
76 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
78 self.tokenizer_json = Some(tokenizer_json.to_string());
79 self
80 }
81
82 pub fn with_loader_type(mut self, loader_type: EmbeddingLoaderType) -> Self {
85 self.loader_type = Some(loader_type);
86 self
87 }
88
89 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
91 self.dtype = dtype;
92 self
93 }
94
95 pub fn with_force_cpu(mut self) -> Self {
97 self.force_cpu = true;
98 self
99 }
100
101 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
103 self.token_source = token_source;
104 self
105 }
106
107 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
109 self.hf_revision = Some(revision.to_string());
110 self
111 }
112
113 pub fn with_isq(mut self, isq: IsqType) -> Self {
115 self.isq = Some(isq);
116 self
117 }
118
119 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
121 self.max_num_seqs = max_num_seqs;
122 self
123 }
124
125 pub fn with_logging(mut self) -> Self {
127 self.with_logging = true;
128 self
129 }
130
131 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
133 self.device_mapping = Some(device_mapping);
134 self
135 }
136
137 #[deprecated(
138 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
139 )]
140 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
148 self.from_uqff = Some(path);
149 self
150 }
151
152 pub fn write_uqff(mut self, path: PathBuf) -> Self {
163 self.write_uqff = Some(path);
164 self
165 }
166
167 pub fn from_hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
169 self.hf_cache_path = Some(hf_cache_path);
170 self
171 }
172
173 pub fn with_device(mut self, device: Device) -> Self {
175 self.device = Some(device);
176 self
177 }
178
179 pub async fn build(self) -> anyhow::Result<Model> {
180 let config = EmbeddingSpecificConfig {
181 topology: self.topology,
182 write_uqff: self.write_uqff,
183 from_uqff: self.from_uqff,
184 hf_cache_path: self.hf_cache_path,
185 };
186
187 if self.with_logging {
188 initialize_logging();
189 }
190
191 let loader = EmbeddingLoaderBuilder::new(config, self.tokenizer_json, Some(self.model_id))
192 .build(self.loader_type);
193
194 let pipeline = loader.load_model_from_hf(
196 self.hf_revision,
197 self.token_source,
198 &self.dtype,
199 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
200 !self.with_logging,
201 self.device_mapping
202 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
203 self.isq,
204 None,
205 )?;
206
207 let scheduler_method = SchedulerConfig::DefaultScheduler {
208 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
209 };
210
211 let runner =
212 MistralRsBuilder::new(pipeline, scheduler_method, self.throughput_logging, None);
213
214 Ok(Model::new(runner.build().await))
215 }
216}
217
218#[derive(Clone)]
219pub struct UqffEmbeddingModelBuilder(EmbeddingModelBuilder);
222
223impl UqffEmbeddingModelBuilder {
224 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
228 let mut inner = EmbeddingModelBuilder::new(model_id);
229 inner.from_uqff = Some(uqff_file);
230 Self(inner)
231 }
232
233 pub async fn build(self) -> anyhow::Result<Model> {
234 self.0.build().await
235 }
236
237 pub fn into_inner(self) -> EmbeddingModelBuilder {
239 self.0
240 }
241}
242
243impl Deref for UqffEmbeddingModelBuilder {
244 type Target = EmbeddingModelBuilder;
245
246 fn deref(&self) -> &Self::Target {
247 &self.0
248 }
249}
250
251impl DerefMut for UqffEmbeddingModelBuilder {
252 fn deref_mut(&mut self) -> &mut Self::Target {
253 &mut self.0
254 }
255}
256
257impl From<UqffEmbeddingModelBuilder> for EmbeddingModelBuilder {
258 fn from(value: UqffEmbeddingModelBuilder) -> Self {
259 value.0
260 }
261}