From 1bbb20ea13e2b9a47936abebdcfb6143fdce8079 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Mon, 8 Sep 2025 22:52:24 +0800 Subject: [PATCH] [main] flashcomm_v1 optim in Qwen Dense Models (#2802) ### What this PR does / why we need it? Flashcomm_v1 optim in Qwen Dense Models. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/5e537f45b453096c489d06f0c47776f9436ef99b Co-authored-by: 1024daniel --- .../test_offline_inference_distributed.py | 23 ++ tests/ut/ops/test_layernorm.py | 37 +++- tests/ut/test_utils.py | 4 +- vllm_ascend/ascend_forward_context.py | 17 +- vllm_ascend/envs.py | 9 + vllm_ascend/ops/__init__.py | 1 + vllm_ascend/ops/layernorm.py | 14 ++ vllm_ascend/ops/linear.py | 201 ++++++++++++++++-- vllm_ascend/ops/register_custom_ops.py | 63 ++++++ vllm_ascend/utils.py | 7 + vllm_ascend/worker/model_runner_v1.py | 6 + 11 files changed, 362 insertions(+), 20 deletions(-) create mode 100644 vllm_ascend/ops/register_custom_ops.py diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index a90c864..4fd72ce 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -23,6 +23,7 @@ Run `pytest tests/test_offline_inference.py`. import os from unittest.mock import patch +import pytest from modelscope import snapshot_download # type: ignore from vllm import SamplingParams @@ -30,6 +31,8 @@ from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"] + def test_models_distributed_QwQ(): example_prompts = [ @@ -150,3 +153,23 @@ def test_sp_for_qwen3_moe() -> None: enable_expert_parallel=True, enforce_eager=True) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("model", QWEN_DENSE_MODELS) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"}) +def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download(model), + max_model_len=8192, + enforce_eager=enforce_eager, + dtype="auto", + tensor_parallel_size=4, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index c7bc657..a2d9877 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -10,6 +10,13 @@ def dummy_tensor(): return torch.randn(4, 8, dtype=torch.float16) +def mock_maybe_chunk_residual(x, residual): + if x.size(0) != residual.size(0): + return residual[:4] + + return residual + + def mock_rms_norm(x, weight, eps): return x + 1, None @@ -23,11 +30,13 @@ def mock_add_rms_norm(x, residual, weight, eps): [None, torch.randn(4, 8, dtype=torch.float32)]) @patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) @patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) -def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return, - residual, dummy_tensor): +@patch("torch.ops.vllm.maybe_chunk_residual", + side_effect=mock_maybe_chunk_residual) +def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, + mock_rmsnorm, is_310p_return, residual, dummy_tensor): with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): - layer = RMSNorm(hidden_size=32, eps=1e-05) + layer = RMSNorm(hidden_size=8, eps=1e-05) if residual is not None: out_x, out_residual = layer.forward_oot(dummy_tensor, residual) @@ -51,3 +60,25 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return, mock_rmsnorm.assert_called_once() assert torch.allclose(out_x, expected_out_x) + + +@patch("vllm_ascend.utils.is_310p", return_value=False) +@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) +@patch("torch.ops.vllm.maybe_chunk_residual", + side_effect=mock_maybe_chunk_residual) +def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, + mock_add_rms_norm, mock_is310p): + x = torch.randn(4, 512, dtype=torch.bfloat16) + residual = torch.randn(16, 512, dtype=torch.bfloat16) + layer = RMSNorm(hidden_size=512, eps=1e-05) + + out_x, out_residual = layer.forward_oot(x, residual) + + expected_out_x = 2 * x + expected_out_residual = 2 * residual[:4] + + mock_maybe_chunk_residual.assert_called_once() + mock_add_rms_norm.assert_called_once() + assert out_residual.size(0) == 4 + assert torch.allclose(out_x, expected_out_x) + assert torch.allclose(out_residual, expected_out_residual) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 463a30d..99f821a 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -303,13 +303,13 @@ class TestUtils(TestBase): # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 12) + self.assertEqual(mock_customop.register_oot.call_count, 13) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 12) + self.assertEqual(mock_customop.register_oot.call_count, 13) class TestProfileExecuteDuration(TestBase): diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 71ae4d0..9bcddf6 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -83,6 +83,7 @@ def set_ascend_forward_context( forward_context = get_forward_context() forward_context.moe_comm_method_name = moe_comm_method + "commimpl" forward_context.with_prefill = with_prefill + tp_world_size = get_tensor_model_parallel_world_size() ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) @@ -103,6 +104,21 @@ def set_ascend_forward_context( # due to multiple warmups before actual capturing forward_context.capturing = False + # set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature. + # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, + # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, + # the performance may degrade due to the switching of communication methods. + flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ + tp_world_size > 1 and \ + num_tokens is not None and num_tokens > 1000 + + if flashcomm_v1_enabled: + pad_size = (tp_world_size - + (num_tokens % tp_world_size)) % tp_world_size + forward_context.pad_size = pad_size + + forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens @@ -118,7 +134,6 @@ def set_ascend_forward_context( if num_tokens is not None: if num_actual_tokens is None: num_actual_tokens = num_tokens - tp_world_size = get_tensor_model_parallel_world_size() # NOTE: token num which need to pad to when mc2 forward_context.padded_num_tokens = math.ceil( max_tokens_across_dp / tp_world_size) * tp_world_size diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 78f8c50..a9ae83d 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -131,6 +131,15 @@ env_variables: Dict[str, Callable[[], Any]] = { # this feature is supported in A2, and eager mode will get better performance. "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), + # Whether to enable FlashComm optimization when tensor parallel is enabled. + # This feature will get better performance when concurrency is large. + "VLLM_ASCEND_ENABLE_FLASHCOMM": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), + # Whether to enable dense model and general optimizations for better performance. + # Since we modified the base parent class `linear`, this optimization is also applicable to other model types. + # However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. + "VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))), # Whether to enable mlp optimize when tensor parallel is enabled. # this feature in eager mode will get better performance. "VLLM_ASCEND_ENABLE_MLP_OPTIMIZE": diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index a1e7417..5c8a798 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -20,6 +20,7 @@ import torch import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa +import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 4f0b550..adaa73c 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -44,6 +44,13 @@ class AddRMSNormW8A8Quant(RMSNorm): import torch_npu if residual is not None: + # FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature + # is enabled, without interfering with the normal operation of components like torchair. + # The final solution should be to move this check into the operator and support + # integration with torchair. + if x.size(0) != residual.size(0): + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + assert x.size(0) == residual.size(0) x, _, residual = torch_npu.npu_add_rms_norm_quant( x, residual, @@ -69,6 +76,13 @@ class AscendRMSNorm(RMSNorm): from vllm_ascend.utils import is_310p if residual is not None: + # FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature + # is enabled, without interfering with the normal operation of components like torchair. + # The final solution should be to move this check into the operator and support + # integration with torchair. + if x.size(0) != residual.size(0): + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + assert x.size(0) == residual.size(0) if is_310p(): orig_dtype = residual.dtype x = x + residual.to(x.dtype) diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index c29837a..8bb7b85 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -26,20 +26,18 @@ from torch.nn.parameter import Parameter from vllm.distributed import divide, split_tensor_along_last_dim from vllm.distributed.parallel_state import get_tp_group from vllm.lora.utils import LinearBase -from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, - ColumnParallelLinear, - MergedColumnParallelLinear, - QuantizeMethodBase, - RowParallelLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( # noqa + WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase, + RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (matmul_allreduce_enable, mlp_tp_enable, - oproj_tp_enable) +from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable, + mlp_tp_enable, oproj_tp_enable) _HCOMM_INFO = None @@ -150,6 +148,9 @@ class AscendRowParallelLinear(RowParallelLinear): comm_group = get_tp_group() self.forward_type = "matmul_allreduce" self.hcomm_info = self.get_hcomm_info(comm_group.device_group) + elif dense_optim_enable(): + comm_group = get_tp_group() + self.forward_type = "dense_optim" else: comm_group = get_tp_group() self.forward_type = "normal" @@ -231,6 +232,8 @@ class AscendRowParallelLinear(RowParallelLinear): return self._forward_mlp_tp(input_) elif self.forward_type == "matmul_allreduce": return self._forward_matmul_allreduce(input_) + elif self.forward_type == "dense_optim": + return self._forward_dense_optim(input_) else: return super().forward(input_) @@ -332,6 +335,39 @@ class AscendRowParallelLinear(RowParallelLinear): return output return output, output_bias + def _forward_dense_optim( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + assert self.quant_method is not None + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + + if self.tp_size == 1 or not self.reduce_results: + output = self.quant_method.apply(self, input_parallel, bias=bias_) + else: + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): """Packed linear layers with column parallelism. @@ -357,15 +393,18 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): *, return_bias: bool = True, ): - self.comm_group = None if prefix.find("gate_up_proj") != -1 and mlp_tp_enable(): - self.comm_group = get_mlp_tp_group() + comm_group = get_mlp_tp_group() self.forward_type = "mlp_tp" + elif dense_optim_enable(): + comm_group = get_tp_group() + self.forward_type = "dense_optim" else: - self.comm_group = get_tp_group() + comm_group = get_tp_group() self.forward_type = "normal_tp" - self.tp_rank = self.comm_group.rank_in_group - self.tp_size = self.comm_group.world_size + self.comm_group = comm_group + self.tp_rank = comm_group.rank_in_group + self.tp_size = comm_group.world_size self.output_sizes = output_sizes assert all(output_size % self.tp_size == 0 @@ -387,6 +426,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.forward_type == "mlp_tp": return self._forward_mlp_tp(input_) + elif self.forward_type == "dense_optim": + return self._forward_dense_optim(input_) else: return super().forward(input_) @@ -405,6 +446,138 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): return output return output, output_bias + def _forward_dense_optim( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + output_parallel = self.quant_method.apply(self, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +class AscendQKVParallelLinear(QKVParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + if dense_optim_enable(): + self.forward_type = "dense_optim" + else: + self.forward_type = "normal_tp" + self.comm_group = get_tp_group() + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = self.comm_group.world_size + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + AscendColumnParallelLinear.__init__(self, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias) + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.forward_type == "dense_optim": + return self._forward_dense_optim(input_) + else: + return super().forward(input_) + + def _forward_dense_optim( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + layer_num = self.prefix.split('.')[2] + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + input_, layer_num != '0') + output_parallel = self.quant_method.apply(self, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + class AscendLinearBase(LinearBase): @@ -438,4 +611,4 @@ class AscendLinearBase(LinearBase): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias - self.disable_tp = disable_tp + self.disable_tp = disable_tp \ No newline at end of file diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py new file mode 100644 index 0000000..0391258 --- /dev/null +++ b/vllm_ascend/ops/register_custom_ops.py @@ -0,0 +1,63 @@ +import torch +import torch.nn.functional as F +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context +from vllm.utils import direct_register_custom_op + + +def _maybe_chunk_residual_impl(x: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: + if get_forward_context().flashcomm_v1_enabled: + pad_size = get_forward_context().pad_size + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] + + return residual + + +def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, + label: bool) -> torch.Tensor: + flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + if flashcomm_v1_enabled and label: + x = tensor_model_parallel_all_gather(x, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + x = x[:-pad_size, :] + return x + + +def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: + flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + if flashcomm_v1_enabled: + pad_size = get_forward_context().pad_size + if pad_size > 0: + x = F.pad(x, (0, 0, 0, pad_size)) + return tensor_model_parallel_reduce_scatter(x, 0) + else: + return tensor_model_parallel_all_reduce(x) + + +direct_register_custom_op(op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: residual, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", + op_func=_maybe_all_gather_and_maybe_unpad_impl, + fake_impl=lambda x, label: x, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_pad_and_reduce", + op_func=_maybe_pad_and_reduce_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1") \ No newline at end of file diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 2959310..8813b68 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -493,6 +493,7 @@ def register_ascend_customop(): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, + AscendQKVParallelLinear, AscendRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) @@ -510,6 +511,8 @@ def register_ascend_customop(): name="RowParallelLinear") CustomOp.register_oot(_decorated_op_cls=AscendMergedColumnParallelLinear, name="MergedColumnParallelLinear") + CustomOp.register_oot(_decorated_op_cls=AscendQKVParallelLinear, + name="QKVParallelLinear") CustomOp.register_oot( _decorated_op_cls=AscendDeepseekScalingRotaryEmbedding, name="DeepseekScalingRotaryEmbedding") @@ -572,3 +575,7 @@ def mlp_tp_enable() -> bool: def matmul_allreduce_enable() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE + + +def dense_optim_enable() -> bool: + return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b80fa94..ba5b439 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -37,6 +37,7 @@ from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -1182,6 +1183,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if get_forward_context().flashcomm_v1_enabled: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] return hidden_states def _build_attn_state(self, num_reqs, num_scheduled_tokens,