Better optimization log for gpt-oss model (#8953)

This commit is contained in:
Xiaoyu Zhang
2025-08-08 15:11:48 +08:00
committed by GitHub
parent 774b47f3f1
commit 0d1e27a0c5
2 changed files with 11 additions and 4 deletions

View File

@@ -24,6 +24,7 @@ from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_hip,
log_info_on_rank0,
next_power_of_2,
round_up,
set_weight_attrs,
@@ -34,7 +35,6 @@ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
if is_flashinfer_available():
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (
mxfp8_quantize,
shuffle_matrix_a,
@@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps
)
if is_cuda() and torch.cuda.get_device_capability()[0] == 10:
if _is_sm100_supported:
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
@@ -331,8 +331,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.use_flashinfer:
logger.info(
"Shuffling MoE weights for FlashInfer, it might take a while..."
log_info_on_rank0(
logger,
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
)
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),

View File

@@ -488,8 +488,14 @@ class ServerArgs:
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
logger.info(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True