mistralrs_bench/
main.rs

1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5    get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6    parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7    DeviceMapSetting, DrySamplingParams, IsqType, Loader, LoaderBuilder, MemoryGpuConfig,
8    MistralRs, MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
9    RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::sync::Arc;
12use std::{fmt::Display, num::NonZeroUsize};
13use tokio::sync::mpsc::channel;
14use tracing::{info, warn};
15
16enum TestName {
17    Prompt(usize),
18    Gen(usize),
19}
20
21impl Display for TestName {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        let name = match self {
24            TestName::Prompt(n) => format!("pp {}", n),
25            TestName::Gen(n) => format!("tg {}", n),
26        };
27        write!(f, "{}", name)
28    }
29}
30
31struct BenchResult {
32    usages: Vec<Usage>,
33    concurrency: usize,
34    test_name: TestName,
35}
36
37struct UncertainTokSec {
38    mean: f32,
39    std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45    }
46}
47
48fn run_bench(
49    mistralrs: Arc<MistralRs>,
50    prompt: RequestMessage,
51    n_gen: usize,
52    concurrency: usize,
53    repetitions: usize,
54    test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56    let sampling_params = SamplingParams {
57        temperature: Some(0.1),
58        top_k: Some(32),
59        top_p: Some(0.1),
60        min_p: Some(0.05),
61        top_n_logprobs: 0,
62        frequency_penalty: Some(0.1),
63        presence_penalty: Some(0.1),
64        max_len: Some(n_gen),
65        stop_toks: None,
66        logits_bias: None,
67        n_choices: 1,
68        dry_params: Some(DrySamplingParams::default()),
69    };
70    let sender = mistralrs.get_sender().unwrap();
71    let (tx, mut rx) = channel(10_000);
72
73    let req = Request::Normal(NormalRequest {
74        id: mistralrs.next_request_id(),
75        messages: prompt,
76        sampling_params: sampling_params.clone(),
77        response: tx,
78        return_logprobs: false,
79        is_streaming: false,
80        constraint: Constraint::None,
81        suffix: None,
82        tools: None,
83        tool_choice: None,
84        logits_processors: None,
85        return_raw_logits: false,
86        web_search_options: None,
87    });
88
89    let mut usages = Vec::new();
90
91    for _ in 0..repetitions {
92        for _ in 0..concurrency {
93            sender
94                .blocking_send(req.clone())
95                .expect("Expected receiver.");
96        }
97        for _ in 0..concurrency {
98            match rx.blocking_recv() {
99                Some(r) => match r {
100                    Response::InternalError(e) => {
101                        unreachable!("Got an internal error: {e:?}");
102                    }
103                    Response::ModelError(e, resp) => {
104                        unreachable!("Got a model error: {e:?}, response: {resp:?}");
105                    }
106                    Response::ValidationError(e) => {
107                        unreachable!("Got a validation error: {e:?}");
108                    }
109                    Response::Done(res) => {
110                        usages.push(res.usage);
111                    }
112                    Response::Chunk(_) => unreachable!(),
113                    Response::CompletionModelError(_, _) => unreachable!(),
114                    Response::CompletionDone(res) => {
115                        usages.push(res.usage);
116                    }
117                    Response::CompletionChunk(_) => unreachable!(),
118                    Response::ImageGeneration(_) => unreachable!(),
119                    Response::Raw { .. } => unreachable!(),
120                },
121                None => unreachable!("Expected a Done response, got None",),
122            }
123        }
124    }
125
126    Ok(BenchResult {
127        usages,
128        concurrency,
129        test_name,
130    })
131}
132
133fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
134    let ts_measurements = match result.test_name {
135        TestName::Prompt(_) => result
136            .usages
137            .iter()
138            .map(|u| u.avg_prompt_tok_per_sec)
139            .collect::<Vec<_>>(),
140        TestName::Gen(_) => result
141            .usages
142            .iter()
143            .map(|u| u.avg_compl_tok_per_sec)
144            .collect::<Vec<_>>(),
145    };
146    // Calculate uncertainty
147    let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
148    let variance = ts_measurements
149        .iter()
150        .map(|e| (mean - e).powf(2.))
151        .sum::<f32>()
152        / ts_measurements.len() as f32;
153    let std_dev = variance.sqrt();
154    UncertainTokSec { mean, std_dev }
155}
156
157fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
158    let ms_tok_measurements = match result.test_name {
159        TestName::Prompt(_) => result
160            .usages
161            .iter()
162            .map(|u| 1000. / u.avg_prompt_tok_per_sec)
163            .collect::<Vec<_>>(),
164        TestName::Gen(_) => result
165            .usages
166            .iter()
167            .map(|u| 1000. / u.avg_compl_tok_per_sec)
168            .collect::<Vec<_>>(),
169    };
170    // Calculate uncertainty
171    let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
172    let variance = ms_tok_measurements
173        .iter()
174        .map(|e| (mean - e).powf(2.))
175        .sum::<f32>()
176        / ms_tok_measurements.len() as f32;
177    let std_dev = variance.sqrt();
178    UncertainTokSec { mean, std_dev }
179}
180
181fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
182    let backend = match device {
183        Device::Cpu => "CPU",
184        Device::Cuda(_) => "CUDA",
185        Device::Metal(_) => "Metal",
186    };
187    let results: Vec<Vec<CellStruct>> = results
188        .into_iter()
189        .map(|r| {
190            vec![
191                model.cell(),
192                backend.cell(),
193                r.test_name.to_string().cell(),
194                get_tok_s(&r).cell().justify(Justify::Right),
195                get_ms_tok(&r).cell().justify(Justify::Right),
196                r.concurrency.cell().justify(Justify::Right),
197                (get_tok_s(&r).mean * r.concurrency as f32)
198                    .cell()
199                    .justify(Justify::Right),
200            ]
201        })
202        .collect();
203
204    let table = results
205        .table()
206        .title(vec![
207            "model".cell().bold(true),
208            // "size".cell().bold(true),
209            // "params".cell().bold(true),
210            "backend".cell().bold(true),
211            // "ngl".cell().bold(true),
212            "test".cell().bold(true),
213            "t/s".cell().bold(true),
214            "ms/t".cell().bold(true),
215            "concurrency".cell().bold(true),
216            "throughput/s".cell().bold(true),
217        ])
218        .bold(true);
219    print_stdout(table).expect("print table");
220}
221
222fn warmup_run(mistralrs: Arc<MistralRs>) {
223    let sampling_params = SamplingParams {
224        temperature: Some(0.1),
225        top_k: Some(32),
226        top_p: Some(0.1),
227        min_p: Some(0.05),
228        top_n_logprobs: 0,
229        frequency_penalty: Some(0.1),
230        presence_penalty: Some(0.1),
231        max_len: Some(5),
232        stop_toks: None,
233        logits_bias: None,
234        n_choices: 1,
235        dry_params: Some(DrySamplingParams::default()),
236    };
237    let sender = mistralrs.get_sender().unwrap();
238    let (tx, mut rx) = channel(10_000);
239
240    let req = Request::Normal(NormalRequest {
241        id: mistralrs.next_request_id(),
242        messages: RequestMessage::Completion {
243            text: "Hello!".to_string(),
244            echo_prompt: false,
245            best_of: None,
246        },
247        sampling_params: sampling_params.clone(),
248        response: tx,
249        return_logprobs: false,
250        is_streaming: false,
251        constraint: Constraint::None,
252        suffix: None,
253        tools: None,
254        tool_choice: None,
255        logits_processors: None,
256        return_raw_logits: false,
257        web_search_options: None,
258    });
259
260    sender
261        .blocking_send(req.clone())
262        .expect("Expected receiver.");
263
264    let _ = rx.blocking_recv();
265}
266
267#[derive(Parser)]
268#[command(version, about, long_about = None)]
269struct Args {
270    /// Model
271    #[clap(subcommand)]
272    model: ModelSelected,
273
274    /// Integer seed to ensure reproducible random number generation.
275    #[arg(short, long)]
276    seed: Option<u64>,
277
278    /// Number of prompt tokens to run.
279    #[arg(long, short = 'p', default_value_t = 512)]
280    n_prompt: usize,
281
282    /// Number of generations tokens to run.
283    #[arg(long, short = 'g', default_value_t = 128)]
284    n_gen: usize,
285
286    /// Number of concurrent requests to run. Default is 1
287    #[clap(short, long, value_parser, value_delimiter = ',')]
288    concurrency: Option<Vec<usize>>,
289
290    /// Number of times to repeat each test.
291    #[arg(long, short, default_value_t = 5)]
292    repetitions: usize,
293
294    /// NOTE: This can be omitted to use automatic device mapping!
295    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
296    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
297    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
298    #[arg(short, long, value_parser, value_delimiter = ';')]
299    num_device_layers: Option<Vec<String>>,
300
301    /// In-situ quantization to apply.
302    #[arg(long = "isq", value_parser = parse_isq_value)]
303    in_situ_quant: Option<IsqType>,
304
305    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
306    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
307    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
308    #[arg(long = "pa-gpu-mem")]
309    paged_attn_gpu_mem: Option<usize>,
310
311    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
312    /// If this is not set and the device is CUDA, it will default to `0.9`.
313    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
314    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
315    #[arg(long = "pa-gpu-mem-usage")]
316    paged_attn_gpu_mem_usage: Option<f32>,
317
318    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
319    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
320    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
321    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
322    #[arg(long = "pa-ctxt-len")]
323    paged_ctxt_len: Option<usize>,
324
325    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
326    /// PagedAttention is only supported on CUDA and is always automatically activated.
327    #[arg(long = "pa-blk-size")]
328    paged_attn_block_size: Option<usize>,
329
330    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
331    #[arg(long = "no-paged-attn", default_value_t = false)]
332    no_paged_attn: bool,
333
334    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
335    #[arg(long = "paged-attn", default_value_t = false)]
336    paged_attn: bool,
337
338    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
339    #[arg(long = "prompt-batchsize")]
340    prompt_chunksize: Option<usize>,
341}
342
343fn main() -> anyhow::Result<()> {
344    let mut args = Args::parse();
345    initialize_logging();
346
347    args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
348
349    let use_flash_attn = mistralrs_core::using_flash_attn();
350
351    let prompt_chunksize = match args.prompt_chunksize {
352        Some(0) => {
353            anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
354        }
355        Some(x) => Some(NonZeroUsize::new(x).unwrap()),
356        None => None,
357    };
358
359    let dtype = get_model_dtype(&args.model)?;
360    let auto_device_map_params = get_auto_device_map_params(&args.model)?;
361
362    let max_seq_len = auto_device_map_params.max_seq_len();
363
364    let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
365        .with_use_flash_attn(use_flash_attn)
366        .with_prompt_chunksize(prompt_chunksize)
367        .build()?;
368    let model_name = loader.get_id();
369
370    #[cfg(feature = "metal")]
371    let device = Device::new_metal(0)?;
372    #[cfg(not(feature = "metal"))]
373    let device = if mistralrs_core::distributed::use_nccl() {
374        Device::Cpu
375    } else {
376        Device::cuda_if_available(0)?
377    };
378
379    if let Some(seed) = args.seed {
380        device.set_seed(seed)?;
381    }
382
383    let token_source = TokenSource::CacheToken;
384    info!(
385        "avx: {}, neon: {}, simd128: {}, f16c: {}",
386        candle_core::utils::with_avx(),
387        candle_core::utils::with_neon(),
388        candle_core::utils::with_simd128(),
389        candle_core::utils::with_f16c()
390    );
391    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
392    if use_flash_attn {
393        info!("Using flash attention.");
394    }
395    if use_flash_attn && loader.get_kind().is_quantized() {
396        warn!("Using flash attention with a quantized model has no effect!")
397    }
398    info!("Model kind is: {}", loader.get_kind().to_string());
399
400    // Parse device mapper
401    let mapper = if let Some(device_layers) = args.num_device_layers {
402        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
403            let layers = device_layers[0].parse::<usize>().unwrap();
404            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
405                DeviceLayerMapMetadata { ordinal: 0, layers },
406            ]))
407        } else {
408            let mut mapping = Vec::new();
409            for layer in device_layers {
410                let split = layer.splitn(2, ':').collect::<Vec<_>>();
411                if split.len() < 2 {
412                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
413                }
414                let ord = split[0]
415                    .parse::<usize>()
416                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
417                let num = split[1]
418                    .parse::<usize>()
419                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
420                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
421                    if *ordinal == ord {
422                        panic!("Duplicate ordinal {ord}");
423                    }
424                }
425                mapping.push(DeviceLayerMapMetadata {
426                    ordinal: ord,
427                    layers: num,
428                });
429            }
430            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
431        }
432    } else {
433        DeviceMapSetting::Auto(auto_device_map_params)
434    };
435
436    let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
437        args.no_paged_attn
438    } else if device.is_metal() {
439        !args.paged_attn
440    } else {
441        true
442    };
443
444    // Allocate 0.5 GB of CPU memory just as a placeholder.
445    // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
446    let cache_config = match (
447        args.paged_attn_block_size,
448        args.paged_attn_gpu_mem,
449        args.paged_attn_gpu_mem_usage,
450        args.paged_ctxt_len,
451        paged_attn_supported(),
452        no_paged_attn,
453    ) {
454        (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
455            block_size,
456            512,
457            MemoryGpuConfig::ContextSize(max_seq_len),
458        )?),
459        (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
460            block_size,
461            512,
462            MemoryGpuConfig::ContextSize(ctxt),
463        )?),
464        (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
465            block_size,
466            512,
467            MemoryGpuConfig::Utilization(f),
468        )?),
469        (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
470            block_size,
471            512,
472            MemoryGpuConfig::MbAmount(m),
473        )?),
474        (block_size, Some(_m), Some(f), None, true, false) => {
475            info!("Both memory size, and usage were specified, defaulting to the usage value.");
476            Some(PagedAttentionConfig::new(
477                block_size,
478                512,
479                MemoryGpuConfig::Utilization(f),
480            )?)
481        }
482        (block_size, Some(_m), None, Some(ctxt), true, false) => {
483            info!("All memory size and ctxt len, defaulting to the context len value.");
484            Some(PagedAttentionConfig::new(
485                block_size,
486                512,
487                MemoryGpuConfig::ContextSize(ctxt),
488            )?)
489        }
490        (block_size, None, Some(f), Some(_ctxt), true, false) => {
491            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
492            Some(PagedAttentionConfig::new(
493                block_size,
494                512,
495                MemoryGpuConfig::Utilization(f),
496            )?)
497        }
498        (_, _, _, _, _, _) => None,
499    };
500
501    let pipeline = loader.load_model_from_hf(
502        None,
503        token_source,
504        &dtype,
505        &device,
506        false,
507        mapper,
508        args.in_situ_quant,
509        cache_config,
510    )?;
511    info!("Model loaded.");
512
513    let scheduler_config = if cache_config.is_some() {
514        // Handle case where we may have device mapping
515        if let Some(ref cache_config) = pipeline.blocking_lock().get_metadata().cache_config {
516            SchedulerConfig::PagedAttentionMeta {
517                max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
518                config: cache_config.clone(),
519            }
520        } else {
521            SchedulerConfig::DefaultScheduler {
522                method: DefaultSchedulerMethod::Fixed(
523                    (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
524                        .try_into()
525                        .unwrap(),
526                ),
527            }
528        }
529    } else {
530        SchedulerConfig::DefaultScheduler {
531            method: DefaultSchedulerMethod::Fixed(
532                (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
533                    .try_into()
534                    .unwrap(),
535            ),
536        }
537    };
538    let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
539        .with_no_prefix_cache(true)
540        .with_disable_eos_stop(true)
541        .build();
542
543    info!("Starting warmup run.");
544    warmup_run(mistralrs.clone());
545    info!("Finished warmup run.");
546    info!("Starting benchmarks.");
547
548    for concurrency in args.concurrency.as_ref().unwrap() {
549        let mut results = vec![];
550        if args.n_gen > 0 {
551            let r = run_bench(
552                mistralrs.clone(),
553                RequestMessage::Completion {
554                    text: "Rust".to_string(),
555                    echo_prompt: false,
556                    best_of: None,
557                },
558                args.n_gen - 1,
559                *concurrency,
560                args.repetitions,
561                TestName::Gen(args.n_gen),
562            )?;
563            results.push(r);
564        }
565
566        if args.n_prompt > 0 {
567            let tks = (1000..1000 + args.n_prompt as u32).collect();
568            let r = run_bench(
569                mistralrs.clone(),
570                RequestMessage::CompletionTokens(tks),
571                1,
572                *concurrency,
573                args.repetitions,
574                TestName::Prompt(args.n_prompt),
575            )?;
576
577            results.push(r);
578        }
579
580        print_usage(&model_name, &device, results);
581    }
582
583    Ok(())
584}