1use super::varbuilder_utils::{
2 from_mmaped_safetensors, load_preload_adapters, DeviceForLoadTensor,
3};
4use anyhow::Result;
5use candle_core::{quantized::ggml_file, DType};
6use mistralrs_quant::ShardedVarBuilder;
7use std::{collections::HashMap, path::PathBuf, sync::Arc};
8
9use crate::{
10 device_map::DeviceMapper,
11 gguf::Content,
12 lora::{LoraConfig, Ordering},
13 paged_attention::AttentionImplementation,
14 pipeline::{AdapterPaths, ModelPaths},
15 xlora_models::XLoraConfig,
16};
17
18#[derive(derive_more::From)]
19pub struct FileGGML {
20 pub ct: ggml_file::Content,
21 pub gqa: usize,
22 pub dtype: DType,
23}
24
25#[derive(derive_more::From)]
26pub struct Device<'a> {
27 device: &'a candle_core::Device,
28 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
29}
30
31pub struct Adapter<'a> {
32 pub xlora_config: Option<XLoraConfig>,
33 pub lora_config: &'a [((String, String), LoraConfig)],
34 pub vb: ShardedVarBuilder,
35 pub ordering: &'a Ordering,
36 pub preload_adapters: Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
37}
38
39impl<'a> Adapter<'a> {
40 #[allow(clippy::borrowed_box)]
45 pub fn try_new<'b: 'a>(
46 paths: &'b Box<dyn ModelPaths>,
47 device: &'b candle_core::Device,
48 silent: bool,
49 is_xlora: bool,
50 ) -> Result<Self> {
51 let AdapterPaths::XLora {
52 adapter_configs,
53 adapter_safetensors,
54 classifier_path,
55 xlora_order,
56 xlora_config,
57 lora_preload_adapter_info,
58 } = paths.get_adapter_paths()
59 else {
60 todo!()
61 };
62
63 let lora_config = adapter_configs.as_ref().unwrap();
64 let ordering = xlora_order.as_ref().unwrap();
65 let preload_adapters = load_preload_adapters(
66 lora_preload_adapter_info,
67 candle_core::DType::F32,
68 device,
69 silent,
70 )?;
71
72 let mut xlora_paths: Vec<PathBuf> = vec![];
74 if is_xlora {
75 xlora_paths = vec![classifier_path.as_ref().unwrap().to_path_buf()];
76 }
77
78 let vb = from_mmaped_safetensors(
81 xlora_paths,
82 adapter_safetensors
83 .as_ref()
84 .unwrap()
85 .iter()
86 .map(|(_, x)| (*x).to_owned())
87 .collect::<Vec<_>>(),
88 Some(candle_core::DType::F32),
89 device,
90 vec![None],
91 silent,
92 None,
93 |_| true,
94 Arc::new(|_| DeviceForLoadTensor::Base),
95 )?;
96
97 Ok(Self {
98 lora_config,
99 xlora_config: xlora_config.clone(),
100 vb,
101 ordering,
102 preload_adapters,
103 })
104 }
105}
106
107pub struct ParamsGGML(pub FileGGML);
109pub struct ParamsGGUF<'a, R: std::io::Seek + std::io::Read>(
110 pub Content<'a, R>,
111 pub Device<'a>,
112 pub AttentionImplementation,
113 pub DType,
114);
115
116pub struct NoAdapter {}
118
119pub trait QuantParams {}
122impl QuantParams for ParamsGGML {}
123impl<R: std::io::Seek + std::io::Read> QuantParams for ParamsGGUF<'_, R> {}
124
125pub trait MaybeAdapter {}
127impl MaybeAdapter for Adapter<'_> {}
128impl MaybeAdapter for NoAdapter {}
129
130#[derive(derive_more::From)]
132pub struct Config<Q: QuantParams, A: MaybeAdapter> {
133 pub quant: Q,
134 pub adapter: A,
135}
136
137#[allow(clippy::large_enum_variant)]
141#[derive(variantly::Variantly)]
142pub enum ModelParams<'a, Q>
143where
144 Q: QuantParams,
145{
146 Quantized(Config<Q, NoAdapter>),
147 Adapted(Config<Q, Adapter<'a>>),
148}
149
150impl<'a, Q: QuantParams> ModelParams<'a, Q> {
155 pub fn new<'b: 'a>(quant: Q, adapter: Option<Adapter<'b>>) -> Self {
156 match adapter {
157 None => Self::Quantized((quant, NoAdapter {}).into()),
158 Some(a) => Self::Adapted((quant, a).into()),
159 }
160 }
161}
162
163pub trait FromGGML {
166 fn from_ggml(
167 ct: ggml_file::Content,
168 gqa: usize,
169 dtype: DType,
170 ) -> Result<Self, candle_core::Error>
171 where
172 Self: Sized;
173}
174
175pub trait FromGGUF {
176 fn from_gguf<R: std::io::Seek + std::io::Read>(
177 ct: Content<'_, R>,
178 device: &candle_core::Device,
179 mapper: Box<dyn DeviceMapper + Send + Sync>,
180 attention_mechanism: AttentionImplementation,
181 dtype: DType,
182 ) -> Result<Self, candle_core::Error>
183 where
184 Self: Sized;
185}
186
187pub trait FromAdapterGGML {
189 #[allow(clippy::too_many_arguments)]
190 fn from_ggml(
191 ct: ggml_file::Content,
192 gqa: usize,
193 lora_config: &[((String, String), LoraConfig)],
194 vb: &ShardedVarBuilder,
195 ordering: &Ordering,
196 xlora_config: Option<XLoraConfig>,
197 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
198 dtype: DType,
199 ) -> Result<Self, candle_core::Error>
200 where
201 Self: Sized;
202}
203pub trait FromAdapterGGUF {
204 #[allow(clippy::too_many_arguments)]
205 fn from_gguf<R: std::io::Seek + std::io::Read>(
206 ct: Content<'_, R>,
207 device: &candle_core::Device,
208 lora_config: &[((String, String), LoraConfig)],
209 vb: &ShardedVarBuilder,
210 ordering: &Ordering,
211 xlora_config: Option<XLoraConfig>,
212 mapper: Box<dyn DeviceMapper + Send + Sync>,
213 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
214 dtype: DType,
215 ) -> Result<Self, candle_core::Error>
216 where
217 Self: Sized;
218}
219
220impl Config<ParamsGGML, NoAdapter> {
222 pub fn try_into_model<T: FromGGML>(self) -> Result<T, candle_core::Error> {
223 let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
225
226 T::from_ggml(ct, gqa, dtype)
228 }
229}
230
231impl Config<ParamsGGML, Adapter<'_>> {
232 pub fn try_into_model<T: FromAdapterGGML>(self) -> Result<T, candle_core::Error> {
233 let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
235
236 let Adapter {
237 xlora_config,
238 lora_config,
239 vb,
240 ordering,
241 preload_adapters,
242 } = self.adapter;
243
244 T::from_ggml(
246 ct,
247 gqa,
248 lora_config,
249 &vb,
250 ordering,
251 xlora_config,
252 &preload_adapters,
253 dtype,
254 )
255 }
256}
257
258impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, NoAdapter> {
259 pub fn try_into_model<T: FromGGUF>(self) -> Result<T, candle_core::Error> {
260 let ParamsGGUF(ct, Device { device, mapper }, attention_implementation, dtype) = self.quant;
262
263 T::from_gguf(ct, device, mapper, attention_implementation, dtype)
265 }
266}
267
268impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, Adapter<'_>> {
269 pub fn try_into_model<T: FromAdapterGGUF>(self) -> Result<T, candle_core::Error> {
270 let ParamsGGUF(ct, Device { device, mapper }, _attention_implementation, dtype) =
272 self.quant;
273
274 let Adapter {
275 xlora_config,
276 lora_config,
277 vb,
278 ordering,
279 preload_adapters,
280 } = self.adapter;
281
282 T::from_gguf(
284 ct,
285 device,
286 lora_config,
287 &vb,
288 ordering,
289 xlora_config,
290 mapper,
291 &preload_adapters,
292 dtype,
293 )
294 }
295}
296
297use crate::{
298 models::quantized_llama::ModelWeights as QLlama,
299 models::quantized_phi2::ModelWeights as QPhi,
300 models::quantized_phi3::ModelWeights as QPhi3,
301 models::quantized_qwen2::ModelWeights as QQwen2,
302 models::quantized_starcoder2::ModelWeights as QStarcoder2,
303 xlora_models::{XLoraQLlama, XLoraQPhi3},
304};
305use akin::akin;
306
307impl TryFrom<ModelParams<'_, ParamsGGML>> for QLlama {
308 type Error = candle_core::Error;
309
310 fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
311 let config = params.expect_quantized("`Config` should be GGML Quantized");
312 config.try_into_model()
313 }
314}
315
316impl TryFrom<ModelParams<'_, ParamsGGML>> for XLoraQLlama {
317 type Error = candle_core::Error;
318
319 fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
320 let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter");
321 config.try_into_model()
322 }
323}
324
325akin! {
326 let &models_gguf = [QLlama, QPhi, QPhi3, QStarcoder2, QQwen2];
327
328 impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf {
329 type Error = candle_core::Error;
330
331 fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
332 let config = params.expect_quantized("`Config` should be GGUF Quantized");
333 config.try_into_model()
334 }
335 }
336}
337
338akin! {
339 let &models_gguf_a = [XLoraQLlama, XLoraQPhi3];
340
341 impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf_a {
342 type Error = candle_core::Error;
343
344 fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
345 let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter");
346 config.try_into_model()
347 }
348 }
349}