[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #2) (#5977)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/attention_mask.py` |
| `vllm_ascend/attention/attention_v1.py` |
| `vllm_ascend/attention/context_parallel/attention_cp.py` |
| `vllm_ascend/attention/context_parallel/common_cp.py` |
| `vllm_ascend/attention/context_parallel/mla_cp.py` |
| `vllm_ascend/attention/utils.py` |
| `vllm_ascend/batch_invariant.py` |
| `vllm_ascend/device/device_op.py` |
| `vllm_ascend/device_allocator/camem.py` |
| `vllm_ascend/envs.py` |


- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-19 08:59:46 +08:00
committed by GitHub
parent 2b6dc100b5
commit 329961b375
11 changed files with 920 additions and 1045 deletions

View File

@@ -49,11 +49,9 @@ line-length = 120
# Folder to be modified # Folder to be modified
exclude = [ exclude = [
"tests/**", "tests/**",
"vllm_ascend/_cann_ops_custom", "vllm_ascend/attention/mla_v1.py",
"vllm_ascend/attention", "vllm_ascend/attention/sfa_v1.py",
"vllm_ascend/core", "vllm_ascend/core",
"vllm_ascend/device",
"vllm_ascend/device_allocator",
"vllm_ascend/distributed", "vllm_ascend/distributed",
"vllm_ascend/eplb", "vllm_ascend/eplb",
"vllm_ascend/kv_offload", "vllm_ascend/kv_offload",
@@ -66,8 +64,6 @@ exclude = [
"vllm_ascend/spec_decode", "vllm_ascend/spec_decode",
"vllm_ascend/worker", "vllm_ascend/worker",
"vllm_ascend/xlite", "vllm_ascend/xlite",
"vllm_ascend/envs.py",
"vllm_ascend/batch_invariant.py",
] ]
[tool.ruff.lint] [tool.ruff.lint]

View File

@@ -21,21 +21,18 @@ from vllm_ascend.utils import singleton
def _generate_attn_mask(max_seq_len, dtype): def _generate_attn_mask(max_seq_len, dtype):
# Construct lower triangle matrix. # Construct lower triangle matrix.
mask_flag = torch.ones((max_seq_len, max_seq_len), mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()
dtype=torch.bool).tril_()
# Create upper triangle matrix used to mark mask positions. # Create upper triangle matrix used to mark mask positions.
mask_flag = ~mask_flag mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf. # Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future. # TODO: Eliminate this part in the future.
mask_value = float('-inf') if dtype == torch.float16 else 1 mask_value = float("-inf") if dtype == torch.float16 else 1
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \ attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype).masked_fill_(mask_flag, mask_value)
.masked_fill_(mask_flag, mask_value)
return attn_mask return attn_mask
@singleton @singleton
class AttentionMaskBuilder: class AttentionMaskBuilder:
def __init__(self, device: torch.device): def __init__(self, device: torch.device):
self.attn_mask_cache = None self.attn_mask_cache = None
self._seq_len_cached = 0 self._seq_len_cached = 0
@@ -52,14 +49,13 @@ class AttentionMaskBuilder:
assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask." assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask."
if self.attn_mask_cache.dtype != dtype: if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype) self.attn_mask_cache = self.attn_mask_cache.to(dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous().to(self.device, non_blocking=True)
).to(self.device, non_blocking=True)
def get_splitfuse_attn_mask(self) -> torch.Tensor: def get_splitfuse_attn_mask(self) -> torch.Tensor:
if self.chunked_prefill_attn_mask is None: if self.chunked_prefill_attn_mask is None:
self.chunked_prefill_attn_mask = torch.triu( self.chunked_prefill_attn_mask = (
torch.ones(2048, torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(self.device)
2048), diagonal=1).to(torch.int8).to(self.device) )
return self.chunked_prefill_attn_mask return self.chunked_prefill_attn_mask
def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor: def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
@@ -68,16 +64,13 @@ class AttentionMaskBuilder:
mask_value = torch.finfo(torch.float32).min mask_value = torch.finfo(torch.float32).min
else: else:
mask_value = 1 mask_value = 1
prefill_mask = torch.triu( prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
torch.ones(512, 512, device=self.device, dtype=dtype), 1) self.mla_mask = torch.where(prefill_mask == 1, mask_value, 0).to(dtype)
self.mla_mask = torch.where(prefill_mask == 1, mask_value,
0).to(dtype)
return self.mla_mask return self.mla_mask
def get_pcp_mla_mask(self, dtype: torch.dtype): def get_pcp_mla_mask(self, dtype: torch.dtype):
if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype: if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype:
self.pcp_mla_mask = torch.triu( self.pcp_mla_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
return self.pcp_mla_mask return self.pcp_mla_mask
def get_swa_mask(self, dtype: torch.dtype, sliding_window): def get_swa_mask(self, dtype: torch.dtype, sliding_window):

View File

@@ -17,7 +17,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import ClassVar, List, Optional, Tuple, Type from typing import ClassVar
import torch import torch
import torch_npu import torch_npu
@@ -29,32 +29,49 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import ( from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
AscendMetadataForDecode, AscendMetadataForPrefill) from vllm_ascend.attention.utils import (
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, AscendCommonAttentionMetadata,
enable_cp, split_decodes_and_prefills, enable_cp,
using_paged_attention) split_decodes_and_prefills,
using_paged_attention,
)
from vllm_ascend.compilation.acl_graph import ( from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params, get_draft_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces) get_graph_params,
update_draft_graph_params_workspaces,
update_graph_params_workspaces,
)
from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import vllm_version_is, weak_ref_tensors from vllm_ascend.utils import vllm_version_is, weak_ref_tensors
# isort: off # isort: off
if vllm_version_is('0.13.0'): if vllm_version_is("0.13.0"):
from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder
AttentionMetadataBuilder)
from vllm.attention.backends.abstract import ( # type: ignore from vllm.attention.backends.abstract import ( # type: ignore
AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionType,
)
from vllm.attention.backends.registry import ( # type: ignore from vllm.attention.backends.registry import ( # type: ignore
AttentionBackendEnum, register_backend) AttentionBackendEnum,
register_backend,
)
else: else:
from vllm.v1.attention.backend import ( # type: ignore from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend, AttentionCGSupport, AttentionImpl, AttentionLayer, AttentionBackend,
AttentionType, AttentionMetadataBuilder) AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionType,
AttentionMetadataBuilder,
)
from vllm.v1.attention.backends.registry import ( # type: ignore from vllm.v1.attention.backends.registry import ( # type: ignore
AttentionBackendEnum, register_backend) AttentionBackendEnum,
register_backend,
)
# isort: on # isort: on
# default max value of sliding window size # default max value of sliding window size
@@ -73,18 +90,18 @@ class AscendAttentionBackend(AttentionBackend):
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN" return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: def get_impl_cls() -> type["AscendAttentionBackendImpl"]:
if enable_cp(): if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import \ from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl
AscendAttentionCPImpl
return AscendAttentionCPImpl return AscendAttentionCPImpl
return AscendAttentionBackendImpl return AscendAttentionBackendImpl
@staticmethod @staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
if enable_cp(): if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import \ from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder
AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder return AscendAttentionCPMetadataBuilder
return AscendAttentionMetadataBuilder return AscendAttentionMetadataBuilder
@@ -94,13 +111,13 @@ class AscendAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
) -> Tuple[int, ...]: ) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
src_kv_cache: List[torch.Tensor], src_kv_cache: list[torch.Tensor],
dst_kv_cache: List[torch.Tensor], dst_kv_cache: list[torch.Tensor],
src_to_dst: torch.Tensor, src_to_dst: torch.Tensor,
) -> None: ) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
@@ -108,14 +125,12 @@ class AscendAttentionBackend(AttentionBackend):
src_indices = src_to_dst[:, 0] src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1] dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to( dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device)
dst_key_cache.device) dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor, src_to_dists: torch.Tensor,
) -> None: ) -> None:
src_indices = src_to_dists[:, 0] src_indices = src_to_dists[:, 0]
@@ -148,8 +163,9 @@ class AscendMetadata:
Contains attention masks, token counts, sequence lengths and KV cache Contains attention masks, token counts, sequence lengths and KV cache
related properties for attention computation. related properties for attention computation.
""" """
# **************************** Basic Properties ************************** # # **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None attn_mask: torch.Tensor | None = None
# Current state of this attention run. # Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -168,12 +184,12 @@ class AscendMetadata:
# should simplified these parameters once attention schema in vLLM-Ascend # should simplified these parameters once attention schema in vLLM-Ascend
# is unified. # is unified.
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore seq_lens_list: list[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore actual_seq_lengths_q: list[int] = None # type: ignore
query_start_loc: torch.Tensor = None query_start_loc: torch.Tensor = None
# Maximum query length in the batch (None for decoding). # Maximum query length in the batch (None for decoding).
max_query_len: Optional[int] = None max_query_len: int | None = None
# ********************** KV Cache Related Properties ********************* # # ********************** KV Cache Related Properties ********************* #
# Block addresses per sequence (Seq id -> list of physical block). # Block addresses per sequence (Seq id -> list of physical block).
@@ -187,9 +203,9 @@ class AscendMetadata:
# (num_tokens,) # (num_tokens,)
slot_mapping: torch.Tensor = None slot_mapping: torch.Tensor = None
# pcp # pcp
prefill: Optional[AscendMetadataForPrefill] = None prefill: AscendMetadataForPrefill | None = None
# dcp # dcp
decode_meta: Optional[AscendMetadataForDecode] = None decode_meta: AscendMetadataForDecode | None = None
causal: bool = True causal: bool = True
# runner_type in model_config. # runner_type in model_config.
@@ -198,7 +214,7 @@ class AscendMetadata:
reshape_cache_event: torch.npu.Event = None reshape_cache_event: torch.npu.Event = None
# sliding window attention mask # sliding window attention mask
swa_mask: Optional[torch.Tensor] = None swa_mask: torch.Tensor | None = None
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
@@ -208,6 +224,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
Handles attention mask generation and metadata preparation for Handles attention mask generation and metadata preparation for
Ascend FlashAttention backend. Ascend FlashAttention backend.
""" """
# Does this backend/builder reorder the batch? # Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch. # length that will be pulled into the front of the batch.
@@ -226,17 +243,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.device = device self.device = device
self.max_num_blocks_per_req = cdiv( self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len, self.model_config.max_model_len, AscendAttentionBackend.get_supported_block_size()[0]
AscendAttentionBackend.get_supported_block_size()[0]) )
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1 self.decode_threshold = 1
if self.speculative_config: if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \ assert self.decode_threshold <= 16, (
f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \ npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}" got {self.decode_threshold}"
)
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
@@ -254,8 +273,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# @override omitted only because of mypy limitation due to type variable. # @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS return AttentionCGSupport.ALWAYS
def reorder_batch(self, input_batch, def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool:
scheduler_output: "SchedulerOutput") -> bool:
return False return False
def build( def build(
@@ -266,12 +284,11 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
) -> AscendMetadata: ) -> AscendMetadata:
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
num_reqs
+ 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) common_attn_metadata, decode_threshold=self.decode_threshold
)
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
@@ -283,19 +300,17 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
attn_state = common_attn_metadata.attn_state attn_state = common_attn_metadata.attn_state
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder # Get attn_mask and swa_mask from singleton AttentionMaskBuilder
attn_mask = self.attn_mask_builder.get_attention_mask( attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config)
self.model_config)
swa_mask = None swa_mask = None
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window') is_swa = hasattr(self.model_config.hf_text_config, "sliding_window")
if self.model_config is not None and is_swa: if self.model_config is not None and is_swa:
swa_mask = self.attn_mask_builder.get_swa_mask( swa_mask = self.attn_mask_builder.get_swa_mask(
self.model_config.dtype, self.model_config.dtype, self.model_config.hf_text_config.sliding_window
self.model_config.hf_text_config.sliding_window) )
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to( query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True)
self.device, non_blocking=True)
attn_metadata = AscendMetadata( attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@@ -313,7 +328,8 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
num_prefills=num_prefills, num_prefills=num_prefills,
num_decodes=num_decodes, num_decodes=num_decodes,
causal=common_attn_metadata.causal, causal=common_attn_metadata.causal,
model_runner_type=self.model_config.runner_type) model_runner_type=self.model_config.runner_type,
)
return attn_metadata return attn_metadata
def build_for_graph_capture( def build_for_graph_capture(
@@ -321,9 +337,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
): ):
if attn_state in (AscendAttentionState.DecodeOnly, AscendAttentionState.ChunkedPrefill):
if attn_state in (AscendAttentionState.DecodeOnly,
AscendAttentionState.ChunkedPrefill):
attn_metadata = self.build( attn_metadata = self.build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
@@ -338,19 +352,18 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
class AscendAttentionBackendImpl(AttentionImpl): class AscendAttentionBackendImpl(AttentionImpl):
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: int, num_kv_heads: int,
alibi_slopes: Optional[List[float]], alibi_slopes: list[float] | None,
sliding_window: Optional[int], sliding_window: int | None,
kv_cache_dtype: str, kv_cache_dtype: str,
logits_soft_cap: Optional[float], logits_soft_cap: float | None,
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str], kv_sharing_target_layer_name: str | None,
**kwargs, **kwargs,
) -> None: ) -> None:
self.vllm_config = get_current_vllm_config() self.vllm_config = get_current_vllm_config()
@@ -362,9 +375,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window self.sliding_window = sliding_window
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu")
dtype=torch.float32,
device="npu")
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
self.attn_type = attn_type self.attn_type = attn_type
@@ -372,18 +383,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None self.key_cache = None
self.value_cache = None self.value_cache = None
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype) super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.post_process_after_loading() flashcomm2_oshard_manager.post_process_after_loading()
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, def full_graph_fia(
value: torch.Tensor, attn_metadata: AscendMetadata, self,
output: torch.Tensor) -> torch.Tensor: query: torch.Tensor,
key, value, block_size, block_table, actual_seq_lengths_kv \ key: torch.Tensor,
= self._get_fia_params(key, value, attn_metadata) value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1] num_tokens = attn_metadata.actual_seq_lengths_q[-1]
forward_context = get_forward_context() forward_context = get_forward_context()
@@ -427,12 +444,22 @@ class AscendAttentionBackendImpl(AttentionImpl):
event.reset(stream) event.reset(stream)
graph_params.events[num_tokens].append(event) graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append( graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key), (
weak_ref_tensors(value), weak_ref_tensors(block_table), weak_ref_tensors(query),
weak_ref_tensors(attn_metadata.attn_mask), block_size, weak_ref_tensors(key),
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads, weak_ref_tensors(value),
self.num_heads, self.scale, weak_ref_tensors(output), weak_ref_tensors(block_table),
weak_ref_tensors(softmax_lse))) weak_ref_tensors(attn_metadata.attn_mask),
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
self.num_kv_heads,
self.num_heads,
self.scale,
weak_ref_tensors(output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream) torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
@@ -463,7 +490,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self, self,
query: torch.Tensor, query: torch.Tensor,
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None, output: torch.Tensor | None = None,
): ):
graph_params = get_graph_params() graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
@@ -481,7 +508,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale_value=self.scale, scale_value=self.scale,
block_table=attn_metadata.block_tables, block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens, context_lens=attn_metadata.seq_lens,
out=output) out=output,
)
update_graph_params_workspaces(num_tokens, workspace) update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode # Handle graph capturing mode
@@ -491,7 +519,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
event.wait(stream) event.wait(stream)
event.reset(stream) event.reset(stream)
graph_params.events[num_tokens].append(event) graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(( graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(query), weak_ref_tensors(query),
weak_ref_tensors(self.key_cache), weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache), weak_ref_tensors(self.value_cache),
@@ -501,7 +530,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata.block_tables, attn_metadata.block_tables,
attn_metadata.seq_lens, attn_metadata.seq_lens,
weak_ref_tensors(output), weak_ref_tensors(output),
)) )
)
torch.npu.graph_task_group_begin(stream) torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention( torch_npu._npu_paged_attention(
@@ -514,53 +544,54 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=attn_metadata.block_tables, block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens, context_lens=attn_metadata.seq_lens,
out=output, out=output,
workspace=workspace) workspace=workspace,
)
handle = torch.npu.graph_task_group_end(stream) handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle) graph_params.handles[num_tokens].append(handle)
return output return output
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
attn_metadata: AscendMetadata):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128 block_size = 128
block_table = None block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
if self.attn_type == AttentionType.ENCODER_DECODER: if self.attn_type == AttentionType.ENCODER_DECODER:
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist()
dim=0).tolist() elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.seq_lens.shape[0] batch_size = attn_metadata.seq_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :] block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore key = self.key_cache.view( # type: ignore
num_block, block_size, -1) num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore value = self.value_cache.view( # type: ignore
num_block, block_size, -1) num_block, block_size, -1
)
actual_seq_lengths_kv = attn_metadata.seq_lens_list actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore key = self.key_cache.view( # type: ignore
num_block, block_size, -1) num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore value = self.value_cache.view( # type: ignore
num_block, block_size, -1) num_block, block_size, -1
)
block_table = attn_metadata.block_tables block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list actual_seq_lengths_kv = attn_metadata.seq_lens_list
# chunked prefill. # chunked prefill.
else: else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore key = self.key_cache.view( # type: ignore
num_block, block_size, -1) num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore value = self.value_cache.view( # type: ignore
num_block, block_size, -1) num_block, block_size, -1
)
block_table = attn_metadata.block_tables block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list actual_seq_lengths_kv = attn_metadata.seq_lens_list
return key, value, block_size, block_table, actual_seq_lengths_kv return key, value, block_size, block_table, actual_seq_lengths_kv
def _forward_fia_slidingwindow(self, query: torch.Tensor, def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor):
attn_metadata: AscendMetadata,
output: torch.Tensor):
batch_size = attn_metadata.seq_lens.shape[0] batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128 block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size) query = query.view(batch_size, 1, self.num_heads * self.head_size)
@@ -583,34 +614,41 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale=self.scale, scale=self.scale,
block_table=attn_metadata.block_tables, block_table=attn_metadata.block_tables,
actual_seq_lengths=[1] * len(attn_metadata.seq_lens), actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
actual_seq_lengths_kv=attn_metadata.seq_lens) actual_seq_lengths_kv=attn_metadata.seq_lens,
)
output = output.view(batch_size, self.num_heads, self.head_size) output = output.view(batch_size, self.num_heads, self.head_size)
return output return output
def forward_fused_infer_attention(self, query: torch.Tensor, def forward_fused_infer_attention(
key: torch.Tensor, value: torch.Tensor, self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: torch.Tensor): output: torch.Tensor,
):
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
# we inherit ForwardContext in model runner v2, when enable model # we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context, # runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error. # just use getattr to avoid attribute error.
if getattr(forward_context, "capturing", False): if getattr(forward_context, "capturing", False):
attn_output, num_tokens = self.full_graph_fia( attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens] output[:num_tokens] = attn_output[:num_tokens]
return output return output
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and self.sliding_window is not None and self.sliding_window is not None
and attn_metadata.seq_lens.shape[0] == query.size(0)): and attn_metadata.seq_lens.shape[0] == query.size(0)
return self._forward_fia_slidingwindow(query, attn_metadata, ):
output) return self._forward_fia_slidingwindow(query, attn_metadata, output)
key, value, block_size, block_table, actual_seq_lengths_kv \ key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
= self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1] num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens] query = query[:num_tokens]
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER: if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens] key = key[:num_tokens]
value = value[:num_tokens] value = value[:num_tokens]
# Get workspace from cache or calculate it if not present. # Get workspace from cache or calculate it if not present.
@@ -630,8 +668,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
sparse_mode=3, sparse_mode=3,
) )
attn_output = attn_output.view(num_tokens, self.num_heads, attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
self.head_size)
output[:num_tokens] = attn_output[:num_tokens] output[:num_tokens] = attn_output[:num_tokens]
return output return output
@@ -639,12 +676,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
self, self,
query: torch.Tensor, query: torch.Tensor,
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None, output: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if forward_context.capturing: if forward_context.capturing:
return self.full_graph_pa(query, attn_metadata, output) return self.full_graph_pa(query, attn_metadata, output)
torch_npu._npu_paged_attention(query=query, torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache, key_cache=self.key_cache,
value_cache=self.value_cache, value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
@@ -652,13 +690,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale_value=self.scale, scale_value=self.scale,
block_table=attn_metadata.block_tables, block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens, context_lens=attn_metadata.seq_lens,
out=output) out=output,
)
return output return output
def _forward_encoder_attention(self, query: torch.Tensor, def _forward_encoder_attention(
key: torch.Tensor, value: torch.Tensor, self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
_: torch.Tensor) -> torch.Tensor: _: torch.Tensor,
) -> torch.Tensor:
assert attn_metadata is not None assert attn_metadata is not None
if attn_metadata.causal: if attn_metadata.causal:
@@ -692,26 +735,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
self, self,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Tuple[torch.Tensor], kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
): ):
if len(kv_cache) > 1: if len(kv_cache) > 1:
if self.is_kv_producer: if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event() attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None: if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping slots = attn_metadata.slot_mapping
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER) encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER
DeviceOperator.reshape_and_cache( DeviceOperator.reshape_and_cache(
key=key[:attn_metadata.num_actual_tokens] key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key,
if not encoder_decoder else key, value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value,
value=value[:attn_metadata.num_actual_tokens]
if not encoder_decoder else value,
key_cache=self.key_cache, key_cache=self.key_cache,
value_cache=self.value_cache, value_cache=self.value_cache,
slot_mapping=slots[:attn_metadata.num_actual_tokens] slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots,
if not encoder_decoder else slots) )
if self.is_kv_producer: if self.is_kv_producer:
attn_metadata.reshape_cache_event.record() attn_metadata.reshape_cache_event.record()
return key, value return key, value
@@ -721,18 +761,19 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Tuple[torch.Tensor], kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: torch.Tensor, output: torch.Tensor,
): ):
num_tokens = query.shape[0] num_tokens = query.shape[0]
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and using_paged_attention(num_tokens, self.vllm_config) and using_paged_attention(num_tokens, self.vllm_config)
and self.sliding_window is None): and self.sliding_window is None
):
output = self.forward_paged_attention(query, attn_metadata, output) output = self.forward_paged_attention(query, attn_metadata, output)
else: else:
output = self.forward_fused_infer_attention( output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
query, key, value, attn_metadata, output)
return output return output
@@ -742,11 +783,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Tuple[torch.Tensor], kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None, output: torch.Tensor | None = None,
output_scale: Optional[torch.Tensor] = None, output_scale: torch.Tensor | None = None,
output_block_scale: Optional[torch.Tensor] = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Ascend attention. """Forward pass with Ascend attention.
Args: Args:
@@ -762,23 +803,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl")
"fused output quantization is not yet supported"
" for AscendAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens = query.shape[0] num_tokens = query.shape[0]
if attn_metadata is None: if attn_metadata is None:
return output.fill_(0) return output.fill_(0)
if key is not None and value is not None: if key is not None and value is not None:
key, value = self.reshape_and_cache(key, value, kv_cache, key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata)
attn_metadata)
# pooling model branch # pooling model branch
if attn_metadata.model_runner_type == "pooling": if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention( attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens] output[:num_tokens] = attn_output[:num_tokens]
return output return output
output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
output)
return output return output

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
from vllm.distributed import (get_dcp_group, from vllm.distributed import get_dcp_group, get_decode_context_model_parallel_world_size, get_pcp_group
get_decode_context_model_parallel_world_size,
get_pcp_group)
@dataclass @dataclass
@@ -17,6 +14,7 @@ class AscendPCPMetadata:
Stores index tensors and sequence lengths for routing attention Stores index tensors and sequence lengths for routing attention
computations across PCP ranks during long sequence processing. computations across PCP ranks during long sequence processing.
""" """
q_head_idx: torch.Tensor = None q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None kv_with_q_head_nomask_idx: torch.Tensor = None
@@ -27,7 +25,7 @@ class AscendPCPMetadata:
head_attn_nomask_seqlens: torch.Tensor = None head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None q_full_idx: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None pcp_allgather_restore_idx: list[int] | None = None
@dataclass @dataclass
@@ -37,6 +35,7 @@ class CPChunkedContextMetadata:
Extends chunked prefill with per-rank chunk information for PCP/DCP. Extends chunked prefill with per-rank chunk information for PCP/DCP.
""" """
# For handling chunked prefill # For handling chunked prefill
cu_seq_lens: torch.Tensor cu_seq_lens: torch.Tensor
starts: torch.Tensor starts: torch.Tensor
@@ -47,11 +46,11 @@ class CPChunkedContextMetadata:
chunk_seq_lens_npu: torch.Tensor chunk_seq_lens_npu: torch.Tensor
# for mla DCP & PCP # for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None padded_local_chunk_seq_lens: list[list[int]] | None = None
local_context_lens_allranks: Optional[list[list[int]]] = None local_context_lens_allranks: list[list[int]] | None = None
padded_local_cu_seq_lens: torch.Tensor = None padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: Optional[int] = None chunk_size: int | None = None
@dataclass @dataclass
@@ -61,20 +60,21 @@ class AscendMetadataForPrefill:
@dataclass @dataclass
class ChunkedContextMetadata: class ChunkedContextMetadata:
"""Metadata for chunked context processing within prefill phase.""" """Metadata for chunked context processing within prefill phase."""
actual_chunk_seq_lengths: torch.Tensor actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor starts: torch.Tensor
chunk_seq_mask_filtered_indices: torch.Tensor chunk_seq_mask_filtered_indices: torch.Tensor
chunked_req_mask: Optional[list[bool]] = None chunked_req_mask: list[bool] | None = None
local_context_lens_allranks: Optional[list[list[int]]] = None local_context_lens_allranks: list[list[int]] | None = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None cp_kv_recover_idx_for_chunk: list[int] | None = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None kv_inverse_idx_for_chunk: list[int] | None = None
batch_chunk_seq_mask: Optional[list[bool]] = None batch_chunk_seq_mask: list[bool] | None = None
local_total_toks: Optional[int] = None local_total_toks: int | None = None
""" Prefill Specific Metadata for Ascend""" """ Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None pcp_metadata: AscendPCPMetadata | None = None
chunked_context: Optional[ChunkedContextMetadata] = None chunked_context: ChunkedContextMetadata | None = None
block_tables: torch.Tensor = None block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None actual_seq_lengths_q: torch.Tensor = None
@@ -82,13 +82,15 @@ class AscendMetadataForPrefill:
@dataclass @dataclass
class AscendMetadataForDecode: class AscendMetadataForDecode:
"""Decode-specific metadata for Ascend attention with Context Parallelism.""" """Decode-specific metadata for Ascend attention with Context Parallelism."""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
batch_seq_mask: torch.Tensor = None batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None block_tables: torch.Tensor = None
def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor, def _process_attn_out_lse(
batch_seq_mask: torch.Tensor) -> torch.Tensor: attn_output: torch.Tensor, softmax_lse: torch.Tensor, batch_seq_mask: torch.Tensor
) -> torch.Tensor:
pcp_size = get_pcp_group().world_size pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size() dcp_size = get_decode_context_model_parallel_world_size()
dcp_group = get_dcp_group().device_group if dcp_size > 1 else None dcp_group = get_dcp_group().device_group if dcp_size > 1 else None
@@ -104,21 +106,17 @@ def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse) attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all, dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=dcp_group)
attn_out_lse,
group=dcp_group)
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
if pcp_size > 1: if pcp_size > 1:
# AllGather out&lse within CP group # AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0)
dim=0)
return attn_out_lse return attn_out_lse
def _npu_attention_update(head_size, def _npu_attention_update(head_size, attn_out_lse: torch.Tensor) -> torch.Tensor:
attn_out_lse: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size() dcp_size = get_decode_context_model_parallel_world_size()
# [PCP * S, DCP * H, D+1] # [PCP * S, DCP * H, D+1]
@@ -134,8 +132,7 @@ def _npu_attention_update(head_size,
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size # Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1) x = x.view(-1, S, H, D_plus_1)
# Split out lse # Split out lse
out_flat, lse_flat = torch.split(x, [D, 1], out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1]
dim=-1) # [N, S, H, D], [N, S, H, 1]
# out: [N, S, H, D] -> [N, S*H, D] # out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H] # lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2) # [N, S*H, D] out_flat = out_flat.flatten(1, 2) # [N, S*H, D]

View File

@@ -1,35 +1,43 @@
from typing import Optional, Tuple, TypeVar from typing import TypeVar
import numpy as np import numpy as np
import torch import torch
import torch_npu import torch_npu
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group, from vllm.distributed import (
get_dcp_group,
get_decode_context_model_parallel_rank, get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size, get_decode_context_model_parallel_world_size,
get_pcp_group) get_pcp_group,
)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
# isort: off # isort: off
from vllm_ascend.attention.mla_v1 import ( from vllm_ascend.attention.mla_v1 import (
AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, AscendMLADecodeMetadata,
AscendMLAMetadataBuilder, AscendMLAPrefillMetadata, AscendMLAImpl,
DecodeMLAPreprocessResult, PrefillMLAPreprocessResult, AscendMLAMetadata,
BUILD_METADATA_STEP_PREFILL) AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata,
DecodeMLAPreprocessResult,
PrefillMLAPreprocessResult,
BUILD_METADATA_STEP_PREFILL,
)
# isort: on # isort: on
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata)
from vllm_ascend.attention.context_parallel.common_cp import ( from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, CPChunkedContextMetadata, _process_attn_out_lse, AscendPCPMetadata,
_npu_attention_update) CPChunkedContextMetadata,
from vllm_ascend.compilation.acl_graph import (get_draft_graph_params, _npu_attention_update,
get_graph_params, _process_attn_out_lse,
update_graph_params_workspaces) )
from vllm_ascend.utils import weak_ref_tensors, vllm_version_is from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import get_draft_graph_params, get_graph_params, update_graph_params_workspaces
from vllm_ascend.utils import vllm_version_is, weak_ref_tensors
if vllm_version_is('0.13.0'): if vllm_version_is("0.13.0"):
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
else: else:
from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.attention.backend import AttentionCGSupport
@@ -54,28 +62,21 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
metadata_cls: type[AscendMLAMetadata] | None = None, metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False, supports_dcp_with_varlen: bool = False,
): ):
super().__init__(kv_cache_spec, layer_names, vllm_config, device, super().__init__(kv_cache_spec, layer_names, vllm_config, device, metadata_cls, supports_dcp_with_varlen)
metadata_cls, supports_dcp_with_varlen)
self.pcp_size = get_pcp_group().world_size self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group( self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
).rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank( self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
) if self.dcp_size > 1 else 0
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0)
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs * self.batch_seq_mask_buf = torch.empty(max_num_seqs * self.decode_threshold, dtype=torch.uint8, device=device)
self.decode_threshold, self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
dtype=torch.uint8, self.block_size, self.cp_virtual_block_size
device=device) )
self.block_size = (self.block_size *
self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size)
def build( def build(
self, self,
@@ -85,15 +86,10 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
) -> AscendMLAMetadata: ) -> AscendMLAMetadata:
metadata_cls = super().build(common_prefix_len, common_attn_metadata) metadata_cls = super().build(common_prefix_len, common_attn_metadata)
if self.num_prefills == 0 and self.pcp_size > 1: if self.num_prefills == 0 and self.pcp_size > 1:
self.slot_mapping[:self. self.slot_mapping[: self.num_decode_tokens] = self.slot_mapping[
num_decode_tokens] = self.slot_mapping[:self. : self.num_decode_tokens * self.pcp_size : self.pcp_size
num_decode_tokens ]
* self. self.slot_mapping[self.num_decode_tokens : self.num_decode_tokens * self.pcp_size].fill_(-1)
pcp_size:
self.
pcp_size]
self.slot_mapping[self.num_decode_tokens:self.num_decode_tokens *
self.pcp_size].fill_(-1)
metadata_cls.slot_mapping = self.slot_mapping metadata_cls.slot_mapping = self.slot_mapping
return metadata_cls return metadata_cls
@@ -118,8 +114,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
# In dcp only spec decode graph padding case, # In dcp only spec decode graph padding case,
# num_actual_tokens_pcp_padded may be less than num_actual_tokens # num_actual_tokens_pcp_padded may be less than num_actual_tokens
self.num_actual_tokens = max( self.num_actual_tokens = max(
long_seq_metadata.num_actual_tokens_pcp_padded, long_seq_metadata.num_actual_tokens_pcp_padded, common_attn_metadata.num_actual_tokens
common_attn_metadata.num_actual_tokens) )
def build_cp_metadata( def build_cp_metadata(
self, self,
@@ -131,30 +127,23 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
return AscendPCPMetadata( return AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata. kv_with_q_head_nomask_idx=common_long_seq_metadata.kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_nomask_idx_tensor, kv_with_q_head_mask_idx=common_long_seq_metadata.kv_with_q_head_mask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata. kv_with_q_tail_nomask_idx=common_long_seq_metadata.kv_with_q_tail_nomask_idx_tensor,
kv_with_q_head_mask_idx_tensor, kv_with_q_tail_mask_idx=common_long_seq_metadata.kv_with_q_tail_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.attn_mask_seqlens, attn_mask_seqlens=common_long_seq_metadata.attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata. head_attn_nomask_seqlens=common_long_seq_metadata.head_attn_nomask_seqlens,
head_attn_nomask_seqlens, tail_attn_nomask_seqlens=common_long_seq_metadata.tail_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx, q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_allgather_restore_idx=common_long_seq_metadata. pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
pcp_allgather_restore_idx) )
def build_chunked_metadata( def build_chunked_metadata(
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
): ):
chunked_context_metadata = super().build_chunked_metadata( chunked_context_metadata = super().build_chunked_metadata(common_prefix_len, common_attn_metadata)
common_prefix_len, common_attn_metadata)
if chunked_context_metadata is None: if chunked_context_metadata is None:
return None return None
@@ -162,33 +151,37 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
assert long_seq_metadata is not None assert long_seq_metadata is not None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None assert num_computed_tokens_of_pcp_dcp is not None
local_context_lens_allranks = torch.tensor( local_context_lens_allranks = torch.tensor(num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten :]).reshape(
num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten:]).reshape( -1, self.dcp_size * self.pcp_size
-1, self.dcp_size * self.pcp_size) )
# Note(qcs): The max local context lengths # Note(qcs): The max local context lengths
# padded to `cp_local_block_size`. # padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv( padded_local_context_lens_cpu = (
cdiv(
self.context_lens_cpu, self.context_lens_cpu,
self.cp_virtual_block_size, self.cp_virtual_block_size,
) * self.cp_local_block_size) )
padded_local_max_context_chunk_across_ranks = (cdiv( * self.cp_local_block_size
)
padded_local_max_context_chunk_across_ranks = (
cdiv(
self.max_context_chunk, self.max_context_chunk,
self.cp_virtual_block_size, self.cp_virtual_block_size,
) * self.cp_local_block_size) )
local_chunk_starts = (torch.arange( * self.cp_local_block_size
self.num_chunks, dtype=torch.int32).unsqueeze(1).expand( )
-1, self.num_prefills) * local_chunk_starts = (
padded_local_max_context_chunk_across_ranks) torch.arange(self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, self.num_prefills)
* padded_local_max_context_chunk_across_ranks
)
local_chunk_ends = torch.min( local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0), padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts + padded_local_max_context_chunk_across_ranks, local_chunk_starts + padded_local_max_context_chunk_across_ranks,
) )
padded_local_chunk_seq_lens = (local_chunk_ends - padded_local_chunk_seq_lens = (local_chunk_ends - local_chunk_starts).clamp(min=0)
local_chunk_starts).clamp(min=0) padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True
self.num_prefills + 1, )
dtype=torch.int32,
pin_memory=True)
torch.cumsum( torch.cumsum(
padded_local_chunk_seq_lens, padded_local_chunk_seq_lens,
dim=1, dim=1,
@@ -197,8 +190,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
) )
chunked_metadata = CPChunkedContextMetadata( chunked_metadata = CPChunkedContextMetadata(
cu_seq_lens=chunked_context_metadata.cu_seq_lens, cu_seq_lens=chunked_context_metadata.cu_seq_lens,
starts=local_chunk_starts.pin_memory().to(self.device, starts=local_chunk_starts.pin_memory().to(self.device, non_blocking=True),
non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunked_context_metadata.max_seq_lens, max_seq_lens=chunked_context_metadata.max_seq_lens,
chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens=self.chunk_seq_lens,
@@ -207,18 +199,14 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(), local_context_lens_allranks=local_context_lens_allranks.tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu. padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True),
pin_memory().to(self.device, non_blocking=True),
cu_seq_lens_lst=self.cu_seq_lens_cpu.tolist(), cu_seq_lens_lst=self.cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks, chunk_size=padded_local_max_context_chunk_across_ranks,
) )
return chunked_metadata return chunked_metadata
def get_block_table_size( def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
self, common_attn_metadata: AscendCommonAttentionMetadata, self.num_decodes_flatten = self.query_lens[: self.num_decodes].sum().item()
build_metadata_step: int):
self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum(
).item()
if build_metadata_step == BUILD_METADATA_STEP_PREFILL: if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
# For pcp + spec decode, we flatten seq_lens and block_table # For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular attn_mask shape # to avoid irregular attn_mask shape
@@ -231,12 +219,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendMLAPrefillMetadata: ) -> AscendMLAPrefillMetadata:
prefill_metadata = super().build_prefill_metadata( prefill_metadata = super().build_prefill_metadata(common_prefix_len, common_attn_metadata)
common_prefix_len, common_attn_metadata) prefill_metadata.pcp_metadata = self.build_cp_metadata(common_prefix_len, common_attn_metadata)
prefill_metadata.pcp_metadata = self.build_cp_metadata( prefill_metadata.block_table = self.block_table[self.num_decodes_flatten :, ...]
common_prefix_len, common_attn_metadata)
prefill_metadata.block_table = self.block_table[
self.num_decodes_flatten:, ...]
return prefill_metadata return prefill_metadata
def build_decode_metadata( def build_decode_metadata(
@@ -244,23 +229,19 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendMLADecodeMetadata: ) -> AscendMLADecodeMetadata:
decode_metadata = super().build_decode_metadata( decode_metadata = super().build_decode_metadata(common_prefix_len, common_attn_metadata)
common_prefix_len, common_attn_metadata)
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None assert long_seq_metadata is not None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None assert num_computed_tokens_of_pcp_dcp is not None
# [bs, pcp_size, dcp_size] # [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array( num_computed_tokens_of_cp_dcp_array = np.array(num_computed_tokens_of_pcp_dcp)[: self.num_decodes_flatten]
num_computed_tokens_of_pcp_dcp)[:self.num_decodes_flatten]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, self.dcp_rank]
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0) batch_seq_mask = cp_seq_len == 0
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True)
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]] batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
decode_metadata.cp_seq_len = cp_seq_len decode_metadata.cp_seq_len = cp_seq_len
@@ -280,30 +261,35 @@ class AscendMlaCPImpl(AscendMLAImpl):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: int, num_kv_heads: int,
alibi_slopes: Optional[list[float]], alibi_slopes: list[float] | None,
sliding_window: Optional[int], sliding_window: int | None,
kv_cache_dtype: str, kv_cache_dtype: str,
logits_soft_cap: Optional[float], logits_soft_cap: float | None,
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str], kv_sharing_target_layer_name: str | None,
**kwargs, **kwargs,
): ):
super().__init__(num_heads, head_size, scale, num_kv_heads, super().__init__(
alibi_slopes, sliding_window, kv_cache_dtype, num_heads,
logits_soft_cap, attn_type, head_size,
kv_sharing_target_layer_name, **kwargs) scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**kwargs,
)
self.pcp_size = get_pcp_group().world_size self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group( self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
).rank_in_group if self.pcp_size > 1 else 0 self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None
self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank( self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
) if self.dcp_size > 1 else 0 self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None
def get_num_actual_tokens(self, attn_metadata: M): def get_num_actual_tokens(self, attn_metadata: M):
if self.pcp_size > 1: if self.pcp_size > 1:
@@ -320,103 +306,80 @@ class AscendMlaCPImpl(AscendMLAImpl):
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x return x
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
attn_metadata):
if not self.pcp_size > 1: if not self.pcp_size > 1:
return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache, return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
attn_metadata)
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded - num_actual_tokens = (
self.pcp_size * num_decode_tokens attn_metadata.num_actual_tokens_pcp_padded - self.pcp_size * num_decode_tokens
) // self.pcp_size + num_decode_tokens ) // self.pcp_size + num_decode_tokens
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
prefill_q = self.q_proj(prefill_q_c)[0] \ prefill_q = self.q_proj(prefill_q_c)[0].view(-1, self.num_heads, self.qk_head_dim)
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :]
prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim] prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim]
cos = attn_metadata.prefill.cos[: num_actual_tokens - num_decode_tokens] cos = attn_metadata.prefill.cos[: num_actual_tokens - num_decode_tokens]
sin = attn_metadata.prefill.sin[: num_actual_tokens - num_decode_tokens] sin = attn_metadata.prefill.sin[: num_actual_tokens - num_decode_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_kv_no_split = kv_no_split[:num_actual_tokens] prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split( kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
assert len( assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_cache kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1])
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view(
[num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
prefill_k_pe = k_pe prefill_k_pe = k_pe
prefill_k_pe[num_decode_tokens:num_actual_tokens] = self.rope_single( prefill_k_pe[num_decode_tokens:num_actual_tokens] = self.rope_single(
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin) prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin
)
prefill_k_c_normed = kv_c_normed[:num_actual_tokens] prefill_k_c_normed = kv_c_normed[:num_actual_tokens]
prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe], prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe], dim=-1)
dim=-1)
prefill_kv_c_k_pe = get_pcp_group().all_gather(prefill_kv_c_k_pe, 0) prefill_kv_c_k_pe = get_pcp_group().all_gather(prefill_kv_c_k_pe, 0)
prefill_kv_c_k_pe = torch.index_select( prefill_kv_c_k_pe = torch.index_select(
prefill_kv_c_k_pe, 0, prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx) )
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens * prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens * self.pcp_size :]
self.pcp_size:] prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
prefill_k_c_normed = prefill_k_c_normed.squeeze() prefill_k_c_normed = prefill_k_c_normed.squeeze()
slot_mapping = attn_metadata.slot_mapping[self.pcp_size * slot_mapping = attn_metadata.slot_mapping[self.pcp_size * num_decode_tokens :]
num_decode_tokens:] torch_npu._npu_reshape_and_cache(
torch_npu._npu_reshape_and_cache(key=kv_c_normed, key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping
value=k_pe, )
key_cache=kv_cache[0], prefill_k_nope, prefill_value = (
value_cache=kv_cache[1], self.kv_b_proj(prefill_k_c_normed)[0]
slot_indices=slot_mapping) .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
prefill_k_nope, prefill_value = self.kv_b_proj( .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
prefill_k_c_normed)[0].view( )
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1)) prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value)
prefill_k_nope, prefill_k_pe,
prefill_value)
def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata): def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata):
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
decode_q_c = q_c[:num_decode_tokens] decode_q_c = q_c[:num_decode_tokens]
cos = attn_metadata.decode.cos cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = \ decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_q_c)
self._q_proj_and_k_up_proj(decode_q_c) decode_ql_nope, decode_q_pe = self.reorg_decode_q(decode_ql_nope, decode_q_pe)
decode_ql_nope, decode_q_pe = self.reorg_decode_q(
decode_ql_nope, decode_q_pe)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens] decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
decode_kv_no_split = kv_no_split[:num_decode_tokens] decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode( decode_k_pe, decode_k_nope = self.exec_kv_decode(decode_kv_no_split, cos, sin, kv_cache, decode_slots)
decode_kv_no_split, cos, sin, kv_cache, decode_slots) return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe)
def get_context_seq_len_npu(self, index: int, def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata):
attn_metadata: AscendMLAMetadata):
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
assert prefill_metadata is not None assert prefill_metadata is not None
assert prefill_metadata.chunked_context is not None assert prefill_metadata.chunked_context is not None
assert isinstance(prefill_metadata.chunked_context, assert isinstance(prefill_metadata.chunked_context, CPChunkedContextMetadata)
CPChunkedContextMetadata)
assert prefill_metadata.chunked_context.padded_chunk_seq_lens_npu is not None assert prefill_metadata.chunked_context.padded_chunk_seq_lens_npu is not None
iters = len(prefill_metadata.chunked_context.seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
assert 0 <= index < iters assert 0 <= index < iters
return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[ return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[index]
index]
def reorg_decode_q(self, decode_q_nope, decode_q_pe): def reorg_decode_q(self, decode_q_nope, decode_q_pe):
if self.dcp_size > 1: if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1) decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
decode_q_no_split = get_dcp_group().all_gather( decode_q_no_split = get_dcp_group().all_gather(decode_q_no_split, 1)
decode_q_no_split, 1) decode_q_nope, decode_q_pe = decode_q_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_q_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
return decode_q_nope, decode_q_pe return decode_q_nope, decode_q_pe
def _forward_prefill( def _forward_prefill(
@@ -426,12 +389,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
k_nope: torch.Tensor, k_nope: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor], kv_c_and_k_pe_cache: tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata, attn_metadata: AscendMLAMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if not self.pcp_size > 1: if not self.pcp_size > 1:
return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value, return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value, kv_c_and_k_pe_cache, attn_metadata)
kv_c_and_k_pe_cache, attn_metadata)
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None assert attn_metadata.prefill.pcp_metadata is not None
num_tokens = q_nope.size(0) num_tokens = q_nope.size(0)
@@ -455,7 +417,8 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_head_nomask_idx, kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens, attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens, attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=attn_metadata.attn_mask) mask=attn_metadata.attn_mask,
)
output_tail, lse_tail = self._attention_with_mask_and_nomask( output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx), q_nope=torch.index_select(q_nope, 0, q_tail_idx),
@@ -467,19 +430,16 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_tail_nomask_idx, kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens, attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens, attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=attn_metadata.attn_mask) mask=attn_metadata.attn_mask,
)
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
attn_output = torch.index_select( attn_output = torch.index_select(torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1), 1, q_full_idx)
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1),
1, q_full_idx)
output, _ = self._compute_prefill_context(q_nope, q_pe, output, _ = self._compute_prefill_context(
kv_c_and_k_pe_cache, q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse
self.qk_rope_head_dim, )
attn_metadata, attn_output,
attn_lse)
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim]) output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
@@ -498,20 +458,16 @@ class AscendMlaCPImpl(AscendMLAImpl):
attn_nomask_seqlens: list[torch.Tensor], attn_nomask_seqlens: list[torch.Tensor],
mask: torch.Tensor, mask: torch.Tensor,
): ):
attn_output = torch.empty(q_nope.shape[0], attn_output = torch.empty(
self.num_heads, q_nope.shape[0], self.num_heads, self.v_head_dim, dtype=k_pe.dtype, device=k_pe.device
self.v_head_dim, )
dtype=k_pe.dtype, attn_lse = torch.empty(self.num_heads, q_pe.shape[0], dtype=torch.float32, device=k_pe.device)
device=k_pe.device)
attn_lse = torch.empty(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=k_pe.device)
# mask # mask
k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx) k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx)
value_mask = torch.index_select(value, 0, kv_mask_idx) value_mask = torch.index_select(value, 0, kv_mask_idx)
k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx) k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope, torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe, q_rope=q_pe,
k_nope=k_nope_mask, k_nope=k_nope_mask,
k_rope=k_pe_mask, k_rope=k_pe_mask,
@@ -528,14 +484,14 @@ class AscendMlaCPImpl(AscendMLAImpl):
input_layout="type_bsnd", input_layout="type_bsnd",
calc_type="calc_type_first_ring", calc_type="calc_type_first_ring",
output=attn_output, output=attn_output,
softmax_lse=attn_lse) softmax_lse=attn_lse,
)
# nomask # nomask
if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0: if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0:
return attn_output, attn_lse return attn_output, attn_lse
for kv_nomask_idx_split, attn_nomask_seqlens_split in zip( for kv_nomask_idx_split, attn_nomask_seqlens_split in zip(kv_nomask_idx, attn_nomask_seqlens):
kv_nomask_idx, attn_nomask_seqlens):
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split) k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split)
value_nomask = torch.index_select(value, 0, kv_nomask_idx_split) value_nomask = torch.index_select(value, 0, kv_nomask_idx_split)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split) k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split)
@@ -557,7 +513,8 @@ class AscendMlaCPImpl(AscendMLAImpl):
input_layout="type_bsnd", input_layout="type_bsnd",
calc_type="calc_type_default", calc_type="calc_type_default",
output=attn_output, output=attn_output,
softmax_lse=attn_lse) softmax_lse=attn_lse,
)
return attn_output, attn_lse return attn_output, attn_lse
def _forward_decode( def _forward_decode(
@@ -579,10 +536,8 @@ class AscendMlaCPImpl(AscendMLAImpl):
else: else:
num_heads = self.num_heads num_heads = self.num_heads
k_nope = k_nope.view(-1, block_size, self.num_kv_heads, k_nope = k_nope.view(-1, block_size, self.num_kv_heads, self.kv_lora_rank)
self.kv_lora_rank) k_pe = k_pe.view(-1, block_size, self.num_kv_heads, self.qk_rope_head_dim)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads,
self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1) q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1) q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
@@ -606,20 +561,35 @@ class AscendMlaCPImpl(AscendMLAImpl):
workspace = graph_params.workspaces.get(num_tokens) workspace = graph_params.workspaces.get(num_tokens)
if workspace is None: if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace( workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, q_nope,
seq_len, num_heads, self.scale, self.num_kv_heads, q_pe,
**common_kwargs) k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
)
update_graph_params_workspaces(num_tokens, workspace) update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope) attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1), softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
dtype=q_nope.dtype,
device=q_nope.device)
graph_params.attn_params[num_tokens].append( graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe), (
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe), weak_ref_tensors(q_nope),
decode_meta.block_table, seq_len, num_heads, self.scale, weak_ref_tensors(q_pe),
self.num_kv_heads, weak_ref_tensors(attn_output), weak_ref_tensors(k_nope),
weak_ref_tensors(softmax_lse))) weak_ref_tensors(k_pe),
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream) torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention( torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_nope,
@@ -634,14 +604,13 @@ class AscendMlaCPImpl(AscendMLAImpl):
**common_kwargs, **common_kwargs,
workspace=workspace, workspace=workspace,
output=attn_output, output=attn_output,
lse=softmax_lse) lse=softmax_lse,
)
handle = torch.npu.graph_task_group_end(stream) handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle) graph_params.handles[num_tokens].append(handle)
else: else:
attn_output = torch.empty_like(q_nope) attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1), softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention( torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_nope,
q_pe, q_pe,
@@ -655,20 +624,17 @@ class AscendMlaCPImpl(AscendMLAImpl):
return_lse=True, return_lse=True,
calc_type="calc_type_ring", calc_type="calc_type_ring",
output=attn_output, output=attn_output,
lse=softmax_lse) lse=softmax_lse,
)
# Update out&lse # Update out&lse
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask)
decode_meta.batch_seq_mask)
attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse) attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse)
return self._v_up_proj(attn_output) return self._v_up_proj(attn_output)
def _out_lse_reshape(self, attn_out: torch.Tensor, def _out_lse_reshape(self, attn_out: torch.Tensor, attn_lse: torch.Tensor) -> torch.Tensor:
attn_lse: torch.Tensor) -> torch.Tensor: attn_out = attn_out.contiguous().view(attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_out = attn_out.contiguous().view( attn_lse = attn_lse.contiguous().view(attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse return attn_out, attn_lse
def _reorg_kvcache( def _reorg_kvcache(
@@ -706,8 +672,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
assert chunked_context.max_seq_lens is not None assert chunked_context.max_seq_lens is not None
assert chunked_context.chunk_size is not None assert chunked_context.chunk_size is not None
padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[ padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[chunk_idx]
chunk_idx]
local_context_lens_allranks = chunked_context.local_context_lens_allranks local_context_lens_allranks = chunked_context.local_context_lens_allranks
sum_seq_len = chunked_context.cu_seq_lens_lst[chunk_idx][-1] sum_seq_len = chunked_context.cu_seq_lens_lst[chunk_idx][-1]
max_seq_len = chunked_context.max_seq_lens[chunk_idx] max_seq_len = chunked_context.max_seq_lens[chunk_idx]
@@ -720,14 +685,16 @@ class AscendMlaCPImpl(AscendMLAImpl):
cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0) cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0)
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split( allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
kv_c_segments = [] kv_c_segments = []
k_pe_segments = [] k_pe_segments = []
src_token_idx = 0 src_token_idx = 0
max_seq_len_check = 0 max_seq_len_check = 0
for padded_local_chunk_seq_len, local_context_lens in zip( for padded_local_chunk_seq_len, local_context_lens in zip(
padded_local_chunk_seq_lens_lst, local_context_lens_allranks): padded_local_chunk_seq_lens_lst, local_context_lens_allranks
):
cur_seq_len = 0 cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens): for rank, local_context_len in enumerate(local_context_lens):
# Note(qcs): We split the context into multiple chunks, # Note(qcs): We split the context into multiple chunks,
@@ -742,15 +709,12 @@ class AscendMlaCPImpl(AscendMLAImpl):
padded_local_chunk_seq_len, padded_local_chunk_seq_len,
) )
if local_chunk_len != 0: if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[rank * toks + kv_c_segment = allgatered_kv_c_normed[
src_token_idx:rank * rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len
toks + ]
src_token_idx + k_pe_segment = allgatered_k_pe[
local_chunk_len] rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len
k_pe_segment = allgatered_k_pe[rank * toks + ]
src_token_idx:rank * toks +
src_token_idx +
local_chunk_len]
kv_c_segments.append(kv_c_segment) kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment) k_pe_segments.append(k_pe_segment)
cur_seq_len += local_chunk_len cur_seq_len += local_chunk_len

