[AMD] Add triton awq_dequantize kernel to support AWQ on ROCm (#7661)

This commit is contained in:
Hubert Lu
2025-07-18 14:27:25 -07:00
committed by GitHub
parent c8f31042a8
commit 7750b91ca8
5 changed files with 530 additions and 3 deletions

View File

@@ -43,11 +43,20 @@ try:
except ImportError:
ops = None
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import awq_dequantize, fused_marlin_moe
elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
warnings.warn(f"HIP does not support fused_marlin_moe currently.")
else:
warnings.warn(f"Only CUDA and HIP support AWQ currently.")
logger = logging.getLogger(__name__)
@@ -398,7 +407,6 @@ class AWQLinearMethod(LinearMethodBase):
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
out = awq_dequantize(qweight, scales, qzeros)
out = torch.matmul(reshaped_x, out)