mistralrs_quant/utils/
uqff.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2
3use candle_core::{DType, Device, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use half::{bf16, f16};
6
7// v0.1.0: initial release
8// v0.1.1: add i16 dtype
9// v0.1.2: add F8E4M3
10// v0.1.3: add AFQ
11// v0.2.0: add f4/f6e3m2/f6e2m3/f8e8m0 type handling
12
13const UQFF_VERSION_MAJOR: u32 = 0;
14const UQFF_VERSION_MINOR: u32 = 2;
15const UQFF_VERSION_PATCH: u32 = 0;
16
17/// Format 4 bytes, little endian: [ UNSPECIFIED ] [ MAJOR ] [ MINOR ] [ PATCH ]
18pub(crate) const UQFF_VERSION: u32 =
19    (UQFF_VERSION_MAJOR << (8 * 2)) | (UQFF_VERSION_MINOR << 8) | UQFF_VERSION_PATCH;
20/// Offset for the quant type. UQFF always serializes the version first.
21pub const UQFF_QUANT_TYPE_OFFSET: usize = std::mem::size_of::<u32>();
22
23/// Check if major version matches: is backwards compatible
24pub(crate) fn version_is_compatible(version: u32) -> Result<()> {
25    let major = version >> (8 * 2);
26    let minor = (version >> 8) & 0xFF;
27    let patch = version & 0xFF;
28
29    if major != UQFF_VERSION_MAJOR {
30        candle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({UQFF_VERSION_MAJOR})");
31    }
32
33    // Check minor version for forward compatibility
34    if minor > UQFF_VERSION_MINOR {
35        candle_core::bail!("Minor version of ISQ artifact file ({major}.{minor}.{patch}) is newer than this build supports ({UQFF_VERSION_MAJOR}.{UQFF_VERSION_MINOR}.{UQFF_VERSION_PATCH}). Please update mistral.rs.");
36    }
37
38    Ok(())
39}
40
41// -----------------------
42// Tensor dtype, u32, little endian
43// -----------------------
44pub(crate) fn write_dtype(dtype: DType, buffer: &mut Vec<u8>) {
45    let dtype: u32 = match dtype {
46        DType::U8 => 0,
47        DType::U32 => 1,
48        DType::I32 => 2,
49        DType::I64 => 3,
50        DType::F16 => 4,
51        DType::BF16 => 5,
52        DType::F32 => 6,
53        DType::F64 => 7,
54        DType::I16 => 8,
55        DType::F8E4M3 => 9,
56        DType::F6E2M3 => 10,
57        DType::F6E3M2 => 11,
58        DType::F4 => 12,
59        DType::F8E8M0 => 13,
60    };
61    buffer.extend(&dtype.to_le_bytes());
62}
63
64pub(crate) fn read_dtype<R: std::io::Read>(buffer: &mut R) -> Result<DType> {
65    let dtype = buffer.read_u32::<LittleEndian>()?;
66    let dtype = match dtype {
67        0 => DType::U8,
68        1 => DType::U32,
69        2 => DType::I32,
70        3 => DType::I64,
71        4 => DType::F16,
72        5 => DType::BF16,
73        6 => DType::F32,
74        7 => DType::F64,
75        8 => DType::I16,
76        9 => DType::F8E4M3,
77        10 => DType::F6E2M3,
78        11 => DType::F6E3M2,
79        12 => DType::F4,
80        13 => DType::F8E8M0,
81        _ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"),
82    };
83    Ok(dtype)
84}
85
86// -----------------------
87// Tensor data length, u32, little endian
88// -----------------------
89// Tensor dtype, u32, little endian
90// -----------------------
91// Num shape dims, u32, little endian
92// -----------------------
93// ...
94// Array (in original order): shape dims, u32, little endian
95// ...
96// -----------------------
97// ...
98// Array: tensor data, u8s
99// ...
100// -----------------------
101
102pub(crate) fn serialize_tensor(buffer: &mut Vec<u8>, tensor: &Tensor) -> Result<()> {
103    let b_shape = tensor.dims();
104    let tensor = tensor.flatten_all()?;
105
106    let bias = match tensor.dtype() {
107        DType::U8 => data_to_bytes::<u8>(tensor.to_vec1()?),
108        DType::U32 => data_to_bytes::<u32>(tensor.to_vec1()?),
109        DType::I16 => data_to_bytes::<i16>(tensor.to_vec1()?),
110        DType::I32 => data_to_bytes::<i32>(tensor.to_vec1()?),
111        DType::I64 => data_to_bytes::<i64>(tensor.to_vec1()?),
112        DType::F16 => data_to_bytes::<half::f16>(tensor.to_vec1()?),
113        DType::BF16 => data_to_bytes::<half::bf16>(tensor.to_vec1()?),
114        DType::F32 => data_to_bytes::<f32>(tensor.to_vec1()?),
115        DType::F64 => data_to_bytes::<f64>(tensor.to_vec1()?),
116        DType::F8E4M3 => data_to_bytes::<F8E4M3>(tensor.to_vec1()?),
117        DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
118            candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors cannot be serialized.")
119        }
120    };
121
122    // Check for potential overflow when converting usize to u32
123    let data_len = bias.len();
124    if data_len > u32::MAX as usize {
125        candle_core::bail!(
126            "Tensor data too large for UQFF format: {} bytes exceeds u32::MAX",
127            data_len
128        );
129    }
130    buffer.extend(&(data_len as u32).to_le_bytes());
131
132    // DType
133    write_dtype(tensor.dtype(), buffer);
134
135    // Shape
136    let shape_len = b_shape.len();
137    if shape_len > u32::MAX as usize {
138        candle_core::bail!(
139            "Tensor has too many dimensions for UQFF format: {} exceeds u32::MAX",
140            shape_len
141        );
142    }
143    buffer.extend((shape_len as u32).to_le_bytes());
144    for dim in b_shape {
145        if *dim > u32::MAX as usize {
146            candle_core::bail!(
147                "Tensor dimension too large for UQFF format: {} exceeds u32::MAX",
148                dim
149            );
150        }
151        buffer.extend((*dim as u32).to_le_bytes());
152    }
153
154    buffer.extend(bias);
155
156    Ok(())
157}
158
159pub(crate) fn deserialize_tensor<R: std::io::Read>(
160    buffer: &mut R,
161    device: &Device,
162) -> Result<Tensor> {
163    let data_len = buffer.read_u32::<LittleEndian>()? as usize;
164
165    // DType
166    let dtype = read_dtype(buffer)?;
167
168    let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
169
170    let mut dims = Vec::with_capacity(n_dims);
171    for _ in 0..n_dims {
172        dims.push(buffer.read_u32::<LittleEndian>()? as usize)
173    }
174
175    let mut tensor_data = vec![0; data_len];
176    buffer.read_exact(&mut tensor_data)?;
177
178    match dtype {
179        DType::F16 => bytes_to_data::<f16>(&tensor_data, &dims, device),
180        DType::BF16 => bytes_to_data::<bf16>(&tensor_data, &dims, device),
181        DType::F32 => bytes_to_data::<f32>(&tensor_data, &dims, device),
182        DType::F64 => bytes_to_data::<f64>(&tensor_data, &dims, device),
183        DType::I32 => bytes_to_data::<i32>(&tensor_data, &dims, device),
184        DType::I64 => bytes_to_data::<i64>(&tensor_data, &dims, device),
185        DType::I16 => bytes_to_data::<i16>(&tensor_data, &dims, device),
186        DType::U32 => bytes_to_data::<u32>(&tensor_data, &dims, device),
187        DType::U8 => bytes_to_data::<u8>(&tensor_data, &dims, device),
188        DType::F8E4M3 => bytes_to_data::<F8E4M3>(&tensor_data, &dims, device),
189        DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
190            candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors cannot be deserialized.")
191        }
192    }
193}
194
195/// Just seek the reader ahead.
196pub(crate) fn fake_deserialize_tensor<R: std::io::Read + std::io::Seek>(
197    buffer: &mut R,
198) -> Result<()> {
199    let data_len = buffer.read_u32::<LittleEndian>()? as usize;
200
201    // DType
202    let _dtype = read_dtype(buffer)?;
203
204    let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
205
206    let mut dims = Vec::with_capacity(n_dims);
207    for _ in 0..n_dims {
208        dims.push(buffer.read_u32::<LittleEndian>()? as usize)
209    }
210
211    // Fake read the data in bytes
212    buffer.seek_relative(data_len as i64)?;
213
214    Ok(())
215}
216
217fn data_to_bytes<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
218    let size_in_bytes = T::DTYPE.size_in_bytes();
219    let length = vs.len() * size_in_bytes;
220    let capacity = vs.capacity() * size_in_bytes;
221    let ptr = vs.as_mut_ptr() as *mut u8;
222    // Don't run the destructor for Vec<T>
223    std::mem::forget(vs);
224    // SAFETY:
225    //
226    // Every T is larger than u8, so there is no issue regarding alignment.
227    // This re-interpret the Vec<T> as a Vec<u8>.
228    unsafe { Vec::from_raw_parts(ptr, length, capacity) }
229}
230
231fn bytes_to_data<T: WithDType>(
232    data: &[u8],
233    shape: &[usize],
234    device: &candle_core::Device,
235) -> Result<Tensor> {
236    let size_in_bytes = T::DTYPE.size_in_bytes();
237    let elem_count = data.len() / size_in_bytes;
238    if (data.as_ptr() as usize) % size_in_bytes == 0 {
239        // SAFETY This is safe because we just checked that this
240        // was correctly aligned.
241        let data: &[T] =
242            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
243        Tensor::from_slice(data, shape, device)
244    } else {
245        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
246        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
247        let mut c: Vec<T> = Vec::with_capacity(elem_count);
248        // SAFETY: We just created c, so the allocated memory is necessarily
249        // contiguous and non overlapping with the view's data.
250        // We're downgrading the `c` pointer from T to u8, which removes alignment
251        // constraints.
252        unsafe {
253            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
254            c.set_len(elem_count)
255        }
256        Tensor::from_slice(&c, shape, device)
257    }
258}