mistralrs_bench/
main.rs

1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5    get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6    parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7    DeviceMapSetting, DrySamplingParams, IsqType, Loader, LoaderBuilder, MemoryGpuConfig,
8    MistralRs, MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
9    RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::sync::Arc;
12use std::{fmt::Display, num::NonZeroUsize};
13use tokio::sync::mpsc::channel;
14use tracing::{info, warn};
15
16enum TestName {
17    Prompt(usize),
18    Gen(usize),
19}
20
21impl Display for TestName {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        let name = match self {
24            TestName::Prompt(n) => format!("pp {}", n),
25            TestName::Gen(n) => format!("tg {}", n),
26        };
27        write!(f, "{}", name)
28    }
29}
30
31struct BenchResult {
32    usages: Vec<Usage>,
33    concurrency: usize,
34    test_name: TestName,
35}
36
37struct UncertainTokSec {
38    mean: f32,
39    std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45    }
46}
47
48fn run_bench(
49    mistralrs: Arc<MistralRs>,
50    prompt: RequestMessage,
51    n_gen: usize,
52    concurrency: usize,
53    repetitions: usize,
54    test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56    let sampling_params = SamplingParams {
57        temperature: Some(0.1),
58        top_k: Some(32),
59        top_p: Some(0.1),
60        min_p: Some(0.05),
61        top_n_logprobs: 0,
62        frequency_penalty: Some(0.1),
63        presence_penalty: Some(0.1),
64        max_len: Some(n_gen),
65        stop_toks: None,
66        logits_bias: None,
67        n_choices: 1,
68        dry_params: Some(DrySamplingParams::default()),
69    };
70    let sender = mistralrs.get_sender().unwrap();
71    let (tx, mut rx) = channel(10_000);
72
73    let req = Request::Normal(NormalRequest {
74        id: mistralrs.next_request_id(),
75        messages: prompt,
76        sampling_params: sampling_params.clone(),
77        response: tx,
78        return_logprobs: false,
79        is_streaming: false,
80        constraint: Constraint::None,
81        suffix: None,
82        adapters: None,
83        tools: None,
84        tool_choice: None,
85        logits_processors: None,
86        return_raw_logits: false,
87        web_search_options: None,
88    });
89
90    let mut usages = Vec::new();
91
92    for _ in 0..repetitions {
93        for _ in 0..concurrency {
94            sender
95                .blocking_send(req.clone())
96                .expect("Expected receiver.");
97        }
98        for _ in 0..concurrency {
99            match rx.blocking_recv() {
100                Some(r) => match r {
101                    Response::InternalError(e) => {
102                        unreachable!("Got an internal error: {e:?}");
103                    }
104                    Response::ModelError(e, resp) => {
105                        unreachable!("Got a model error: {e:?}, response: {resp:?}");
106                    }
107                    Response::ValidationError(e) => {
108                        unreachable!("Got a validation error: {e:?}");
109                    }
110                    Response::Done(res) => {
111                        usages.push(res.usage);
112                    }
113                    Response::Chunk(_) => unreachable!(),
114                    Response::CompletionModelError(_, _) => unreachable!(),
115                    Response::CompletionDone(res) => {
116                        usages.push(res.usage);
117                    }
118                    Response::CompletionChunk(_) => unreachable!(),
119                    Response::ImageGeneration(_) => unreachable!(),
120                    Response::Raw { .. } => unreachable!(),
121                },
122                None => unreachable!("Expected a Done response, got None",),
123            }
124        }
125    }
126
127    Ok(BenchResult {
128        usages,
129        concurrency,
130        test_name,
131    })
132}
133
134fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
135    let ts_measurements = match result.test_name {
136        TestName::Prompt(_) => result
137            .usages
138            .iter()
139            .map(|u| u.avg_prompt_tok_per_sec)
140            .collect::<Vec<_>>(),
141        TestName::Gen(_) => result
142            .usages
143            .iter()
144            .map(|u| u.avg_compl_tok_per_sec)
145            .collect::<Vec<_>>(),
146    };
147    // Calculate uncertainty
148    let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
149    let variance = ts_measurements
150        .iter()
151        .map(|e| (mean - e).powf(2.))
152        .sum::<f32>()
153        / ts_measurements.len() as f32;
154    let std_dev = variance.sqrt();
155    UncertainTokSec { mean, std_dev }
156}
157
158fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
159    let ms_tok_measurements = match result.test_name {
160        TestName::Prompt(_) => result
161            .usages
162            .iter()
163            .map(|u| 1000. / u.avg_prompt_tok_per_sec)
164            .collect::<Vec<_>>(),
165        TestName::Gen(_) => result
166            .usages
167            .iter()
168            .map(|u| 1000. / u.avg_compl_tok_per_sec)
169            .collect::<Vec<_>>(),
170    };
171    // Calculate uncertainty
172    let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
173    let variance = ms_tok_measurements
174        .iter()
175        .map(|e| (mean - e).powf(2.))
176        .sum::<f32>()
177        / ms_tok_measurements.len() as f32;
178    let std_dev = variance.sqrt();
179    UncertainTokSec { mean, std_dev }
180}
181
182fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
183    let backend = match device {
184        Device::Cpu => "CPU",
185        Device::Cuda(_) => "CUDA",
186        Device::Metal(_) => "Metal",
187    };
188    let results: Vec<Vec<CellStruct>> = results
189        .into_iter()
190        .map(|r| {
191            vec![
192                model.cell(),
193                backend.cell(),
194                r.test_name.to_string().cell(),
195                get_tok_s(&r).cell().justify(Justify::Right),
196                get_ms_tok(&r).cell().justify(Justify::Right),
197                r.concurrency.cell().justify(Justify::Right),
198                (get_tok_s(&r).mean * r.concurrency as f32)
199                    .cell()
200                    .justify(Justify::Right),
201            ]
202        })
203        .collect();
204
205    let table = results
206        .table()
207        .title(vec![
208            "model".cell().bold(true),
209            // "size".cell().bold(true),
210            // "params".cell().bold(true),
211            "backend".cell().bold(true),
212            // "ngl".cell().bold(true),
213            "test".cell().bold(true),
214            "t/s".cell().bold(true),
215            "ms/t".cell().bold(true),
216            "concurrency".cell().bold(true),
217            "throughput/s".cell().bold(true),
218        ])
219        .bold(true);
220    print_stdout(table).expect("print table");
221}
222
223fn warmup_run(mistralrs: Arc<MistralRs>) {
224    let sampling_params = SamplingParams {
225        temperature: Some(0.1),
226        top_k: Some(32),
227        top_p: Some(0.1),
228        min_p: Some(0.05),
229        top_n_logprobs: 0,
230        frequency_penalty: Some(0.1),
231        presence_penalty: Some(0.1),
232        max_len: Some(5),
233        stop_toks: None,
234        logits_bias: None,
235        n_choices: 1,
236        dry_params: Some(DrySamplingParams::default()),
237    };
238    let sender = mistralrs.get_sender().unwrap();
239    let (tx, mut rx) = channel(10_000);
240
241    let req = Request::Normal(NormalRequest {
242        id: mistralrs.next_request_id(),
243        messages: RequestMessage::Completion {
244            text: "Hello!".to_string(),
245            echo_prompt: false,
246            best_of: None,
247        },
248        sampling_params: sampling_params.clone(),
249        response: tx,
250        return_logprobs: false,
251        is_streaming: false,
252        constraint: Constraint::None,
253        suffix: None,
254        adapters: None,
255        tools: None,
256        tool_choice: None,
257        logits_processors: None,
258        return_raw_logits: false,
259        web_search_options: None,
260    });
261
262    sender
263        .blocking_send(req.clone())
264        .expect("Expected receiver.");
265
266    let _ = rx.blocking_recv();
267}
268
269#[derive(Parser)]
270#[command(version, about, long_about = None)]
271struct Args {
272    /// Model
273    #[clap(subcommand)]
274    model: ModelSelected,
275
276    /// Integer seed to ensure reproducible random number generation.
277    #[arg(short, long)]
278    seed: Option<u64>,
279
280    /// Number of prompt tokens to run.
281    #[arg(long, short = 'p', default_value_t = 512)]
282    n_prompt: usize,
283
284    /// Number of generations tokens to run.
285    #[arg(long, short = 'g', default_value_t = 128)]
286    n_gen: usize,
287
288    /// Number of concurrent requests to run. Default is 1
289    #[clap(short, long, value_parser, value_delimiter = ',')]
290    concurrency: Option<Vec<usize>>,
291
292    /// Number of times to repeat each test.
293    #[arg(long, short, default_value_t = 5)]
294    repetitions: usize,
295
296    /// NOTE: This can be omitted to use automatic device mapping!
297    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
298    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
299    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
300    #[arg(short, long, value_parser, value_delimiter = ';')]
301    num_device_layers: Option<Vec<String>>,
302
303    /// In-situ quantization to apply.
304    #[arg(long = "isq", value_parser = parse_isq_value)]
305    in_situ_quant: Option<IsqType>,
306
307    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
308    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
309    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
310    #[arg(long = "pa-gpu-mem")]
311    paged_attn_gpu_mem: Option<usize>,
312
313    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
314    /// If this is not set and the device is CUDA, it will default to `0.9`.
315    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
316    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
317    #[arg(long = "pa-gpu-mem-usage")]
318    paged_attn_gpu_mem_usage: Option<f32>,
319
320    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
321    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
322    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
323    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
324    #[arg(long = "pa-ctxt-len")]
325    paged_ctxt_len: Option<usize>,
326
327    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
328    /// PagedAttention is only supported on CUDA and is always automatically activated.
329    #[arg(long = "pa-blk-size")]
330    paged_attn_block_size: Option<usize>,
331
332    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
333    #[arg(long = "no-paged-attn", default_value_t = false)]
334    no_paged_attn: bool,
335
336    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
337    #[arg(long = "paged-attn", default_value_t = false)]
338    paged_attn: bool,
339
340    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
341    #[arg(long = "prompt-batchsize")]
342    prompt_chunksize: Option<usize>,
343}
344
345fn main() -> anyhow::Result<()> {
346    let mut args = Args::parse();
347    initialize_logging();
348
349    args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
350
351    let use_flash_attn = mistralrs_core::using_flash_attn();
352
353    let prompt_chunksize = match args.prompt_chunksize {
354        Some(0) => {
355            anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
356        }
357        Some(x) => Some(NonZeroUsize::new(x).unwrap()),
358        None => None,
359    };
360
361    let dtype = get_model_dtype(&args.model)?;
362    let auto_device_map_params = get_auto_device_map_params(&args.model)?;
363
364    let max_seq_len = auto_device_map_params.max_seq_len();
365
366    let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
367        .with_use_flash_attn(use_flash_attn)
368        .with_prompt_chunksize(prompt_chunksize)
369        .build()?;
370    let model_name = loader.get_id();
371
372    #[cfg(feature = "metal")]
373    let device = Device::new_metal(0)?;
374    #[cfg(not(feature = "metal"))]
375    let device = if mistralrs_core::distributed::use_nccl() {
376        Device::Cpu
377    } else {
378        Device::cuda_if_available(0)?
379    };
380
381    if let Some(seed) = args.seed {
382        device.set_seed(seed)?;
383    }
384
385    let token_source = TokenSource::CacheToken;
386    info!(
387        "avx: {}, neon: {}, simd128: {}, f16c: {}",
388        candle_core::utils::with_avx(),
389        candle_core::utils::with_neon(),
390        candle_core::utils::with_simd128(),
391        candle_core::utils::with_f16c()
392    );
393    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
394    if use_flash_attn {
395        info!("Using flash attention.");
396    }
397    if use_flash_attn && loader.get_kind().is_quantized() {
398        warn!("Using flash attention with a quantized model has no effect!")
399    }
400    info!("Model kind is: {}", loader.get_kind().to_string());
401
402    // Parse device mapper
403    let mapper = if let Some(device_layers) = args.num_device_layers {
404        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
405            let layers = device_layers[0].parse::<usize>().unwrap();
406            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
407                DeviceLayerMapMetadata { ordinal: 0, layers },
408            ]))
409        } else {
410            let mut mapping = Vec::new();
411            for layer in device_layers {
412                let split = layer.splitn(2, ':').collect::<Vec<_>>();
413                if split.len() < 2 {
414                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
415                }
416                let ord = split[0]
417                    .parse::<usize>()
418                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
419                let num = split[1]
420                    .parse::<usize>()
421                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
422                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
423                    if *ordinal == ord {
424                        panic!("Duplicate ordinal {ord}");
425                    }
426                }
427                mapping.push(DeviceLayerMapMetadata {
428                    ordinal: ord,
429                    layers: num,
430                });
431            }
432            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
433        }
434    } else {
435        DeviceMapSetting::Auto(auto_device_map_params)
436    };
437
438    let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
439        args.no_paged_attn
440    } else if device.is_metal() {
441        !args.paged_attn
442    } else {
443        true
444    };
445
446    // Allocate 0.5 GB of CPU memory just as a placeholder.
447    // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
448    let cache_config = match (
449        args.paged_attn_block_size,
450        args.paged_attn_gpu_mem,
451        args.paged_attn_gpu_mem_usage,
452        args.paged_ctxt_len,
453        paged_attn_supported(),
454        no_paged_attn,
455    ) {
456        (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
457            block_size,
458            512,
459            MemoryGpuConfig::ContextSize(max_seq_len),
460        )?),
461        (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
462            block_size,
463            512,
464            MemoryGpuConfig::ContextSize(ctxt),
465        )?),
466        (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
467            block_size,
468            512,
469            MemoryGpuConfig::Utilization(f),
470        )?),
471        (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
472            block_size,
473            512,
474            MemoryGpuConfig::MbAmount(m),
475        )?),
476        (block_size, Some(_m), Some(f), None, true, false) => {
477            info!("Both memory size, and usage were specified, defaulting to the usage value.");
478            Some(PagedAttentionConfig::new(
479                block_size,
480                512,
481                MemoryGpuConfig::Utilization(f),
482            )?)
483        }
484        (block_size, Some(_m), None, Some(ctxt), true, false) => {
485            info!("All memory size and ctxt len, defaulting to the context len value.");
486            Some(PagedAttentionConfig::new(
487                block_size,
488                512,
489                MemoryGpuConfig::ContextSize(ctxt),
490            )?)
491        }
492        (block_size, None, Some(f), Some(_ctxt), true, false) => {
493            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
494            Some(PagedAttentionConfig::new(
495                block_size,
496                512,
497                MemoryGpuConfig::Utilization(f),
498            )?)
499        }
500        (_, _, _, _, _, _) => None,
501    };
502
503    let pipeline = loader.load_model_from_hf(
504        None,
505        token_source,
506        &dtype,
507        &device,
508        false,
509        mapper,
510        args.in_situ_quant,
511        cache_config,
512    )?;
513    info!("Model loaded.");
514
515    let scheduler_config = if cache_config.is_some() {
516        // Handle case where we may have device mapping
517        if let Some(ref cache_config) = pipeline.blocking_lock().get_metadata().cache_config {
518            SchedulerConfig::PagedAttentionMeta {
519                max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
520                config: cache_config.clone(),
521            }
522        } else {
523            SchedulerConfig::DefaultScheduler {
524                method: DefaultSchedulerMethod::Fixed(
525                    (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
526                        .try_into()
527                        .unwrap(),
528                ),
529            }
530        }
531    } else {
532        SchedulerConfig::DefaultScheduler {
533            method: DefaultSchedulerMethod::Fixed(
534                (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
535                    .try_into()
536                    .unwrap(),
537            ),
538        }
539    };
540    let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
541        .with_no_prefix_cache(true)
542        .with_disable_eos_stop(true)
543        .build();
544
545    info!("Starting warmup run.");
546    warmup_run(mistralrs.clone());
547    info!("Finished warmup run.");
548    info!("Starting benchmarks.");
549
550    for concurrency in args.concurrency.as_ref().unwrap() {
551        let mut results = vec![];
552        if args.n_gen > 0 {
553            let r = run_bench(
554                mistralrs.clone(),
555                RequestMessage::Completion {
556                    text: "Rust".to_string(),
557                    echo_prompt: false,
558                    best_of: None,
559                },
560                args.n_gen - 1,
561                *concurrency,
562                args.repetitions,
563                TestName::Gen(args.n_gen),
564            )?;
565            results.push(r);
566        }
567
568        if args.n_prompt > 0 {
569            let tks = (1000..1000 + args.n_prompt as u32).collect();
570            let r = run_bench(
571                mistralrs.clone(),
572                RequestMessage::CompletionTokens(tks),
573                1,
574                *concurrency,
575                args.repetitions,
576                TestName::Prompt(args.n_prompt),
577            )?;
578
579            results.push(r);
580        }
581
582        print_usage(&model_name, &device, results);
583    }
584
585    Ok(())
586}