FlashInfer NVFP4 MoE with EP & 2-stream shared expert (#7327)
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com> Co-authored-by: alcanderian <alcanderian@gmail.com>
This commit is contained in:
@@ -1295,6 +1295,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
def get_moe_impl_class():
|
def get_moe_impl_class():
|
||||||
if global_server_args_dict["enable_deepep_moe"]:
|
if global_server_args_dict["enable_deepep_moe"]:
|
||||||
return DeepEPMoE
|
return DeepEPMoE
|
||||||
|
if global_server_args_dict["enable_flashinfer_moe"]:
|
||||||
|
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
||||||
|
return FusedMoE
|
||||||
if global_server_args_dict["enable_ep_moe"]:
|
if global_server_args_dict["enable_ep_moe"]:
|
||||||
return EPMoE
|
return EPMoE
|
||||||
return FusedMoE
|
return FusedMoE
|
||||||
|
|||||||
@@ -314,6 +314,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
enable_flashinfer_moe: Optional[bool] = False,
|
||||||
|
enable_ep_moe: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -324,9 +326,34 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.tp_size = (
|
self.tp_size = (
|
||||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||||
)
|
)
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.expert_map = None
|
||||||
|
self.enable_flashinfer_moe = enable_flashinfer_moe
|
||||||
|
if enable_ep_moe:
|
||||||
|
assert (
|
||||||
|
self.enable_flashinfer_moe
|
||||||
|
), "FusedMoE only supports EP with --enable-flashinfer-moe"
|
||||||
|
self.ep_size = self.tp_size
|
||||||
|
self.ep_rank = self.tp_rank
|
||||||
|
self.tp_size = 1
|
||||||
|
self.tp_rank = 0
|
||||||
|
# Create a tensor of size num_experts filled with -1
|
||||||
|
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
||||||
|
# Create a expert map for the local experts
|
||||||
|
assert num_experts % self.ep_size == 0
|
||||||
|
self.local_num_experts = num_experts // self.ep_size
|
||||||
|
self.expert_map[
|
||||||
|
self.ep_rank
|
||||||
|
* self.local_num_experts : (self.ep_rank + 1)
|
||||||
|
* self.local_num_experts
|
||||||
|
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
|
||||||
|
else:
|
||||||
|
self.ep_size = 1
|
||||||
|
self.ep_rank = 0
|
||||||
|
self.local_num_experts = num_experts
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.num_experts = num_experts
|
|
||||||
assert intermediate_size % self.tp_size == 0
|
assert intermediate_size % self.tp_size == 0
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@@ -344,7 +371,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.use_presharded_weights = use_presharded_weights
|
self.use_presharded_weights = use_presharded_weights
|
||||||
self.inplace = inplace
|
self.inplace = inplace
|
||||||
self.no_combine = no_combine
|
self.no_combine = no_combine
|
||||||
self.local_num_experts = num_experts
|
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
@@ -352,11 +378,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
|
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
||||||
|
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
num_experts=num_experts,
|
num_experts=self.local_num_experts,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
# FIXME: figure out which intermediate_size to use
|
# FIXME: figure out which intermediate_size to use
|
||||||
intermediate_size=self.intermediate_size_per_partition,
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
@@ -450,12 +478,15 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
# Narrow parameter and load.
|
# Narrow parameter and load.
|
||||||
# w1, gate_proj: Load into first logical weight of w13.
|
# w1, gate_proj: Load into first logical weight of w13.
|
||||||
if shard_id == "w1":
|
|
||||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
|
||||||
# w3, up_proj: Load into second logical weight of w13.
|
# w3, up_proj: Load into second logical weight of w13.
|
||||||
|
# trtllm cutlass kernel assumes differently
|
||||||
|
assert shard_id in ("w1", "w3")
|
||||||
|
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
||||||
|
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
||||||
|
start = shard_size
|
||||||
else:
|
else:
|
||||||
assert shard_id == "w3"
|
start = 0
|
||||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
||||||
expert_data.copy_(loaded_weight)
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def _load_w2(
|
def _load_w2(
|
||||||
@@ -509,6 +540,11 @@ class FusedMoE(torch.nn.Module):
|
|||||||
assert shard_id in ("w1", "w3")
|
assert shard_id in ("w1", "w3")
|
||||||
expert_data.copy_(loaded_weight)
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||||
|
if self.expert_map is None:
|
||||||
|
return expert_id
|
||||||
|
return self.expert_map[expert_id].item()
|
||||||
|
|
||||||
def weight_loader(
|
def weight_loader(
|
||||||
self,
|
self,
|
||||||
param: torch.nn.Parameter,
|
param: torch.nn.Parameter,
|
||||||
@@ -517,6 +553,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||||
|
if expert_id == -1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# TP rank is set to 0 if EP is enabled
|
||||||
|
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||||
# against known CompressionFormat enum values that have this quality
|
# against known CompressionFormat enum values that have this quality
|
||||||
@@ -541,7 +584,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||||
|
|
||||||
expert_data = param.data[expert_id]
|
expert_data = param.data[expert_id]
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
# is_transposed: if the dim to shard the weight
|
# is_transposed: if the dim to shard the weight
|
||||||
# should be flipped. Required by GPTQ, compressed-tensors
|
# should be flipped. Required by GPTQ, compressed-tensors
|
||||||
@@ -549,7 +591,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
is_transposed = getattr(param, "is_transposed", False)
|
is_transposed = getattr(param, "is_transposed", False)
|
||||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
shard_dim = ~shard_dim
|
shard_dim = int(not shard_dim)
|
||||||
|
|
||||||
# Case input scale: input_scale loading is only supported for fp8
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
@@ -690,9 +732,19 @@ class FusedMoE(torch.nn.Module):
|
|||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
**(
|
||||||
|
dict(
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
ep_rank=self.ep_rank,
|
||||||
|
ep_size=self.ep_size,
|
||||||
|
)
|
||||||
|
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|||||||
@@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flashinfer import fp4_quantize as fp4_quantize
|
||||||
|
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||||
|
except ImportError:
|
||||||
|
flashinfer_cutlass_fused_moe = None
|
||||||
|
|
||||||
# Initialize logger for the module
|
# Initialize logger for the module
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|||||||
layer.alpha = Parameter(
|
layer.alpha = Parameter(
|
||||||
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
||||||
)
|
)
|
||||||
|
layer.input_scale_inv = Parameter(
|
||||||
|
(1 / input_scale_2).to(torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
# Pad and blockwise interleave weight_scale
|
# Pad and blockwise interleave weight_scale
|
||||||
scales = layer.weight_scale
|
scales = layer.weight_scale
|
||||||
@@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|||||||
output_shape = [x_m, w_n]
|
output_shape = [x_m, w_n]
|
||||||
|
|
||||||
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||||
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
|
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
|
||||||
|
|
||||||
assert x_fp4.dtype == torch.uint8
|
assert x_fp4.dtype == torch.uint8
|
||||||
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
||||||
@@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod:
|
|||||||
" quantization. Please use Blackwell and"
|
" quantization. Please use Blackwell and"
|
||||||
" above."
|
" above."
|
||||||
)
|
)
|
||||||
|
self.enable_flashinfer_moe = False
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod:
|
|||||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||||
|
|
||||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
if self.enable_flashinfer_moe:
|
||||||
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||||
|
else:
|
||||||
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||||
layer.g1_alphas = Parameter(
|
layer.g1_alphas = Parameter(
|
||||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod:
|
|||||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||||
|
|
||||||
# GEMM 2
|
# GEMM 2
|
||||||
|
if self.enable_flashinfer_moe:
|
||||||
|
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||||
|
else:
|
||||||
|
w2_input_scale = layer.w2_input_scale
|
||||||
|
|
||||||
layer.g2_alphas = Parameter(
|
layer.g2_alphas = Parameter(
|
||||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is for quantization, so we need to invert it.
|
# This is for quantization, so we need to invert it.
|
||||||
layer.w2_input_scale_quant = Parameter(
|
layer.w2_input_scale_quant = Parameter(
|
||||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
|
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod:
|
|||||||
layer.cutlass_moe_params = CutlassMoEParams(
|
layer.cutlass_moe_params = CutlassMoEParams(
|
||||||
CutlassMoEType.BlockscaledFP4,
|
CutlassMoEType.BlockscaledFP4,
|
||||||
device,
|
device,
|
||||||
num_experts=layer.num_experts,
|
num_experts=layer.num_experts, # global num experts
|
||||||
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
||||||
hidden_size=layer.w13_weight.shape[2] * 2,
|
hidden_size=layer.w13_weight.shape[2] * 2,
|
||||||
) # k
|
) # k
|
||||||
|
|
||||||
|
@property
|
||||||
|
def load_up_proj_weight_first(self) -> bool:
|
||||||
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||||
|
return self.enable_flashinfer_moe
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod:
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
ep_rank: Optional[int] = None,
|
||||||
|
ep_size: Optional[int] = None,
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
@@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod:
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_flashinfer_moe:
|
||||||
|
assert (
|
||||||
|
not apply_router_weight_on_input
|
||||||
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||||
|
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||||
|
# and fp4 quantized weights loaded from the checkpoint
|
||||||
|
output = flashinfer_cutlass_fused_moe(
|
||||||
|
x,
|
||||||
|
topk_ids.to(torch.int),
|
||||||
|
topk_weights,
|
||||||
|
layer.w13_weight.view(torch.long),
|
||||||
|
layer.w2_weight.view(torch.long),
|
||||||
|
x.dtype,
|
||||||
|
quant_scales=[
|
||||||
|
layer.w13_input_scale_quant,
|
||||||
|
layer.w13_blockscale_swizzled.view(torch.int32),
|
||||||
|
layer.g1_alphas,
|
||||||
|
layer.w2_input_scale_quant,
|
||||||
|
layer.w2_blockscale_swizzled.view(torch.int32),
|
||||||
|
layer.g2_alphas,
|
||||||
|
],
|
||||||
|
ep_size=ep_size,
|
||||||
|
ep_rank=ep_rank,
|
||||||
|
tp_size=tp_size,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||||
|
)
|
||||||
|
return output[0]
|
||||||
|
|
||||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||||
|
|
||||||
return cutlass_moe_fp4(
|
return cutlass_moe_fp4(
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"enable_deepep_moe",
|
"enable_deepep_moe",
|
||||||
"deepep_mode",
|
"deepep_mode",
|
||||||
"enable_ep_moe",
|
"enable_ep_moe",
|
||||||
|
"enable_flashinfer_moe",
|
||||||
"moe_dense_tp_size",
|
"moe_dense_tp_size",
|
||||||
"ep_dispatch_algorithm",
|
"ep_dispatch_algorithm",
|
||||||
"deepep_config",
|
"deepep_config",
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
layer_id: int,
|
layer_id: int,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -238,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
self.alt_stream = alt_stream
|
||||||
|
|
||||||
if self.tp_size > config.n_routed_experts:
|
if self.tp_size > config.n_routed_experts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -275,6 +277,15 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if global_server_args_dict["enable_deepep_moe"]
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
|
# Additional args for FusedMoE
|
||||||
|
**(
|
||||||
|
dict(
|
||||||
|
enable_flashinfer_moe=True,
|
||||||
|
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||||
|
)
|
||||||
|
if global_server_args_dict["enable_flashinfer_moe"]
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
||||||
@@ -338,10 +349,36 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not self._enable_deepep_moe:
|
if not self._enable_deepep_moe:
|
||||||
return self.forward_normal(hidden_states)
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||||
|
if (
|
||||||
|
self.alt_stream is not None
|
||||||
|
and self.num_fused_shared_experts == 0
|
||||||
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
|
):
|
||||||
|
return self.forward_normal_dual_stream(hidden_states)
|
||||||
|
else:
|
||||||
|
return self.forward_normal(hidden_states)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
|
|
||||||
|
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
final_hidden_states = self.experts(
|
||||||
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
|
)
|
||||||
|
if not _is_cuda:
|
||||||
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
@@ -1446,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
|
alt_stream=alt_stream,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if enable_moe_dense_fully_dp():
|
if enable_moe_dense_fully_dp():
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ class ServerArgs:
|
|||||||
ep_size: int = 1
|
ep_size: int = 1
|
||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
|
enable_flashinfer_moe: bool = False
|
||||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||||
ep_num_redundant_experts: int = 0
|
ep_num_redundant_experts: int = 0
|
||||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
||||||
@@ -244,7 +245,15 @@ class ServerArgs:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||||
)
|
)
|
||||||
|
if self.enable_flashinfer_moe:
|
||||||
|
assert (
|
||||||
|
self.quantization == "modelopt_fp4"
|
||||||
|
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||||
|
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||||
|
self.disable_shared_experts_fusion = True
|
||||||
|
logger.warning(
|
||||||
|
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
||||||
|
)
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
@@ -1162,6 +1171,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-flashinfer-moe",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user