diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 01ecce1a6..c0d558085 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -15,18 +15,18 @@ from vllm import _custom_ops as ops from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip +from sglang.srt.utils import ( + direct_register_custom_op, + get_device_name, + is_cuda_available, + is_hip, +) -is_hip_flag = False -if not is_hip(): - if torch.cuda.is_available(): - from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - else: - sgl_moe_align_block_size = None +is_cuda = is_cuda_available() +is_hip_flag = is_hip() +if is_cuda: + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - is_hip_flag = False -else: - is_hip_flag = True logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0