[main] addrmsnorm + quant fusion optim in Dense Models (#2772)

### What this PR does / why we need it?
This PR fused addrmsnorm op and w8a8 quant op to get better perf.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.2
- vLLM main:
0faf3cc3e8

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-09-16 22:31:38 +08:00
committed by GitHub
parent 88ca8a051c
commit 6b7117dbb7
5 changed files with 211 additions and 270 deletions

View File

@@ -18,47 +18,40 @@
from typing import Optional, Tuple, Union, cast
import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
class AddRMSNormW8A8Quant(RMSNorm):
# Fuse AddRmsNorm and W8A8 quantization ops together
def _addrmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor,
layer: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
def __init__(
self,
hidden_size: int,
layer: torch.nn.Module,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.layer = layer
from vllm_ascend.utils import is_310p
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
if layer is not None and not is_310p():
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
else:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
class AscendRMSNorm(RMSNorm):
@@ -70,26 +63,49 @@ class AscendRMSNorm(RMSNorm):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import is_310p
if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
@property
def next_need_quant_fusion_linear(self):
try:
forward_context = get_forward_context()
if not forward_context.addrmsnorm_quant_fusion_enabled or \
forward_context.layer_idx == forward_context.num_hidden_layers:
return None
except AssertionError:
return None
next_linear = None
model_instance = forward_context.model_instance
layer_idx = forward_context.layer_idx
fusion_linear = forward_context.fusion_linear
next_linear = None
if fusion_linear == "qkv_dense":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_up_dense"
elif fusion_linear == "gate_up_dense":
next_linear = model_instance.model.layers[
layer_idx].mlp.gate_up_proj
forward_context.fusion_linear = "qkv_dense"
# if prefetch_mlp_weight enabled, following accumulation operation
# does not need to be repeated
if not forward_context.prefetch_mlp_enabled:
forward_context.layer_idx += 1
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
if next_linear is not None and \
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
next_linear = None
return next_linear
class AscendQuantRMSNorm(AscendRMSNorm):