mistralrs_core/utils/
varbuilder_utils.rs

1//! Utilities for creating a VarBuilder from a VarMap loaded from tensor storage formats.
2
3use std::{
4    collections::HashMap,
5    path::PathBuf,
6    sync::Arc,
7    thread::{self, JoinHandle},
8};
9
10use candle_core::{pickle::PthTensors, DType, Device, Result, Tensor};
11use mistralrs_quant::{safetensors::MmapedSafetensors, ShardedSafeTensors, ShardedVarBuilder};
12use regex::Regex;
13
14use crate::lora::LoraConfig;
15use crate::utils::progress::IterWithProgress;
16use derive_new::new;
17
18const MISTRALRS_NO_MMAP: &str = "MISTRALRS_NO_MMAP";
19
20trait TensorLoaderBackend {
21    fn get_names(&self) -> Vec<String>;
22    fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
23}
24
25struct SafetensorBackend(MmapedSafetensors);
26
27impl TensorLoaderBackend for SafetensorBackend {
28    fn get_names(&self) -> Vec<String> {
29        self.0
30            .tensors()
31            .into_iter()
32            .map(|(name, _)| name)
33            .collect::<Vec<_>>()
34    }
35    fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
36        self.0.load(name, device, dtype)
37    }
38}
39
40struct PickleBackend(PthTensors);
41
42impl TensorLoaderBackend for PickleBackend {
43    fn get_names(&self) -> Vec<String> {
44        self.0.tensor_infos().keys().cloned().collect::<Vec<_>>()
45    }
46    fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
47        self.0
48            .get(name)?
49            .ok_or(candle_core::Error::Msg(format!(
50                "Could not load tensor {name}"
51            )))?
52            .to_device(device)
53    }
54}
55
56pub enum DeviceForLoadTensor {
57    Base,
58    Idx(usize),
59}
60
61/// Load tensors into a VarBuilder backed by a VarMap using MmapedSafetensors.
62/// Set `silent` to not show a progress bar.
63///
64/// # Predicate semantics:
65/// - If `regexes` is specified, this will be used in `make_dummy_predicate` based on `.any`
66/// - Otherwise, only include keys for which predicate evaluates to true.
67#[allow(clippy::too_many_arguments)]
68pub(crate) fn from_mmaped_safetensors(
69    paths: Vec<PathBuf>,
70    xlora_paths: Vec<PathBuf>,
71    dtype: Option<DType>,
72    base_device: &Device,
73    layer_devices: Vec<Option<Device>>,
74    silent: bool,
75    make_dummy_regexes: Option<Arc<Vec<Regex>>>,
76    predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static,
77    get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
78) -> Result<ShardedVarBuilder> {
79    let use_no_mmap = std::env::var(MISTRALRS_NO_MMAP).is_ok_and(|x| x == "1");
80    if xlora_paths.is_empty() && !use_no_mmap {
81        if !silent {
82            tracing::info!("Loading model using mmap strategy.");
83        }
84        return Ok(unsafe {
85            ShardedSafeTensors::sharded(
86                &paths,
87                dtype.unwrap_or(DType::F16),
88                base_device,
89                make_dummy_regexes,
90                Arc::new(predicate),
91            )?
92        });
93    }
94
95    #[allow(clippy::type_complexity)]
96    let mut handles: Vec<JoinHandle<Result<HashMap<String, Tensor>>>> = Vec::new();
97
98    for path in paths {
99        let base_device = base_device.clone();
100        let layer_devices = layer_devices.clone();
101        let get_device_for_tensor = get_device_for_tensor.clone();
102        if let Some(regexes) = make_dummy_regexes.clone() {
103            let predicate = predicate.clone();
104            handles.push(thread::spawn(Box::new(move || {
105                let loader = Common::new();
106                loader.load_tensors_from_path(
107                    &path,
108                    &base_device,
109                    layer_devices,
110                    get_device_for_tensor,
111                    dtype,
112                    silent,
113                    predicate,
114                    |key| regexes.iter().any(|r| r.is_match(key)),
115                )
116            })));
117        } else {
118            let predicate = predicate.clone();
119            handles.push(thread::spawn(Box::new(move || {
120                let loader = Common::new();
121                loader.load_tensors_from_path(
122                    &path,
123                    &base_device,
124                    layer_devices,
125                    get_device_for_tensor,
126                    dtype,
127                    silent,
128                    predicate,
129                    |_| false,
130                )
131            })));
132        }
133    }
134    for (i, path) in xlora_paths.into_iter().enumerate() {
135        let base_device = base_device.clone();
136        let layer_devices = layer_devices.clone();
137        let get_device_for_tensor = get_device_for_tensor.clone();
138        if let Some(regexes) = make_dummy_regexes.clone() {
139            let predicate = predicate.clone();
140            handles.push(thread::spawn(Box::new(move || {
141                let loader = XLora::new(i + 1);
142                loader.load_tensors_from_path(
143                    &path,
144                    &base_device,
145                    layer_devices,
146                    get_device_for_tensor,
147                    dtype,
148                    silent,
149                    predicate,
150                    |key| regexes.iter().any(|r| r.is_match(key)),
151                )
152            })));
153        } else {
154            let predicate = predicate.clone();
155            handles.push(thread::spawn(Box::new(move || {
156                let loader = XLora::new(i + 1);
157                loader.load_tensors_from_path(
158                    &path,
159                    &base_device,
160                    layer_devices,
161                    get_device_for_tensor,
162                    dtype,
163                    silent,
164                    predicate,
165                    |_| false,
166                )
167            })));
168        }
169    }
170
171    let mut ws = HashMap::new();
172    // Wait until all spawned threads have finished loading tensors:
173    while !handles.iter().all(|h| h.is_finished()) {}
174    for h in handles {
175        ws.extend(h.join().unwrap()?);
176    }
177
178    let backend = Box::new(ws);
179
180    // TODO(EricLBuehler): separation of concerns.
181    // This is to have WNA16 for GPTQ which is required. No bf16 for GPTQ
182    Ok(ShardedSafeTensors::wrap(
183        backend,
184        dtype.unwrap_or(DType::F16),
185        base_device.clone(),
186    ))
187}
188
189pub(crate) fn load_preload_adapters(
190    paths: &Option<HashMap<String, (PathBuf, LoraConfig)>>,
191    dtype: DType,
192    device: &Device,
193    silent: bool,
194) -> Result<Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>> {
195    if let Some(paths) = paths {
196        let mut map = HashMap::new();
197        for (name, (path, config)) in paths {
198            let loader = Common::new();
199            let loaded_tensors = loader.load_tensors_from_path(
200                path,
201                device,
202                vec![None],
203                Arc::new(|_| DeviceForLoadTensor::Base),
204                Some(dtype),
205                silent,
206                |_| true,
207                |_| false,
208            )?;
209
210            let backend = Box::new(loaded_tensors);
211
212            // TODO(EricLBuehler): separation of concerns.
213            // This is to have WNA16 for GPTQ which is required. No bf16 for GPTQ
214            let vb = ShardedSafeTensors::wrap(backend, dtype, device.clone());
215
216            map.insert(name.clone(), (vb, config.clone()));
217        }
218        Ok(Some(map))
219    } else {
220        Ok(None)
221    }
222}
223
224// Presently this logic only needs to diverge for X-LoRA support via `get_name_key_pairs()`
225trait LoadTensors {
226    #[allow(clippy::too_many_arguments)]
227    fn load_tensors_from_path(
228        &self,
229        path: &PathBuf,
230        base_device: &Device,
231        layer_devices: Vec<Option<Device>>,
232        get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
233        dtype: Option<DType>,
234        is_silent: bool,
235        predicate: impl Fn(String) -> bool,
236        make_dummy_predicate: impl Fn(&str) -> bool,
237    ) -> Result<HashMap<String, Tensor>> {
238        let tensors: Box<dyn TensorLoaderBackend> = match path
239            .extension()
240            .expect("Expected extension")
241            .to_str()
242            .expect("Expected to convert")
243        {
244            "safetensors" => Box::new(SafetensorBackend(unsafe {
245                MmapedSafetensors::new(path)?
246            })),
247            "pth" | "pt" | "bin" => Box::new(PickleBackend(
248                candle_core::pickle::PthTensors::new(path, None)?
249            )),
250            other => candle_core::bail!("Unexpected extension `{other}`, this should have been handled by `get_model_paths`."),
251        };
252
253        // Extracts the tensor name and processes it, filtering tensors and deriving the key name:
254        let names_only = tensors
255            .get_names()
256            .into_iter()
257            .filter(|x| predicate(x.to_string()));
258        let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();
259
260        // Take the filtered list of tensors to load, store with derived lookup key:
261        let mut loaded_tensors = HashMap::new();
262        if !iter.is_empty() {
263            for (load_name, key_name) in iter.into_iter().with_progress(is_silent) {
264                if !make_dummy_predicate(&load_name) {
265                    let dev = match get_device_for_tensor(load_name.clone()) {
266                        DeviceForLoadTensor::Base => base_device,
267                        DeviceForLoadTensor::Idx(i) => layer_devices
268                            .get(i)
269                            .and_then(|d| d.as_ref())
270                            .unwrap_or(base_device),
271                    };
272                    // If making a dummy, don't add the tensor. `mistralrs_quant` handles this!
273                    let tensor = tensors.load_name(&load_name, dev, dtype)?;
274
275                    loaded_tensors.insert(key_name, tensor);
276                }
277            }
278        }
279
280        Ok(loaded_tensors)
281    }
282
283    fn get_name_key_pairs(
284        &self,
285        tensors: impl Iterator<Item = String>,
286    ) -> impl Iterator<Item = (String, String)> {
287        tensors.map(|name| {
288            let new_name = name.replace("base_model.model.model", "model");
289
290            (name, new_name)
291        })
292    }
293}
294
295#[derive(new)]
296struct Common {}
297impl LoadTensors for Common {}
298
299#[derive(new)]
300struct XLora {
301    // Matches the associated path instance for reference in `get_name_key_pairs()`
302    adapter_index: usize,
303}
304
305impl LoadTensors for XLora {
306    fn get_name_key_pairs(
307        &self,
308        tensors: impl Iterator<Item = String>,
309    ) -> impl Iterator<Item = (String, String)> {
310        let expectation = "tensor name `{new_name}` should have substring `.lora`";
311
312        tensors
313            .filter(|name| !name.contains("internal_xlora_classifier"))
314            .map(|name| {
315                let mut new_name = name.replace("base_model.model.model", "model");
316                // TODO: Add better context to describe intent / requirement:
317                let pos = new_name.find(".lora").expect(expectation);
318                new_name.insert_str(pos + 7, &format!(".{}", self.adapter_index));
319
320                (name, new_name)
321            })
322    }
323}