Add AWQ quantization support for NPU. (#10158)
Co-authored-by: Alisehen <814073252@qq.com> Co-authored-by: Yaochen Han <48639761+Alisehen@users.noreply.github.com> Co-authored-by: Zhengda Qin <zhengdqin@gmail.com>
This commit is contained in:
@@ -51,6 +51,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod",
|
||||
"AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod",
|
||||
"AWQLinearAscendMethod",
|
||||
"GPTQMarlinLinearMethod",
|
||||
"Fp8LinearMethod",
|
||||
"BlockInt8LinearMethod",
|
||||
|
||||
@@ -31,6 +31,7 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
@@ -39,11 +40,16 @@ if TYPE_CHECKING:
|
||||
CombineInput,
|
||||
)
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip, is_xpu
|
||||
from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_xpu = is_xpu()
|
||||
_is_npu = is_npu()
|
||||
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
awq_dequantize,
|
||||
@@ -117,12 +123,17 @@ class AWQConfig(QuantizationConfig):
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
if _is_npu:
|
||||
raise NotImplementedError(
|
||||
'NPU hardware does not support "get_min_capability" feature.'
|
||||
)
|
||||
else:
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
@@ -146,6 +157,16 @@ class AWQConfig(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[LinearMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if _is_npu:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQLinearAscendMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEAscendMethod(self)
|
||||
return None
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
@@ -575,6 +596,64 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class AWQLinearAscendMethod(AWQLinearMethod):
|
||||
"""Linear method for AWQ on Ascend.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ quantization config.
|
||||
"""
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
qweight_tmp = torch.zeros_like(layer.qweight.data)
|
||||
qzeros_tmp = layer.qzeros.data
|
||||
qzeros_list = []
|
||||
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
|
||||
for i in range(0, self.quant_config.pack_factor):
|
||||
shift_num = shifts[i] * 4
|
||||
qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
|
||||
qweight_tmp.bitwise_or_(
|
||||
((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i))
|
||||
)
|
||||
|
||||
qweight_tmp.bitwise_xor_(0x88888888)
|
||||
|
||||
qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1)
|
||||
qzeros_tmp = -(qzeros_tmp - 8)
|
||||
qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype)
|
||||
|
||||
layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False)
|
||||
layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
if bias is not None and bias.dtype == torch.bfloat16:
|
||||
bias = bias.float()
|
||||
|
||||
out = torch_npu.npu_weight_quant_batchmatmul(
|
||||
reshaped_x,
|
||||
qweight,
|
||||
antiquant_scale=scales,
|
||||
antiquant_offset=qzeros,
|
||||
antiquant_group_size=self.quant_config.group_size,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
@@ -677,7 +756,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace(device, 4)
|
||||
if not _is_npu:
|
||||
layer.workspace = marlin_make_workspace(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
@@ -785,3 +865,95 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
).to(orig_dtype)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
|
||||
class AWQMoEAscendMethod(AWQMoEMethod):
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data)
|
||||
w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data)
|
||||
w13_qzeros_list = []
|
||||
w2_qzeros_list = []
|
||||
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
for i in range(0, self.quant_config.pack_factor):
|
||||
shift_num = shifts[i] * 4
|
||||
w13_qzeros_list.append(
|
||||
(layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
|
||||
)
|
||||
w2_qzeros_list.append(
|
||||
(layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
|
||||
)
|
||||
w13_qweight_tmp.bitwise_or_(
|
||||
((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i)))
|
||||
& (0xF << (4 * i))
|
||||
)
|
||||
w2_qweight_tmp.bitwise_or_(
|
||||
((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i)))
|
||||
& (0xF << (4 * i))
|
||||
)
|
||||
|
||||
w13_qweight_tmp.bitwise_xor_(0x88888888)
|
||||
w2_qweight_tmp.bitwise_xor_(0x88888888)
|
||||
|
||||
w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape(
|
||||
layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1
|
||||
)
|
||||
w13_qzeros_tmp = -(w13_qzeros_tmp - 8)
|
||||
w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype)
|
||||
w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape(
|
||||
layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1
|
||||
)
|
||||
w2_qzeros_tmp = -(w2_qzeros_tmp - 8)
|
||||
w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype)
|
||||
|
||||
layer.register_parameter(
|
||||
"w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
assert (
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
output = npu_fused_experts(
|
||||
hidden_states=x,
|
||||
w13=layer.w13_qweight,
|
||||
w13_scale=layer.w13_scales,
|
||||
w13_offset=layer.w13_qzeros,
|
||||
w2=layer.w2_qweight,
|
||||
w2_scale=layer.w2_scales,
|
||||
w2_offset=layer.w2_qzeros,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=topk_ids.shape[1],
|
||||
use_wna16=True,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
@@ -337,3 +337,32 @@ def awq_gemm_triton(
|
||||
result = result.sum(0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def awq_dequantize_decomposition(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qweight_tmp = qweight
|
||||
qzeros_tmp = zeros
|
||||
qweight_list = []
|
||||
qzeros_list = []
|
||||
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
for i in range(0, 8):
|
||||
shift_num = shifts[i] * 4
|
||||
qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
|
||||
qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF)
|
||||
qzeros_tmp = (
|
||||
torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype)
|
||||
)
|
||||
qweight_tmp = (
|
||||
torch.cat(qweight_list, dim=-1)
|
||||
.reshape(qweight_tmp.shape[0], -1)
|
||||
.to(scales.dtype)
|
||||
)
|
||||
res = (
|
||||
qweight_tmp.reshape(qzeros_tmp.shape[0], -1, qzeros_tmp.shape[1])
|
||||
- qzeros_tmp.unsqueeze(1)
|
||||
) * scales.unsqueeze(1)
|
||||
return res.reshape(qweight_tmp.shape[0], -1)
|
||||
|
||||
@@ -102,7 +102,12 @@ def npu_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
**kwargs,
|
||||
):
|
||||
w13_offset = kwargs.get("w13_offset", None)
|
||||
w2_offset = kwargs.get("w2_offset", None)
|
||||
use_wna16 = kwargs.get("use_wna16", False)
|
||||
|
||||
original_shape = hidden_states.shape
|
||||
original_dtype = hidden_states.dtype
|
||||
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
|
||||
@@ -127,12 +132,22 @@ def npu_fused_experts(
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
if not use_wna16:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
scale_args13 = {
|
||||
"scale": [w13_scale.to(scale_dtype)],
|
||||
"per_token_scale": [pertoken_scale],
|
||||
}
|
||||
else:
|
||||
scale_args13 = {
|
||||
"antiquant_scale": [w13_scale],
|
||||
"antiquant_offset": [w13_offset],
|
||||
}
|
||||
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w13],
|
||||
scale=[w13_scale.to(scale_dtype)],
|
||||
per_token_scale=[pertoken_scale],
|
||||
**scale_args13,
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
@@ -141,13 +156,20 @@ def npu_fused_experts(
|
||||
)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
if not use_wna16:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
scale_args2 = {
|
||||
"scale": [w2_scale.to(scale_dtype)],
|
||||
"per_token_scale": [pertoken_scale],
|
||||
}
|
||||
else:
|
||||
scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]}
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale.to(scale_dtype)],
|
||||
per_token_scale=[pertoken_scale],
|
||||
**scale_args2,
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
|
||||
@@ -612,6 +612,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if _is_npu:
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
class LayeredModelLoader(DefaultModelLoader):
|
||||
|
||||
@@ -189,6 +189,10 @@ elif _is_npu:
|
||||
import custom_ops # noqa: F401
|
||||
import sgl_kernel_npu # noqa: F401
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
awq_dequantize_decomposition as awq_dequantize,
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -2965,7 +2969,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||
# AWQ compatible
|
||||
if _is_cuda or _is_hip:
|
||||
if _is_cuda or _is_hip or _is_npu:
|
||||
w = awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
|
||||
@@ -510,6 +510,8 @@ def get_available_gpu_memory(
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch NPU context.",
|
||||
)
|
||||
if empty_cache:
|
||||
torch.npu.empty_cache()
|
||||
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
|
||||
|
||||
if distributed:
|
||||
|
||||
Reference in New Issue
Block a user