1#[cfg(feature = "cuda")]
2mod ffi;
3
4#[cfg(feature = "cuda")]
5mod cuda {
6 use candle_core::{DType, Result, Storage, Tensor};
7 use half::{bf16, f16};
8 use std::ffi::{c_int, c_long};
9
10 use crate::utils::slice_ptr;
11
12 fn apply_rotary_<
13 T: candle_core::cuda_backend::CudaDType
14 + candle_core::cuda_backend::cudarc::driver::DeviceRepr,
15 >(
16 query: &Tensor,
17 key: &Tensor,
18 cos_cache: &Tensor,
19 sin_cache: &Tensor,
20 is_neox: bool,
21 ) -> Result<()> {
22 let dtype = query.dtype();
23 if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype {
24 candle_core::bail!("apply-rotary expects all tensors to have the same dtype");
25 }
26
27 let internal_type = match dtype {
28 DType::F16 => 0,
29 DType::BF16 => 1,
30 DType::F32 => 2,
31 dtype => candle_core::bail!("dtype {dtype:?} is not supported"),
32 };
33
34 let (q, q_l) = query.storage_and_layout();
35 let q = match &*q {
36 Storage::Cuda(q) => q,
37 _ => candle_core::bail!("query must be a cuda tensor"),
38 };
39
40 let (k, k_l) = key.storage_and_layout();
41 let k = match &*k {
42 Storage::Cuda(k) => k,
43 _ => candle_core::bail!("key must be a cuda tensor"),
44 };
45
46 let (cc, cc_l) = cos_cache.storage_and_layout();
47 let cc = match &*cc {
48 Storage::Cuda(cc) => cc,
49 _ => candle_core::bail!("cos_cache must be a cuda tensor"),
50 };
51
52 let (sc, sc_l) = sin_cache.storage_and_layout();
53 let sc = match &*sc {
54 Storage::Cuda(sc) => sc,
55 _ => candle_core::bail!("sin_cache must be a cuda tensor"),
56 };
57
58 let q_rank = q_l.stride().len();
59 let k_rank = k_l.stride().len();
60 let cc_rank = cc_l.stride().len();
61 let sc_rank = sc_l.stride().len();
62
63 if q_rank != 3 || k_rank != 3 {
64 candle_core::bail!(
65 "apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})"
66 )
67 }
68
69 if cc_rank != 2 || sc_rank != 2 {
70 candle_core::bail!(
71 "apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})"
72 )
73 }
74
75 let q = q.as_cuda_slice::<T>()?;
77 let k = k.as_cuda_slice::<T>()?;
78 let cc = cc.as_cuda_slice::<T>()?;
79 let sc = sc.as_cuda_slice::<T>()?;
80
81 let (q, _q_guard) = slice_ptr(q, q_l.start_offset());
83 let (k, _k_guard) = slice_ptr(k, k_l.start_offset());
84 let (cc, _cc_guard) = slice_ptr(cc, cc_l.start_offset());
85 let (sc, _sc_guard) = slice_ptr(sc, sc_l.start_offset());
86
87 let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?;
88 let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?;
89
90 if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) {
91 candle_core::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
92 }
93
94 let rot_dim = cc_l.dims()[1];
95 if (num_tokens, rot_dim) != cc_l.shape().dims2()? {
96 candle_core::bail!(
97 "shape mismatch cos_cache {:?}, expected {:?}",
98 cc_l.shape(),
99 (num_tokens, rot_dim)
100 )
101 }
102
103 if (num_tokens, rot_dim) != sc_l.shape().dims2()? {
104 candle_core::bail!(
105 "shape mismatch sin_cache {:?}, expected {:?}",
106 sc_l.shape(),
107 (num_tokens, rot_dim)
108 )
109 }
110
111 let query_stride = q_l.stride()[0];
112 let key_stride = k_l.stride()[0];
113
114 let neox = if is_neox { 1 } else { 0 };
115
116 unsafe {
117 super::ffi::rotary_embedding(
118 q as *const core::ffi::c_void,
119 k as *const core::ffi::c_void,
120 cc as *const core::ffi::c_void,
121 sc as *const core::ffi::c_void,
122 neox,
123 head_size as c_int,
124 num_tokens as c_long,
125 rot_dim as c_int,
126 num_heads as c_int,
127 num_kv_heads as c_int,
128 query_stride as c_long,
129 key_stride as c_long,
130 internal_type,
131 )
132 }
133 Ok(())
134 }
135
136 pub fn apply_rotary_inplace(
146 query: &Tensor,
147 key: &Tensor,
148 cos_cache: &Tensor,
149 sin_cache: &Tensor,
150 is_neox: bool,
151 ) -> Result<()> {
152 match key.dtype() {
153 DType::F16 => apply_rotary_::<f16>(query, key, cos_cache, sin_cache, is_neox),
154 DType::BF16 => apply_rotary_::<bf16>(query, key, cos_cache, sin_cache, is_neox),
155 DType::F32 => apply_rotary_::<f32>(query, key, cos_cache, sin_cache, is_neox),
156 dt => {
157 candle_core::bail!("apply_rotary is only supported for f32, f16 and bf16 ({dt:?})")
158 }
159 }
160 }
161}
162
163#[cfg(feature = "cuda")]
164pub use cuda::*;
165
166#[cfg(not(feature = "cuda"))]
176pub fn apply_rotary_inplace(
177 _query: &candle_core::Tensor,
178 _key: &candle_core::Tensor,
179 _cos_cache: &candle_core::Tensor,
180 _sin_cache: &candle_core::Tensor,
181 _is_neox: bool,
182) -> candle_core::Result<()> {
183 candle_core::bail!("apply_rotary is only supported for cuda");
184}