mistralrs_bench/
main.rs

1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5    get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6    parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7    DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8    MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, PagedCacheType, Request,
9    RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::sync::Arc;
12use std::{fmt::Display, num::NonZeroUsize};
13use tokio::sync::mpsc::channel;
14use tracing::info;
15
16enum TestName {
17    Prompt(usize),
18    Gen(usize),
19}
20
21impl Display for TestName {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        let name = match self {
24            TestName::Prompt(n) => format!("pp {n}"),
25            TestName::Gen(n) => format!("tg {n}"),
26        };
27        write!(f, "{name}")
28    }
29}
30
31struct BenchResult {
32    usages: Vec<Usage>,
33    concurrency: usize,
34    test_name: TestName,
35}
36
37struct UncertainTokSec {
38    mean: f32,
39    std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45    }
46}
47
48fn run_bench(
49    mistralrs: Arc<MistralRs>,
50    prompt: RequestMessage,
51    n_gen: usize,
52    concurrency: usize,
53    repetitions: usize,
54    test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56    let sampling_params = SamplingParams {
57        temperature: Some(0.1),
58        top_k: Some(32),
59        top_p: Some(0.1),
60        min_p: Some(0.05),
61        top_n_logprobs: 0,
62        frequency_penalty: Some(0.1),
63        presence_penalty: Some(0.1),
64        max_len: Some(n_gen),
65        stop_toks: None,
66        logits_bias: None,
67        n_choices: 1,
68        dry_params: Some(DrySamplingParams::default()),
69    };
70    let sender = mistralrs.get_sender(None).unwrap();
71    let (tx, mut rx) = channel(10_000);
72
73    let req = Request::Normal(Box::new(NormalRequest {
74        id: mistralrs.next_request_id(),
75        messages: prompt,
76        sampling_params: sampling_params.clone(),
77        response: tx,
78        return_logprobs: false,
79        is_streaming: false,
80        constraint: Constraint::None,
81        suffix: None,
82        tools: None,
83        tool_choice: None,
84        logits_processors: None,
85        return_raw_logits: false,
86        web_search_options: None,
87        model_id: None,
88    }));
89
90    let mut usages = Vec::new();
91
92    for _ in 0..repetitions {
93        for _ in 0..concurrency {
94            if sender.blocking_send(req.clone()).is_err() {
95                eprintln!("Receiver disconnected");
96            }
97        }
98        for _ in 0..concurrency {
99            match rx.blocking_recv() {
100                Some(r) => match r {
101                    Response::InternalError(e) => {
102                        unreachable!("Got an internal error: {e:?}");
103                    }
104                    Response::ModelError(e, resp) => {
105                        unreachable!("Got a model error: {e:?}, response: {resp:?}");
106                    }
107                    Response::ValidationError(e) => {
108                        unreachable!("Got a validation error: {e:?}");
109                    }
110                    Response::Done(res) => {
111                        usages.push(res.usage);
112                    }
113                    Response::Chunk(_) => unreachable!(),
114                    Response::CompletionModelError(_, _) => unreachable!(),
115                    Response::CompletionDone(res) => {
116                        usages.push(res.usage);
117                    }
118                    Response::CompletionChunk(_) => unreachable!(),
119                    Response::ImageGeneration(_) => unreachable!(),
120                    Response::Speech { .. } => unreachable!(),
121                    Response::Raw { .. } => unreachable!(),
122                },
123                None => unreachable!("Expected a Done response, got None",),
124            }
125        }
126    }
127
128    Ok(BenchResult {
129        usages,
130        concurrency,
131        test_name,
132    })
133}
134
135fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
136    let ts_measurements = match result.test_name {
137        TestName::Prompt(_) => result
138            .usages
139            .iter()
140            .map(|u| u.avg_prompt_tok_per_sec)
141            .collect::<Vec<_>>(),
142        TestName::Gen(_) => result
143            .usages
144            .iter()
145            .map(|u| u.avg_compl_tok_per_sec)
146            .collect::<Vec<_>>(),
147    };
148    // Calculate uncertainty
149    let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
150    let variance = ts_measurements
151        .iter()
152        .map(|e| (mean - e).powf(2.))
153        .sum::<f32>()
154        / ts_measurements.len() as f32;
155    let std_dev = variance.sqrt();
156    UncertainTokSec { mean, std_dev }
157}
158
159fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
160    let ms_tok_measurements = match result.test_name {
161        TestName::Prompt(_) => result
162            .usages
163            .iter()
164            .map(|u| 1000. / u.avg_prompt_tok_per_sec)
165            .collect::<Vec<_>>(),
166        TestName::Gen(_) => result
167            .usages
168            .iter()
169            .map(|u| 1000. / u.avg_compl_tok_per_sec)
170            .collect::<Vec<_>>(),
171    };
172    // Calculate uncertainty
173    let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
174    let variance = ms_tok_measurements
175        .iter()
176        .map(|e| (mean - e).powf(2.))
177        .sum::<f32>()
178        / ms_tok_measurements.len() as f32;
179    let std_dev = variance.sqrt();
180    UncertainTokSec { mean, std_dev }
181}
182
183fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
184    let backend = match device {
185        Device::Cpu => "CPU",
186        Device::Cuda(_) => "CUDA",
187        Device::Metal(_) => "Metal",
188    };
189    let results: Vec<Vec<CellStruct>> = results
190        .into_iter()
191        .map(|r| {
192            vec![
193                model.cell(),
194                backend.cell(),
195                r.test_name.to_string().cell(),
196                get_tok_s(&r).cell().justify(Justify::Right),
197                get_ms_tok(&r).cell().justify(Justify::Right),
198                r.concurrency.cell().justify(Justify::Right),
199                (get_tok_s(&r).mean * r.concurrency as f32)
200                    .cell()
201                    .justify(Justify::Right),
202            ]
203        })
204        .collect();
205
206    let table = results
207        .table()
208        .title(vec![
209            "model".cell().bold(true),
210            // "size".cell().bold(true),
211            // "params".cell().bold(true),
212            "backend".cell().bold(true),
213            // "ngl".cell().bold(true),
214            "test".cell().bold(true),
215            "t/s".cell().bold(true),
216            "ms/t".cell().bold(true),
217            "concurrency".cell().bold(true),
218            "throughput/s".cell().bold(true),
219        ])
220        .bold(true);
221    print_stdout(table).expect("print table");
222}
223
224fn warmup_run(mistralrs: Arc<MistralRs>) {
225    let sampling_params = SamplingParams {
226        temperature: Some(0.1),
227        top_k: Some(32),
228        top_p: Some(0.1),
229        min_p: Some(0.05),
230        top_n_logprobs: 0,
231        frequency_penalty: Some(0.1),
232        presence_penalty: Some(0.1),
233        max_len: Some(5),
234        stop_toks: None,
235        logits_bias: None,
236        n_choices: 1,
237        dry_params: Some(DrySamplingParams::default()),
238    };
239    let sender = mistralrs.get_sender(None).unwrap();
240    let (tx, mut rx) = channel(10_000);
241
242    let req = Request::Normal(Box::new(NormalRequest {
243        id: mistralrs.next_request_id(),
244        messages: RequestMessage::Completion {
245            text: "Hello!".to_string(),
246            echo_prompt: false,
247            best_of: None,
248        },
249        sampling_params: sampling_params.clone(),
250        response: tx,
251        return_logprobs: false,
252        is_streaming: false,
253        constraint: Constraint::None,
254        suffix: None,
255        tools: None,
256        tool_choice: None,
257        logits_processors: None,
258        return_raw_logits: false,
259        web_search_options: None,
260        model_id: None,
261    }));
262
263    if sender.blocking_send(req.clone()).is_err() {
264        eprintln!("Receiver disconnected");
265    }
266
267    let _ = rx.blocking_recv();
268}
269
270fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
271    s.parse()
272}
273
274#[derive(Parser)]
275#[command(version, about, long_about = None)]
276struct Args {
277    /// Model
278    #[clap(subcommand)]
279    model: ModelSelected,
280
281    /// Integer seed to ensure reproducible random number generation.
282    #[arg(short, long)]
283    seed: Option<u64>,
284
285    /// Number of prompt tokens to run.
286    #[arg(long, short = 'p', default_value_t = 512)]
287    n_prompt: usize,
288
289    /// Number of generations tokens to run.
290    #[arg(long, short = 'g', default_value_t = 128)]
291    n_gen: usize,
292
293    /// Number of concurrent requests to run. Default is 1
294    #[clap(short, long, value_parser, value_delimiter = ',')]
295    concurrency: Option<Vec<usize>>,
296
297    /// Number of times to repeat each test.
298    #[arg(long, short, default_value_t = 5)]
299    repetitions: usize,
300
301    /// NOTE: This can be omitted to use automatic device mapping!
302    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
303    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
304    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
305    #[arg(short, long, value_parser, value_delimiter = ';')]
306    num_device_layers: Option<Vec<String>>,
307
308    /// In-situ quantization to apply.
309    #[arg(long = "isq")]
310    in_situ_quant: Option<String>,
311
312    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
313    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
314    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
315    #[arg(long = "pa-gpu-mem")]
316    paged_attn_gpu_mem: Option<usize>,
317
318    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
319    /// If this is not set and the device is CUDA, it will default to `0.9`.
320    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
321    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
322    #[arg(long = "pa-gpu-mem-usage")]
323    paged_attn_gpu_mem_usage: Option<f32>,
324
325    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
326    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
327    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
328    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
329    #[arg(long = "pa-ctxt-len")]
330    paged_ctxt_len: Option<usize>,
331
332    /// PagedAttention KV cache type (auto or f8e4m3).
333    /// Defaults to `auto`.
334    #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
335    cache_type: Option<PagedCacheType>,
336
337    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
338    /// PagedAttention is only supported on CUDA and is always automatically activated.
339    #[arg(long = "pa-blk-size")]
340    paged_attn_block_size: Option<usize>,
341
342    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
343    #[arg(long = "no-paged-attn", default_value_t = false)]
344    no_paged_attn: bool,
345
346    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
347    #[arg(long = "paged-attn", default_value_t = false)]
348    paged_attn: bool,
349
350    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
351    #[arg(long = "prompt-batchsize")]
352    prompt_chunksize: Option<usize>,
353}
354
355#[tokio::main]
356async fn main() -> anyhow::Result<()> {
357    let mut args = Args::parse();
358    initialize_logging();
359
360    args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
361
362    let prompt_chunksize = match args.prompt_chunksize {
363        Some(0) => {
364            anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
365        }
366        Some(x) => Some(NonZeroUsize::new(x).unwrap()),
367        None => None,
368    };
369
370    let dtype = get_model_dtype(&args.model)?;
371    let auto_device_map_params = get_auto_device_map_params(&args.model)?;
372
373    let max_seq_len = auto_device_map_params.max_seq_len();
374
375    let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
376        .with_prompt_chunksize(prompt_chunksize)
377        .build()?;
378    let model_name = loader.get_id();
379
380    #[cfg(feature = "metal")]
381    let device = Device::new_metal(0)?;
382    #[cfg(not(feature = "metal"))]
383    let device = if mistralrs_core::distributed::use_nccl() {
384        Device::Cpu
385    } else {
386        Device::cuda_if_available(0)?
387    };
388
389    if let Some(seed) = args.seed {
390        device.set_seed(seed)?;
391    }
392
393    let token_source = TokenSource::CacheToken;
394    info!(
395        "avx: {}, neon: {}, simd128: {}, f16c: {}",
396        candle_core::utils::with_avx(),
397        candle_core::utils::with_neon(),
398        candle_core::utils::with_simd128(),
399        candle_core::utils::with_f16c()
400    );
401    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
402    info!("Model kind is: {}", loader.get_kind().to_string());
403
404    // Parse device mapper
405    let mapper = if let Some(device_layers) = args.num_device_layers {
406        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
407            let layers = device_layers[0].parse::<usize>().unwrap();
408            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
409                DeviceLayerMapMetadata { ordinal: 0, layers },
410            ]))
411        } else {
412            let mut mapping = Vec::new();
413            for layer in device_layers {
414                let split = layer.splitn(2, ':').collect::<Vec<_>>();
415                if split.len() < 2 {
416                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
417                }
418                let ord = split[0]
419                    .parse::<usize>()
420                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
421                let num = split[1]
422                    .parse::<usize>()
423                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
424                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
425                    if *ordinal == ord {
426                        panic!("Duplicate ordinal {ord}");
427                    }
428                }
429                mapping.push(DeviceLayerMapMetadata {
430                    ordinal: ord,
431                    layers: num,
432                });
433            }
434            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
435        }
436    } else {
437        DeviceMapSetting::Auto(auto_device_map_params)
438    };
439
440    let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
441        args.no_paged_attn
442    } else if device.is_metal() {
443        !args.paged_attn
444    } else {
445        true
446    };
447
448    // Allocate 0.5 GB of CPU memory just as a placeholder.
449    // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
450    let cache_config = match (
451        args.paged_attn_block_size,
452        args.paged_attn_gpu_mem,
453        args.paged_attn_gpu_mem_usage,
454        args.paged_ctxt_len,
455        paged_attn_supported(),
456        no_paged_attn,
457    ) {
458        (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
459            block_size,
460            512,
461            MemoryGpuConfig::ContextSize(max_seq_len),
462            args.cache_type.unwrap_or_default(),
463        )?),
464        (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
465            block_size,
466            512,
467            MemoryGpuConfig::ContextSize(ctxt),
468            args.cache_type.unwrap_or_default(),
469        )?),
470        (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
471            block_size,
472            512,
473            MemoryGpuConfig::Utilization(f),
474            args.cache_type.unwrap_or_default(),
475        )?),
476        (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
477            block_size,
478            512,
479            MemoryGpuConfig::MbAmount(m),
480            args.cache_type.unwrap_or_default(),
481        )?),
482        (block_size, Some(_m), Some(f), None, true, false) => {
483            info!("Both memory size, and usage were specified, defaulting to the usage value.");
484            Some(PagedAttentionConfig::new(
485                block_size,
486                512,
487                MemoryGpuConfig::Utilization(f),
488                args.cache_type.unwrap_or_default(),
489            )?)
490        }
491        (block_size, Some(_m), None, Some(ctxt), true, false) => {
492            info!("All memory size and ctxt len, defaulting to the context len value.");
493            Some(PagedAttentionConfig::new(
494                block_size,
495                512,
496                MemoryGpuConfig::ContextSize(ctxt),
497                args.cache_type.unwrap_or_default(),
498            )?)
499        }
500        (block_size, None, Some(f), Some(_ctxt), true, false) => {
501            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
502            Some(PagedAttentionConfig::new(
503                block_size,
504                512,
505                MemoryGpuConfig::Utilization(f),
506                args.cache_type.unwrap_or_default(),
507            )?)
508        }
509        (_, _, _, _, _, _) => None,
510    };
511
512    let isq = args
513        .in_situ_quant
514        .as_ref()
515        .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
516
517    let pipeline = loader.load_model_from_hf(
518        None,
519        token_source,
520        &dtype,
521        &device,
522        false,
523        mapper,
524        isq,
525        cache_config,
526    )?;
527    info!("Model loaded.");
528
529    let scheduler_config = if cache_config.is_some() {
530        // Handle case where we may have device mapping
531        if let Some(ref cache_config) = pipeline.blocking_lock().get_metadata().cache_config {
532            SchedulerConfig::PagedAttentionMeta {
533                max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
534                config: cache_config.clone(),
535            }
536        } else {
537            SchedulerConfig::DefaultScheduler {
538                method: DefaultSchedulerMethod::Fixed(
539                    (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
540                        .try_into()
541                        .unwrap(),
542                ),
543            }
544        }
545    } else {
546        SchedulerConfig::DefaultScheduler {
547            method: DefaultSchedulerMethod::Fixed(
548                (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
549                    .try_into()
550                    .unwrap(),
551            ),
552        }
553    };
554    let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
555        .with_no_prefix_cache(true)
556        .with_disable_eos_stop(true)
557        .build()
558        .await;
559
560    info!("Starting warmup run.");
561    warmup_run(mistralrs.clone());
562    info!("Finished warmup run.");
563    info!("Starting benchmarks.");
564
565    for concurrency in args.concurrency.as_ref().unwrap() {
566        let mut results = vec![];
567        if args.n_gen > 0 {
568            let r = run_bench(
569                mistralrs.clone(),
570                RequestMessage::Completion {
571                    text: "Rust".to_string(),
572                    echo_prompt: false,
573                    best_of: None,
574                },
575                args.n_gen - 1,
576                *concurrency,
577                args.repetitions,
578                TestName::Gen(args.n_gen),
579            )?;
580            results.push(r);
581        }
582
583        if args.n_prompt > 0 {
584            let tks = (1000..1000 + args.n_prompt as u32).collect();
585            let r = run_bench(
586                mistralrs.clone(),
587                RequestMessage::CompletionTokens(tks),
588                1,
589                *concurrency,
590                args.repetitions,
591                TestName::Prompt(args.n_prompt),
592            )?;
593
594            results.push(r);
595        }
596
597        print_usage(&model_name, &device, results);
598    }
599
600    Ok(())
601}