Fix triton_kernels import error on some hardwares (#11831)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -31,8 +30,6 @@ if TYPE_CHECKING:
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_hip = is_hip()
|
||||
@@ -143,7 +140,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
if torch.cuda.is_available() and has_triton_kernels:
|
||||
if torch.cuda.is_available() and use_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward as _tk_forward,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user