[Feat] flashcomm2+oshard Generalized (#4723)
### What this PR does / why we need it?
[FlashComm2](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
introduces redundant storage of the o_proj matrix, which imposes
pressure on GPU memory. We propose the FlashComm2+Oshard approach by
integrating the shared linear layer feature (#2931). This approach
distributes weights layer-by-layer to each GPU and accesses the o_proj
of each layer via asynchronous broadcast operations, thereby alleviating
memory pressure while achieving nearly lossless performance compared to
the original FlashComm2. This PR implements a generalized
FlashComm2+Oshard solution.
Using following env to support flashcomm2 with oshard
```shell
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
--additional-config '{
"layer_sharding": ["o_proj"]
}'
```
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -157,6 +157,29 @@ def test_qwen3_moe_fc2_tp2() -> None:
|
|||||||
vllm_model.generate(example_prompts, sampling_params)
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"})
|
||||||
|
def test_qwen3_moe_fc2_oshard_tp2() -> None:
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9)
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
snapshot_download("Qwen/Qwen3-30B-A3B"),
|
||||||
|
dtype="auto",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
enforce_eager=
|
||||||
|
True, # TODO(Levi-JQ): support graph mode for fc2 in Qwen
|
||||||
|
additional_config={"layer_sharding": ["o_proj"]}) as vllm_model:
|
||||||
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
def test_deepseek_v2_lite_fc1_tp2() -> None:
|
def test_deepseek_v2_lite_fc1_tp2() -> None:
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
from vllm_ascend.compilation.acl_graph import (
|
from vllm_ascend.compilation.acl_graph import (
|
||||||
get_draft_graph_params, get_graph_params,
|
get_draft_graph_params, get_graph_params,
|
||||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||||
|
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||||
weak_ref_tensors)
|
weak_ref_tensors)
|
||||||
|
|
||||||
@@ -349,6 +350,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.value_cache = None
|
self.value_cache = None
|
||||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
super().process_weights_after_loading(act_dtype)
|
||||||
|
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||||
|
flashcomm2_oshard_manager.post_process_after_loading()
|
||||||
|
|
||||||
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
||||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||||
output: torch.Tensor) -> torch.Tensor:
|
output: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
100
vllm_ascend/ops/flashcomm2_oshard_manager.py
Normal file
100
vllm_ascend/ops/flashcomm2_oshard_manager.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.parallel_state import get_shard_weight_group
|
||||||
|
from vllm_ascend.ops.layer_shard_linear import (
|
||||||
|
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||||
|
reach_layer_for_shard_weight_series, register_layer_to_shard_weight_series)
|
||||||
|
from vllm_ascend.utils import flashcomm2_enable, o_shard_enable
|
||||||
|
|
||||||
|
|
||||||
|
class Flashcomm2OShardManager:
|
||||||
|
"""Manages sharded layers for the FlashComm2 O-Shard feature.
|
||||||
|
|
||||||
|
This class is implemented to centralize all logic related to Flashcomm2OShard layers.
|
||||||
|
Its main responsibilities are:
|
||||||
|
1. Registering Attention `o_proj` layers that require O-Sharding.
|
||||||
|
2. Storing and managing these layers in a dictionary mapping layer indices
|
||||||
|
to layer objects (`layer_index -> layer`).
|
||||||
|
3. Providing a high-level API for external callers to use at key stages
|
||||||
|
like model initialization, computation, and weight loading.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_shard_layers: A dictionary to store the registered sharded layers,
|
||||||
|
mapping a layer index (int) to its corresponding layer object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._shard_layers: Dict[int, Any] = {}
|
||||||
|
|
||||||
|
def flashcomm2_oshard_enable(self):
|
||||||
|
return flashcomm2_enable() and o_shard_enable()
|
||||||
|
|
||||||
|
def register_layer(self, layer: Any, prefetch_step: int = 1):
|
||||||
|
"""Registers a layer for O-Sharding.
|
||||||
|
|
||||||
|
This method first checks if the O-Shard feature is enabled and if the
|
||||||
|
provided layer qualifies as a target (e.g., a hidden layer). If so,
|
||||||
|
it performs two actions:
|
||||||
|
1. Caches the layer internally in the `_shard_layers` dictionary.
|
||||||
|
2. Calls the underlying `register_layer_to_shared_weight_series`
|
||||||
|
function to register it for communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer object to be registered.
|
||||||
|
prefetch_step: The prefetch step to be used when registering the
|
||||||
|
layer to the shared weight series.
|
||||||
|
"""
|
||||||
|
# Check if the layer is a target for sharding.
|
||||||
|
if is_hidden_layer(layer):
|
||||||
|
layer_idx = extract_layer_index(layer.prefix)
|
||||||
|
self._shard_layers[layer_idx] = layer
|
||||||
|
|
||||||
|
register_layer_to_shard_weight_series(
|
||||||
|
series_name="o_proj",
|
||||||
|
group=get_shard_weight_group(),
|
||||||
|
layer=layer,
|
||||||
|
prefetch_step=prefetch_step)
|
||||||
|
|
||||||
|
def get_layer(self, layer_idx: int) -> Optional[Any]:
|
||||||
|
"""Safely retrieves a registered layer by its index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_idx: The index of the layer to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The layer object if found, otherwise None.
|
||||||
|
"""
|
||||||
|
return self._shard_layers.get(layer_idx)
|
||||||
|
|
||||||
|
def trigger_broadcast_for_layer(self, layer_prefix: str):
|
||||||
|
"""Triggers a broadcast for a specific layer during model computation.
|
||||||
|
|
||||||
|
This method is intended to be called within a layer's forward pass.
|
||||||
|
It extracts the layer index from the prefix, retrieves the corresponding
|
||||||
|
registered layer object, and then triggers the broadcast operation
|
||||||
|
if all conditions are met.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_prefix: The name prefix of the current layer being computed.
|
||||||
|
"""
|
||||||
|
layer_idx = extract_layer_index(layer_prefix)
|
||||||
|
target_layer = self.get_layer(layer_idx)
|
||||||
|
|
||||||
|
# Ensure the layer exists and meets the sharding criteria.
|
||||||
|
if target_layer and is_hidden_layer(target_layer):
|
||||||
|
reach_layer_for_shard_weight_series(target_layer)
|
||||||
|
|
||||||
|
def post_process_after_loading(self):
|
||||||
|
"""Performs post-processing on all registered layers after weight loading.
|
||||||
|
|
||||||
|
This should be called once after the model weights have been fully loaded.
|
||||||
|
"""
|
||||||
|
if self._shard_layers:
|
||||||
|
# Pick any layer (e.g., the first one) to trigger the shard post-processing
|
||||||
|
any_layer = next(iter(self._shard_layers.values()))
|
||||||
|
post_process_after_loading_for_shard_weight_series(any_layer)
|
||||||
|
|
||||||
|
|
||||||
|
flashcomm2_oshard_manager = Flashcomm2OShardManager()
|
||||||
@@ -21,6 +21,7 @@ CustomLinearOp
|
|||||||
├── CustomColumnParallelOp
|
├── CustomColumnParallelOp
|
||||||
│ ├── MLPColumnParallelOp
|
│ ├── MLPColumnParallelOp
|
||||||
│ ├── SequenceColumnParallelOp
|
│ ├── SequenceColumnParallelOp
|
||||||
|
│ ├── Flashcomm2OshardQKVParallelOp
|
||||||
└── CustomRowParallelOp
|
└── CustomRowParallelOp
|
||||||
│ ├── MLPRowParallelOp
|
│ ├── MLPRowParallelOp
|
||||||
│ ├── OProjRowParallelOp
|
│ ├── OProjRowParallelOp
|
||||||
@@ -60,6 +61,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
|
|||||||
get_flashcomm2_otp_group,
|
get_flashcomm2_otp_group,
|
||||||
get_mlp_tp_group,
|
get_mlp_tp_group,
|
||||||
get_otp_group)
|
get_otp_group)
|
||||||
|
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||||
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
|
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
|
||||||
get_flashcomm2_reorgnized_batch_ids,
|
get_flashcomm2_reorgnized_batch_ids,
|
||||||
matmul_allreduce_enable, mlp_tp_enable,
|
matmul_allreduce_enable, mlp_tp_enable,
|
||||||
@@ -400,6 +402,9 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
super().update_attrs()
|
super().update_attrs()
|
||||||
self.input_is_parallel = self.layer.input_is_parallel
|
self.input_is_parallel = self.layer.input_is_parallel
|
||||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||||
|
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||||
|
flashcomm2_oshard_manager.register_layer(self.layer,
|
||||||
|
prefetch_step=1)
|
||||||
|
|
||||||
|
|
||||||
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||||
@@ -479,6 +484,39 @@ class SequenceColumnParallelOp(CustomColumnParallelOp):
|
|||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
|
class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
||||||
|
|
||||||
|
def __init__(self, layer):
|
||||||
|
super().__init__(layer)
|
||||||
|
|
||||||
|
def apply_impl(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
"""Column-parallel linear with FlashComm2 OShard optimization."""
|
||||||
|
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
if enable_sp():
|
||||||
|
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
input_, True)
|
||||||
|
|
||||||
|
# Trigger async broadcast before matmul to overlap communication.
|
||||||
|
flashcomm2_oshard_manager.trigger_broadcast_for_layer(
|
||||||
|
self.layer.prefix)
|
||||||
|
|
||||||
|
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||||
|
if self.gather_output and self.tp_size > 1:
|
||||||
|
# All-gather across the partitions.
|
||||||
|
output = self.comm_group.all_gather(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class SequenceRowParallelOp(CustomRowParallelOp):
|
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||||
|
|
||||||
def __init__(self, layer):
|
def __init__(self, layer):
|
||||||
@@ -657,12 +695,15 @@ class ShardedCPColumnParallelOp(CustomColumnParallelOp):
|
|||||||
def _get_column_parallel_op(
|
def _get_column_parallel_op(
|
||||||
prefix, layer
|
prefix, layer
|
||||||
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||||
ShardedCPColumnParallelOp]]:
|
ShardedCPColumnParallelOp, Flashcomm2OshardQKVParallelOp]]:
|
||||||
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
|
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
|
||||||
return ShardedCPColumnParallelOp(layer)
|
return ShardedCPColumnParallelOp(layer)
|
||||||
if "gate_up_proj" in prefix and mlp_tp_enable(
|
if "gate_up_proj" in prefix and mlp_tp_enable(
|
||||||
) and not is_moe_layer(prefix):
|
) and not is_moe_layer(prefix):
|
||||||
return MLPColumnParallelOp(layer)
|
return MLPColumnParallelOp(layer)
|
||||||
|
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||||
|
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
||||||
|
return Flashcomm2OshardQKVParallelOp(layer)
|
||||||
if enable_sp():
|
if enable_sp():
|
||||||
if "shared_expert" in prefix:
|
if "shared_expert" in prefix:
|
||||||
return None
|
return None
|
||||||
@@ -719,6 +760,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
|||||||
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||||
MLPRowParallelOp, OProjRowParallelOp,
|
MLPRowParallelOp, OProjRowParallelOp,
|
||||||
Flashcomm2OProjRowParallelOp,
|
Flashcomm2OProjRowParallelOp,
|
||||||
|
Flashcomm2OshardQKVParallelOp,
|
||||||
MatmulAllreduceRowParallelOp,
|
MatmulAllreduceRowParallelOp,
|
||||||
SequenceRowParallelOp, ShardedCPRowParallelOp,
|
SequenceRowParallelOp, ShardedCPRowParallelOp,
|
||||||
ShardedCPColumnParallelOp]] = None
|
ShardedCPColumnParallelOp]] = None
|
||||||
|
|||||||
@@ -1017,6 +1017,13 @@ def flashcomm2_enable() -> bool:
|
|||||||
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
||||||
|
|
||||||
|
|
||||||
|
def o_shard_enable() -> bool:
|
||||||
|
layer_sharding = get_ascend_config().layer_sharding
|
||||||
|
if layer_sharding is None:
|
||||||
|
return False
|
||||||
|
return "o_proj" in layer_sharding
|
||||||
|
|
||||||
|
|
||||||
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||||
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
||||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
|||||||
Reference in New Issue
Block a user