[main] flashcomm_v1 optim in Qwen Dense Models (#2802)

### What this PR does / why we need it?
Flashcomm_v1 optim in Qwen Dense Models.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
5e537f45b4

Co-authored-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
rjg-lyh
2025-09-08 22:52:24 +08:00
committed by GitHub
parent 4df8df5b94
commit 1bbb20ea13
11 changed files with 362 additions and 20 deletions

View File

@@ -26,20 +26,18 @@ from torch.nn.parameter import Parameter
from vllm.distributed import divide, split_tensor_along_last_dim
from vllm.distributed.parallel_state import get_tp_group
from vllm.lora.utils import LinearBase
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
MergedColumnParallelLinear,
QuantizeMethodBase,
RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
mlp_tp_enable, oproj_tp_enable)
_HCOMM_INFO = None
@@ -150,6 +148,9 @@ class AscendRowParallelLinear(RowParallelLinear):
comm_group = get_tp_group()
self.forward_type = "matmul_allreduce"
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
elif dense_optim_enable():
comm_group = get_tp_group()
self.forward_type = "dense_optim"
else:
comm_group = get_tp_group()
self.forward_type = "normal"
@@ -231,6 +232,8 @@ class AscendRowParallelLinear(RowParallelLinear):
return self._forward_mlp_tp(input_)
elif self.forward_type == "matmul_allreduce":
return self._forward_matmul_allreduce(input_)
elif self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
@@ -332,6 +335,39 @@ class AscendRowParallelLinear(RowParallelLinear):
return output
return output, output_bias
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.tp_size == 1 or not self.reduce_results:
output = self.quant_method.apply(self, input_parallel, bias=bias_)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
@@ -357,15 +393,18 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
*,
return_bias: bool = True,
):
self.comm_group = None
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
self.comm_group = get_mlp_tp_group()
comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
elif dense_optim_enable():
comm_group = get_tp_group()
self.forward_type = "dense_optim"
else:
self.comm_group = get_tp_group()
comm_group = get_tp_group()
self.forward_type = "normal_tp"
self.tp_rank = self.comm_group.rank_in_group
self.tp_size = self.comm_group.world_size
self.comm_group = comm_group
self.tp_rank = comm_group.rank_in_group
self.tp_size = comm_group.world_size
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
@@ -387,6 +426,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "mlp_tp":
return self._forward_mlp_tp(input_)
elif self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
@@ -405,6 +446,138 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
return output
return output, output_bias
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendQKVParallelLinear(QKVParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
if dense_optim_enable():
self.forward_type = "dense_optim"
else:
self.forward_type = "normal_tp"
self.comm_group = get_tp_group()
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = self.comm_group.world_size
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
layer_num = self.prefix.split('.')[2]
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendLinearBase(LinearBase):
@@ -438,4 +611,4 @@ class AscendLinearBase(LinearBase):
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
self.disable_tp = disable_tp