Update fused_moe's benchmark (#3346)
This commit is contained in:
@@ -2,6 +2,7 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
import vllm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||||
|
|
||||||
@@ -29,11 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral
|
# Default: Mixtral
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
@@ -41,12 +42,27 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
|
||||||
|
vllm_version_num = (
|
||||||
|
vllm.__version_tuple__[0] * 100
|
||||||
|
+ vllm.__version_tuple__[1] * 10
|
||||||
|
+ vllm.__version_tuple__[2]
|
||||||
|
)
|
||||||
|
block_shape = None
|
||||||
|
if (
|
||||||
|
hasattr(config, "quantization_config")
|
||||||
|
and "weight_block_size" in config.quantization_config
|
||||||
|
):
|
||||||
|
block_shape = config.quantization_config["weight_block_size"]
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
|
||||||
|
|
||||||
shape_configs = {
|
shape_configs = {
|
||||||
"num_experts": E,
|
"num_experts": E,
|
||||||
"topk": topk,
|
"topk": topk,
|
||||||
"hidden_size": config.hidden_size,
|
"hidden_size": config.hidden_size,
|
||||||
"shard_intermediate_size": shard_intermediate_size,
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
"dtype": config.torch_dtype,
|
"dtype": config.torch_dtype,
|
||||||
|
"block_shape": block_shape,
|
||||||
}
|
}
|
||||||
print(f"{shape_configs=}")
|
print(f"{shape_configs=}")
|
||||||
return shape_configs
|
return shape_configs
|
||||||
@@ -63,21 +79,39 @@ def fused_moe_vllm_api(
|
|||||||
w2_scale=None,
|
w2_scale=None,
|
||||||
a1_scale=None,
|
a1_scale=None,
|
||||||
a2_scale=None,
|
a2_scale=None,
|
||||||
|
block_shape=None,
|
||||||
):
|
):
|
||||||
return fused_moe_vllm(
|
if block_shape is not None:
|
||||||
x,
|
return fused_moe_vllm(
|
||||||
w1,
|
x,
|
||||||
w2,
|
w1,
|
||||||
input_gating,
|
w2,
|
||||||
topk,
|
input_gating,
|
||||||
renormalize=True,
|
topk,
|
||||||
inplace=True,
|
renormalize=True,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
inplace=True,
|
||||||
w1_scale=w1_scale,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
w2_scale=w2_scale,
|
w1_scale=w1_scale,
|
||||||
a1_scale=a1_scale,
|
w2_scale=w2_scale,
|
||||||
a2_scale=a2_scale,
|
a1_scale=a1_scale,
|
||||||
)
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return fused_moe_vllm(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fused_moe_sglang_api(
|
def fused_moe_sglang_api(
|
||||||
@@ -91,6 +125,7 @@ def fused_moe_sglang_api(
|
|||||||
w2_scale=None,
|
w2_scale=None,
|
||||||
a1_scale=None,
|
a1_scale=None,
|
||||||
a2_scale=None,
|
a2_scale=None,
|
||||||
|
block_shape=None,
|
||||||
):
|
):
|
||||||
return fused_moe_sglang(
|
return fused_moe_sglang(
|
||||||
x,
|
x,
|
||||||
@@ -105,6 +140,7 @@ def fused_moe_sglang_api(
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -141,8 +177,10 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
|
|||||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||||
topk = model_config["topk"]
|
topk = model_config["topk"]
|
||||||
dtype = model_config["dtype"]
|
dtype = model_config["dtype"]
|
||||||
|
block_shape = getattr(model_config, "block_shape", None)
|
||||||
|
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
init_dtype = dtype
|
init_dtype = dtype
|
||||||
@@ -154,16 +192,29 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
|
|||||||
)
|
)
|
||||||
w1 = w1.to(torch.float8_e4m3fn)
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
w2 = w2.to(torch.float8_e4m3fn)
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
|
||||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
if block_shape is None:
|
||||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||||
|
w1_scale = torch.rand(
|
||||||
|
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.rand(
|
||||||
|
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||||
w2 = torch.randn(
|
w2 = torch.randn(
|
||||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||||
)
|
)
|
||||||
w1_scale = w2_scale = a1_scale = a2_scale = None
|
|
||||||
|
|
||||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
@@ -185,6 +236,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@@ -201,6 +253,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
)[0],
|
)[0],
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user