[refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204)

### What this PR does / why we need it?
Remove unnecessary attributes from set_ascend_forward_context
1.prefetch_stream
2.weight_prefetch_method
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2025-12-23 08:49:52 +08:00
committed by GitHub
parent 95e8a52156
commit c3a8d13ca7
10 changed files with 55 additions and 83 deletions

View File

@@ -1,7 +1,7 @@
import math
from contextlib import contextmanager
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
import torch
from vllm.config import CUDAGraphMode, VllmConfig
@@ -16,11 +16,6 @@ from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
get_ascend_device_type, has_layer_idx,
is_moe_model)
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
else:
WeightPrefetchMethod = None
class MoECommType(Enum):
ALLGATHER = 0
@@ -41,9 +36,7 @@ def set_ascend_forward_context(
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
is_mtp_model=False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
@@ -116,13 +109,10 @@ def set_ascend_forward_context(
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method
forward_context.is_mtp_model = is_mtp_model
if num_tokens is None and attn_metadata is not None:

View File

@@ -18,7 +18,8 @@ from typing import Callable, Optional
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.utils import get_weight_prefetch_method
def select_experts(hidden_states: torch.Tensor,
@@ -56,7 +57,7 @@ def select_experts(hidden_states: torch.Tensor,
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_forward_context().weight_prefetch_method
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")

View File

@@ -24,7 +24,8 @@ from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
enable_custom_op, get_ascend_device_type)
enable_custom_op, get_ascend_device_type,
get_weight_prefetch_method)
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
@@ -100,7 +101,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None
_output_dtype = w2_scale[0].dtype
weight_prefetch_method = get_forward_context().weight_prefetch_method
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
hidden_states)

View File

@@ -119,16 +119,16 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
weight_prefetch_stream = prefetch_stream()
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
with torch.npu.stream(weight_prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
@@ -178,13 +178,13 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
weight_prefetch_stream = prefetch_stream()
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream())
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
with torch.npu.stream(weight_prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
@@ -208,9 +208,9 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
prefetch_stream = forward_context.prefetch_stream
weight_prefetch_stream = prefetch_stream()
# wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream)
torch.npu.current_stream().wait_stream(weight_prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return

View File

@@ -19,10 +19,10 @@ from typing import Any, Dict, Optional
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type, maybe_trans_nz)
get_ascend_device_type,
get_weight_prefetch_method, maybe_trans_nz)
def quant_per_tensor(in_tensor: torch.Tensor,
@@ -98,12 +98,7 @@ class AscendW8A8LinearMethod:
) -> torch.Tensor:
if x.dtype != torch.int8:
layer_cls_name = layer.__class__.__name__
try:
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
except AssertionError:
weight_prefetch_method = None
weight_prefetch_method = get_weight_prefetch_method()
# prefetch qkvo_proj.weight preprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(

View File

@@ -34,7 +34,7 @@ from vllm.logger import logger
from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_config import WeightPrefetchConfig, get_ascend_config
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -52,6 +52,7 @@ ACL_FORMAT_FRACTAL_NZ = 29
_CUSTOM_OP_ENABLED = None
_CURRENT_STREAM = None
_PREFETCH_STREAM = None
_WEIGHT_PREFETCH_METHOD = None
_GLOBAL_STREAM = None
_SHARED_EXPERTS_CALCULATION_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
@@ -309,6 +310,18 @@ def prefetch_stream() -> torch.npu.Stream:
return _PREFETCH_STREAM
def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
global _WEIGHT_PREFETCH_METHOD
if _WEIGHT_PREFETCH_METHOD is None:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
return _WEIGHT_PREFETCH_METHOD
def get_weight_prefetch_method():
return _WEIGHT_PREFETCH_METHOD
def global_stream() -> torch.npu.Stream:
global _GLOBAL_STREAM
if _GLOBAL_STREAM is None:

View File

@@ -82,7 +82,6 @@ from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput,
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.utils import AttentionGroup
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -106,7 +105,6 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.sampler import AscendSampler
@@ -115,7 +113,8 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_moe_model,
lmhead_tp_enable, maybe_trans_nz)
lmhead_tp_enable, maybe_trans_nz,
set_weight_prefetch_method)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.ascend_forward_context import ( # isort: skip
@@ -209,18 +208,13 @@ class NPUModelRunner(GPUModelRunner):
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
self.prefetch_stream = torch.npu.Stream(device=device)
else:
self.prefetch_stream = None
self.sampler = AscendSampler()
self.attn_mask = None
self.attn_state = None
# Ascend-specific configurations
self.ascend_config = get_ascend_config()
self.weight_prefetch_method = WeightPrefetchMethod(
self.ascend_config.weight_prefetch_config)
set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
# Dump / PrecisionDebugger configuration now comes from AscendConfig
dump_cfg = self.ascend_config.dump_config
self.dump_enable = dump_cfg.enable_dump
@@ -1420,9 +1414,7 @@ class NPUModelRunner(GPUModelRunner):
batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens,
prefetch_stream=self.prefetch_stream,
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
model_instance=self.model):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
@@ -2133,9 +2125,7 @@ class NPUModelRunner(GPUModelRunner):
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
prefetch_stream=self.prefetch_stream,
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
model_instance=self.model):
hidden_states = self._generate_dummy_run_hidden_states(
input_ids, positions, num_tokens_padded,
intermediate_tensors, inputs_embeds)