### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/ops/layer_shard_linear.py`|
|`vllm_ascend/ops/linear.py`|
|`vllm_ascend/ops/linear_op.py`|
|`vllm_ascend/worker/worker.py`|
| ` vllm_ascend/patch/worker/patch_bert.py` |
| ` vllm_ascend/patch/worker/patch_deepseek.py` |
| ` vllm_ascend/patch/worker/patch_distributed.py` |
| ` vllm_ascend/patch/worker/patch_module.py` |
| ` vllm_ascend/patch/worker/patch_multimodal_merge.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` |
| ` vllm_ascend/patch/worker/patch_rejection_sampler.py` |
| ` vllm_ascend/patch/worker/patch_rope.py` |
| ` vllm_ascend/patch/worker/patch_triton.py` |
| ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` |
| ` vllm_ascend/patch/worker/patch_v2_egale.py` |
|` vllm_ascend/worker/npu_input_batch.py`|
|` vllm_ascend/worker/v2/aclgraph_utils.py`|
|` vllm_ascend/worker/v2/attn_utils.py`|
|` vllm_ascend/worker/v2/model_runner.py`|
|` vllm_ascend/worker/v2/sample/gumbel.py`|
|` vllm_ascend/worker/v2/sample/penalties.py`|
|` vllm_ascend/worker/v2/sample/sampler.py`|
|` vllm_ascend/worker/v2/spec_decode/__init__.py`|
|` vllm_ascend/worker/v2/spec_decode/eagle.py`|
|` vllm_ascend/worker/v2/states.py`|
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -31,16 +31,18 @@ CustomLinearOp
|
||||
└── CustomReplicatedOp
|
||||
How to extend a new linear op? Taking column parallel op as an example:
|
||||
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
||||
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
|
||||
2. [Optional] The default communication group is the TP group. If a custom communication group is needed,
|
||||
override the comm_group method
|
||||
3. Override the apply method according to requirements, which will replace the original linear.forward
|
||||
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments
|
||||
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
|
||||
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on
|
||||
prefix and configuration judgments
|
||||
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in
|
||||
get_row_parallel_op.
|
||||
"""
|
||||
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -49,27 +51,37 @@ import torch_npu
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend import envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
|
||||
get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
get_flashcomm2_odp_group,
|
||||
get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group,
|
||||
)
|
||||
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable, shared_expert_dp_enabled,
|
||||
get_weight_prefetch_method)
|
||||
from vllm_ascend.utils import (
|
||||
enable_dsa_cp,
|
||||
enable_dsa_cp_with_layer_shard,
|
||||
enable_sp,
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
get_weight_prefetch_method,
|
||||
matmul_allreduce_enable,
|
||||
mlp_tp_enable,
|
||||
oproj_tp_enable,
|
||||
shared_expert_dp_enabled,
|
||||
)
|
||||
|
||||
|
||||
class CustomLinearOp:
|
||||
|
||||
def __init__(self, layer):
|
||||
self.layer = layer
|
||||
self.bias = None
|
||||
@@ -112,7 +124,6 @@ class CustomLinearOp:
|
||||
|
||||
|
||||
class CustomColumnParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.gather_output = None
|
||||
@@ -123,7 +134,6 @@ class CustomColumnParallelOp(CustomLinearOp):
|
||||
|
||||
|
||||
class CustomRowParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.reduce_results = None
|
||||
@@ -140,7 +150,9 @@ class CustomRowParallelOp(CustomLinearOp):
|
||||
output, output_bias = self.apply_impl(input_)
|
||||
weight_prefetch_method = get_weight_prefetch_method()
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_GATE_UP, output, self.prefix)
|
||||
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(
|
||||
weight_prefetch_method.MLP_GATE_UP, output, self.prefix
|
||||
)
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
@@ -148,7 +160,6 @@ class CustomRowParallelOp(CustomLinearOp):
|
||||
|
||||
|
||||
class CustomReplicatedOp(CustomLinearOp):
|
||||
|
||||
def apply_impl(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
@@ -160,7 +171,6 @@ class CustomReplicatedOp(CustomLinearOp):
|
||||
|
||||
|
||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@@ -171,7 +181,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
@@ -183,7 +193,6 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
|
||||
class MLPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@@ -191,22 +200,16 @@ class MLPRowParallelOp(CustomRowParallelOp):
|
||||
def comm_group(self):
|
||||
return get_mlp_tp_group()
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.layer.bias
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.layer.bias
|
||||
output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
|
||||
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
@@ -214,7 +217,6 @@ class MLPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
|
||||
class OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@@ -225,13 +227,11 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Prepare tensors for all-to-all communication
|
||||
@@ -241,27 +241,19 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
# Reshape tensor for efficient cross-device transfer:
|
||||
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
|
||||
send_buf = (input_parallel.reshape(-1,
|
||||
self.tp_size, chunk_size).transpose(
|
||||
0, 1).contiguous().view(-1))
|
||||
send_buf = input_parallel.reshape(-1, self.tp_size, chunk_size).transpose(0, 1).contiguous().view(-1)
|
||||
|
||||
# Create receive buffer
|
||||
recv_buf = torch.empty(total_batch_size * chunk_size,
|
||||
dtype=input_parallel.dtype,
|
||||
device=input_parallel.device)
|
||||
recv_buf = torch.empty(total_batch_size * chunk_size, dtype=input_parallel.dtype, device=input_parallel.device)
|
||||
|
||||
# Perform all-to-all communication
|
||||
dist.all_to_all_single(recv_buf,
|
||||
send_buf,
|
||||
group=self.comm_group.device_group)
|
||||
dist.all_to_all_single(recv_buf, send_buf, group=self.comm_group.device_group)
|
||||
input_parallel = recv_buf.view(total_batch_size, chunk_size)
|
||||
|
||||
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
|
||||
|
||||
# otp-specific: Combine partial results across devices
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
@@ -278,14 +270,12 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
|
||||
class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.odp_group = get_flashcomm2_odp_group()
|
||||
self.odp_size = self.odp_group.world_size
|
||||
self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
|
||||
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(
|
||||
get_tp_group().world_size)
|
||||
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size)
|
||||
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
|
||||
self.layer._quant_comm_config = {}
|
||||
|
||||
@@ -308,32 +298,28 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
"""Linear layer for Flashcomm2.
|
||||
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
|
||||
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
|
||||
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
|
||||
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
|
||||
"""
|
||||
# Handle input parallelism - split or use as-is
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = self.tp_rank
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# padding for all-to-all
|
||||
forward_context = get_forward_context()
|
||||
num_padding_tokens = forward_context.pad_size
|
||||
if num_padding_tokens > 0:
|
||||
input_parallel = nn.functional.pad(input_parallel,
|
||||
(0, 0, 0, num_padding_tokens))
|
||||
input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
|
||||
|
||||
def otp_maybe_quant_comm(x):
|
||||
|
||||
# Reorganize the tensor so that the batch id and rank id correspond to each other.
|
||||
chunk_num = len(self.reorgnized_batch_ids) * len(
|
||||
self.reorgnized_batch_ids[0])
|
||||
chunk_num = len(self.reorgnized_batch_ids) * len(self.reorgnized_batch_ids[0])
|
||||
batch_size = x.size(0)
|
||||
|
||||
assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})"
|
||||
@@ -352,26 +338,19 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
total_intermediate_size = local_intermediate_size * all2all_tp_size
|
||||
|
||||
# Create receive buffer
|
||||
recv_buf = torch.empty(total_intermediate_size * chunk_size,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
recv_buf = torch.empty(total_intermediate_size * chunk_size, dtype=x.dtype, device=x.device)
|
||||
|
||||
# Perform all-to-all communication
|
||||
dist.all_to_all_single(recv_buf,
|
||||
send_buf,
|
||||
group=self.odp_group.device_group)
|
||||
dist.all_to_all_single(recv_buf, send_buf, group=self.odp_group.device_group)
|
||||
|
||||
return recv_buf.view(all2all_tp_size, chunk_size,
|
||||
-1).transpose(0, 1).reshape(chunk_size, -1)
|
||||
return recv_buf.view(all2all_tp_size, chunk_size, -1).transpose(0, 1).reshape(chunk_size, -1)
|
||||
|
||||
if not hasattr(self, "_quant_comm_config"):
|
||||
self.layer._quant_comm_config = {}
|
||||
self.layer._quant_comm_config[
|
||||
"communication_fn"] = otp_maybe_quant_comm
|
||||
actual_quant_method = getattr(self.quant_method, 'quant_method',
|
||||
self.quant_method)
|
||||
from vllm_ascend.quantization.methods.w8a8_static import \
|
||||
AscendW8A8LinearMethod
|
||||
self.layer._quant_comm_config["communication_fn"] = otp_maybe_quant_comm
|
||||
actual_quant_method = getattr(self.quant_method, "quant_method", self.quant_method)
|
||||
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
|
||||
|
||||
if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
|
||||
# Check if w8a8 quantization is enabled. If not, communicate immediately.
|
||||
input_parallel = otp_maybe_quant_comm(input_parallel)
|
||||
@@ -382,9 +361,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
|
||||
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
|
||||
if self.tp_size > 1:
|
||||
# flashcomm2 with reduce-scatter
|
||||
@@ -408,8 +385,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
flashcomm2_oshard_manager.register_layer(self.layer,
|
||||
prefetch_step=1)
|
||||
flashcomm2_oshard_manager.register_layer(self.layer, prefetch_step=1)
|
||||
|
||||
|
||||
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
@@ -419,28 +395,22 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
super().__init__(layer)
|
||||
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation."""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.layer.weight.t(),
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
output = torch_npu.npu_mm_all_reduce_base(
|
||||
input_parallel, self.layer.weight.t(), self.hcomm_info, bias=bias_
|
||||
)
|
||||
else:
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
@@ -454,18 +424,14 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
cls._HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
cls._HCOMM_INFO = group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
else:
|
||||
cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return cls._HCOMM_INFO
|
||||
|
||||
|
||||
class SequenceColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
@@ -490,13 +456,10 @@ class SequenceColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
|
||||
class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
"""Column-parallel linear with FlashComm2 OShard optimization."""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
@@ -505,12 +468,10 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if enable_sp():
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
input_, True)
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
|
||||
# Trigger async broadcast before matmul to overlap communication.
|
||||
flashcomm2_oshard_manager.trigger_broadcast_for_layer(
|
||||
self.layer.prefix)
|
||||
flashcomm2_oshard_manager.trigger_broadcast_for_layer(self.layer.prefix)
|
||||
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
if self.gather_output and self.tp_size > 1:
|
||||
@@ -523,14 +484,11 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||
|
||||
|
||||
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.unique_prefix = None
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
@@ -540,26 +498,21 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
|
||||
if self.tp_size == 1 or not self.reduce_results:
|
||||
output = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
|
||||
else:
|
||||
output = torch.ops.vllm.matmul_and_reduce(input_parallel,
|
||||
self.unique_prefix)
|
||||
output = torch.ops.vllm.matmul_and_reduce(input_parallel, self.unique_prefix)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def matmul_and_reduce(self, input_parallel: torch.Tensor,
|
||||
bias_: Optional[Parameter]) -> torch.Tensor:
|
||||
def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor:
|
||||
assert self.quant_method is not None
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
@@ -572,29 +525,24 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
x = input_parallel
|
||||
|
||||
if not sp_enabled:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||
x,
|
||||
bias=bias_)
|
||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0 and not (enable_dsa_cp()
|
||||
and "o_proj" in self.layer.prefix):
|
||||
if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix):
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
|
||||
world_size = self.layer.tp_size
|
||||
comm_mode = "aiv"
|
||||
hcom_name = get_tp_group().device_group._get_backend(
|
||||
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
|
||||
hcom_name = get_tp_group().device_group._get_backend(torch.device("npu")).get_hccl_comm_name(self.layer.tp_rank)
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.method_adapters import AscendLinearMethod
|
||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||
|
||||
# For unquant
|
||||
if mmrs_fusion and isinstance(self.layer.quant_method,
|
||||
UnquantizedLinearMethod):
|
||||
if mmrs_fusion and isinstance(self.layer.quant_method, UnquantizedLinearMethod):
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x,
|
||||
self.layer.weight.t(),
|
||||
@@ -603,19 +551,22 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
reduce_op="sum",
|
||||
bias=None,
|
||||
comm_turn=0,
|
||||
comm_mode=comm_mode)
|
||||
comm_mode=comm_mode,
|
||||
)
|
||||
if bias_ is not None:
|
||||
output.add_(bias_)
|
||||
# For w8a8 quant
|
||||
elif mmrs_fusion and (
|
||||
isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||
and isinstance(self.layer.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod)):
|
||||
isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||
and isinstance(self.layer.quant_method.quant_method, AscendW8A8LinearMethod)
|
||||
):
|
||||
if x.dtype != torch.int8:
|
||||
x_quant = torch.ops.vllm.quantize(
|
||||
x, self.layer.aclnn_input_scale,
|
||||
x,
|
||||
self.layer.aclnn_input_scale,
|
||||
self.layer.aclnn_input_scale_reciprocal,
|
||||
self.layer.aclnn_input_offset)
|
||||
self.layer.aclnn_input_offset,
|
||||
)
|
||||
else:
|
||||
x_quant = x
|
||||
quant_bias = self.layer.quant_bias
|
||||
@@ -631,14 +582,11 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
comm_turn=0,
|
||||
x2_scale=deq_scale,
|
||||
output_dtype=output_dtype,
|
||||
comm_mode=comm_mode)
|
||||
output = torch.add(
|
||||
output,
|
||||
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
|
||||
comm_mode=comm_mode,
|
||||
)
|
||||
output = torch.add(output, torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
|
||||
else:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||
x,
|
||||
bias=bias_)
|
||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
|
||||
|
||||
return output
|
||||
@@ -651,13 +599,10 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
|
||||
class ShardedCPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
# fake comm group to bypass tp logic
|
||||
return SimpleNamespace(world_size=1,
|
||||
rank_in_group=0,
|
||||
device_group=None)
|
||||
return SimpleNamespace(world_size=1, rank_in_group=0, device_group=None)
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
@@ -677,13 +622,10 @@ class ShardedCPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
|
||||
class ShardedCPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
# fake comm group to bypass tp logic
|
||||
return SimpleNamespace(world_size=1,
|
||||
rank_in_group=0,
|
||||
device_group=None)
|
||||
return SimpleNamespace(world_size=1, rank_in_group=0, device_group=None)
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
@@ -700,12 +642,10 @@ class ShardedCPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def _get_column_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
ShardedCPColumnParallelOp, Flashcomm2OshardQKVParallelOp]]:
|
||||
) -> MLPColumnParallelOp | SequenceColumnParallelOp | ShardedCPColumnParallelOp | Flashcomm2OshardQKVParallelOp | None:
|
||||
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
|
||||
return ShardedCPColumnParallelOp(layer)
|
||||
if "gate_up_proj" in prefix and mlp_tp_enable(
|
||||
) and not is_moe_layer(prefix):
|
||||
if "gate_up_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
|
||||
return MLPColumnParallelOp(layer)
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
||||
@@ -714,7 +654,7 @@ def _get_column_parallel_op(
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
sp_column_prefix = [
|
||||
"gate_up_proj", # first MLP of most LLMs
|
||||
"gate_up_proj", # first MLP of most LLMs
|
||||
"in_proj", # gated deltanet of Qwen3 Next
|
||||
"qkv_proj", # qkv linear of most LLMs
|
||||
"conv1d", # gated deltanet of Qwen3 Next
|
||||
@@ -729,9 +669,15 @@ def _get_column_parallel_op(
|
||||
|
||||
def _get_row_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
|
||||
) -> (
|
||||
MLPRowParallelOp
|
||||
| OProjRowParallelOp
|
||||
| Flashcomm2OProjRowParallelOp
|
||||
| MatmulAllreduceRowParallelOp
|
||||
| SequenceRowParallelOp
|
||||
| ShardedCPRowParallelOp
|
||||
| None
|
||||
):
|
||||
if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix:
|
||||
return ShardedCPRowParallelOp(layer)
|
||||
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
|
||||
@@ -760,16 +706,21 @@ def _get_row_parallel_op(
|
||||
|
||||
|
||||
def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
if disable_tp or ("shared_experts" in prefix
|
||||
and shared_expert_dp_enabled()):
|
||||
if disable_tp or ("shared_experts" in prefix and shared_expert_dp_enabled()):
|
||||
return None, 0, 1
|
||||
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
MLPRowParallelOp, OProjRowParallelOp,
|
||||
Flashcomm2OProjRowParallelOp,
|
||||
Flashcomm2OshardQKVParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp, ShardedCPRowParallelOp,
|
||||
ShardedCPColumnParallelOp]] = None
|
||||
custom_op: (
|
||||
MLPColumnParallelOp
|
||||
| SequenceColumnParallelOp
|
||||
| MLPRowParallelOp
|
||||
| OProjRowParallelOp
|
||||
| Flashcomm2OProjRowParallelOp
|
||||
| Flashcomm2OshardQKVParallelOp
|
||||
| MatmulAllreduceRowParallelOp
|
||||
| SequenceRowParallelOp
|
||||
| ShardedCPRowParallelOp
|
||||
| ShardedCPColumnParallelOp
|
||||
| None
|
||||
) = None
|
||||
if direct == "row":
|
||||
custom_op = _get_row_parallel_op(prefix, layer)
|
||||
|
||||
@@ -782,8 +733,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
|
||||
|
||||
def get_replicated_op(disable_tp, prefix,
|
||||
layer) -> Optional[Union[CustomReplicatedOp]]:
|
||||
def get_replicated_op(disable_tp, prefix, layer) -> CustomReplicatedOp | None:
|
||||
if disable_tp:
|
||||
return None
|
||||
|
||||
@@ -791,24 +741,22 @@ def get_replicated_op(disable_tp, prefix,
|
||||
|
||||
|
||||
def is_moe_layer(prefix: str) -> bool:
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_moe_params():
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
n_routed_experts = getattr(config, 'n_routed_experts', 0)
|
||||
first_k_dense_replace = getattr(config, 'first_k_dense_replace',
|
||||
float('inf'))
|
||||
moe_layer_freq = getattr(config, 'moe_layer_freq', 1)
|
||||
n_routed_experts = getattr(config, "n_routed_experts", 0)
|
||||
first_k_dense_replace = getattr(config, "first_k_dense_replace", float("inf"))
|
||||
moe_layer_freq = getattr(config, "moe_layer_freq", 1)
|
||||
return n_routed_experts, first_k_dense_replace, moe_layer_freq
|
||||
|
||||
match = re.search(r'layers\.(\d+)\.', prefix)
|
||||
match = re.search(r"layers\.(\d+)\.", prefix)
|
||||
if match is None:
|
||||
return False
|
||||
layer_idx = int(match.group(1))
|
||||
|
||||
n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params()
|
||||
|
||||
return (n_routed_experts is not None and layer_idx >= first_k_dense_replace
|
||||
and layer_idx % moe_layer_freq == 0)
|
||||
return n_routed_experts is not None and layer_idx >= first_k_dense_replace and layer_idx % moe_layer_freq == 0
|
||||
|
||||
Reference in New Issue
Block a user