[0.11.0]Chery pick pta upgrade change (#3940)
This PR cherry-pick two commit from main to upgrade torch-npu to 2.7.1 official release --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -20,6 +20,13 @@ set(VLLM_ASCEND_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}")
|
|||||||
|
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
|
run_python(TORCH_VERSION
|
||||||
|
"import torch; print(torch.__version__)" "Failed to locate torch path")
|
||||||
|
# check torch version is 2.7.1
|
||||||
|
if(NOT ${TORCH_VERSION} VERSION_EQUAL "2.7.1")
|
||||||
|
message(FATAL_ERROR "Expected PyTorch version 2.7.1, but found ${TORCH_VERSION}")
|
||||||
|
endif()
|
||||||
|
|
||||||
set(RUN_MODE "npu" CACHE STRING "cpu/sim/npu")
|
set(RUN_MODE "npu" CACHE STRING "cpu/sim/npu")
|
||||||
set(SOC_VERSION ${SOC_VERSION})
|
set(SOC_VERSION ${SOC_VERSION})
|
||||||
message(STATUS "Detected SOC version: ${SOC_VERSION}")
|
message(STATUS "Detected SOC version: ${SOC_VERSION}")
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l
|
|||||||
- Software:
|
- Software:
|
||||||
* Python >= 3.9, < 3.12
|
* Python >= 3.9, < 3.12
|
||||||
* CANN >= 8.2.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
* CANN >= 8.2.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
||||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
* PyTorch == 2.7.1, torch-npu == 2.7.1
|
||||||
* vLLM (the same version as vllm-ascend)
|
* vLLM (the same version as vllm-ascend)
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP
|
|||||||
- 软件:
|
- 软件:
|
||||||
* Python >= 3.9, < 3.12
|
* Python >= 3.9, < 3.12
|
||||||
* CANN >= 8.2.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
* CANN >= 8.2.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
||||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
* PyTorch == 2.7.1, torch-npu == 2.7.1
|
||||||
* vLLM (与vllm-ascend版本一致)
|
* vLLM (与vllm-ascend版本一致)
|
||||||
|
|
||||||
## 开始使用
|
## 开始使用
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ This document describes how to install vllm-ascend manually.
|
|||||||
|---------------|----------------------------------|-------------------------------------------|
|
|---------------|----------------------------------|-------------------------------------------|
|
||||||
| Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html) | Required for CANN |
|
| Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html) | Required for CANN |
|
||||||
| CANN | >= 8.2.RC1 | Required for vllm-ascend and torch-npu |
|
| CANN | >= 8.2.RC1 | Required for vllm-ascend and torch-npu |
|
||||||
| torch-npu | >= 2.7.1.dev20250724 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps |
|
| torch-npu | == 2.7.1 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps |
|
||||||
| torch | >= 2.7.1 | Required for torch-npu and vllm |
|
| torch | == 2.7.1 | Required for torch-npu and vllm |
|
||||||
|
|
||||||
There are two installation methods:
|
There are two installation methods:
|
||||||
- **Using pip**: first prepare env manually or via CANN image, then install `vllm-ascend` using pip.
|
- **Using pip**: first prepare env manually or via CANN image, then install `vllm-ascend` using pip.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
* Software:
|
* Software:
|
||||||
* Python >= 3.9, < 3.12
|
* Python >= 3.9, < 3.12
|
||||||
* CANN >= 8.2.rc1
|
* CANN >= 8.2.rc1
|
||||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
* PyTorch == 2.7.1, torch-npu == 2.7.1
|
||||||
* vLLM (same version as vllm-ascend)
|
* vLLM (same version as vllm-ascend)
|
||||||
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
|
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
* Software:
|
* Software:
|
||||||
* Python >= 3.9, < 3.12
|
* Python >= 3.9, < 3.12
|
||||||
* CANN >= 8.2.rc1
|
* CANN >= 8.2.rc1
|
||||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
* PyTorch == 2.7.1, torch-npu == 2.7.1
|
||||||
* vLLM:main branch
|
* vLLM:main branch
|
||||||
* vLLM-Ascend:main branch
|
* vLLM-Ascend:main branch
|
||||||
* Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy)(Currently available branch code, continuously updated.)
|
* Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy)(Currently available branch code, continuously updated.)
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ requires = [
|
|||||||
"scipy",
|
"scipy",
|
||||||
"setuptools>=64",
|
"setuptools>=64",
|
||||||
"setuptools-scm>=8",
|
"setuptools-scm>=8",
|
||||||
"torch-npu==2.7.1.dev20250724",
|
"torch-npu==2.7.1",
|
||||||
"torch>=2.7.1",
|
"torch==2.7.1",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"wheel",
|
"wheel",
|
||||||
"msgpack",
|
"msgpack",
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ pyyaml
|
|||||||
scipy
|
scipy
|
||||||
setuptools>=64
|
setuptools>=64
|
||||||
setuptools-scm>=8
|
setuptools-scm>=8
|
||||||
torch>=2.7.1
|
torch==2.7.1
|
||||||
torchvision
|
torchvision
|
||||||
wheel
|
wheel
|
||||||
opencv-python-headless<=4.11.0.86 # Required to avoid numpy version conflict with vllm
|
opencv-python-headless<=4.11.0.86 # Required to avoid numpy version conflict with vllm
|
||||||
@@ -23,6 +23,6 @@ quart
|
|||||||
numba
|
numba
|
||||||
|
|
||||||
# Install torch_npu
|
# Install torch_npu
|
||||||
--pre
|
#--pre
|
||||||
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
|
#--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
|
||||||
torch-npu==2.7.1.dev20250724
|
torch-npu==2.7.1
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
|
|
||||||
from tests.ut.base import PytestBase
|
from tests.ut.base import PytestBase
|
||||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
from vllm_ascend.utils import version_check
|
|
||||||
|
|
||||||
|
|
||||||
def mock_rms_norm(x, weight, eps):
|
def mock_rms_norm(x, weight, eps):
|
||||||
@@ -18,15 +17,6 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
|||||||
return 2 * x, None, 2 * residual
|
return 2 * x, None, 2 * residual
|
||||||
|
|
||||||
|
|
||||||
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
|
|
||||||
epsilon):
|
|
||||||
x_out = 2 * x
|
|
||||||
residual_out = 2 * residual
|
|
||||||
x_out_quant = x_out.to(torch.int8)
|
|
||||||
residual_out_quant = residual_out.to(torch.int8)
|
|
||||||
return x_out_quant, None, residual_out_quant
|
|
||||||
|
|
||||||
|
|
||||||
def mock_add_rms_norm_quant_with_bias(x, residual, weight, quant_scale,
|
def mock_add_rms_norm_quant_with_bias(x, residual, weight, quant_scale,
|
||||||
quant_offset, beta, epsilon):
|
quant_offset, beta, epsilon):
|
||||||
x_out = 2 * x
|
x_out = 2 * x
|
||||||
@@ -43,10 +33,8 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||||
mocker.patch("torch_npu.npu_add_rms_norm",
|
mocker.patch("torch_npu.npu_add_rms_norm",
|
||||||
side_effect=mock_add_rms_norm)
|
side_effect=mock_add_rms_norm)
|
||||||
torch_npu_check = version_check()
|
|
||||||
arnq_side_effect = mock_add_rms_norm_quant_with_bias if torch_npu_check else mock_add_rms_norm_quant
|
|
||||||
mocker.patch("torch_npu.npu_add_rms_norm_quant",
|
mocker.patch("torch_npu.npu_add_rms_norm_quant",
|
||||||
side_effect=arnq_side_effect)
|
side_effect=mock_add_rms_norm_quant_with_bias)
|
||||||
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
|
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
|
||||||
side_effect=lambda x: None)
|
side_effect=lambda x: None)
|
||||||
|
|
||||||
@@ -82,8 +70,7 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
|
|
||||||
mock_model_instance = mocker.MagicMock()
|
mock_model_instance = mocker.MagicMock()
|
||||||
mock_forward_context.model_instance = mock_model_instance
|
mock_forward_context.model_instance = mock_model_instance
|
||||||
torch_npu_check = version_check()
|
num_hidden_layers = 3
|
||||||
num_hidden_layers = 3 if torch_npu_check else 2
|
|
||||||
mock_model_instance.model.layers = [
|
mock_model_instance.model.layers = [
|
||||||
mocker.MagicMock() for _ in range(num_hidden_layers)
|
mocker.MagicMock() for _ in range(num_hidden_layers)
|
||||||
]
|
]
|
||||||
@@ -136,37 +123,34 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||||
assert mock_forward_context.layer_idx == 1
|
assert mock_forward_context.layer_idx == 1
|
||||||
|
|
||||||
if torch_npu_check:
|
mock_forward_context.fusion_linear = "gate_moe"
|
||||||
mock_forward_context.fusion_linear = "gate_moe"
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
|
assert mock_get_forward_context.call_count == 5
|
||||||
|
fusion_linear_expected = "qkv_moe"
|
||||||
|
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||||
|
assert mock_forward_context.layer_idx == 2
|
||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 6
|
assert mock_get_forward_context.call_count == 6
|
||||||
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
|
fusion_linear_expected = "gate_moe"
|
||||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||||
assert mock_forward_context.layer_idx == 2
|
assert mock_forward_context.layer_idx == 2
|
||||||
|
|
||||||
|
# last layer returned directly
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 7
|
assert mock_get_forward_context.call_count == 7
|
||||||
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
|
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
assert mock_forward_context.layer_idx == 3
|
||||||
assert mock_forward_context.layer_idx == 2
|
|
||||||
|
|
||||||
if not torch_npu_check:
|
|
||||||
return
|
|
||||||
# last layer returned directly
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 8
|
assert mock_get_forward_context.call_count == 8
|
||||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||||
assert mock_forward_context.layer_idx == 3
|
assert mock_forward_context.layer_idx == 3
|
||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 9
|
|
||||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
|
||||||
assert mock_forward_context.layer_idx == 3
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
|||||||
@patch("torch_npu.npu_swiglu")
|
@patch("torch_npu.npu_swiglu")
|
||||||
@patch("torch_npu.npu_dynamic_quant")
|
@patch("torch_npu.npu_dynamic_quant")
|
||||||
@patch("torch_npu.npu_moe_finalize_routing")
|
@patch("torch_npu.npu_moe_finalize_routing")
|
||||||
@patch("torch_npu.npu_moe_init_routing")
|
@patch("torch_npu.npu_moe_init_routing_quant")
|
||||||
def test_torchair_fused_experts_with_all2all(
|
def test_torchair_fused_experts_with_all2all(
|
||||||
self, mock_moe_init_routing, mock_moe_finalize_routing,
|
self, mock_npu_moe_init_routing_quant, mock_moe_finalize_routing,
|
||||||
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
|
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
|
||||||
mock_moe_re_routing, mock_all_to_all_single):
|
mock_moe_re_routing, mock_all_to_all_single):
|
||||||
|
|
||||||
@@ -38,11 +38,10 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
|||||||
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
||||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||||
input)
|
input)
|
||||||
mock_moe_init_routing.return_value = (
|
mock_npu_moe_init_routing_quant.return_value = (
|
||||||
placeholder_int8,
|
placeholder_int8, placeholder_ones, placeholder_ones,
|
||||||
placeholder_ones,
|
torch.bincount(placeholder_ones, minlength=len(expert_map)),
|
||||||
placeholder_ones,
|
torch.randn(self.num_tokens))
|
||||||
)
|
|
||||||
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
||||||
torch.randint(0,
|
torch.randint(0,
|
||||||
100,
|
100,
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
|||||||
set_forward_context)
|
set_forward_context)
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.utils import (enable_sp, has_layer_idx, is_moe_model,
|
from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model
|
||||||
version_check)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||||
@@ -162,9 +161,7 @@ def set_ascend_forward_context(
|
|||||||
# this optim now just support dense models due to the specific operators used.
|
# this optim now just support dense models due to the specific operators used.
|
||||||
# Once the necessary conditions are met, support for MOE models will also be added.
|
# Once the necessary conditions are met, support for MOE models will also be added.
|
||||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||||
model_type_scope = ["llama", "qwen2", "qwen3"]
|
model_type_scope = ["llama", "qwen2", "qwen3", "qwen3_moe"]
|
||||||
if version_check():
|
|
||||||
model_type_scope.append("qwen3_moe")
|
|
||||||
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
||||||
vllm_config.model_config.hf_config.model_type in model_type_scope and \
|
vllm_config.model_config.hf_config.model_type in model_type_scope and \
|
||||||
forward_context.layer_idx is not None
|
forward_context.layer_idx is not None
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
|||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||||
nd_to_nz_2d, nd_to_nz_spec, version_check)
|
nd_to_nz_2d, nd_to_nz_spec)
|
||||||
|
|
||||||
from ..utils import weak_ref_tensors
|
from ..utils import weak_ref_tensors
|
||||||
|
|
||||||
@@ -321,7 +321,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.key_cache = None
|
self.key_cache = None
|
||||||
self.value_cache = None
|
self.value_cache = None
|
||||||
self.torch_npu_check = version_check()
|
|
||||||
|
|
||||||
def _forward_prefill_no_cache(
|
def _forward_prefill_no_cache(
|
||||||
self,
|
self,
|
||||||
@@ -429,22 +428,21 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
if forward_context.capturing:
|
if forward_context.capturing:
|
||||||
if self.torch_npu_check:
|
# Get workspace from cache or calculate it if not present.
|
||||||
# Get workspace from cache or calculate it if not present.
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
workspace = graph_params.workspaces.get(num_tokens)
|
if workspace is None:
|
||||||
if workspace is None:
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
query=query,
|
||||||
query=query,
|
key_cache=self.key_cache,
|
||||||
key_cache=self.key_cache,
|
value_cache=self.value_cache,
|
||||||
value_cache=self.value_cache,
|
num_kv_heads=self.num_kv_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_heads=self.num_heads,
|
||||||
num_heads=self.num_heads,
|
scale_value=self.scale,
|
||||||
scale_value=self.scale,
|
block_table=attn_metadata.block_tables,
|
||||||
block_table=attn_metadata.block_tables,
|
context_lens=attn_metadata.seq_lens,
|
||||||
context_lens=attn_metadata.seq_lens,
|
out=output)
|
||||||
out=output)
|
update_graph_params_workspaces(num_tokens,
|
||||||
update_graph_params_workspaces(
|
weak_ref_tensors(workspace))
|
||||||
num_tokens, weak_ref_tensors(workspace))
|
|
||||||
|
|
||||||
# Handle graph capturing mode
|
# Handle graph capturing mode
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
@@ -466,30 +464,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
))
|
))
|
||||||
|
|
||||||
torch.npu.graph_task_group_begin(stream)
|
torch.npu.graph_task_group_begin(stream)
|
||||||
|
torch_npu._npu_paged_attention(
|
||||||
if self.torch_npu_check:
|
query=query,
|
||||||
torch_npu._npu_paged_attention(
|
key_cache=self.key_cache,
|
||||||
query=query,
|
value_cache=self.value_cache,
|
||||||
key_cache=self.key_cache,
|
num_kv_heads=self.num_kv_heads,
|
||||||
value_cache=self.value_cache,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
scale_value=self.scale,
|
||||||
num_heads=self.num_heads,
|
block_table=attn_metadata.block_tables,
|
||||||
scale_value=self.scale,
|
context_lens=attn_metadata.seq_lens,
|
||||||
block_table=attn_metadata.block_tables,
|
out=output,
|
||||||
context_lens=attn_metadata.seq_lens,
|
workspace=workspace)
|
||||||
out=output,
|
|
||||||
workspace=workspace)
|
|
||||||
else:
|
|
||||||
torch_npu._npu_paged_attention(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
out=output)
|
|
||||||
handle = torch.npu.graph_task_group_end(stream)
|
handle = torch.npu.graph_task_group_end(stream)
|
||||||
graph_params.handles[num_tokens].append(handle)
|
graph_params.handles[num_tokens].append(handle)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -18,8 +18,6 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from vllm_ascend.utils import version_check
|
|
||||||
|
|
||||||
from ..utils import weak_ref_tensors
|
from ..utils import weak_ref_tensors
|
||||||
|
|
||||||
|
|
||||||
@@ -213,32 +211,20 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
|||||||
output,
|
output,
|
||||||
) = param
|
) = param
|
||||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||||
torch_npu_check = version_check()
|
|
||||||
|
|
||||||
with torch.npu.stream(update_stream):
|
with torch.npu.stream(update_stream):
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
if torch_npu_check:
|
torch_npu._npu_paged_attention(
|
||||||
torch_npu._npu_paged_attention(
|
query=query,
|
||||||
query=query,
|
key_cache=key_cache,
|
||||||
key_cache=key_cache,
|
value_cache=value_cache,
|
||||||
value_cache=value_cache,
|
num_kv_heads=num_kv_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_heads=num_heads,
|
||||||
num_heads=num_heads,
|
scale_value=scale,
|
||||||
scale_value=scale,
|
block_table=block_table,
|
||||||
block_table=block_table,
|
context_lens=seq_lens,
|
||||||
context_lens=seq_lens,
|
out=output,
|
||||||
out=output,
|
workspace=graph_params.workspaces.get(runtime_shape))
|
||||||
workspace=graph_params.workspaces.get(runtime_shape))
|
|
||||||
else:
|
|
||||||
torch_npu._npu_paged_attention(query=query,
|
|
||||||
key_cache=key_cache,
|
|
||||||
value_cache=value_cache,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
num_heads=num_heads,
|
|
||||||
scale_value=scale,
|
|
||||||
block_table=block_table,
|
|
||||||
context_lens=seq_lens,
|
|
||||||
out=output)
|
|
||||||
torch.npu.graph_task_update_end(update_stream)
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
event.record(update_stream)
|
event.record(update_stream)
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ from vllm.config import get_current_vllm_config
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||||
|
|
||||||
from vllm_ascend.utils import version_check
|
|
||||||
|
|
||||||
|
|
||||||
def _addrmsnorm_forward_oot(
|
def _addrmsnorm_forward_oot(
|
||||||
self,
|
self,
|
||||||
@@ -36,7 +34,6 @@ def _addrmsnorm_forward_oot(
|
|||||||
|
|
||||||
from vllm_ascend.utils import is_310p
|
from vllm_ascend.utils import is_310p
|
||||||
|
|
||||||
torch_npu_check = version_check()
|
|
||||||
if layer is not None and not is_310p():
|
if layer is not None and not is_310p():
|
||||||
layer_cls_name = layer.__class__.__name__
|
layer_cls_name = layer.__class__.__name__
|
||||||
try:
|
try:
|
||||||
@@ -53,23 +50,15 @@ def _addrmsnorm_forward_oot(
|
|||||||
start_flag=x,
|
start_flag=x,
|
||||||
)
|
)
|
||||||
# add_rms_norm_quant
|
# add_rms_norm_quant
|
||||||
if torch_npu_check:
|
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
x,
|
||||||
x,
|
residual,
|
||||||
residual,
|
self.weight,
|
||||||
self.weight,
|
layer.aclnn_input_scale,
|
||||||
layer.aclnn_input_scale,
|
layer.aclnn_input_offset,
|
||||||
layer.aclnn_input_offset,
|
beta=bias,
|
||||||
beta=bias,
|
epsilon=self.variance_epsilon)
|
||||||
epsilon=self.variance_epsilon)
|
|
||||||
else:
|
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
layer.aclnn_input_scale,
|
|
||||||
layer.aclnn_input_offset,
|
|
||||||
epsilon=self.variance_epsilon)
|
|
||||||
# prefetch qkvo_proj.weight postprocess
|
# prefetch qkvo_proj.weight postprocess
|
||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||||
@@ -87,7 +76,7 @@ def _addrmsnorm_forward_oot(
|
|||||||
else:
|
else:
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||||
x, residual, self.weight, self.variance_epsilon)
|
x, residual, self.weight, self.variance_epsilon)
|
||||||
if torch_npu_check and bias is not None:
|
if bias is not None:
|
||||||
x.add_(bias)
|
x.add_(bias)
|
||||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||||
return x, residual
|
return x, residual
|
||||||
@@ -106,9 +95,8 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.bias = None
|
self.bias = None
|
||||||
self.torch_npu_check = version_check()
|
|
||||||
# quantization with anti_method m4 will generate none-zero norm bias
|
# quantization with anti_method m4 will generate none-zero norm bias
|
||||||
if self.torch_npu_check and vllm_config.quant_config is not None and \
|
if vllm_config.quant_config is not None and \
|
||||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@@ -128,7 +116,7 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
return x, residual
|
return x, residual
|
||||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
if self.torch_npu_check and self.bias is not None:
|
if self.bias is not None:
|
||||||
x.add_(self.bias)
|
x.add_(self.bias)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||||
AscendRowParallelLinear)
|
AscendRowParallelLinear)
|
||||||
from vllm_ascend.utils import version_check
|
|
||||||
|
|
||||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||||
MOE_PREFETCH_TOKEN_THRESHOLD = 96
|
MOE_PREFETCH_TOKEN_THRESHOLD = 96
|
||||||
@@ -83,8 +82,7 @@ class WeightPrefetchMethod:
|
|||||||
if not self.moe.is_active_this_forward:
|
if not self.moe.is_active_this_forward:
|
||||||
return
|
return
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if not version_check():
|
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
|
||||||
forward_context.layer_idx += 1
|
|
||||||
weight = forward_context.model_instance.model.layers[
|
weight = forward_context.model_instance.model.layers[
|
||||||
forward_context.layer_idx - 1].mlp.experts.w13_weight
|
forward_context.layer_idx - 1].mlp.experts.w13_weight
|
||||||
weight_size = weight.data.element_size() * weight.data.numel(
|
weight_size = weight.data.element_size() * weight.data.numel(
|
||||||
|
|||||||
@@ -510,8 +510,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
|||||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||||
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
|
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
|
||||||
AscendSharedFusedMoE)
|
AscendSharedFusedMoE)
|
||||||
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
|
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
|
||||||
AscendQuantRMSNorm, AscendRMSNorm)
|
|
||||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||||
AscendMergedColumnParallelLinear,
|
AscendMergedColumnParallelLinear,
|
||||||
AscendQKVParallelLinear,
|
AscendQKVParallelLinear,
|
||||||
@@ -547,12 +546,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
|||||||
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
|
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
if vllm_config is not None and \
|
|
||||||
vllm_config.quant_config is not None and \
|
|
||||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()) and \
|
|
||||||
not version_check():
|
|
||||||
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
|
|
||||||
|
|
||||||
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
||||||
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
||||||
|
|
||||||
@@ -743,21 +736,6 @@ def is_hierarchical_communication_enabled():
|
|||||||
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
|
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def version_check():
|
|
||||||
"""check if torch_npu version >= dev20250919"""
|
|
||||||
import re
|
|
||||||
torch_npu_version = torch_npu.version.__version__
|
|
||||||
date_pattern = r'dev(\d{8})'
|
|
||||||
|
|
||||||
match = re.search(date_pattern, torch_npu_version)
|
|
||||||
if match:
|
|
||||||
full_date = match.group(1)
|
|
||||||
if full_date >= "20250919":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||||
if model_instance is None:
|
if model_instance is None:
|
||||||
return False
|
return False
|
||||||
|
|||||||
Reference in New Issue
Block a user