mistralrs/
diffusion_model.rsuse mistralrs_core::*;
use crate::{best_device, Model};
pub struct DiffusionModelBuilder {
pub(crate) model_id: String,
pub(crate) token_source: TokenSource,
pub(crate) hf_revision: Option<String>,
pub(crate) loader_type: DiffusionLoaderType,
pub(crate) dtype: ModelDType,
pub(crate) force_cpu: bool,
pub(crate) use_flash_attn: bool,
pub(crate) max_num_seqs: usize,
pub(crate) with_logging: bool,
}
impl DiffusionModelBuilder {
pub fn new(model_id: impl ToString, loader_type: DiffusionLoaderType) -> Self {
Self {
model_id: model_id.to_string(),
use_flash_attn: cfg!(feature = "flash-attn"),
loader_type,
dtype: ModelDType::Auto,
force_cpu: false,
token_source: TokenSource::CacheToken,
hf_revision: None,
max_num_seqs: 32,
with_logging: false,
}
}
pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
self.dtype = dtype;
self
}
pub fn with_force_cpu(mut self) -> Self {
self.force_cpu = true;
self
}
pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
self.token_source = token_source;
self
}
pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
self.hf_revision = Some(revision.to_string());
self
}
pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
self.max_num_seqs = max_num_seqs;
self
}
pub fn with_logging(mut self) -> Self {
self.with_logging = true;
self
}
pub async fn build(self) -> anyhow::Result<Model> {
let config = DiffusionSpecificConfig {
use_flash_attn: self.use_flash_attn,
};
if self.with_logging {
initialize_logging();
}
let loader =
DiffusionLoaderBuilder::new(config, Some(self.model_id)).build(self.loader_type);
let pipeline = loader.load_model_from_hf(
self.hf_revision,
self.token_source,
&self.dtype,
&best_device(self.force_cpu)?,
!self.with_logging,
DeviceMapMetadata::dummy(),
None,
None,
)?;
let scheduler_method = SchedulerConfig::DefaultScheduler {
method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
};
let runner =
MistralRsBuilder::new(pipeline, scheduler_method).with_gemm_full_precision_f16(true);
Ok(Model::new(runner.build()))
}
}