[AMD] Add triton awq_dequantize kernel to support AWQ on ROCm (#7661)
This commit is contained in:
@@ -43,11 +43,20 @@ try:
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, fused_marlin_moe
|
||||
elif _is_hip:
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton as awq_dequantize,
|
||||
)
|
||||
|
||||
warnings.warn(f"HIP does not support fused_marlin_moe currently.")
|
||||
else:
|
||||
warnings.warn(f"Only CUDA and HIP support AWQ currently.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -398,7 +407,6 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
out = awq_dequantize(qweight, scales, qzeros)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user