[main] addrmsnorm + quant fusion optim in Dense Models (#2772)
### What this PR does / why we need it?
This PR fused addrmsnorm op and w8a8 quant op to get better perf.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.2
- vLLM main:
0faf3cc3e8
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -1,19 +1,17 @@
|
||||
from unittest.mock import patch
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tensor():
|
||||
return torch.randn(4, 8, dtype=torch.float16)
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
|
||||
|
||||
def mock_maybe_chunk_residual(x, residual):
|
||||
if x.size(0) != residual.size(0):
|
||||
return residual[:4]
|
||||
|
||||
return residual
|
||||
|
||||
|
||||
@@ -25,69 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
||||
return 2 * x, None, 2 * residual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@pytest.mark.parametrize("residual",
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
def test_RMSNorm_forward(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
|
||||
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
|
||||
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
|
||||
epsilon):
|
||||
x_out = 2 * x
|
||||
residual_out = 2 * residual
|
||||
x_out_quant = x_out.to(torch.int8)
|
||||
residual_out_quant = residual_out.to(torch.int8)
|
||||
return x_out_quant, None, residual_out_quant
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
|
||||
class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
mocker.patch("torch_npu.npu_add_rms_norm",
|
||||
side_effect=mock_add_rms_norm)
|
||||
mocker.patch("torch_npu.npu_add_rms_norm_quant",
|
||||
side_effect=mock_add_rms_norm_quant)
|
||||
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
|
||||
side_effect=lambda x: None)
|
||||
|
||||
# Test case for the most common and basic scenario
|
||||
@pytest.mark.parametrize(
|
||||
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
|
||||
def test_forward_oot_basic(self, residual):
|
||||
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
if residual is not None:
|
||||
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
if is_310p_return:
|
||||
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
|
||||
expected_out_x = expected_arg_x + 1
|
||||
expected_out_residual = expected_arg_x.to(residual.dtype)
|
||||
x_out_expected = 2 * x
|
||||
residual_out_expected = 2 * residual
|
||||
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_rmsnorm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
assert torch.allclose(residual_out, residual_out_expected)
|
||||
else:
|
||||
out_x = layer.forward(dummy_tensor, residual)
|
||||
expected_out_x = dummy_tensor + 1
|
||||
x_out = layer.forward(x, residual)
|
||||
x_out_expected = x + 1
|
||||
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
|
||||
# Test case for flashcomm_v1 scenario
|
||||
def test_forward_oot_with_flashcomm_v1(self):
|
||||
layer = RMSNorm(hidden_size=512, eps=1e-05)
|
||||
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
x_out_expected = 2 * x
|
||||
residual_out_expected = 2 * residual[:4]
|
||||
|
||||
assert residual_out.size(0) == 4
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
assert torch.allclose(residual_out, residual_out_expected)
|
||||
|
||||
# Test case for addrmsnorm + w8a8 quant fusion
|
||||
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
|
||||
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")
|
||||
mock_is_310p.return_value = False
|
||||
mock_get_forward_context = mocker.patch(
|
||||
"vllm_ascend.ops.layernorm.get_forward_context")
|
||||
|
||||
# Simulating a scenario with quant_fusion enabled
|
||||
mock_forward_context = mocker.MagicMock()
|
||||
|
||||
mock_model_instance = mocker.MagicMock()
|
||||
mock_forward_context.model_instance = mock_model_instance
|
||||
mock_model_instance.model.layers = [
|
||||
mocker.MagicMock() for _ in range(2)
|
||||
]
|
||||
|
||||
mock_layer_0 = mock_model_instance.model.layers[0]
|
||||
mock_layer_0.self_attn.qkv_proj = mocker.MagicMock()
|
||||
mock_layer_0.mlp.gate_up_proj = mocker.MagicMock()
|
||||
|
||||
mock_layer_1 = mock_model_instance.model.layers[1]
|
||||
mock_layer_1.self_attn.qkv_proj = mocker.MagicMock()
|
||||
mock_layer_1.mlp.gate_up_proj = mocker.MagicMock()
|
||||
|
||||
mock_quant_method_0_qkv = mocker.MagicMock()
|
||||
mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod()
|
||||
mock_quant_method_0_gate_up = mocker.MagicMock()
|
||||
mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod()
|
||||
mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv
|
||||
mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up
|
||||
|
||||
mock_quant_method_1_qkv = mocker.MagicMock()
|
||||
mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod()
|
||||
mock_quant_method_1_gate_up = mocker.MagicMock()
|
||||
mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod()
|
||||
mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv
|
||||
mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up
|
||||
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
|
||||
mock_forward_context.prefetch_mlp_enabled = False
|
||||
mock_forward_context.layer_idx = 0
|
||||
mock_forward_context.num_hidden_layers = 2
|
||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||
|
||||
# Ensure fusion and layer_idx increment are handled correctly
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
residual = torch.randn(4, 8, dtype=torch.float16)
|
||||
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 1
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 2
|
||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 3
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 4
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_add_rms_norm, mock_is310p):
|
||||
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||
layer = RMSNorm(hidden_size=512, eps=1e-05)
|
||||
|
||||
out_x, out_residual = layer.forward_oot(x, residual)
|
||||
|
||||
expected_out_x = 2 * x
|
||||
expected_out_residual = 2 * residual[:4]
|
||||
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_add_rms_norm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert out_residual.size(0) == 4
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -129,6 +129,22 @@ def set_ascend_forward_context(
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||
|
||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||
# It will be improved later by implementing operator fusion through the FX graph.
|
||||
#
|
||||
# set for addrmsnorm+quant fusion.
|
||||
# this optim now just support dense models due to the specific operators used.
|
||||
# Once the necessary conditions are met, support for MOE models will also be added.
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
||||
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
|
||||
forward_context.layer_idx is not None
|
||||
if addrmsnorm_quant_fusion_enabled:
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
||||
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
|
||||
@@ -35,9 +35,6 @@ def register_model():
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
|
||||
|
||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||
ModelRegistry.register_model(
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen3Config
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
PPMissingLayer, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
|
||||
|
||||
|
||||
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
if quant_config is None:
|
||||
return
|
||||
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
|
||||
assert isinstance(quant_config, AscendQuantConfig), \
|
||||
"Expected quant_config to be an instance of AscendQuantConfig"
|
||||
|
||||
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod):
|
||||
self.input_layernorm = AddRMSNormW8A8Quant(
|
||||
config.hidden_size,
|
||||
layer=self.self_attn.qkv_proj,
|
||||
eps=config.rms_norm_eps)
|
||||
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod):
|
||||
self.post_attention_layernorm = AddRMSNormW8A8Quant(
|
||||
config.hidden_size,
|
||||
layer=self.mlp.gate_up_proj,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
ALL_DECODER_LAYER_TYPES = {
|
||||
"attention": CustomQwen3DecoderLayer,
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class CustomQwen3Model(Qwen2Model):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=CustomQwen3DecoderLayer)
|
||||
|
||||
|
||||
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# add `CustomQwen3Model` to init self.model
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
@@ -18,47 +18,40 @@
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
class AddRMSNormW8A8Quant(RMSNorm):
|
||||
# Fuse AddRmsNorm and W8A8 quantization ops together
|
||||
def _addrmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
layer: Optional[torch.nn.Module] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
layer: torch.nn.Module,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
self.layer = layer
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
self.layer.aclnn_input_scale,
|
||||
self.layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
if layer is not None and not is_310p():
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
else:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
@@ -70,26 +63,49 @@ class AscendRMSNorm(RMSNorm):
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
x, residual = _addrmsnorm_forward_oot(
|
||||
self, x, residual, self.next_need_quant_fusion_linear)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
|
||||
@property
|
||||
def next_need_quant_fusion_linear(self):
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context.addrmsnorm_quant_fusion_enabled or \
|
||||
forward_context.layer_idx == forward_context.num_hidden_layers:
|
||||
return None
|
||||
except AssertionError:
|
||||
return None
|
||||
|
||||
next_linear = None
|
||||
model_instance = forward_context.model_instance
|
||||
layer_idx = forward_context.layer_idx
|
||||
fusion_linear = forward_context.fusion_linear
|
||||
next_linear = None
|
||||
if fusion_linear == "qkv_dense":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].self_attn.qkv_proj
|
||||
forward_context.fusion_linear = "gate_up_dense"
|
||||
elif fusion_linear == "gate_up_dense":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].mlp.gate_up_proj
|
||||
forward_context.fusion_linear = "qkv_dense"
|
||||
# if prefetch_mlp_weight enabled, following accumulation operation
|
||||
# does not need to be repeated
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
forward_context.layer_idx += 1
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
if next_linear is not None and \
|
||||
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
||||
next_linear = None
|
||||
return next_linear
|
||||
|
||||
|
||||
class AscendQuantRMSNorm(AscendRMSNorm):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user