diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index e00562f58..dd2ecd8f2 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -178,6 +178,7 @@ def pre_reorder_triton_kernel( topk, hidden_size, BLOCK_SIZE: tl.constexpr, + use_per_token_if_dynamic: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty @@ -188,11 +189,15 @@ def pre_reorder_triton_kernel( vec = tl.arange(0, BLOCK_SIZE) + if a1_scales_ptr is not None and use_per_token_if_dynamic: + scale = 1.0 / tl.load(a1_scales_ptr + src_idx) + for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id >= start_expert_id and expert_id <= end_expert_id: if a1_scales_ptr is not None: - scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) + if not use_per_token_if_dynamic: + scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) else: scale = 1.0 @@ -558,6 +563,7 @@ def grouped_gemm_triton_kernel( bs_stride_0: tl.constexpr, bs_stride_2: tl.constexpr, bs_stride_1: tl.constexpr, + use_per_token_if_dynamic: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -621,7 +627,10 @@ def grouped_gemm_triton_kernel( b_ptr += BLOCK_SIZE_K if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): - scale_a_value = tl.load(scale_a + m_range_start + offs_am[:, None]) + if use_per_token_if_dynamic: + scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None])) + else: + scale_a_value = tl.load(scale_a + expert_id) scale_b_value = tl.load(scale_b + expert_id) accumulator *= scale_a_value * scale_b_value @@ -658,6 +667,7 @@ def grouped_gemm_triton( scale_b: torch.Tensor = None, block_shape: Optional[List[int]] = None, c_dtype=None, + use_per_token_if_dynamic: bool = True, ): assert weight_column_major == True # TODO: more if use_fp8_w8a8 and block_shape is None: @@ -698,6 +708,11 @@ def grouped_gemm_triton( triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), ) + if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic: + assert ( + scale_a.shape[0] == a.shape[0] + ), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}" + grouped_gemm_triton_kernel[grid]( a, b, @@ -721,6 +736,7 @@ def grouped_gemm_triton( scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0, scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0, scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0, + use_per_token_if_dynamic, **config, ) return c diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 6d629333e..63b13ac34 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -50,7 +50,10 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant +from sglang.srt.layers.quantization.fp8_kernel import ( + scaled_fp8_quant, + sglang_per_token_quant_fp8, +) from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs @@ -65,10 +68,16 @@ logger = logging.getLogger(__name__) class GroupedGemmRunner(torch.nn.Module): flashinfer_gemm_warpper = None - def __init__(self, device, use_flashinfer: bool = False): + def __init__( + self, + device, + use_flashinfer: bool = False, + use_per_token_if_dynamic: bool = True, + ): super().__init__() self.device = device self.use_flashinfer = use_flashinfer + self.use_per_token_if_dynamic = use_per_token_if_dynamic if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None: GroupedGemmRunner._init_flashinfer_wrapper(device) @@ -124,6 +133,7 @@ class GroupedGemmRunner(torch.nn.Module): scale_b, block_shape=block_shape, c_dtype=c_dtype, + use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) return c @@ -154,6 +164,7 @@ class EPMoE(torch.nn.Module): custom_routing_function: Optional[Callable] = None, activation: str = "silu", routed_scaling_factor: Optional[float] = None, + use_per_token_if_dynamic: bool = True, ): super().__init__() @@ -184,6 +195,7 @@ class EPMoE(torch.nn.Module): self.custom_routing_function = custom_routing_function self.activation = activation self.routed_scaling_factor = routed_scaling_factor + self.use_per_token_if_dynamic = use_per_token_if_dynamic if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() @@ -227,6 +239,7 @@ class EPMoE(torch.nn.Module): self.grouped_gemm_runner = GroupedGemmRunner( hidden_states.device, use_flashinfer=False, # TODO: use flashinfer + use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) topk_weights, topk_ids = select_experts( @@ -259,12 +272,16 @@ class EPMoE(torch.nn.Module): ), ) if self.activation_scheme == "dynamic" and not self.use_block_quant: - max_value = ( - torch.max(hidden_states) - .repeat(self.num_experts_per_partition) - .to(torch.float32) - ) - self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + if self.use_per_token_if_dynamic: + max_value = torch.max(hidden_states, dim=1).values.to(torch.float32) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + else: + max_value = ( + torch.max(hidden_states) + .repeat(self.num_experts_per_partition) + .to(torch.float32) + ) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max # PreReorder pre_reorder_triton_kernel[(hidden_states.shape[0],)]( @@ -278,9 +295,27 @@ class EPMoE(torch.nn.Module): self.top_k, hidden_states.shape[1], BLOCK_SIZE=512, + use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) dispose_tensor(hidden_states) + if ( + self.activation_scheme == "dynamic" + and not self.use_block_quant + and self.use_per_token_if_dynamic + ): + scale = torch.empty( + hidden_states_shape[0] * self.top_k, + device=hidden_states_device, + dtype=torch.float32, + ) + scale[src2dst] = ( + self.w13_input_scale.unsqueeze(1) + .expand(hidden_states_shape[0], self.top_k) + .reshape(-1) + ) + self.w13_input_scale = scale + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, @@ -310,21 +345,24 @@ class EPMoE(torch.nn.Module): del gateup_input # Act - down_input = torch.empty( - gateup_output.shape[0], - gateup_output.shape[1] // 2, - device=gateup_output.device, - dtype=( - self.fp8_dtype - if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states_dtype - ), - ) - if self.w2_input_scale is None and not self.use_block_quant: - self.w2_input_scale = torch.ones( - self.num_experts_per_partition, - dtype=torch.float32, - device=hidden_states_device, + if self.activation_scheme == "dynamic" and not self.use_block_quant: + self.w2_input_scale = None + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=hidden_states_dtype, + ) + else: + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=( + self.fp8_dtype + if (self.use_fp8_w8a8 and not self.use_block_quant) + else hidden_states_dtype + ), ) if self.activation == "silu": @@ -353,6 +391,16 @@ class EPMoE(torch.nn.Module): raise ValueError(f"Unsupported activation: {self.activation=}") del gateup_output + if self.activation_scheme == "dynamic" and not self.use_block_quant: + if self.use_per_token_if_dynamic: + down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input) + else: + self.w2_input_scale = torch.ones( + self.num_experts_per_partition, + dtype=torch.float32, + device=hidden_states_device, + ) + # GroupGemm-1 down_output = torch.empty( down_input.shape[0],