pub fn compute_n_kv_groups( total_num_kv_heads: usize, num_attention_heads: usize, comm: &Comm, ) -> usize
Compute the number of KV groups, taking into account KV head replication.