diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index e0486891a..ae7d13ea5 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -1,10 +1,17 @@ import logging -from typing import Optional +from typing import List, Optional import torch import triton import triton.language as tl +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 + +_is_cuda = torch.cuda.is_available() and torch.version.cuda +if _is_cuda: + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) logger = logging.getLogger(__name__) @@ -218,12 +225,19 @@ def grouped_gemm_triton_kernel( seg_indptr, weight_indices, m_num_tiles_indptr, - use_fp8_w8a8, scale_a, scale_b, + use_fp8_w8a8: tl.constexpr, + group_n: tl.constexpr, + group_k: tl.constexpr, a_stride_0: tl.constexpr, b_stride_0: tl.constexpr, b_stride_1: tl.constexpr, + as_stride_0: tl.constexpr, + as_stride_1: tl.constexpr, + bs_stride_0: tl.constexpr, + bs_stride_2: tl.constexpr, + bs_stride_1: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -260,6 +274,12 @@ def grouped_gemm_triton_kernel( + (n_range_start + offs_bn[:, None]) * b_stride_1 + offs_k[None, :] ) + + if group_k > 0 and group_n > 0: + a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0 + offs_bsn = (n_range_start + offs_bn) // group_n + b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a_tile = tl.load( @@ -268,14 +288,23 @@ def grouped_gemm_triton_kernel( b_tile = tl.load( b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 ) - accumulator = tl.dot(a_tile, b_tile.T, accumulator) + + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1) + b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2) + accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :] + else: + accumulator = tl.dot(a_tile, b_tile.T, accumulator) a_ptr += BLOCK_SIZE_K b_ptr += BLOCK_SIZE_K - if use_fp8_w8a8: + if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): scale_a_value = tl.load(scale_a + expert_id) scale_b_value = tl.load(scale_b + expert_id) accumulator *= scale_a_value * scale_b_value + c_tile = accumulator.to(c_dtype) offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) @@ -307,14 +336,29 @@ def grouped_gemm_triton( use_fp8_w8a8: bool = False, scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, + block_shape: Optional[List[int]] = None, ): assert weight_column_major == True # TODO: more - if use_fp8_w8a8: + if use_fp8_w8a8 and block_shape is None: assert scale_a is not None and scale_b is not None + if block_shape is not None: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + if _is_cuda: + a, scale_a = sglang_per_token_group_quant_fp8(a, block_k) + else: + a, scale_a = per_token_group_quant_fp8(a, block_k) + + assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] + assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] + assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] + + # TODO: adjust config or tune kernel + # Reduce block size to prevent L40 shared memory overflow. config = { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, } @@ -338,12 +382,19 @@ def grouped_gemm_triton( seg_indptr, weight_indices, m_num_tiles_indptr, - use_fp8_w8a8, scale_a, scale_b, + use_fp8_w8a8, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], a.stride(0), b.stride(0), b.stride(1), + scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0, + scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0, + scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0, + scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0, + scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0, **config, ) return c diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 4d6040646..7468c0b91 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -17,6 +17,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( run_moe_ep_preproess, silu_and_mul_triton_kernel, ) +from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( @@ -61,6 +62,7 @@ class GroupedGemmRunner(torch.nn.Module): use_fp8_w8a8: bool = False, scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, + block_shape: Optional[List[int]] = None, ): if self.use_flashinfer: # TODO: flashinfer @@ -87,6 +89,7 @@ class GroupedGemmRunner(torch.nn.Module): use_fp8_w8a8, scale_a, scale_b, + block_shape=block_shape, ) return c @@ -147,12 +150,20 @@ class EPMoE(torch.nn.Module): if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.use_fp8_w8a8 = False + self.use_block_quant = False + self.block_shape = None self.activation_scheme = None else: self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( quant_config ) self.use_fp8_w8a8 = True + self.use_block_quant = getattr(self.quant_method, "block_quant", False) + self.block_shape = ( + self.quant_method.quant_config.weight_block_size + if self.use_block_quant + else None + ) self.fp8_dtype = torch.float8_e4m3fn self.activation_scheme = quant_config.activation_scheme @@ -173,7 +184,8 @@ class EPMoE(torch.nn.Module): if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( - hidden_states.device, use_flashinfer=False # TODO: use flashinfer + hidden_states.device, + use_flashinfer=False, # TODO: use flashinfer ) topk_weights, topk_ids = select_experts( @@ -195,9 +207,13 @@ class EPMoE(torch.nn.Module): gateup_input = torch.empty( (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), device=hidden_states.device, - dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + dtype=( + self.fp8_dtype + if (self.use_fp8_w8a8 and not self.use_block_quant) + else hidden_states.dtype + ), ) - if self.activation_scheme == "dynamic": + if self.activation_scheme == "dynamic" and not self.use_block_quant: max_value = ( torch.max(hidden_states) .repeat(self.num_experts_per_partition) @@ -243,7 +259,12 @@ class EPMoE(torch.nn.Module): weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w13_input_scale, - scale_b=self.w13_weight_scale, + scale_b=( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + block_shape=self.block_shape, ) # Act @@ -251,9 +272,13 @@ class EPMoE(torch.nn.Module): gateup_output.shape[0], gateup_output.shape[1] // 2, device=gateup_output.device, - dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + dtype=( + self.fp8_dtype + if (self.use_fp8_w8a8 and not self.use_block_quant) + else hidden_states.dtype + ), ) - if self.w2_input_scale is None: + if self.w2_input_scale is None and not self.use_block_quant: self.w2_input_scale = torch.ones( self.num_experts_per_partition, dtype=torch.float32, @@ -291,7 +316,12 @@ class EPMoE(torch.nn.Module): weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w2_input_scale, - scale_b=self.w2_weight_scale, + scale_b=( + self.w2_weight_scale_inv + if self.use_block_quant + else self.w2_weight_scale + ), + block_shape=self.block_shape, ) # PostReorder @@ -358,7 +388,11 @@ class EPMoE(torch.nn.Module): # Special case for fp8 scales. if "scale" in weight_name: self._load_fp8_scale( - param.data, loaded_weight, weight_name, shard_id, expert_id + param.data, + loaded_weight, + weight_name, + shard_id, + expert_id, ) return @@ -395,18 +429,33 @@ class EPMoE(torch.nn.Module): param_data[expert_id] = loaded_weight # Weight scales elif "weight_scale" in weight_name: + if self.use_block_quant: + block_n, block_k = self.block_shape[0], self.block_shape[1] + if shard_id == "w1": + param_data[expert_id][ + : (self.intermediate_size + block_n - 1) // block_n, : + ] = loaded_weight + elif shard_id == "w3": + param_data[expert_id][ + (self.intermediate_size + block_n - 1) // block_n :, : + ] = loaded_weight + else: # w2 + param_data[expert_id] = loaded_weight # If we are in merged column case (gate_up_proj) - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + def create_weights( self, layer: torch.nn.Module, @@ -498,6 +547,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None def create_weights( self, @@ -512,6 +562,29 @@ class Fp8EPMoEMethod(Fp8MoEMethod): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by collum parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -538,21 +611,49 @@ class Fp8EPMoEMethod(Fp8MoEMethod): set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, 2, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) + if self.block_quant: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + else: + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly - extra_weight_attrs.update({"quant_method": "tensor"}) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py new file mode 100644 index 000000000..c077d0c45 --- /dev/null +++ b/python/sglang/test/test_block_fp8_ep.py @@ -0,0 +1,361 @@ +import itertools +import random +import unittest +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch + +from sglang.srt.layers.moe.ep_moe.kernels import ( + grouped_gemm_triton, + post_reorder_triton_kernel, + pre_reorder_triton_kernel, + run_moe_ep_preproess, + silu_and_mul_triton_kernel, +) +from sglang.srt.layers.moe.topk import select_experts + + +# For test +def ep_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + # ep config + num_experts: int = 256, + fp8_dtype: torch.types = torch.float8_e4m3fn, + num_experts_per_partition: int = 128, + start_expert_id: int = 0, + end_expert_id: int = 127, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + w1_scale_inv: Optional[torch.Tensor] = None, + w2_scale_inv: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + use_blockwise_fp8 = block_shape is not None + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + # correction_bias=correction_bias, #skip this in test + custom_routing_function=custom_routing_function, + ) + + reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts) + + gateup_input = torch.empty( + (int(hidden_states.shape[0] * top_k), hidden_states.shape[1]), + device=hidden_states.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_blockwise_fp8) + else hidden_states.dtype + ), + ) + + if use_fp8_w8a8 and not use_blockwise_fp8: + max_value = ( + torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32) + ) + w1_input_scale = max_value / torch.finfo(fp8_dtype).max + else: + w1_input_scale = None + + # PreReorder + pre_reorder_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + src2dst, + topk_ids, + w1_input_scale, + start_expert_id, + end_expert_id, + top_k, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + + seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2] + weight_indices_cur_rank = torch.arange( + 0, + num_experts_per_partition, + device=hidden_states.device, + dtype=torch.int64, + ) + + # GroupGemm-0 + gateup_output = torch.empty( + gateup_input.shape[0], + w1.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + gateup_output = grouped_gemm_triton( + a=gateup_input, + b=w1, + c=gateup_output, + batch_size=num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=use_fp8_w8a8, + scale_a=w1_input_scale, + scale_b=w1_scale_inv, + block_shape=block_shape, + ) + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_blockwise_fp8) + else hidden_states.dtype + ), + ) + if use_fp8_w8a8 and not use_blockwise_fp8: + w2_input_scale = torch.ones( + num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + else: + w2_input_scale = None + + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + w2_input_scale, + start_expert_id, + end_expert_id, + BLOCK_SIZE=512, + ) + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + w2.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + down_output = grouped_gemm_triton( + a=down_input, + b=w2, + c=down_output, + batch_size=num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=use_fp8_w8a8, + scale_a=w2_input_scale, + scale_b=w2_scale_inv, + block_shape=block_shape, + ) + + # PostReorder + output = torch.empty_like(hidden_states) + post_reorder_triton_kernel[(hidden_states.size(0),)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + top_k, + hidden_states.size(1), + BLOCK_SIZE=512, + ) + return output + + +# test util +def block_dequant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise quantization. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. + Note only float8 is supported for now. + """ + + # process 3D tensor + if x_q_block.dim() == 3: + batch_size = x_q_block.size(0) + return torch.stack( + [block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)] + ) + + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + x_dq_block_tiles = [ + [ + x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] + + return x_dq_block + + +class TestW8A8BlockFP8EPMoE(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + M = [1, 222, 1024, 2048] + N = [128, 1024, 2048] + K = [256, 4096, 5120] + E = [8, 16] + ep_size = [2, 4] + TOP_KS = [2, 4] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_block_fp8_ep_moe( + self, M, N, K, E, ep_size, topk, block_size, dtype, seed + ): + torch.manual_seed(seed) + random.seed(seed) + # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max + w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max + w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_s = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) + + w1_ref = block_dequant(w1, w1_s, block_size).to(dtype) + w2_ref = block_dequant(w2, w2_s, block_size).to(dtype) + + score = torch.randn((M, E), dtype=dtype) + num_experts_per_partition = E // ep_size + cur_rank = random.randint(0, ep_size - 1) + start_id = cur_rank * num_experts_per_partition + end_id = start_id + num_experts_per_partition - 1 + + with torch.inference_mode(): + out = ep_moe( + hidden_states=a, + w1=w1, + w2=w2, + router_logits=score, + top_k=topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale_inv=w1_s, + w2_scale_inv=w2_s, + block_shape=block_size, + num_experts=E, + num_experts_per_partition=num_experts_per_partition, + start_expert_id=start_id, + end_expert_id=end_id, + ) + ref_out = ep_moe( + hidden_states=a, + w1=w1_ref, + w2=w2_ref, + router_logits=score, + top_k=topk, + renormalize=False, + use_fp8_w8a8=False, + w1_scale_inv=None, + w2_scale_inv=None, + block_shape=None, + num_experts=E, + num_experts_per_partition=num_experts_per_partition, + start_expert_id=start_id, + end_expert_id=end_id, + ) + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6) + < 0.06 + ) + + def test_w8a8_block_fp8_ep_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.ep_size, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + ep_size=params[4], + topk=params[5], + block_size=params[6], + dtype=params[7], + seed=params[8], + ): + self._w8a8_block_fp8_ep_moe(*params) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + unittest.main(verbosity=2)