1#[cfg(feature = "cuda")]
2mod ffi;
3
4#[cfg(feature = "cuda")]
5mod cuda {
6 use candle_core::cuda_backend::cudarc::driver::DevicePtr;
7 use candle_core::{DType, Result, Storage, Tensor};
8 use half::{bf16, f16};
9 use std::ffi::{c_int, c_long};
10
11 fn apply_rotary_<
12 T: candle_core::cuda_backend::CudaDType
13 + candle_core::cuda_backend::cudarc::driver::DeviceRepr,
14 >(
15 query: &Tensor,
16 key: &Tensor,
17 cos_cache: &Tensor,
18 sin_cache: &Tensor,
19 is_neox: bool,
20 ) -> Result<()> {
21 let dtype = query.dtype();
22 if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype {
23 candle_core::bail!("apply-rotary expects all tensors to have the same dtype");
24 }
25
26 let internal_type = match dtype {
27 DType::F16 => 0,
28 DType::BF16 => 1,
29 DType::F32 => 2,
30 dtype => candle_core::bail!("dtype {dtype:?} is not supported"),
31 };
32
33 let (q, q_l) = query.storage_and_layout();
34 let q = match &*q {
35 Storage::Cuda(q) => q,
36 _ => candle_core::bail!("query must be a cuda tensor"),
37 };
38
39 let (k, k_l) = key.storage_and_layout();
40 let k = match &*k {
41 Storage::Cuda(k) => k,
42 _ => candle_core::bail!("key must be a cuda tensor"),
43 };
44
45 let (cc, cc_l) = cos_cache.storage_and_layout();
46 let cc = match &*cc {
47 Storage::Cuda(cc) => cc,
48 _ => candle_core::bail!("cos_cache must be a cuda tensor"),
49 };
50
51 let (sc, sc_l) = sin_cache.storage_and_layout();
52 let sc = match &*sc {
53 Storage::Cuda(sc) => sc,
54 _ => candle_core::bail!("sin_cache must be a cuda tensor"),
55 };
56
57 let q_rank = q_l.stride().len();
58 let k_rank = k_l.stride().len();
59 let cc_rank = cc_l.stride().len();
60 let sc_rank = sc_l.stride().len();
61
62 if q_rank != 3 || k_rank != 3 {
63 candle_core::bail!(
64 "apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})"
65 )
66 }
67
68 if cc_rank != 2 || sc_rank != 2 {
69 candle_core::bail!(
70 "apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})"
71 )
72 }
73
74 let q = q.as_cuda_slice::<T>()?;
76 let k = k.as_cuda_slice::<T>()?;
77 let cc = cc.as_cuda_slice::<T>()?;
78 let sc = sc.as_cuda_slice::<T>()?;
79
80 let q = q.slice(q_l.start_offset()..);
82 let k = k.slice(k_l.start_offset()..);
83 let cc = cc.slice(cc_l.start_offset()..);
84 let sc = sc.slice(sc_l.start_offset()..);
85
86 let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?;
87 let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?;
88
89 if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) {
90 candle_core::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
91 }
92
93 let rot_dim = cc_l.dims()[1];
94 if (num_tokens, rot_dim) != cc_l.shape().dims2()? {
95 candle_core::bail!(
96 "shape mismatch cos_cache {:?}, expected {:?}",
97 cc_l.shape(),
98 (num_tokens, rot_dim)
99 )
100 }
101
102 if (num_tokens, rot_dim) != sc_l.shape().dims2()? {
103 candle_core::bail!(
104 "shape mismatch sin_cache {:?}, expected {:?}",
105 sc_l.shape(),
106 (num_tokens, rot_dim)
107 )
108 }
109
110 let query_stride = q_l.stride()[0];
111 let key_stride = k_l.stride()[0];
112
113 let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
114 let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
115 let cc_ptr = *cc.device_ptr() as *const core::ffi::c_void;
116 let sc_ptr = *sc.device_ptr() as *const core::ffi::c_void;
117
118 let neox = if is_neox { 1 } else { 0 };
119
120 unsafe {
121 super::ffi::rotary_embedding(
122 q_ptr,
123 k_ptr,
124 cc_ptr,
125 sc_ptr,
126 neox,
127 head_size as c_int,
128 num_tokens as c_long,
129 rot_dim as c_int,
130 num_heads as c_int,
131 num_kv_heads as c_int,
132 query_stride as c_long,
133 key_stride as c_long,
134 internal_type,
135 )
136 }
137 Ok(())
138 }
139
140 pub fn apply_rotary_inplace(
150 query: &Tensor,
151 key: &Tensor,
152 cos_cache: &Tensor,
153 sin_cache: &Tensor,
154 is_neox: bool,
155 ) -> Result<()> {
156 match key.dtype() {
157 DType::F16 => apply_rotary_::<f16>(query, key, cos_cache, sin_cache, is_neox),
158 DType::BF16 => apply_rotary_::<bf16>(query, key, cos_cache, sin_cache, is_neox),
159 DType::F32 => apply_rotary_::<f32>(query, key, cos_cache, sin_cache, is_neox),
160 dt => {
161 candle_core::bail!("apply_rotary is only supported for f32, f16 and bf16 ({dt:?})")
162 }
163 }
164 }
165}
166
167#[cfg(feature = "cuda")]
168pub use cuda::*;
169
170#[cfg(not(feature = "cuda"))]
180pub fn apply_rotary_inplace(
181 _query: &candle_core::Tensor,
182 _key: &candle_core::Tensor,
183 _cos_cache: &candle_core::Tensor,
184 _sin_cache: &candle_core::Tensor,
185 _is_neox: bool,
186) -> candle_core::Result<()> {
187 candle_core::bail!("apply_rotary is only supported for cuda");
188}