Files
xc-llm-ascend/vllm_ascend/ops/register_custom_ops.py
rjg-lyh fc2bcbe21c [Ops] Fix bug in register_custom_ops without forward_context (#2883)
### What this PR does / why we need it?
This PR fixed the bug in register_custom_ops without forward_context. We
set try-except to consider this situation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
7920de0a2a

Signed-off-by: rjg-lyh <1318825571@qq.com>
2025-09-12 16:58:08 +08:00

193 lines
7.5 KiB
Python

import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend
def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
logger.info("Forward context is None, skipping the operation.")
return residual
if x.size(0) != residual.size(0):
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
assert flashcomm_v1_enabled is True, (
"Currently, this situation only occurs "
"when flashcomm_v1 is enabled")
pad_size = forward_context.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
return residual
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
logger.info("Forward context is None, skipping the operation.")
return x
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
if flashcomm_v1_enabled and label:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size
if pad_size > 0:
x = x[:-pad_size, :]
return x
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
logger.info("Forward context is None, skipping the operation.")
return tensor_model_parallel_all_reduce(x)
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
if flashcomm_v1_enabled:
pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
return tensor_model_parallel_all_reduce(x)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
prefix: str) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
logger.info("Forward context is None, skipping the operation.")
return
if not forward_context.prefetch_mlp_enabled:
return
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, mlp_gate_up_prefetch_size)
return
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None:
return
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
logger.info("Forward context is None, skipping the operation.")
return
if not forward_context.prefetch_mlp_enabled:
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, mlp_down_prefetch_size)
forward_context.layer_idx += 1
return
def _maybe_prefetch_mlp_down_proj_impl_fake(
x_dependency: torch.Tensor) -> None:
return
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
logger.info("Forward context is None, skipping the operation.")
return
if not forward_context.prefetch_mlp_enabled:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
prefetch_stream = forward_context.prefetch_stream
# wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
op_func=_maybe_prefetch_mlp_down_proj_impl,
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
op_func=_maybe_wait_prefetch_done_impl,
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")