mod.rsuse std::{fmt::Debug, sync::Barrier};
use candle_core::{Result, Tensor};
pub use ops::{Comm, Id, SumAllReduce};
pub mod layers;
pub mod socket;
pub trait BarrierLike: Debug + Send + Sync {
fn wait(&self) -> Result<()>;
impl BarrierLike for Barrier {
fn wait(&self) -> Result<()> {
pub fn get_global_tp_size_from_devices() -> Result<usize> {
#[cfg(feature = "cuda")]
use candle_core::cuda::WrapErr;
.map(|x| x as usize)
#[cfg(not(feature = "cuda"))]
pub fn use_nccl() -> bool {
|| std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
&& (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
pub trait DistributedOperation {
fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor>;
#[cfg(all(feature = "cuda", feature = "nccl"))]
mod ops {
use std::{fmt::Debug, ops::Deref, sync::Arc};
use candle_core::{
backend::BackendStorage, cuda::cudarc, cuda_backend::WrapErr, CpuStorage, CustomOp1, DType,
Device, Layout, Result, Shape, Tensor,
#[derive(Debug, Clone, Copy)]
pub struct Id(cudarc::nccl::Id);
impl Id {
pub fn new() -> Self {
let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
pub fn uninit(internal: [::core::ffi::c_char; 128usize]) -> Self {
pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
pub struct Comm {
comm: cudarc::nccl::Comm,
impl Comm {
pub fn from_device(id: Id, dev: &Device, rank: usize, world_size: usize) -> Result<Self> {
let device = dev.as_cuda_device()?.cuda_device();
assert_eq!(rank, device.ordinal());
Ok(Self {
comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, id.0)
.map_err(|e| e.0)
.expect("Failed to create `Comm`, error code"),
unsafe impl Sync for Comm {}
unsafe impl Send for Comm {}
impl Deref for Comm {
type Target = cudarc::nccl::Comm;
fn deref(&self) -> &Self::Target {
#[derive(Clone, Debug)]
pub struct SumAllReduce {
comm: Arc<Comm>,
impl SumAllReduce {
pub fn new(comm: &Arc<Comm>) -> Self {
Self { comm: comm.clone() }
impl super::DistributedOperation for SumAllReduce {
fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
impl CustomOp1 for SumAllReduce {
fn name(&self) -> &'static str {
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
candle_core::bail!("SumAllReduce is never used on cpu")
fn cuda_fwd(
s: &candle_core::CudaStorage,
l: &Layout,
) -> Result<(candle_core::CudaStorage, Shape)> {
use cudarc::{driver::DeviceSlice, nccl::ReduceOp};
use half::{bf16, f16};
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let dst = match s.dtype() {
DType::BF16 => {
let s = s.as_cuda_slice::<bf16>()?;
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle_core::bail!("input has to be contiguous"),
assert_eq!(dev.ordinal(), self.comm.rank());
assert!(elem_count > 0);
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
.all_reduce(s, &mut dst, &ReduceOp::Sum)
candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
DType::F16 => {
let s = s.as_cuda_slice::<f16>()?;
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle_core::bail!("input has to be contiguous"),
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
.all_reduce(s, &mut dst, &ReduceOp::Sum)
candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
DType::F32 => {
let s = s.as_cuda_slice::<f32>()?;
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle_core::bail!("input has to be contiguous"),
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
.all_reduce(s, &mut dst, &ReduceOp::Sum)
candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
Ok((dst, l.shape().clone()))
#[cfg(not(all(feature = "cuda", feature = "nccl")))]
mod ops {
use std::sync::Arc;
use candle_core::{Device, Result, Tensor};
#[derive(Debug, Clone, Copy)]
pub struct Id;
impl Default for Id {
fn default() -> Self {
impl Id {
pub fn new() -> Self {
pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
pub struct Comm;
impl Comm {
pub fn from_device(
_id: Id,
_dev: &Device,
_rank: usize,
_world_size: usize,
) -> Result<Self> {
pub fn rank(&self) -> usize {
pub fn world_size(&self) -> usize {
#[derive(Clone, Debug)]
pub struct SumAllReduce;
impl SumAllReduce {
pub fn new(_comm: &Arc<Comm>) -> Self {
impl super::DistributedOperation for SumAllReduce {
fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {