mistralrs/
vision_model.rsuse mistralrs_core::*;
use std::{num::NonZeroUsize, path::PathBuf};
use crate::{best_device, Model};
pub struct VisionModelBuilder {
pub(crate) model_id: String,
pub(crate) token_source: TokenSource,
pub(crate) hf_revision: Option<String>,
pub(crate) write_uqff: Option<PathBuf>,
pub(crate) from_uqff: Option<PathBuf>,
pub(crate) calibration_file: Option<PathBuf>,
pub(crate) chat_template: Option<String>,
pub(crate) tokenizer_json: Option<String>,
pub(crate) device_mapping: Option<DeviceMapMetadata>,
pub(crate) max_edge: Option<u32>,
pub(crate) use_flash_attn: bool,
pub(crate) prompt_batchsize: Option<NonZeroUsize>,
pub(crate) topology: Option<Topology>,
pub(crate) loader_type: VisionLoaderType,
pub(crate) dtype: ModelDType,
pub(crate) force_cpu: bool,
pub(crate) isq: Option<IsqType>,
pub(crate) max_num_seqs: usize,
pub(crate) with_logging: bool,
}
impl VisionModelBuilder {
pub fn new(model_id: impl ToString, loader_type: VisionLoaderType) -> Self {
Self {
model_id: model_id.to_string(),
use_flash_attn: cfg!(feature = "flash-attn"),
topology: None,
write_uqff: None,
from_uqff: None,
prompt_batchsize: None,
chat_template: None,
tokenizer_json: None,
max_edge: None,
loader_type,
dtype: ModelDType::Auto,
force_cpu: false,
token_source: TokenSource::CacheToken,
hf_revision: None,
isq: None,
max_num_seqs: 32,
with_logging: false,
device_mapping: None,
calibration_file: None,
}
}
pub fn with_prompt_batchsize(mut self, prompt_batchsize: NonZeroUsize) -> Self {
self.prompt_batchsize = Some(prompt_batchsize);
self
}
pub fn with_topology(mut self, topology: Topology) -> Self {
self.topology = Some(topology);
self
}
pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
self.chat_template = Some(chat_template.to_string());
self
}
pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
self.tokenizer_json = Some(tokenizer_json.to_string());
self
}
pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
self.dtype = dtype;
self
}
pub fn with_force_cpu(mut self) -> Self {
self.force_cpu = true;
self
}
pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
self.token_source = token_source;
self
}
pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
self.hf_revision = Some(revision.to_string());
self
}
pub fn with_isq(mut self, isq: IsqType) -> Self {
self.isq = Some(isq);
self
}
pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
self.calibration_file = Some(path);
self
}
pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
self.max_num_seqs = max_num_seqs;
self
}
pub fn with_logging(mut self) -> Self {
self.with_logging = true;
self
}
pub fn with_device_mapping(mut self, device_mapping: DeviceMapMetadata) -> Self {
self.device_mapping = Some(device_mapping);
self
}
pub fn from_uqff(mut self, path: PathBuf) -> Self {
self.from_uqff = Some(path);
self
}
pub fn from_max_edge(mut self, max_edge: u32) -> Self {
self.max_edge = Some(max_edge);
self
}
pub fn write_uqff(mut self, path: PathBuf) -> Self {
self.write_uqff = Some(path);
self
}
pub async fn build(self) -> anyhow::Result<Model> {
let config = VisionSpecificConfig {
use_flash_attn: self.use_flash_attn,
prompt_batchsize: self.prompt_batchsize,
topology: self.topology,
write_uqff: self.write_uqff,
from_uqff: self.from_uqff,
max_edge: self.max_edge,
calibration_file: self.calibration_file,
};
if self.with_logging {
initialize_logging();
}
let loader = VisionLoaderBuilder::new(
config,
self.chat_template,
self.tokenizer_json,
Some(self.model_id),
)
.build(self.loader_type);
let pipeline = loader.load_model_from_hf(
self.hf_revision,
self.token_source,
&self.dtype,
&best_device(self.force_cpu)?,
!self.with_logging,
self.device_mapping.unwrap_or(DeviceMapMetadata::dummy()),
self.isq,
None,
)?;
let scheduler_method = SchedulerConfig::DefaultScheduler {
method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
};
let runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(false)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(false);
Ok(Model::new(runner.build()))
}
}