mistralrs_core/utils/
varbuilder_utils.rs

1//! Utilities for creating a VarBuilder from a VarMap loaded from tensor storage formats.
2
3use std::{
4    collections::HashMap,
5    path::PathBuf,
6    sync::Arc,
7    thread::{self, JoinHandle},
8};
9
10use candle_core::{pickle::PthTensors, DType, Device, Result, Tensor};
11use mistralrs_quant::{safetensors::MmapedSafetensors, ShardedSafeTensors, ShardedVarBuilder};
12use regex::Regex;
13
14use crate::lora::LoraConfig;
15use crate::utils::progress::IterWithProgress;
16use derive_new::new;
17
18trait TensorLoaderBackend {
19    fn get_names(&self) -> Vec<String>;
20    fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
21}
22
23struct SafetensorBackend(MmapedSafetensors);
24
25impl TensorLoaderBackend for SafetensorBackend {
26    fn get_names(&self) -> Vec<String> {
27        self.0
28            .tensors()
29            .into_iter()
30            .map(|(name, _)| name)
31            .collect::<Vec<_>>()
32    }
33    fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
34        self.0.load(name, device, dtype)
35    }
36}
37
38struct PickleBackend(PthTensors);
39
40impl TensorLoaderBackend for PickleBackend {
41    fn get_names(&self) -> Vec<String> {
42        self.0.tensor_infos().keys().cloned().collect::<Vec<_>>()
43    }
44    fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
45        self.0
46            .get(name)?
47            .ok_or(candle_core::Error::Msg(format!(
48                "Could not load tensor {name}"
49            )))?
50            .to_device(device)
51    }
52}
53
54pub enum DeviceForLoadTensor {
55    Base,
56    Idx(usize),
57}
58
59/// Load tensors into a VarBuilder backed by a VarMap using MmapedSafetensors.
60/// Set `silent` to not show a progress bar.
61///
62/// # Predicate semantics:
63/// - If `regexes` is specified, this will be used in `make_dummy_predicate` based on `.any`
64/// - Otherwise, only include keys for which predicate evaluates to true.
65#[allow(clippy::too_many_arguments)]
66pub(crate) fn from_mmaped_safetensors<'a>(
67    paths: Vec<PathBuf>,
68    xlora_paths: Vec<PathBuf>,
69    dtype: Option<DType>,
70    base_device: &Device,
71    layer_devices: Vec<Option<Device>>,
72    silent: bool,
73    make_dummy_regexes: Option<Arc<Vec<Regex>>>,
74    predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static,
75    get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
76) -> Result<ShardedVarBuilder<'a>> {
77    if base_device.is_cuda() {
78        return Ok(unsafe {
79            ShardedSafeTensors::sharded(
80                &paths,
81                dtype.unwrap_or(DType::F16),
82                base_device,
83                make_dummy_regexes,
84            )?
85        });
86    }
87
88    #[allow(clippy::type_complexity)]
89    let mut handles: Vec<JoinHandle<Result<HashMap<String, Tensor>>>> = Vec::new();
90
91    for path in paths {
92        let base_device = base_device.clone();
93        let layer_devices = layer_devices.clone();
94        let get_device_for_tensor = get_device_for_tensor.clone();
95        if let Some(regexes) = make_dummy_regexes.clone() {
96            let predicate = predicate.clone();
97            handles.push(thread::spawn(Box::new(move || {
98                let loader = Common::new();
99                loader.load_tensors_from_path(
100                    &path,
101                    &base_device,
102                    layer_devices,
103                    get_device_for_tensor,
104                    dtype,
105                    silent,
106                    predicate,
107                    |key| regexes.iter().any(|r| r.is_match(key)),
108                )
109            })));
110        } else {
111            let predicate = predicate.clone();
112            handles.push(thread::spawn(Box::new(move || {
113                let loader = Common::new();
114                loader.load_tensors_from_path(
115                    &path,
116                    &base_device,
117                    layer_devices,
118                    get_device_for_tensor,
119                    dtype,
120                    silent,
121                    predicate,
122                    |_| false,
123                )
124            })));
125        }
126    }
127    for (i, path) in xlora_paths.into_iter().enumerate() {
128        let base_device = base_device.clone();
129        let layer_devices = layer_devices.clone();
130        let get_device_for_tensor = get_device_for_tensor.clone();
131        if let Some(regexes) = make_dummy_regexes.clone() {
132            let predicate = predicate.clone();
133            handles.push(thread::spawn(Box::new(move || {
134                let loader = XLora::new(i + 1);
135                loader.load_tensors_from_path(
136                    &path,
137                    &base_device,
138                    layer_devices,
139                    get_device_for_tensor,
140                    dtype,
141                    silent,
142                    predicate,
143                    |key| regexes.iter().any(|r| r.is_match(key)),
144                )
145            })));
146        } else {
147            let predicate = predicate.clone();
148            handles.push(thread::spawn(Box::new(move || {
149                let loader = XLora::new(i + 1);
150                loader.load_tensors_from_path(
151                    &path,
152                    &base_device,
153                    layer_devices,
154                    get_device_for_tensor,
155                    dtype,
156                    silent,
157                    predicate,
158                    |_| false,
159                )
160            })));
161        }
162    }
163
164    let mut ws = HashMap::new();
165    // Wait until all spawned threads have finished loading tensors:
166    while !handles.iter().all(|h| h.is_finished()) {}
167    for h in handles {
168        ws.extend(h.join().unwrap()?);
169    }
170
171    let backend = Box::new(ws);
172
173    // TODO(EricLBuehler): separation of concerns.
174    // This is to have WNA16 for GPTQ which is required. No bf16 for GPTQ
175    Ok(ShardedSafeTensors::wrap(
176        backend,
177        dtype.unwrap_or(DType::F16),
178        base_device.clone(),
179    ))
180}
181
182pub(crate) fn load_preload_adapters<'a>(
183    paths: &Option<HashMap<String, (PathBuf, LoraConfig)>>,
184    dtype: DType,
185    device: &Device,
186    silent: bool,
187) -> Result<Option<HashMap<String, (ShardedVarBuilder<'a>, LoraConfig)>>> {
188    if let Some(paths) = paths {
189        let mut map = HashMap::new();
190        for (name, (path, config)) in paths {
191            let loader = Common::new();
192            let loaded_tensors = loader.load_tensors_from_path(
193                path,
194                device,
195                vec![None],
196                Arc::new(|_| DeviceForLoadTensor::Base),
197                Some(dtype),
198                silent,
199                |_| true,
200                |_| false,
201            )?;
202
203            let backend = Box::new(loaded_tensors);
204
205            // TODO(EricLBuehler): separation of concerns.
206            // This is to have WNA16 for GPTQ which is required. No bf16 for GPTQ
207            let vb = ShardedSafeTensors::wrap(backend, dtype, device.clone());
208
209            map.insert(name.clone(), (vb, config.clone()));
210        }
211        Ok(Some(map))
212    } else {
213        Ok(None)
214    }
215}
216
217// Presently this logic only needs to diverge for X-LoRA support via `get_name_key_pairs()`
218trait LoadTensors {
219    #[allow(clippy::too_many_arguments)]
220    fn load_tensors_from_path(
221        &self,
222        path: &PathBuf,
223        base_device: &Device,
224        layer_devices: Vec<Option<Device>>,
225        get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
226        dtype: Option<DType>,
227        is_silent: bool,
228        predicate: impl Fn(String) -> bool,
229        make_dummy_predicate: impl Fn(&str) -> bool,
230    ) -> Result<HashMap<String, Tensor>> {
231        let tensors: Box<dyn TensorLoaderBackend> = match path
232            .extension()
233            .expect("Expected extension")
234            .to_str()
235            .expect("Expected to convert")
236        {
237            "safetensors" => Box::new(SafetensorBackend(unsafe {
238                MmapedSafetensors::new(path)?
239            })),
240            "pth" | "pt" | "bin" => Box::new(PickleBackend(
241                candle_core::pickle::PthTensors::new(path, None)?
242            )),
243            other => candle_core::bail!("Unexpected extension `{other}`, this should have been handled by `get_model_paths`."),
244        };
245
246        // Extracts the tensor name and processes it, filtering tensors and deriving the key name:
247        let names_only = tensors
248            .get_names()
249            .into_iter()
250            .filter(|x| predicate(x.to_string()));
251        let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();
252
253        // Take the filtered list of tensors to load, store with derived lookup key:
254        let mut loaded_tensors = HashMap::new();
255        if !iter.is_empty() {
256            for (load_name, key_name) in iter.into_iter().with_progress(is_silent) {
257                if !make_dummy_predicate(&load_name) {
258                    let dev = match get_device_for_tensor(load_name.clone()) {
259                        DeviceForLoadTensor::Base => base_device,
260                        DeviceForLoadTensor::Idx(i) => layer_devices
261                            .get(i)
262                            .and_then(|d| d.as_ref())
263                            .unwrap_or(base_device),
264                    };
265                    // If making a dummy, don't add the tensor. `mistralrs_quant` handles this!
266                    let tensor = tensors.load_name(&load_name, dev, dtype)?;
267
268                    loaded_tensors.insert(key_name, tensor);
269                }
270            }
271        }
272
273        Ok(loaded_tensors)
274    }
275
276    fn get_name_key_pairs(
277        &self,
278        tensors: impl Iterator<Item = String>,
279    ) -> impl Iterator<Item = (String, String)> {
280        tensors.map(|name| {
281            let new_name = name.replace("base_model.model.model", "model");
282
283            (name, new_name)
284        })
285    }
286}
287
288#[derive(new)]
289struct Common {}
290impl LoadTensors for Common {}
291
292#[derive(new)]
293struct XLora {
294    // Matches the associated path instance for reference in `get_name_key_pairs()`
295    adapter_index: usize,
296}
297
298impl LoadTensors for XLora {
299    fn get_name_key_pairs(
300        &self,
301        tensors: impl Iterator<Item = String>,
302    ) -> impl Iterator<Item = (String, String)> {
303        let expectation = "tensor name `{new_name}` should have substring `.lora`";
304
305        tensors
306            .filter(|name| !name.contains("internal_xlora_classifier"))
307            .map(|name| {
308                let mut new_name = name.replace("base_model.model.model", "model");
309                // TODO: Add better context to describe intent / requirement:
310                let pos = new_name.find(".lora").expect(expectation);
311                new_name.insert_str(pos + 7, &format!(".{}", self.adapter_index));
312
313                (name, new_name)
314            })
315    }
316}