mistralrs_bench/
main.rs

1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5    get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6    parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7    DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8    MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, PagedCacheType, Request,
9    RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::fmt::Display;
12use std::sync::Arc;
13use tokio::sync::mpsc::channel;
14use tracing::info;
15
16enum TestName {
17    Prompt(usize),
18    Gen(usize),
19}
20
21impl Display for TestName {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        let name = match self {
24            TestName::Prompt(n) => format!("pp {n}"),
25            TestName::Gen(n) => format!("tg {n}"),
26        };
27        write!(f, "{name}")
28    }
29}
30
31struct BenchResult {
32    usages: Vec<Usage>,
33    concurrency: usize,
34    test_name: TestName,
35}
36
37struct UncertainTokSec {
38    mean: f32,
39    std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45    }
46}
47
48async fn run_bench(
49    mistralrs: Arc<MistralRs>,
50    prompt: RequestMessage,
51    n_gen: usize,
52    concurrency: usize,
53    repetitions: usize,
54    test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56    let sampling_params = SamplingParams {
57        temperature: Some(0.1),
58        top_k: Some(32),
59        top_p: Some(0.1),
60        min_p: Some(0.05),
61        top_n_logprobs: 0,
62        frequency_penalty: Some(0.1),
63        presence_penalty: Some(0.1),
64        repetition_penalty: None,
65        max_len: Some(n_gen),
66        stop_toks: None,
67        logits_bias: None,
68        n_choices: 1,
69        dry_params: Some(DrySamplingParams::default()),
70    };
71    let sender = mistralrs.get_sender(None).unwrap();
72    let (tx, mut rx) = channel(10_000);
73
74    let req = Request::Normal(Box::new(NormalRequest {
75        id: mistralrs.next_request_id(),
76        messages: prompt,
77        sampling_params: sampling_params.clone(),
78        response: tx,
79        return_logprobs: false,
80        is_streaming: false,
81        constraint: Constraint::None,
82        suffix: None,
83        tools: None,
84        tool_choice: None,
85        logits_processors: None,
86        return_raw_logits: false,
87        web_search_options: None,
88        model_id: None,
89        truncate_sequence: false,
90    }));
91
92    let mut usages = Vec::new();
93
94    for _ in 0..repetitions {
95        for _ in 0..concurrency {
96            if sender.send(req.clone()).await.is_err() {
97                eprintln!("Receiver disconnected");
98            }
99        }
100        for _ in 0..concurrency {
101            match rx.recv().await {
102                Some(r) => match r {
103                    Response::InternalError(e) => {
104                        unreachable!("Got an internal error: {e:?}");
105                    }
106                    Response::ModelError(e, resp) => {
107                        unreachable!("Got a model error: {e:?}, response: {resp:?}");
108                    }
109                    Response::ValidationError(e) => {
110                        unreachable!("Got a validation error: {e:?}");
111                    }
112                    Response::Done(res) => {
113                        usages.push(res.usage);
114                    }
115                    Response::Chunk(_) => unreachable!(),
116                    Response::CompletionModelError(_, _) => unreachable!(),
117                    Response::CompletionDone(res) => {
118                        usages.push(res.usage);
119                    }
120                    Response::CompletionChunk(_) => unreachable!(),
121                    Response::ImageGeneration(_) => unreachable!(),
122                    Response::Speech { .. } => unreachable!(),
123                    Response::Raw { .. } => unreachable!(),
124                    Response::Embeddings { .. } => unreachable!(),
125                },
126                None => unreachable!("Expected a Done response, got None",),
127            }
128        }
129    }
130
131    Ok(BenchResult {
132        usages,
133        concurrency,
134        test_name,
135    })
136}
137
138fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
139    let ts_measurements = match result.test_name {
140        TestName::Prompt(_) => result
141            .usages
142            .iter()
143            .map(|u| u.avg_prompt_tok_per_sec)
144            .collect::<Vec<_>>(),
145        TestName::Gen(_) => result
146            .usages
147            .iter()
148            .map(|u| u.avg_compl_tok_per_sec)
149            .collect::<Vec<_>>(),
150    };
151    // Calculate uncertainty
152    let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
153    let variance = ts_measurements
154        .iter()
155        .map(|e| (mean - e).powf(2.))
156        .sum::<f32>()
157        / ts_measurements.len() as f32;
158    let std_dev = variance.sqrt();
159    UncertainTokSec { mean, std_dev }
160}
161
162fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
163    let ms_tok_measurements = match result.test_name {
164        TestName::Prompt(_) => result
165            .usages
166            .iter()
167            .map(|u| 1000. / u.avg_prompt_tok_per_sec)
168            .collect::<Vec<_>>(),
169        TestName::Gen(_) => result
170            .usages
171            .iter()
172            .map(|u| 1000. / u.avg_compl_tok_per_sec)
173            .collect::<Vec<_>>(),
174    };
175    // Calculate uncertainty
176    let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
177    let variance = ms_tok_measurements
178        .iter()
179        .map(|e| (mean - e).powf(2.))
180        .sum::<f32>()
181        / ms_tok_measurements.len() as f32;
182    let std_dev = variance.sqrt();
183    UncertainTokSec { mean, std_dev }
184}
185
186fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
187    let backend = match device {
188        Device::Cpu => "CPU",
189        Device::Cuda(_) => "CUDA",
190        Device::Metal(_) => "Metal",
191    };
192    let results: Vec<Vec<CellStruct>> = results
193        .into_iter()
194        .map(|r| {
195            vec![
196                model.cell(),
197                backend.cell(),
198                r.test_name.to_string().cell(),
199                get_tok_s(&r).cell().justify(Justify::Right),
200                get_ms_tok(&r).cell().justify(Justify::Right),
201                r.concurrency.cell().justify(Justify::Right),
202                (get_tok_s(&r).mean * r.concurrency as f32)
203                    .cell()
204                    .justify(Justify::Right),
205            ]
206        })
207        .collect();
208
209    let table = results
210        .table()
211        .title(vec![
212            "model".cell().bold(true),
213            // "size".cell().bold(true),
214            // "params".cell().bold(true),
215            "backend".cell().bold(true),
216            // "ngl".cell().bold(true),
217            "test".cell().bold(true),
218            "t/s".cell().bold(true),
219            "ms/t".cell().bold(true),
220            "concurrency".cell().bold(true),
221            "throughput/s".cell().bold(true),
222        ])
223        .bold(true);
224    print_stdout(table).expect("print table");
225}
226
227async fn warmup_run(mistralrs: Arc<MistralRs>) {
228    let sampling_params = SamplingParams {
229        max_len: Some(1),
230        ..SamplingParams::deterministic()
231    };
232    let sender = mistralrs.get_sender(None).unwrap();
233    let (tx, mut rx) = channel(10_000);
234
235    let req = Request::Normal(Box::new(NormalRequest {
236        id: mistralrs.next_request_id(),
237        messages: RequestMessage::Completion {
238            text: "Hello!".to_string(),
239            echo_prompt: false,
240            best_of: None,
241        },
242        sampling_params: sampling_params.clone(),
243        response: tx,
244        return_logprobs: false,
245        is_streaming: false,
246        constraint: Constraint::None,
247        suffix: None,
248        tools: None,
249        tool_choice: None,
250        logits_processors: None,
251        return_raw_logits: false,
252        web_search_options: None,
253        model_id: None,
254        truncate_sequence: false,
255    }));
256
257    if sender.send(req.clone()).await.is_err() {
258        eprintln!("Receiver disconnected");
259    }
260
261    let _ = rx.recv().await;
262}
263
264fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
265    s.parse()
266}
267
268#[derive(Parser)]
269#[command(version, about, long_about = None)]
270struct Args {
271    /// Model
272    #[clap(subcommand)]
273    model: ModelSelected,
274
275    /// Integer seed to ensure reproducible random number generation.
276    #[arg(short, long)]
277    seed: Option<u64>,
278
279    /// Number of prompt tokens to run.
280    #[arg(long, short = 'p', default_value_t = 512)]
281    n_prompt: usize,
282
283    /// Number of generations tokens to run.
284    #[arg(long, short = 'g', default_value_t = 128)]
285    n_gen: usize,
286
287    /// Number of concurrent requests to run. Default is 1
288    #[clap(short, long, value_parser, value_delimiter = ',')]
289    concurrency: Option<Vec<usize>>,
290
291    /// Number of times to repeat each test.
292    #[arg(long, short, default_value_t = 5)]
293    repetitions: usize,
294
295    /// NOTE: This can be omitted to use automatic device mapping!
296    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
297    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
298    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
299    #[arg(short, long, value_parser, value_delimiter = ';')]
300    num_device_layers: Option<Vec<String>>,
301
302    /// In-situ quantization to apply.
303    #[arg(long = "isq")]
304    in_situ_quant: Option<String>,
305
306    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
307    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
308    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
309    #[arg(long = "pa-gpu-mem")]
310    paged_attn_gpu_mem: Option<usize>,
311
312    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
313    /// If this is not set and the device is CUDA, it will default to `0.9`.
314    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
315    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
316    #[arg(long = "pa-gpu-mem-usage")]
317    paged_attn_gpu_mem_usage: Option<f32>,
318
319    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
320    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
321    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
322    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
323    #[arg(long = "pa-ctxt-len")]
324    paged_ctxt_len: Option<usize>,
325
326    /// PagedAttention KV cache type (auto or f8e4m3).
327    /// Defaults to `auto`.
328    #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
329    cache_type: Option<PagedCacheType>,
330
331    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
332    /// PagedAttention is only supported on CUDA and is always automatically activated.
333    #[arg(long = "pa-blk-size")]
334    paged_attn_block_size: Option<usize>,
335
336    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
337    #[arg(long = "no-paged-attn", default_value_t = false)]
338    no_paged_attn: bool,
339
340    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
341    #[arg(long = "paged-attn", default_value_t = false)]
342    paged_attn: bool,
343}
344
345#[tokio::main]
346async fn main() -> anyhow::Result<()> {
347    let mut args = Args::parse();
348    initialize_logging();
349
350    args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
351
352    let dtype = get_model_dtype(&args.model)?;
353    let auto_device_map_params = get_auto_device_map_params(&args.model)?;
354
355    let max_seq_len = auto_device_map_params.max_seq_len();
356
357    let loader: Box<dyn Loader> = LoaderBuilder::new(args.model).build()?;
358    let model_name = loader.get_id();
359
360    #[cfg(feature = "metal")]
361    let device = Device::new_metal(0)?;
362    #[cfg(not(feature = "metal"))]
363    let device = if mistralrs_core::distributed::use_nccl() {
364        Device::Cpu
365    } else {
366        Device::cuda_if_available(0)?
367    };
368
369    if let Some(seed) = args.seed {
370        device.set_seed(seed)?;
371    }
372
373    let token_source = TokenSource::CacheToken;
374    info!(
375        "avx: {}, neon: {}, simd128: {}, f16c: {}",
376        candle_core::utils::with_avx(),
377        candle_core::utils::with_neon(),
378        candle_core::utils::with_simd128(),
379        candle_core::utils::with_f16c()
380    );
381    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
382    info!("Model kind is: {}", loader.get_kind().to_string());
383
384    // Parse device mapper
385    let mapper = if let Some(device_layers) = args.num_device_layers {
386        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
387            let layers = device_layers[0].parse::<usize>().unwrap();
388            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
389                DeviceLayerMapMetadata { ordinal: 0, layers },
390            ]))
391        } else {
392            let mut mapping = Vec::new();
393            for layer in device_layers {
394                let split = layer.splitn(2, ':').collect::<Vec<_>>();
395                if split.len() < 2 {
396                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
397                }
398                let ord = split[0]
399                    .parse::<usize>()
400                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
401                let num = split[1]
402                    .parse::<usize>()
403                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
404                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
405                    if *ordinal == ord {
406                        panic!("Duplicate ordinal {ord}");
407                    }
408                }
409                mapping.push(DeviceLayerMapMetadata {
410                    ordinal: ord,
411                    layers: num,
412                });
413            }
414            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
415        }
416    } else {
417        DeviceMapSetting::Auto(auto_device_map_params)
418    };
419
420    let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
421        args.no_paged_attn
422    } else if device.is_metal() {
423        !args.paged_attn
424    } else {
425        true
426    };
427
428    let cache_config = match (
429        args.paged_attn_block_size,
430        args.paged_attn_gpu_mem,
431        args.paged_attn_gpu_mem_usage,
432        args.paged_ctxt_len,
433        paged_attn_supported(),
434        no_paged_attn,
435    ) {
436        (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
437            block_size,
438            MemoryGpuConfig::ContextSize(max_seq_len),
439            args.cache_type.unwrap_or_default(),
440        )?),
441        (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
442            block_size,
443            MemoryGpuConfig::ContextSize(ctxt),
444            args.cache_type.unwrap_or_default(),
445        )?),
446        (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
447            block_size,
448            MemoryGpuConfig::Utilization(f),
449            args.cache_type.unwrap_or_default(),
450        )?),
451        (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
452            block_size,
453            MemoryGpuConfig::MbAmount(m),
454            args.cache_type.unwrap_or_default(),
455        )?),
456        (block_size, Some(_m), Some(f), None, true, false) => {
457            info!("Both memory size, and usage were specified, defaulting to the usage value.");
458            Some(PagedAttentionConfig::new(
459                block_size,
460                MemoryGpuConfig::Utilization(f),
461                args.cache_type.unwrap_or_default(),
462            )?)
463        }
464        (block_size, Some(_m), None, Some(ctxt), true, false) => {
465            info!("All memory size and ctxt len, defaulting to the context len value.");
466            Some(PagedAttentionConfig::new(
467                block_size,
468                MemoryGpuConfig::ContextSize(ctxt),
469                args.cache_type.unwrap_or_default(),
470            )?)
471        }
472        (block_size, None, Some(f), Some(_ctxt), true, false) => {
473            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
474            Some(PagedAttentionConfig::new(
475                block_size,
476                MemoryGpuConfig::Utilization(f),
477                args.cache_type.unwrap_or_default(),
478            )?)
479        }
480        (_, _, _, _, _, _) => None,
481    };
482
483    let isq = args
484        .in_situ_quant
485        .as_ref()
486        .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
487
488    let pipeline = loader.load_model_from_hf(
489        None,
490        token_source,
491        &dtype,
492        &device,
493        false,
494        mapper,
495        isq,
496        cache_config,
497    )?;
498    info!("Model loaded.");
499
500    let scheduler_config = if cache_config.is_some() {
501        // Handle case where we may have device mapping
502        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
503            SchedulerConfig::PagedAttentionMeta {
504                max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
505                config: cache_config.clone(),
506            }
507        } else {
508            SchedulerConfig::DefaultScheduler {
509                method: DefaultSchedulerMethod::Fixed(
510                    (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
511                        .try_into()
512                        .unwrap(),
513                ),
514            }
515        }
516    } else {
517        SchedulerConfig::DefaultScheduler {
518            method: DefaultSchedulerMethod::Fixed(
519                (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
520                    .try_into()
521                    .unwrap(),
522            ),
523        }
524    };
525    let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
526        .with_no_prefix_cache(true)
527        .with_disable_eos_stop(true)
528        .build()
529        .await;
530
531    info!("Starting warmup run.");
532    warmup_run(mistralrs.clone()).await;
533    info!("Finished warmup run.");
534    info!("Starting benchmarks.");
535
536    for concurrency in args.concurrency.as_ref().unwrap() {
537        let mut results = vec![];
538        if args.n_gen > 0 {
539            let r = run_bench(
540                mistralrs.clone(),
541                RequestMessage::Completion {
542                    text: "Rust".to_string(),
543                    echo_prompt: false,
544                    best_of: None,
545                },
546                args.n_gen - 1,
547                *concurrency,
548                args.repetitions,
549                TestName::Gen(args.n_gen),
550            )
551            .await?;
552            results.push(r);
553        }
554
555        if args.n_prompt > 0 {
556            let tks = (1000..1000 + args.n_prompt as u32).collect();
557            let r = run_bench(
558                mistralrs.clone(),
559                RequestMessage::CompletionTokens(tks),
560                1,
561                *concurrency,
562                args.repetitions,
563                TestName::Prompt(args.n_prompt),
564            )
565            .await?;
566
567            results.push(r);
568        }
569
570        print_usage(&model_name, &device, results);
571    }
572
573    Ok(())
574}