mistralrs/
embedding_model.rs1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4 ops::{Deref, DerefMut},
5 path::PathBuf,
6};
7
8use crate::model_builder_trait::{build_embedding_pipeline, build_model_from_pipeline};
9use crate::Model;
10
11#[derive(Clone)]
12pub struct EmbeddingModelBuilder {
14 pub(crate) model_id: String,
16 pub(crate) token_source: TokenSource,
17 pub(crate) hf_revision: Option<String>,
18 pub(crate) write_uqff: Option<PathBuf>,
19 pub(crate) from_uqff: Option<Vec<PathBuf>>,
20 pub(crate) tokenizer_json: Option<String>,
21 pub(crate) device_mapping: Option<DeviceMapSetting>,
22 pub(crate) hf_cache_path: Option<PathBuf>,
23 pub(crate) device: Option<Device>,
24
25 pub(crate) topology: Option<Topology>,
27 pub(crate) topology_path: Option<String>,
28 pub(crate) loader_type: Option<EmbeddingLoaderType>,
29 pub(crate) dtype: ModelDType,
30 pub(crate) force_cpu: bool,
31 pub(crate) isq: Option<IsqType>,
32 pub(crate) throughput_logging: bool,
33
34 pub(crate) max_num_seqs: usize,
36 pub(crate) with_logging: bool,
37}
38
39impl EmbeddingModelBuilder {
40 pub fn new(model_id: impl ToString) -> Self {
45 Self {
46 model_id: model_id.to_string(),
47 topology: None,
48 topology_path: None,
49 write_uqff: None,
50 from_uqff: None,
51 tokenizer_json: None,
52 loader_type: None,
53 dtype: ModelDType::Auto,
54 force_cpu: false,
55 token_source: TokenSource::CacheToken,
56 hf_revision: None,
57 isq: None,
58 max_num_seqs: 32,
59 with_logging: false,
60 device_mapping: None,
61 throughput_logging: false,
62 hf_cache_path: None,
63 device: None,
64 }
65 }
66
67 pub fn with_throughput_logging(mut self) -> Self {
69 self.throughput_logging = true;
70 self
71 }
72
73 pub fn with_topology(mut self, topology: Topology) -> Self {
75 self.topology = Some(topology);
76 self
77 }
78
79 pub fn with_topology_from_path<P: AsRef<std::path::Path>>(
82 mut self,
83 path: P,
84 ) -> anyhow::Result<Self> {
85 let path_str = path.as_ref().to_string_lossy().to_string();
86 self.topology = Some(Topology::from_path(&path)?);
87 self.topology_path = Some(path_str);
88 Ok(self)
89 }
90
91 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
93 self.tokenizer_json = Some(tokenizer_json.to_string());
94 self
95 }
96
97 pub fn with_loader_type(mut self, loader_type: EmbeddingLoaderType) -> Self {
100 self.loader_type = Some(loader_type);
101 self
102 }
103
104 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
106 self.dtype = dtype;
107 self
108 }
109
110 pub fn with_force_cpu(mut self) -> Self {
112 self.force_cpu = true;
113 self
114 }
115
116 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
118 self.token_source = token_source;
119 self
120 }
121
122 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
124 self.hf_revision = Some(revision.to_string());
125 self
126 }
127
128 pub fn with_isq(mut self, isq: IsqType) -> Self {
130 self.isq = Some(isq);
131 self
132 }
133
134 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
136 self.max_num_seqs = max_num_seqs;
137 self
138 }
139
140 pub fn with_logging(mut self) -> Self {
142 self.with_logging = true;
143 self
144 }
145
146 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
148 self.device_mapping = Some(device_mapping);
149 self
150 }
151
152 #[deprecated(
153 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
154 )]
155 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
163 self.from_uqff = Some(path);
164 self
165 }
166
167 pub fn write_uqff(mut self, path: PathBuf) -> Self {
178 self.write_uqff = Some(path);
179 self
180 }
181
182 pub fn from_hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
184 self.hf_cache_path = Some(hf_cache_path);
185 self
186 }
187
188 pub fn with_device(mut self, device: Device) -> Self {
190 self.device = Some(device);
191 self
192 }
193
194 pub async fn build(self) -> anyhow::Result<Model> {
195 let (pipeline, scheduler_config, add_model_config) = build_embedding_pipeline(self).await?;
196 Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
197 }
198}
199
200#[derive(Clone)]
201pub struct UqffEmbeddingModelBuilder(EmbeddingModelBuilder);
204
205impl UqffEmbeddingModelBuilder {
206 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
210 let mut inner = EmbeddingModelBuilder::new(model_id);
211 inner.from_uqff = Some(uqff_file);
212 Self(inner)
213 }
214
215 pub async fn build(self) -> anyhow::Result<Model> {
216 self.0.build().await
217 }
218
219 pub fn into_inner(self) -> EmbeddingModelBuilder {
221 self.0
222 }
223}
224
225impl Deref for UqffEmbeddingModelBuilder {
226 type Target = EmbeddingModelBuilder;
227
228 fn deref(&self) -> &Self::Target {
229 &self.0
230 }
231}
232
233impl DerefMut for UqffEmbeddingModelBuilder {
234 fn deref_mut(&mut self) -> &mut Self::Target {
235 &mut self.0
236 }
237}
238
239impl From<UqffEmbeddingModelBuilder> for EmbeddingModelBuilder {
240 fn from(value: UqffEmbeddingModelBuilder) -> Self {
241 value.0
242 }
243}