[Feat] flashcomm_v2 optim solution (#3232)
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -195,6 +195,7 @@ jobs:
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_fc2_for_qwen3_moe
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
|
||||
|
||||
|
||||
@@ -189,6 +189,26 @@ def test_sp_for_qwen3_moe() -> None:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"})
|
||||
def test_fc2_for_qwen3_moe() -> None:
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=5,
|
||||
temperature=0.0,
|
||||
top_k=50,
|
||||
top_p=0.9)
|
||||
|
||||
with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"),
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
enable_expert_parallel=True,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None:
|
||||
example_prompts = [
|
||||
|
||||
@@ -34,6 +34,7 @@ MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"]
|
||||
reason="aclgraph only support on v1")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "0"})
|
||||
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
|
||||
def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
|
||||
del os.environ["VLLM_WORKER_MULTIPROC_METHOD"]
|
||||
|
||||
@@ -4,9 +4,10 @@ import pytest
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
_LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel,
|
||||
get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group,
|
||||
init_ascend_model_parallel)
|
||||
_FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP,
|
||||
destroy_ascend_model_parallel, get_flashcomm2_odp_group,
|
||||
get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group,
|
||||
get_otp_group, get_p_tp_group, init_ascend_model_parallel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -21,9 +22,13 @@ def mock_distributed():
|
||||
with patch('torch.distributed.is_initialized', return_value=True), \
|
||||
patch('torch.distributed.get_world_size', return_value=8), \
|
||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
|
||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
|
||||
mock_group.return_value.local_rank = 0
|
||||
mock_group.return_value.device_group = MagicMock()
|
||||
mock_tp_group.return_value.world_size = 4
|
||||
mock_dp_group.return_value.world_size = 2
|
||||
yield
|
||||
|
||||
|
||||
@@ -31,23 +36,33 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.lmhead_tensor_parallel_size = 2
|
||||
mock_ascend_config.oproj_tensor_parallel_size = 2
|
||||
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
|
||||
mock_ascend_config.pd_tp_ratio = 2
|
||||
mock_ascend_config.num_head_replica = 0
|
||||
mock_ascend_config.pd_head_ratio = 2
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.kv_transfer_config.is_kv_producer = True
|
||||
mock_envs_ascend = MagicMock()
|
||||
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
|
||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \
|
||||
patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \
|
||||
patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config):
|
||||
init_ascend_model_parallel(parallel_config)
|
||||
|
||||
mc2_group = get_mc2_group()
|
||||
lmheadtp_group = get_lmhead_tp_group()
|
||||
otp_group = get_otp_group()
|
||||
flashcomm2_otp_group = get_flashcomm2_otp_group()
|
||||
flashcomm2_odp_group = get_flashcomm2_odp_group()
|
||||
p_tp_group = get_p_tp_group()
|
||||
assert mc2_group is not None
|
||||
assert otp_group is not None
|
||||
assert flashcomm2_otp_group is not None
|
||||
assert flashcomm2_odp_group is not None
|
||||
assert lmheadtp_group is not None
|
||||
assert p_tp_group is not None
|
||||
|
||||
@@ -55,4 +70,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
assert _MC2 is None
|
||||
assert _LMTP is None
|
||||
assert _OTP is None
|
||||
assert _FLASHCOMM2_OTP is None
|
||||
assert _FLASHCOMM2_ODP is None
|
||||
assert _P_TP is None
|
||||
|
||||
@@ -130,6 +130,10 @@ class AscendConfig:
|
||||
"Only support P node tp size lagger then D node tp size")
|
||||
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
||||
"SLO_limits_for_dynamic_batch", -1)
|
||||
from vllm_ascend.utils import \
|
||||
get_flashcomm2_oproj_tp_size_and_validate_config
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
|
||||
self, vllm_config)
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
|
||||
@@ -11,7 +11,8 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model
|
||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx,
|
||||
is_moe_model)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
@@ -121,13 +122,17 @@ def set_ascend_forward_context(
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
if sp_enabled:
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
forward_context.num_tokens = num_tokens
|
||||
|
||||
# set this for rope forward_oot using
|
||||
forward_context.is_first_layer = True
|
||||
@@ -179,7 +184,8 @@ def set_ascend_forward_context(
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = \
|
||||
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if sp_enabled:
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
|
||||
@@ -2,12 +2,14 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
||||
get_tp_group, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
from vllm_ascend.utils import (flashcomm2_enable,
|
||||
prefill_context_parallel_enable)
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
@@ -15,6 +17,8 @@ _MLP_TP: Optional[GroupCoordinator] = None
|
||||
_OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
@@ -34,6 +38,16 @@ def get_lmhead_tp_group() -> GroupCoordinator:
|
||||
return _LMTP
|
||||
|
||||
|
||||
def get_flashcomm2_otp_group() -> GroupCoordinator:
|
||||
return _FLASHCOMM2_OTP
|
||||
|
||||
|
||||
def get_flashcomm2_odp_group() -> GroupCoordinator:
|
||||
assert _FLASHCOMM2_ODP is not None, (
|
||||
"output data parallel group for flashcomm2 is not initialized")
|
||||
return _FLASHCOMM2_ODP
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
@@ -165,6 +179,48 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
backend,
|
||||
group_name="lmheadtp")
|
||||
|
||||
# TODO: Extract and unify the logic across different communication group.
|
||||
if flashcomm2_enable():
|
||||
flashcomm2_otp_size = get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size
|
||||
global_tp_size = get_tp_group().world_size
|
||||
global_dp_size = get_dp_group().world_size
|
||||
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
|
||||
flashcomm2_otp_size)
|
||||
|
||||
global _FLASHCOMM2_OTP
|
||||
global _FLASHCOMM2_ODP
|
||||
|
||||
_FLASHCOMM2_OTP = None
|
||||
_FLASHCOMM2_ODP = get_tp_group()
|
||||
|
||||
if flashcomm2_otp_size > 1:
|
||||
otp_group_ranks = []
|
||||
odp_group_ranks: list[list[int]] = [
|
||||
[] for _ in range(flashcomm2_otp_size * global_dp_size)
|
||||
]
|
||||
|
||||
for dp_group_index in range(global_dp_size):
|
||||
for i in range(num_fc2_oproj_tensor_parallel_groups):
|
||||
ranks = []
|
||||
for j in range(flashcomm2_otp_size):
|
||||
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
|
||||
ranks.append(rank_idx)
|
||||
odp_group_index = dp_group_index * flashcomm2_otp_size + j
|
||||
odp_group_ranks[odp_group_index].append(rank_idx)
|
||||
otp_group_ranks.append(ranks)
|
||||
|
||||
_FLASHCOMM2_OTP = init_model_parallel_group(
|
||||
otp_group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="flashcomm2_otp")
|
||||
_FLASHCOMM2_ODP = init_model_parallel_group(
|
||||
odp_group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="flashcomm2_odp")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
@@ -201,3 +257,15 @@ def destroy_ascend_model_parallel():
|
||||
if _P_TP:
|
||||
_P_TP.destroy()
|
||||
_P_TP = None
|
||||
|
||||
global _FLASHCOMM2_OTP
|
||||
if _FLASHCOMM2_OTP and get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size != 1:
|
||||
_FLASHCOMM2_OTP.destroy()
|
||||
_FLASHCOMM2_OTP = None
|
||||
|
||||
global _FLASHCOMM2_ODP
|
||||
if _FLASHCOMM2_ODP and get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size != 1:
|
||||
_FLASHCOMM2_ODP.destroy()
|
||||
_FLASHCOMM2_ODP = None
|
||||
|
||||
@@ -132,6 +132,12 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# This feature will get better performance when concurrency is large.
|
||||
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
|
||||
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
|
||||
# The specific value set will be used as the O-matrix TP group size for flashcomm2.
|
||||
# For a detailed introduction to the parameters and the differences and applicable scenarios
|
||||
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
||||
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||
@@ -185,4 +191,4 @@ def __getattr__(name: str):
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
||||
return list(env_variables.keys())
|
||||
|
||||
@@ -24,6 +24,7 @@ CustomLinearOp
|
||||
└── CustomRowParallelOp
|
||||
│ ├── MLPRowParallelOp
|
||||
│ ├── OProjRowParallelOp
|
||||
| ├── Flashcomm2OProjRowParallelOp
|
||||
│ ├── MatmulAllreduceRowParallelOp
|
||||
│ └── SequenceRowParallelOp
|
||||
└── CustomReplicatedOp
|
||||
@@ -41,6 +42,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (split_tensor_along_last_dim,
|
||||
@@ -49,9 +51,14 @@ from vllm.distributed import (split_tensor_along_last_dim,
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
|
||||
get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable, shared_expert_dp_enabled)
|
||||
|
||||
@@ -263,6 +270,135 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
|
||||
class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.odp_group = get_flashcomm2_odp_group()
|
||||
self.odp_size = self.odp_group.world_size
|
||||
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(
|
||||
get_tp_group().world_size)
|
||||
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
|
||||
self.layer._quant_comm_config = {}
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_flashcomm2_otp_group()
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
|
||||
return 0
|
||||
return self.comm_group.rank_in_group
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
|
||||
return 1
|
||||
return self.comm_group.world_size
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer for Flashcomm2.
|
||||
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
|
||||
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
|
||||
"""
|
||||
# Handle input parallelism - split or use as-is
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = self.tp_rank
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# padding for all-to-all
|
||||
forward_context = get_forward_context()
|
||||
num_padding_tokens = forward_context.pad_size
|
||||
if num_padding_tokens > 0:
|
||||
input_parallel = nn.functional.pad(input_parallel,
|
||||
(0, 0, 0, num_padding_tokens))
|
||||
|
||||
def otp_maybe_quant_comm(x):
|
||||
|
||||
# Reorganize the tensor so that the batch id and rank id correspond to each other.
|
||||
chunk_num = len(self.reorgnized_batch_ids) * len(
|
||||
self.reorgnized_batch_ids[0])
|
||||
batch_size = x.size(0)
|
||||
|
||||
assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})"
|
||||
|
||||
batch_size_per_chunk = batch_size // chunk_num
|
||||
# Indices of reorganized tensor
|
||||
chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1])
|
||||
reorganized_chunks = chunked[self.group_indices]
|
||||
send_buf = reorganized_chunks.flatten(1, 2)
|
||||
|
||||
# all-to-all operation parameters
|
||||
all2all_tp_size = self.odp_size
|
||||
local_intermediate_size = x.size(1)
|
||||
chunk_size = x.size(0) // all2all_tp_size
|
||||
total_intermediate_size = local_intermediate_size * all2all_tp_size
|
||||
|
||||
# Create receive buffer
|
||||
recv_buf = torch.empty(total_intermediate_size * chunk_size,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
# Perform all-to-all communication
|
||||
dist.all_to_all_single(recv_buf,
|
||||
send_buf,
|
||||
group=self.odp_group.device_group)
|
||||
|
||||
return recv_buf.view(all2all_tp_size, chunk_size,
|
||||
-1).transpose(0, 1).reshape(chunk_size, -1)
|
||||
|
||||
if not hasattr(self, "_quant_comm_config"):
|
||||
self.layer._quant_comm_config = {}
|
||||
self.layer._quant_comm_config[
|
||||
"communication_fn"] = otp_maybe_quant_comm
|
||||
actual_quant_method = getattr(self.quant_method, 'quant_method',
|
||||
self.quant_method)
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
|
||||
# Check if w8a8 quantization is enabled. If not, communicate immediately.
|
||||
input_parallel = otp_maybe_quant_comm(input_parallel)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
|
||||
if self.tp_size > 1:
|
||||
# flashcomm2 with reduce-scatter
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
# flashcomm1 not enabled
|
||||
output = get_tp_group().all_gather(output, 0)
|
||||
if num_padding_tokens > 0:
|
||||
output = output[:-num_padding_tokens]
|
||||
|
||||
# Handle bias return based on configuration
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
|
||||
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
_HCOMM_INFO = None
|
||||
|
||||
@@ -487,13 +623,17 @@ def _get_column_parallel_op(
|
||||
def _get_row_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
|
||||
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]]:
|
||||
if "down_proj" in prefix and mlp_tp_enable():
|
||||
return MLPRowParallelOp(layer)
|
||||
if "o_proj" in prefix and oproj_tp_enable():
|
||||
return OProjRowParallelOp(layer)
|
||||
if matmul_allreduce_enable():
|
||||
return MatmulAllreduceRowParallelOp(layer)
|
||||
if flashcomm2_enable():
|
||||
if "o_proj" in prefix or "out_proj" in prefix:
|
||||
return Flashcomm2OProjRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
@@ -509,6 +649,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
return None, 0, 1
|
||||
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
MLPRowParallelOp, OProjRowParallelOp,
|
||||
Flashcomm2OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]] = None
|
||||
if direct == "row":
|
||||
|
||||
@@ -35,12 +35,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
|
||||
mlp_tp_enable, oproj_tp_enable)
|
||||
|
||||
from .utils import get_quant_method
|
||||
|
||||
@@ -348,6 +350,13 @@ class AscendLinearMethod(LinearMethodBase):
|
||||
tp_rank = get_otp_group().rank_in_group
|
||||
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
|
||||
tp_rank = get_mlp_tp_group().rank_in_group
|
||||
elif (layer.prefix.find("o_proj") != -1 or
|
||||
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
|
||||
if get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size == 1:
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_rank = get_flashcomm2_otp_group().rank_in_group
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
else:
|
||||
|
||||
@@ -115,12 +115,30 @@ class AscendW8A8LinearMethod:
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# quant
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
|
||||
quant_comm_config = getattr(layer, "_quant_comm_config", {})
|
||||
comm_fn = quant_comm_config.get("communication_fn")
|
||||
enable_flashcomm2_quant_comm = comm_fn is not None and (
|
||||
"o_proj" in layer.prefix or "out_proj" in layer.prefix)
|
||||
if enable_flashcomm2_quant_comm:
|
||||
quant_input_x = x.contiguous().view(
|
||||
-1, layer.aclnn_input_scale_reciprocal.size(0))
|
||||
quant_x = quant_per_tensor(
|
||||
quant_input_x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
comm_input = quant_x.view(x.size(0), -1)
|
||||
assert comm_fn is not None
|
||||
x = comm_fn(comm_input)
|
||||
else:
|
||||
# quant
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
|
||||
# prefetch qkvo_proj.weight postprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||
|
||||
@@ -814,3 +814,68 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
|
||||
hasattr(model_instance.model, "start_layer")
|
||||
return _HAS_LAYER_IDX
|
||||
|
||||
|
||||
def flashcomm2_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
||||
|
||||
|
||||
def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
|
||||
vllm_config):
|
||||
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
if not flashcomm2_enable():
|
||||
logger.info("FLASHCOMM2 not enable.")
|
||||
return flashcomm2_oproj_tp_size
|
||||
|
||||
logger.info(
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
|
||||
)
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
|
||||
)
|
||||
if ascend_config.oproj_tensor_parallel_size is not None:
|
||||
raise AssertionError(
|
||||
"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
|
||||
)
|
||||
if global_tp_size <= flashcomm2_oproj_tp_size:
|
||||
raise AssertionError(
|
||||
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})"
|
||||
)
|
||||
if global_tp_size % flashcomm2_oproj_tp_size != 0:
|
||||
raise AssertionError(
|
||||
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
|
||||
)
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation."
|
||||
)
|
||||
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, "
|
||||
"with additional support for hybrid deployment scenarios. "
|
||||
"It is not applicable in D-scenario environments.")
|
||||
|
||||
return flashcomm2_oproj_tp_size
|
||||
|
||||
|
||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
|
||||
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
|
||||
# the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]].
|
||||
flashcomm2_otp_size = get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size
|
||||
num_oproj_tensor_parallel_groups: int = (global_tp_size //
|
||||
flashcomm2_otp_size)
|
||||
|
||||
reorgnized_batch_ids = []
|
||||
for i in range(num_oproj_tensor_parallel_groups):
|
||||
ranks = []
|
||||
for j in range(flashcomm2_otp_size):
|
||||
rank_idx = i + j * num_oproj_tensor_parallel_groups
|
||||
ranks.append(rank_idx)
|
||||
reorgnized_batch_ids.append(ranks)
|
||||
|
||||
return reorgnized_batch_ids
|
||||
|
||||
Reference in New Issue
Block a user