[main] flashcomm_v1 optim in Qwen Dense Models (#2802)

### What this PR does / why we need it?
Flashcomm_v1 optim in Qwen Dense Models.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
5e537f45b4

Co-authored-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
rjg-lyh
2025-09-08 22:52:24 +08:00
committed by GitHub
parent 4df8df5b94
commit 1bbb20ea13
11 changed files with 362 additions and 20 deletions

View File

@@ -23,6 +23,7 @@ Run `pytest tests/test_offline_inference.py`.
import os
from unittest.mock import patch
import pytest
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
@@ -30,6 +31,8 @@ from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"]
def test_models_distributed_QwQ():
example_prompts = [
@@ -150,3 +153,23 @@ def test_sp_for_qwen3_moe() -> None:
enable_expert_parallel=True,
enforce_eager=True) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download(model),
max_model_len=8192,
enforce_eager=enforce_eager,
dtype="auto",
tensor_parallel_size=4,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -10,6 +10,13 @@ def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)
def mock_maybe_chunk_residual(x, residual):
if x.size(0) != residual.size(0):
return residual[:4]
return residual
def mock_rms_norm(x, weight, eps):
return x + 1, None
@@ -23,11 +30,13 @@ def mock_add_rms_norm(x, residual, weight, eps):
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
residual, dummy_tensor):
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = RMSNorm(hidden_size=32, eps=1e-05)
layer = RMSNorm(hidden_size=8, eps=1e-05)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
@@ -51,3 +60,25 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
mock_add_rms_norm, mock_is310p):
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)
layer = RMSNorm(hidden_size=512, eps=1e-05)
out_x, out_residual = layer.forward_oot(x, residual)
expected_out_x = 2 * x
expected_out_residual = 2 * residual[:4]
mock_maybe_chunk_residual.assert_called_once()
mock_add_rms_norm.assert_called_once()
assert out_residual.size(0) == 4
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)

View File

@@ -303,13 +303,13 @@ class TestUtils(TestBase):
# ascend custom op is not registered
utils.register_ascend_customop()
# should call register_oot three
self.assertEqual(mock_customop.register_oot.call_count, 12)
self.assertEqual(mock_customop.register_oot.call_count, 13)
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
# ascend custom op is already registered
utils.register_ascend_customop()
# should not register_oot again, thus only called three in this ut
self.assertEqual(mock_customop.register_oot.call_count, 12)
self.assertEqual(mock_customop.register_oot.call_count, 13)
class TestProfileExecuteDuration(TestBase):

View File

@@ -83,6 +83,7 @@ def set_ascend_forward_context(
forward_context = get_forward_context()
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
forward_context.with_prefill = with_prefill
tp_world_size = get_tensor_model_parallel_world_size()
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
@@ -103,6 +104,21 @@ def set_ascend_forward_context(
# due to multiple warmups before actual capturing
forward_context.capturing = False
# set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
if flashcomm_v1_enabled:
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
@@ -118,7 +134,6 @@ def set_ascend_forward_context(
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
tp_world_size = get_tensor_model_parallel_world_size()
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size

View File

@@ -131,6 +131,15 @@ env_variables: Dict[str, Callable[[], Any]] = {
# this feature is supported in A2, and eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
# Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
# Whether to enable dense model and general optimizations for better performance.
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
# Whether to enable mlp optimize when tensor parallel is enabled.
# this feature in eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":

View File

@@ -20,6 +20,7 @@ import torch
import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import (

View File

@@ -44,6 +44,13 @@ class AddRMSNormW8A8Quant(RMSNorm):
import torch_npu
if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
@@ -69,6 +76,13 @@ class AscendRMSNorm(RMSNorm):
from vllm_ascend.utils import is_310p
if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)

View File

@@ -26,20 +26,18 @@ from torch.nn.parameter import Parameter
from vllm.distributed import divide, split_tensor_along_last_dim
from vllm.distributed.parallel_state import get_tp_group
from vllm.lora.utils import LinearBase
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
MergedColumnParallelLinear,
QuantizeMethodBase,
RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
mlp_tp_enable, oproj_tp_enable)
_HCOMM_INFO = None
@@ -150,6 +148,9 @@ class AscendRowParallelLinear(RowParallelLinear):
comm_group = get_tp_group()
self.forward_type = "matmul_allreduce"
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
elif dense_optim_enable():
comm_group = get_tp_group()
self.forward_type = "dense_optim"
else:
comm_group = get_tp_group()
self.forward_type = "normal"
@@ -231,6 +232,8 @@ class AscendRowParallelLinear(RowParallelLinear):
return self._forward_mlp_tp(input_)
elif self.forward_type == "matmul_allreduce":
return self._forward_matmul_allreduce(input_)
elif self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
@@ -332,6 +335,39 @@ class AscendRowParallelLinear(RowParallelLinear):
return output
return output, output_bias
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.tp_size == 1 or not self.reduce_results:
output = self.quant_method.apply(self, input_parallel, bias=bias_)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
@@ -357,15 +393,18 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
*,
return_bias: bool = True,
):
self.comm_group = None
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
self.comm_group = get_mlp_tp_group()
comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
elif dense_optim_enable():
comm_group = get_tp_group()
self.forward_type = "dense_optim"
else:
self.comm_group = get_tp_group()
comm_group = get_tp_group()
self.forward_type = "normal_tp"
self.tp_rank = self.comm_group.rank_in_group
self.tp_size = self.comm_group.world_size
self.comm_group = comm_group
self.tp_rank = comm_group.rank_in_group
self.tp_size = comm_group.world_size
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
@@ -387,6 +426,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "mlp_tp":
return self._forward_mlp_tp(input_)
elif self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
@@ -405,6 +446,138 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
return output
return output, output_bias
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendQKVParallelLinear(QKVParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
if dense_optim_enable():
self.forward_type = "dense_optim"
else:
self.forward_type = "normal_tp"
self.comm_group = get_tp_group()
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = self.comm_group.world_size
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
layer_num = self.prefix.split('.')[2]
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendLinearBase(LinearBase):
@@ -438,4 +611,4 @@ class AscendLinearBase(LinearBase):
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
self.disable_tp = disable_tp

View File

@@ -0,0 +1,63 @@
import torch
import torch.nn.functional as F
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
if get_forward_context().flashcomm_v1_enabled:
pad_size = get_forward_context().pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
return residual
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
if flashcomm_v1_enabled and label:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
x = x[:-pad_size, :]
return x
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
if flashcomm_v1_enabled:
pad_size = get_forward_context().pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
return tensor_model_parallel_all_reduce(x)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")

View File

@@ -493,6 +493,7 @@ def register_ascend_customop():
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
@@ -510,6 +511,8 @@ def register_ascend_customop():
name="RowParallelLinear")
CustomOp.register_oot(_decorated_op_cls=AscendMergedColumnParallelLinear,
name="MergedColumnParallelLinear")
CustomOp.register_oot(_decorated_op_cls=AscendQKVParallelLinear,
name="QKVParallelLinear")
CustomOp.register_oot(
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
name="DeepseekScalingRotaryEmbedding")
@@ -572,3 +575,7 @@ def mlp_tp_enable() -> bool:
def matmul_allreduce_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
def dense_optim_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE

View File

@@ -37,6 +37,7 @@ from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -1182,6 +1183,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if get_forward_context().flashcomm_v1_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
hidden_states = hidden_states[:-pad_size, :]
return hidden_states
def _build_attn_state(self, num_reqs, num_scheduled_tokens,