Better optimization log for gpt-oss model (#8953)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.srt.utils import (
|
|||||||
is_cuda,
|
is_cuda,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
log_info_on_rank0,
|
||||||
next_power_of_2,
|
next_power_of_2,
|
||||||
round_up,
|
round_up,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
@@ -34,7 +35,6 @@ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
|||||||
|
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
mxfp8_quantize,
|
mxfp8_quantize,
|
||||||
shuffle_matrix_a,
|
shuffle_matrix_a,
|
||||||
@@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
|||||||
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
|
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||||
mx_axis=1, num_warps=num_warps
|
mx_axis=1, num_warps=num_warps
|
||||||
)
|
)
|
||||||
if is_cuda() and torch.cuda.get_device_capability()[0] == 10:
|
if _is_sm100_supported:
|
||||||
constraints = {
|
constraints = {
|
||||||
"is_persistent": True,
|
"is_persistent": True,
|
||||||
"epilogue_subtile": 1,
|
"epilogue_subtile": 1,
|
||||||
@@ -331,8 +331,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
if self.use_flashinfer:
|
if self.use_flashinfer:
|
||||||
logger.info(
|
log_info_on_rank0(
|
||||||
"Shuffling MoE weights for FlashInfer, it might take a while..."
|
logger,
|
||||||
|
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
|
||||||
)
|
)
|
||||||
layer.gemm1_alpha = Parameter(
|
layer.gemm1_alpha = Parameter(
|
||||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||||
|
|||||||
@@ -488,8 +488,14 @@ class ServerArgs:
|
|||||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||||
self.enable_flashinfer_mxfp4_moe = True
|
self.enable_flashinfer_mxfp4_moe = True
|
||||||
self.enable_triton_kernel_moe = False
|
self.enable_triton_kernel_moe = False
|
||||||
|
logger.info(
|
||||||
|
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.enable_triton_kernel_moe = True
|
self.enable_triton_kernel_moe = True
|
||||||
|
logger.info(
|
||||||
|
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||||
|
)
|
||||||
|
|
||||||
self.disable_hybrid_swa_memory = True
|
self.disable_hybrid_swa_memory = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user