Integrate triton moe kernel (#7689)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -0,0 +1,271 @@
|
|||||||
|
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from sglang.srt.distributed.parallel_state import (
|
||||||
|
destroy_distributed_environment,
|
||||||
|
destroy_model_parallel,
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe as fused_moe_sglang,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
|
triton_kernel_moe_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
|
"""Get model configuration parameters"""
|
||||||
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
|
E = (
|
||||||
|
config.n_routed_experts + 1
|
||||||
|
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||||
|
else config.n_routed_experts
|
||||||
|
)
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
|
||||||
|
block_shape = None
|
||||||
|
if (
|
||||||
|
hasattr(config, "quantization_config")
|
||||||
|
and "weight_block_size" in config.quantization_config
|
||||||
|
):
|
||||||
|
block_shape = config.quantization_config["weight_block_size"]
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
|
||||||
|
shape_configs = {
|
||||||
|
"num_experts": E,
|
||||||
|
"topk": topk,
|
||||||
|
"hidden_size": config.hidden_size,
|
||||||
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
|
"dtype": config.torch_dtype,
|
||||||
|
"block_shape": block_shape,
|
||||||
|
}
|
||||||
|
print(f"{shape_configs=}")
|
||||||
|
return shape_configs
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_triton_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
):
|
||||||
|
return triton_kernel_moe_forward(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_sglang_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
block_shape=None,
|
||||||
|
):
|
||||||
|
return fused_moe_sglang(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=False,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=[
|
||||||
|
"sglang_fused_moe_triton_v340",
|
||||||
|
"sglang_fused_moe_triton",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"sglang_fused_moe_triton_v340",
|
||||||
|
"sglang_fused_moe_triton",
|
||||||
|
],
|
||||||
|
styles=[
|
||||||
|
("blue", "-"),
|
||||||
|
("green", "-"),
|
||||||
|
],
|
||||||
|
ylabel="Time (ms)",
|
||||||
|
plot_name="fused-moe-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(
|
||||||
|
batch_size,
|
||||||
|
provider,
|
||||||
|
model_config,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_cuda_graph: bool = False,
|
||||||
|
):
|
||||||
|
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
num_tokens = batch_size
|
||||||
|
num_experts = model_config["num_experts"]
|
||||||
|
hidden_size = model_config["hidden_size"]
|
||||||
|
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||||
|
topk = model_config["topk"]
|
||||||
|
dtype = model_config["dtype"]
|
||||||
|
block_shape = model_config["block_shape"]
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
|
||||||
|
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
w1_tri = w1.clone()
|
||||||
|
w2_tri = w2.clone()
|
||||||
|
w1_tri = w1_tri.transpose(-2, -1).contiguous()
|
||||||
|
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||||
|
|
||||||
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
if provider == "sglang_fused_moe_triton_v340":
|
||||||
|
api_func = fused_moe_triton_api
|
||||||
|
api_kwargs = {
|
||||||
|
"x": x,
|
||||||
|
"w1": w1_tri,
|
||||||
|
"w2": w2_tri,
|
||||||
|
"input_gating": input_gating,
|
||||||
|
"topk": topk,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
api_func = fused_moe_sglang_api
|
||||||
|
api_kwargs = {
|
||||||
|
"x": x,
|
||||||
|
"w1": w1,
|
||||||
|
"w2": w2,
|
||||||
|
"input_gating": input_gating,
|
||||||
|
"topk": topk,
|
||||||
|
"use_fp8_w8a8": use_fp8_w8a8,
|
||||||
|
"block_shape": block_shape,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(10):
|
||||||
|
_ = api_func(**api_kwargs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if use_cuda_graph:
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph, stream=stream):
|
||||||
|
api_func(**api_kwargs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
bench_lambda = lambda: graph.replay()
|
||||||
|
else:
|
||||||
|
bench_lambda = lambda: api_func(**api_kwargs)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles)
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-size", type=int, default=2)
|
||||||
|
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/sglang_fused_moe/",
|
||||||
|
)
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||||
|
init_method="tcp://127.0.0.1:23456",
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
distributed_init_method="tcp://127.0.0.1:23456",
|
||||||
|
local_rank=0,
|
||||||
|
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||||
|
)
|
||||||
|
|
||||||
|
initialize_model_parallel(
|
||||||
|
tensor_model_parallel_size=1,
|
||||||
|
pipeline_model_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = get_model_config(args.model, args.tp_size)
|
||||||
|
benchmark.run(
|
||||||
|
show_plots=True,
|
||||||
|
print_data=True,
|
||||||
|
save_path=args.save_path,
|
||||||
|
model_config=model_config,
|
||||||
|
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||||
|
use_cuda_graph=args.use_cuda_graph,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
destroy_model_parallel()
|
||||||
|
destroy_distributed_environment()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1737,6 +1737,7 @@ def fused_moe(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
@@ -1822,6 +1823,7 @@ def fused_moe(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||||
|
|
||||||
|
import importlib
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
|
|||||||
use_intel_amx_backend,
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
|
if has_triton_kernels:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
|
triton_kernel_moe_forward,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
fused_experts = None # type: ignore
|
||||||
|
|
||||||
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
def __init__(self, use_triton_kernels: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.use_triton_kernels = use_triton_kernels
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
# Fused gate_up_proj (column parallel)
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
||||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
# down_proj (row parallel)
|
# down_proj (row parallel)
|
||||||
|
w2_weight_n, w2_weight_k = (
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
)
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
||||||
w2_weight = torch.nn.Parameter(
|
w2_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
||||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids = select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
top_k=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
correction_bias=correction_bias,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if _use_aiter:
|
if self.use_triton_kernels:
|
||||||
assert not no_combine, "unsupported"
|
return triton_kernel_moe_forward(
|
||||||
if apply_router_weight_on_input:
|
|
||||||
assert (
|
|
||||||
topk_weights.dim() == 2
|
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
|
||||||
_, topk = topk_weights.shape
|
|
||||||
assert (
|
|
||||||
topk == 1
|
|
||||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
|
||||||
x = x * topk_weights.to(x.dtype)
|
|
||||||
topk_weights = torch.ones_like(
|
|
||||||
topk_weights, dtype=torch.float32
|
|
||||||
) # topk_weights must be FP32 (float32)
|
|
||||||
|
|
||||||
return fused_moe(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
activation=(
|
|
||||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return fused_experts(
|
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
gating_output=router_logits,
|
||||||
topk_ids=topk_ids,
|
topk=top_k,
|
||||||
inplace=inplace and not no_combine,
|
renormalize=renormalize,
|
||||||
activation=activation,
|
)
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
else:
|
||||||
no_combine=no_combine,
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
assert not no_combine, "unsupported"
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
assert (
|
||||||
|
topk_weights.dim() == 2
|
||||||
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
|
_, topk = topk_weights.shape
|
||||||
|
assert (
|
||||||
|
topk == 1
|
||||||
|
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
|
x = x * topk_weights.to(x.dtype)
|
||||||
|
topk_weights = torch.ones_like(
|
||||||
|
topk_weights, dtype=torch.float32
|
||||||
|
) # topk_weights must be FP32 (float32)
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu
|
||||||
|
if activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=inplace and not no_combine,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -475,9 +506,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
self.no_combine = no_combine
|
self.no_combine = no_combine
|
||||||
|
|
||||||
|
self.use_triton_kernels = (
|
||||||
|
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
||||||
|
)
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||||
UnquantizedFusedMoEMethod()
|
self.use_triton_kernels
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
@@ -597,6 +632,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not self.use_presharded_weights:
|
if not self.use_presharded_weights:
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||||
loaded_weight = loaded_weight.narrow(
|
loaded_weight = loaded_weight.narrow(
|
||||||
shard_dim, shard_size * tp_rank, shard_size
|
shard_dim, shard_size * tp_rank, shard_size
|
||||||
)
|
)
|
||||||
@@ -630,6 +667,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not self.use_presharded_weights:
|
if not self.use_presharded_weights:
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||||
loaded_weight = loaded_weight.narrow(
|
loaded_weight = loaded_weight.narrow(
|
||||||
shard_dim, shard_size * tp_rank, shard_size
|
shard_dim, shard_size * tp_rank, shard_size
|
||||||
)
|
)
|
||||||
@@ -716,6 +755,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# should be whatever dimension intermediate_size is
|
# should be whatever dimension intermediate_size is
|
||||||
is_transposed = getattr(param, "is_transposed", False)
|
is_transposed = getattr(param, "is_transposed", False)
|
||||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
is_transposed = True
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
shard_dim = int(not shard_dim)
|
shard_dim = int(not shard_dim)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,176 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
|
from triton_kernels.matmul_ogs import matmul_ogs
|
||||||
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||||
|
|
||||||
|
from sglang.srt.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
def triton_kernel_moe_forward(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if not renormalize:
|
||||||
|
gating_output = torch.softmax(gating_output, dim=-1)
|
||||||
|
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
||||||
|
|
||||||
|
return triton_kernel_fused_experts(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
routing_data,
|
||||||
|
gather_idx,
|
||||||
|
scatter_idx,
|
||||||
|
inplace=inplace,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This is a triton implementation of the fused_experts function
|
||||||
|
def triton_kernel_fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
routing_data: RoutingData,
|
||||||
|
gather_indx: GatherIndx,
|
||||||
|
scatter_indx: ScatterIndx,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
||||||
|
assert per_channel_quant == False, "per_channel_quant is not supported"
|
||||||
|
assert expert_map == None, "expert_map is not supported"
|
||||||
|
assert w1_scale == None, "w1_scale is not supported"
|
||||||
|
assert w2_scale == None, "w2_scale is not supported"
|
||||||
|
assert a1_scale == None, "a1_scale is not supported"
|
||||||
|
assert a2_scale == None, "a2_scale is not supported"
|
||||||
|
assert block_shape == None, "block_shape is not supported"
|
||||||
|
|
||||||
|
# type check
|
||||||
|
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
||||||
|
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
||||||
|
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
||||||
|
|
||||||
|
# Shape check
|
||||||
|
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
||||||
|
assert (
|
||||||
|
hidden_states.shape[-1] == w1.shape[-2]
|
||||||
|
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
||||||
|
assert (
|
||||||
|
w2.shape[-1] == w1.shape[1]
|
||||||
|
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
||||||
|
|
||||||
|
# feature check
|
||||||
|
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
||||||
|
|
||||||
|
M, K = hidden_states.shape
|
||||||
|
E, _, N = w1.shape
|
||||||
|
n_expts_act = routing_data.n_expts_act
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = E
|
||||||
|
|
||||||
|
# consistent with default implementation
|
||||||
|
intermediate_cache2 = torch.empty(
|
||||||
|
(M * n_expts_act, N // 2), device="cuda", dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_cache1 = matmul_ogs(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
None,
|
||||||
|
routing_data,
|
||||||
|
gather_indx=gather_indx,
|
||||||
|
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if activation == "silu":
|
||||||
|
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
|
elif activation == "gelu":
|
||||||
|
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
|
intermediate_cache3 = matmul_ogs(
|
||||||
|
intermediate_cache2,
|
||||||
|
w2,
|
||||||
|
None,
|
||||||
|
routing_data,
|
||||||
|
scatter_indx=scatter_indx,
|
||||||
|
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return intermediate_cache3
|
||||||
|
|
||||||
|
|
||||||
|
def triton_kernel_moe_forward_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="forward_cuda_triton",
|
||||||
|
op_func=triton_kernel_moe_forward,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=triton_kernel_moe_forward_fake,
|
||||||
|
)
|
||||||
@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"triton_attention_reduce_in_fp32",
|
"triton_attention_reduce_in_fp32",
|
||||||
"num_reserved_decode_tokens",
|
"num_reserved_decode_tokens",
|
||||||
"weight_loader_disable_mmap",
|
"weight_loader_disable_mmap",
|
||||||
|
"enable_triton_kernel_moe",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class ServerArgs:
|
|||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
disable_fast_image_processor: bool = False
|
disable_fast_image_processor: bool = False
|
||||||
enable_return_hidden_states: bool = False
|
enable_return_hidden_states: bool = False
|
||||||
|
enable_triton_kernel_moe: bool = False
|
||||||
warmups: Optional[str] = None
|
warmups: Optional[str] = None
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
@@ -1554,6 +1555,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable returning hidden states with responses.",
|
help="Enable returning hidden states with responses.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-triton-kernel-moe",
|
||||||
|
action="store_true",
|
||||||
|
help="Use triton moe grouped gemm kernel.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warmups",
|
"--warmups",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
146
test/srt/test_triton_fused_moe.py
Normal file
146
test/srt/test_triton_fused_moe.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
|
triton_kernel_moe_forward,
|
||||||
|
)
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedMOE(CustomTestCase):
|
||||||
|
NUM_EXPERTS = [8, 64]
|
||||||
|
TOP_KS = [2, 4]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01):
|
||||||
|
"""Create a random CUDA tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Tensor shape
|
||||||
|
dtype: Data type
|
||||||
|
mean: Mean value
|
||||||
|
std: Standard deviation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Randomly initialized CUDA tensor
|
||||||
|
"""
|
||||||
|
return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std)
|
||||||
|
|
||||||
|
def get_tolerance(self, dtype):
|
||||||
|
"""Get tolerance values for different data types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: Data type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (relative tolerance, absolute tolerance)
|
||||||
|
"""
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return 1e-5, 1e-5
|
||||||
|
elif dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
return 1e-5, 1e-5
|
||||||
|
else:
|
||||||
|
return 1e-2, 1e-2 # Default values for other types
|
||||||
|
|
||||||
|
def torch_naive_moe(
|
||||||
|
self,
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
):
|
||||||
|
B, D = a.shape
|
||||||
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
|
topk_weight = topk_weight.view(-1)
|
||||||
|
topk_ids = topk_ids.view(-1)
|
||||||
|
|
||||||
|
if w1.dtype == torch.float8_e4m3fn:
|
||||||
|
w1_compute = w1.to(a.dtype)
|
||||||
|
w2_compute = w2.to(a.dtype)
|
||||||
|
else:
|
||||||
|
w1_compute = w1
|
||||||
|
w2_compute = w2
|
||||||
|
|
||||||
|
for i in range(w1_compute.shape[0]):
|
||||||
|
mask = topk_ids == i
|
||||||
|
if mask.sum():
|
||||||
|
out[mask] = SiluAndMul()(
|
||||||
|
a[mask] @ w1_compute[i].transpose(0, 1)
|
||||||
|
) @ w2_compute[i].transpose(0, 1)
|
||||||
|
|
||||||
|
return (
|
||||||
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||||
|
).sum(dim=1)
|
||||||
|
|
||||||
|
def _test_case(self, m, n, k, e, topk, dtype):
|
||||||
|
rtol, atol = self.get_tolerance(dtype)
|
||||||
|
|
||||||
|
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||||
|
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
|
||||||
|
w2 = self.create_random_cuda_tensor((e, k, n), dtype)
|
||||||
|
w1_tri = w1.clone()
|
||||||
|
w2_tri = w2.clone()
|
||||||
|
w1_tri = w1_tri.transpose(-2, -1).contiguous()
|
||||||
|
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||||
|
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||||
|
|
||||||
|
triton_output = triton_kernel_moe_forward(
|
||||||
|
a, w1_tri, w2_tri, score, topk, renormalize=False
|
||||||
|
)
|
||||||
|
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
||||||
|
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
def test_various_configurations(self):
|
||||||
|
m_values = [1, 32, 64, 256]
|
||||||
|
n_values = [128, 1024]
|
||||||
|
k_values = [128, 512, 1024]
|
||||||
|
dtypes = [torch.bfloat16]
|
||||||
|
|
||||||
|
# Calculate total number of tests
|
||||||
|
total_tests = (
|
||||||
|
len(m_values)
|
||||||
|
* len(n_values)
|
||||||
|
* len(k_values)
|
||||||
|
* len(self.NUM_EXPERTS)
|
||||||
|
* len(self.TOP_KS)
|
||||||
|
* len(dtypes)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create progress bar
|
||||||
|
with tqdm(total=total_tests, desc="Running MoE tests") as pbar:
|
||||||
|
for m in m_values:
|
||||||
|
for n in n_values:
|
||||||
|
for k in k_values:
|
||||||
|
for e in self.NUM_EXPERTS:
|
||||||
|
for topk in self.TOP_KS:
|
||||||
|
for dtype in dtypes:
|
||||||
|
with self.subTest(
|
||||||
|
m=m,
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
e=e,
|
||||||
|
topk=topk,
|
||||||
|
dtype=dtype,
|
||||||
|
):
|
||||||
|
self._test_case(
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
e,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user