[refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204)

### What this PR does / why we need it?
Remove unnecessary attributes from set_ascend_forward_context
1.prefetch_stream
2.weight_prefetch_method
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2025-12-23 08:49:52 +08:00
committed by GitHub
parent 95e8a52156
commit c3a8d13ca7
10 changed files with 55 additions and 83 deletions

View File

@@ -286,8 +286,8 @@ def test_select_experts(
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk, \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())):
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
@@ -323,8 +323,8 @@ def test_select_experts(
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()), \
pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
@@ -336,17 +336,3 @@ def test_select_experts_invalid_scoring_func(device: str):
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_missing_group_params(device: str):
with pytest.raises(AssertionError):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 64, device=device),
top_k=2,
use_grouped_topk=True,
renormalize=False,
scoring_func="softmax")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -92,15 +92,14 @@ def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method.finalize.side_effect = mock_finalize
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
mock_weight_prefetch_method = MagicMock()
mock_forward_context_obj = MagicMock(
moe_comm_method=mock_moe_comm_method,
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
mc2_mask=torch.zeros(16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False,
weight_prefetch_method=mock_weight_prefetch_method)
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
mc2_mask=torch.zeros(
16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
@@ -133,8 +132,8 @@ def mock_dist_env(mocker: MockerFixture):
return_value=None), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=mock_forward_context_obj):
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=mock_weight_prefetch_method):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
@@ -590,4 +589,4 @@ class TestUnifiedApplyMLP(TestBase):
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.dtype, torch.bfloat16)

View File

@@ -63,21 +63,18 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
@patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method")
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
mock_get_forward_context):
mock_get_weight_prefetch_method):
layer = MagicMock()
layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3
mock_forward_context = MagicMock()
mock_get_forward_context.return_value = mock_forward_context
mock_weight_prefetch_method = MagicMock()
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
mock_get_weight_prefetch_method.return_value = MagicMock()
x = torch.randn(32, 128)
bias = torch.randn(256)

View File

@@ -1,7 +1,7 @@
import math
from contextlib import contextmanager
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
import torch
from vllm.config import CUDAGraphMode, VllmConfig
@@ -16,11 +16,6 @@ from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
get_ascend_device_type, has_layer_idx,
is_moe_model)
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
else:
WeightPrefetchMethod = None
class MoECommType(Enum):
ALLGATHER = 0
@@ -41,9 +36,7 @@ def set_ascend_forward_context(
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
is_mtp_model=False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
@@ -116,13 +109,10 @@ def set_ascend_forward_context(
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method
forward_context.is_mtp_model = is_mtp_model
if num_tokens is None and attn_metadata is not None:

View File

@@ -18,7 +18,8 @@ from typing import Callable, Optional
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.utils import get_weight_prefetch_method
def select_experts(hidden_states: torch.Tensor,
@@ -56,7 +57,7 @@ def select_experts(hidden_states: torch.Tensor,
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_forward_context().weight_prefetch_method
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")

View File

@@ -24,7 +24,8 @@ from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
enable_custom_op, get_ascend_device_type)
enable_custom_op, get_ascend_device_type,
get_weight_prefetch_method)
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
@@ -100,7 +101,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None
_output_dtype = w2_scale[0].dtype
weight_prefetch_method = get_forward_context().weight_prefetch_method
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
hidden_states)

View File

@@ -119,16 +119,16 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
weight_prefetch_stream = prefetch_stream()
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
with torch.npu.stream(weight_prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
@@ -178,13 +178,13 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
weight_prefetch_stream = prefetch_stream()
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream())
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
with torch.npu.stream(weight_prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
@@ -208,9 +208,9 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
prefetch_stream = forward_context.prefetch_stream
weight_prefetch_stream = prefetch_stream()
# wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream)
torch.npu.current_stream().wait_stream(weight_prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return

View File

@@ -19,10 +19,10 @@ from typing import Any, Dict, Optional
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type, maybe_trans_nz)
get_ascend_device_type,
get_weight_prefetch_method, maybe_trans_nz)
def quant_per_tensor(in_tensor: torch.Tensor,
@@ -98,12 +98,7 @@ class AscendW8A8LinearMethod:
) -> torch.Tensor:
if x.dtype != torch.int8:
layer_cls_name = layer.__class__.__name__
try:
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
except AssertionError:
weight_prefetch_method = None
weight_prefetch_method = get_weight_prefetch_method()
# prefetch qkvo_proj.weight preprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(

View File

@@ -34,7 +34,7 @@ from vllm.logger import logger
from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_config import WeightPrefetchConfig, get_ascend_config
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -52,6 +52,7 @@ ACL_FORMAT_FRACTAL_NZ = 29
_CUSTOM_OP_ENABLED = None
_CURRENT_STREAM = None
_PREFETCH_STREAM = None
_WEIGHT_PREFETCH_METHOD = None
_GLOBAL_STREAM = None
_SHARED_EXPERTS_CALCULATION_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
@@ -309,6 +310,18 @@ def prefetch_stream() -> torch.npu.Stream:
return _PREFETCH_STREAM
def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
global _WEIGHT_PREFETCH_METHOD
if _WEIGHT_PREFETCH_METHOD is None:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
return _WEIGHT_PREFETCH_METHOD
def get_weight_prefetch_method():
return _WEIGHT_PREFETCH_METHOD
def global_stream() -> torch.npu.Stream:
global _GLOBAL_STREAM
if _GLOBAL_STREAM is None:

View File

@@ -82,7 +82,6 @@ from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput,
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.utils import AttentionGroup
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -106,7 +105,6 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.sampler import AscendSampler
@@ -115,7 +113,8 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_moe_model,
lmhead_tp_enable, maybe_trans_nz)
lmhead_tp_enable, maybe_trans_nz,
set_weight_prefetch_method)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.ascend_forward_context import ( # isort: skip
@@ -209,18 +208,13 @@ class NPUModelRunner(GPUModelRunner):
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
self.prefetch_stream = torch.npu.Stream(device=device)
else:
self.prefetch_stream = None
self.sampler = AscendSampler()
self.attn_mask = None
self.attn_state = None
# Ascend-specific configurations
self.ascend_config = get_ascend_config()
self.weight_prefetch_method = WeightPrefetchMethod(
self.ascend_config.weight_prefetch_config)
set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
# Dump / PrecisionDebugger configuration now comes from AscendConfig
dump_cfg = self.ascend_config.dump_config
self.dump_enable = dump_cfg.enable_dump
@@ -1420,9 +1414,7 @@ class NPUModelRunner(GPUModelRunner):
batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens,
prefetch_stream=self.prefetch_stream,
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
model_instance=self.model):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
@@ -2133,9 +2125,7 @@ class NPUModelRunner(GPUModelRunner):
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
prefetch_stream=self.prefetch_stream,
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
model_instance=self.model):
hidden_states = self._generate_dummy_run_hidden_states(
input_ids, positions, num_tokens_padded,
intermediate_tensors, inputs_embeds)