1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 ops::{Deref, DerefMut},
7 path::PathBuf,
8 sync::Arc,
9};
10
11use crate::model_builder_trait::{build_model_from_pipeline, build_text_pipeline};
12use crate::Model;
13
14#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17 pub callback: Arc<ToolCallback>,
18 pub tool: Tool,
19}
20
21#[derive(Clone)]
22pub struct TextModelBuilder {
24 pub(crate) model_id: String,
26 pub(crate) token_source: TokenSource,
27 pub(crate) hf_revision: Option<String>,
28 pub(crate) write_uqff: Option<PathBuf>,
29 pub(crate) from_uqff: Option<Vec<PathBuf>>,
30 pub(crate) imatrix: Option<PathBuf>,
31 pub(crate) calibration_file: Option<PathBuf>,
32 pub(crate) chat_template: Option<String>,
33 pub(crate) jinja_explicit: Option<String>,
34 pub(crate) tokenizer_json: Option<String>,
35 pub(crate) device_mapping: Option<DeviceMapSetting>,
36 pub(crate) hf_cache_path: Option<PathBuf>,
37 pub(crate) search_embedding_model: Option<SearchEmbeddingModel>,
38 pub(crate) search_callback: Option<Arc<SearchCallback>>,
39 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41 pub(crate) mcp_client_config: Option<McpClientConfig>,
42 pub(crate) device: Option<Device>,
43 pub(crate) matformer_config_path: Option<PathBuf>,
44 pub(crate) matformer_slice_name: Option<String>,
45
46 pub(crate) topology: Option<Topology>,
48 pub(crate) topology_path: Option<String>,
49 pub(crate) organization: IsqOrganization,
50 pub(crate) loader_type: Option<NormalLoaderType>,
51 pub(crate) dtype: ModelDType,
52 pub(crate) force_cpu: bool,
53 pub(crate) isq: Option<IsqType>,
54 pub(crate) throughput_logging: bool,
55
56 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
58 pub(crate) max_num_seqs: usize,
59 pub(crate) no_kv_cache: bool,
60 pub(crate) with_logging: bool,
61 pub(crate) prefix_cache_n: Option<usize>,
62}
63
64pub struct PagedAttentionMetaBuilder {
66 block_size: Option<usize>,
67 mem_gpu: MemoryGpuConfig,
68 cache_type: PagedCacheType,
69}
70
71impl Default for PagedAttentionMetaBuilder {
72 fn default() -> Self {
73 Self {
74 block_size: None,
75 mem_gpu: MemoryGpuConfig::ContextSize(4096),
76 cache_type: PagedCacheType::Auto,
77 }
78 }
79}
80
81impl PagedAttentionMetaBuilder {
82 pub fn with_block_size(mut self, block_size: usize) -> Self {
83 self.block_size = Some(block_size);
84 self
85 }
86
87 pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
88 self.mem_gpu = mem_gpu;
89 self
90 }
91
92 pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
93 self.cache_type = cache_type;
94 self
95 }
96
97 pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
98 PagedAttentionConfig::new(self.block_size, self.mem_gpu, self.cache_type)
99 }
100}
101
102impl TextModelBuilder {
103 pub fn new(model_id: impl ToString) -> Self {
111 Self {
112 model_id: model_id.to_string(),
113 topology: None,
114 topology_path: None,
115 organization: IsqOrganization::Default,
116 write_uqff: None,
117 from_uqff: None,
118 chat_template: None,
119 tokenizer_json: None,
120 loader_type: None,
121 dtype: ModelDType::Auto,
122 force_cpu: false,
123 token_source: TokenSource::CacheToken,
124 hf_revision: None,
125 isq: None,
126 paged_attn_cfg: None,
127 max_num_seqs: 32,
128 no_kv_cache: false,
129 prefix_cache_n: Some(16),
130 with_logging: false,
131 device_mapping: None,
132 imatrix: None,
133 calibration_file: None,
134 jinja_explicit: None,
135 throughput_logging: false,
136 hf_cache_path: None,
137 search_embedding_model: None,
138 search_callback: None,
139 tool_callbacks: HashMap::new(),
140 tool_callbacks_with_tools: HashMap::new(),
141 mcp_client_config: None,
142 device: None,
143 matformer_config_path: None,
144 matformer_slice_name: None,
145 }
146 }
147
148 pub fn with_search(mut self, search_embedding_model: SearchEmbeddingModel) -> Self {
150 self.search_embedding_model = Some(search_embedding_model);
151 self
152 }
153
154 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
156 self.search_callback = Some(callback);
157 self
158 }
159
160 pub fn with_tool_callback(
162 mut self,
163 name: impl Into<String>,
164 callback: Arc<ToolCallback>,
165 ) -> Self {
166 self.tool_callbacks.insert(name.into(), callback);
167 self
168 }
169
170 pub fn with_tool_callback_and_tool(
173 mut self,
174 name: impl Into<String>,
175 callback: Arc<ToolCallback>,
176 tool: Tool,
177 ) -> Self {
178 let name = name.into();
179 self.tool_callbacks_with_tools
180 .insert(name, ToolCallbackWithTool { callback, tool });
181 self
182 }
183
184 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
187 self.mcp_client_config = Some(config);
188 self
189 }
190
191 pub fn with_throughput_logging(mut self) -> Self {
193 self.throughput_logging = true;
194 self
195 }
196
197 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
199 self.jinja_explicit = Some(jinja_explicit);
200 self
201 }
202
203 pub fn with_topology(mut self, topology: Topology) -> Self {
205 self.topology = Some(topology);
206 self
207 }
208
209 pub fn with_topology_from_path<P: AsRef<std::path::Path>>(
212 mut self,
213 path: P,
214 ) -> anyhow::Result<Self> {
215 let path_str = path.as_ref().to_string_lossy().to_string();
216 self.topology = Some(Topology::from_path(&path)?);
217 self.topology_path = Some(path_str);
218 Ok(self)
219 }
220
221 pub fn with_mixture_qexperts_isq(mut self) -> Self {
223 self.organization = IsqOrganization::MoeExpertsOnly;
224 self
225 }
226
227 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
229 self.chat_template = Some(chat_template.to_string());
230 self
231 }
232
233 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
235 self.tokenizer_json = Some(tokenizer_json.to_string());
236 self
237 }
238
239 pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
242 self.loader_type = Some(loader_type);
243 self
244 }
245
246 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
248 self.dtype = dtype;
249 self
250 }
251
252 pub fn with_force_cpu(mut self) -> Self {
254 self.force_cpu = true;
255 self
256 }
257
258 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
260 self.token_source = token_source;
261 self
262 }
263
264 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
266 self.hf_revision = Some(revision.to_string());
267 self
268 }
269
270 pub fn with_isq(mut self, isq: IsqType) -> Self {
272 self.isq = Some(isq);
273 self
274 }
275
276 pub fn with_imatrix(mut self, path: PathBuf) -> Self {
278 self.imatrix = Some(path);
279 self
280 }
281
282 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
284 self.calibration_file = Some(path);
285 self
286 }
287
288 pub fn with_paged_attn(
293 mut self,
294 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
295 ) -> anyhow::Result<Self> {
296 if paged_attn_supported() {
297 self.paged_attn_cfg = Some(paged_attn_cfg()?);
298 } else {
299 self.paged_attn_cfg = None;
300 }
301 Ok(self)
302 }
303
304 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
306 self.max_num_seqs = max_num_seqs;
307 self
308 }
309
310 pub fn with_no_kv_cache(mut self) -> Self {
312 self.no_kv_cache = true;
313 self
314 }
315
316 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
318 self.prefix_cache_n = n_seqs;
319 self
320 }
321
322 pub fn with_logging(mut self) -> Self {
324 self.with_logging = true;
325 self
326 }
327
328 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
330 self.device_mapping = Some(device_mapping);
331 self
332 }
333
334 #[deprecated(
335 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
336 )]
337 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
345 self.from_uqff = Some(path);
346 self
347 }
348
349 pub fn write_uqff(mut self, path: PathBuf) -> Self {
360 self.write_uqff = Some(path);
361 self
362 }
363
364 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
366 self.hf_cache_path = Some(hf_cache_path);
367 self
368 }
369
370 pub fn with_device(mut self, device: Device) -> Self {
372 self.device = Some(device);
373 self
374 }
375
376 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
378 self.matformer_config_path = Some(path);
379 self
380 }
381
382 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
384 self.matformer_slice_name = Some(name);
385 self
386 }
387
388 pub async fn build(self) -> anyhow::Result<Model> {
389 let (pipeline, scheduler_config, add_model_config) = build_text_pipeline(self).await?;
390 Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
391 }
392}
393
394#[derive(Clone)]
395pub struct UqffTextModelBuilder(TextModelBuilder);
398
399impl UqffTextModelBuilder {
400 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
407 let mut inner = TextModelBuilder::new(model_id);
408 inner.from_uqff = Some(uqff_file);
409 Self(inner)
410 }
411
412 pub async fn build(self) -> anyhow::Result<Model> {
413 self.0.build().await
414 }
415
416 pub fn into_inner(self) -> TextModelBuilder {
418 self.0
419 }
420}
421
422impl Deref for UqffTextModelBuilder {
423 type Target = TextModelBuilder;
424
425 fn deref(&self) -> &Self::Target {
426 &self.0
427 }
428}
429
430impl DerefMut for UqffTextModelBuilder {
431 fn deref_mut(&mut self) -> &mut Self::Target {
432 &mut self.0
433 }
434}
435
436impl From<UqffTextModelBuilder> for TextModelBuilder {
437 fn from(value: UqffTextModelBuilder) -> Self {
438 value.0
439 }
440}