pub fn compute_kv_shard( total_num_kv_heads: usize, head_dim: usize, comm: &Comm, ) -> Shard
Compute the appropriate KV shard. This handles KV head replication. Be sure to use compute_n_kv_groups in tandem.
compute_n_kv_groups