mistralrs_core/utils/
model_config.rs

1use super::varbuilder_utils::{
2    from_mmaped_safetensors, load_preload_adapters, DeviceForLoadTensor,
3};
4use anyhow::Result;
5use candle_core::{quantized::ggml_file, DType};
6use mistralrs_quant::ShardedVarBuilder;
7use std::{collections::HashMap, path::PathBuf, sync::Arc};
8
9use crate::{
10    device_map::DeviceMapper,
11    gguf::Content,
12    lora::{LoraConfig, Ordering},
13    paged_attention::AttentionImplementation,
14    pipeline::ModelPaths,
15    xlora_models::XLoraConfig,
16};
17
18#[derive(derive_more::From)]
19pub struct FileGGML {
20    pub ct: ggml_file::Content,
21    pub gqa: usize,
22    pub dtype: DType,
23}
24
25#[derive(derive_more::From)]
26pub struct Device<'a> {
27    device: &'a candle_core::Device,
28    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
29}
30
31pub struct Adapter<'a> {
32    pub xlora_config: Option<XLoraConfig>,
33    pub lora_config: &'a [((String, String), LoraConfig)],
34    pub vb: ShardedVarBuilder<'a>,
35    pub ordering: &'a Ordering,
36    pub preload_adapters: Option<HashMap<String, (ShardedVarBuilder<'a>, LoraConfig)>>,
37}
38
39impl<'a> Adapter<'a> {
40    // NOTE: It is not possible to store references for values returned by: load_preload_adapters() + from_mmaped_safetensors(),
41    // As referenced value would drop after this method, Adapter takes ownership of vb + preload_adapters
42    // and then passes by reference to the `from_gguf()` / `from_ggml()` methods when proxying to params.
43    // NOTE: Due to reference usage persisting in returned struct, additional lifetime annotations were required.
44    #[allow(clippy::borrowed_box)]
45    pub fn try_new<'b: 'a>(
46        paths: &'b Box<dyn ModelPaths>,
47        device: &'b candle_core::Device,
48        silent: bool,
49        is_xlora: bool,
50    ) -> Result<Self> {
51        let lora_config = paths.get_adapter_configs().as_ref().unwrap();
52        let ordering = paths.get_ordering().as_ref().unwrap();
53        let preload_adapters = load_preload_adapters(
54            paths.get_lora_preload_adapter_info(),
55            candle_core::DType::F32,
56            device,
57            silent,
58        )?;
59
60        // X-LoRA support:
61        let mut xlora_paths: Vec<PathBuf> = vec![];
62        let mut xlora_config: Option<XLoraConfig> = None;
63        if is_xlora {
64            xlora_paths = vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()];
65            xlora_config = Some(paths.get_classifier_config().as_ref().unwrap().clone());
66        }
67
68        // Create VarBuilder:
69        // TODO: `from_mmaped_safetensors` has `xlora_paths` as the 2nd param (_valid but params need to be named better_)
70        let vb = from_mmaped_safetensors(
71            xlora_paths,
72            paths
73                .get_adapter_filenames()
74                .as_ref()
75                .unwrap()
76                .iter()
77                .map(|(_, x)| (*x).to_owned())
78                .collect::<Vec<_>>(),
79            Some(candle_core::DType::F32),
80            device,
81            vec![None],
82            silent,
83            None,
84            |_| true,
85            Arc::new(|_| DeviceForLoadTensor::Base),
86        )?;
87
88        Ok(Self {
89            lora_config,
90            xlora_config,
91            vb,
92            ordering,
93            preload_adapters,
94        })
95    }
96}
97
98// New type wrappers that segment the distinct parameter sets used by `from_ggml()` + `from_gguf()` methods:
99pub struct ParamsGGML(pub FileGGML);
100pub struct ParamsGGUF<'a, R: std::io::Seek + std::io::Read>(
101    pub Content<'a, R>,
102    pub Device<'a>,
103    pub AttentionImplementation,
104    pub DType,
105);
106
107// A `None` type vs the `Some` type (`Adapter<'a>`)
108pub struct NoAdapter {}
109
110// Marker traits to restrict type input:
111// (required workaround to support impl on subtypes, otherwise would use an enum)
112pub trait QuantParams {}
113impl QuantParams for ParamsGGML {}
114impl<R: std::io::Seek + std::io::Read> QuantParams for ParamsGGUF<'_, R> {}
115
116// Emulates `Option<Adapter>` but is compatible as a type bound in `impl<T>` for Some vs None
117pub trait MaybeAdapter {}
118impl MaybeAdapter for Adapter<'_> {}
119impl MaybeAdapter for NoAdapter {}
120
121// `derive_more::From` provides a terser construction for enum variants of `ModelParams`.
122#[derive(derive_more::From)]
123pub struct Config<Q: QuantParams, A: MaybeAdapter> {
124    pub quant: Q,
125    pub adapter: A,
126}
127
128// NOTE: Variantly used for `.expect_quantized()` / `.expect_adapted()` methods
129// `where` clause required due to bug with inline bounds:
130// https://github.com/luker-os/variantly/pull/16
131#[allow(clippy::large_enum_variant)]
132#[derive(variantly::Variantly)]
133pub enum ModelParams<'a, Q>
134where
135    Q: QuantParams,
136{
137    Quantized(Config<Q, NoAdapter>),
138    Adapted(Config<Q, Adapter<'a>>),
139}
140
141// A `builder()` method is derived from the `new()` method and it's params (derived builder struct fields).
142// NOTE: Intended to be built via fluent API in a single line, cannot conditionally append params.
143// `.adapter(Adapter<' >)` or for conditional usage `.and_adapter(Option<Adapter<' >)` can be used.
144// Otherwise omitting an `.adapter()` call prior to calling `build()` is ok, defaults to `None`.
145impl<'a, Q: QuantParams> ModelParams<'a, Q> {
146    pub fn new<'b: 'a>(quant: Q, adapter: Option<Adapter<'b>>) -> Self {
147        match adapter {
148            None => Self::Quantized((quant, NoAdapter {}).into()),
149            Some(a) => Self::Adapted((quant, a).into()),
150        }
151    }
152}
153
154// Traits for the existing methods used across various model types to impl `from_ggml()` / `from_gguf()`
155// Basic:
156pub trait FromGGML {
157    fn from_ggml(
158        ct: ggml_file::Content,
159        gqa: usize,
160        dtype: DType,
161    ) -> Result<Self, candle_core::Error>
162    where
163        Self: Sized;
164}
165
166pub trait FromGGUF {
167    fn from_gguf<R: std::io::Seek + std::io::Read>(
168        ct: Content<'_, R>,
169        device: &candle_core::Device,
170        mapper: Box<dyn DeviceMapper + Send + Sync>,
171        attention_mechanism: AttentionImplementation,
172        dtype: DType,
173    ) -> Result<Self, candle_core::Error>
174    where
175        Self: Sized;
176}
177
178// Extended variants:
179pub trait FromAdapterGGML {
180    #[allow(clippy::too_many_arguments)]
181    fn from_ggml(
182        ct: ggml_file::Content,
183        gqa: usize,
184        lora_config: &[((String, String), LoraConfig)],
185        vb: &ShardedVarBuilder,
186        ordering: &Ordering,
187        xlora_config: Option<XLoraConfig>,
188        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
189        dtype: DType,
190    ) -> Result<Self, candle_core::Error>
191    where
192        Self: Sized;
193}
194pub trait FromAdapterGGUF {
195    #[allow(clippy::too_many_arguments)]
196    fn from_gguf<R: std::io::Seek + std::io::Read>(
197        ct: Content<'_, R>,
198        device: &candle_core::Device,
199        lora_config: &[((String, String), LoraConfig)],
200        vb: &ShardedVarBuilder,
201        ordering: &Ordering,
202        xlora_config: Option<XLoraConfig>,
203        mapper: Box<dyn DeviceMapper + Send + Sync>,
204        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
205        dtype: DType,
206    ) -> Result<Self, candle_core::Error>
207    where
208        Self: Sized;
209}
210
211// NOTE: Below is a workaround to proxy params to the existing API methods `get_gguf()` / `get_gmml()` traits covered above.
212impl Config<ParamsGGML, NoAdapter> {
213    pub fn try_into_model<T: FromGGML>(self) -> Result<T, candle_core::Error> {
214        // Destructure props:
215        let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
216
217        // Forwards all structured fields above into the required flattened param sequence:
218        T::from_ggml(ct, gqa, dtype)
219    }
220}
221
222impl Config<ParamsGGML, Adapter<'_>> {
223    pub fn try_into_model<T: FromAdapterGGML>(self) -> Result<T, candle_core::Error> {
224        // Destructure props:
225        let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
226
227        let Adapter {
228            xlora_config,
229            lora_config,
230            vb,
231            ordering,
232            preload_adapters,
233        } = self.adapter;
234
235        // Forwards all structured fields above into the required flattened param sequence:
236        T::from_ggml(
237            ct,
238            gqa,
239            lora_config,
240            &vb,
241            ordering,
242            xlora_config,
243            &preload_adapters,
244            dtype,
245        )
246    }
247}
248
249impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, NoAdapter> {
250    pub fn try_into_model<T: FromGGUF>(self) -> Result<T, candle_core::Error> {
251        // Destructure props:
252        let ParamsGGUF(ct, Device { device, mapper }, attention_implementation, dtype) = self.quant;
253
254        // Forwards all structured fields above into the required flattened param sequence:
255        T::from_gguf(ct, device, mapper, attention_implementation, dtype)
256    }
257}
258
259impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, Adapter<'_>> {
260    pub fn try_into_model<T: FromAdapterGGUF>(self) -> Result<T, candle_core::Error> {
261        // Destructure props:
262        let ParamsGGUF(ct, Device { device, mapper }, _attention_implementation, dtype) =
263            self.quant;
264
265        let Adapter {
266            xlora_config,
267            lora_config,
268            vb,
269            ordering,
270            preload_adapters,
271        } = self.adapter;
272
273        // Forwards all structured fields above into the required flattened param sequence:
274        T::from_gguf(
275            ct,
276            device,
277            lora_config,
278            &vb,
279            ordering,
280            xlora_config,
281            mapper,
282            &preload_adapters,
283            dtype,
284        )
285    }
286}
287
288use crate::{
289    models::quantized_llama::ModelWeights as QLlama,
290    models::quantized_phi2::ModelWeights as QPhi,
291    models::quantized_phi3::ModelWeights as QPhi3,
292    models::quantized_qwen2::ModelWeights as QQwen2,
293    models::quantized_starcoder2::ModelWeights as QStarcoder2,
294    xlora_models::{XLoraQLlama, XLoraQPhi3},
295};
296use akin::akin;
297
298impl TryFrom<ModelParams<'_, ParamsGGML>> for QLlama {
299    type Error = candle_core::Error;
300
301    fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
302        let config = params.expect_quantized("`Config` should be GGML Quantized");
303        config.try_into_model()
304    }
305}
306
307impl TryFrom<ModelParams<'_, ParamsGGML>> for XLoraQLlama {
308    type Error = candle_core::Error;
309
310    fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
311        let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter");
312        config.try_into_model()
313    }
314}
315
316akin! {
317    let &models_gguf = [QLlama, QPhi, QPhi3, QStarcoder2, QQwen2];
318
319    impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf {
320        type Error = candle_core::Error;
321
322        fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
323            let config = params.expect_quantized("`Config` should be GGUF Quantized");
324            config.try_into_model()
325        }
326    }
327}
328
329akin! {
330    let &models_gguf_a = [XLoraQLlama, XLoraQPhi3];
331
332    impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf_a {
333        type Error = candle_core::Error;
334
335        fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
336            let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter");
337            config.try_into_model()
338        }
339    }
340}