From fc2bcbe21c86f7684c80e42771b128da9fc17571 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:58:08 +0800 Subject: [PATCH] [Ops] Fix bug in register_custom_ops without forward_context (#2883) ### What this PR does / why we need it? This PR fixed the bug in register_custom_ops without forward_context. We set try-except to consider this situation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/7920de0a2af7c663efae29856f1bac061a553b58 Signed-off-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/envs.py | 10 +++-- vllm_ascend/ops/register_custom_ops.py | 62 ++++++++++++++++++++------ 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index ef8e33e..5792c83 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -139,11 +139,13 @@ env_variables: Dict[str, Callable[[], Any]] = { "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), # buffer size for gate up prefetch - "MLP_GATE_UP_PREFETCH_SIZE": - lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), + "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": + lambda: int( + os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), # buffer size for down proj prefetch - "MLP_DOWN_PREFETCH_SIZE": - lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), + "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": + lambda: int( + os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), # Whether to enable dense model and general optimizations for better performance. # Since we modified the base parent class `linear`, this optimization is also applicable to other model types. # However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index d066dc9..916b86f 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -7,6 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context +from vllm.logger import logger from vllm.utils import direct_register_custom_op import vllm_ascend.envs as envs_ascend @@ -14,12 +15,18 @@ import vllm_ascend.envs as envs_ascend def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + logger.info("Forward context is None, skipping the operation.") + return residual + if x.size(0) != residual.size(0): - flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled assert flashcomm_v1_enabled is True, ( "Currently, this situation only occurs " "when flashcomm_v1 is enabled") - pad_size = get_forward_context().pad_size + pad_size = forward_context.pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) tp_size = get_tensor_model_parallel_world_size() @@ -31,19 +38,31 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool) -> torch.Tensor: - flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + try: + forward_context = get_forward_context() + except AssertionError: + logger.info("Forward context is None, skipping the operation.") + return x + + flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled if flashcomm_v1_enabled and label: x = tensor_model_parallel_all_gather(x, 0) - pad_size = get_forward_context().pad_size + pad_size = forward_context.pad_size if pad_size > 0: x = x[:-pad_size, :] return x def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: - flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + try: + forward_context = get_forward_context() + except AssertionError: + logger.info("Forward context is None, skipping the operation.") + return tensor_model_parallel_all_reduce(x) + + flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled if flashcomm_v1_enabled: - pad_size = get_forward_context().pad_size + pad_size = forward_context.pad_size if pad_size > 0: x = F.pad(x, (0, 0, 0, pad_size)) return tensor_model_parallel_reduce_scatter(x, 0) @@ -53,7 +72,12 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, prefix: str) -> None: - forward_context = get_forward_context() + try: + forward_context = get_forward_context() + except AssertionError: + logger.info("Forward context is None, skipping the operation.") + return + if not forward_context.prefetch_mlp_enabled: return model_instance = forward_context.model_instance @@ -67,9 +91,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, prefetch_stream.wait_stream(torch.npu.current_stream()) with torch.npu.stream(prefetch_stream): - MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE + mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \ - x_dependency, MLP_GATE_UP_PREFETCH_SIZE) + x_dependency, mlp_gate_up_prefetch_size) return @@ -79,7 +103,12 @@ def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: - forward_context = get_forward_context() + try: + forward_context = get_forward_context() + except AssertionError: + logger.info("Forward context is None, skipping the operation.") + return + if not forward_context.prefetch_mlp_enabled: return forward_context.prefetch_mlp_down_proj = True @@ -91,9 +120,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: prefetch_stream.wait_stream(torch.npu.current_stream()) with torch.npu.stream(prefetch_stream): - MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE + mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \ - x_dependency, MLP_DOWN_PREFETCH_SIZE) + x_dependency, mlp_down_prefetch_size) forward_context.layer_idx += 1 return @@ -104,12 +133,17 @@ def _maybe_prefetch_mlp_down_proj_impl_fake( def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: - forward_context = get_forward_context() + try: + forward_context = get_forward_context() + except AssertionError: + logger.info("Forward context is None, skipping the operation.") + return + if not forward_context.prefetch_mlp_enabled: return if forward_context.prefetch_mlp_gate_up_proj or \ forward_context.prefetch_mlp_down_proj: - prefetch_stream = get_forward_context().prefetch_stream + prefetch_stream = forward_context.prefetch_stream # wait until prefetch done torch.npu.current_stream().wait_stream(prefetch_stream) forward_context.prefetch_mlp_gate_up_proj = False