[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #3) (#5978)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
This commit is contained in:
SILONG ZENG
2026-01-24 22:10:18 +08:00
committed by GitHub
parent 4e53c1d900
commit 7faa6878a6
9 changed files with 953 additions and 1148 deletions

View File

@@ -51,11 +51,6 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
# (3)
"vllm_ascend/attention/*.py",
"vllm_ascend/core/*.py",
"vllm_ascend/distributed/device_communicators/**",
"vllm_ascend/distributed/utils.py",
# (5)
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
"vllm_ascend/distributed/kv_transfer/utils/**",

View File

@@ -394,7 +394,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # type: ignore[misc]
assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1)

File diff suppressed because it is too large Load Diff

View File

@@ -1,19 +1,17 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, TypeVar
import torch
import torch_npu
import vllm.envs as envs_vllm
from torch import nn
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend, AttentionCGSupport, MLAAttentionImpl)
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -22,22 +20,33 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata,
ascend_chunked_prefill_workspace_size,
maybe_save_kv_layer_to_connector,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
trans_rope_weight,
transdata,
wait_for_kv_layer_from_connector,
)
from vllm_ascend.distributed.utils import all_gather_async
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
is_hidden_layer,
post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series)
register_all_layers_to_shard_weight_series,
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz)
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
_round_up,
dispose_layer,
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
maybe_trans_nz,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
@@ -48,7 +57,6 @@ BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024
class AscendSFABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
@@ -63,12 +71,11 @@ class AscendSFABackend(AttentionBackend):
return AscendSFAMetadataBuilder
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_impl_cls() -> Type["AscendSFAImpl"]:
def get_impl_cls() -> type["AscendSFAImpl"]:
return AscendSFAImpl
@@ -91,6 +98,7 @@ class AscendSFAMetadata:
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
@@ -109,11 +117,11 @@ class AscendSFAMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim: Optional[int] = None
head_dim: int | None = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
dsa_cp_context: Optional[DSACPContext] = None
dsa_cp_context: DSACPContext | None = None
reshape_cache_event: torch.npu.Event = None
@@ -136,37 +144,38 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
supports_dcp_with_varlen: bool = False,
):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device,
kv_cache_spec,
layer_names,
vllm_config,
device,
metadata_cls if metadata_cls is not None else AscendSFAMetadata,
supports_dcp_with_varlen)
supports_dcp_with_varlen,
)
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
assert self.decode_threshold <= 16, (
f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_dsa_cp = enable_dsa_cp()
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.actual_seq_lengths_key = torch.empty_like(
self.actual_seq_lengths_query)
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
@staticmethod
def determine_chunked_prefill_workspace_size(
vllm_config: VllmConfig) -> int:
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@classmethod
@@ -179,8 +188,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_BATCH
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA
return False
@@ -196,9 +204,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
input_positions = common_attn_metadata.positions[:
num_input_tokens].long(
)
input_positions = common_attn_metadata.positions[:num_input_tokens].long()
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
@@ -216,8 +222,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
local_end = min(local_end_with_pad, num_actual_tokens)
pad_size = num_tokens_pad - cos.shape[0]
assert cos.shape == sin.shape, \
f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
assert cos.shape == sin.shape, f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
if pad_size > 0:
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
@@ -225,9 +230,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
if pad_size_slot > 0:
slot_mapping = nn.functional.pad(slot_mapping,
(0, pad_size_slot),
value=-1)
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size_slot), value=-1)
else:
slot_mapping = slot_mapping[:num_tokens_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
@@ -235,15 +238,18 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad]
assert cos.shape[0] == num_tokens_per_device, \
assert cos.shape[0] == num_tokens_per_device, (
f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}"
assert slot_mapping_cp.shape[0] == num_tokens_per_device, \
)
assert slot_mapping_cp.shape[0] == num_tokens_per_device, (
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
assert slot_mapping.shape[0] == num_tokens_pad, \
)
assert slot_mapping.shape[0] == num_tokens_pad, (
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
got {slot_mapping.shape[0]} and {num_tokens_pad}"
)
actual_seq_lengths_query = self.actual_seq_lengths_query
actual_seq_lengths_key = self.actual_seq_lengths_key
@@ -291,31 +297,26 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
seq_lens=seq_lens,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
attn_mask=self.attn_mask_builder.get_attention_mask(
self.model_config),
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens],
dsa_cp_context=dsa_cp_context)
dsa_cp_context=dsa_cp_context,
)
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state in {
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
}:
if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)
raise NotImplementedError("Currently we only support building dummy metadata for DecodeOnly state")
attn_metadata.attn_state = attn_state
return attn_metadata
@@ -326,8 +327,9 @@ class AscendSFAImpl(MLAAttentionImpl):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: Optional[torch.Tensor] = None
o_proj_full_pool: torch.Tensor | None = None
def __init__(
self,
@@ -335,12 +337,12 @@ class AscendSFAImpl(MLAAttentionImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
kv_sharing_target_layer_name: str | None,
**kwargs,
) -> None:
self.num_heads = num_heads
@@ -350,26 +352,25 @@ class AscendSFAImpl(MLAAttentionImpl):
self.kv_cache_dtype = kv_cache_dtype
# MLA Args
self.q_lora_rank = kwargs['q_lora_rank']
self.kv_lora_rank = kwargs['kv_lora_rank']
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
'q_b_proj']
self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None)
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.indexer = kwargs['indexer']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.q_lora_rank = kwargs["q_lora_rank"]
self.kv_lora_rank = kwargs["kv_lora_rank"]
self.qk_nope_head_dim = kwargs["qk_nope_head_dim"]
self.qk_rope_head_dim = kwargs["qk_rope_head_dim"]
self.qk_head_dim = kwargs["qk_head_dim"]
self.v_head_dim = kwargs["v_head_dim"]
self.rotary_emb = kwargs["rotary_emb"]
self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"]
self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj")
self.kv_b_proj = kwargs["kv_b_proj"]
self.o_proj = kwargs["o_proj"]
self.indexer = kwargs["indexer"]
self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa")
self.kv_a_layernorm = kwargs.get("kv_a_layernorm")
self.q_a_layernorm = kwargs.get("q_a_layernorm")
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
self.q_b_proj = kwargs['q_b_proj']
self.q_b_proj = kwargs["q_b_proj"]
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
@@ -383,7 +384,9 @@ class AscendSFAImpl(MLAAttentionImpl):
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
# indexer param
self.n_head: int = self.indexer.n_head # 64
@@ -400,38 +403,38 @@ class AscendSFAImpl(MLAAttentionImpl):
self.local_num_heads = self.num_heads * self.tp_size
if self.enable_dsa_cp_prefill_only:
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
for layer_name in get_ascend_config().layer_sharding or []:
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, "
"skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
kv_b_proj_weight = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.local_num_heads *
(self.qk_nope_head_dim + self.v_head_dim)), (
self.kv_lora_rank,
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.local_num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.local_num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
@@ -445,10 +448,9 @@ class AscendSFAImpl(MLAAttentionImpl):
dispose_layer(self.kv_b_proj)
if self.enable_dsa_cp:
if self.enable_dsa_cp_prefill_only:
for layer in (self.layer_sharding_kwargs or []):
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(
layer)
post_process_after_loading_for_shard_weight_series(layer)
else:
self._init_o_proj_tp_full_params()
@@ -459,15 +461,14 @@ class AscendSFAImpl(MLAAttentionImpl):
None,
)
reasons = []
if self.fused_qkv_a_proj is None or not isinstance(
quant_method, AscendW8A8LinearMethod):
if self.fused_qkv_a_proj is None or not isinstance(quant_method, AscendW8A8LinearMethod):
reasons.append(
"Currently mlapo only supports W8A8 quantization in SFA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
"thus mlapo is disabled for these layers."
)
if self.enable_dsa_cp:
reasons.append("Currently mlapo does not support SFA with CP,"
"thus mlapo is disabled for these layers.")
reasons.append("Currently mlapo does not support SFA with CP,thus mlapo is disabled for these layers.")
if reasons:
self.enable_mlapo = False
for msg in reasons:
@@ -480,32 +481,31 @@ class AscendSFAImpl(MLAAttentionImpl):
def _v_up_proj(self, x):
num_input_tokens, _, _ = x.shape
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS:
if (
x.dtype in [torch.float16, torch.bfloat16]
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose")
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS
):
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.local_num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.local_num_heads,
self.kv_lora_rank).transpose(0, 1)
x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0,
1).reshape(-1,
self.local_num_heads * self.v_head_dim)
x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim)
return x
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.local_num_heads, self.qk_head_dim)\
q_nope, q_pe = (
self.q_proj(x)[0]
.view(-1, self.local_num_heads, self.qk_head_dim)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
@@ -519,27 +519,26 @@ class AscendSFAImpl(MLAAttentionImpl):
kv_no_split: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
kv_cache: tuple,
slots: torch.Tensor,
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
S = 1
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA"
if self.enable_dsa_cp:
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
self.kv_a_layernorm.weight, # type: ignore[union-attr]
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
cache_mode=cache_mode,
is_output_kv=True,
)
@@ -547,13 +546,13 @@ class AscendSFAImpl(MLAAttentionImpl):
else:
torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
self.kv_a_layernorm.weight, # type: ignore[union-attr]
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
cache_mode=cache_mode,
)
return None, None
@@ -577,78 +576,53 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.kv_a_proj_with_mqa is None
assert self.fused_qkv_a_proj is not None
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous()
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
self.q_lora_rank:].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
q_lora_rank].contiguous(
)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
dim=-1).contiguous()
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous()
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous()
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
self.q_lora_rank:].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
q_lora_rank].contiguous(
)
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous()
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
dim=-1).contiguous()
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous()
wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
-1)
wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data
qb_deq_scl = qb_deq_scl.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
qb_qt_bias = self.q_proj.quant_bias.data
qb_qt_bias = qb_qt_bias.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_proj.weight.device
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
self.beta1 = self.q_a_layernorm.bias.data # type: ignore[union-attr]
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
@@ -659,9 +633,11 @@ class AscendSFAImpl(MLAAttentionImpl):
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory.
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer and \
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
if (
self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.is_kv_consumer
and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS
):
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
@@ -673,13 +649,12 @@ class AscendSFAImpl(MLAAttentionImpl):
def _sfa_preprocess_decode(
self,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool,
num_input_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), need_gather_q_kv)
k_nope, k_pe = kv_cache[0], kv_cache[1]
ql_nope = torch.empty(
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
@@ -734,17 +709,17 @@ class AscendSFAImpl(MLAAttentionImpl):
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
forward_context = get_forward_context()
if attn_metadata is None:
# Profiling run.
if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run:
for layer in (self.layer_sharding_kwargs or []):
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
@@ -761,12 +736,13 @@ class AscendSFAImpl(MLAAttentionImpl):
# all-gather o_proj weight for prefill stage of PD mix node
o_proj_full_handle = None
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage.
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj
# weight for prefill stage.
should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in {
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding,
}
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states,
@@ -776,35 +752,30 @@ class AscendSFAImpl(MLAAttentionImpl):
num_input_tokens=num_input_tokens,
)
q, k = self.indexer_select_pre_process(
x=hidden_states,
qr=q_c,
cos=cos,
sin=sin,
need_gather_q_kv=need_gather_q_kv)
x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv
)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
maybe_npu_prefetch(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
q_c = self.q_a_layernorm(q_c)
# Process for Flash Comm V1
if need_gather_q_kv:
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c.contiguous(), need_gather_q_kv)
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
kv_no_split.contiguous(), need_gather_q_kv
)
q, k = self.indexer_select_pre_process(
x=hidden_states,
qr=q_c,
cos=cos,
sin=sin,
need_gather_q_kv=need_gather_q_kv)
x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv
)
wait_for_kv_layer_from_connector(layer_name)
@@ -815,22 +786,20 @@ class AscendSFAImpl(MLAAttentionImpl):
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
slot_mapping)
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping)
if self.enable_dsa_cp:
assert k_pe is not None
assert k_nope is not None
# support all_gather kv async for communication calculation overlap
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat([
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k.view(-1, k.shape[-1])
],
dim=1),
torch.cat(
[k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), k.view(-1, k.shape[-1])],
dim=1,
),
get_tp_group(),
async_op=should_shard_weight)
async_op=should_shard_weight,
)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
@@ -840,34 +809,27 @@ class AscendSFAImpl(MLAAttentionImpl):
kv_ag_handle.wait()
if self.enable_dsa_cp_prefill_only:
for layer in (self.layer_sharding_kwargs or []):
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
elif should_shard_weight:
_, o_proj_full_handle = all_gather_async(
self.o_proj_tp_weight,
get_tp_group(),
output=AscendSFAImpl.o_proj_full_pool)
self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool
)
if kv_cache is not None:
assert fused_kv_no_split is not None
k_pe, k_nope, k = fused_kv_no_split.split([
self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim
],
dim=-1)
k_pe, k_nope, k = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
slot_mapping = attn_metadata.slot_mapping.view(-1, 1)
torch_npu.npu_scatter_nd_update_(
kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping,
k_nope)
torch_npu.npu_scatter_nd_update_(
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
k_pe)
torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope)
torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe)
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(-1, 1),
k.view(-1, k.shape[-1])) # b, s, n, d
kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1])
) # b, s, n, d
topk_indices = self.indexer_select_post_process(
x=hidden_states,
@@ -880,7 +842,8 @@ class AscendSFAImpl(MLAAttentionImpl):
sin=sin,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
need_gather_q_kv=need_gather_q_kv)
need_gather_q_kv=need_gather_q_kv,
)
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
query=ql_nope,
@@ -900,10 +863,12 @@ class AscendSFAImpl(MLAAttentionImpl):
)
attn_output = self._v_up_proj(attn_output)
maybe_npu_prefetch(inputs=self.o_proj.weight,
maybe_npu_prefetch(
inputs=self.o_proj.weight,
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
enabled=self.enable_prefetch,
)
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:
# When using SFA-CP with pd mixed, o_proj has two cases:
@@ -913,7 +878,8 @@ class AscendSFAImpl(MLAAttentionImpl):
attn_output=attn_output,
output=output,
o_proj_full_handle=o_proj_full_handle,
should_shard_weight=should_shard_weight)
should_shard_weight=should_shard_weight,
)
if not require_o_proj_forward:
return result
attn_output = result
@@ -933,8 +899,7 @@ class AscendSFAImpl(MLAAttentionImpl):
need_gather_q_kv: bool = False,
):
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(k_proj, need_gather_q_kv)
k = self.k_norm(k_proj).unsqueeze(1)
k = k.view(-1, 1, self.head_dim)
@@ -944,17 +909,9 @@ class AscendSFAImpl(MLAAttentionImpl):
cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q,
k,
cos,
sin,
rope_dim=self.qk_rope_head_dim,
is_neox_style=True)
q, k = rope_forward_triton(q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=True)
else:
k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1)
k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
@@ -972,9 +929,9 @@ class AscendSFAImpl(MLAAttentionImpl):
self,
x: torch.Tensor,
qr: torch.Tensor,
q: Optional[torch.Tensor],
q: torch.Tensor | None,
k: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
cos: torch.Tensor,
sin: torch.Tensor,
@@ -988,9 +945,8 @@ class AscendSFAImpl(MLAAttentionImpl):
cos_q, sin_q = cos, sin
q_pe, q_nope = torch.split(
q,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64]
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
) # [b,s,64,64+64]
q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q)
@@ -1000,17 +956,14 @@ class AscendSFAImpl(MLAAttentionImpl):
if kv_cache is not None:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
-1, 1),
k.view(-1,
k.shape[-1])) # b, s, n, d
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1])
) # b, s, n, d
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
weights, need_gather_q_kv)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
block_table = attn_metadata.block_tables
@@ -1024,7 +977,8 @@ class AscendSFAImpl(MLAAttentionImpl):
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3)
sparse_mode=3,
)
return topk_indices
def _init_o_proj_tp_full_params(self):
@@ -1039,38 +993,33 @@ class AscendSFAImpl(MLAAttentionImpl):
if AscendSFAImpl.o_proj_full_pool is None:
sample = self.o_proj.weight
AscendSFAImpl.o_proj_full_pool = torch.empty(
(sample.shape[0] * self.tp_size, sample.shape[1]),
dtype=sample.dtype,
device=sample.device)
(sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device
)
# Save TP-mode parameters (original sharded weights)
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone(
).detach()
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone(
).detach()
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone(
).detach()
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach()
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach()
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach()
# Initially switch to TP mode for graph capture
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(
self.tp_size)
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(
self.tp_size)
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(
self.tp_size)
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size)
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size)
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size)
def _handle_o_proj_weight_switch_and_forward(
self, attn_output: torch.Tensor, output: torch.Tensor,
o_proj_full_handle: Optional[torch.distributed.Work],
should_shard_weight: bool) -> Tuple[torch.Tensor, bool]:
self,
attn_output: torch.Tensor,
output: torch.Tensor,
o_proj_full_handle: torch.distributed.Work | None,
should_shard_weight: bool,
) -> tuple[torch.Tensor, bool]:
"""
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
"""
@@ -1082,36 +1031,30 @@ class AscendSFAImpl(MLAAttentionImpl):
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
self.o_proj.aclnn_input_scale.set_(
self.o_proj_full_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_full_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(
self.o_proj_full_aclnn_input_offset)
self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset)
# Apply quantization method and execute forward computation
output[...] = self.o_proj.quant_method.quant_method.apply(
self.o_proj, attn_output)
output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output)
# Switch o_proj back to TP-mode for subsequent decode operations
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(
self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(
self.o_proj_tp_aclnn_input_offset)
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
return output, False
else:
# For decode scenario: perform all-to-all communication on o_proj input activations
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
send = attn_output.view(-1, self.tp_size, self.num_heads *
self.v_head_dim).permute(1, 0, 2).reshape(
-1, self.num_heads * self.v_head_dim)
send = (
attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim)
.permute(1, 0, 2)
.reshape(-1, self.num_heads * self.v_head_dim)
)
attn_output = torch.empty_like(send)
torch.distributed.all_to_all_single(
attn_output, send, group=get_tp_group().device_group)
torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group)
return attn_output, True

