mistralrs_core/utils/
varbuilder_utils.rs

1//! Utilities for creating a VarBuilder from a VarMap loaded from tensor storage formats.
2
3use std::{
4    collections::HashMap,
5    path::PathBuf,
6    sync::Arc,
7    thread::{self, JoinHandle},
8};
9
10use candle_core::{pickle::PthTensors, DType, Device, Result, Tensor};
11use mistralrs_quant::{safetensors::MmapedSafetensors, ShardedSafeTensors, ShardedVarBuilder};
12use regex::Regex;
13
14use crate::lora::LoraConfig;
15use crate::utils::progress::IterWithProgress;
16use derive_new::new;
17
18trait TensorLoaderBackend {
19    fn get_names(&self) -> Vec<String>;
20    fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
21}
22
23struct SafetensorBackend(MmapedSafetensors);
24
25impl TensorLoaderBackend for SafetensorBackend {
26    fn get_names(&self) -> Vec<String> {
27        self.0
28            .tensors()
29            .into_iter()
30            .map(|(name, _)| name)
31            .collect::<Vec<_>>()
32    }
33    fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
34        self.0.load(name, device, dtype)
35    }
36}
37
38struct PickleBackend(PthTensors);
39
40impl TensorLoaderBackend for PickleBackend {
41    fn get_names(&self) -> Vec<String> {
42        self.0.tensor_infos().keys().cloned().collect::<Vec<_>>()
43    }
44    fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
45        self.0
46            .get(name)?
47            .ok_or(candle_core::Error::Msg(format!(
48                "Could not load tensor {name}"
49            )))?
50            .to_device(device)
51    }
52}
53
54pub enum DeviceForLoadTensor {
55    Base,
56    Idx(usize),
57}
58
59/// Load tensors into a VarBuilder backed by a VarMap using MmapedSafetensors.
60/// Set `silent` to not show a progress bar.
61///
62/// # Predicate semantics:
63/// - If `regexes` is specified, this will be used in `make_dummy_predicate` based on `.any`
64/// - Otherwise, only include keys for which predicate evaluates to true.
65#[allow(clippy::too_many_arguments)]
66pub(crate) fn from_mmaped_safetensors(
67    paths: Vec<PathBuf>,
68    xlora_paths: Vec<PathBuf>,
69    dtype: Option<DType>,
70    base_device: &Device,
71    layer_devices: Vec<Option<Device>>,
72    silent: bool,
73    make_dummy_regexes: Option<Arc<Vec<Regex>>>,
74    predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static,
75    get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
76) -> Result<ShardedVarBuilder> {
77    // No mmap for cuda.
78    if xlora_paths.is_empty() && !base_device.is_cuda() || cfg!(feature = "ring") {
79        if !silent {
80            tracing::info!("Loading model using mmap strategy.");
81        }
82        return Ok(unsafe {
83            ShardedSafeTensors::sharded(
84                &paths,
85                dtype.unwrap_or(DType::F16),
86                base_device,
87                make_dummy_regexes,
88                Arc::new(predicate),
89            )?
90        });
91    }
92
93    #[allow(clippy::type_complexity)]
94    let mut handles: Vec<JoinHandle<Result<HashMap<String, Tensor>>>> = Vec::new();
95
96    for path in paths {
97        let base_device = base_device.clone();
98        let layer_devices = layer_devices.clone();
99        let get_device_for_tensor = get_device_for_tensor.clone();
100        if let Some(regexes) = make_dummy_regexes.clone() {
101            let predicate = predicate.clone();
102            handles.push(thread::spawn(Box::new(move || {
103                let loader = Common::new();
104                loader.load_tensors_from_path(
105                    &path,
106                    &base_device,
107                    layer_devices,
108                    get_device_for_tensor,
109                    dtype,
110                    silent,
111                    predicate,
112                    |key| regexes.iter().any(|r| r.is_match(key)),
113                )
114            })));
115        } else {
116            let predicate = predicate.clone();
117            handles.push(thread::spawn(Box::new(move || {
118                let loader = Common::new();
119                loader.load_tensors_from_path(
120                    &path,
121                    &base_device,
122                    layer_devices,
123                    get_device_for_tensor,
124                    dtype,
125                    silent,
126                    predicate,
127                    |_| false,
128                )
129            })));
130        }
131    }
132    for (i, path) in xlora_paths.into_iter().enumerate() {
133        let base_device = base_device.clone();
134        let layer_devices = layer_devices.clone();
135        let get_device_for_tensor = get_device_for_tensor.clone();
136        if let Some(regexes) = make_dummy_regexes.clone() {
137            let predicate = predicate.clone();
138            handles.push(thread::spawn(Box::new(move || {
139                let loader = XLora::new(i + 1);
140                loader.load_tensors_from_path(
141                    &path,
142                    &base_device,
143                    layer_devices,
144                    get_device_for_tensor,
145                    dtype,
146                    silent,
147                    predicate,
148                    |key| regexes.iter().any(|r| r.is_match(key)),
149                )
150            })));
151        } else {
152            let predicate = predicate.clone();
153            handles.push(thread::spawn(Box::new(move || {
154                let loader = XLora::new(i + 1);
155                loader.load_tensors_from_path(
156                    &path,
157                    &base_device,
158                    layer_devices,
159                    get_device_for_tensor,
160                    dtype,
161                    silent,
162                    predicate,
163                    |_| false,
164                )
165            })));
166        }
167    }
168
169    let mut ws = HashMap::new();
170    // Wait until all spawned threads have finished loading tensors:
171    while !handles.iter().all(|h| h.is_finished()) {}
172    for h in handles {
173        ws.extend(h.join().unwrap()?);
174    }
175
176    let backend = Box::new(ws);
177
178    // TODO(EricLBuehler): separation of concerns.
179    // This is to have WNA16 for GPTQ which is required. No bf16 for GPTQ
180    Ok(ShardedSafeTensors::wrap(
181        backend,
182        dtype.unwrap_or(DType::F16),
183        base_device.clone(),
184    ))
185}
186
187pub(crate) fn load_preload_adapters(
188    paths: &Option<HashMap<String, (PathBuf, LoraConfig)>>,
189    dtype: DType,
190    device: &Device,
191    silent: bool,
192) -> Result<Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>> {
193    if let Some(paths) = paths {
194        let mut map = HashMap::new();
195        for (name, (path, config)) in paths {
196            let loader = Common::new();
197            let loaded_tensors = loader.load_tensors_from_path(
198                path,
199                device,
200                vec![None],
201                Arc::new(|_| DeviceForLoadTensor::Base),
202                Some(dtype),
203                silent,
204                |_| true,
205                |_| false,
206            )?;
207
208            let backend = Box::new(loaded_tensors);
209
210            // TODO(EricLBuehler): separation of concerns.
211            // This is to have WNA16 for GPTQ which is required. No bf16 for GPTQ
212            let vb = ShardedSafeTensors::wrap(backend, dtype, device.clone());
213
214            map.insert(name.clone(), (vb, config.clone()));
215        }
216        Ok(Some(map))
217    } else {
218        Ok(None)
219    }
220}
221
222// Presently this logic only needs to diverge for X-LoRA support via `get_name_key_pairs()`
223trait LoadTensors {
224    #[allow(clippy::too_many_arguments)]
225    fn load_tensors_from_path(
226        &self,
227        path: &PathBuf,
228        base_device: &Device,
229        layer_devices: Vec<Option<Device>>,
230        get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
231        dtype: Option<DType>,
232        is_silent: bool,
233        predicate: impl Fn(String) -> bool,
234        make_dummy_predicate: impl Fn(&str) -> bool,
235    ) -> Result<HashMap<String, Tensor>> {
236        let tensors: Box<dyn TensorLoaderBackend> = match path
237            .extension()
238            .expect("Expected extension")
239            .to_str()
240            .expect("Expected to convert")
241        {
242            "safetensors" => Box::new(SafetensorBackend(unsafe {
243                MmapedSafetensors::new(path)?
244            })),
245            "pth" | "pt" | "bin" => Box::new(PickleBackend(
246                candle_core::pickle::PthTensors::new(path, None)?
247            )),
248            other => candle_core::bail!("Unexpected extension `{other}`, this should have been handled by `get_model_paths`."),
249        };
250
251        // Extracts the tensor name and processes it, filtering tensors and deriving the key name:
252        let names_only = tensors
253            .get_names()
254            .into_iter()
255            .filter(|x| predicate(x.to_string()));
256        let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();
257
258        // Take the filtered list of tensors to load, store with derived lookup key:
259        let mut loaded_tensors = HashMap::new();
260        if !iter.is_empty() {
261            for (load_name, key_name) in iter.into_iter().with_progress(is_silent) {
262                if !make_dummy_predicate(&load_name) {
263                    let dev = match get_device_for_tensor(load_name.clone()) {
264                        DeviceForLoadTensor::Base => base_device,
265                        DeviceForLoadTensor::Idx(i) => layer_devices
266                            .get(i)
267                            .and_then(|d| d.as_ref())
268                            .unwrap_or(base_device),
269                    };
270                    // If making a dummy, don't add the tensor. `mistralrs_quant` handles this!
271                    let tensor = tensors.load_name(&load_name, dev, dtype)?;
272
273                    loaded_tensors.insert(key_name, tensor);
274                }
275            }
276        }
277
278        Ok(loaded_tensors)
279    }
280
281    fn get_name_key_pairs(
282        &self,
283        tensors: impl Iterator<Item = String>,
284    ) -> impl Iterator<Item = (String, String)> {
285        tensors.map(|name| {
286            let new_name = name.replace("base_model.model.model", "model");
287
288            (name, new_name)
289        })
290    }
291}
292
293#[derive(new)]
294struct Common {}
295impl LoadTensors for Common {}
296
297#[derive(new)]
298struct XLora {
299    // Matches the associated path instance for reference in `get_name_key_pairs()`
300    adapter_index: usize,
301}
302
303impl LoadTensors for XLora {
304    fn get_name_key_pairs(
305        &self,
306        tensors: impl Iterator<Item = String>,
307    ) -> impl Iterator<Item = (String, String)> {
308        let expectation = "tensor name `{new_name}` should have substring `.lora`";
309
310        tensors
311            .filter(|name| !name.contains("internal_xlora_classifier"))
312            .map(|name| {
313                let mut new_name = name.replace("base_model.model.model", "model");
314                // TODO: Add better context to describe intent / requirement:
315                let pos = new_name.find(".lora").expect(expectation);
316                new_name.insert_str(pos + 7, &format!(".{}", self.adapter_index));
317
318                (name, new_name)
319            })
320    }
321}