feat: check for is_cuda for sgl_kernel import (#2984)
This commit is contained in:
@@ -15,18 +15,18 @@ from vllm import _custom_ops as ops
|
|||||||
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
from sglang.srt.utils import (
|
||||||
|
direct_register_custom_op,
|
||||||
|
get_device_name,
|
||||||
|
is_cuda_available,
|
||||||
|
is_hip,
|
||||||
|
)
|
||||||
|
|
||||||
is_hip_flag = False
|
is_cuda = is_cuda_available()
|
||||||
if not is_hip():
|
is_hip_flag = is_hip()
|
||||||
if torch.cuda.is_available():
|
if is_cuda:
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
else:
|
|
||||||
sgl_moe_align_block_size = None
|
|
||||||
|
|
||||||
is_hip_flag = False
|
|
||||||
else:
|
|
||||||
is_hip_flag = True
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
|
|||||||
Reference in New Issue
Block a user