ROCm: sgl-kernel enablement starting with sgl_moe_align_block (#3287)

This commit is contained in:
HAI
2025-02-04 05:44:44 -08:00
committed by GitHub
parent d39899e85c
commit 2c1a695ff1
6 changed files with 131 additions and 13 deletions

View File

@@ -15,18 +15,10 @@ from vllm import _custom_ops as ops
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import (
direct_register_custom_op,
get_device_name,
is_cuda_available,
is_hip,
)
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
is_cuda = is_cuda_available()
is_hip_flag = is_hip()
if is_cuda:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -415,7 +407,7 @@ def moe_align_block_size(
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
if num_experts >= 224:
if enable_moe_align_block_size_triton or is_hip_flag:
if enable_moe_align_block_size_triton:
moe_align_block_size_triton(
topk_ids,
num_experts,