optimize: reduce shulffle and quantization overhead in cutlass_moe sm90 (#8962)
Co-authored-by: 戚余航 <qiyuhang@bytedance.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
|
||||
|
||||
if is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_group_transpose,
|
||||
per_token_group_quant_fp8_hopper_moe_mn_major,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
|
||||
k,
|
||||
)
|
||||
|
||||
if is_sm100_supported():
|
||||
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
else:
|
||||
rep_a = shuffle_rows(a, a_map, (m * topk, k))
|
||||
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
rep_a, expert_offsets, problem_sizes1, 128
|
||||
)
|
||||
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
|
||||
if not is_sm100_supported():
|
||||
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
||||
w1_scale = w1_scale.contiguous()
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
if is_sm100_supported():
|
||||
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
||||
else:
|
||||
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
intermediate, expert_offsets, problem_sizes2, 128
|
||||
)
|
||||
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
||||
if not is_sm100_supported():
|
||||
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
||||
w2_scale = w2_scale.contiguous()
|
||||
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
|
||||
@@ -1356,3 +1356,62 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
expert_tokens_alignment,
|
||||
)
|
||||
return a_q, sfa
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_group_transpose(
|
||||
data_ptr: torch.Tensor,
|
||||
trans_data_ptr: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
k: int,
|
||||
M_ALIGNMENT: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
expert_id = tl.program_id(0)
|
||||
m_id = tl.program_id(1)
|
||||
k_id = tl.program_id(2)
|
||||
|
||||
curr_expert_offset = tl.load(expert_offsets + expert_id)
|
||||
next_expert_offset = tl.load(expert_offsets + expert_id + 1)
|
||||
num_tokens_of_expert = next_expert_offset - curr_expert_offset
|
||||
tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
|
||||
tl.multiple_of(next_expert_offset, M_ALIGNMENT)
|
||||
|
||||
data_start_ptr = data_ptr + curr_expert_offset * k
|
||||
trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
|
||||
|
||||
k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
k_mask = k_coord < k
|
||||
for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
|
||||
m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
m_mask = m_coord < num_tokens_of_expert
|
||||
off = m_coord[:, None] * k + k_coord[None, :]
|
||||
trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
|
||||
mask = m_mask[:, None] & k_mask[None, :]
|
||||
|
||||
data = tl.load(data_start_ptr + off, mask=mask)
|
||||
tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
|
||||
|
||||
|
||||
def per_group_transpose(
|
||||
a: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
M_ALIGNMENT: int = 1,
|
||||
) -> torch.Tensor:
|
||||
assert a.dim() == 2
|
||||
assert a.is_contiguous(), "`a` is not contiguous"
|
||||
|
||||
m, k = a.size()
|
||||
trans_a = torch.empty_like(a)
|
||||
num_experts = expert_offsets.size(0) - 1
|
||||
|
||||
grid = lambda META: (
|
||||
num_experts,
|
||||
triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
|
||||
triton.cdiv(k, META["BLOCK_SIZE_K"]),
|
||||
)
|
||||
_per_group_transpose[grid](
|
||||
a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
|
||||
)
|
||||
return trans_a
|
||||
|
||||
Reference in New Issue
Block a user