View File

@@ -21,26 +21,21 @@ from __future__ import annotations
import time
from collections import defaultdict
from dataclasses import dataclass, fields
from typing import Type, Union
from vllm._bc_linter import bc_linter_include
from vllm.config import SchedulerConfig, VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \
KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs, FinishReason)
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
@@ -51,26 +46,22 @@ logger = init_logger(__name__)
@dataclass
class RecomputeSchedulerConfig(SchedulerConfig):
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
scheduler_cls: str | type[object] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
@classmethod
def initialize_from_config(cls, vllm_config: VllmConfig):
vllm_scheduler_config = vllm_config.scheduler_config
scheduler_config = {
field.name: getattr(vllm_scheduler_config, field.name)
for field in fields(vllm_scheduler_config) if field.init
for field in fields(vllm_scheduler_config)
if field.init
}
if vllm_scheduler_config.async_scheduling:
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler")
scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler"
else:
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
scheduler_config[
"max_model_len"] = vllm_config.model_config.max_model_len
scheduler_config[
"is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
scheduler_config["max_model_len"] = vllm_config.model_config.max_model_len
scheduler_config["is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
return cls(**scheduler_config)
@@ -125,33 +116,32 @@ class RecomputeScheduler(Scheduler):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if (request.num_output_placeholders > 0
if (
request.num_output_placeholders > 0
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
# Since output placeholders are also included in the computed tokens
# count, we subtract (num_output_placeholders - 1) to remove any draft
# tokens, so that we can be sure no further steps are needed even if
# they are all rejected.
and request.num_computed_tokens + 2 -
request.num_output_placeholders
>= request.num_prompt_tokens + request.max_tokens):
and request.num_computed_tokens + 2 - request.num_output_placeholders
>= request.num_prompt_tokens + request.max_tokens
):
# Async scheduling: Avoid scheduling an extra step when we are sure that
# the previous step has reached request.max_tokens. We don't schedule
# partial draft tokens since this prevents uniform decode optimizations.
req_index += 1
continue
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
num_new_tokens = (
request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
@@ -209,9 +199,10 @@ class RecomputeScheduler(Scheduler):
recomputed_req = self.running.pop()
self.kv_cache_manager.free(recomputed_req)
recomputed_reqs.append(
RecomputeReqInfo(recomputed_req.request_id,
recomputed_req.output_token_ids,
recomputed_req.client_index))
RecomputeReqInfo(
recomputed_req.request_id, recomputed_req.output_token_ids, recomputed_req.client_index
)
)
if recomputed_req == request:
break
else:
@@ -223,28 +214,23 @@ class RecomputeScheduler(Scheduler):
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[
preempted_req.request_id]
token_budget += num_scheduled_tokens[preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(
preempted_req.request_id)
scheduled_spec_decode_tokens.pop(
preempted_req.request_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
preempted_req.request_id, None)
num_scheduled_tokens.pop(preempted_req.request_id)
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs)
preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs
)
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
self._preempt_request(preempted_req,
scheduled_timestamp)
self._preempt_request(preempted_req, scheduled_timestamp)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
@@ -263,23 +249,20 @@ class RecomputeScheduler(Scheduler):
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens -
request.num_output_placeholders)
num_scheduled_spec_tokens = (
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
@@ -294,8 +277,10 @@ class RecomputeScheduler(Scheduler):
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
req.lora_request.lora_int_id
for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
@@ -337,9 +322,14 @@ class RecomputeScheduler(Scheduler):
# Check that adding the request still respects the max_loras
# constraint.
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras and
request.lora_request.lora_int_id not in scheduled_loras)):
if (
self.lora_config
and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
@@ -351,14 +341,15 @@ class RecomputeScheduler(Scheduler):
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = (
self.kv_cache_manager.get_computed_blocks(request))
new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
request
)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
ext_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens
)
if ext_tokens is None:
# The request cannot be scheduled because
@@ -372,8 +363,7 @@ class RecomputeScheduler(Scheduler):
num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
@@ -401,8 +391,7 @@ class RecomputeScheduler(Scheduler):
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if (not self.scheduler_config.enable_chunked_prefill
and num_new_tokens > token_budget):
if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
# If chunked_prefill is disabled,
# we can stop the scheduling here.
break
@@ -433,9 +422,7 @@ class RecomputeScheduler(Scheduler):
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens
== 0 else
self.num_lookahead_tokens)
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
@@ -443,8 +430,7 @@ class RecomputeScheduler(Scheduler):
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens = (
self.scheduler_config.max_num_encoder_input_tokens)
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
@@ -488,20 +474,17 @@ class RecomputeScheduler(Scheduler):
req_index += 1
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
raise RuntimeError(f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = (
self.kv_cache_manager.get_blocks(request.request_id))
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
@@ -511,8 +494,7 @@ class RecomputeScheduler(Scheduler):
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
@@ -522,8 +504,7 @@ class RecomputeScheduler(Scheduler):
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(
request, i)
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
@@ -537,20 +518,15 @@ class RecomputeScheduler(Scheduler):
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs) <= len(self.running)
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext(
"schedule: get_num_common_prefix_blocks"):
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id))
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
# Construct the scheduler output.
if self.use_v2_model_runner:
@@ -561,17 +537,16 @@ class RecomputeScheduler(Scheduler):
req,
req_to_new_blocks[req.request_id].get_block_ids(),
req._all_token_ids,
) for req in scheduled_new_reqs
)
for req in scheduled_new_reqs
]
else:
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
with record_function_or_nullcontext(
"schedule: make_cached_request_data"):
with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
@@ -592,15 +567,13 @@ class RecomputeScheduler(Scheduler):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id
for req in preempted_reqs},
preempted_req_ids={req.request_id for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
recomputed_reqs=recomputed_reqs,
)
@@ -609,14 +582,12 @@ class RecomputeScheduler(Scheduler):
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta: KVConnectorMetadata = self.connector.build_connector_meta(
scheduler_output)
meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector
if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
scheduler_output)
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output)
scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"):
@@ -639,8 +610,8 @@ class RecomputeScheduler(Scheduler):
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats
if kv_connector_output else None)
kv_connector_output.kv_connector_stats if kv_connector_output else None
)
if kv_connector_stats and self.connector:
kv_stats = self.connector.get_kv_connector_stats()
if kv_stats:
@@ -651,8 +622,7 @@ class RecomputeScheduler(Scheduler):
# These blocks contain externally computed tokens that failed to
# load. Identify affected requests and adjust their computed token
# count to trigger recomputation of the invalid blocks.
failed_kv_load_req_ids = self._handle_invalid_blocks(
kv_connector_output.invalid_block_ids)
failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids)
# return recomputed requests as EngineCoreOutput
if scheduler_output.recomputed_reqs is not None:
@@ -663,7 +633,8 @@ class RecomputeScheduler(Scheduler):
finish_reason=FinishReason.STOP,
new_token_ids=[req_info.output_token_ids[-1]],
stop_reason="recomputed",
))
)
)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
@@ -683,11 +654,9 @@ class RecomputeScheduler(Scheduler):
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = (sampled_token_ids[req_index]
if sampled_token_ids else [])
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
@@ -717,15 +686,13 @@ class RecomputeScheduler(Scheduler):
# Check for stop and update request status.
if new_token_ids:
new_token_ids, stopped = self._update_request_with_output(
request, new_token_ids)
new_token_ids, stopped = self._update_request_with_output(request, new_token_ids)
# Stop checking for pooler models.
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len,
pooler_output)
stopped = check_stop(request, self.max_model_len, pooler_output)
if stopped:
kv_transfer_params = self._free_request(request)
@@ -735,19 +702,14 @@ class RecomputeScheduler(Scheduler):
stopped_preempted_reqs.add(request)
# Extract sample logprobs if needed.
if (request.sampling_params is not None
and request.sampling_params.logprobs is not None
and logprobs):
new_logprobs = logprobs.slice_request(req_index,
len(new_token_ids))
if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs:
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
if new_token_ids and self.structured_output_manager.should_advance(
request):
if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens(
req_id, new_token_ids)
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
@@ -770,7 +732,8 @@ class RecomputeScheduler(Scheduler):
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_nans_in_logits=request.num_nans_in_logits,
))
)
)
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
@@ -805,10 +768,7 @@ class RecomputeScheduler(Scheduler):
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {
client_index: EngineCoreOutputs(outputs=outs)
for client_index, outs in outputs.items()
}
engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()}
finished_req_ids = self.finished_req_ids_dict
if finished_req_ids:
@@ -819,12 +779,10 @@ class RecomputeScheduler(Scheduler):
if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set
else:
engine_core_outputs[client_index] = EngineCoreOutputs(
finished_requests=finished_set)
engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
finished_req_ids.clear()
if (stats := self.make_stats(spec_decoding_stats,
kv_connector_stats)) is not None:
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request
@@ -836,6 +794,5 @@ class RecomputeScheduler(Scheduler):
class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -16,7 +16,6 @@
#
import os
import time
from typing import Optional
import pandas as pd
from vllm.config import VllmConfig
@@ -25,8 +24,7 @@ from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -43,8 +41,9 @@ class BudgetRefiner:
if not self.enabled:
return
logger.info(
"Dynamic batch is enabled with SLO limit: {}, and chunked prefill is forced to be activated because dynamic batch relies on it"
.format(str(slo_limit)))
"Dynamic batch is enabled with SLO limit: {}, and chunked prefill is "
"forced to be activated because dynamic batch relies on it".format(str(slo_limit))
)
self.lookup: dict[tuple[int, int], int] = {}
self.context_keys: set[int] = set()
self.dnum_keys: set[int] = set()
@@ -61,19 +60,20 @@ class BudgetRefiner:
"The dynamic batching feature requires the lookup table "
"'profile_table.csv', but it was not found at '%s'. "
"Please download the corresponding table file.",
table_file_path)
table_file_path,
)
self.enabled = False
return
else:
df = pd.read_csv(table_file_path)
grouped = df.groupby(['ctx_len', 'd_num'])
grouped = df.groupby(["ctx_len", "d_num"])
for (ctx_len, d_num), group in grouped:
valid = group[group['cost'] <= slo_limit]
valid = group[group["cost"] <= slo_limit]
if not valid.empty:
max_row = valid.loc[valid['chunk_size'].idxmax()]
max_row = valid.loc[valid["chunk_size"].idxmax()]
assert isinstance(ctx_len, int), "ctx_len must be an integer"
assert isinstance(d_num, int), "d_num must be an integer"
self.lookup[(ctx_len, d_num)] = int(max_row['chunk_size'])
self.lookup[(ctx_len, d_num)] = int(max_row["chunk_size"])
self.context_keys.add(ctx_len)
self.dnum_keys.add(d_num)
self.context_keys = set(sorted(self.context_keys))
@@ -97,7 +97,10 @@ class BudgetRefiner:
logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}")
budget = self.default_budget
# For debug.
# logger.info(f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, raw ctx,dnum {num_deocde_tokens, num_decode}")
# logger.info(
# f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, "
# f"raw ctx,dnum {num_deocde_tokens, num_decode}"
# )
return budget
def refine_budget(self, running_request, budget):
@@ -106,9 +109,8 @@ class BudgetRefiner:
return budget
# assume all running request will be scheduled.
num_decode_token_lst = [
req.num_tokens_with_spec \
for req in running_request \
if req.num_computed_tokens >= req.num_prompt_tokens ]
req.num_tokens_with_spec for req in running_request if req.num_computed_tokens >= req.num_prompt_tokens
]
num_decode = len(num_decode_token_lst)
if num_decode <= 0:
return budget
@@ -125,18 +127,25 @@ class SchedulerDynamicBatch(Scheduler):
vllm_config: VllmConfig,
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
block_size: Optional[int] = None,
block_size: int | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
super().__init__(vllm_config, kv_cache_config,
structured_output_manager, block_size, mm_registry,
include_finished_set, log_stats)
super().__init__(
vllm_config,
kv_cache_config,
structured_output_manager,
block_size,
mm_registry,
include_finished_set,
log_stats,
)
self.running: list[Request] = []
self.budget_refiner = BudgetRefiner(
default_budget=self.scheduler_config.max_num_batched_tokens,
slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch)
slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch,
)
def schedule(self) -> SchedulerOutput:
# NOTE: This scheduling algorithm is developed based on the "super.schedule()"
@@ -159,20 +168,13 @@ class SchedulerDynamicBatch(Scheduler):
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
token_budget = self.budget_refiner.refine_budget(
self.running, token_budget)
token_budget = self.budget_refiner.refine_budget(self.running, token_budget)
# NOTE: We move the prefill requests to the end of the self.running
# list and keep the relative order unchanged. This rearrangement makes this
# scheduling algorithm a strict decode-first chunked prefills.
d_lst = [
req for req in self.running
if req.num_computed_tokens >= req.num_prompt_tokens
]
p_lst = [
req for req in self.running
if req.num_computed_tokens < req.num_prompt_tokens
]
d_lst = [req for req in self.running if req.num_computed_tokens >= req.num_prompt_tokens]
p_lst = [req for req in self.running if req.num_computed_tokens < req.num_prompt_tokens]
self.running = d_lst + p_lst
# Encoder-related.
@@ -189,30 +191,26 @@ class SchedulerDynamicBatch(Scheduler):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_compute_budget)
(encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget) = (
self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens, encoder_compute_budget
)
)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
@@ -231,9 +229,8 @@ class SchedulerDynamicBatch(Scheduler):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens
)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
@@ -253,8 +250,7 @@ class SchedulerDynamicBatch(Scheduler):
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
preempted_req.record_event(EngineCoreEventType.PREEMPTED, scheduled_timestamp)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
@@ -279,19 +275,15 @@ class SchedulerDynamicBatch(Scheduler):
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
num_scheduled_spec_tokens = num_new_tokens + request.num_computed_tokens - request.num_tokens
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
@@ -301,8 +293,10 @@ class SchedulerDynamicBatch(Scheduler):
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
req.lora_request.lora_int_id
for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
@@ -323,9 +317,7 @@ class SchedulerDynamicBatch(Scheduler):
if is_ready:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
logger.debug("%s is still in WAITING_FOR_REMOTE_KVS state.", request.request_id)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
@@ -343,9 +335,14 @@ class SchedulerDynamicBatch(Scheduler):
# Check that adding the request still respects the max_loras
# constraint.
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras and
request.lora_request.lora_int_id not in scheduled_loras)):
if (
self.lora_config
and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
@@ -357,15 +354,15 @@ class SchedulerDynamicBatch(Scheduler):
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
request
)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
num_external_computed_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens
)
if num_external_computed_tokens is None:
# The request cannot be scheduled because
@@ -376,13 +373,11 @@ class SchedulerDynamicBatch(Scheduler):
continue
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
new_computed_blocks = self.kv_cache_manager.create_empty_block_list()
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
@@ -399,15 +394,12 @@ class SchedulerDynamicBatch(Scheduler):
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.enable_chunked_prefill and \
num_new_tokens > token_budget:
if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
@@ -417,11 +409,11 @@ class SchedulerDynamicBatch(Scheduler):
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget,
_) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_compute_budget)
(encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, _) = (
self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens, encoder_compute_budget
)
)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
@@ -431,9 +423,7 @@ class SchedulerDynamicBatch(Scheduler):
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens
== 0 else
self.num_lookahead_tokens)
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
@@ -441,8 +431,7 @@ class SchedulerDynamicBatch(Scheduler):
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens =\
self.scheduler_config.max_num_encoder_input_tokens
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
@@ -484,20 +473,17 @@ class SchedulerDynamicBatch(Scheduler):
req_index += 1
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
raise RuntimeError(f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = (
self.kv_cache_manager.get_blocks(request.request_id))
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
@@ -507,8 +493,7 @@ class SchedulerDynamicBatch(Scheduler):
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
@@ -526,22 +511,17 @@ class SchedulerDynamicBatch(Scheduler):
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) <= len(self.running))
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id))
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
@@ -564,8 +544,7 @@ class SchedulerDynamicBatch(Scheduler):
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
)
# NOTE(Kuntai): this function is designed for multiple purposes:

