diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index 8863142c5..cc9f4588c 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -178,10 +178,11 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1 ### Example: Serving with 8 A100/A800 with AWQ Quantization -AWQ does not support BF16, so add the `--dtype half` flag if AWQ is used for quantization. One example is as follows: +Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance. +One example is as follows: ```bash -python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --dtype half +python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16 ``` diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 8bc8d12b0..bb16e2747 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -258,6 +258,7 @@ class ModelConfig: "experts_int8", "w8a8_int8", "w8a8_fp8", + "moe_wna16", ] compatible_quantization_methods = { "w8a8_int8": ["compressed-tensors", "compressed_tensors"], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 946d194a1..7a5a6d3cd 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -52,6 +52,257 @@ if _is_cuda or _is_hip: from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size +@triton.jit +def write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, +): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not even_Ks: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -496,6 +747,7 @@ def invoke_fused_moe_kernel( C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -508,6 +760,7 @@ def invoke_fused_moe_kernel( use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, block_shape: Optional[List[int]] = None, no_combine: bool = False, ) -> None: @@ -548,8 +801,9 @@ def invoke_fused_moe_kernel( assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16: + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None @@ -565,43 +819,90 @@ def invoke_fused_moe_kernel( else: even_Ks = False - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2] - padded_size, - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - even_Ks=even_Ks, - **config, - ) + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, + **config, + ) + + else: + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padded_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, + **config, + ) def get_config_file_name( @@ -750,6 +1051,7 @@ def try_get_optimal_moe_config( def get_config_dtype_str( dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, + use_int4_w4a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, use_int8_w8a8: Optional[bool] = False, ): @@ -757,6 +1059,8 @@ def get_config_dtype_str( return "fp8_w8a8" elif use_int8_w8a8: return "int8_w8a8" + elif use_int4_w4a16: + return "int4_w4a16" elif use_int8_w8a16: return "int8_w8a16" elif dtype == torch.float: @@ -776,8 +1080,11 @@ def inplace_fused_experts( use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -793,8 +1100,11 @@ def inplace_fused_experts( use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + use_int4_w4a16, w1_scale, w2_scale, + w1_zp, + w2_zp, a1_scale, a2_scale, block_shape, @@ -811,8 +1121,11 @@ def inplace_fused_experts_fake( use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -838,8 +1151,11 @@ def outplace_fused_experts( use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -856,8 +1172,11 @@ def outplace_fused_experts( use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + use_int4_w4a16, w1_scale, w2_scale, + w1_zp, + w2_zp, a1_scale, a2_scale, block_shape, @@ -875,8 +1194,11 @@ def outplace_fused_experts_fake( use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -904,8 +1226,11 @@ def fused_experts( use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -923,8 +1248,11 @@ def fused_experts( use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + use_int4_w4a16, w1_scale, w2_scale, + w1_zp, + w2_zp, a1_scale, a2_scale, block_shape, @@ -941,8 +1269,11 @@ def fused_experts( use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + use_int4_w4a16, w1_scale, w2_scale, + w1_zp, + w2_zp, a1_scale, a2_scale, block_shape, @@ -961,8 +1292,11 @@ def fused_experts_impl( use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -977,7 +1311,12 @@ def fused_experts_impl( padded_size = 0 # Check constraints. - assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" + else: + assert ( + hidden_states.shape[1] == w1.shape[2] - padded_size + ), "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -994,6 +1333,7 @@ def fused_experts_impl( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype, ) @@ -1075,6 +1415,7 @@ def fused_experts_impl( intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -1087,6 +1428,7 @@ def fused_experts_impl( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, block_shape=block_shape, ) if activation == "silu": @@ -1116,6 +1458,7 @@ def fused_experts_impl( ), a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -1128,6 +1471,7 @@ def fused_experts_impl( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, block_shape=block_shape, ) @@ -1173,8 +1517,11 @@ def fused_moe( use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -1204,6 +1551,9 @@ def fused_moe( products for w1 and w2. Defaults to False. - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -1243,8 +1593,11 @@ def fused_moe( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 9667ddf4d..bee35c9c7 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -61,6 +61,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config +from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.vocab_parallel_embedding import ( @@ -75,6 +76,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "modelopt": ModelOptFp8Config, "w8a8_int8": W8A8Int8Config, "w8a8_fp8": W8A8Fp8Config, + "moe_wna16": MoeWNA16Config, "compressed-tensors": CompressedTensorsConfig, } diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py new file mode 100644 index 000000000..b99016d95 --- /dev/null +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -0,0 +1,501 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py + +import logging +from typing import Any, Callable, Dict, List, Optional + +import torch + +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.parallel_state import get_tp_group +from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod +from sglang.srt.layers.quantization.awq import AWQConfig +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig +from sglang.srt.utils import get_device_capability, set_weight_attrs + +logger = logging.getLogger(__name__) + + +class MoeWNA16Config(QuantizationConfig): + """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" + + def __init__( + self, + linear_quant_method: str, + weight_bits: int, + group_size: int, + has_zp: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[List[str]], + full_config: Dict[str, Any], + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.has_zp = has_zp + self.bit8_pack_factor = 8 // self.weight_bits + self.lm_head_quantized = lm_head_quantized + self.linear_quant_method = linear_quant_method + self.full_config = full_config + self.use_marlin = False + # Avoid circular import + + if self.linear_quant_method == "gptq": + self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) + elif self.linear_quant_method == "awq": + capability_tuple = get_device_capability() + device_capability = ( + -1 + if capability_tuple is None + else capability_tuple[0] * 10 + capability_tuple[1] + ) + awq_min_capability = AWQConfig.get_min_capability() + if device_capability < awq_min_capability: + raise ValueError( + "The quantization method moe_wna16 + awq is not supported " + "for the current GPU. " + f"Minimum capability: {awq_min_capability}. " + f"Current capability: {device_capability}." + ) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + if modules_to_not_convert is None: + self.modules_to_not_convert = [] + else: + self.modules_to_not_convert = modules_to_not_convert + + @classmethod + def get_name(cls) -> str: + return "moe_wna16" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + def get_scaled_act_names(self) -> List[str]: + raise NotImplementedError + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + if quant_method == "gptq": + has_zp = not cls.get_from_keys(config, ["sym"]) + modules_to_not_convert = [] + elif quant_method == "awq": + has_zp = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + return cls( + quant_method, + weight_bits, + group_size, + has_zp, + lm_head_quantized, + modules_to_not_convert, + config, + ) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) + if can_convert and user_quant == "moe_wna16": + return cls.get_name() + return None + + @classmethod + def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + desc_act = quant_config.get("desc_act") + + capability_tuple = get_device_capability() + device_capability = ( + -1 + if capability_tuple is None + else capability_tuple[0] * 10 + capability_tuple[1] + ) + # Avoid circular import + awq_min_capability = AWQConfig.get_min_capability() + + gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8] + awq_compatible = ( + quant_method == "awq" + and num_bits == 4 + and device_capability >= awq_min_capability + ) + + return gptq_compatible or awq_compatible + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + # avoid circular import + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + + if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + elif isinstance(layer, LinearBase): + + if self.linear_quant_method == "gptq": + if self.use_marlin: + return GPTQMarlinConfig.from_config( + self.full_config + ).get_quant_method(layer, prefix) + else: + return GPTQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) + elif self.linear_quant_method == "awq": + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + elif isinstance(layer, FusedMoE): + return MoeWNA16Method(self) + return None + + +def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class MoeWNA16Method: + """Linear method for MOE WNA16 (W8A16/W4A16) quantization. + + Args: + quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. + """ + + def __new__(cls, *args, **kwargs): + # avoid circular import + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config: MoeWNA16Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + layer.quant_config = self.quant_config + bit8_pack_factor = self.quant_config.bit8_pack_factor + group_size = self.quant_config.group_size + group_size_div_factor = 1 + + # make intermediate_size and hidden_size diviable by group_size + # we reduce the group size to ensure that + # and we would repeat the loaded_weight later + while intermediate_size_per_partition % group_size or hidden_size % group_size: + group_size = group_size // 2 + group_size_div_factor *= 2 + assert group_size >= 32 + layer.group_size = group_size + layer.group_size_div_factor = group_size_div_factor + + strategy = FusedMoeWeightScaleSupported.GROUP.value + extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False}) + + assert "weight_loader" in extra_weight_attrs + weight_loader = extra_weight_attrs["weight_loader"] + wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader) + extra_weight_attrs["weight_loader"] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // bit8_pack_factor, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // bit8_pack_factor, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + w13_scales = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + if self.quant_config.has_zp: + w13_qzeros = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition // bit8_pack_factor, + hidden_size // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size // bit8_pack_factor, + intermediate_size_per_partition // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + if self.quant_config.linear_quant_method == "gptq": + # some param are unused, but we need to init them in order to + # load weights + invalid_param_keys = ["w13_g_idx", "w2_g_idx"] + if not self.quant_config.has_zp: + invalid_param_keys += ["w13_qzeros", "w2_qzeros"] + for key in invalid_param_keys: + param = torch.nn.Parameter( + torch.empty((0,), dtype=torch.int32), requires_grad=False + ) + layer.register_parameter(key, param) + set_weight_attrs(param, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, + ) -> torch.Tensor: + # avoid circular import + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts + + assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + ) + + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + + return fused_experts( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size], + no_combine=no_combine, + ) + + @staticmethod + def get_weight_loader(layer, weight_loader): + + def convert_awq_tensor(tensor, tensor_type): + # convert awq qweight/qzeros to a standard format (assume int4) + # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) + # qzeros: (k // group_size, n // pack_factor_bit32) -> + # (n // pack_factor_bit8, k // group_size) + # pack_factor_bit32 = 32 // weight_bits + # pack_factor_bit8 = 8 // weight_bits + + # 0. suppose origin shape (a, b), dtype int32 + # 1. convert to uint8, shape (a, b) -> (a, 4 * b) + size0 = tensor.size(0) + tensor = tensor.view(torch.uint8) + + # 2. unpack to uint4 (only when weight_bits == 4) + # shape (a, 4 * b) -> (a, 4 * b, 2) + shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + + # 3. change order, see + # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py + # shape -> (a, 4 * b * pack_factor_bit8) + reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7] + tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order] + tensor = tensor.view(size0, -1) + + # 4. transpose, shape -> (4 * b * pack_factor_bit8, a) + tensor = tensor.T.contiguous() + + # 5. repack (only when weight_bits == 4) + # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8) + # qzeros shape -> (4 * b, a) + + if tensor_type == "qweight": + tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] + elif tensor_type == "qzeros": + tensor = tensor[1::2, :] * 16 + tensor[::2, :] + return tensor + + def convert_gptq_int4_qzeros(tensor): + tensor = tensor.view(torch.uint8) + shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + tensor = tensor + 1 + tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 + return tensor + + def moe_wna16_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ): + if "g_idx" in weight_name: + return + if not layer.quant_config.has_zp and "qzeros" in weight_name: + return + + device = get_tp_group().device + tp_rank = get_tensor_model_parallel_rank() + loaded_weight = loaded_weight.to(device) + shard_size = layer.intermediate_size_per_partition + + # convert gptq and awq weight to a standard format + if layer.quant_config.linear_quant_method == "awq": + assert layer.quant_config.weight_bits == 4 + if "weight" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, "qweight") + elif "zeros" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") + else: + loaded_weight = loaded_weight.T + elif layer.quant_config.linear_quant_method == "gptq": + assert layer.quant_config.weight_bits in [4, 8] + if "weight" in weight_name: + loaded_weight = loaded_weight.T.contiguous().view(torch.uint8) + elif "zeros" in weight_name: + # add 1 to gptq qzeros to align with awq + loaded_weight = loaded_weight.view(torch.uint8) + if layer.quant_config.weight_bits == 4: + loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T + else: + loaded_weight = loaded_weight.T + 1 + else: + loaded_weight = loaded_weight.T + + # repeat the qzeros/scales to fit new group size + if ( + layer.group_size_div_factor > 1 + and "qzeros" in weight_name + or "scales" in weight_name + ): + loaded_weight = loaded_weight.repeat_interleave( + layer.group_size_div_factor, 1 + ) + + if "w13_qzeros" in weight_name: + tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[ + tp_rank + ] + if shard_id == "w1": + param.data[expert_id, : shard_size // 2] = tensor + else: + param.data[expert_id, shard_size // 2 :] = tensor + elif "w2_qzeros" in weight_name: + param.data[expert_id] = loaded_weight.view( + loaded_weight.size(0), layer.tp_size, -1 + )[:, tp_rank] + else: + weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) + + return moe_wna16_weight_loader diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index c6524af49..5615dbca3 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py from types import MappingProxyType -from typing import List, Mapping, Tuple, Union +from typing import List, Mapping, Optional, Tuple, Union import torch diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f025a4b55..3bdb7de9c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -496,6 +496,7 @@ class ServerArgs: "modelopt", "w8a8_int8", "w8a8_fp8", + "moe_wna16", ], help="The quantization method.", ) diff --git a/test/srt/test_triton_moe_wna16.py b/test/srt/test_triton_moe_wna16.py new file mode 100644 index 000000000..2613586a8 --- /dev/null +++ b/test/srt/test_triton_moe_wna16.py @@ -0,0 +1,238 @@ +from typing import Optional + +import pytest +import torch + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe + +NUM_EXPERTS = [8, 64] +TOP_KS = [2, 6] + + +def quantize_weights( + w: torch.Tensor, + quant_type: str, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert quant_type in ["w4a16", "w4a16b8", "w8a16", "w8a16b128"] + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + if quant_type == "w4a16": + max_q_val = 15 + min_q_val = 0 + elif quant_type == "w4a16b8": + max_q_val = 7 + min_q_val = -1 + elif quant_type == "w8a16": + max_q_val = 255 + min_q_val = 0 + elif quant_type == "w8a16b128": + max_q_val = 127 + min_q_val = -128 + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + w_s = (max_val - min_val).clamp(min=1e-5) / max_q_val + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type == "w4a16b8": + w_q += 8 + elif quant_type == "w8a16b128": + w_q += 128 + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def torch_moe(a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( + 0, 1 + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +# fork from https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_moe.py +@pytest.mark.parametrize("m", [1, 32, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("group_size", [64, 128]) +@pytest.mark.parametrize("has_zp", [True, False]) +@pytest.mark.parametrize("weight_bits", [8]) # [4, 8]) +def test_fused_moe_wn16( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + group_size: int, + has_zp: bool, + weight_bits: int, +): + print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + if weight_bits == 4: + pack_factor = 2 + quant_type = "w4a16" if has_zp else "w4a16b8" + elif weight_bits == 8: + pack_factor = 1 + quant_type = "w8a16" if has_zp else "w8a16b128" + + w1_ref = w1.clone() + w2_ref = w2.clone() + w1_qweight = torch.empty( + (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 + ) + w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) + w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) + w1_qzeros = torch.empty( + (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 + ) + w2_qzeros = torch.empty( + (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 + ) + + for i in range(e * 2): + expert_id = i % e + if i // e == 0: + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w1, + w1_ref, + w1_qweight, + w1_scales, + w1_qzeros, + ) + else: + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w2, + w2_ref, + w2_qweight, + w2_scales, + w2_qzeros, + ) + weight, qweight, scales, qzeros = quantize_weights( + w[expert_id].T, quant_type, group_size, has_zp, False + ) + weight = weight.T + qweight = qweight.T.contiguous().to(torch.uint8) + scales = scales.T + if has_zp: + qzeros = qzeros.T.contiguous().to(torch.uint8) + if weight_bits == 4: + qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] + if has_zp: + qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] + + w_ref[expert_id] = weight + w_qweight[expert_id] = qweight + w_scales[expert_id] = scales + if has_zp: + w_qzeros[expert_id] = qzeros + + triton_output = fused_moe( + a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size], + ) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)