[main] addrmsnorm + quant fusion optim in Dense Models (#2772)
### What this PR does / why we need it?
This PR fused addrmsnorm op and w8a8 quant op to get better perf.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.2
- vLLM main:
0faf3cc3e8
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -1,19 +1,17 @@
|
|||||||
from unittest.mock import patch
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
from tests.ut.base import PytestBase
|
||||||
@pytest.fixture
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
def dummy_tensor():
|
|
||||||
return torch.randn(4, 8, dtype=torch.float16)
|
|
||||||
|
|
||||||
|
|
||||||
def mock_maybe_chunk_residual(x, residual):
|
def mock_maybe_chunk_residual(x, residual):
|
||||||
if x.size(0) != residual.size(0):
|
if x.size(0) != residual.size(0):
|
||||||
return residual[:4]
|
return residual[:4]
|
||||||
|
|
||||||
return residual
|
return residual
|
||||||
|
|
||||||
|
|
||||||
@@ -25,69 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
|||||||
return 2 * x, None, 2 * residual
|
return 2 * x, None, 2 * residual
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
|
||||||
@pytest.mark.parametrize("residual",
|
epsilon):
|
||||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
x_out = 2 * x
|
||||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
residual_out = 2 * residual
|
||||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
x_out_quant = x_out.to(torch.int8)
|
||||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
residual_out_quant = residual_out.to(torch.int8)
|
||||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
return x_out_quant, None, residual_out_quant
|
||||||
side_effect=mock_maybe_chunk_residual)
|
|
||||||
def test_RMSNorm_forward(mock_maybe_chunk_residual,
|
|
||||||
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
|
|
||||||
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
|
|
||||||
|
|
||||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
|
||||||
|
class TestAscendRMSNorm(PytestBase):
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def context(self, mocker: MockerFixture):
|
||||||
|
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
|
||||||
|
side_effect=mock_maybe_chunk_residual)
|
||||||
|
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||||
|
mocker.patch("torch_npu.npu_add_rms_norm",
|
||||||
|
side_effect=mock_add_rms_norm)
|
||||||
|
mocker.patch("torch_npu.npu_add_rms_norm_quant",
|
||||||
|
side_effect=mock_add_rms_norm_quant)
|
||||||
|
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
|
||||||
|
side_effect=lambda x: None)
|
||||||
|
|
||||||
|
# Test case for the most common and basic scenario
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
|
||||||
|
def test_forward_oot_basic(self, residual):
|
||||||
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||||
|
x = torch.randn(4, 8, dtype=torch.float16)
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
if is_310p_return:
|
x_out_expected = 2 * x
|
||||||
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
|
residual_out_expected = 2 * residual
|
||||||
expected_out_x = expected_arg_x + 1
|
|
||||||
expected_out_residual = expected_arg_x.to(residual.dtype)
|
|
||||||
|
|
||||||
mock_maybe_chunk_residual.assert_called_once()
|
assert torch.allclose(x_out, x_out_expected)
|
||||||
mock_rmsnorm.assert_called_once()
|
assert torch.allclose(residual_out, residual_out_expected)
|
||||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
|
||||||
assert torch.allclose(out_x, expected_out_x)
|
|
||||||
assert torch.allclose(out_residual, expected_out_residual)
|
|
||||||
else:
|
|
||||||
expected_out_x = 2 * dummy_tensor
|
|
||||||
expected_out_residual = 2 * residual
|
|
||||||
mock_maybe_chunk_residual.assert_called_once()
|
|
||||||
mock_add_rmsnorm.assert_called_once()
|
|
||||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
|
||||||
assert torch.allclose(out_x, expected_out_x)
|
|
||||||
assert torch.allclose(out_residual, expected_out_residual)
|
|
||||||
else:
|
else:
|
||||||
out_x = layer.forward(dummy_tensor, residual)
|
x_out = layer.forward(x, residual)
|
||||||
expected_out_x = dummy_tensor + 1
|
x_out_expected = x + 1
|
||||||
|
|
||||||
mock_rmsnorm.assert_called_once()
|
assert torch.allclose(x_out, x_out_expected)
|
||||||
assert torch.allclose(out_x, expected_out_x)
|
|
||||||
|
# Test case for flashcomm_v1 scenario
|
||||||
|
def test_forward_oot_with_flashcomm_v1(self):
|
||||||
|
layer = RMSNorm(hidden_size=512, eps=1e-05)
|
||||||
|
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||||
|
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
|
x_out_expected = 2 * x
|
||||||
|
residual_out_expected = 2 * residual[:4]
|
||||||
|
|
||||||
|
assert residual_out.size(0) == 4
|
||||||
|
assert torch.allclose(x_out, x_out_expected)
|
||||||
|
assert torch.allclose(residual_out, residual_out_expected)
|
||||||
|
|
||||||
|
# Test case for addrmsnorm + w8a8 quant fusion
|
||||||
|
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
|
||||||
|
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")
|
||||||
|
mock_is_310p.return_value = False
|
||||||
|
mock_get_forward_context = mocker.patch(
|
||||||
|
"vllm_ascend.ops.layernorm.get_forward_context")
|
||||||
|
|
||||||
|
# Simulating a scenario with quant_fusion enabled
|
||||||
|
mock_forward_context = mocker.MagicMock()
|
||||||
|
|
||||||
|
mock_model_instance = mocker.MagicMock()
|
||||||
|
mock_forward_context.model_instance = mock_model_instance
|
||||||
|
mock_model_instance.model.layers = [
|
||||||
|
mocker.MagicMock() for _ in range(2)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_layer_0 = mock_model_instance.model.layers[0]
|
||||||
|
mock_layer_0.self_attn.qkv_proj = mocker.MagicMock()
|
||||||
|
mock_layer_0.mlp.gate_up_proj = mocker.MagicMock()
|
||||||
|
|
||||||
|
mock_layer_1 = mock_model_instance.model.layers[1]
|
||||||
|
mock_layer_1.self_attn.qkv_proj = mocker.MagicMock()
|
||||||
|
mock_layer_1.mlp.gate_up_proj = mocker.MagicMock()
|
||||||
|
|
||||||
|
mock_quant_method_0_qkv = mocker.MagicMock()
|
||||||
|
mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod()
|
||||||
|
mock_quant_method_0_gate_up = mocker.MagicMock()
|
||||||
|
mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod()
|
||||||
|
mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv
|
||||||
|
mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up
|
||||||
|
|
||||||
|
mock_quant_method_1_qkv = mocker.MagicMock()
|
||||||
|
mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod()
|
||||||
|
mock_quant_method_1_gate_up = mocker.MagicMock()
|
||||||
|
mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod()
|
||||||
|
mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv
|
||||||
|
mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = mock_forward_context
|
||||||
|
|
||||||
|
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
|
||||||
|
mock_forward_context.prefetch_mlp_enabled = False
|
||||||
|
mock_forward_context.layer_idx = 0
|
||||||
|
mock_forward_context.num_hidden_layers = 2
|
||||||
|
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||||
|
|
||||||
|
# Ensure fusion and layer_idx increment are handled correctly
|
||||||
|
x = torch.randn(4, 8, dtype=torch.float16)
|
||||||
|
residual = torch.randn(4, 8, dtype=torch.float16)
|
||||||
|
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||||
|
|
||||||
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
|
assert mock_get_forward_context.call_count == 1
|
||||||
|
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||||
|
assert mock_forward_context.layer_idx == 1
|
||||||
|
|
||||||
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
|
assert mock_get_forward_context.call_count == 2
|
||||||
|
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||||
|
assert mock_forward_context.layer_idx == 1
|
||||||
|
|
||||||
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
|
assert mock_get_forward_context.call_count == 3
|
||||||
|
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||||
|
assert mock_forward_context.layer_idx == 2
|
||||||
|
|
||||||
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
|
assert mock_get_forward_context.call_count == 4
|
||||||
|
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||||
|
assert mock_forward_context.layer_idx == 2
|
||||||
|
|
||||||
|
|
||||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
if __name__ == '__main__':
|
||||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
unittest.main()
|
||||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
|
||||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
|
||||||
side_effect=mock_maybe_chunk_residual)
|
|
||||||
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
|
|
||||||
mock_maybe_wait_prefetch_done,
|
|
||||||
mock_add_rms_norm, mock_is310p):
|
|
||||||
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
|
||||||
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
|
||||||
layer = RMSNorm(hidden_size=512, eps=1e-05)
|
|
||||||
|
|
||||||
out_x, out_residual = layer.forward_oot(x, residual)
|
|
||||||
|
|
||||||
expected_out_x = 2 * x
|
|
||||||
expected_out_residual = 2 * residual[:4]
|
|
||||||
|
|
||||||
mock_maybe_chunk_residual.assert_called_once()
|
|
||||||
mock_add_rms_norm.assert_called_once()
|
|
||||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
|
||||||
assert out_residual.size(0) == 4
|
|
||||||
assert torch.allclose(out_x, expected_out_x)
|
|
||||||
assert torch.allclose(out_residual, expected_out_residual)
|
|
||||||
|
|||||||
@@ -129,6 +129,22 @@ def set_ascend_forward_context(
|
|||||||
forward_context.prefetch_mlp_down_proj = False
|
forward_context.prefetch_mlp_down_proj = False
|
||||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||||
|
|
||||||
|
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||||
|
# It will be improved later by implementing operator fusion through the FX graph.
|
||||||
|
#
|
||||||
|
# set for addrmsnorm+quant fusion.
|
||||||
|
# this optim now just support dense models due to the specific operators used.
|
||||||
|
# Once the necessary conditions are met, support for MOE models will also be added.
|
||||||
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||||
|
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
||||||
|
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
|
||||||
|
forward_context.layer_idx is not None
|
||||||
|
if addrmsnorm_quant_fusion_enabled:
|
||||||
|
forward_context.model_instance = model_instance
|
||||||
|
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||||
|
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
||||||
|
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
||||||
|
|
||||||
if num_tokens is None and attn_metadata is not None:
|
if num_tokens is None and attn_metadata is not None:
|
||||||
num_tokens = attn_metadata.num_actual_tokens
|
num_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
|
|||||||
@@ -35,9 +35,6 @@ def register_model():
|
|||||||
"Qwen3MoeForCausalLM",
|
"Qwen3MoeForCausalLM",
|
||||||
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
|
||||||
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
|
|
||||||
|
|
||||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
|
|||||||
@@ -1,156 +0,0 @@
|
|||||||
from collections.abc import Iterable
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import Qwen3Config
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
|
||||||
from vllm.distributed import get_pp_group
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
||||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
|
||||||
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
|
|
||||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
|
||||||
PPMissingLayer, maybe_prefix)
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
|
|
||||||
|
|
||||||
|
|
||||||
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Qwen3Config,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config=config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=prefix)
|
|
||||||
if quant_config is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
|
||||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
|
||||||
|
|
||||||
assert isinstance(quant_config, AscendQuantConfig), \
|
|
||||||
"Expected quant_config to be an instance of AscendQuantConfig"
|
|
||||||
|
|
||||||
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
|
|
||||||
AscendW8A8LinearMethod):
|
|
||||||
self.input_layernorm = AddRMSNormW8A8Quant(
|
|
||||||
config.hidden_size,
|
|
||||||
layer=self.self_attn.qkv_proj,
|
|
||||||
eps=config.rms_norm_eps)
|
|
||||||
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
|
|
||||||
AscendW8A8LinearMethod):
|
|
||||||
self.post_attention_layernorm = AddRMSNormW8A8Quant(
|
|
||||||
config.hidden_size,
|
|
||||||
layer=self.mlp.gate_up_proj,
|
|
||||||
eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
|
|
||||||
ALL_DECODER_LAYER_TYPES = {
|
|
||||||
"attention": CustomQwen3DecoderLayer,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(
|
|
||||||
dynamic_arg_dims={
|
|
||||||
"input_ids": 0,
|
|
||||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
|
||||||
# otherwise (seq_len, ).
|
|
||||||
"positions": -1,
|
|
||||||
"intermediate_tensors": 0,
|
|
||||||
"inputs_embeds": 0,
|
|
||||||
})
|
|
||||||
class CustomQwen3Model(Qwen2Model):
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__(vllm_config=vllm_config,
|
|
||||||
prefix=prefix,
|
|
||||||
decoder_layer_type=CustomQwen3DecoderLayer)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|
||||||
# add `CustomQwen3Model` to init self.model
|
|
||||||
packed_modules_mapping = {
|
|
||||||
"qkv_proj": [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
],
|
|
||||||
"gate_up_proj": [
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
lora_config = vllm_config.lora_config
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.lora_config = lora_config
|
|
||||||
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.model = CustomQwen3Model(vllm_config=vllm_config,
|
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
|
||||||
if config.tie_word_embeddings:
|
|
||||||
self.lm_head = self.model.embed_tokens
|
|
||||||
else:
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(
|
|
||||||
prefix, "lm_head"))
|
|
||||||
else:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
self.model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.model.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
|
||||||
torch.Tensor]]) -> set[str]:
|
|
||||||
loader = AutoWeightsLoader(
|
|
||||||
self,
|
|
||||||
skip_prefixes=(["lm_head."]
|
|
||||||
if self.config.tie_word_embeddings else None),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights)
|
|
||||||
@@ -18,47 +18,40 @@
|
|||||||
from typing import Optional, Tuple, Union, cast
|
from typing import Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
class AddRMSNormW8A8Quant(RMSNorm):
|
def _addrmsnorm_forward_oot(
|
||||||
# Fuse AddRmsNorm and W8A8 quantization ops together
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
layer: Optional[torch.nn.Module] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
def __init__(
|
from vllm_ascend.utils import is_310p
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
var_hidden_size: Optional[int] = None,
|
|
||||||
has_weight: bool = True,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
|
||||||
self.layer = layer
|
|
||||||
|
|
||||||
def forward(
|
if layer is not None and not is_310p():
|
||||||
self,
|
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||||
x: torch.Tensor,
|
x,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
self.weight,
|
||||||
import torch_npu
|
layer.aclnn_input_scale,
|
||||||
|
layer.aclnn_input_offset,
|
||||||
if residual is not None:
|
epsilon=self.variance_epsilon)
|
||||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
else:
|
||||||
assert x.size(0) == residual.size(0)
|
if is_310p():
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
orig_dtype = residual.dtype
|
||||||
x,
|
x = x + residual.to(x.dtype)
|
||||||
residual,
|
residual = x.to(orig_dtype)
|
||||||
self.weight,
|
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||||
self.layer.aclnn_input_scale,
|
self.variance_epsilon)
|
||||||
self.layer.aclnn_input_offset,
|
else:
|
||||||
epsilon=self.variance_epsilon)
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
x, residual, self.weight, self.variance_epsilon)
|
||||||
return x, residual
|
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||||
|
return x, residual
|
||||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
|
||||||
self.variance_epsilon)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class AscendRMSNorm(RMSNorm):
|
class AscendRMSNorm(RMSNorm):
|
||||||
@@ -70,26 +63,49 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
from vllm_ascend.utils import is_310p
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||||
assert x.size(0) == residual.size(0)
|
assert x.size(0) == residual.size(0)
|
||||||
if is_310p():
|
x, residual = _addrmsnorm_forward_oot(
|
||||||
orig_dtype = residual.dtype
|
self, x, residual, self.next_need_quant_fusion_linear)
|
||||||
x = x + residual.to(x.dtype)
|
|
||||||
residual = x.to(orig_dtype)
|
|
||||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
|
||||||
self.variance_epsilon)
|
|
||||||
else:
|
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
||||||
x, residual, self.weight, self.variance_epsilon)
|
|
||||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
|
||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_need_quant_fusion_linear(self):
|
||||||
|
try:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
if not forward_context.addrmsnorm_quant_fusion_enabled or \
|
||||||
|
forward_context.layer_idx == forward_context.num_hidden_layers:
|
||||||
|
return None
|
||||||
|
except AssertionError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
next_linear = None
|
||||||
|
model_instance = forward_context.model_instance
|
||||||
|
layer_idx = forward_context.layer_idx
|
||||||
|
fusion_linear = forward_context.fusion_linear
|
||||||
|
next_linear = None
|
||||||
|
if fusion_linear == "qkv_dense":
|
||||||
|
next_linear = model_instance.model.layers[
|
||||||
|
layer_idx].self_attn.qkv_proj
|
||||||
|
forward_context.fusion_linear = "gate_up_dense"
|
||||||
|
elif fusion_linear == "gate_up_dense":
|
||||||
|
next_linear = model_instance.model.layers[
|
||||||
|
layer_idx].mlp.gate_up_proj
|
||||||
|
forward_context.fusion_linear = "qkv_dense"
|
||||||
|
# if prefetch_mlp_weight enabled, following accumulation operation
|
||||||
|
# does not need to be repeated
|
||||||
|
if not forward_context.prefetch_mlp_enabled:
|
||||||
|
forward_context.layer_idx += 1
|
||||||
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
|
if next_linear is not None and \
|
||||||
|
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
||||||
|
next_linear = None
|
||||||
|
return next_linear
|
||||||
|
|
||||||
|
|
||||||
class AscendQuantRMSNorm(AscendRMSNorm):
|
class AscendQuantRMSNorm(AscendRMSNorm):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user