1use std::{
4 collections::HashMap,
5 path::PathBuf,
6 sync::Arc,
7 thread::{self, JoinHandle},
8};
9
10use candle_core::{pickle::PthTensors, DType, Device, Result, Tensor};
11use mistralrs_quant::{safetensors::MmapedSafetensors, ShardedSafeTensors, ShardedVarBuilder};
12use regex::Regex;
13
14use crate::lora::LoraConfig;
15use crate::utils::progress::IterWithProgress;
16use derive_new::new;
17
18const MISTRALRS_NO_MMAP: &str = "MISTRALRS_NO_MMAP";
19
20trait TensorLoaderBackend {
21 fn get_names(&self) -> Vec<String>;
22 fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
23}
24
25struct SafetensorBackend(MmapedSafetensors);
26
27impl TensorLoaderBackend for SafetensorBackend {
28 fn get_names(&self) -> Vec<String> {
29 self.0
30 .tensors()
31 .into_iter()
32 .map(|(name, _)| name)
33 .collect::<Vec<_>>()
34 }
35 fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
36 self.0.load(name, device, dtype)
37 }
38}
39
40struct PickleBackend(PthTensors);
41
42impl TensorLoaderBackend for PickleBackend {
43 fn get_names(&self) -> Vec<String> {
44 self.0.tensor_infos().keys().cloned().collect::<Vec<_>>()
45 }
46 fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
47 self.0
48 .get(name)?
49 .ok_or(candle_core::Error::Msg(format!(
50 "Could not load tensor {name}"
51 )))?
52 .to_device(device)
53 }
54}
55
56pub enum DeviceForLoadTensor {
57 Base,
58 Idx(usize),
59}
60
61#[allow(clippy::too_many_arguments)]
68pub(crate) fn from_mmaped_safetensors(
69 paths: Vec<PathBuf>,
70 xlora_paths: Vec<PathBuf>,
71 dtype: Option<DType>,
72 base_device: &Device,
73 layer_devices: Vec<Option<Device>>,
74 silent: bool,
75 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
76 predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static,
77 get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
78) -> Result<ShardedVarBuilder> {
79 let use_no_mmap = std::env::var(MISTRALRS_NO_MMAP).is_ok_and(|x| x == "1");
80 if xlora_paths.is_empty() && !use_no_mmap {
81 if !silent {
82 tracing::info!("Loading model using mmap strategy.");
83 }
84 return Ok(unsafe {
85 ShardedSafeTensors::sharded(
86 &paths,
87 dtype.unwrap_or(DType::F16),
88 base_device,
89 make_dummy_regexes,
90 Arc::new(predicate),
91 )?
92 });
93 }
94
95 #[allow(clippy::type_complexity)]
96 let mut handles: Vec<JoinHandle<Result<HashMap<String, Tensor>>>> = Vec::new();
97
98 for path in paths {
99 let base_device = base_device.clone();
100 let layer_devices = layer_devices.clone();
101 let get_device_for_tensor = get_device_for_tensor.clone();
102 if let Some(regexes) = make_dummy_regexes.clone() {
103 let predicate = predicate.clone();
104 handles.push(thread::spawn(Box::new(move || {
105 let loader = Common::new();
106 loader.load_tensors_from_path(
107 &path,
108 &base_device,
109 layer_devices,
110 get_device_for_tensor,
111 dtype,
112 silent,
113 predicate,
114 |key| regexes.iter().any(|r| r.is_match(key)),
115 )
116 })));
117 } else {
118 let predicate = predicate.clone();
119 handles.push(thread::spawn(Box::new(move || {
120 let loader = Common::new();
121 loader.load_tensors_from_path(
122 &path,
123 &base_device,
124 layer_devices,
125 get_device_for_tensor,
126 dtype,
127 silent,
128 predicate,
129 |_| false,
130 )
131 })));
132 }
133 }
134 for (i, path) in xlora_paths.into_iter().enumerate() {
135 let base_device = base_device.clone();
136 let layer_devices = layer_devices.clone();
137 let get_device_for_tensor = get_device_for_tensor.clone();
138 if let Some(regexes) = make_dummy_regexes.clone() {
139 let predicate = predicate.clone();
140 handles.push(thread::spawn(Box::new(move || {
141 let loader = XLora::new(i + 1);
142 loader.load_tensors_from_path(
143 &path,
144 &base_device,
145 layer_devices,
146 get_device_for_tensor,
147 dtype,
148 silent,
149 predicate,
150 |key| regexes.iter().any(|r| r.is_match(key)),
151 )
152 })));
153 } else {
154 let predicate = predicate.clone();
155 handles.push(thread::spawn(Box::new(move || {
156 let loader = XLora::new(i + 1);
157 loader.load_tensors_from_path(
158 &path,
159 &base_device,
160 layer_devices,
161 get_device_for_tensor,
162 dtype,
163 silent,
164 predicate,
165 |_| false,
166 )
167 })));
168 }
169 }
170
171 let mut ws = HashMap::new();
172 while !handles.iter().all(|h| h.is_finished()) {}
174 for h in handles {
175 ws.extend(h.join().unwrap()?);
176 }
177
178 let backend = Box::new(ws);
179
180 Ok(ShardedSafeTensors::wrap(
183 backend,
184 dtype.unwrap_or(DType::F16),
185 base_device.clone(),
186 ))
187}
188
189pub(crate) fn load_preload_adapters(
190 paths: &Option<HashMap<String, (PathBuf, LoraConfig)>>,
191 dtype: DType,
192 device: &Device,
193 silent: bool,
194) -> Result<Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>> {
195 if let Some(paths) = paths {
196 let mut map = HashMap::new();
197 for (name, (path, config)) in paths {
198 let loader = Common::new();
199 let loaded_tensors = loader.load_tensors_from_path(
200 path,
201 device,
202 vec![None],
203 Arc::new(|_| DeviceForLoadTensor::Base),
204 Some(dtype),
205 silent,
206 |_| true,
207 |_| false,
208 )?;
209
210 let backend = Box::new(loaded_tensors);
211
212 let vb = ShardedSafeTensors::wrap(backend, dtype, device.clone());
215
216 map.insert(name.clone(), (vb, config.clone()));
217 }
218 Ok(Some(map))
219 } else {
220 Ok(None)
221 }
222}
223
224trait LoadTensors {
226 #[allow(clippy::too_many_arguments)]
227 fn load_tensors_from_path(
228 &self,
229 path: &PathBuf,
230 base_device: &Device,
231 layer_devices: Vec<Option<Device>>,
232 get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
233 dtype: Option<DType>,
234 is_silent: bool,
235 predicate: impl Fn(String) -> bool,
236 make_dummy_predicate: impl Fn(&str) -> bool,
237 ) -> Result<HashMap<String, Tensor>> {
238 let tensors: Box<dyn TensorLoaderBackend> = match path
239 .extension()
240 .expect("Expected extension")
241 .to_str()
242 .expect("Expected to convert")
243 {
244 "safetensors" => Box::new(SafetensorBackend(unsafe {
245 MmapedSafetensors::new(path)?
246 })),
247 "pth" | "pt" | "bin" => Box::new(PickleBackend(
248 candle_core::pickle::PthTensors::new(path, None)?
249 )),
250 other => candle_core::bail!("Unexpected extension `{other}`, this should have been handled by `get_model_paths`."),
251 };
252
253 let names_only = tensors
255 .get_names()
256 .into_iter()
257 .filter(|x| predicate(x.to_string()));
258 let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();
259
260 let mut loaded_tensors = HashMap::new();
262 if !iter.is_empty() {
263 for (load_name, key_name) in iter.into_iter().with_progress(is_silent) {
264 if !make_dummy_predicate(&load_name) {
265 let dev = match get_device_for_tensor(load_name.clone()) {
266 DeviceForLoadTensor::Base => base_device,
267 DeviceForLoadTensor::Idx(i) => layer_devices
268 .get(i)
269 .and_then(|d| d.as_ref())
270 .unwrap_or(base_device),
271 };
272 let tensor = tensors.load_name(&load_name, dev, dtype)?;
274
275 loaded_tensors.insert(key_name, tensor);
276 }
277 }
278 }
279
280 Ok(loaded_tensors)
281 }
282
283 fn get_name_key_pairs(
284 &self,
285 tensors: impl Iterator<Item = String>,
286 ) -> impl Iterator<Item = (String, String)> {
287 tensors.map(|name| {
288 let new_name = name.replace("base_model.model.model", "model");
289
290 (name, new_name)
291 })
292 }
293}
294
295#[derive(new)]
296struct Common {}
297impl LoadTensors for Common {}
298
299#[derive(new)]
300struct XLora {
301 adapter_index: usize,
303}
304
305impl LoadTensors for XLora {
306 fn get_name_key_pairs(
307 &self,
308 tensors: impl Iterator<Item = String>,
309 ) -> impl Iterator<Item = (String, String)> {
310 let expectation = "tensor name `{new_name}` should have substring `.lora`";
311
312 tensors
313 .filter(|name| !name.contains("internal_xlora_classifier"))
314 .map(|name| {
315 let mut new_name = name.replace("base_model.model.model", "model");
316 let pos = new_name.find(".lora").expect(expectation);
318 new_name.insert_str(pos + 7, &format!(".{}", self.adapter_index));
319
320 (name, new_name)
321 })
322 }
323}