From 0a62e671fb0244f799137d3b4fae471dfd74f91d Mon Sep 17 00:00:00 2001 From: Levi <54832289+Levi-JQ@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:01:45 +0800 Subject: [PATCH] [Feat] flashcomm_v2 optim solution (#3232) ### What this PR does / why we need it? Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: zzhxx <2783294813@qq.com> Signed-off-by: Levi-JQ Co-authored-by: Levi-JQ Co-authored-by: zzhxx <2783294813@qq.com> --- .github/workflows/_e2e_test.yaml | 1 + .../test_offline_inference_distributed.py | 20 +++ tests/e2e/singlecard/test_aclgraph_mem.py | 1 + tests/ut/distributed/test_parallel_state.py | 27 +++- vllm_ascend/ascend_config.py | 4 + vllm_ascend/ascend_forward_context.py | 16 +- vllm_ascend/distributed/parallel_state.py | 72 ++++++++- vllm_ascend/envs.py | 8 +- vllm_ascend/ops/linear_op.py | 145 +++++++++++++++++- vllm_ascend/quantization/quant_config.py | 15 +- vllm_ascend/quantization/w8a8.py | 30 +++- vllm_ascend/utils.py | 65 ++++++++ 12 files changed, 380 insertions(+), 24 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 6bbd4ba6..cd3e3e58 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -195,6 +195,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_fc2_for_qwen3_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1 pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index a8102ec7..320c3bdf 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -189,6 +189,26 @@ def test_sp_for_qwen3_moe() -> None: vllm_model.generate(example_prompts, sampling_params) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"}) +def test_fc2_for_qwen3_moe() -> None: + example_prompts = [ + "Hello, my name is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"), + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager=True) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None: example_prompts = [ diff --git a/tests/e2e/singlecard/test_aclgraph_mem.py b/tests/e2e/singlecard/test_aclgraph_mem.py index c7e50788..df7d355e 100644 --- a/tests/e2e/singlecard/test_aclgraph_mem.py +++ b/tests/e2e/singlecard/test_aclgraph_mem.py @@ -34,6 +34,7 @@ MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"] reason="aclgraph only support on v1") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [4]) +@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "0"}) @patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) def test_aclgraph_mem_use(model: str, max_tokens: int) -> None: del os.environ["VLLM_WORKER_MULTIPROC_METHOD"] diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index c6724ce0..15a5c509 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -4,9 +4,10 @@ import pytest from vllm.config import ParallelConfig from vllm_ascend.distributed.parallel_state import ( - _LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel, - get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group, - init_ascend_model_parallel) + _FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP, + destroy_ascend_model_parallel, get_flashcomm2_odp_group, + get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group, + get_otp_group, get_p_tp_group, init_ascend_model_parallel) @pytest.fixture @@ -21,9 +22,13 @@ def mock_distributed(): with patch('torch.distributed.is_initialized', return_value=True), \ patch('torch.distributed.get_world_size', return_value=8), \ patch('torch.distributed.get_backend', return_value='nccl'), \ - patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group: + patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \ + patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \ + patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group: mock_group.return_value.local_rank = 0 mock_group.return_value.device_group = MagicMock() + mock_tp_group.return_value.world_size = 4 + mock_dp_group.return_value.world_size = 2 yield @@ -31,23 +36,33 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_ascend_config = MagicMock() mock_ascend_config.lmhead_tensor_parallel_size = 2 mock_ascend_config.oproj_tensor_parallel_size = 2 + mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2 mock_ascend_config.pd_tp_ratio = 2 mock_ascend_config.num_head_replica = 0 mock_ascend_config.pd_head_ratio = 2 mock_vllm_config = MagicMock() mock_vllm_config.kv_transfer_config.is_kv_producer = True + mock_envs_ascend = MagicMock() + mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2 + mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0 with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \ - patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config): + patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \ + patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \ + patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config): init_ascend_model_parallel(parallel_config) mc2_group = get_mc2_group() lmheadtp_group = get_lmhead_tp_group() otp_group = get_otp_group() + flashcomm2_otp_group = get_flashcomm2_otp_group() + flashcomm2_odp_group = get_flashcomm2_odp_group() p_tp_group = get_p_tp_group() assert mc2_group is not None assert otp_group is not None + assert flashcomm2_otp_group is not None + assert flashcomm2_odp_group is not None assert lmheadtp_group is not None assert p_tp_group is not None @@ -55,4 +70,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): assert _MC2 is None assert _LMTP is None assert _OTP is None + assert _FLASHCOMM2_OTP is None + assert _FLASHCOMM2_ODP is None assert _P_TP is None diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 82eb78ea..f947fc62 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -130,6 +130,10 @@ class AscendConfig: "Only support P node tp size lagger then D node tp size") self.SLO_limits_for_dynamic_batch = additional_config.get( "SLO_limits_for_dynamic_batch", -1) + from vllm_ascend.utils import \ + get_flashcomm2_oproj_tp_size_and_validate_config + self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( + self, vllm_config) class TorchairGraphConfig: diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index a0f1edd5..65bc5a47 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -11,7 +11,8 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context, set_forward_context) import vllm_ascend.envs as envs_ascend -from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model +from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx, + is_moe_model) if TYPE_CHECKING: from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod @@ -121,13 +122,17 @@ def set_ascend_forward_context( sp_enabled = enable_sp(vllm_config) and \ num_tokens is not None and num_tokens > 1000 forward_context.mmrs_fusion = mmrs_fusion + forward_context.num_tokens = num_tokens + forward_context.sp_enabled = sp_enabled + #TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 + forward_context.flashcomm_v2_enabled = flashcomm2_enable( + ) and tp_world_size > 1 and num_tokens is not None - if sp_enabled: + if (forward_context.sp_enabled + or forward_context.flashcomm_v2_enabled): pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size - forward_context.sp_enabled = sp_enabled - forward_context.num_tokens = num_tokens # set this for rope forward_oot using forward_context.is_first_layer = True @@ -179,7 +184,8 @@ def set_ascend_forward_context( if dp_world_size > 1 and forward_context.dp_metadata is not None: max_tokens_across_dp = \ forward_context.dp_metadata.max_tokens_across_dp_cpu.item() - if sp_enabled: + if (forward_context.sp_enabled + or forward_context.flashcomm_v2_enabled): padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size pad_size = padded_length - num_tokens diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 4885d4d1..9b5dde0f 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -2,12 +2,14 @@ from typing import Optional import torch from vllm.config import ParallelConfig, get_current_vllm_config -from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, +from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, + get_tp_group, get_world_group, init_model_parallel_group) import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import prefill_context_parallel_enable +from vllm_ascend.utils import (flashcomm2_enable, + prefill_context_parallel_enable) # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -15,6 +17,8 @@ _MLP_TP: Optional[GroupCoordinator] = None _OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None _P_TP: Optional[GroupCoordinator] = None +_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None +_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None def get_mc2_group() -> GroupCoordinator: @@ -34,6 +38,16 @@ def get_lmhead_tp_group() -> GroupCoordinator: return _LMTP +def get_flashcomm2_otp_group() -> GroupCoordinator: + return _FLASHCOMM2_OTP + + +def get_flashcomm2_odp_group() -> GroupCoordinator: + assert _FLASHCOMM2_ODP is not None, ( + "output data parallel group for flashcomm2 is not initialized") + return _FLASHCOMM2_ODP + + def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") return _MLP_TP @@ -165,6 +179,48 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="lmheadtp") + # TODO: Extract and unify the logic across different communication group. + if flashcomm2_enable(): + flashcomm2_otp_size = get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size + global_tp_size = get_tp_group().world_size + global_dp_size = get_dp_group().world_size + num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size // + flashcomm2_otp_size) + + global _FLASHCOMM2_OTP + global _FLASHCOMM2_ODP + + _FLASHCOMM2_OTP = None + _FLASHCOMM2_ODP = get_tp_group() + + if flashcomm2_otp_size > 1: + otp_group_ranks = [] + odp_group_ranks: list[list[int]] = [ + [] for _ in range(flashcomm2_otp_size * global_dp_size) + ] + + for dp_group_index in range(global_dp_size): + for i in range(num_fc2_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups + ranks.append(rank_idx) + odp_group_index = dp_group_index * flashcomm2_otp_size + j + odp_group_ranks[odp_group_index].append(rank_idx) + otp_group_ranks.append(ranks) + + _FLASHCOMM2_OTP = init_model_parallel_group( + otp_group_ranks, + get_world_group().local_rank, + backend, + group_name="flashcomm2_otp") + _FLASHCOMM2_ODP = init_model_parallel_group( + odp_group_ranks, + get_world_group().local_rank, + backend, + group_name="flashcomm2_odp") + def get_mlp_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" @@ -201,3 +257,15 @@ def destroy_ascend_model_parallel(): if _P_TP: _P_TP.destroy() _P_TP = None + + global _FLASHCOMM2_OTP + if _FLASHCOMM2_OTP and get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size != 1: + _FLASHCOMM2_OTP.destroy() + _FLASHCOMM2_OTP = None + + global _FLASHCOMM2_ODP + if _FLASHCOMM2_ODP and get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size != 1: + _FLASHCOMM2_ODP.destroy() + _FLASHCOMM2_ODP = None diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index a6b4081a..fdadfa24 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -132,6 +132,12 @@ env_variables: Dict[str, Callable[[], Any]] = { # This feature will get better performance when concurrency is large. "VLLM_ASCEND_ENABLE_FLASHCOMM1": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))), + # Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it. + # The specific value set will be used as the O-matrix TP group size for flashcomm2. + # For a detailed introduction to the parameters and the differences and applicable scenarios + # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. + "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": + lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), # Whether to enable MLP weight prefetch, only used in small concurrency. "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), @@ -185,4 +191,4 @@ def __getattr__(name: str): def __dir__(): - return list(env_variables.keys()) \ No newline at end of file + return list(env_variables.keys()) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 1271f8e9..2bffa44c 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -24,6 +24,7 @@ CustomLinearOp └── CustomRowParallelOp │ ├── MLPRowParallelOp │ ├── OProjRowParallelOp +| ├── Flashcomm2OProjRowParallelOp │ ├── MatmulAllreduceRowParallelOp │ └── SequenceRowParallelOp └── CustomReplicatedOp @@ -41,6 +42,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F import torch_npu +from torch import nn from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from vllm.distributed import (split_tensor_along_last_dim, @@ -49,9 +51,14 @@ from vllm.distributed import (split_tensor_along_last_dim, from vllm.distributed.parallel_state import get_tp_group from vllm.forward_context import get_forward_context -from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, + get_flashcomm2_otp_group, + get_mlp_tp_group, get_otp_group) from vllm_ascend.utils import (dense_optim_enable, enable_sp, + flashcomm2_enable, + get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, shared_expert_dp_enabled) @@ -263,6 +270,135 @@ class OProjRowParallelOp(CustomRowParallelOp): self.input_size_per_partition = self.layer.input_size_per_partition +class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer): + super().__init__(layer) + self.odp_group = get_flashcomm2_odp_group() + self.odp_size = self.odp_group.world_size + self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids( + get_tp_group().world_size) + self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() + self.layer._quant_comm_config = {} + + @property + def comm_group(self): + return get_flashcomm2_otp_group() + + @property + def tp_rank(self): + if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1: + return 0 + return self.comm_group.rank_in_group + + @property + def tp_size(self): + if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1: + return 1 + return self.comm_group.world_size + + def apply_impl( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer for Flashcomm2. + Input.ahspe = [batchsize*seqlength, headnum*headdim/TP] + Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize] + """ + # Handle input parallelism - split or use as-is + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = self.tp_rank + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # padding for all-to-all + forward_context = get_forward_context() + num_padding_tokens = forward_context.pad_size + if num_padding_tokens > 0: + input_parallel = nn.functional.pad(input_parallel, + (0, 0, 0, num_padding_tokens)) + + def otp_maybe_quant_comm(x): + + # Reorganize the tensor so that the batch id and rank id correspond to each other. + chunk_num = len(self.reorgnized_batch_ids) * len( + self.reorgnized_batch_ids[0]) + batch_size = x.size(0) + + assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})" + + batch_size_per_chunk = batch_size // chunk_num + # Indices of reorganized tensor + chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1]) + reorganized_chunks = chunked[self.group_indices] + send_buf = reorganized_chunks.flatten(1, 2) + + # all-to-all operation parameters + all2all_tp_size = self.odp_size + local_intermediate_size = x.size(1) + chunk_size = x.size(0) // all2all_tp_size + total_intermediate_size = local_intermediate_size * all2all_tp_size + + # Create receive buffer + recv_buf = torch.empty(total_intermediate_size * chunk_size, + dtype=x.dtype, + device=x.device) + + # Perform all-to-all communication + dist.all_to_all_single(recv_buf, + send_buf, + group=self.odp_group.device_group) + + return recv_buf.view(all2all_tp_size, chunk_size, + -1).transpose(0, 1).reshape(chunk_size, -1) + + if not hasattr(self, "_quant_comm_config"): + self.layer._quant_comm_config = {} + self.layer._quant_comm_config[ + "communication_fn"] = otp_maybe_quant_comm + actual_quant_method = getattr(self.quant_method, 'quant_method', + self.quant_method) + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + if not isinstance(actual_quant_method, AscendW8A8LinearMethod): + # Check if w8a8 quantization is enabled. If not, communicate immediately. + input_parallel = otp_maybe_quant_comm(input_parallel) + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + # output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate] + if self.tp_size > 1: + # flashcomm2 with reduce-scatter + output = self.comm_group.reduce_scatter(output_parallel, dim=0) + else: + output = output_parallel + + if not forward_context.sp_enabled: + # flashcomm1 not enabled + output = get_tp_group().all_gather(output, 0) + if num_padding_tokens > 0: + output = output[:-num_padding_tokens] + + # Handle bias return based on configuration + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.input_size_per_partition = self.layer.input_size_per_partition + + class MatmulAllreduceRowParallelOp(CustomRowParallelOp): _HCOMM_INFO = None @@ -487,13 +623,17 @@ def _get_column_parallel_op( def _get_row_parallel_op( prefix, layer ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, - MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]: + Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, + SequenceRowParallelOp]]: if "down_proj" in prefix and mlp_tp_enable(): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): return OProjRowParallelOp(layer) if matmul_allreduce_enable(): return MatmulAllreduceRowParallelOp(layer) + if flashcomm2_enable(): + if "o_proj" in prefix or "out_proj" in prefix: + return Flashcomm2OProjRowParallelOp(layer) if enable_sp(): if "shared_expert" in prefix: return None @@ -509,6 +649,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct): return None, 0, 1 custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, MLPRowParallelOp, OProjRowParallelOp, + Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp]] = None if direct == "row": diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 5960d2f8..c0760c80 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -35,12 +35,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, + get_mlp_tp_group, get_otp_group) from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, - oproj_tp_enable) +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable, + mlp_tp_enable, oproj_tp_enable) from .utils import get_quant_method @@ -348,6 +350,13 @@ class AscendLinearMethod(LinearMethodBase): tp_rank = get_otp_group().rank_in_group elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): tp_rank = get_mlp_tp_group().rank_in_group + elif (layer.prefix.find("o_proj") != -1 or + layer.prefix.find("out_proj") != -1) and flashcomm2_enable(): + if get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size == 1: + tp_rank = 0 + else: + tp_rank = get_flashcomm2_otp_group().rank_in_group else: tp_rank = get_tensor_model_parallel_rank() else: diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 07b7cac2..dcd692ac 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -115,12 +115,30 @@ class AscendW8A8LinearMethod: weight=layer.weight, start_flag=x, ) - # quant - x = quant_per_tensor( - x, - layer.aclnn_input_scale_reciprocal, - layer.aclnn_input_offset, - ) + + quant_comm_config = getattr(layer, "_quant_comm_config", {}) + comm_fn = quant_comm_config.get("communication_fn") + enable_flashcomm2_quant_comm = comm_fn is not None and ( + "o_proj" in layer.prefix or "out_proj" in layer.prefix) + if enable_flashcomm2_quant_comm: + quant_input_x = x.contiguous().view( + -1, layer.aclnn_input_scale_reciprocal.size(0)) + quant_x = quant_per_tensor( + quant_input_x, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + comm_input = quant_x.view(x.size(0), -1) + assert comm_fn is not None + x = comm_fn(comm_input) + else: + # quant + x = quant_per_tensor( + x, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + # prefetch qkvo_proj.weight postprocess if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 46e80606..ec3c6f03 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -814,3 +814,68 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool: _HAS_LAYER_IDX = hasattr(model_instance, "model") and \ hasattr(model_instance.model, "start_layer") return _HAS_LAYER_IDX + + +def flashcomm2_enable() -> bool: + return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0 + + +def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config, + vllm_config): + flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE + global_tp_size = vllm_config.parallel_config.tensor_parallel_size + + if not flashcomm2_enable(): + logger.info("FLASHCOMM2 not enable.") + return flashcomm2_oproj_tp_size + + logger.info( + f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}" + ) + if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1: + logger.warning_once( + "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." + ) + if ascend_config.oproj_tensor_parallel_size is not None: + raise AssertionError( + "flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size" + ) + if global_tp_size <= flashcomm2_oproj_tp_size: + raise AssertionError( + f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})" + ) + if global_tp_size % flashcomm2_oproj_tp_size != 0: + raise AssertionError( + f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})" + ) + if vllm_config.kv_transfer_config is None: + logger.warning_once( + "It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation." + ) + if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer: + raise AssertionError( + "FLASHCOMM2 primarily targets P-scenario deployments, " + "with additional support for hybrid deployment scenarios. " + "It is not applicable in D-scenario environments.") + + return flashcomm2_oproj_tp_size + + +def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: + # Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain. + # For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2, + # the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]]. + flashcomm2_otp_size = get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size + num_oproj_tensor_parallel_groups: int = (global_tp_size // + flashcomm2_otp_size) + + reorgnized_batch_ids = [] + for i in range(num_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + rank_idx = i + j * num_oproj_tensor_parallel_groups + ranks.append(rank_idx) + reorgnized_batch_ids.append(ranks) + + return reorgnized_batch_ids