diffusion_rs_cli/
main.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use cliclack::input;
use std::{path::PathBuf, time::Instant};

use clap::{Parser, Subcommand};
use diffusion_rs_core::{
    DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource,
};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;

const GUIDANCE_SCALE_DEFAULT: f64 = 0.0;

#[derive(Debug, Subcommand)]
pub enum SourceCommand {
    /// Load the model from a DDUF file.
    Dduf {
        /// DDUF file path
        #[arg(short, long)]
        file: String,
    },

    /// Load the model from some model ID (local path or Hugging Face model ID)
    ModelId {
        /// Model ID
        #[arg(short, long)]
        model_id: String,
    },
}

#[derive(Parser)]
struct Args {
    #[clap(subcommand)]
    source: SourceCommand,

    /// Hugging Face token. Useful for accessing gated repositories.
    /// By default, the Hugging Face token at ~/.cache/huggingface/token is used.
    #[arg(long)]
    token: Option<String>,

    /// Guidance scale to use. This is model specific. If not specified, defaults to 0.0.
    #[arg(short, long)]
    scale: Option<f64>,

    /// Number of denoising steps. This is model specific. A higher number of steps often means higher quality.
    #[arg(short, long)]
    num_steps: usize,

    /// Offloading setting to use for this model
    #[arg(short, long)]
    offloading: Option<Offloading>,

    /// DType for the model. The default is to use an automatic strategy with a fallback pattern: BF16 -> F16 -> F32
    #[arg(short, long, default_value = "auto")]
    dtype: ModelDType,
}

fn main() -> anyhow::Result<()> {
    let args = Args::parse();

    let filter = EnvFilter::builder()
        .with_default_directive(LevelFilter::INFO.into())
        .from_env_lossy();
    tracing_subscriber::fmt().with_env_filter(filter).init();

    let source = match args.source {
        SourceCommand::Dduf { file } => ModelSource::dduf(file)?,
        SourceCommand::ModelId { model_id } => ModelSource::from_model_id(model_id),
    };
    let token = args
        .token
        .map(TokenSource::Literal)
        .unwrap_or(TokenSource::CacheToken);

    let pipeline = Pipeline::load(source, false, token, None, args.offloading, &args.dtype)?;

    let height: usize = input("Height:")
        .default_input("720")
        .validate(|input: &String| {
            if input.parse::<usize>().map_err(|e| e.to_string())? == 0 {
                Err("Nonzero value is required!".to_string())
            } else {
                Ok(())
            }
        })
        .interact()?;
    let width: usize = input("Width:")
        .default_input("1280")
        .validate(|input: &String| {
            if input.parse::<usize>().map_err(|e| e.to_string())? == 0 {
                Err("Nonzero value is required!".to_string())
            } else {
                Ok(())
            }
        })
        .interact()?;

    loop {
        let prompt: String = input("Prompt:")
            .validate(|input: &String| {
                if input.is_empty() {
                    Err("Prompt is required!")
                } else {
                    Ok(())
                }
            })
            .interact()?;

        let start = Instant::now();

        let images = pipeline.forward(
            vec![prompt],
            DiffusionGenerationParams {
                height,
                width,
                num_steps: args.num_steps,
                guidance_scale: args.scale.unwrap_or(GUIDANCE_SCALE_DEFAULT),
            },
        )?;

        let end = Instant::now();
        println!(
            "Image generation took: {:.2}s",
            end.duration_since(start).as_secs_f32()
        );

        let out_file: String = input("Save image to:")
            .validate(|input: &String| {
                if input.is_empty() {
                    Err("Image path is required!")
                } else {
                    let path = PathBuf::from(input);
                    let ext = path.extension().ok_or("Extension is required!")?;
                    if !["png", "jpg"].contains(&ext.to_str().unwrap()) {
                        Err(".png or .jpg extension is required!")
                    } else {
                        Ok(())
                    }
                }
            })
            .interact()?;

        images[0].save(out_file)?;
    }
}