Add more fused moe benchmark utilities (#2314)
This commit is contained in:
@@ -0,0 +1,275 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
|
"""Get model configuration parameters"""
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
|
||||||
|
shape_configs = {
|
||||||
|
"num_experts": E,
|
||||||
|
"topk": topk,
|
||||||
|
"hidden_size": config.hidden_size,
|
||||||
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
|
"dtype": config.torch_dtype,
|
||||||
|
}
|
||||||
|
print(f"{shape_configs=}")
|
||||||
|
return shape_configs
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk_native(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
||||||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def fused_moe_torch(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert not use_fp8_w8a8, "Not supported"
|
||||||
|
|
||||||
|
topk_weights, topk_ids = fused_topk_native(
|
||||||
|
hidden_states=x,
|
||||||
|
gating_output=input_gating,
|
||||||
|
topk=topk,
|
||||||
|
renormalize=True,
|
||||||
|
)
|
||||||
|
w13_weights = w1[topk_ids]
|
||||||
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||||
|
w2_weights = w2[topk_ids]
|
||||||
|
x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
||||||
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||||
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||||
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_torch_compile(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
):
|
||||||
|
return fused_moe_torch(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_sglang_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
):
|
||||||
|
return fused_moe_triton(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=list(range(1, 5)),
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=[
|
||||||
|
"fused_moe_triton",
|
||||||
|
"fused_moe_torch_compile",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"fused_moe_triton",
|
||||||
|
"fused_moe_torch_compile",
|
||||||
|
],
|
||||||
|
styles=[
|
||||||
|
("blue", "-"),
|
||||||
|
("green", "-"),
|
||||||
|
],
|
||||||
|
ylabel="Time (ms)",
|
||||||
|
plot_name="fused-moe-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, model_config, use_fp8=False):
|
||||||
|
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
num_tokens = batch_size
|
||||||
|
num_experts = model_config["num_experts"]
|
||||||
|
hidden_size = model_config["hidden_size"]
|
||||||
|
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||||
|
topk = model_config["topk"]
|
||||||
|
dtype = model_config["dtype"]
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
|
||||||
|
if use_fp8:
|
||||||
|
init_dtype = dtype
|
||||||
|
w1 = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||||
|
)
|
||||||
|
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||||
|
|
||||||
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
api_func = (
|
||||||
|
fused_moe_torch_compile
|
||||||
|
if provider == "fused_moe_torch_compile"
|
||||||
|
else fused_moe_sglang_api
|
||||||
|
)
|
||||||
|
for _ in range(10):
|
||||||
|
y = api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)[0],
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-size", type=int, default=2)
|
||||||
|
parser.add_argument("--use-fp8", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/fused_moe_torch_compile/",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_config = get_model_config(args.model, args.tp_size)
|
||||||
|
benchmark.run(
|
||||||
|
show_plots=True,
|
||||||
|
print_data=True,
|
||||||
|
save_path=args.save_path,
|
||||||
|
model_config=model_config,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,22 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import numbers
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from torch.nn import init
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
||||||
get_moe_configs as get_moe_configs_vllm,
|
|
||||||
)
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
|
||||||
get_moe_configs as get_moe_configs_sglang,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model_name: str, tp_size: int):
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
@@ -39,19 +28,21 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral, Grok1, etc.
|
# Default: Mixtral
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
|
||||||
return {
|
shape_configs = {
|
||||||
"num_experts": E,
|
"num_experts": E,
|
||||||
"topk": topk,
|
"topk": topk,
|
||||||
"hidden_size": config.hidden_size,
|
"hidden_size": config.hidden_size,
|
||||||
"shard_intermediate_size": shard_intermediate_size,
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
"dtype": config.torch_dtype,
|
"dtype": config.torch_dtype,
|
||||||
}
|
}
|
||||||
|
print(f"{shape_configs=}")
|
||||||
|
return shape_configs
|
||||||
|
|
||||||
|
|
||||||
def fused_moe_vllm_api(
|
def fused_moe_vllm_api(
|
||||||
@@ -133,7 +124,7 @@ def fused_moe_sglang_api(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
def benchmark(batch_size, provider, model_config, use_fp8=False):
|
def benchmark(batch_size, provider, model_config, use_fp8=False):
|
||||||
print(f"benchmark for batch_size={batch_size}")
|
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
@@ -210,7 +201,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = FlexibleArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Tuple, TypedDict
|
from typing import Any, Dict, List, Tuple, TypedDict
|
||||||
@@ -9,10 +10,14 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
from ray.experimental.tqdm_ray import tqdm
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import *
|
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe,
|
||||||
|
get_config_dtype_str,
|
||||||
|
get_config_file_name,
|
||||||
|
get_default_config,
|
||||||
|
get_moe_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkConfig(TypedDict):
|
class BenchmarkConfig(TypedDict):
|
||||||
@@ -92,7 +97,7 @@ def benchmark_config(
|
|||||||
input_gating.copy_(gating_output[i])
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import override_config
|
from sglang.srt.layers.fused_moe_triton import override_config
|
||||||
|
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
fused_moe(
|
fused_moe(
|
||||||
@@ -174,7 +179,7 @@ class BenchmarkWorker:
|
|||||||
|
|
||||||
def __init__(self, seed: int) -> None:
|
def __init__(self, seed: int) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(seed)
|
torch.cuda.manual_seed_all(0)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def benchmark(
|
def benchmark(
|
||||||
@@ -188,7 +193,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
) -> Tuple[Dict[str, int], float]:
|
) -> Tuple[Dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
torch.cuda.manual_seed_all(0)
|
||||||
dtype_str = get_config_dtype_str(
|
dtype_str = get_config_dtype_str(
|
||||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
)
|
)
|
||||||
@@ -319,7 +324,7 @@ def main(args: argparse.Namespace):
|
|||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral.
|
# Default: Mixtral
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
@@ -430,7 +435,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user