feat: support cutlass_moe_fp8 kernel for fusedmoe in sm90 (#8678)
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
|
||||
|
||||
if is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8_hopper_moe_mn_major,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
||||
device = a_q.device
|
||||
device = a.device
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
|
||||
k,
|
||||
)
|
||||
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
if is_sm100_supported():
|
||||
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
else:
|
||||
rep_a = shuffle_rows(a, a_map, (m * topk, k))
|
||||
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
rep_a, expert_offsets, problem_sizes1, 128
|
||||
)
|
||||
w1_scale = w1_scale.contiguous()
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
||||
if is_sm100_supported():
|
||||
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
||||
else:
|
||||
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
intermediate, expert_offsets, problem_sizes2, 128
|
||||
)
|
||||
w2_scale = w2_scale.contiguous()
|
||||
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c2,
|
||||
|
||||
@@ -63,7 +63,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
@@ -619,7 +619,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if (
|
||||
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
||||
and self.cutlass_fp8_supported
|
||||
and is_sm100_supported()
|
||||
and (is_sm100_supported() or is_sm90_supported())
|
||||
):
|
||||
self.ab_strides1 = torch.full(
|
||||
(num_experts,),
|
||||
@@ -1034,7 +1034,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
||||
and self.cutlass_fp8_supported
|
||||
and self.block_quant
|
||||
and is_sm100_supported()
|
||||
and (is_sm100_supported() or is_sm90_supported())
|
||||
):
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
|
||||
return (input,) if self.return_tuple else input
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_sm90_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
||||
torch.version.cuda >= "12.3"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user