Add more fused moe benchmark utilities (#2314)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
@@ -9,10 +10,14 @@ import torch
|
||||
import triton
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import *
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
||||
fused_moe,
|
||||
get_config_dtype_str,
|
||||
get_config_file_name,
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@@ -92,7 +97,7 @@ def benchmark_config(
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import override_config
|
||||
from sglang.srt.layers.fused_moe_triton import override_config
|
||||
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
@@ -174,7 +179,7 @@ class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
self.seed = seed
|
||||
|
||||
def benchmark(
|
||||
@@ -188,7 +193,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
@@ -319,7 +324,7 @@ def main(args: argparse.Namespace):
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
@@ -430,7 +435,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user