1use super::varbuilder_utils::{
2 from_mmaped_safetensors, load_preload_adapters, DeviceForLoadTensor,
3};
4use anyhow::Result;
5use candle_core::{quantized::ggml_file, DType};
6use mistralrs_quant::ShardedVarBuilder;
7use std::{collections::HashMap, path::PathBuf, sync::Arc};
8
9use crate::{
10 device_map::DeviceMapper,
11 gguf::Content,
12 lora::{LoraConfig, Ordering},
13 paged_attention::AttentionImplementation,
14 pipeline::ModelPaths,
15 xlora_models::XLoraConfig,
16};
17
18#[derive(derive_more::From)]
19pub struct FileGGML {
20 pub ct: ggml_file::Content,
21 pub gqa: usize,
22 pub dtype: DType,
23}
24
25#[derive(derive_more::From)]
26pub struct Device<'a> {
27 device: &'a candle_core::Device,
28 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
29}
30
31pub struct Adapter<'a> {
32 pub xlora_config: Option<XLoraConfig>,
33 pub lora_config: &'a [((String, String), LoraConfig)],
34 pub vb: ShardedVarBuilder<'a>,
35 pub ordering: &'a Ordering,
36 pub preload_adapters: Option<HashMap<String, (ShardedVarBuilder<'a>, LoraConfig)>>,
37}
38
39impl<'a> Adapter<'a> {
40 #[allow(clippy::borrowed_box)]
45 pub fn try_new<'b: 'a>(
46 paths: &'b Box<dyn ModelPaths>,
47 device: &'b candle_core::Device,
48 silent: bool,
49 is_xlora: bool,
50 ) -> Result<Self> {
51 let lora_config = paths.get_adapter_configs().as_ref().unwrap();
52 let ordering = paths.get_ordering().as_ref().unwrap();
53 let preload_adapters = load_preload_adapters(
54 paths.get_lora_preload_adapter_info(),
55 candle_core::DType::F32,
56 device,
57 silent,
58 )?;
59
60 let mut xlora_paths: Vec<PathBuf> = vec![];
62 let mut xlora_config: Option<XLoraConfig> = None;
63 if is_xlora {
64 xlora_paths = vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()];
65 xlora_config = Some(paths.get_classifier_config().as_ref().unwrap().clone());
66 }
67
68 let vb = from_mmaped_safetensors(
71 xlora_paths,
72 paths
73 .get_adapter_filenames()
74 .as_ref()
75 .unwrap()
76 .iter()
77 .map(|(_, x)| (*x).to_owned())
78 .collect::<Vec<_>>(),
79 Some(candle_core::DType::F32),
80 device,
81 vec![None],
82 silent,
83 None,
84 |_| true,
85 Arc::new(|_| DeviceForLoadTensor::Base),
86 )?;
87
88 Ok(Self {
89 lora_config,
90 xlora_config,
91 vb,
92 ordering,
93 preload_adapters,
94 })
95 }
96}
97
98pub struct ParamsGGML(pub FileGGML);
100pub struct ParamsGGUF<'a, R: std::io::Seek + std::io::Read>(
101 pub Content<'a, R>,
102 pub Device<'a>,
103 pub AttentionImplementation,
104 pub DType,
105);
106
107pub struct NoAdapter {}
109
110pub trait QuantParams {}
113impl QuantParams for ParamsGGML {}
114impl<R: std::io::Seek + std::io::Read> QuantParams for ParamsGGUF<'_, R> {}
115
116pub trait MaybeAdapter {}
118impl MaybeAdapter for Adapter<'_> {}
119impl MaybeAdapter for NoAdapter {}
120
121#[derive(derive_more::From)]
123pub struct Config<Q: QuantParams, A: MaybeAdapter> {
124 pub quant: Q,
125 pub adapter: A,
126}
127
128#[allow(clippy::large_enum_variant)]
132#[derive(variantly::Variantly)]
133pub enum ModelParams<'a, Q>
134where
135 Q: QuantParams,
136{
137 Quantized(Config<Q, NoAdapter>),
138 Adapted(Config<Q, Adapter<'a>>),
139}
140
141impl<'a, Q: QuantParams> ModelParams<'a, Q> {
146 pub fn new<'b: 'a>(quant: Q, adapter: Option<Adapter<'b>>) -> Self {
147 match adapter {
148 None => Self::Quantized((quant, NoAdapter {}).into()),
149 Some(a) => Self::Adapted((quant, a).into()),
150 }
151 }
152}
153
154pub trait FromGGML {
157 fn from_ggml(
158 ct: ggml_file::Content,
159 gqa: usize,
160 dtype: DType,
161 ) -> Result<Self, candle_core::Error>
162 where
163 Self: Sized;
164}
165
166pub trait FromGGUF {
167 fn from_gguf<R: std::io::Seek + std::io::Read>(
168 ct: Content<'_, R>,
169 device: &candle_core::Device,
170 mapper: Box<dyn DeviceMapper + Send + Sync>,
171 attention_mechanism: AttentionImplementation,
172 dtype: DType,
173 ) -> Result<Self, candle_core::Error>
174 where
175 Self: Sized;
176}
177
178pub trait FromAdapterGGML {
180 #[allow(clippy::too_many_arguments)]
181 fn from_ggml(
182 ct: ggml_file::Content,
183 gqa: usize,
184 lora_config: &[((String, String), LoraConfig)],
185 vb: &ShardedVarBuilder,
186 ordering: &Ordering,
187 xlora_config: Option<XLoraConfig>,
188 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
189 dtype: DType,
190 ) -> Result<Self, candle_core::Error>
191 where
192 Self: Sized;
193}
194pub trait FromAdapterGGUF {
195 #[allow(clippy::too_many_arguments)]
196 fn from_gguf<R: std::io::Seek + std::io::Read>(
197 ct: Content<'_, R>,
198 device: &candle_core::Device,
199 lora_config: &[((String, String), LoraConfig)],
200 vb: &ShardedVarBuilder,
201 ordering: &Ordering,
202 xlora_config: Option<XLoraConfig>,
203 mapper: Box<dyn DeviceMapper + Send + Sync>,
204 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
205 dtype: DType,
206 ) -> Result<Self, candle_core::Error>
207 where
208 Self: Sized;
209}
210
211impl Config<ParamsGGML, NoAdapter> {
213 pub fn try_into_model<T: FromGGML>(self) -> Result<T, candle_core::Error> {
214 let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
216
217 T::from_ggml(ct, gqa, dtype)
219 }
220}
221
222impl Config<ParamsGGML, Adapter<'_>> {
223 pub fn try_into_model<T: FromAdapterGGML>(self) -> Result<T, candle_core::Error> {
224 let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
226
227 let Adapter {
228 xlora_config,
229 lora_config,
230 vb,
231 ordering,
232 preload_adapters,
233 } = self.adapter;
234
235 T::from_ggml(
237 ct,
238 gqa,
239 lora_config,
240 &vb,
241 ordering,
242 xlora_config,
243 &preload_adapters,
244 dtype,
245 )
246 }
247}
248
249impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, NoAdapter> {
250 pub fn try_into_model<T: FromGGUF>(self) -> Result<T, candle_core::Error> {
251 let ParamsGGUF(ct, Device { device, mapper }, attention_implementation, dtype) = self.quant;
253
254 T::from_gguf(ct, device, mapper, attention_implementation, dtype)
256 }
257}
258
259impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, Adapter<'_>> {
260 pub fn try_into_model<T: FromAdapterGGUF>(self) -> Result<T, candle_core::Error> {
261 let ParamsGGUF(ct, Device { device, mapper }, _attention_implementation, dtype) =
263 self.quant;
264
265 let Adapter {
266 xlora_config,
267 lora_config,
268 vb,
269 ordering,
270 preload_adapters,
271 } = self.adapter;
272
273 T::from_gguf(
275 ct,
276 device,
277 lora_config,
278 &vb,
279 ordering,
280 xlora_config,
281 mapper,
282 &preload_adapters,
283 dtype,
284 )
285 }
286}
287
288use crate::{
289 models::quantized_llama::ModelWeights as QLlama,
290 models::quantized_phi2::ModelWeights as QPhi,
291 models::quantized_phi3::ModelWeights as QPhi3,
292 models::quantized_qwen2::ModelWeights as QQwen2,
293 models::quantized_starcoder2::ModelWeights as QStarcoder2,
294 xlora_models::{XLoraQLlama, XLoraQPhi3},
295};
296use akin::akin;
297
298impl TryFrom<ModelParams<'_, ParamsGGML>> for QLlama {
299 type Error = candle_core::Error;
300
301 fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
302 let config = params.expect_quantized("`Config` should be GGML Quantized");
303 config.try_into_model()
304 }
305}
306
307impl TryFrom<ModelParams<'_, ParamsGGML>> for XLoraQLlama {
308 type Error = candle_core::Error;
309
310 fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
311 let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter");
312 config.try_into_model()
313 }
314}
315
316akin! {
317 let &models_gguf = [QLlama, QPhi, QPhi3, QStarcoder2, QQwen2];
318
319 impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf {
320 type Error = candle_core::Error;
321
322 fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
323 let config = params.expect_quantized("`Config` should be GGUF Quantized");
324 config.try_into_model()
325 }
326 }
327}
328
329akin! {
330 let &models_gguf_a = [XLoraQLlama, XLoraQPhi3];
331
332 impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf_a {
333 type Error = candle_core::Error;
334
335 fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
336 let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter");
337 config.try_into_model()
338 }
339 }
340}