mistralrs_core/utils/
model_config.rs

1use super::varbuilder_utils::{
2    from_mmaped_safetensors, load_preload_adapters, DeviceForLoadTensor,
3};
4use anyhow::Result;
5use candle_core::{quantized::ggml_file, DType};
6use mistralrs_quant::ShardedVarBuilder;
7use std::{collections::HashMap, path::PathBuf, sync::Arc};
8
9use crate::{
10    device_map::DeviceMapper,
11    gguf::Content,
12    lora::{LoraConfig, Ordering},
13    paged_attention::AttentionImplementation,
14    pipeline::{AdapterPaths, ModelPaths},
15    xlora_models::XLoraConfig,
16};
17
18#[derive(derive_more::From)]
19pub struct FileGGML {
20    pub ct: ggml_file::Content,
21    pub gqa: usize,
22    pub dtype: DType,
23}
24
25#[derive(derive_more::From)]
26pub struct Device<'a> {
27    device: &'a candle_core::Device,
28    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
29}
30
31pub struct Adapter<'a> {
32    pub xlora_config: Option<XLoraConfig>,
33    pub lora_config: &'a [((String, String), LoraConfig)],
34    pub vb: ShardedVarBuilder,
35    pub ordering: &'a Ordering,
36    pub preload_adapters: Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
37}
38
39impl<'a> Adapter<'a> {
40    // NOTE: It is not possible to store references for values returned by: load_preload_adapters() + from_mmaped_safetensors(),
41    // As referenced value would drop after this method, Adapter takes ownership of vb + preload_adapters
42    // and then passes by reference to the `from_gguf()` / `from_ggml()` methods when proxying to params.
43    // NOTE: Due to reference usage persisting in returned struct, additional lifetime annotations were required.
44    #[allow(clippy::borrowed_box)]
45    pub fn try_new<'b: 'a>(
46        paths: &'b Box<dyn ModelPaths>,
47        device: &'b candle_core::Device,
48        silent: bool,
49        is_xlora: bool,
50    ) -> Result<Self> {
51        let AdapterPaths::XLora {
52            adapter_configs,
53            adapter_safetensors,
54            classifier_path,
55            xlora_order,
56            xlora_config,
57            lora_preload_adapter_info,
58        } = paths.get_adapter_paths()
59        else {
60            todo!()
61        };
62
63        let lora_config = adapter_configs.as_ref().unwrap();
64        let ordering = xlora_order.as_ref().unwrap();
65        let preload_adapters = load_preload_adapters(
66            lora_preload_adapter_info,
67            candle_core::DType::F32,
68            device,
69            silent,
70        )?;
71
72        // X-LoRA support:
73        let mut xlora_paths: Vec<PathBuf> = vec![];
74        if is_xlora {
75            xlora_paths = vec![classifier_path.as_ref().unwrap().to_path_buf()];
76        }
77
78        // Create VarBuilder:
79        // TODO: `from_mmaped_safetensors` has `xlora_paths` as the 2nd param (_valid but params need to be named better_)
80        let vb = from_mmaped_safetensors(
81            xlora_paths,
82            adapter_safetensors
83                .as_ref()
84                .unwrap()
85                .iter()
86                .map(|(_, x)| (*x).to_owned())
87                .collect::<Vec<_>>(),
88            Some(candle_core::DType::F32),
89            device,
90            vec![None],
91            silent,
92            None,
93            |_| true,
94            Arc::new(|_| DeviceForLoadTensor::Base),
95        )?;
96
97        Ok(Self {
98            lora_config,
99            xlora_config: xlora_config.clone(),
100            vb,
101            ordering,
102            preload_adapters,
103        })
104    }
105}
106
107// New type wrappers that segment the distinct parameter sets used by `from_ggml()` + `from_gguf()` methods:
108pub struct ParamsGGML(pub FileGGML);
109pub struct ParamsGGUF<'a, R: std::io::Seek + std::io::Read>(
110    pub Content<'a, R>,
111    pub Device<'a>,
112    pub AttentionImplementation,
113    pub DType,
114);
115
116// A `None` type vs the `Some` type (`Adapter<'a>`)
117pub struct NoAdapter {}
118
119// Marker traits to restrict type input:
120// (required workaround to support impl on subtypes, otherwise would use an enum)
121pub trait QuantParams {}
122impl QuantParams for ParamsGGML {}
123impl<R: std::io::Seek + std::io::Read> QuantParams for ParamsGGUF<'_, R> {}
124
125// Emulates `Option<Adapter>` but is compatible as a type bound in `impl<T>` for Some vs None
126pub trait MaybeAdapter {}
127impl MaybeAdapter for Adapter<'_> {}
128impl MaybeAdapter for NoAdapter {}
129
130// `derive_more::From` provides a terser construction for enum variants of `ModelParams`.
131#[derive(derive_more::From)]
132pub struct Config<Q: QuantParams, A: MaybeAdapter> {
133    pub quant: Q,
134    pub adapter: A,
135}
136
137// NOTE: Variantly used for `.expect_quantized()` / `.expect_adapted()` methods
138// `where` clause required due to bug with inline bounds:
139// https://github.com/luker-os/variantly/pull/16
140#[allow(clippy::large_enum_variant)]
141#[derive(variantly::Variantly)]
142pub enum ModelParams<'a, Q>
143where
144    Q: QuantParams,
145{
146    Quantized(Config<Q, NoAdapter>),
147    Adapted(Config<Q, Adapter<'a>>),
148}
149
150// A `builder()` method is derived from the `new()` method and it's params (derived builder struct fields).
151// NOTE: Intended to be built via fluent API in a single line, cannot conditionally append params.
152// `.adapter(Adapter<' >)` or for conditional usage `.and_adapter(Option<Adapter<' >)` can be used.
153// Otherwise omitting an `.adapter()` call prior to calling `build()` is ok, defaults to `None`.
154impl<'a, Q: QuantParams> ModelParams<'a, Q> {
155    pub fn new<'b: 'a>(quant: Q, adapter: Option<Adapter<'b>>) -> Self {
156        match adapter {
157            None => Self::Quantized((quant, NoAdapter {}).into()),
158            Some(a) => Self::Adapted((quant, a).into()),
159        }
160    }
161}
162
163// Traits for the existing methods used across various model types to impl `from_ggml()` / `from_gguf()`
164// Basic:
165pub trait FromGGML {
166    fn from_ggml(
167        ct: ggml_file::Content,
168        gqa: usize,
169        dtype: DType,
170    ) -> Result<Self, candle_core::Error>
171    where
172        Self: Sized;
173}
174
175pub trait FromGGUF {
176    fn from_gguf<R: std::io::Seek + std::io::Read>(
177        ct: Content<'_, R>,
178        device: &candle_core::Device,
179        mapper: Box<dyn DeviceMapper + Send + Sync>,
180        attention_mechanism: AttentionImplementation,
181        dtype: DType,
182    ) -> Result<Self, candle_core::Error>
183    where
184        Self: Sized;
185}
186
187// Extended variants:
188pub trait FromAdapterGGML {
189    #[allow(clippy::too_many_arguments)]
190    fn from_ggml(
191        ct: ggml_file::Content,
192        gqa: usize,
193        lora_config: &[((String, String), LoraConfig)],
194        vb: &ShardedVarBuilder,
195        ordering: &Ordering,
196        xlora_config: Option<XLoraConfig>,
197        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
198        dtype: DType,
199    ) -> Result<Self, candle_core::Error>
200    where
201        Self: Sized;
202}
203pub trait FromAdapterGGUF {
204    #[allow(clippy::too_many_arguments)]
205    fn from_gguf<R: std::io::Seek + std::io::Read>(
206        ct: Content<'_, R>,
207        device: &candle_core::Device,
208        lora_config: &[((String, String), LoraConfig)],
209        vb: &ShardedVarBuilder,
210        ordering: &Ordering,
211        xlora_config: Option<XLoraConfig>,
212        mapper: Box<dyn DeviceMapper + Send + Sync>,
213        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
214        dtype: DType,
215    ) -> Result<Self, candle_core::Error>
216    where
217        Self: Sized;
218}
219
220// NOTE: Below is a workaround to proxy params to the existing API methods `get_gguf()` / `get_gmml()` traits covered above.
221impl Config<ParamsGGML, NoAdapter> {
222    pub fn try_into_model<T: FromGGML>(self) -> Result<T, candle_core::Error> {
223        // Destructure props:
224        let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
225
226        // Forwards all structured fields above into the required flattened param sequence:
227        T::from_ggml(ct, gqa, dtype)
228    }
229}
230
231impl Config<ParamsGGML, Adapter<'_>> {
232    pub fn try_into_model<T: FromAdapterGGML>(self) -> Result<T, candle_core::Error> {
233        // Destructure props:
234        let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
235
236        let Adapter {
237            xlora_config,
238            lora_config,
239            vb,
240            ordering,
241            preload_adapters,
242        } = self.adapter;
243
244        // Forwards all structured fields above into the required flattened param sequence:
245        T::from_ggml(
246            ct,
247            gqa,
248            lora_config,
249            &vb,
250            ordering,
251            xlora_config,
252            &preload_adapters,
253            dtype,
254        )
255    }
256}
257
258impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, NoAdapter> {
259    pub fn try_into_model<T: FromGGUF>(self) -> Result<T, candle_core::Error> {
260        // Destructure props:
261        let ParamsGGUF(ct, Device { device, mapper }, attention_implementation, dtype) = self.quant;
262
263        // Forwards all structured fields above into the required flattened param sequence:
264        T::from_gguf(ct, device, mapper, attention_implementation, dtype)
265    }
266}
267
268impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, Adapter<'_>> {
269    pub fn try_into_model<T: FromAdapterGGUF>(self) -> Result<T, candle_core::Error> {
270        // Destructure props:
271        let ParamsGGUF(ct, Device { device, mapper }, _attention_implementation, dtype) =
272            self.quant;
273
274        let Adapter {
275            xlora_config,
276            lora_config,
277            vb,
278            ordering,
279            preload_adapters,
280        } = self.adapter;
281
282        // Forwards all structured fields above into the required flattened param sequence:
283        T::from_gguf(
284            ct,
285            device,
286            lora_config,
287            &vb,
288            ordering,
289            xlora_config,
290            mapper,
291            &preload_adapters,
292            dtype,
293        )
294    }
295}
296
297use crate::{
298    models::quantized_llama::ModelWeights as QLlama,
299    models::quantized_phi2::ModelWeights as QPhi,
300    models::quantized_phi3::ModelWeights as QPhi3,
301    models::quantized_qwen2::ModelWeights as QQwen2,
302    models::quantized_starcoder2::ModelWeights as QStarcoder2,
303    xlora_models::{XLoraQLlama, XLoraQPhi3},
304};
305use akin::akin;
306
307impl TryFrom<ModelParams<'_, ParamsGGML>> for QLlama {
308    type Error = candle_core::Error;
309
310    fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
311        let config = params.expect_quantized("`Config` should be GGML Quantized");
312        config.try_into_model()
313    }
314}
315
316impl TryFrom<ModelParams<'_, ParamsGGML>> for XLoraQLlama {
317    type Error = candle_core::Error;
318
319    fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
320        let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter");
321        config.try_into_model()
322    }
323}
324
325akin! {
326    let &models_gguf = [QLlama, QPhi, QPhi3, QStarcoder2, QQwen2];
327
328    impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf {
329        type Error = candle_core::Error;
330
331        fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
332            let config = params.expect_quantized("`Config` should be GGUF Quantized");
333            config.try_into_model()
334        }
335    }
336}
337
338akin! {
339    let &models_gguf_a = [XLoraQLlama, XLoraQPhi3];
340
341    impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf_a {
342        type Error = candle_core::Error;
343
344        fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
345            let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter");
346            config.try_into_model()
347        }
348    }
349}