diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 6dadf0d0f..262f1ae39 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams -from sglang.srt.layers.utils import is_sm100_supported +from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported from sglang.srt.utils import is_cuda _is_cuda = is_cuda() @@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8( if is_cuda: from sglang.srt.layers.quantization.fp8_kernel import ( + per_group_transpose, per_token_group_quant_fp8_hopper_moe_mn_major, sglang_per_token_group_quant_fp8, ) @@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8( k, ) - if is_sm100_supported(): - a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128) - rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) - rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) - else: - rep_a = shuffle_rows(a, a_map, (m * topk, k)) - rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major( - rep_a, expert_offsets, problem_sizes1, 128 - ) + a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128) + rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) + rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) + + if not is_sm100_supported(): + rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets) w1_scale = w1_scale.contiguous() c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) @@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8( intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) silu_and_mul(c1, intermediate) - if is_sm100_supported(): - intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) - else: - intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major( - intermediate, expert_offsets, problem_sizes2, 128 - ) + intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) + if not is_sm100_supported(): + a2_scale = per_group_transpose(a2_scale, expert_offsets) w2_scale = w2_scale.contiguous() fp8_blockwise_scaled_grouped_mm( diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 16d1a4d7f..c3be57649 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -1356,3 +1356,62 @@ def per_token_group_quant_fp8_hopper_moe_mn_major( expert_tokens_alignment, ) return a_q, sfa + + +@triton.jit +def _per_group_transpose( + data_ptr: torch.Tensor, + trans_data_ptr: torch.Tensor, + expert_offsets: torch.Tensor, + k: int, + M_ALIGNMENT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + expert_id = tl.program_id(0) + m_id = tl.program_id(1) + k_id = tl.program_id(2) + + curr_expert_offset = tl.load(expert_offsets + expert_id) + next_expert_offset = tl.load(expert_offsets + expert_id + 1) + num_tokens_of_expert = next_expert_offset - curr_expert_offset + tl.multiple_of(curr_expert_offset, M_ALIGNMENT) + tl.multiple_of(next_expert_offset, M_ALIGNMENT) + + data_start_ptr = data_ptr + curr_expert_offset * k + trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k + + k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_coord < k + for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)): + m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = m_coord < num_tokens_of_expert + off = m_coord[:, None] * k + k_coord[None, :] + trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert + mask = m_mask[:, None] & k_mask[None, :] + + data = tl.load(data_start_ptr + off, mask=mask) + tl.store(trans_data_start_ptr + trans_off, data, mask=mask) + + +def per_group_transpose( + a: torch.Tensor, + expert_offsets: torch.Tensor, + M_ALIGNMENT: int = 1, +) -> torch.Tensor: + assert a.dim() == 2 + assert a.is_contiguous(), "`a` is not contiguous" + + m, k = a.size() + trans_a = torch.empty_like(a) + num_experts = expert_offsets.size(0) - 1 + + grid = lambda META: ( + num_experts, + triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]), + triton.cdiv(k, META["BLOCK_SIZE_K"]), + ) + _per_group_transpose[grid]( + a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8 + ) + return trans_a