diffusion_rs_core/pipelines/
scheduler.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
use diffusion_rs_common::core::{Context, Result};
use serde::Deserialize;

#[derive(Deserialize, Clone)]
pub struct SchedulerConfig {
    #[serde(rename = "_class_name")]
    pub scheduler_type: SchedulerType,
    pub base_image_seq_len: usize,
    pub base_shift: f64,
    pub max_image_seq_len: usize,
    pub max_shift: f64,
    pub shift: f64,
    pub use_dynamic_shifting: bool,
}

#[derive(Deserialize, Clone)]
pub enum SchedulerType {
    #[serde(rename = "FlowMatchEulerDiscreteScheduler")]
    FlowMatchEulerDiscrete,
}

fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 {
    let e = mu.exp();
    e / (e + (1. / t - 1.).powf(sigma))
}

impl SchedulerConfig {
    pub fn get_timesteps(&self, num_steps: usize, mu: Option<f64>) -> Result<Vec<f64>> {
        let mut sigmas: Vec<f64> = (0..=num_steps)
            .map(|v| v as f64 / num_steps as f64)
            .rev()
            .collect();
        match self.scheduler_type {
            SchedulerType::FlowMatchEulerDiscrete => {
                if self.use_dynamic_shifting {
                    let mu = mu.context("`mu` is required for dynamic shifting")?;
                    sigmas = sigmas
                        .iter()
                        .map(|sigma| time_shift(mu, 1., *sigma))
                        .collect();
                } else {
                    sigmas = sigmas
                        .iter()
                        .map(|sigma| self.shift * sigma / (1. + (self.shift - 1.) * sigma))
                        .collect();
                }

                Ok(sigmas)
            }
        }
    }
}