Files
xc-llm-ascend/vllm_ascend/ops/flashcomm2_oshard_manager.py
SILONG ZENG 06aa6036f6 [Lint]Style: Convert vllm-ascend/ to ruff format(new Batch #8) (#6604)
### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| vllm_ascend/ops/\_\_init\_\_.py |
| vllm_ascend/ops/activation.py |
| vllm_ascend/ops/flashcomm2_oshard_manager.py |
| vllm_ascend/ops/layernorm.py |
| vllm_ascend/ops/mla.py |
| vllm_ascend/ops/mm_encoder_attention.py |
| vllm_ascend/ops/register_custom_ops.py |
| vllm_ascend/ops/vocab_parallel_embedding.py |
| vllm_ascend/ops/weight_prefetch.py |
| vllm_ascend/spec_decode/\_\_init\_\_.py |
| vllm_ascend/spec_decode/eagle_proposer.py |
| vllm_ascend/spec_decode/interface.py |
| vllm_ascend/spec_decode/mtp_proposer.py |
| vllm_ascend/spec_decode/ngram_proposer.py |
| vllm_ascend/spec_decode/suffix_proposer.py |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

Signed-off-by: MrZ20 <2609716663@qq.com>
2026-02-07 09:16:07 +08:00

102 lines
4.0 KiB
Python

from typing import Any
from vllm.model_executor.models.utils import extract_layer_index
from vllm_ascend.distributed.parallel_state import get_shard_weight_group
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer,
post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_layer_to_shard_weight_series,
)
from vllm_ascend.utils import flashcomm2_enable, o_shard_enable
class Flashcomm2OShardManager:
"""Manages sharded layers for the FlashComm2 O-Shard feature.
This class is implemented to centralize all logic related to Flashcomm2OShard layers.
Its main responsibilities are:
1. Registering Attention `o_proj` layers that require O-Sharding.
2. Storing and managing these layers in a dictionary mapping layer indices
to layer objects (`layer_index -> layer`).
3. Providing a high-level API for external callers to use at key stages
like model initialization, computation, and weight loading.
Attributes:
_shard_layers: A dictionary to store the registered sharded layers,
mapping a layer index (int) to its corresponding layer object.
"""
def __init__(self):
self._shard_layers: dict[int, Any] = {}
def flashcomm2_oshard_enable(self):
return flashcomm2_enable() and o_shard_enable()
def register_layer(self, layer: Any, prefetch_step: int = 1):
"""Registers a layer for O-Sharding.
This method first checks if the O-Shard feature is enabled and if the
provided layer qualifies as a target (e.g., a hidden layer). If so,
it performs two actions:
1. Caches the layer internally in the `_shard_layers` dictionary.
2. Calls the underlying `register_layer_to_shared_weight_series`
function to register it for communication.
Args:
layer: The layer object to be registered.
prefetch_step: The prefetch step to be used when registering the
layer to the shared weight series.
"""
# Check if the layer is a target for sharding.
if is_hidden_layer(layer):
layer_idx = extract_layer_index(layer.prefix)
self._shard_layers[layer_idx] = layer
register_layer_to_shard_weight_series(
series_name="o_proj", group=get_shard_weight_group(), layer=layer, prefetch_step=prefetch_step
)
def get_layer(self, layer_idx: int) -> Any | None:
"""Safely retrieves a registered layer by its index.
Args:
layer_idx: The index of the layer to retrieve.
Returns:
The layer object if found, otherwise None.
"""
return self._shard_layers.get(layer_idx)
def trigger_broadcast_for_layer(self, layer_prefix: str):
"""Triggers a broadcast for a specific layer during model computation.
This method is intended to be called within a layer's forward pass.
It extracts the layer index from the prefix, retrieves the corresponding
registered layer object, and then triggers the broadcast operation
if all conditions are met.
Args:
layer_prefix: The name prefix of the current layer being computed.
"""
layer_idx = extract_layer_index(layer_prefix)
target_layer = self.get_layer(layer_idx)
# Ensure the layer exists and meets the sharding criteria.
if target_layer and is_hidden_layer(target_layer):
reach_layer_for_shard_weight_series(target_layer)
def post_process_after_loading(self):
"""Performs post-processing on all registered layers after weight loading.
This should be called once after the model weights have been fully loaded.
"""
if self._shard_layers:
# Pick any layer (e.g., the first one) to trigger the shard post-processing
any_layer = next(iter(self._shard_layers.values()))
post_process_after_loading_for_shard_weight_series(any_layer)
flashcomm2_oshard_manager = Flashcomm2OShardManager()