Add more fused moe benchmark utilities (#2314)

This commit is contained in:
Lianmin Zheng
2024-12-02 04:26:55 -08:00
committed by GitHub
parent 18108abe5d
commit 33deca81b5
3 changed files with 294 additions and 23 deletions

View File

@@ -1,5 +1,6 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse
import json
import time
from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict
@@ -9,10 +10,14 @@ import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from sglang.srt.layers.fused_moe_triton.fused_moe import *
from sglang.srt.layers.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
get_default_config,
get_moe_configs,
)
class BenchmarkConfig(TypedDict):
@@ -92,7 +97,7 @@ def benchmark_config(
input_gating.copy_(gating_output[i])
def run():
from sglang.srt.layers.fused_moe_triton.fused_moe import override_config
from sglang.srt.layers.fused_moe_triton import override_config
with override_config(config):
fused_moe(
@@ -174,7 +179,7 @@ class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
torch.cuda.manual_seed_all(0)
self.seed = seed
def benchmark(
@@ -188,7 +193,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
current_platform.seed_everything(self.seed)
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
@@ -319,7 +324,7 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral.
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
@@ -430,7 +435,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)