mistralrs_quant/utils/
uqff.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2
3use candle_core::{DType, Device, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use half::{bf16, f16};
6
7// v0.1.0: initial release
8// v0.1.1: add i16 dtype
9// v0.1.2: add F8E4M3
10// v0.1.3: add AFQ
11
12const UQFF_VERSION_MAJOR: u32 = 0;
13const UQFF_VERSION_MINOR: u32 = 1;
14const UQFF_VERSION_PATCH: u32 = 3;
15
16/// Format 4 bytes, little endian: [ UNSPECIFIED ] [ MAJOR ] [ MINOR ] [ PATCH ]
17pub(crate) const UQFF_VERSION: u32 =
18    (UQFF_VERSION_MAJOR << (8 * 2)) | (UQFF_VERSION_MINOR << 8) | UQFF_VERSION_PATCH;
19/// Offset for the quant type. UQFF always serializes the version first.
20pub const UQFF_QUANT_TYPE_OFFSET: usize = std::mem::size_of::<u32>();
21
22/// Check if major version matches: is backwards compatible
23pub(crate) fn version_is_compatible(version: u32) -> Result<()> {
24    let major = version >> (8 * 2);
25    let _minor = version >> 8;
26    let _patch = version;
27
28    if major != UQFF_VERSION_MAJOR {
29        candle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({UQFF_VERSION_MAJOR})");
30    }
31
32    Ok(())
33}
34
35// -----------------------
36// Tensor dtype, u32, little endian
37// -----------------------
38pub(crate) fn write_dtype(dtype: DType, buffer: &mut Vec<u8>) {
39    let dtype: u32 = match dtype {
40        DType::U8 => 0,
41        DType::U32 => 1,
42        DType::I32 => 2,
43        DType::I64 => 3,
44        DType::F16 => 4,
45        DType::BF16 => 5,
46        DType::F32 => 6,
47        DType::F64 => 7,
48        DType::I16 => 8,
49        DType::F8E4M3 => 9,
50    };
51    buffer.extend(&dtype.to_le_bytes());
52}
53
54pub(crate) fn read_dtype<R: std::io::Read>(buffer: &mut R) -> Result<DType> {
55    let dtype = buffer.read_u32::<LittleEndian>()?;
56    let dtype = match dtype {
57        0 => DType::U8,
58        1 => DType::U32,
59        2 => DType::I32,
60        3 => DType::I64,
61        4 => DType::F16,
62        5 => DType::BF16,
63        6 => DType::F32,
64        7 => DType::F64,
65        8 => DType::I16,
66        9 => DType::F8E4M3,
67        _ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"),
68    };
69    Ok(dtype)
70}
71
72// -----------------------
73// Tensor data length, u32, little endian
74// -----------------------
75// Tensor dtype, u32, little endian
76// -----------------------
77// Num shape dims, u32, little endian
78// -----------------------
79// ...
80// Array (in original order): shape dims, u32, little endian
81// ...
82// -----------------------
83// ...
84// Array: tensor data, u8s
85// ...
86// -----------------------
87
88pub(crate) fn serialize_tensor(buffer: &mut Vec<u8>, tensor: &Tensor) -> Result<()> {
89    let b_shape = tensor.dims();
90    let tensor = tensor.flatten_all()?;
91
92    let bias = match tensor.dtype() {
93        DType::U8 => data_to_bytes::<u8>(tensor.to_vec1()?),
94        DType::U32 => data_to_bytes::<u32>(tensor.to_vec1()?),
95        DType::I16 => data_to_bytes::<i16>(tensor.to_vec1()?),
96        DType::I32 => data_to_bytes::<i32>(tensor.to_vec1()?),
97        DType::I64 => data_to_bytes::<i64>(tensor.to_vec1()?),
98        DType::F16 => data_to_bytes::<half::f16>(tensor.to_vec1()?),
99        DType::BF16 => data_to_bytes::<half::bf16>(tensor.to_vec1()?),
100        DType::F32 => data_to_bytes::<f32>(tensor.to_vec1()?),
101        DType::F64 => data_to_bytes::<f64>(tensor.to_vec1()?),
102        DType::F8E4M3 => data_to_bytes::<F8E4M3>(tensor.to_vec1()?),
103    };
104    buffer.extend(&(bias.len() as u32).to_le_bytes());
105
106    // DType
107    write_dtype(tensor.dtype(), buffer);
108
109    // Shape
110    buffer.extend((b_shape.len() as u32).to_le_bytes());
111    for dim in b_shape {
112        buffer.extend((*dim as u32).to_le_bytes());
113    }
114
115    buffer.extend(bias);
116
117    Ok(())
118}
119
120pub(crate) fn deserialize_tensor<R: std::io::Read>(
121    buffer: &mut R,
122    device: &Device,
123) -> Result<Tensor> {
124    let data_len = buffer.read_u32::<LittleEndian>()? as usize;
125
126    // DType
127    let dtype = read_dtype(buffer)?;
128
129    let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
130
131    let mut dims = Vec::with_capacity(n_dims);
132    for _ in 0..n_dims {
133        dims.push(buffer.read_u32::<LittleEndian>()? as usize)
134    }
135
136    let mut tensor_data = vec![0; data_len];
137    buffer.read_exact(&mut tensor_data)?;
138
139    match dtype {
140        DType::F16 => bytes_to_data::<f16>(&tensor_data, &dims, device),
141        DType::BF16 => bytes_to_data::<bf16>(&tensor_data, &dims, device),
142        DType::F32 => bytes_to_data::<f32>(&tensor_data, &dims, device),
143        DType::F64 => bytes_to_data::<f64>(&tensor_data, &dims, device),
144        DType::I32 => bytes_to_data::<i32>(&tensor_data, &dims, device),
145        DType::I64 => bytes_to_data::<i64>(&tensor_data, &dims, device),
146        DType::I16 => bytes_to_data::<i16>(&tensor_data, &dims, device),
147        DType::U32 => bytes_to_data::<u32>(&tensor_data, &dims, device),
148        DType::U8 => bytes_to_data::<u8>(&tensor_data, &dims, device),
149        DType::F8E4M3 => bytes_to_data::<F8E4M3>(&tensor_data, &dims, device),
150    }
151}
152
153/// Just seek the reader ahead.
154pub(crate) fn fake_deserialize_tensor<R: std::io::Read + std::io::Seek>(
155    buffer: &mut R,
156) -> Result<()> {
157    let data_len = buffer.read_u32::<LittleEndian>()? as usize;
158
159    // DType
160    let _dtype = read_dtype(buffer)?;
161
162    let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
163
164    let mut dims = Vec::with_capacity(n_dims);
165    for _ in 0..n_dims {
166        dims.push(buffer.read_u32::<LittleEndian>()? as usize)
167    }
168
169    // Fake read the data in bytes
170    buffer.seek_relative(data_len as i64)?;
171
172    Ok(())
173}
174
175fn data_to_bytes<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
176    let size_in_bytes = T::DTYPE.size_in_bytes();
177    let length = vs.len() * size_in_bytes;
178    let capacity = vs.capacity() * size_in_bytes;
179    let ptr = vs.as_mut_ptr() as *mut u8;
180    // Don't run the destructor for Vec<T>
181    std::mem::forget(vs);
182    // SAFETY:
183    //
184    // Every T is larger than u8, so there is no issue regarding alignment.
185    // This re-interpret the Vec<T> as a Vec<u8>.
186    unsafe { Vec::from_raw_parts(ptr, length, capacity) }
187}
188
189fn bytes_to_data<T: WithDType>(
190    data: &[u8],
191    shape: &[usize],
192    device: &candle_core::Device,
193) -> Result<Tensor> {
194    let size_in_bytes = T::DTYPE.size_in_bytes();
195    let elem_count = data.len() / size_in_bytes;
196    if (data.as_ptr() as usize) % size_in_bytes == 0 {
197        // SAFETY This is safe because we just checked that this
198        // was correctly aligned.
199        let data: &[T] =
200            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
201        Tensor::from_slice(data, shape, device)
202    } else {
203        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
204        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
205        let mut c: Vec<T> = Vec::with_capacity(elem_count);
206        // SAFETY: We just created c, so the allocated memory is necessarily
207        // contiguous and non overlapping with the view's data.
208        // We're downgrading the `c` pointer from T to u8, which removes alignment
209        // constraints.
210        unsafe {
211            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
212            c.set_len(elem_count)
213        }
214        Tensor::from_slice(&c, shape, device)
215    }
216}