View File

@@ -1,18 +1,15 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Any, List, Optional from typing import Any
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config, from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type
get_ascend_device_type)
def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
@@ -21,6 +18,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
if get_ascend_device_type() == AscendDeviceType.A5: if get_ascend_device_type() == AscendDeviceType.A5:
return False return False
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY: if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
return False return False
@@ -31,8 +29,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def enable_cp(): def enable_cp():
prefill_config = get_current_vllm_config().parallel_config prefill_config = get_current_vllm_config().parallel_config
return prefill_config.prefill_context_parallel_size > 1 \ return prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1
or prefill_config.decode_context_parallel_size > 1
@dataclass @dataclass
@@ -42,13 +39,14 @@ class AscendPrefillContextParallelMetadata:
Contains index tensors and sequence lengths for PCP operations. Contains index tensors and sequence lengths for PCP operations.
""" """
pcp_allgather_restore_idx: torch.Tensor = None pcp_allgather_restore_idx: torch.Tensor = None
cp_kv_recover_idx_for_chunk: torch.Tensor = None cp_kv_recover_idx_for_chunk: torch.Tensor = None
num_actual_tokens_pcp_padded: int = 0 num_actual_tokens_pcp_padded: int = 0
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
q_head_idx_tensor: torch.Tensor = None q_head_idx_tensor: torch.Tensor = None
@@ -85,6 +83,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
For many of the tensors we keep both NPU and CPU versions. For many of the tensors we keep both NPU and CPU versions.
""" """
# CPU tensor of sequence lengths for host-side operations. # CPU tensor of sequence lengths for host-side operations.
# E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths. # E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths.
seq_lens_cpu: torch.Tensor = None seq_lens_cpu: torch.Tensor = None
@@ -115,20 +114,17 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
num_input_tokens: int = 0 num_input_tokens: int = 0
# Metadata for Prefill Context Parallelism (PCP) operations. # Metadata for Prefill Context Parallelism (PCP) operations.
prefill_context_parallel_metadata: Optional[ prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None
AscendPrefillContextParallelMetadata] = None
# TODO: Remove it when vLLM no longer uses this function. # TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int, def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
# This only use to eagle now. It will be use to enforce_eager in future. # This only use to eagle now. It will be use to enforce_eager in future.
return AscendCommonAttentionMetadata( return AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1], query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs], seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs], seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self. num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
num_computed_tokens_cpu[:num_actual_reqs],
num_reqs=num_actual_reqs, num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
@@ -144,14 +140,14 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
attn_state=self.attn_state, attn_state=self.attn_state,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=self.num_input_tokens, num_input_tokens=self.num_input_tokens,
prefill_context_parallel_metadata=self. prefill_context_parallel_metadata=self.prefill_context_parallel_metadata,
prefill_context_parallel_metadata, max_seq_len=self.max_seq_len,
max_seq_len=self.max_seq_len) )
def filter_chunked_req_indices( def filter_chunked_req_indices(
seq_len: torch.Tensor, seq_len: torch.Tensor,
mask_for_non_zero_chunk: Optional[List[bool]], mask_for_non_zero_chunk: list[bool] | None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
filter the reqs which are doing real chunk_prefill. filter the reqs which are doing real chunk_prefill.
@@ -162,14 +158,15 @@ def filter_chunked_req_indices(
Returns: Returns:
filtered_indices: the real chunked req's indices filtered_indices: the real chunked req's indices
""" """
assert mask_for_non_zero_chunk is not None and len(seq_len) == len( assert mask_for_non_zero_chunk is not None and len(seq_len) == len(mask_for_non_zero_chunk)
mask_for_non_zero_chunk)
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0) offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
filtered_indices = torch.cat([ filtered_indices = torch.cat(
[
torch.arange(offsets[i], offsets[i] + seq_len[i]) torch.arange(offsets[i], offsets[i] + seq_len[i])
for i in range(len(mask_for_non_zero_chunk)) for i in range(len(mask_for_non_zero_chunk))
if mask_for_non_zero_chunk[i] if mask_for_non_zero_chunk[i]
]) ]
)
return filtered_indices return filtered_indices
@@ -195,12 +192,9 @@ def split_decodes_and_prefills(
num_prefill_tokens: The number of tokens in the prefill requests. num_prefill_tokens: The number of tokens in the prefill requests.
""" """
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \ query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu if long_seq_metadata else None
if long_seq_metadata else None max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full if long_seq_metadata else 0
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \ max_query_len = common_attn_metadata.max_query_len if max_query_len_pcp_full == 0 else max_query_len_pcp_full
if long_seq_metadata else 0
max_query_len = common_attn_metadata.max_query_len \
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu query_start_loc = common_attn_metadata.query_start_loc_cpu
@@ -208,8 +202,7 @@ def split_decodes_and_prefills(
if max_query_len <= decode_threshold: if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \ query_lens = (query_start_loc[1:] - query_start_loc[:-1]) if query_lens_pcp_full is None else query_lens_pcp_full
if query_lens_pcp_full is None else query_lens_pcp_full
is_prefill = query_lens > decode_threshold is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill): if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
@@ -238,7 +231,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
def maybe_save_kv_layer_to_connector( def maybe_save_kv_layer_to_connector(
layer_name: str, layer_name: str,
kv_cache_layer: List[torch.Tensor], kv_cache_layer: list[torch.Tensor],
): ):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return return
@@ -264,8 +257,7 @@ def trans_rope_weight(weight, rope_dim):
return weight.contiguous() return weight.contiguous()
nope_part = weight[..., :-rope_dim, :] nope_part = weight[..., :-rope_dim, :]
rope_part = weight[..., -rope_dim:, :] rope_part = weight[..., -rope_dim:, :]
reordered_rope_part = torch.cat( reordered_rope_part = torch.cat((rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous() return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
@@ -278,12 +270,9 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
nz_mat = torch.permute( nz_mat = torch.permute(
torch.reshape( torch.reshape(
nd_mat, nd_mat,
(r // block_size[0], block_size[0], c // block_size[1], (r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
block_size[1]),
), ),
[2, 0, 1, 3], [2, 0, 1, 3],
) )
nz_mat = torch.reshape( nz_mat = torch.reshape(nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
nz_mat,
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
return nz_mat return nz_mat

View File

@@ -27,8 +27,12 @@ logger = init_logger(__name__)
if HAS_TRITON: if HAS_TRITON:
from vllm_ascend.ops.triton.batch_invariant.matmul import ( from vllm_ascend.ops.triton.batch_invariant.matmul import (
addmm_batch_invariant, bmm_batch_invariant, linear_batch_invariant, addmm_batch_invariant,
matmul_batch_invariant, mm_batch_invariant) bmm_batch_invariant,
linear_batch_invariant,
matmul_batch_invariant,
mm_batch_invariant,
)
def override_envs_for_invariance(): def override_envs_for_invariance():
@@ -73,10 +77,11 @@ def init_batch_invariance():
if vllm_is_batch_invariant(): if vllm_is_batch_invariant():
if HAS_TRITON: if HAS_TRITON:
logger.info( logger.info(
"Enabling batch-invariant mode for vLLM on Ascend NPU.", ) "Enabling batch-invariant mode for vLLM on Ascend NPU.",
)
override_envs_for_invariance() override_envs_for_invariance()
enable_batch_invariant_mode() enable_batch_invariant_mode()
else: else:
logger.warning( logger.warning(
"Batch-invariant mode requested but Triton is not available." "Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.",
"skipping batch-invariant initialization.", ) )

View File

@@ -15,35 +15,26 @@
# limitations under the License. # limitations under the License.
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from typing import Optional, Type
import torch_npu import torch_npu
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
class BaseDeviceAdaptor(object): class BaseDeviceAdaptor:
@classmethod @classmethod
def reshape_and_cache(cls, key, value, key_cache, value_cache, def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
slot_mapping): torch_npu._npu_reshape_and_cache(
torch_npu._npu_reshape_and_cache(key=key, key=key, value=value, key_cache=key_cache, value_cache=value_cache, slot_indices=slot_mapping
value=value, )
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slot_mapping)
class A5DeviceAdaptor(BaseDeviceAdaptor): class A5DeviceAdaptor(BaseDeviceAdaptor):
@classmethod @classmethod
def reshape_and_cache(cls, key, value, key_cache, value_cache, def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
slot_mapping): torch_npu.npu_scatter_pa_kv_cache(
torch_npu.npu_scatter_pa_kv_cache(key=key, key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping
value=value.contiguous(), )
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping)
def get_device_adaptor(): def get_device_adaptor():
@@ -53,4 +44,4 @@ def get_device_adaptor():
return BaseDeviceAdaptor return BaseDeviceAdaptor
DeviceOperator: Optional[Type['BaseDeviceAdaptor']] = get_device_adaptor() DeviceOperator: type["BaseDeviceAdaptor"] | None = get_device_adaptor()

View File

@@ -18,15 +18,16 @@
# #
import dataclasses import dataclasses
import os import os
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any
import torch import torch
from acl.rt import memcpy # type: ignore # noqa: F401 from acl.rt import memcpy # type: ignore # noqa: F401
from vllm.logger import logger from vllm.logger import logger
def find_loaded_library(lib_name) -> Optional[str]: def find_loaded_library(lib_name) -> str | None:
""" """
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the the file `/proc/self/maps` contains the memory maps of the process, which includes the
@@ -47,20 +48,22 @@ def find_loaded_library(lib_name) -> Optional[str]:
start = found_line.index("/") start = found_line.index("/")
path = found_line[start:].strip() path = found_line[start:].strip()
filename = path.split("/")[-1] filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \ assert filename.rpartition(".so")[0].startswith(lib_name), f"Unexpected filename: {filename} for library {lib_name}"
f"Unexpected filename: {filename} for library {lib_name}"
return path return path
camem_available = False camem_available = False
try: try:
from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401 from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401
init_module, python_create_and_map, python_unmap_and_release) init_module,
python_create_and_map,
python_unmap_and_release,
)
lib_name = find_loaded_library("vllm_ascend_C") lib_name = find_loaded_library("vllm_ascend_C")
camem_available = True camem_available = True
except ImportError as e: except ImportError as e:
logger.warning( logger.warning("Failed to import vllm_ascend_C:%s. Sleep mode will be disabled. ", e)
"Failed to import vllm_ascend_C:%s. Sleep mode will be disabled. ", e)
init_module = None init_module = None
python_create_and_map = None python_create_and_map = None
python_unmap_and_release = None python_unmap_and_release = None
@@ -68,14 +71,14 @@ except ImportError as e:
libcudart = None libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle # py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int] HandleType = tuple[int, int, int, int]
@dataclasses.dataclass @dataclasses.dataclass
class AllocationData: class AllocationData:
handle: HandleType handle: HandleType
tag: str tag: str
cpu_backup_tensor: Optional[torch.Tensor] = None cpu_backup_tensor: torch.Tensor | None = None
def create_and_map(allocation_handle: HandleType) -> None: def create_and_map(allocation_handle: HandleType) -> None:
@@ -88,18 +91,18 @@ def unmap_and_release(allocation_handle: HandleType) -> None:
def get_pluggable_allocator( def get_pluggable_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None], python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]] python_free_func: Callable[[int], tuple[int, int, int, int]],
) -> torch.npu.memory.NPUPluggableAllocator: ) -> torch.npu.memory.NPUPluggableAllocator:
init_module(python_malloc_fn, python_free_func) init_module(python_malloc_fn, python_free_func)
new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, 'my_malloc', new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, "my_malloc", "my_free")
'my_free')
return new_alloc return new_alloc
@contextmanager @contextmanager
def use_memory_pool_with_allocator( def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None], python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]): python_free_func: Callable[[int], tuple[int, int, int, int]],
):
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.npu.memory.MemPool(new_alloc._allocator) mem_pool = torch.npu.memory.MemPool(new_alloc._allocator)
with torch.npu.memory.use_mem_pool(mem_pool): with torch.npu.memory.use_mem_pool(mem_pool):
@@ -127,6 +130,7 @@ class CaMemAllocator:
the global variable will be overwritten and the free callback will the global variable will be overwritten and the free callback will
not work as expected. not work as expected.
""" """
instance = None instance = None
default_tag: str = "default" default_tag: str = "default"
@@ -143,22 +147,22 @@ class CaMemAllocator:
def __init__(self): def __init__(self):
conf = os.environ.get("PYTORCH_NPU_ALLOC_CONF", "") conf = os.environ.get("PYTORCH_NPU_ALLOC_CONF", "")
assert "expandable_segments:True" not in conf, \ assert "expandable_segments:True" not in conf, (
("Expandable segments are not compatible with memory pool. " "Expandable segments are not compatible with memory pool. "
"Please track https://github.com/pytorch/pytorch/issues/147851 " "Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates.") "for the latest updates."
)
self.pointer_to_data: Dict[int, AllocationData] = {} self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CaMemAllocator.default_tag self.current_tag: str = CaMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {} self.allocator_and_pools: dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
Internal method to store the allocation data Internal method to store the allocation data
when memory is allocated in the memory pool.""" when memory is allocated in the memory pool."""
py_d_mem = allocation_handle[2] py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData( self.pointer_to_data[py_d_mem] = AllocationData(allocation_handle, self.current_tag)
allocation_handle, self.current_tag)
return return
def python_free_callback(self, ptr: int) -> HandleType: def python_free_callback(self, ptr: int) -> HandleType:
@@ -170,10 +174,7 @@ class CaMemAllocator:
data.cpu_backup_tensor = None data.cpu_backup_tensor = None
return data.handle return data.handle
def sleep( def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
self,
offload_tags: Optional[Union[Tuple[str, ...],
str]] = None) -> None:
""" """
Put the allocator in sleep mode. Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be All data in the memory allocation with the specified tag will be
@@ -194,19 +195,15 @@ class CaMemAllocator:
handle = data.handle handle = data.handle
if data.tag in offload_tags: if data.tag in offload_tags:
size_in_bytes = handle[1] size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(size_in_bytes, cpu_backup_tensor = torch.empty(size_in_bytes, dtype=torch.uint8, device="cpu", pin_memory=True)
dtype=torch.uint8,
device='cpu',
pin_memory=True)
cpu_ptr = cpu_backup_tensor.data_ptr() cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_DEVICE_TO_HOST = 2 ACL_MEMCPY_DEVICE_TO_HOST = 2
dest_max = cpu_ptr + size_in_bytes * 2 dest_max = cpu_ptr + size_in_bytes * 2
memcpy(cpu_ptr, dest_max, ptr, size_in_bytes, memcpy(cpu_ptr, dest_max, ptr, size_in_bytes, ACL_MEMCPY_DEVICE_TO_HOST)
ACL_MEMCPY_DEVICE_TO_HOST)
data.cpu_backup_tensor = cpu_backup_tensor data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle) unmap_and_release(handle)
def wake_up(self, tags: Optional[list[str]] = None) -> None: def wake_up(self, tags: list[str] | None = None) -> None:
""" """
Wake up the allocator from sleep mode. Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU All data that is previously offloaded will be loaded back to GPU
@@ -218,17 +215,15 @@ class CaMemAllocator:
if data.cpu_backup_tensor is not None: if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None: if cpu_backup_tensor is not None:
size_in_bytes = cpu_backup_tensor.numel( size_in_bytes = cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
) * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr() cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_HOST_TO_DEVICE = 1 ACL_MEMCPY_HOST_TO_DEVICE = 1
dest_max = ptr + size_in_bytes * 2 dest_max = ptr + size_in_bytes * 2
memcpy(ptr, dest_max, cpu_ptr, size_in_bytes, memcpy(ptr, dest_max, cpu_ptr, size_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE)
ACL_MEMCPY_HOST_TO_DEVICE)
data.cpu_backup_tensor = None data.cpu_backup_tensor = None
@contextmanager @contextmanager
def use_memory_pool(self, tag: Optional[str] = None): def use_memory_pool(self, tag: str | None = None):
""" """
A context manager to use the memory pool. A context manager to use the memory pool.
All memory allocation created inside the context will be allocated All memory allocation created inside the context will be allocated
@@ -243,8 +238,7 @@ class CaMemAllocator:
old_tag = self.current_tag old_tag = self.current_tag
self.current_tag = tag self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback, with use_memory_pool_with_allocator(self.python_malloc_callback, self.python_free_callback) as data:
self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6, # start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and # possibly because of gc-related issue w.r.t. the allocator and
# the memory pool. # the memory pool.

