Files
xc-llm-ascend/vllm_ascend/ops/register_custom_ops.py
rjg-lyh 1bbb20ea13 [main] flashcomm_v1 optim in Qwen Dense Models (#2802)
### What this PR does / why we need it?
Flashcomm_v1 optim in Qwen Dense Models.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
5e537f45b4

Co-authored-by: 1024daniel <xxltju324@gmail.com>
2025-09-08 22:52:24 +08:00

63 lines
2.6 KiB
Python

import torch
import torch.nn.functional as F
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
if get_forward_context().flashcomm_v1_enabled:
pad_size = get_forward_context().pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
return residual
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
if flashcomm_v1_enabled and label:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
x = x[:-pad_size, :]
return x
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
if flashcomm_v1_enabled:
pad_size = get_forward_context().pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
return tensor_model_parallel_all_reduce(x)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")