From 7722c11c1d2a2da5b914f3e043b7e8fcd182c0f5 Mon Sep 17 00:00:00 2001 From: HAI Date: Thu, 26 Dec 2024 20:22:14 -0800 Subject: [PATCH] Regression fix to AMD/ROCm from recent change (#2606) --- .../srt/layers/moe/fused_moe_triton/fused_moe.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 108561842..aa649254d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -11,12 +11,17 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl -from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from vllm import _custom_ops as ops from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import direct_register_custom_op, get_device_name +from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip + +not_hip = False +if not is_hip(): + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size + + not_hip = True logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @@ -268,7 +273,7 @@ def moe_align_block_size( ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) # FIXME(zhyncs) - if num_experts >= 256: + if not_hip and num_experts >= 256: sgl_moe_align_block_size( topk_ids, num_experts,