diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 7987f5b..8b22625 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -258,4 +258,4 @@ jobs: VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_USE_MODELSCOPE: True run: | - pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP + pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP \ No newline at end of file diff --git a/.github/workflows/vllm_ascend_test_full.yaml b/.github/workflows/vllm_ascend_test_full.yaml index 04e654e..779b2e6 100644 --- a/.github/workflows/vllm_ascend_test_full.yaml +++ b/.github/workflows/vllm_ascend_test_full.yaml @@ -226,6 +226,8 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1 + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight #pytest -sv tests/e2e/multicard/test_pipeline_parallel.py pytest -sv tests/e2e/multicard/test_prefix_caching.py diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 4fd72ce..907ac40 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -31,7 +31,9 @@ from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" -QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"] +QWEN_DENSE_MODELS = [ + "vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8" +] def test_models_distributed_QwQ(): @@ -170,6 +172,29 @@ def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager): max_model_len=8192, enforce_eager=enforce_eager, dtype="auto", - tensor_parallel_size=4, + tensor_parallel_size=2, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("model", QWEN_DENSE_MODELS) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"}) +def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight( + model, enforce_eager): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download(model), + max_model_len=8192, + enforce_eager=enforce_eager, + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/ut/ops/test_activation.py b/tests/ut/ops/test_activation.py index b90ccff..76bc55d 100644 --- a/tests/ut/ops/test_activation.py +++ b/tests/ut/ops/test_activation.py @@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor): @pytest.mark.parametrize("is_310p_return", [True, False]) @patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1) -def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor): +@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) +@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", + side_effect=lambda x: None) +def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj, + mock_maybe_wait_prefetch_done, mock_swiglu, + is_310p_return, dummy_tensor): with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): layer = SiluAndMul() @@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor): else: expected_arg = dummy_tensor + # assert mock_maybe_prefetch_mlp_down_proj.call_count == 1 + mock_maybe_prefetch_mlp_down_proj.assert_called_once() + # assert mock_swiglu.call_count == 1 mock_swiglu.assert_called_once() + # assert mock_maybe_wait_prefetch_done.call_count == 1 + mock_maybe_wait_prefetch_done.assert_called_once() + actual_arg = mock_swiglu.call_args[0][0] assert torch.allclose( actual_arg, diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index a2d9877..3bed078 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps): [None, torch.randn(4, 8, dtype=torch.float32)]) @patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) @patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) +@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) @patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual) -def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, +def test_RMSNorm_forward(mock_maybe_chunk_residual, + mock_maybe_wait_prefetch_done, mock_add_rmsnorm, mock_rmsnorm, is_310p_return, residual, dummy_tensor): with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): @@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, expected_out_x = expected_arg_x + 1 expected_out_residual = expected_arg_x.to(residual.dtype) + mock_maybe_chunk_residual.assert_called_once() mock_rmsnorm.assert_called_once() + mock_maybe_wait_prefetch_done.assert_called_once() assert torch.allclose(out_x, expected_out_x) assert torch.allclose(out_residual, expected_out_residual) else: expected_out_x = 2 * dummy_tensor expected_out_residual = 2 * residual + mock_maybe_chunk_residual.assert_called_once() mock_add_rmsnorm.assert_called_once() + mock_maybe_wait_prefetch_done.assert_called_once() assert torch.allclose(out_x, expected_out_x) assert torch.allclose(out_residual, expected_out_residual) else: @@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) +@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) @patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual) def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, + mock_maybe_wait_prefetch_done, mock_add_rms_norm, mock_is310p): x = torch.randn(4, 512, dtype=torch.bfloat16) residual = torch.randn(16, 512, dtype=torch.bfloat16) @@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, mock_maybe_chunk_residual.assert_called_once() mock_add_rms_norm.assert_called_once() + mock_maybe_wait_prefetch_done.assert_called_once() assert out_residual.size(0) == 4 assert torch.allclose(out_x, expected_out_x) assert torch.allclose(out_residual, expected_out_residual) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py index 912ff9a..3942144 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_v2.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_v2.py @@ -275,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, @patch("torch_npu.npu_add_rms_norm") @patch("torch_npu.npu_rms_norm") -def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm, +@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) +@patch("torch.ops.vllm.maybe_chunk_residual", + side_effect=lambda x, residual: residual) +def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual, + mock_maybe_wait_prefetch_done, + mock_rms_norm, mock_add_norm, mock_distributed, base_config, vllm_config): mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index be38b9d..d107a9e 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -66,7 +66,9 @@ def set_ascend_forward_context( moe_comm_method: str = "", num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + batch_descriptor: Optional[BatchDescriptor] = None, + prefetch_stream: torch.npu.Stream = None, + model_instance: torch.nn.Module = None): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. @@ -108,7 +110,8 @@ def set_ascend_forward_context( # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. - flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ + flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ + envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 @@ -122,6 +125,26 @@ def set_ascend_forward_context( # set this for rope forward_oot using forward_context.is_first_layer = True + # set layer_idx to enable optimization features that depend on this information. + # This is only applicable to models that contain these necessary attributes. + forward_context.layer_idx = None + if model_instance is not None and \ + hasattr(model_instance, "model") and \ + hasattr(model_instance.model, "start_layer"): + forward_context.layer_idx = model_instance.model.start_layer + + # set for mlp weight prefetch + prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ + envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ + forward_context.layer_idx is not None and \ + num_tokens is not None and num_tokens < 500 + if prefetch_mlp_enabled: + forward_context.prefetch_stream = prefetch_stream + forward_context.model_instance = model_instance + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index a9ae83d..ef8e33e 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -135,6 +135,15 @@ env_variables: Dict[str, Callable[[], Any]] = { # This feature will get better performance when concurrency is large. "VLLM_ASCEND_ENABLE_FLASHCOMM": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), + # Whether to enable MLP weight prefetch, only used in small concurrency. + "VLLM_ASCEND_ENABLE_PREFETCH_MLP": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), + # buffer size for gate up prefetch + "MLP_GATE_UP_PREFETCH_SIZE": + lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), + # buffer size for down proj prefetch + "MLP_DOWN_PREFETCH_SIZE": + lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), # Whether to enable dense model and general optimizations for better performance. # Since we modified the base parent class `linear`, this optimization is also applicable to other model types. # However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index 26082fe..fb1abe6 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -35,8 +35,10 @@ class AscendSiluAndMul(SiluAndMul): from vllm_ascend.utils import is_310p + torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) if is_310p(): out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) else: out = torch_npu.npu_swiglu(x) + torch.ops.vllm.maybe_wait_prefetch_done(out) return out diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index adaa73c..d97d771 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -44,12 +44,7 @@ class AddRMSNormW8A8Quant(RMSNorm): import torch_npu if residual is not None: - # FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature - # is enabled, without interfering with the normal operation of components like torchair. - # The final solution should be to move this check into the operator and support - # integration with torchair. - if x.size(0) != residual.size(0): - residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) x, _, residual = torch_npu.npu_add_rms_norm_quant( x, @@ -58,6 +53,7 @@ class AddRMSNormW8A8Quant(RMSNorm): self.layer.aclnn_input_scale, self.layer.aclnn_input_offset, epsilon=self.variance_epsilon) + torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, @@ -76,12 +72,7 @@ class AscendRMSNorm(RMSNorm): from vllm_ascend.utils import is_310p if residual is not None: - # FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature - # is enabled, without interfering with the normal operation of components like torchair. - # The final solution should be to move this check into the operator and support - # integration with torchair. - if x.size(0) != residual.size(0): - residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) if is_310p(): orig_dtype = residual.dtype @@ -92,6 +83,7 @@ class AscendRMSNorm(RMSNorm): else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) + torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 0c6f430..9b472a7 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -390,6 +390,7 @@ class AscendRowParallelLinear(RowParallelLinear): input_parallel, bias=bias_) output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) output_bias = self.bias if self.skip_bias_add else None diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 0391258..d066dc9 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +import torch_npu from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -8,10 +9,16 @@ from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.forward_context import get_forward_context from vllm.utils import direct_register_custom_op +import vllm_ascend.envs as envs_ascend + def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: - if get_forward_context().flashcomm_v1_enabled: + if x.size(0) != residual.size(0): + flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + assert flashcomm_v1_enabled is True, ( + "Currently, this situation only occurs " + "when flashcomm_v1 is enabled") pad_size = get_forward_context().pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) @@ -44,6 +51,76 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: return tensor_model_parallel_all_reduce(x) +def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, + prefix: str) -> None: + forward_context = get_forward_context() + if not forward_context.prefetch_mlp_enabled: + return + model_instance = forward_context.model_instance + prefetch_stream = forward_context.prefetch_stream + layer_idx = int(prefix.split('.')[2]) + + # start point of gate_up_proj weight prefetch + if prefix.split('.')[-2] == "self_attn": + forward_context.prefetch_mlp_gate_up_proj = True + if forward_context.prefetch_mlp_gate_up_proj: + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE + torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \ + x_dependency, MLP_GATE_UP_PREFETCH_SIZE) + return + + +def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, + prefix: str) -> None: + return + + +def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: + forward_context = get_forward_context() + if not forward_context.prefetch_mlp_enabled: + return + forward_context.prefetch_mlp_down_proj = True + model_instance = forward_context.model_instance + prefetch_stream = forward_context.prefetch_stream + layer_idx = forward_context.layer_idx + + # start point of down_proj weight prefetch + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE + torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \ + x_dependency, MLP_DOWN_PREFETCH_SIZE) + forward_context.layer_idx += 1 + return + + +def _maybe_prefetch_mlp_down_proj_impl_fake( + x_dependency: torch.Tensor) -> None: + return + + +def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: + forward_context = get_forward_context() + if not forward_context.prefetch_mlp_enabled: + return + if forward_context.prefetch_mlp_gate_up_proj or \ + forward_context.prefetch_mlp_down_proj: + prefetch_stream = get_forward_context().prefetch_stream + # wait until prefetch done + torch.npu.current_stream().wait_stream(prefetch_stream) + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + return + + +def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: + return + + direct_register_custom_op(op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, fake_impl=lambda x, residual: residual, @@ -60,4 +137,22 @@ direct_register_custom_op(op_name="maybe_pad_and_reduce", op_func=_maybe_pad_and_reduce_impl, fake_impl=lambda x: x, mutates_args=[], - dispatch_key="PrivateUse1") \ No newline at end of file + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj", + op_func=_maybe_prefetch_mlp_gate_up_proj_impl, + fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj", + op_func=_maybe_prefetch_mlp_down_proj_impl, + fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_wait_prefetch_done", + op_func=_maybe_wait_prefetch_done_impl, + fake_impl=_maybe_wait_prefetch_done_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_ascend/torchair/ops/torchair_activation.py b/vllm_ascend/torchair/ops/torchair_activation.py new file mode 100644 index 0000000..0721ea0 --- /dev/null +++ b/vllm_ascend/torchair/ops/torchair_activation.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch + + +def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: + """AscendSiluAndMul forward in torchair mode. + + The key difference from the original implementation is the removal of operators + from the torch.ops.vllm class, as these operators only function in non-torchair + modes. Adding them back would cause the graph compilation to fail. + """ + + import torch_npu + + from vllm_ascend.utils import is_310p + + if is_310p(): + out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) + else: + out = torch_npu.npu_swiglu(x) + return out diff --git a/vllm_ascend/torchair/ops/torchair_layernorm.py b/vllm_ascend/torchair/ops/torchair_layernorm.py new file mode 100644 index 0000000..d90f889 --- /dev/null +++ b/vllm_ascend/torchair/ops/torchair_layernorm.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from typing import Optional, Tuple, Union + +import torch + + +def torchair_rmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """AscendRMSNorm forward in torchair mode. + + The key difference from the original implementation is the removal of operators + from the torch.ops.vllm class, as these operators only function in non-torchair + modes. Adding them back would cause the graph compilation to fail. + """ + + import torch_npu + + from vllm_ascend.utils import is_310p + if residual is not None: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + return x, residual + + x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + return x diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index fcf2914..56b8c71 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -199,8 +199,12 @@ def torchair_quant_method_register(): def torchair_ops_patch(): + from vllm_ascend.ops.activation import AscendSiluAndMul + from vllm_ascend.ops.layernorm import AscendRMSNorm from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + from vllm_ascend.torchair.ops import (torchair_activation, + torchair_layernorm) from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( deepseek_rope_init_func, native_rope_deepseek_forward, qwen_rope_init_func, rope_forward) @@ -210,3 +214,6 @@ def torchair_ops_patch(): AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign] AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign] + + AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] + AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a7c8966..ab8f593 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -227,6 +227,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device + if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: + self.prefetch_stream = torch.npu.Stream(device=device) + else: + self.prefetch_stream = None self.dtype = self.model_config.dtype if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: # TODO: drop the env config to use ascend sampler by default @@ -1592,7 +1596,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. - total_num_scheduled_tokens): + total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + model_instance=self.model): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2057,7 +2063,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): moe_comm_method=moe_comm_method, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + model_instance=self.model): hidden_states = self._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 681c3f8..ef23645 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -51,6 +51,18 @@ from vllm_ascend.utils import (init_ascend_soc_version, try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 +from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402 + +torch_non_c_binding_in_graph_functions_npu = dict.fromkeys( + ["torch.npu.current_stream"], + TorchInGraphFunctionVariable, +) # noqa: E402 +torch_non_c_binding_in_graph_functions_npu[ + "torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402 +torch._dynamo.trace_rules.torch_name_rule_map.append( + torch_non_c_binding_in_graph_functions_npu) # noqa: E402 + class NPUWorker(WorkerBase):