[main] flashcomm_v1 optim in Qwen Dense Models (#2802)
### What this PR does / why we need it?
Flashcomm_v1 optim in Qwen Dense Models.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.1.1
- vLLM main:
5e537f45b4
Co-authored-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
@@ -26,20 +26,18 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import divide, split_tensor_along_last_dim
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.lora.utils import LinearBase
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QuantizeMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.linear import ( # noqa
|
||||
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear,
|
||||
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
|
||||
RowParallelLinear, UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
|
||||
mlp_tp_enable, oproj_tp_enable)
|
||||
|
||||
_HCOMM_INFO = None
|
||||
|
||||
@@ -150,6 +148,9 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "matmul_allreduce"
|
||||
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
|
||||
elif dense_optim_enable():
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "dense_optim"
|
||||
else:
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "normal"
|
||||
@@ -231,6 +232,8 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
return self._forward_mlp_tp(input_)
|
||||
elif self.forward_type == "matmul_allreduce":
|
||||
return self._forward_matmul_allreduce(input_)
|
||||
elif self.forward_type == "dense_optim":
|
||||
return self._forward_dense_optim(input_)
|
||||
else:
|
||||
return super().forward(input_)
|
||||
|
||||
@@ -332,6 +335,39 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def _forward_dense_optim(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
|
||||
if self.tp_size == 1 or not self.reduce_results:
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
@@ -357,15 +393,18 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.comm_group = None
|
||||
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
|
||||
self.comm_group = get_mlp_tp_group()
|
||||
comm_group = get_mlp_tp_group()
|
||||
self.forward_type = "mlp_tp"
|
||||
elif dense_optim_enable():
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "dense_optim"
|
||||
else:
|
||||
self.comm_group = get_tp_group()
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "normal_tp"
|
||||
self.tp_rank = self.comm_group.rank_in_group
|
||||
self.tp_size = self.comm_group.world_size
|
||||
self.comm_group = comm_group
|
||||
self.tp_rank = comm_group.rank_in_group
|
||||
self.tp_size = comm_group.world_size
|
||||
|
||||
self.output_sizes = output_sizes
|
||||
assert all(output_size % self.tp_size == 0
|
||||
@@ -387,6 +426,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.forward_type == "mlp_tp":
|
||||
return self._forward_mlp_tp(input_)
|
||||
elif self.forward_type == "dense_optim":
|
||||
return self._forward_dense_optim(input_)
|
||||
else:
|
||||
return super().forward(input_)
|
||||
|
||||
@@ -405,6 +446,138 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def _forward_dense_optim(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class AscendQKVParallelLinear(QKVParallelLinear):
|
||||
"""Linear layers for the attention's QKV transformation.
|
||||
|
||||
Linear layers for the linear transformation of the query, key, and value
|
||||
vectors in the attention layer. The weight matrix is concatenated along
|
||||
the output dimension. The layer is parallelized along the head dimension.
|
||||
When the number of key/value heads is smaller than the number of query
|
||||
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
||||
be replicated while the query heads are partitioned.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
if dense_optim_enable():
|
||||
self.forward_type = "dense_optim"
|
||||
else:
|
||||
self.forward_type = "normal_tp"
|
||||
self.comm_group = get_tp_group()
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = self.comm_group.world_size
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size,
|
||||
self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads +
|
||||
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
AscendColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.forward_type == "dense_optim":
|
||||
return self._forward_dense_optim(input_)
|
||||
else:
|
||||
return super().forward(input_)
|
||||
|
||||
def _forward_dense_optim(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
layer_num = self.prefix.split('.')[2]
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
input_, layer_num != '0')
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class AscendLinearBase(LinearBase):
|
||||
|
||||
@@ -438,4 +611,4 @@ class AscendLinearBase(LinearBase):
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.disable_tp = disable_tp
|
||||
self.disable_tp = disable_tp
|
||||
Reference in New Issue
Block a user