View File

@@ -14,61 +14,50 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, Optional
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
class NPUCommunicator(DeviceCommunicatorBase):
def __init__(self,
def __init__(
self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
device: torch.device | None = None,
device_group: dist.ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()
def all_to_all(self,
def all_to_all(
self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
scatter_sizes: list[int] | None = None,
gather_sizes: list[int] | None = None,
) -> torch.Tensor:
if scatter_dim < 0:
scatter_dim += input_.dim()
if gather_dim < 0:
gather_dim += input_.dim()
if scatter_sizes is not None and gather_sizes is not None:
input_list = [
t.contiguous()
for t in torch.split(input_, scatter_sizes, scatter_dim)
]
input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
output_list = []
tensor_shape_base = input_list[self.rank].size()
for i in range(self.world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(
torch.empty(tensor_shape,
dtype=input_.dtype,
device=input_.device))
output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
else:
input_list = [
t.contiguous() for t in torch.tensor_split(
input_, self.world_size, scatter_dim)
]
output_list = [
torch.empty_like(input_list[i]) for i in range(self.world_size)
]
input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)]
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()

View File

@@ -15,7 +15,6 @@
# limitations under the License.
#
from typing import Optional, Union
import torch
import torch.distributed as dist
@@ -24,18 +23,23 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import logger
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
hcclRedOpTypeEnum, hcclUniqueId)
HCCLLibrary,
aclrtStream_t,
buffer_type,
hcclComm_t,
hcclDataTypeEnum,
hcclRedOpTypeEnum,
hcclUniqueId,
)
from vllm_ascend.utils import current_stream
class PyHcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
group: ProcessGroup | StatelessProcessGroup,
device: int | str | torch.device,
library_path: str | None = None,
):
"""
Args:
@@ -52,7 +56,8 @@ class PyHcclCommunicator:
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.HCCL, (
"PyHcclCommunicator should be attached to a non-HCCL group.")
"PyHcclCommunicator should be attached to a non-HCCL group."
)
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
@@ -113,8 +118,7 @@ class PyHcclCommunicator:
# `torch.npu.device` is a context manager that changes the
# current npu device to the specified one
with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank)
stream = current_stream()
# A small all_reduce for warmup.
@@ -123,43 +127,48 @@ class PyHcclCommunicator:
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
if self.disabled:
return None
# hccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {in_tensor.device}"
)
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
self.hccl.hcclAllReduce(
buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op), self.comm,
aclrtStream_t(stream.npu_stream))
hcclRedOpTypeEnum.from_torch(op),
self.comm,
aclrtStream_t(stream.npu_stream),
)
return out_tensor
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
if src == self.rank:
buffer = buffer_type(tensor.data_ptr())
else:
buffer = buffer_type(tensor.data_ptr())
self.hccl.hcclBroadcast(buffer, tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, aclrtStream_t(stream.npu_stream))
self.hccl.hcclBroadcast(
buffer,
tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
aclrtStream_t(stream.npu_stream),
)

View File

@@ -18,7 +18,7 @@
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any
import torch
from torch.distributed import ReduceOp
@@ -107,33 +107,36 @@ class hcclRedOpTypeEnum:
class Function:
name: str
restype: Any
argtypes: List[Any]
argtypes: list[Any]
class HCCLLibrary:
exported_functions = [
# const char* HcclGetErrorString(HcclResult code);
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
Function("HcclGetRootInfo", hcclResult_t,
[ctypes.POINTER(hcclUniqueId)]),
Function("HcclGetRootInfo", hcclResult_t, [ctypes.POINTER(hcclUniqueId)]),
# HcclResult HcclCommInitRootInfo(
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
Function("HcclCommInitRootInfo", hcclResult_t, [
Function(
"HcclCommInitRootInfo",
hcclResult_t,
[
ctypes.c_int,
ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
]),
],
),
# HcclResult HcclAllReduce(
# void *sendBuf, void *recvBuf, uint64_t count,
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
# aclrtStream stream);
Function("HcclAllReduce", hcclResult_t, [
Function(
"HcclAllReduce",
hcclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
@@ -141,35 +144,37 @@ class HCCLLibrary:
hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
]),
],
),
# HcclResult HcclBroadcast(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream stream);
Function("HcclBroadcast", hcclResult_t, [
Function(
"HcclBroadcast",
hcclResult_t,
[
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
]),
],
),
# HcclResult HcclCommDestroy(HcclComm comm);
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the correspongding directory
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: str | None = None):
so_file = so_file or find_hccl_library()
try:
@@ -185,12 +190,14 @@ class HCCLLibrary:
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable HCCL_SO_PATH"
" to point to the correct hccl library path.", so_file,
platform.platform())
" to point to the correct hccl library path.",
so_file,
platform.platform(),
)
raise e
if so_file not in HCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
_funcs: dict[str, Any] = {}
for func in HCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
@@ -209,34 +216,37 @@ class HCCLLibrary:
def hcclGetUniqueId(self) -> hcclUniqueId:
unique_id = hcclUniqueId()
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
ctypes.byref(unique_id)))
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](ctypes.byref(unique_id)))
return unique_id
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
rank: int) -> hcclComm_t:
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, rank: int) -> hcclComm_t:
comm = hcclComm_t()
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
self.HCCL_CHECK(
self._funcs["HcclCommInitRootInfo"](world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))
)
return comm
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
def hcclAllReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: hcclComm_t,
stream: aclrtStream_t,
) -> None:
# `datatype` actually should be `hcclDataType_t`
# and `op` should be `hcclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream))
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
root: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
root, comm, stream))
def hcclBroadcast(
self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t
) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, root, comm, stream))
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))