optimize: reduce shulffle and quantization overhead in cutlass_moe sm90 (#8962)
Co-authored-by: 戚余航 <qiyuhang@bytedance.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
|
|||||||
|
|
||||||
if is_cuda:
|
if is_cuda:
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
per_group_transpose,
|
||||||
per_token_group_quant_fp8_hopper_moe_mn_major,
|
per_token_group_quant_fp8_hopper_moe_mn_major,
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
|
|||||||
k,
|
k,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sm100_supported():
|
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
||||||
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
|
||||||
else:
|
if not is_sm100_supported():
|
||||||
rep_a = shuffle_rows(a, a_map, (m * topk, k))
|
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
||||||
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
|
|
||||||
rep_a, expert_offsets, problem_sizes1, 128
|
|
||||||
)
|
|
||||||
w1_scale = w1_scale.contiguous()
|
w1_scale = w1_scale.contiguous()
|
||||||
|
|
||||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||||
@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
|
|||||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||||
silu_and_mul(c1, intermediate)
|
silu_and_mul(c1, intermediate)
|
||||||
|
|
||||||
if is_sm100_supported():
|
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
||||||
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
if not is_sm100_supported():
|
||||||
else:
|
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
||||||
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
|
|
||||||
intermediate, expert_offsets, problem_sizes2, 128
|
|
||||||
)
|
|
||||||
w2_scale = w2_scale.contiguous()
|
w2_scale = w2_scale.contiguous()
|
||||||
|
|
||||||
fp8_blockwise_scaled_grouped_mm(
|
fp8_blockwise_scaled_grouped_mm(
|
||||||
|
|||||||
@@ -1356,3 +1356,62 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
|
|||||||
expert_tokens_alignment,
|
expert_tokens_alignment,
|
||||||
)
|
)
|
||||||
return a_q, sfa
|
return a_q, sfa
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _per_group_transpose(
|
||||||
|
data_ptr: torch.Tensor,
|
||||||
|
trans_data_ptr: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
k: int,
|
||||||
|
M_ALIGNMENT: tl.constexpr,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
):
|
||||||
|
expert_id = tl.program_id(0)
|
||||||
|
m_id = tl.program_id(1)
|
||||||
|
k_id = tl.program_id(2)
|
||||||
|
|
||||||
|
curr_expert_offset = tl.load(expert_offsets + expert_id)
|
||||||
|
next_expert_offset = tl.load(expert_offsets + expert_id + 1)
|
||||||
|
num_tokens_of_expert = next_expert_offset - curr_expert_offset
|
||||||
|
tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
|
||||||
|
tl.multiple_of(next_expert_offset, M_ALIGNMENT)
|
||||||
|
|
||||||
|
data_start_ptr = data_ptr + curr_expert_offset * k
|
||||||
|
trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
|
||||||
|
|
||||||
|
k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
k_mask = k_coord < k
|
||||||
|
for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
|
||||||
|
m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
m_mask = m_coord < num_tokens_of_expert
|
||||||
|
off = m_coord[:, None] * k + k_coord[None, :]
|
||||||
|
trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
|
||||||
|
mask = m_mask[:, None] & k_mask[None, :]
|
||||||
|
|
||||||
|
data = tl.load(data_start_ptr + off, mask=mask)
|
||||||
|
tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def per_group_transpose(
|
||||||
|
a: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
M_ALIGNMENT: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert a.dim() == 2
|
||||||
|
assert a.is_contiguous(), "`a` is not contiguous"
|
||||||
|
|
||||||
|
m, k = a.size()
|
||||||
|
trans_a = torch.empty_like(a)
|
||||||
|
num_experts = expert_offsets.size(0) - 1
|
||||||
|
|
||||||
|
grid = lambda META: (
|
||||||
|
num_experts,
|
||||||
|
triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
|
||||||
|
triton.cdiv(k, META["BLOCK_SIZE_K"]),
|
||||||
|
)
|
||||||
|
_per_group_transpose[grid](
|
||||||
|
a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
|
||||||
|
)
|
||||||
|
return trans_a
|
||||||
|
|||||||
Reference in New Issue
Block a user