View File

@@ -19,107 +19,89 @@
# #
import os import os
from typing import Any, Callable, Dict from collections.abc import Callable
from typing import Any
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the used env vars. # to extract the used env vars.
# begin-env-vars-definition # begin-env-vars-definition
env_variables: Dict[str, Callable[[], Any]] = { env_variables: dict[str, Callable[[], Any]] = {
# max compile thread number for package building. Usually, it is set to # max compile thread number for package building. Usually, it is set to
# the number of CPU cores. If not set, the default value is None, which # the number of CPU cores. If not set, the default value is None, which
# means all number of CPU cores will be used. # means all number of CPU cores will be used.
"MAX_JOBS": "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
lambda: os.getenv("MAX_JOBS", None),
# The build type of the package. It can be one of the following values: # The build type of the package. It can be one of the following values:
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release. # Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
"CMAKE_BUILD_TYPE": "CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
lambda: os.getenv("CMAKE_BUILD_TYPE"),
# The CXX compiler used for compiling the package. If not set, the default # The CXX compiler used for compiling the package. If not set, the default
# value is None, which means the system default CXX compiler will be used. # value is None, which means the system default CXX compiler will be used.
"CXX_COMPILER": "CXX_COMPILER": lambda: os.getenv("CXX_COMPILER", None),
lambda: os.getenv("CXX_COMPILER", None),
# The C compiler used for compiling the package. If not set, the default # The C compiler used for compiling the package. If not set, the default
# value is None, which means the system default C compiler will be used. # value is None, which means the system default C compiler will be used.
"C_COMPILER": "C_COMPILER": lambda: os.getenv("C_COMPILER", None),
lambda: os.getenv("C_COMPILER", None),
# The version of the Ascend chip. It's used for package building. # The version of the Ascend chip. It's used for package building.
# If not set, we will query chip info through `npu-smi`. # If not set, we will query chip info through `npu-smi`.
# Please make sure that the version is correct. # Please make sure that the version is correct.
"SOC_VERSION": "SOC_VERSION": lambda: os.getenv("SOC_VERSION", None),
lambda: os.getenv("SOC_VERSION", None),
# If set, vllm-ascend will print verbose logs during compilation # If set, vllm-ascend will print verbose logs during compilation
"VERBOSE": "VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))),
lambda: bool(int(os.getenv('VERBOSE', '0'))),
# The home path for CANN toolkit. If not set, the default value is # The home path for CANN toolkit. If not set, the default value is
# /usr/local/Ascend/ascend-toolkit/latest # /usr/local/Ascend/ascend-toolkit/latest
"ASCEND_HOME_PATH": "ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None),
lambda: os.getenv("ASCEND_HOME_PATH", None),
# The path for HCCL library, it's used by pyhccl communicator backend. If # The path for HCCL library, it's used by pyhccl communicator backend. If
# not set, the default value is libhccl.so. # not set, the default value is libhccl.so.
"HCCL_SO_PATH": "HCCL_SO_PATH": lambda: os.environ.get("HCCL_SO_PATH", None),
lambda: os.environ.get("HCCL_SO_PATH", None),
# The version of vllm is installed. This value is used for developers who # The version of vllm is installed. This value is used for developers who
# installed vllm from source locally. In this case, the version of vllm is # installed vllm from source locally. In this case, the version of vllm is
# usually changed. For example, if the version of vllm is "0.9.0", but when # usually changed. For example, if the version of vllm is "0.9.0", but when
# it's installed from source, the version of vllm is usually set to "0.9.1". # it's installed from source, the version of vllm is usually set to "0.9.1".
# In this case, developers need to set this value to "0.9.0" to make sure # In this case, developers need to set this value to "0.9.0" to make sure
# that the correct package is installed. # that the correct package is installed.
"VLLM_VERSION": "VLLM_VERSION": lambda: os.getenv("VLLM_VERSION", None),
lambda: os.getenv("VLLM_VERSION", None),
# Whether to enable the model execute time observe profile. Disable it when # Whether to enable the model execute time observe profile. Disable it when
# running vllm ascend in production environment. # running vllm ascend in production environment.
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", "0"))
), ),
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf # Some models are optimized by vllm ascend. While in some case, e.g. rlhf
# training, the optimized model may not be suitable. In this case, set this # training, the optimized model may not be suitable. In this case, set this
# value to False to disable the optimized model. # value to False to disable the optimized model.
"USE_OPTIMIZED_MODEL": "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv("USE_OPTIMIZED_MODEL", "1"))),
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled. # Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
# this feature is supported in A2, and eager mode will get better performance. # this feature is supported in A2, and eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", "0"))),
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
# Whether to enable FlashComm optimization when tensor parallel is enabled. # Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large. # This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM1": "VLLM_ASCEND_ENABLE_FLASHCOMM1": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))),
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it. # Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
# The specific value set will be used as the O-matrix TP group size for flashcomm2. # The specific value set will be used as the O-matrix TP group size for flashcomm2.
# For a detailed introduction to the parameters and the differences and applicable scenarios # For a detailed introduction to the parameters and the differences and applicable scenarios
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
# Whether to enable MLP weight prefetch, only used in small concurrency. # Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))),
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
# buffer size for gate up prefetch # buffer size for gate up prefetch
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int(
lambda: int( os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), ),
# buffer size for down proj prefetch # buffer size for down proj prefetch
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int(
lambda: int( os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), ),
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend. # Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
"MSMONITOR_USE_DAEMON": "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))),
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), "VLLM_ASCEND_ENABLE_MLAPO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", "0"))),
"VLLM_ASCEND_ENABLE_MLAPO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
# Whether to enable weight cast format to FRACTAL_NZ. # Whether to enable weight cast format to FRACTAL_NZ.
# 0: close nz; # 0: close nz;
# 1: only quant case enable nz; # 1: only quant case enable nz;
# 2: enable nz as long as possible. # 2: enable nz as long as possible.
"VLLM_ASCEND_ENABLE_NZ": "VLLM_ASCEND_ENABLE_NZ": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
# Decide whether we should enable CP parallelism. # Decide whether we should enable CP parallelism.
"VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": "VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", "0"))),
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", '0'))),
# Whether to anbale dynamic EPLB # Whether to anbale dynamic EPLB
"DYNAMIC_EPLB": "DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator) # Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator)
# 0, or not set: default ALLTOALL and MC2 will be used. # 0, or not set: default ALLTOALL and MC2 will be used.
# 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator. # 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator.
@@ -127,11 +109,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator. # 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer # `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
# with W8A8. And MTP layer must be W8A8. # with W8A8. And MTP layer must be W8A8.
"VLLM_ASCEND_ENABLE_FUSED_MC2": "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")),
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
# Whether to anbale balance scheduling # Whether to anbale balance scheduling
"VLLM_ASCEND_BALANCE_SCHEDULING": "VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))),
lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", '0'))),
} }
# end-env-vars-definition # end-env-vars-definition