mistralrs_quant/utils/
uqff.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
use byteorder::{LittleEndian, ReadBytesExt};

use candle_core::{DType, Device, Result, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16};

// v0.1.0: initial release
// v0.1.1: add i16 dtype
// v0.1.2: add F8E4M3

const HQFF_VERSION_MAJOR: u32 = 0;
const HQFF_VERSION_MINOR: u32 = 1;
const HQFF_VERSION_PATCH: u32 = 2;

/// Format 4 bytes, little endian: [ UNSPECIFIED ] [ MAJOR ] [ MINOR ] [ PATCH ]
pub(crate) const HQFF_VERSION: u32 =
    (HQFF_VERSION_MAJOR << (8 * 2)) | (HQFF_VERSION_MINOR << 8) | HQFF_VERSION_PATCH;

/// Check if major version matches: is backwards compatible
pub(crate) fn version_is_compatible(version: u32) -> Result<()> {
    let major = version >> (8 * 2);
    let _minor = version >> 8;
    let _patch = version;

    if major != HQFF_VERSION_MAJOR {
        candle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({HQFF_VERSION_MAJOR})");
    }

    Ok(())
}

// -----------------------
// Tensor dtype, u32, little endian
// -----------------------
pub(crate) fn write_dtype(dtype: DType, buffer: &mut Vec<u8>) {
    let dtype: u32 = match dtype {
        DType::U8 => 0,
        DType::U32 => 1,
        DType::I32 => 2,
        DType::I64 => 3,
        DType::F16 => 4,
        DType::BF16 => 5,
        DType::F32 => 6,
        DType::F64 => 7,
        DType::I16 => 8,
        DType::F8E4M3 => 9,
    };
    buffer.extend(&dtype.to_le_bytes());
}

pub(crate) fn read_dtype<R: std::io::Read>(buffer: &mut R) -> Result<DType> {
    let dtype = buffer.read_u32::<LittleEndian>()?;
    let dtype = match dtype {
        0 => DType::U8,
        1 => DType::U32,
        2 => DType::I32,
        3 => DType::I64,
        4 => DType::F16,
        5 => DType::BF16,
        6 => DType::F32,
        7 => DType::F64,
        8 => DType::I16,
        9 => DType::F8E4M3,
        _ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"),
    };
    Ok(dtype)
}

// -----------------------
// Tensor data length, u32, little endian
// -----------------------
// Tensor dtype, u32, little endian
// -----------------------
// Num shape dims, u32, little endian
// -----------------------
// ...
// Array (in original order): shape dims, u32, little endian
// ...
// -----------------------
// ...
// Array: tensor data, u8s
// ...
// -----------------------

pub(crate) fn serialize_tensor(buffer: &mut Vec<u8>, tensor: &Tensor) -> Result<()> {
    let b_shape = tensor.dims();
    let tensor = tensor.flatten_all()?;

    let bias = match tensor.dtype() {
        DType::U8 => data_to_bytes::<u8>(tensor.to_vec1()?),
        DType::U32 => data_to_bytes::<u32>(tensor.to_vec1()?),
        DType::I16 => data_to_bytes::<i16>(tensor.to_vec1()?),
        DType::I32 => data_to_bytes::<i32>(tensor.to_vec1()?),
        DType::I64 => data_to_bytes::<i64>(tensor.to_vec1()?),
        DType::F16 => data_to_bytes::<half::f16>(tensor.to_vec1()?),
        DType::BF16 => data_to_bytes::<half::bf16>(tensor.to_vec1()?),
        DType::F32 => data_to_bytes::<f32>(tensor.to_vec1()?),
        DType::F64 => data_to_bytes::<f64>(tensor.to_vec1()?),
        DType::F8E4M3 => data_to_bytes::<F8E4M3>(tensor.to_vec1()?),
    };
    buffer.extend(&(bias.len() as u32).to_le_bytes());

    // DType
    write_dtype(tensor.dtype(), buffer);

    // Shape
    buffer.extend((b_shape.len() as u32).to_le_bytes());
    for dim in b_shape {
        buffer.extend((*dim as u32).to_le_bytes());
    }

    buffer.extend(bias);

    Ok(())
}

pub(crate) fn deserialize_tensor<R: std::io::Read>(
    buffer: &mut R,
    device: &Device,
) -> Result<Tensor> {
    let data_len = buffer.read_u32::<LittleEndian>()? as usize;

    // DType
    let dtype = read_dtype(buffer)?;

    let n_dims = buffer.read_u32::<LittleEndian>()? as usize;

    let mut dims = Vec::with_capacity(n_dims);
    for _ in 0..n_dims {
        dims.push(buffer.read_u32::<LittleEndian>()? as usize)
    }

    let mut tensor_data = vec![0; data_len];
    buffer.read_exact(&mut tensor_data)?;

    match dtype {
        DType::F16 => bytes_to_data::<f16>(&tensor_data, &dims, device),
        DType::BF16 => bytes_to_data::<bf16>(&tensor_data, &dims, device),
        DType::F32 => bytes_to_data::<f32>(&tensor_data, &dims, device),
        DType::F64 => bytes_to_data::<f64>(&tensor_data, &dims, device),
        DType::I32 => bytes_to_data::<i32>(&tensor_data, &dims, device),
        DType::I64 => bytes_to_data::<i64>(&tensor_data, &dims, device),
        DType::I16 => bytes_to_data::<i16>(&tensor_data, &dims, device),
        DType::U32 => bytes_to_data::<u32>(&tensor_data, &dims, device),
        DType::U8 => bytes_to_data::<u8>(&tensor_data, &dims, device),
        DType::F8E4M3 => bytes_to_data::<F8E4M3>(&tensor_data, &dims, device),
    }
}

fn data_to_bytes<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
    let size_in_bytes = T::DTYPE.size_in_bytes();
    let length = vs.len() * size_in_bytes;
    let capacity = vs.capacity() * size_in_bytes;
    let ptr = vs.as_mut_ptr() as *mut u8;
    // Don't run the destructor for Vec<T>
    std::mem::forget(vs);
    // SAFETY:
    //
    // Every T is larger than u8, so there is no issue regarding alignment.
    // This re-interpret the Vec<T> as a Vec<u8>.
    unsafe { Vec::from_raw_parts(ptr, length, capacity) }
}

fn bytes_to_data<T: WithDType>(
    data: &[u8],
    shape: &[usize],
    device: &candle_core::Device,
) -> Result<Tensor> {
    let size_in_bytes = T::DTYPE.size_in_bytes();
    let elem_count = data.len() / size_in_bytes;
    if (data.as_ptr() as usize) % size_in_bytes == 0 {
        // SAFETY This is safe because we just checked that this
        // was correctly aligned.
        let data: &[T] =
            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
        Tensor::from_slice(data, shape, device)
    } else {
        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
        let mut c: Vec<T> = Vec::with_capacity(elem_count);
        // SAFETY: We just created c, so the allocated memory is necessarily
        // contiguous and non overlapping with the view's data.
        // We're downgrading the `c` pointer from T to u8, which removes alignment
        // constraints.
        unsafe {
            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
            c.set_len(elem_count)
        }
        Tensor::from_slice(&c, shape, device)
    }
}