Regression fix to AMD/ROCm from recent change (#2606)
This commit is contained in:
@@ -11,12 +11,17 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
||||||
|
|
||||||
|
not_hip = False
|
||||||
|
if not is_hip():
|
||||||
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
|
|
||||||
|
not_hip = True
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
@@ -268,7 +273,7 @@ def moe_align_block_size(
|
|||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
# FIXME(zhyncs)
|
# FIXME(zhyncs)
|
||||||
if num_experts >= 256:
|
if not_hip and num_experts >= 256:
|
||||||
sgl_moe_align_block_size(
|
sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
|
|||||||
Reference in New Issue
Block a user