### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
This commit is contained in:
@@ -51,11 +51,6 @@ line-length = 120
|
||||
# Folder to be modified
|
||||
exclude = [
|
||||
"tests/**",
|
||||
# (3)
|
||||
"vllm_ascend/attention/*.py",
|
||||
"vllm_ascend/core/*.py",
|
||||
"vllm_ascend/distributed/device_communicators/**",
|
||||
"vllm_ascend/distributed/utils.py",
|
||||
# (5)
|
||||
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
|
||||
"vllm_ascend/distributed/kv_transfer/utils/**",
|
||||
|
||||
@@ -394,7 +394,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
|
||||
kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # type: ignore[misc]
|
||||
assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
|
||||
kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1])
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from torch import nn
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backend import ( # type: ignore
|
||||
AttentionBackend, AttentionCGSupport, MLAAttentionImpl)
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
@@ -22,22 +20,33 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.attention.utils import (
|
||||
AscendCommonAttentionMetadata,
|
||||
ascend_chunked_prefill_workspace_size,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
trans_rope_weight,
|
||||
transdata,
|
||||
wait_for_kv_layer_from_connector,
|
||||
)
|
||||
from vllm_ascend.distributed.utils import all_gather_async
|
||||
from vllm_ascend.ops.layer_shard_linear import (
|
||||
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||
is_hidden_layer,
|
||||
post_process_after_loading_for_shard_weight_series,
|
||||
reach_layer_for_shard_weight_series,
|
||||
register_all_layers_to_shard_weight_series)
|
||||
register_all_layers_to_shard_weight_series,
|
||||
)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
||||
enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz)
|
||||
from vllm_ascend.utils import (
|
||||
ACL_FORMAT_FRACTAL_ND,
|
||||
_round_up,
|
||||
dispose_layer,
|
||||
enable_dsa_cp,
|
||||
enable_dsa_cp_with_layer_shard,
|
||||
maybe_trans_nz,
|
||||
)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -48,7 +57,6 @@ BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024
|
||||
|
||||
|
||||
class AscendSFABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
@@ -63,12 +71,11 @@ class AscendSFABackend(AttentionBackend):
|
||||
return AscendSFAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendSFAImpl"]:
|
||||
def get_impl_cls() -> type["AscendSFAImpl"]:
|
||||
return AscendSFAImpl
|
||||
|
||||
|
||||
@@ -91,6 +98,7 @@ class AscendSFAMetadata:
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
@@ -109,11 +117,11 @@ class AscendSFAMetadata:
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
head_dim: int | None = None
|
||||
attn_mask: torch.Tensor = None
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
dsa_cp_context: Optional[DSACPContext] = None
|
||||
dsa_cp_context: DSACPContext | None = None
|
||||
reshape_cache_event: torch.npu.Event = None
|
||||
|
||||
|
||||
@@ -136,37 +144,38 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
metadata_cls if metadata_cls is not None else AscendSFAMetadata,
|
||||
supports_dcp_with_varlen)
|
||||
supports_dcp_with_varlen,
|
||||
)
|
||||
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
self.block_size - 1) // self.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold += spec_token_num
|
||||
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||
assert self.decode_threshold <= 16, (
|
||||
f"decode_threshold exceeded \
|
||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||
got {self.decode_threshold}"
|
||||
)
|
||||
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.enable_dsa_cp = enable_dsa_cp()
|
||||
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.actual_seq_lengths_key = torch.empty_like(
|
||||
self.actual_seq_lengths_query)
|
||||
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
|
||||
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
|
||||
|
||||
@staticmethod
|
||||
def determine_chunked_prefill_workspace_size(
|
||||
vllm_config: VllmConfig) -> int:
|
||||
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
|
||||
return ascend_chunked_prefill_workspace_size(vllm_config)
|
||||
|
||||
@classmethod
|
||||
@@ -179,8 +188,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def reorder_batch(self, input_batch: "NPUInputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool:
|
||||
# No need to reorder for Ascend SFA
|
||||
return False
|
||||
|
||||
@@ -196,9 +204,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_input_tokens].long(
|
||||
)
|
||||
input_positions = common_attn_metadata.positions[:num_input_tokens].long()
|
||||
|
||||
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||
@@ -216,8 +222,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
local_end = min(local_end_with_pad, num_actual_tokens)
|
||||
|
||||
pad_size = num_tokens_pad - cos.shape[0]
|
||||
assert cos.shape == sin.shape, \
|
||||
f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
|
||||
assert cos.shape == sin.shape, f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
|
||||
|
||||
if pad_size > 0:
|
||||
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
||||
@@ -225,9 +230,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
|
||||
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
|
||||
if pad_size_slot > 0:
|
||||
slot_mapping = nn.functional.pad(slot_mapping,
|
||||
(0, pad_size_slot),
|
||||
value=-1)
|
||||
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size_slot), value=-1)
|
||||
else:
|
||||
slot_mapping = slot_mapping[:num_tokens_pad]
|
||||
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
||||
@@ -235,15 +238,18 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
cos = cos[local_start:local_end_with_pad]
|
||||
sin = sin[local_start:local_end_with_pad]
|
||||
|
||||
assert cos.shape[0] == num_tokens_per_device, \
|
||||
assert cos.shape[0] == num_tokens_per_device, (
|
||||
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
||||
got {cos.shape[0]} and {num_tokens_per_device}"
|
||||
assert slot_mapping_cp.shape[0] == num_tokens_per_device, \
|
||||
)
|
||||
assert slot_mapping_cp.shape[0] == num_tokens_per_device, (
|
||||
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
|
||||
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
|
||||
assert slot_mapping.shape[0] == num_tokens_pad, \
|
||||
)
|
||||
assert slot_mapping.shape[0] == num_tokens_pad, (
|
||||
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
|
||||
got {slot_mapping.shape[0]} and {num_tokens_pad}"
|
||||
)
|
||||
|
||||
actual_seq_lengths_query = self.actual_seq_lengths_query
|
||||
actual_seq_lengths_key = self.actual_seq_lengths_key
|
||||
@@ -291,31 +297,26 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
attn_mask=self.attn_mask_builder.get_attention_mask(
|
||||
self.model_config),
|
||||
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
block_tables=block_table,
|
||||
sin=sin[:num_input_tokens],
|
||||
cos=cos[:num_input_tokens],
|
||||
dsa_cp_context=dsa_cp_context)
|
||||
dsa_cp_context=dsa_cp_context,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
):
|
||||
if attn_state in {
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
}:
|
||||
if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||
)
|
||||
raise NotImplementedError("Currently we only support building dummy metadata for DecodeOnly state")
|
||||
|
||||
attn_metadata.attn_state = attn_state
|
||||
return attn_metadata
|
||||
@@ -326,8 +327,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
||||
o_proj_full_pool: Optional[torch.Tensor] = None
|
||||
o_proj_full_pool: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -335,12 +337,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
@@ -350,26 +352,25 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# MLA Args
|
||||
self.q_lora_rank = kwargs['q_lora_rank']
|
||||
self.kv_lora_rank = kwargs['kv_lora_rank']
|
||||
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
||||
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
||||
self.qk_head_dim = kwargs['qk_head_dim']
|
||||
self.v_head_dim = kwargs['v_head_dim']
|
||||
self.rotary_emb = kwargs['rotary_emb']
|
||||
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
|
||||
'q_b_proj']
|
||||
self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None)
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.indexer = kwargs['indexer']
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
self.q_lora_rank = kwargs["q_lora_rank"]
|
||||
self.kv_lora_rank = kwargs["kv_lora_rank"]
|
||||
self.qk_nope_head_dim = kwargs["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim = kwargs["qk_rope_head_dim"]
|
||||
self.qk_head_dim = kwargs["qk_head_dim"]
|
||||
self.v_head_dim = kwargs["v_head_dim"]
|
||||
self.rotary_emb = kwargs["rotary_emb"]
|
||||
self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"]
|
||||
self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj")
|
||||
self.kv_b_proj = kwargs["kv_b_proj"]
|
||||
self.o_proj = kwargs["o_proj"]
|
||||
self.indexer = kwargs["indexer"]
|
||||
self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = kwargs.get("kv_a_layernorm")
|
||||
self.q_a_layernorm = kwargs.get("q_a_layernorm")
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.q_b_proj = kwargs['q_b_proj']
|
||||
self.q_b_proj = kwargs["q_b_proj"]
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
@@ -383,7 +384,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
self.local_num_heads = self.num_heads
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.is_kv_producer = (
|
||||
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
)
|
||||
|
||||
# indexer param
|
||||
self.n_head: int = self.indexer.n_head # 64
|
||||
@@ -400,38 +403,38 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
if self.enable_dsa_cp_prefill_only:
|
||||
self.layer_sharding_kwargs = []
|
||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||
for layer_name in get_ascend_config().layer_sharding or []:
|
||||
if layer_name in kwargs:
|
||||
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||
else:
|
||||
logger.warning_once(
|
||||
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, "
|
||||
"skipping sharding configuration"
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(
|
||||
self.layer_sharding_kwargs)
|
||||
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
# NOTE: We currently do not support quant kv_b_proj.
|
||||
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
||||
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
|
||||
kv_b_proj_weight = torch_npu.npu_format_cast(
|
||||
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
|
||||
kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank, self.local_num_heads *
|
||||
(self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
self.kv_lora_rank,
|
||||
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.local_num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
f"{self.v_head_dim=}"
|
||||
)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.local_num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
||||
@@ -445,10 +448,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
dispose_layer(self.kv_b_proj)
|
||||
if self.enable_dsa_cp:
|
||||
if self.enable_dsa_cp_prefill_only:
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
for layer in self.layer_sharding_kwargs or []:
|
||||
if is_hidden_layer(layer):
|
||||
post_process_after_loading_for_shard_weight_series(
|
||||
layer)
|
||||
post_process_after_loading_for_shard_weight_series(layer)
|
||||
else:
|
||||
self._init_o_proj_tp_full_params()
|
||||
|
||||
@@ -459,15 +461,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
None,
|
||||
)
|
||||
reasons = []
|
||||
if self.fused_qkv_a_proj is None or not isinstance(
|
||||
quant_method, AscendW8A8LinearMethod):
|
||||
if self.fused_qkv_a_proj is None or not isinstance(quant_method, AscendW8A8LinearMethod):
|
||||
reasons.append(
|
||||
"Currently mlapo only supports W8A8 quantization in SFA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
"thus mlapo is disabled for these layers."
|
||||
)
|
||||
if self.enable_dsa_cp:
|
||||
reasons.append("Currently mlapo does not support SFA with CP,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
reasons.append("Currently mlapo does not support SFA with CP,thus mlapo is disabled for these layers.")
|
||||
if reasons:
|
||||
self.enable_mlapo = False
|
||||
for msg in reasons:
|
||||
@@ -480,32 +481,31 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
num_input_tokens, _, _ = x.shape
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
||||
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS:
|
||||
if (
|
||||
x.dtype in [torch.float16, torch.bfloat16]
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose")
|
||||
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS
|
||||
):
|
||||
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
||||
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
||||
x = res.reshape(-1, self.local_num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.local_num_heads,
|
||||
self.kv_lora_rank).transpose(0, 1)
|
||||
x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# # Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0,
|
||||
1).reshape(-1,
|
||||
self.local_num_heads * self.v_head_dim)
|
||||
x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim)
|
||||
return x
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.local_num_heads, self.qk_head_dim)\
|
||||
q_nope, q_pe = (
|
||||
self.q_proj(x)[0]
|
||||
.view(-1, self.local_num_heads, self.qk_head_dim)
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
@@ -519,27 +519,26 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
kv_no_split: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
kv_cache: tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
B = kv_no_split.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA"
|
||||
|
||||
if self.enable_dsa_cp:
|
||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
self.kv_a_layernorm.weight, # type: ignore[union-attr]
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
|
||||
cache_mode=cache_mode,
|
||||
is_output_kv=True,
|
||||
)
|
||||
@@ -547,13 +546,13 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
else:
|
||||
torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
self.kv_a_layernorm.weight, # type: ignore[union-attr]
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
return None, None
|
||||
@@ -577,78 +576,53 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
assert self.kv_a_proj_with_mqa is None
|
||||
assert self.fused_qkv_a_proj is not None
|
||||
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., self.q_lora_rank:].contiguous()
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., :self.q_lora_rank].contiguous()
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous()
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous()
|
||||
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
|
||||
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
|
||||
self.q_lora_rank:].contiguous()
|
||||
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
|
||||
q_lora_rank].contiguous(
|
||||
)
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
|
||||
self.qk_rope_head_dim)
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
|
||||
dim=-1).contiguous()
|
||||
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous()
|
||||
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous()
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim)
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous()
|
||||
|
||||
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
|
||||
self.q_lora_rank:].contiguous()
|
||||
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
|
||||
q_lora_rank].contiguous(
|
||||
)
|
||||
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous()
|
||||
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous()
|
||||
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
|
||||
self.qk_rope_head_dim)
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
|
||||
dim=-1).contiguous()
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim)
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous()
|
||||
|
||||
wu_q = self.q_proj.weight.data
|
||||
wu_q = wu_q.t().reshape(self.num_heads,
|
||||
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
-1)
|
||||
wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
|
||||
wu_q = wu_q.reshape(
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||
-1)
|
||||
wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1)
|
||||
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
|
||||
qb_deq_scl = self.q_proj.deq_scale.data
|
||||
qb_deq_scl = qb_deq_scl.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||
qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
|
||||
self.qb_deq_scl = qb_deq_scl.reshape(
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
|
||||
qb_qt_bias = self.q_proj.quant_bias.data
|
||||
qb_qt_bias = qb_qt_bias.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||
qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
||||
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
|
||||
self.qb_qt_bias = qb_qt_bias.reshape(
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
|
||||
device = self.q_proj.weight.device
|
||||
self.gamma1 = self.q_a_layernorm.weight.data
|
||||
self.beta1 = self.q_a_layernorm.bias.data
|
||||
self.gamma2 = self.kv_a_layernorm.weight.data
|
||||
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
|
||||
self.beta1 = self.q_a_layernorm.bias.data # type: ignore[union-attr]
|
||||
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
|
||||
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
|
||||
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
|
||||
self.quant_scale1 = self.q_proj.input_scale.data
|
||||
@@ -659,9 +633,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||
# referenced, so drop them to save memory.
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
if (
|
||||
self.vllm_config.kv_transfer_config is not None
|
||||
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
||||
and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS
|
||||
):
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
@@ -673,13 +649,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
def _sfa_preprocess_decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool,
|
||||
num_input_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), need_gather_q_kv)
|
||||
k_nope, k_pe = kv_cache[0], kv_cache[1]
|
||||
ql_nope = torch.empty(
|
||||
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
|
||||
@@ -734,17 +709,17 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self,
|
||||
layer_name,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
forward_context = get_forward_context()
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run:
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
for layer in self.layer_sharding_kwargs or []:
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
return output.fill_(0)
|
||||
@@ -761,12 +736,13 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
# all-gather o_proj weight for prefill stage of PD mix node
|
||||
o_proj_full_handle = None
|
||||
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage.
|
||||
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj
|
||||
# weight for prefill stage.
|
||||
should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in {
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
}
|
||||
|
||||
|
||||
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||
hidden_states=hidden_states,
|
||||
@@ -776,35 +752,30 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
num_input_tokens=num_input_tokens,
|
||||
)
|
||||
q, k = self.indexer_select_pre_process(
|
||||
x=hidden_states,
|
||||
qr=q_c,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv
|
||||
)
|
||||
else:
|
||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||
dependency=hidden_states,
|
||||
enabled=self.enable_prefetch)
|
||||
maybe_npu_prefetch(
|
||||
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
|
||||
)
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_no_split = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
# Process for Flash Comm V1
|
||||
if need_gather_q_kv:
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c.contiguous(), need_gather_q_kv)
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
kv_no_split.contiguous(), need_gather_q_kv
|
||||
)
|
||||
|
||||
q, k = self.indexer_select_pre_process(
|
||||
x=hidden_states,
|
||||
qr=q_c,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv
|
||||
)
|
||||
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
@@ -815,22 +786,20 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
|
||||
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
|
||||
|
||||
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
|
||||
slot_mapping)
|
||||
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping)
|
||||
|
||||
if self.enable_dsa_cp:
|
||||
assert k_pe is not None
|
||||
assert k_nope is not None
|
||||
# support all_gather kv async for communication calculation overlap
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat([
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k.view(-1, k.shape[-1])
|
||||
],
|
||||
dim=1),
|
||||
torch.cat(
|
||||
[k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), k.view(-1, k.shape[-1])],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=should_shard_weight)
|
||||
async_op=should_shard_weight,
|
||||
)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
@@ -840,34 +809,27 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
kv_ag_handle.wait()
|
||||
|
||||
if self.enable_dsa_cp_prefill_only:
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
for layer in self.layer_sharding_kwargs or []:
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
elif should_shard_weight:
|
||||
_, o_proj_full_handle = all_gather_async(
|
||||
self.o_proj_tp_weight,
|
||||
get_tp_group(),
|
||||
output=AscendSFAImpl.o_proj_full_pool)
|
||||
self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool
|
||||
)
|
||||
|
||||
if kv_cache is not None:
|
||||
assert fused_kv_no_split is not None
|
||||
k_pe, k_nope, k = fused_kv_no_split.split([
|
||||
self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim
|
||||
],
|
||||
dim=-1)
|
||||
k_pe, k_nope, k = fused_kv_no_split.split(
|
||||
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
||||
)
|
||||
slot_mapping = attn_metadata.slot_mapping.view(-1, 1)
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping,
|
||||
k_nope)
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
|
||||
k_pe)
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope)
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe)
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(-1, 1),
|
||||
k.view(-1, k.shape[-1])) # b, s, n, d
|
||||
kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1])
|
||||
) # b, s, n, d
|
||||
|
||||
topk_indices = self.indexer_select_post_process(
|
||||
x=hidden_states,
|
||||
@@ -880,7 +842,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
sin=sin,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
)
|
||||
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
@@ -900,10 +863,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
)
|
||||
|
||||
attn_output = self._v_up_proj(attn_output)
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
maybe_npu_prefetch(
|
||||
inputs=self.o_proj.weight,
|
||||
dependency=attn_output,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
enabled=self.enable_prefetch,
|
||||
)
|
||||
|
||||
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:
|
||||
# When using SFA-CP with pd mixed, o_proj has two cases:
|
||||
@@ -913,7 +878,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
attn_output=attn_output,
|
||||
output=output,
|
||||
o_proj_full_handle=o_proj_full_handle,
|
||||
should_shard_weight=should_shard_weight)
|
||||
should_shard_weight=should_shard_weight,
|
||||
)
|
||||
if not require_o_proj_forward:
|
||||
return result
|
||||
attn_output = result
|
||||
@@ -933,8 +899,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
need_gather_q_kv: bool = False,
|
||||
):
|
||||
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
k_proj, need_gather_q_kv)
|
||||
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(k_proj, need_gather_q_kv)
|
||||
k = self.k_norm(k_proj).unsqueeze(1)
|
||||
k = k.view(-1, 1, self.head_dim)
|
||||
|
||||
@@ -944,17 +909,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
cos = cos.view(-1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, self.qk_rope_head_dim)
|
||||
q, k = rope_forward_triton(q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
is_neox_style=True)
|
||||
q, k = rope_forward_triton(q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=True)
|
||||
else:
|
||||
k_pe, k_nope = torch.split(
|
||||
k,
|
||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
@@ -972,9 +929,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
q: Optional[torch.Tensor],
|
||||
q: torch.Tensor | None,
|
||||
k: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
@@ -988,9 +945,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
cos_q, sin_q = cos, sin
|
||||
|
||||
q_pe, q_nope = torch.split(
|
||||
q,
|
||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64,64+64]
|
||||
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
|
||||
) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q)
|
||||
@@ -1000,17 +956,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if kv_cache is not None:
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(
|
||||
-1, 1),
|
||||
k.view(-1,
|
||||
k.shape[-1])) # b, s, n, d
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1])
|
||||
) # b, s, n, d
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
weights, need_gather_q_kv)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
|
||||
|
||||
block_table = attn_metadata.block_tables
|
||||
|
||||
@@ -1024,7 +977,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3)
|
||||
sparse_mode=3,
|
||||
)
|
||||
return topk_indices
|
||||
|
||||
def _init_o_proj_tp_full_params(self):
|
||||
@@ -1039,38 +993,33 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if AscendSFAImpl.o_proj_full_pool is None:
|
||||
sample = self.o_proj.weight
|
||||
AscendSFAImpl.o_proj_full_pool = torch.empty(
|
||||
(sample.shape[0] * self.tp_size, sample.shape[1]),
|
||||
dtype=sample.dtype,
|
||||
device=sample.device)
|
||||
(sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device
|
||||
)
|
||||
|
||||
# Save TP-mode parameters (original sharded weights)
|
||||
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
|
||||
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone(
|
||||
).detach()
|
||||
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone(
|
||||
).detach()
|
||||
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone(
|
||||
).detach()
|
||||
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach()
|
||||
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach()
|
||||
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach()
|
||||
|
||||
# Initially switch to TP mode for graph capture
|
||||
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
||||
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
|
||||
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
||||
self.o_proj_tp_aclnn_input_scale_reciprocal)
|
||||
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
|
||||
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
|
||||
|
||||
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
|
||||
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(
|
||||
self.tp_size)
|
||||
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(
|
||||
self.tp_size)
|
||||
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(
|
||||
self.tp_size)
|
||||
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size)
|
||||
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size)
|
||||
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size)
|
||||
|
||||
def _handle_o_proj_weight_switch_and_forward(
|
||||
self, attn_output: torch.Tensor, output: torch.Tensor,
|
||||
o_proj_full_handle: Optional[torch.distributed.Work],
|
||||
should_shard_weight: bool) -> Tuple[torch.Tensor, bool]:
|
||||
self,
|
||||
attn_output: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
o_proj_full_handle: torch.distributed.Work | None,
|
||||
should_shard_weight: bool,
|
||||
) -> tuple[torch.Tensor, bool]:
|
||||
"""
|
||||
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
|
||||
"""
|
||||
@@ -1082,36 +1031,30 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
|
||||
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
|
||||
self.o_proj.aclnn_input_scale.set_(
|
||||
self.o_proj_full_aclnn_input_scale)
|
||||
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
||||
self.o_proj_full_aclnn_input_scale_reciprocal)
|
||||
self.o_proj.aclnn_input_offset.set_(
|
||||
self.o_proj_full_aclnn_input_offset)
|
||||
self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale)
|
||||
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal)
|
||||
self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset)
|
||||
|
||||
# Apply quantization method and execute forward computation
|
||||
output[...] = self.o_proj.quant_method.quant_method.apply(
|
||||
self.o_proj, attn_output)
|
||||
output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output)
|
||||
|
||||
# Switch o_proj back to TP-mode for subsequent decode operations
|
||||
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
||||
self.o_proj.aclnn_input_scale.set_(
|
||||
self.o_proj_tp_aclnn_input_scale)
|
||||
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
||||
self.o_proj_tp_aclnn_input_scale_reciprocal)
|
||||
self.o_proj.aclnn_input_offset.set_(
|
||||
self.o_proj_tp_aclnn_input_offset)
|
||||
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
|
||||
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
|
||||
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
|
||||
|
||||
return output, False
|
||||
else:
|
||||
# For decode scenario: perform all-to-all communication on o_proj input activations
|
||||
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
|
||||
send = attn_output.view(-1, self.tp_size, self.num_heads *
|
||||
self.v_head_dim).permute(1, 0, 2).reshape(
|
||||
-1, self.num_heads * self.v_head_dim)
|
||||
send = (
|
||||
attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(send)
|
||||
torch.distributed.all_to_all_single(
|
||||
attn_output, send, group=get_tp_group().device_group)
|
||||
torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group)
|
||||
|
||||
return attn_output, True
|
||||
|
||||
@@ -21,26 +21,21 @@ from __future__ import annotations
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Type, Union
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
from vllm.config import SchedulerConfig, VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \
|
||||
KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||
create_request_queue)
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs, FinishReason)
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
@@ -51,26 +46,22 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class RecomputeSchedulerConfig(SchedulerConfig):
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
|
||||
scheduler_cls: str | type[object] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(cls, vllm_config: VllmConfig):
|
||||
vllm_scheduler_config = vllm_config.scheduler_config
|
||||
scheduler_config = {
|
||||
field.name: getattr(vllm_scheduler_config, field.name)
|
||||
for field in fields(vllm_scheduler_config) if field.init
|
||||
for field in fields(vllm_scheduler_config)
|
||||
if field.init
|
||||
}
|
||||
if vllm_scheduler_config.async_scheduling:
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler")
|
||||
scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler"
|
||||
else:
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
|
||||
scheduler_config[
|
||||
"max_model_len"] = vllm_config.model_config.max_model_len
|
||||
scheduler_config[
|
||||
"is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
|
||||
scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
|
||||
scheduler_config["max_model_len"] = vllm_config.model_config.max_model_len
|
||||
scheduler_config["is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder
|
||||
return cls(**scheduler_config)
|
||||
|
||||
|
||||
@@ -125,33 +116,32 @@ class RecomputeScheduler(Scheduler):
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
|
||||
if (request.num_output_placeholders > 0
|
||||
if (
|
||||
request.num_output_placeholders > 0
|
||||
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
|
||||
# Since output placeholders are also included in the computed tokens
|
||||
# count, we subtract (num_output_placeholders - 1) to remove any draft
|
||||
# tokens, so that we can be sure no further steps are needed even if
|
||||
# they are all rejected.
|
||||
and request.num_computed_tokens + 2 -
|
||||
request.num_output_placeholders
|
||||
>= request.num_prompt_tokens + request.max_tokens):
|
||||
and request.num_computed_tokens + 2 - request.num_output_placeholders
|
||||
>= request.num_prompt_tokens + request.max_tokens
|
||||
):
|
||||
# Async scheduling: Avoid scheduling an extra step when we are sure that
|
||||
# the previous step has reached request.max_tokens. We don't schedule
|
||||
# partial draft tokens since this prevents uniform decode optimizations.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec +
|
||||
request.num_output_placeholders -
|
||||
request.num_computed_tokens)
|
||||
num_new_tokens = (
|
||||
request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
|
||||
)
|
||||
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - 1 - request.num_computed_tokens)
|
||||
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
@@ -209,9 +199,10 @@ class RecomputeScheduler(Scheduler):
|
||||
recomputed_req = self.running.pop()
|
||||
self.kv_cache_manager.free(recomputed_req)
|
||||
recomputed_reqs.append(
|
||||
RecomputeReqInfo(recomputed_req.request_id,
|
||||
recomputed_req.output_token_ids,
|
||||
recomputed_req.client_index))
|
||||
RecomputeReqInfo(
|
||||
recomputed_req.request_id, recomputed_req.output_token_ids, recomputed_req.client_index
|
||||
)
|
||||
)
|
||||
if recomputed_req == request:
|
||||
break
|
||||
else:
|
||||
@@ -223,28 +214,23 @@ class RecomputeScheduler(Scheduler):
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens[
|
||||
preempted_req.request_id]
|
||||
token_budget += num_scheduled_tokens[preempted_req.request_id]
|
||||
req_to_new_blocks.pop(preempted_req.request_id)
|
||||
num_scheduled_tokens.pop(
|
||||
preempted_req.request_id)
|
||||
scheduled_spec_decode_tokens.pop(
|
||||
preempted_req.request_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
|
||||
preempted_req.request_id, None)
|
||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
|
||||
if preempted_encoder_inputs:
|
||||
# Restore encoder compute budget if the preempted
|
||||
# request had encoder inputs scheduled in this step.
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs)
|
||||
preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self._preempt_request(preempted_req,
|
||||
scheduled_timestamp)
|
||||
self._preempt_request(preempted_req, scheduled_timestamp)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
@@ -263,23 +249,20 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens -
|
||||
request.num_output_placeholders)
|
||||
num_scheduled_spec_tokens = (
|
||||
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
|
||||
# New spec tokens will be set in `update_draft_token_ids` before the
|
||||
# next step when applicable.
|
||||
request.spec_token_ids = []
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -294,8 +277,10 @@ class RecomputeScheduler(Scheduler):
|
||||
scheduled_loras: set[int] = set()
|
||||
if self.lora_config:
|
||||
scheduled_loras = set(
|
||||
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
req.lora_request.lora_int_id
|
||||
for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0
|
||||
)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
@@ -337,9 +322,14 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras and
|
||||
request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
if (
|
||||
self.lora_config
|
||||
and request.lora_request
|
||||
and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras
|
||||
)
|
||||
):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
@@ -351,14 +341,15 @@ class RecomputeScheduler(Scheduler):
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
# Get locally-cached tokens.
|
||||
new_computed_blocks, num_new_local_computed_tokens = (
|
||||
self.kv_cache_manager.get_computed_blocks(request))
|
||||
new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
|
||||
request
|
||||
)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
ext_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens))
|
||||
ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens
|
||||
)
|
||||
|
||||
if ext_tokens is None:
|
||||
# The request cannot be scheduled because
|
||||
@@ -372,8 +363,7 @@ class RecomputeScheduler(Scheduler):
|
||||
num_external_computed_tokens = ext_tokens
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_new_local_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
|
||||
else:
|
||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||
# after async KV recvs are completed.
|
||||
@@ -401,8 +391,7 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
if (not self.scheduler_config.enable_chunked_prefill
|
||||
and num_new_tokens > token_budget):
|
||||
if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
|
||||
# If chunked_prefill is disabled,
|
||||
# we can stop the scheduling here.
|
||||
break
|
||||
@@ -433,9 +422,7 @@ class RecomputeScheduler(Scheduler):
|
||||
# extra block gets allocated which
|
||||
# creates a mismatch between the number
|
||||
# of local and remote blocks.
|
||||
effective_lookahead_tokens = (0 if request.num_computed_tokens
|
||||
== 0 else
|
||||
self.num_lookahead_tokens)
|
||||
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
@@ -443,8 +430,7 @@ class RecomputeScheduler(Scheduler):
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens = (
|
||||
self.scheduler_config.max_num_encoder_input_tokens)
|
||||
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
|
||||
@@ -488,20 +474,17 @@ class RecomputeScheduler(Scheduler):
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid request status: {request.status}")
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_blocks[request.request_id] = (
|
||||
self.kv_cache_manager.get_blocks(request.request_id))
|
||||
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
@@ -511,8 +494,7 @@ class RecomputeScheduler(Scheduler):
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -522,8 +504,7 @@ class RecomputeScheduler(Scheduler):
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(
|
||||
request, i)
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
@@ -537,20 +518,15 @@ class RecomputeScheduler(Scheduler):
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
# this step, the total number of scheduled requests can be smaller than
|
||||
# len(self.running).
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs) <= len(self.running)
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(
|
||||
self.kv_cache_config.kv_cache_groups)
|
||||
with record_function_or_nullcontext(
|
||||
"schedule: get_num_common_prefix_blocks"):
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request.request_id))
|
||||
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
|
||||
|
||||
# Construct the scheduler output.
|
||||
if self.use_v2_model_runner:
|
||||
@@ -561,17 +537,16 @@ class RecomputeScheduler(Scheduler):
|
||||
req,
|
||||
req_to_new_blocks[req.request_id].get_block_ids(),
|
||||
req._all_token_ids,
|
||||
) for req in scheduled_new_reqs
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
with record_function_or_nullcontext(
|
||||
"schedule: make_cached_request_data"):
|
||||
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
@@ -592,15 +567,13 @@ class RecomputeScheduler(Scheduler):
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids={req.request_id
|
||||
for req in preempted_reqs},
|
||||
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
recomputed_reqs=recomputed_reqs,
|
||||
)
|
||||
|
||||
@@ -609,14 +582,12 @@ class RecomputeScheduler(Scheduler):
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
||||
scheduler_output)
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# Build the connector meta for ECConnector
|
||||
if self.ec_connector is not None:
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
||||
scheduler_output)
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.ec_connector_metadata = ec_meta
|
||||
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
@@ -639,8 +610,8 @@ class RecomputeScheduler(Scheduler):
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: SpecDecodingStats | None = None
|
||||
kv_connector_stats: KVConnectorStats | None = (
|
||||
kv_connector_output.kv_connector_stats
|
||||
if kv_connector_output else None)
|
||||
kv_connector_output.kv_connector_stats if kv_connector_output else None
|
||||
)
|
||||
if kv_connector_stats and self.connector:
|
||||
kv_stats = self.connector.get_kv_connector_stats()
|
||||
if kv_stats:
|
||||
@@ -651,8 +622,7 @@ class RecomputeScheduler(Scheduler):
|
||||
# These blocks contain externally computed tokens that failed to
|
||||
# load. Identify affected requests and adjust their computed token
|
||||
# count to trigger recomputation of the invalid blocks.
|
||||
failed_kv_load_req_ids = self._handle_invalid_blocks(
|
||||
kv_connector_output.invalid_block_ids)
|
||||
failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids)
|
||||
|
||||
# return recomputed requests as EngineCoreOutput
|
||||
if scheduler_output.recomputed_reqs is not None:
|
||||
@@ -663,7 +633,8 @@ class RecomputeScheduler(Scheduler):
|
||||
finish_reason=FinishReason.STOP,
|
||||
new_token_ids=[req_info.output_token_ids[-1]],
|
||||
stop_reason="recomputed",
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
|
||||
# the below loop can be a performance bottleneck. We should do our best
|
||||
@@ -683,11 +654,9 @@ class RecomputeScheduler(Scheduler):
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = (sampled_token_ids[req_index]
|
||||
if sampled_token_ids else [])
|
||||
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
|
||||
if scheduled_spec_token_ids:
|
||||
num_draft_tokens = len(scheduled_spec_token_ids)
|
||||
num_accepted = len(generated_token_ids) - 1
|
||||
@@ -717,15 +686,13 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Check for stop and update request status.
|
||||
if new_token_ids:
|
||||
new_token_ids, stopped = self._update_request_with_output(
|
||||
request, new_token_ids)
|
||||
new_token_ids, stopped = self._update_request_with_output(request, new_token_ids)
|
||||
|
||||
# Stop checking for pooler models.
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
pooler_output = pooler_outputs[req_index]
|
||||
stopped = check_stop(request, self.max_model_len,
|
||||
pooler_output)
|
||||
stopped = check_stop(request, self.max_model_len, pooler_output)
|
||||
|
||||
if stopped:
|
||||
kv_transfer_params = self._free_request(request)
|
||||
@@ -735,19 +702,14 @@ class RecomputeScheduler(Scheduler):
|
||||
stopped_preempted_reqs.add(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if (request.sampling_params is not None
|
||||
and request.sampling_params.logprobs is not None
|
||||
and logprobs):
|
||||
new_logprobs = logprobs.slice_request(req_index,
|
||||
len(new_token_ids))
|
||||
if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs:
|
||||
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
|
||||
|
||||
if new_token_ids and self.structured_output_manager.should_advance(
|
||||
request):
|
||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||
struct_output_request = request.structured_output_request
|
||||
assert struct_output_request is not None
|
||||
assert struct_output_request.grammar is not None
|
||||
struct_output_request.grammar.accept_tokens(
|
||||
req_id, new_token_ids)
|
||||
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
|
||||
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
@@ -770,7 +732,8 @@ class RecomputeScheduler(Scheduler):
|
||||
trace_headers=request.trace_headers,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
num_nans_in_logits=request.num_nans_in_logits,
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
@@ -805,10 +768,7 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
engine_core_outputs = {
|
||||
client_index: EngineCoreOutputs(outputs=outs)
|
||||
for client_index, outs in outputs.items()
|
||||
}
|
||||
engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()}
|
||||
|
||||
finished_req_ids = self.finished_req_ids_dict
|
||||
if finished_req_ids:
|
||||
@@ -819,12 +779,10 @@ class RecomputeScheduler(Scheduler):
|
||||
if (eco := engine_core_outputs.get(client_index)) is not None:
|
||||
eco.finished_requests = finished_set
|
||||
else:
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(
|
||||
finished_requests=finished_set)
|
||||
engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
|
||||
finished_req_ids.clear()
|
||||
|
||||
if (stats := self.make_stats(spec_decoding_stats,
|
||||
kv_connector_stats)) is not None:
|
||||
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None:
|
||||
# Return stats to only one of the front-ends.
|
||||
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
|
||||
# We must return the stats even if there are no request
|
||||
@@ -836,6 +794,5 @@ class RecomputeScheduler(Scheduler):
|
||||
|
||||
|
||||
class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
#
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from vllm.config import VllmConfig
|
||||
@@ -25,8 +24,7 @@ from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||
create_request_queue)
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreEventType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -43,8 +41,9 @@ class BudgetRefiner:
|
||||
if not self.enabled:
|
||||
return
|
||||
logger.info(
|
||||
"Dynamic batch is enabled with SLO limit: {}, and chunked prefill is forced to be activated because dynamic batch relies on it"
|
||||
.format(str(slo_limit)))
|
||||
"Dynamic batch is enabled with SLO limit: {}, and chunked prefill is "
|
||||
"forced to be activated because dynamic batch relies on it".format(str(slo_limit))
|
||||
)
|
||||
self.lookup: dict[tuple[int, int], int] = {}
|
||||
self.context_keys: set[int] = set()
|
||||
self.dnum_keys: set[int] = set()
|
||||
@@ -61,19 +60,20 @@ class BudgetRefiner:
|
||||
"The dynamic batching feature requires the lookup table "
|
||||
"'profile_table.csv', but it was not found at '%s'. "
|
||||
"Please download the corresponding table file.",
|
||||
table_file_path)
|
||||
table_file_path,
|
||||
)
|
||||
self.enabled = False
|
||||
return
|
||||
else:
|
||||
df = pd.read_csv(table_file_path)
|
||||
grouped = df.groupby(['ctx_len', 'd_num'])
|
||||
grouped = df.groupby(["ctx_len", "d_num"])
|
||||
for (ctx_len, d_num), group in grouped:
|
||||
valid = group[group['cost'] <= slo_limit]
|
||||
valid = group[group["cost"] <= slo_limit]
|
||||
if not valid.empty:
|
||||
max_row = valid.loc[valid['chunk_size'].idxmax()]
|
||||
max_row = valid.loc[valid["chunk_size"].idxmax()]
|
||||
assert isinstance(ctx_len, int), "ctx_len must be an integer"
|
||||
assert isinstance(d_num, int), "d_num must be an integer"
|
||||
self.lookup[(ctx_len, d_num)] = int(max_row['chunk_size'])
|
||||
self.lookup[(ctx_len, d_num)] = int(max_row["chunk_size"])
|
||||
self.context_keys.add(ctx_len)
|
||||
self.dnum_keys.add(d_num)
|
||||
self.context_keys = set(sorted(self.context_keys))
|
||||
@@ -97,7 +97,10 @@ class BudgetRefiner:
|
||||
logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}")
|
||||
budget = self.default_budget
|
||||
# For debug.
|
||||
# logger.info(f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, raw ctx,dnum {num_deocde_tokens, num_decode}")
|
||||
# logger.info(
|
||||
# f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, "
|
||||
# f"raw ctx,dnum {num_deocde_tokens, num_decode}"
|
||||
# )
|
||||
return budget
|
||||
|
||||
def refine_budget(self, running_request, budget):
|
||||
@@ -106,9 +109,8 @@ class BudgetRefiner:
|
||||
return budget
|
||||
# assume all running request will be scheduled.
|
||||
num_decode_token_lst = [
|
||||
req.num_tokens_with_spec \
|
||||
for req in running_request \
|
||||
if req.num_computed_tokens >= req.num_prompt_tokens ]
|
||||
req.num_tokens_with_spec for req in running_request if req.num_computed_tokens >= req.num_prompt_tokens
|
||||
]
|
||||
num_decode = len(num_decode_token_lst)
|
||||
if num_decode <= 0:
|
||||
return budget
|
||||
@@ -125,18 +127,25 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
block_size: Optional[int] = None,
|
||||
block_size: int | None = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vllm_config, kv_cache_config,
|
||||
structured_output_manager, block_size, mm_registry,
|
||||
include_finished_set, log_stats)
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
kv_cache_config,
|
||||
structured_output_manager,
|
||||
block_size,
|
||||
mm_registry,
|
||||
include_finished_set,
|
||||
log_stats,
|
||||
)
|
||||
self.running: list[Request] = []
|
||||
self.budget_refiner = BudgetRefiner(
|
||||
default_budget=self.scheduler_config.max_num_batched_tokens,
|
||||
slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch)
|
||||
slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch,
|
||||
)
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE: This scheduling algorithm is developed based on the "super.schedule()"
|
||||
@@ -159,20 +168,13 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
token_budget = self.budget_refiner.refine_budget(
|
||||
self.running, token_budget)
|
||||
token_budget = self.budget_refiner.refine_budget(self.running, token_budget)
|
||||
|
||||
# NOTE: We move the prefill requests to the end of the self.running
|
||||
# list and keep the relative order unchanged. This rearrangement makes this
|
||||
# scheduling algorithm a strict decode-first chunked prefills.
|
||||
d_lst = [
|
||||
req for req in self.running
|
||||
if req.num_computed_tokens >= req.num_prompt_tokens
|
||||
]
|
||||
p_lst = [
|
||||
req for req in self.running
|
||||
if req.num_computed_tokens < req.num_prompt_tokens
|
||||
]
|
||||
d_lst = [req for req in self.running if req.num_computed_tokens >= req.num_prompt_tokens]
|
||||
p_lst = [req for req in self.running if req.num_computed_tokens < req.num_prompt_tokens]
|
||||
self.running = d_lst + p_lst
|
||||
|
||||
# Encoder-related.
|
||||
@@ -189,30 +191,26 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec +
|
||||
request.num_output_placeholders -
|
||||
request.num_computed_tokens)
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
|
||||
)
|
||||
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - 1 - request.num_computed_tokens)
|
||||
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_compute_budget
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens,
|
||||
encoder_compute_budget)
|
||||
(encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget) = (
|
||||
self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens, encoder_compute_budget
|
||||
)
|
||||
)
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
@@ -231,9 +229,8 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens
|
||||
)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
@@ -253,8 +250,7 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
||||
preempted_req.record_event(EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
@@ -279,19 +275,15 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
num_scheduled_spec_tokens = num_new_tokens + request.num_computed_tokens - request.num_tokens
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -301,8 +293,10 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
scheduled_loras: set[int] = set()
|
||||
if self.lora_config:
|
||||
scheduled_loras = set(
|
||||
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
req.lora_request.lora_int_id
|
||||
for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0
|
||||
)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
@@ -323,9 +317,7 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request.request_id)
|
||||
logger.debug("%s is still in WAITING_FOR_REMOTE_KVS state.", request.request_id)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
@@ -343,9 +335,14 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras and
|
||||
request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
if (
|
||||
self.lora_config
|
||||
and request.lora_request
|
||||
and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras
|
||||
)
|
||||
):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
@@ -357,15 +354,15 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
# Get locally-cached tokens.
|
||||
new_computed_blocks, num_new_local_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
|
||||
request
|
||||
)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
num_external_computed_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens))
|
||||
num_external_computed_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens
|
||||
)
|
||||
|
||||
if num_external_computed_tokens is None:
|
||||
# The request cannot be scheduled because
|
||||
@@ -376,13 +373,11 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
continue
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_new_local_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
|
||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||
# after async KV recvs are completed.
|
||||
else:
|
||||
new_computed_blocks = (
|
||||
self.kv_cache_manager.create_empty_block_list())
|
||||
new_computed_blocks = self.kv_cache_manager.create_empty_block_list()
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
@@ -399,15 +394,12 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold
|
||||
< num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
if not self.scheduler_config.enable_chunked_prefill and \
|
||||
num_new_tokens > token_budget:
|
||||
if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
@@ -417,11 +409,11 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
_) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_compute_budget)
|
||||
(encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, _) = (
|
||||
self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens, encoder_compute_budget
|
||||
)
|
||||
)
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
@@ -431,9 +423,7 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
# extra block gets allocated which
|
||||
# creates a mismatch between the number
|
||||
# of local and remote blocks.
|
||||
effective_lookahead_tokens = (0 if request.num_computed_tokens
|
||||
== 0 else
|
||||
self.num_lookahead_tokens)
|
||||
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
@@ -441,8 +431,7 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens =\
|
||||
self.scheduler_config.max_num_encoder_input_tokens
|
||||
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
|
||||
@@ -484,20 +473,17 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid request status: {request.status}")
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_blocks[request.request_id] = (
|
||||
self.kv_cache_manager.get_blocks(request.request_id))
|
||||
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
@@ -507,8 +493,7 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
@@ -526,22 +511,17 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
# this step, the total number of scheduled requests can be smaller than
|
||||
# len(self.running).
|
||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||
len(scheduled_running_reqs) <= len(self.running))
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(
|
||||
self.kv_cache_config.kv_cache_groups)
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request.request_id))
|
||||
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
@@ -564,8 +544,7 @@ class SchedulerDynamicBatch(Scheduler):
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
|
||||
@@ -14,61 +14,50 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.device_communicators.base_device_communicator import \
|
||||
DeviceCommunicatorBase
|
||||
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class NPUCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
device: torch.device | None = None,
|
||||
device_group: dist.ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
|
||||
# init device according to rank
|
||||
self.device = torch.npu.current_device()
|
||||
|
||||
def all_to_all(self,
|
||||
def all_to_all(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
scatter_sizes: list[int] | None = None,
|
||||
gather_sizes: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
if scatter_dim < 0:
|
||||
scatter_dim += input_.dim()
|
||||
if gather_dim < 0:
|
||||
gather_dim += input_.dim()
|
||||
|
||||
if scatter_sizes is not None and gather_sizes is not None:
|
||||
input_list = [
|
||||
t.contiguous()
|
||||
for t in torch.split(input_, scatter_sizes, scatter_dim)
|
||||
]
|
||||
input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
|
||||
output_list = []
|
||||
tensor_shape_base = input_list[self.rank].size()
|
||||
for i in range(self.world_size):
|
||||
tensor_shape = list(tensor_shape_base)
|
||||
tensor_shape[gather_dim] = gather_sizes[i]
|
||||
output_list.append(
|
||||
torch.empty(tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device))
|
||||
output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
|
||||
|
||||
else:
|
||||
input_list = [
|
||||
t.contiguous() for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim)
|
||||
]
|
||||
output_list = [
|
||||
torch.empty_like(input_list[i]) for i in range(self.world_size)
|
||||
]
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)]
|
||||
|
||||
dist.all_to_all(output_list, input_list, group=self.device_group)
|
||||
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -24,18 +23,23 @@ from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum, hcclUniqueId)
|
||||
HCCLLibrary,
|
||||
aclrtStream_t,
|
||||
buffer_type,
|
||||
hcclComm_t,
|
||||
hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum,
|
||||
hcclUniqueId,
|
||||
)
|
||||
from vllm_ascend.utils import current_stream
|
||||
|
||||
|
||||
class PyHcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
group: ProcessGroup | StatelessProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
library_path: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -52,7 +56,8 @@ class PyHcclCommunicator:
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.HCCL, (
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group.")
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group."
|
||||
)
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
@@ -113,8 +118,7 @@ class PyHcclCommunicator:
|
||||
# `torch.npu.device` is a context manager that changes the
|
||||
# current npu device to the specified one
|
||||
with torch.npu.device(device):
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
@@ -123,43 +127,48 @@ class PyHcclCommunicator:
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# hccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {in_tensor.device}"
|
||||
)
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
self.hccl.hcclAllReduce(
|
||||
buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
aclrtStream_t(stream.npu_stream))
|
||||
hcclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
aclrtStream_t(stream.npu_stream),
|
||||
)
|
||||
return out_tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
self.hccl.hcclBroadcast(buffer, tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, aclrtStream_t(stream.npu_stream))
|
||||
self.hccl.hcclBroadcast(
|
||||
buffer,
|
||||
tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
aclrtStream_t(stream.npu_stream),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
@@ -107,33 +107,36 @@ class hcclRedOpTypeEnum:
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class HCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* HcclGetErrorString(HcclResult code);
|
||||
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
|
||||
|
||||
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
|
||||
Function("HcclGetRootInfo", hcclResult_t,
|
||||
[ctypes.POINTER(hcclUniqueId)]),
|
||||
|
||||
Function("HcclGetRootInfo", hcclResult_t, [ctypes.POINTER(hcclUniqueId)]),
|
||||
# HcclResult HcclCommInitRootInfo(
|
||||
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
|
||||
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
|
||||
Function("HcclCommInitRootInfo", hcclResult_t, [
|
||||
Function(
|
||||
"HcclCommInitRootInfo",
|
||||
hcclResult_t,
|
||||
[
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
]),
|
||||
|
||||
],
|
||||
),
|
||||
# HcclResult HcclAllReduce(
|
||||
# void *sendBuf, void *recvBuf, uint64_t count,
|
||||
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
|
||||
# aclrtStream stream);
|
||||
Function("HcclAllReduce", hcclResult_t, [
|
||||
Function(
|
||||
"HcclAllReduce",
|
||||
hcclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
@@ -141,35 +144,37 @@ class HCCLLibrary:
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
],
|
||||
),
|
||||
# HcclResult HcclBroadcast(
|
||||
# void *buf, uint64_t count,
|
||||
# HcclDataType dataType, uint32_t root,
|
||||
# HcclComm comm, aclrtStream stream);
|
||||
Function("HcclBroadcast", hcclResult_t, [
|
||||
Function(
|
||||
"HcclBroadcast",
|
||||
hcclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
],
|
||||
),
|
||||
# HcclResult HcclCommDestroy(HcclComm comm);
|
||||
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the correspongding directory
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: str | None = None):
|
||||
so_file = so_file or find_hccl_library()
|
||||
|
||||
try:
|
||||
@@ -185,12 +190,14 @@ class HCCLLibrary:
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable HCCL_SO_PATH"
|
||||
" to point to the correct hccl library path.", so_file,
|
||||
platform.platform())
|
||||
" to point to the correct hccl library path.",
|
||||
so_file,
|
||||
platform.platform(),
|
||||
)
|
||||
raise e
|
||||
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in HCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
@@ -209,34 +216,37 @@ class HCCLLibrary:
|
||||
|
||||
def hcclGetUniqueId(self) -> hcclUniqueId:
|
||||
unique_id = hcclUniqueId()
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
|
||||
ctypes.byref(unique_id)))
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
|
||||
rank: int) -> hcclComm_t:
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, rank: int) -> hcclComm_t:
|
||||
comm = hcclComm_t()
|
||||
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
|
||||
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
|
||||
self.HCCL_CHECK(
|
||||
self._funcs["HcclCommInitRootInfo"](world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))
|
||||
)
|
||||
return comm
|
||||
|
||||
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
def hcclAllReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: hcclComm_t,
|
||||
stream: aclrtStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `hcclDataType_t`
|
||||
# and `op` should be `hcclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream))
|
||||
|
||||
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
|
||||
root: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
|
||||
root, comm, stream))
|
||||
def hcclBroadcast(
|
||||
self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t
|
||||
) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, root, comm, stream))
|
||||
|
||||
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
|
||||
|
||||
Reference in New Issue
Block a user