sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
76
benchmark/kernels/fused_moe_triton/README.md
Normal file
76
benchmark/kernels/fused_moe_triton/README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
## Tuning Triton MoE Kernels
|
||||
|
||||
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
|
||||
|
||||
### Tuning Tool
|
||||
|
||||
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Tune Mixtral-8x7B with default settings
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||
--tune
|
||||
|
||||
# Tune Qwen2-57B with FP8 and TP=4
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--tp-size 4 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
# Tune Qwen3-235B-A22B-FP8 and TP=4
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model Qwen/Qwen3-235B-A22B-FP8 \
|
||||
--tp-size 4 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
# Tune DeepSeek-V3 with FP8 and TP=8
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
# Tune DeepSeek-R1 with channel-wise INT8 and TP=16
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model meituan/DeepSeek-R1-Channel-INT8 \
|
||||
--tp-size 16 \
|
||||
--dtype int8_w8a8 \
|
||||
--tune
|
||||
```
|
||||
|
||||
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
|
||||
|
||||
### Performance Comparison Tool
|
||||
|
||||
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Compare with default settings (Mixtral model)
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
|
||||
|
||||
# Compare with FP8 mode for Qwen2-57B
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--use-fp8-w8a8
|
||||
|
||||
# Compare with custom TP size
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8
|
||||
|
||||
# Compare with custom TP size
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8
|
||||
```
|
||||
|
||||
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
||||
|
||||
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
|
||||
|
||||
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.
|
||||
@@ -0,0 +1,292 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_sglang,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
"""Get model configuration parameters"""
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
]:
|
||||
E = (
|
||||
config.n_routed_experts + 1
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
|
||||
shape_configs = {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
"block_shape": block_shape,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def fused_moe_triton_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
):
|
||||
topk_op = TopK(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
topk_op.use_triton_kernels = True
|
||||
triton_topk_output = topk_op.forward_cuda(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
)
|
||||
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=False,
|
||||
)
|
||||
return triton_kernel_moe_forward(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
triton_topk_output,
|
||||
moe_runner_config,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
topk_output = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
return fused_moe_sglang(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"sglang_fused_moe_triton_v340",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
line_names=[
|
||||
"sglang_fused_moe_triton_v340",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(
|
||||
batch_size,
|
||||
provider,
|
||||
model_config,
|
||||
use_fp8_w8a8=False,
|
||||
use_cuda_graph: bool = False,
|
||||
):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
w1_tri = w1_tri.transpose(-2, -1).contiguous()
|
||||
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
if provider == "sglang_fused_moe_triton_v340":
|
||||
api_func = fused_moe_triton_api
|
||||
api_kwargs = {
|
||||
"x": x,
|
||||
"w1": w1_tri,
|
||||
"w2": w2_tri,
|
||||
"input_gating": input_gating,
|
||||
"topk": topk,
|
||||
}
|
||||
else:
|
||||
api_func = fused_moe_sglang_api
|
||||
api_kwargs = {
|
||||
"x": x,
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"input_gating": input_gating,
|
||||
"topk": topk,
|
||||
"use_fp8_w8a8": use_fp8_w8a8,
|
||||
"block_shape": block_shape,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = api_func(**api_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if use_cuda_graph:
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
api_func(**api_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
bench_lambda = lambda: graph.replay()
|
||||
else:
|
||||
bench_lambda = lambda: api_func(**api_kwargs)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", type=int, default=2)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/sglang_fused_moe/",
|
||||
)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="tcp://127.0.0.1:23456",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
)
|
||||
finally:
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
202
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Normal file
202
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.testing import do_bench
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _moe_sum_reduce_kernel(
|
||||
input_ptr,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_ptr,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
token_num: int,
|
||||
topk_num: int,
|
||||
hidden_dim: int,
|
||||
routed_scaling_factor: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DIM: tl.constexpr,
|
||||
NUM_STAGE: tl.constexpr,
|
||||
):
|
||||
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
|
||||
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
|
||||
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
|
||||
|
||||
token_block_id = tl.program_id(0)
|
||||
dim_block_id = tl.program_id(1)
|
||||
|
||||
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
|
||||
mask_token = offs_token < token_num
|
||||
mask_dim = offs_dim < hidden_dim
|
||||
|
||||
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
|
||||
|
||||
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||
tile = tl.load(
|
||||
base_ptrs + i * input_stride_1,
|
||||
mask=mask_token[:, None] & mask_dim[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tile.to(tl.float32)
|
||||
accumulator *= routed_scaling_factor
|
||||
|
||||
# -------- Write back --------
|
||||
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
|
||||
tl.store(
|
||||
store_ptrs,
|
||||
accumulator.to(input_ptr.dtype.element_ty),
|
||||
mask=mask_token[:, None] & mask_dim[None, :],
|
||||
)
|
||||
|
||||
|
||||
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
||||
def moe_sum_reduce(
|
||||
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
||||
):
|
||||
assert input.is_contiguous()
|
||||
assert output.is_contiguous()
|
||||
|
||||
token_num, topk_num, hidden_dim = input.shape
|
||||
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
|
||||
|
||||
BLOCK_M = 1
|
||||
BLOCK_DIM = 2048
|
||||
NUM_STAGE = 1
|
||||
num_warps = 16
|
||||
|
||||
grid = (
|
||||
triton.cdiv(token_num, BLOCK_M),
|
||||
triton.cdiv(hidden_dim, BLOCK_DIM),
|
||||
)
|
||||
|
||||
_moe_sum_reduce_kernel[grid](
|
||||
input,
|
||||
*input.stride(),
|
||||
output,
|
||||
*output.stride(),
|
||||
token_num=token_num,
|
||||
topk_num=topk_num,
|
||||
hidden_dim=hidden_dim,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_DIM=BLOCK_DIM,
|
||||
NUM_STAGE=NUM_STAGE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def compute_sum_scaled_baseline(
|
||||
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
|
||||
) -> torch.Tensor:
|
||||
torch.sum(x, dim=1, out=out)
|
||||
out.mul_(routed_scaling_factor)
|
||||
return out
|
||||
|
||||
|
||||
@torch.compile
|
||||
def compute_sum_scaled_compiled(
|
||||
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
|
||||
) -> torch.Tensor:
|
||||
torch.sum(x * routed_scaling_factor, dim=1, out=out)
|
||||
return out
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
num_tokens_range = [2**i for i in range(0, 13)]
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=num_tokens_range,
|
||||
line_arg="version",
|
||||
line_vals=["baseline", "compiled", "triton"],
|
||||
line_names=["Original", "TorchCompile", "TritonKernel"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="sum_scaled_performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, version):
|
||||
topk = 9
|
||||
hidden_size = 4096
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3
|
||||
|
||||
x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
if version == "baseline":
|
||||
compute_sum_scaled_baseline(x, out, scaling_factor)
|
||||
elif version == "compiled":
|
||||
compute_sum_scaled_compiled(x, out, scaling_factor)
|
||||
else:
|
||||
moe_sum_reduce(x, out, scaling_factor)
|
||||
|
||||
# Benchmark
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if version == "baseline":
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif version == "compiled":
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def verify_correctness(num_tokens=1024):
|
||||
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
|
||||
scaling_factor = 0.3
|
||||
|
||||
out_baseline = torch.empty_like(x[:, 0])
|
||||
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)
|
||||
|
||||
out_compiled = torch.empty_like(out_baseline)
|
||||
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
|
||||
|
||||
out_triton = torch.empty_like(out_baseline)
|
||||
moe_sum_reduce(x, out_triton, scaling_factor)
|
||||
|
||||
if torch.allclose(
|
||||
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
|
||||
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
print(
|
||||
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
|
||||
)
|
||||
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running correctness verification...")
|
||||
verify_correctness()
|
||||
|
||||
print("\nRunning performance benchmark...")
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
# save_path="./configs/benchmark_ops/sum_scaled/"
|
||||
)
|
||||
@@ -0,0 +1,305 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from torch.nn import functional as F
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_triton,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
"""Get model configuration parameters"""
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
shape_configs = {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
M, _ = hidden_states.shape
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@torch.compile(dynamic=False)
|
||||
def fused_moe_torch(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
) -> torch.Tensor:
|
||||
assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
|
||||
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=x,
|
||||
gating_output=input_gating,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
w13_weights = w1[topk_ids]
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = w2[topk_ids]
|
||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||
x1 = F.silu(x1)
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
|
||||
|
||||
def fused_moe_torch_compile(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
):
|
||||
return fused_moe_torch(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
):
|
||||
return fused_moe_triton(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list(range(1, 5)),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"fused_moe_triton",
|
||||
"fused_moe_torch_compile",
|
||||
],
|
||||
line_names=[
|
||||
"fused_moe_triton",
|
||||
"fused_moe_torch_compile",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
set_torch_compile_config()
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
init_dtype = dtype
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
# Warmup
|
||||
api_func = (
|
||||
fused_moe_torch_compile
|
||||
if provider == "fused_moe_torch_compile"
|
||||
else fused_moe_sglang_api
|
||||
)
|
||||
for _ in range(10):
|
||||
y = api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)[0],
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", type=int, default=2)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fused_moe_torch_compile/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,349 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import vllm
|
||||
from transformers import AutoConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_sglang,
|
||||
)
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
"""Get model configuration parameters"""
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
]:
|
||||
E = (
|
||||
config.n_routed_experts + 1
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
vllm_version_num = (
|
||||
vllm.__version_tuple__[0] * 100
|
||||
+ vllm.__version_tuple__[1] * 10
|
||||
+ vllm.__version_tuple__[2]
|
||||
)
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
assert (
|
||||
vllm_version_num >= 66
|
||||
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
|
||||
|
||||
shape_configs = {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
"block_shape": block_shape,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def fused_moe_vllm_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
if block_shape is not None:
|
||||
return fused_moe_vllm(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
return fused_moe_vllm(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
return fused_moe_sglang(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list(range(1, 513)),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm_fused_moe_triton",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_fused_moe_triton",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
init_dtype = dtype
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
|
||||
if block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
# Warmup
|
||||
api_func = (
|
||||
fused_moe_vllm_api
|
||||
if provider == "vllm_fused_moe_triton"
|
||||
else fused_moe_sglang_api
|
||||
)
|
||||
for _ in range(10):
|
||||
y = api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)[0],
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", type=int, default=2)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="tcp://127.0.0.1:23456",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
)
|
||||
finally:
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
599
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Normal file
599
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Normal file
@@ -0,0 +1,599 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import triton
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe,
|
||||
get_config_dtype_str,
|
||||
get_config_file_name,
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int] = None,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16 or use_int8_w8a8:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
if use_int8_w8a8 and block_shape is None:
|
||||
w1_scale = torch.randn(
|
||||
num_experts, shard_intermediate_size, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
|
||||
elif block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
topk_output = select_experts(x, input_gating, topk_config)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating = gating_output[i]
|
||||
new_topk_output = select_experts(x, input_gating, topk_config)
|
||||
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
||||
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
||||
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
||||
|
||||
def run():
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
waves_per_eu_range = 0
|
||||
for num_stages in [2]:
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for block_n in [16, 32, 64, 128, 256]:
|
||||
for num_warps in [1, 2, 4, 8]:
|
||||
for group_size in [1, 4, 8, 16, 32]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
configs: List[BenchmarkConfig] = []
|
||||
if _is_hip:
|
||||
configs = get_rocm_configs_compute_bound()
|
||||
else:
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(0)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
op_config = get_moe_configs(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
False,
|
||||
block_shape,
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
num_iters=10,
|
||||
)
|
||||
except (triton.runtime.autotuner.OutOfResources, RuntimeError):
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
**(
|
||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: Dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(
|
||||
num_experts,
|
||||
shard_intermediate_size // 2,
|
||||
dtype_str,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = (
|
||||
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts + (
|
||||
0 if args.disable_shared_experts_fusion else 1
|
||||
)
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
|
||||
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
search_space = [
|
||||
config
|
||||
for config in search_space
|
||||
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.perf_counter()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
search_space,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
best_configs = {
|
||||
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(
|
||||
best_configs,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
print(f"Kernel time: {kernel_time:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user