From c3a8d13ca7894fe323ec37e7b505ff42249249a2 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Tue, 23 Dec 2025 08:49:52 +0800 Subject: [PATCH] [refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204) ### What this PR does / why we need it? Remove unnecessary attributes from set_ascend_forward_context 1.prefetch_stream 2.weight_prefetch_method ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- tests/e2e/nightly/ops/test_fused_moe.py | 22 ++++-------------- tests/ut/ops/test_fused_moe.py | 23 +++++++++---------- tests/ut/quantization/test_w8a8.py | 9 +++----- vllm_ascend/ascend_forward_context.py | 12 +--------- vllm_ascend/ops/fused_moe/experts_selector.py | 5 ++-- vllm_ascend/ops/fused_moe/moe_mlp.py | 5 ++-- vllm_ascend/ops/register_custom_ops.py | 16 ++++++------- vllm_ascend/quantization/w8a8.py | 11 +++------ vllm_ascend/utils.py | 15 +++++++++++- vllm_ascend/worker/model_runner_v1.py | 20 ++++------------ 10 files changed, 55 insertions(+), 83 deletions(-) diff --git a/tests/e2e/nightly/ops/test_fused_moe.py b/tests/e2e/nightly/ops/test_fused_moe.py index 8fcac0e4..971d9310 100644 --- a/tests/e2e/nightly/ops/test_fused_moe.py +++ b/tests/e2e/nightly/ops/test_fused_moe.py @@ -286,8 +286,8 @@ def test_select_experts( with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk" ) as mock_native_grouped_topk, \ - patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', - return_value=MagicMock(weight_prefetch_method=MagicMock())): + patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method', + return_value=MagicMock()): mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) @@ -323,8 +323,8 @@ def test_select_experts( @pytest.mark.parametrize("device", DEVICE) def test_select_experts_invalid_scoring_func(device: str): - with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', - return_value=MagicMock(weight_prefetch_method=MagicMock())), \ + with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method', + return_value=MagicMock()), \ pytest.raises(ValueError, match="Unsupported scoring function: invalid"): select_experts(hidden_states=torch.randn(1, 128, device=device), @@ -336,17 +336,3 @@ def test_select_experts_invalid_scoring_func(device: str): gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() - - -@pytest.mark.parametrize("device", DEVICE) -def test_select_experts_missing_group_params(device: str): - with pytest.raises(AssertionError): - select_experts(hidden_states=torch.randn(1, 128, device=device), - router_logits=torch.randn(1, 64, device=device), - top_k=2, - use_grouped_topk=True, - renormalize=False, - scoring_func="softmax") - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index 96daaa25..215b076d 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -92,15 +92,14 @@ def mock_dist_env(mocker: MockerFixture): mock_moe_comm_method.finalize.side_effect = mock_finalize dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) mock_weight_prefetch_method = MagicMock() - mock_forward_context_obj = MagicMock( - moe_comm_method=mock_moe_comm_method, - moe_comm_type=MoECommType.MC2, - max_tokens_across_dp=10, - dp_metadata=dp_metadata, - mc2_mask=torch.zeros(16, dtype=torch.bool), - padded_num_tokens=16, - with_quant=False, - weight_prefetch_method=mock_weight_prefetch_method) + mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method, + moe_comm_type=MoECommType.MC2, + max_tokens_across_dp=10, + dp_metadata=dp_metadata, + mc2_mask=torch.zeros( + 16, dtype=torch.bool), + padded_num_tokens=16, + with_quant=False) with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ @@ -133,8 +132,8 @@ def mock_dist_env(mocker: MockerFixture): return_value=None), \ patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', return_value=None), \ - patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', - return_value=mock_forward_context_obj): + patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method', + return_value=mock_weight_prefetch_method): yield { 'mock_forward_context_obj': mock_forward_context_obj, @@ -590,4 +589,4 @@ class TestUnifiedApplyMLP(TestBase): self.assertTrue(mock_forward_context.with_quant) self.assertEqual(result.shape, hidden_states_shape) - self.assertEqual(result.dtype, torch.bfloat16) \ No newline at end of file + self.assertEqual(result.dtype, torch.bfloat16) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 9fa549b2..b8639cc4 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -63,21 +63,18 @@ class TestAscendW8A8LinearMethod(TestBase): self.assertEqual(params['weight_scale'].shape, (10, 1)) self.assertEqual(params['weight_offset'].shape, (10, 1)) - @patch("vllm_ascend.quantization.w8a8.get_forward_context") + @patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method") @patch("torch.ops.vllm.quantize") @patch("torch_npu.npu_quant_matmul") def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize, - mock_get_forward_context): + mock_get_weight_prefetch_method): layer = MagicMock() layer.aclnn_input_scale = 0.1 layer.aclnn_input_offset = 0.2 layer.weight = torch.randn(128, 256) layer.deq_scale = 0.3 - mock_forward_context = MagicMock() - mock_get_forward_context.return_value = mock_forward_context - mock_weight_prefetch_method = MagicMock() - mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method + mock_get_weight_prefetch_method.return_value = MagicMock() x = torch.randn(32, 128) bias = torch.randn(256) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index f8b9d1cd..e12b45fa 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -1,7 +1,7 @@ import math from contextlib import contextmanager from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional import torch from vllm.config import CUDAGraphMode, VllmConfig @@ -16,11 +16,6 @@ from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable, get_ascend_device_type, has_layer_idx, is_moe_model) -if TYPE_CHECKING: - from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod -else: - WeightPrefetchMethod = None - class MoECommType(Enum): ALLGATHER = 0 @@ -41,9 +36,7 @@ def set_ascend_forward_context( num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: Optional[BatchDescriptor] = None, - prefetch_stream: torch.npu.Stream = None, model_instance: torch.nn.Module = None, - weight_prefetch_method: Optional[WeightPrefetchMethod] = None, is_mtp_model=False): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -116,13 +109,10 @@ def set_ascend_forward_context( forward_context.layer_idx is not None and \ num_tokens is not None and num_tokens < 500 if prefetch_mlp_enabled: - forward_context.prefetch_stream = prefetch_stream - forward_context.model_instance = model_instance forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.model_instance = model_instance - forward_context.weight_prefetch_method = weight_prefetch_method forward_context.is_mtp_model = is_mtp_model if num_tokens is None and attn_metadata is not None: diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index 05ec0e38..51e0cb9f 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -18,7 +18,8 @@ from typing import Callable, Optional import torch import torch_npu -from vllm.forward_context import get_forward_context + +from vllm_ascend.utils import get_weight_prefetch_method def select_experts(hidden_states: torch.Tensor, @@ -56,7 +57,7 @@ def select_experts(hidden_states: torch.Tensor, topk_ids: selected expert IDs of shape (num_tokens, top_k). """ # prefetch w1_w3_proj.weight preprocess - weight_prefetch_method = get_forward_context().weight_prefetch_method + weight_prefetch_method = get_weight_prefetch_method() if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 5893168a..d102a1d5 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -24,7 +24,8 @@ from vllm.triton_utils import HAS_TRITON from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, - enable_custom_op, get_ascend_device_type) + enable_custom_op, get_ascend_device_type, + get_weight_prefetch_method) def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): @@ -100,7 +101,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias1, bias2 = None, None _output_dtype = w2_scale[0].dtype - weight_prefetch_method = get_forward_context().weight_prefetch_method + weight_prefetch_method = get_weight_prefetch_method() if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_postprocess( hidden_states) diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 6874687f..8403f438 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -119,16 +119,16 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, if not getattr(forward_context, 'prefetch_mlp_enabled', False): return model_instance = forward_context.model_instance - prefetch_stream = forward_context.prefetch_stream + weight_prefetch_stream = prefetch_stream() layer_idx = int(prefix.split('.')[2]) # start point of gate_up_proj weight prefetch if prefix.split('.')[-2] == "self_attn": forward_context.prefetch_mlp_gate_up_proj = True if forward_context.prefetch_mlp_gate_up_proj: - prefetch_stream.wait_stream(torch.npu.current_stream()) + weight_prefetch_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(prefetch_stream): + with torch.npu.stream(weight_prefetch_stream): mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE torch_npu.npu_prefetch( model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, @@ -178,13 +178,13 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: return forward_context.prefetch_mlp_down_proj = True model_instance = forward_context.model_instance - prefetch_stream = forward_context.prefetch_stream + weight_prefetch_stream = prefetch_stream() layer_idx = forward_context.layer_idx # start point of down_proj weight prefetch - prefetch_stream.wait_stream(torch.npu.current_stream()) + weight_prefetch_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(prefetch_stream): + with torch.npu.stream(weight_prefetch_stream): mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE torch_npu.npu_prefetch( model_instance.model.layers[layer_idx].mlp.down_proj.weight, @@ -208,9 +208,9 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: return if forward_context.prefetch_mlp_gate_up_proj or \ forward_context.prefetch_mlp_down_proj: - prefetch_stream = forward_context.prefetch_stream + weight_prefetch_stream = prefetch_stream() # wait until prefetch done - torch.npu.current_stream().wait_stream(prefetch_stream) + torch.npu.current_stream().wait_stream(weight_prefetch_stream) forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False return diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 30846a3c..8809682e 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -19,10 +19,10 @@ from typing import Any, Dict, Optional import torch import torch_npu -from vllm.forward_context import get_forward_context from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType, - get_ascend_device_type, maybe_trans_nz) + get_ascend_device_type, + get_weight_prefetch_method, maybe_trans_nz) def quant_per_tensor(in_tensor: torch.Tensor, @@ -98,12 +98,7 @@ class AscendW8A8LinearMethod: ) -> torch.Tensor: if x.dtype != torch.int8: layer_cls_name = layer.__class__.__name__ - try: - weight_prefetch_method = get_forward_context( - ).weight_prefetch_method - except AssertionError: - weight_prefetch_method = None - + weight_prefetch_method = get_weight_prefetch_method() # prefetch qkvo_proj.weight preprocess if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 4ef58970..e7e0484c 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -34,7 +34,7 @@ from vllm.logger import logger from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_config import WeightPrefetchConfig, get_ascend_config if TYPE_CHECKING: from vllm.config import VllmConfig @@ -52,6 +52,7 @@ ACL_FORMAT_FRACTAL_NZ = 29 _CUSTOM_OP_ENABLED = None _CURRENT_STREAM = None _PREFETCH_STREAM = None +_WEIGHT_PREFETCH_METHOD = None _GLOBAL_STREAM = None _SHARED_EXPERTS_CALCULATION_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False @@ -309,6 +310,18 @@ def prefetch_stream() -> torch.npu.Stream: return _PREFETCH_STREAM +def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig): + global _WEIGHT_PREFETCH_METHOD + if _WEIGHT_PREFETCH_METHOD is None: + from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod + _WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config) + return _WEIGHT_PREFETCH_METHOD + + +def get_weight_prefetch_method(): + return _WEIGHT_PREFETCH_METHOD + + def global_stream() -> torch.npu.Stream: global _GLOBAL_STREAM if _GLOBAL_STREAM is None: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 82fc6861..d24584e7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -82,7 +82,6 @@ from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput, from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.utils import AttentionGroup -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -106,7 +105,6 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin -from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.sampler import AscendSampler @@ -115,7 +113,8 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_moe_model, - lmhead_tp_enable, maybe_trans_nz) + lmhead_tp_enable, maybe_trans_nz, + set_weight_prefetch_method) from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.ascend_forward_context import ( # isort: skip @@ -209,18 +208,13 @@ class NPUModelRunner(GPUModelRunner): self.pcp_rank = 0 if self.pcp_size > 1: self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs - if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: - self.prefetch_stream = torch.npu.Stream(device=device) - else: - self.prefetch_stream = None self.sampler = AscendSampler() self.attn_mask = None self.attn_state = None # Ascend-specific configurations self.ascend_config = get_ascend_config() - self.weight_prefetch_method = WeightPrefetchMethod( - self.ascend_config.weight_prefetch_config) + set_weight_prefetch_method(self.ascend_config.weight_prefetch_config) # Dump / PrecisionDebugger configuration now comes from AscendConfig dump_cfg = self.ascend_config.dump_config self.dump_enable = dump_cfg.enable_dump @@ -1420,9 +1414,7 @@ class NPUModelRunner(GPUModelRunner): batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. total_num_scheduled_tokens, - prefetch_stream=self.prefetch_stream, - model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method): + model_instance=self.model): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2133,9 +2125,7 @@ class NPUModelRunner(GPUModelRunner): num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, - prefetch_stream=self.prefetch_stream, - model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method): + model_instance=self.model): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds)