[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #2) (#5977)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/attention_mask.py` |
| `vllm_ascend/attention/attention_v1.py` |
| `vllm_ascend/attention/context_parallel/attention_cp.py` |
| `vllm_ascend/attention/context_parallel/common_cp.py` |
| `vllm_ascend/attention/context_parallel/mla_cp.py` |
| `vllm_ascend/attention/utils.py` |
| `vllm_ascend/batch_invariant.py` |
| `vllm_ascend/device/device_op.py` |
| `vllm_ascend/device_allocator/camem.py` |
| `vllm_ascend/envs.py` |


- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-19 08:59:46 +08:00
committed by GitHub
parent 2b6dc100b5
commit 329961b375
11 changed files with 920 additions and 1045 deletions

View File

@@ -49,11 +49,9 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
"vllm_ascend/_cann_ops_custom",
"vllm_ascend/attention",
"vllm_ascend/attention/mla_v1.py",
"vllm_ascend/attention/sfa_v1.py",
"vllm_ascend/core",
"vllm_ascend/device",
"vllm_ascend/device_allocator",
"vllm_ascend/distributed",
"vllm_ascend/eplb",
"vllm_ascend/kv_offload",
@@ -66,8 +64,6 @@ exclude = [
"vllm_ascend/spec_decode",
"vllm_ascend/worker",
"vllm_ascend/xlite",
"vllm_ascend/envs.py",
"vllm_ascend/batch_invariant.py",
]
[tool.ruff.lint]

View File

@@ -21,21 +21,18 @@ from vllm_ascend.utils import singleton
def _generate_attn_mask(max_seq_len, dtype):
# Construct lower triangle matrix.
mask_flag = torch.ones((max_seq_len, max_seq_len),
dtype=torch.bool).tril_()
mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()
# Create upper triangle matrix used to mark mask positions.
mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future.
mask_value = float('-inf') if dtype == torch.float16 else 1
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \
.masked_fill_(mask_flag, mask_value)
mask_value = float("-inf") if dtype == torch.float16 else 1
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype).masked_fill_(mask_flag, mask_value)
return attn_mask
@singleton
class AttentionMaskBuilder:
def __init__(self, device: torch.device):
self.attn_mask_cache = None
self._seq_len_cached = 0
@@ -52,14 +49,13 @@ class AttentionMaskBuilder:
assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask."
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(self.device, non_blocking=True)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous().to(self.device, non_blocking=True)
def get_splitfuse_attn_mask(self) -> torch.Tensor:
if self.chunked_prefill_attn_mask is None:
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(2048,
2048), diagonal=1).to(torch.int8).to(self.device)
self.chunked_prefill_attn_mask = (
torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(self.device)
)
return self.chunked_prefill_attn_mask
def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
@@ -68,16 +64,13 @@ class AttentionMaskBuilder:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
prefill_mask = torch.triu(
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.mla_mask = torch.where(prefill_mask == 1, mask_value,
0).to(dtype)
prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.mla_mask = torch.where(prefill_mask == 1, mask_value, 0).to(dtype)
return self.mla_mask
def get_pcp_mla_mask(self, dtype: torch.dtype):
if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype:
self.pcp_mla_mask = torch.triu(
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.pcp_mla_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
return self.pcp_mla_mask
def get_swa_mask(self, dtype: torch.dtype, sliding_window):
@@ -99,4 +92,4 @@ class AttentionMaskBuilder:
if get_pcp_group().world_size > 1:
return self.get_pcp_mla_mask(model_config.dtype)
# Prefill stages use 512x512 mask with appropriate dtype
return self.get_mla_mask(model_config.dtype)
return self.get_mla_mask(model_config.dtype)

View File

@@ -17,7 +17,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import ClassVar, List, Optional, Tuple, Type
from typing import ClassVar
import torch
import torch_npu
@@ -29,32 +29,49 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import (
AscendMetadataForDecode, AscendMetadataForPrefill)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp, split_decodes_and_prefills,
using_paged_attention)
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata,
enable_cp,
split_decodes_and_prefills,
using_paged_attention,
)
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces)
get_draft_graph_params,
get_graph_params,
update_draft_graph_params_workspaces,
update_graph_params_workspaces,
)
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import vllm_version_is, weak_ref_tensors
# isort: off
if vllm_version_is('0.13.0'):
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder)
if vllm_version_is("0.13.0"):
from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder
from vllm.attention.backends.abstract import ( # type: ignore
AttentionBackend, AttentionImpl, AttentionLayer, AttentionType)
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionType,
)
from vllm.attention.backends.registry import ( # type: ignore
AttentionBackendEnum, register_backend)
AttentionBackendEnum,
register_backend,
)
else:
from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend, AttentionCGSupport, AttentionImpl, AttentionLayer,
AttentionType, AttentionMetadataBuilder)
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionType,
AttentionMetadataBuilder,
)
from vllm.v1.attention.backends.registry import ( # type: ignore
AttentionBackendEnum, register_backend)
AttentionBackendEnum,
register_backend,
)
# isort: on
# default max value of sliding window size
@@ -73,18 +90,18 @@ class AscendAttentionBackend(AttentionBackend):
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
def get_impl_cls() -> type["AscendAttentionBackendImpl"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPImpl
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl
return AscendAttentionCPImpl
return AscendAttentionBackendImpl
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPMetadataBuilder
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder
return AscendAttentionCPMetadataBuilder
return AscendAttentionMetadataBuilder
@@ -94,13 +111,13 @@ class AscendAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_kv_cache: list[torch.Tensor],
dst_kv_cache: list[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
@@ -108,14 +125,12 @@ class AscendAttentionBackend(AttentionBackend):
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
@@ -148,8 +163,9 @@ class AscendMetadata:
Contains attention masks, token counts, sequence lengths and KV cache
related properties for attention computation.
"""
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
attn_mask: torch.Tensor | None = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -168,12 +184,12 @@ class AscendMetadata:
# should simplified these parameters once attention schema in vLLM-Ascend
# is unified.
seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore
seq_lens_list: list[int] = None # type: ignore
actual_seq_lengths_q: list[int] = None # type: ignore
query_start_loc: torch.Tensor = None
# Maximum query length in the batch (None for decoding).
max_query_len: Optional[int] = None
max_query_len: int | None = None
# ********************** KV Cache Related Properties ********************* #
# Block addresses per sequence (Seq id -> list of physical block).
@@ -187,9 +203,9 @@ class AscendMetadata:
# (num_tokens,)
slot_mapping: torch.Tensor = None
# pcp
prefill: Optional[AscendMetadataForPrefill] = None
prefill: AscendMetadataForPrefill | None = None
# dcp
decode_meta: Optional[AscendMetadataForDecode] = None
decode_meta: AscendMetadataForDecode | None = None
causal: bool = True
# runner_type in model_config.
@@ -198,7 +214,7 @@ class AscendMetadata:
reshape_cache_event: torch.npu.Event = None
# sliding window attention mask
swa_mask: Optional[torch.Tensor] = None
swa_mask: torch.Tensor | None = None
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
@@ -208,6 +224,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
Handles attention mask generation and metadata preparation for
Ascend FlashAttention backend.
"""
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
@@ -226,17 +243,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
self.compilation_config = vllm_config.compilation_config
self.device = device
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len,
AscendAttentionBackend.get_supported_block_size()[0])
self.model_config.max_model_len, AscendAttentionBackend.get_supported_block_size()[0]
)
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
assert self.decode_threshold <= 16, (
f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
)
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
@@ -254,8 +273,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool:
return False
def build(
@@ -266,12 +284,11 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
) -> AscendMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.decode_threshold
)
block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
@@ -283,19 +300,17 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
attn_state = common_attn_metadata.attn_state
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
attn_mask = self.attn_mask_builder.get_attention_mask(
self.model_config)
attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config)
swa_mask = None
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window')
is_swa = hasattr(self.model_config.hf_text_config, "sliding_window")
if self.model_config is not None and is_swa:
swa_mask = self.attn_mask_builder.get_swa_mask(
self.model_config.dtype,
self.model_config.hf_text_config.sliding_window)
self.model_config.dtype, self.model_config.hf_text_config.sliding_window
)
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
@@ -313,7 +328,8 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
num_prefills=num_prefills,
num_decodes=num_decodes,
causal=common_attn_metadata.causal,
model_runner_type=self.model_config.runner_type)
model_runner_type=self.model_config.runner_type,
)
return attn_metadata
def build_for_graph_capture(
@@ -321,9 +337,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state in (AscendAttentionState.DecodeOnly,
AscendAttentionState.ChunkedPrefill):
if attn_state in (AscendAttentionState.DecodeOnly, AscendAttentionState.ChunkedPrefill):
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
@@ -338,19 +352,18 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
class AscendAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
kv_sharing_target_layer_name: str | None,
**kwargs,
) -> None:
self.vllm_config = get_current_vllm_config()
@@ -362,9 +375,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
@@ -372,18 +383,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.post_process_after_loading()
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: torch.Tensor) -> torch.Tensor:
key, value, block_size, block_table, actual_seq_lengths_kv \
= self._get_fia_params(key, value, attn_metadata)
def full_graph_fia(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
forward_context = get_forward_context()
@@ -427,12 +444,22 @@ class AscendAttentionBackendImpl(AttentionImpl):
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))
(
weak_ref_tensors(query),
weak_ref_tensors(key),
weak_ref_tensors(value),
weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask),
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
self.num_kv_heads,
self.num_heads,
self.scale,
weak_ref_tensors(output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
@@ -463,7 +490,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
):
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
@@ -481,7 +508,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
out=output,
)
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
@@ -491,17 +519,19 @@ class AscendAttentionBackendImpl(AttentionImpl):
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
@@ -514,53 +544,54 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
workspace=workspace,
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata):
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
if self.attn_type == AttentionType.ENCODER_DECODER:
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens,
dim=0).tolist()
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist()
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.seq_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
num_block, block_size, -1
)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
num_block, block_size, -1
)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# chunked prefill.
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
num_block, block_size, -1
)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
num_block, block_size, -1
)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
return key, value, block_size, block_table, actual_seq_lengths_kv
def _forward_fia_slidingwindow(self, query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor):
def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor):
batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size)
@@ -583,34 +614,41 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale=self.scale,
block_table=attn_metadata.block_tables,
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
actual_seq_lengths_kv=attn_metadata.seq_lens)
actual_seq_lengths_kv=attn_metadata.seq_lens,
)
output = output.view(batch_size, self.num_heads, self.head_size)
return output
def forward_fused_infer_attention(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor):
def forward_fused_infer_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
forward_context: ForwardContext = get_forward_context()
# we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error.
if getattr(forward_context, "capturing", False):
attn_output, num_tokens = self.full_graph_fia(
query, key, value, attn_metadata, output)
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and self.sliding_window is not None
and attn_metadata.seq_lens.shape[0] == query.size(0)):
return self._forward_fia_slidingwindow(query, attn_metadata,
output)
key, value, block_size, block_table, actual_seq_lengths_kv \
= self._get_fia_params(key, value, attn_metadata)
if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and self.sliding_window is not None
and attn_metadata.seq_lens.shape[0] == query.size(0)
):
return self._forward_fia_slidingwindow(query, attn_metadata, output)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER:
if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens]
value = value[:num_tokens]
# Get workspace from cache or calculate it if not present.
@@ -630,8 +668,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
sparse_mode=3,
)
attn_output = attn_output.view(num_tokens, self.num_heads,
self.head_size)
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output[:num_tokens]
return output
@@ -639,26 +676,32 @@ class AscendAttentionBackendImpl(AttentionImpl):
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
return self.full_graph_pa(query, attn_metadata, output)
torch_npu._npu_paged_attention(query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
)
return output
def _forward_encoder_attention(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata,
_: torch.Tensor) -> torch.Tensor:
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
_: torch.Tensor,
) -> torch.Tensor:
assert attn_metadata is not None
if attn_metadata.causal:
@@ -692,26 +735,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
self,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
):
if len(kv_cache) > 1:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER)
encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER
DeviceOperator.reshape_and_cache(
key=key[:attn_metadata.num_actual_tokens]
if not encoder_decoder else key,
value=value[:attn_metadata.num_actual_tokens]
if not encoder_decoder else value,
key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key,
value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_mapping=slots[:attn_metadata.num_actual_tokens]
if not encoder_decoder else slots)
slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots,
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value
@@ -721,18 +761,19 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
num_tokens = query.shape[0]
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and using_paged_attention(num_tokens, self.vllm_config)
and self.sliding_window is None):
if (
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
and using_paged_attention(num_tokens, self.vllm_config)
and self.sliding_window is None
):
output = self.forward_paged_attention(query, attn_metadata, output)
else:
output = self.forward_fused_infer_attention(
query, key, value, attn_metadata, output)
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
return output
@@ -742,11 +783,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
@@ -762,23 +803,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for AscendAttentionBackendImpl")
raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
if key is not None and value is not None:
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata)
# pooling model branch
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(
query, key, value, attn_metadata, output)
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
output = self.forward_impl(query, key, value, kv_cache, attn_metadata,
output)
output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
return output

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,9 @@
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_world_size,
get_pcp_group)
from vllm.distributed import get_dcp_group, get_decode_context_model_parallel_world_size, get_pcp_group
@dataclass
@@ -17,6 +14,7 @@ class AscendPCPMetadata:
Stores index tensors and sequence lengths for routing attention
computations across PCP ranks during long sequence processing.
"""
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
@@ -27,7 +25,7 @@ class AscendPCPMetadata:
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
pcp_allgather_restore_idx: list[int] | None = None
@dataclass
@@ -37,6 +35,7 @@ class CPChunkedContextMetadata:
Extends chunked prefill with per-rank chunk information for PCP/DCP.
"""
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
@@ -47,48 +46,51 @@ class CPChunkedContextMetadata:
chunk_seq_lens_npu: torch.Tensor
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
padded_local_chunk_seq_lens: list[list[int]] | None = None
local_context_lens_allranks: list[list[int]] | None = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None
cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: int | None = None
@dataclass
class AscendMetadataForPrefill:
""" Prefill-specific metadata for Ascend attention with Context Parallelism."""
"""Prefill-specific metadata for Ascend attention with Context Parallelism."""
@dataclass
class ChunkedContextMetadata:
"""Metadata for chunked context processing within prefill phase."""
actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor
chunk_seq_mask_filtered_indices: torch.Tensor
chunked_req_mask: Optional[list[bool]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None
batch_chunk_seq_mask: Optional[list[bool]] = None
local_total_toks: Optional[int] = None
chunked_req_mask: list[bool] | None = None
local_context_lens_allranks: list[list[int]] | None = None
cp_kv_recover_idx_for_chunk: list[int] | None = None
kv_inverse_idx_for_chunk: list[int] | None = None
batch_chunk_seq_mask: list[bool] | None = None
local_total_toks: int | None = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
chunked_context: Optional[ChunkedContextMetadata] = None
pcp_metadata: AscendPCPMetadata | None = None
chunked_context: ChunkedContextMetadata | None = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None
@dataclass
class AscendMetadataForDecode:
""" Decode-specific metadata for Ascend attention with Context Parallelism."""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
"""Decode-specific metadata for Ascend attention with Context Parallelism."""
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None
def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
batch_seq_mask: torch.Tensor) -> torch.Tensor:
def _process_attn_out_lse(
attn_output: torch.Tensor, softmax_lse: torch.Tensor, batch_seq_mask: torch.Tensor
) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
dcp_group = get_dcp_group().device_group if dcp_size > 1 else None
@@ -104,21 +106,17 @@ def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=dcp_group)
dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=dcp_group)
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
if pcp_size > 1:
# AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(),
dim=0)
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0)
return attn_out_lse
def _npu_attention_update(head_size,
attn_out_lse: torch.Tensor) -> torch.Tensor:
def _npu_attention_update(head_size, attn_out_lse: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
# [PCP * S, DCP * H, D+1]
@@ -134,8 +132,7 @@ def _npu_attention_update(head_size,
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1)
# Split out lse
out_flat, lse_flat = torch.split(x, [D, 1],
dim=-1) # [N, S, H, D], [N, S, H, 1]
out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1]
# out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]

View File

@@ -1,35 +1,43 @@
from typing import Optional, Tuple, TypeVar
from typing import TypeVar
import numpy as np
import torch
import torch_npu
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group)
from vllm.distributed import (
get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
# isort: off
from vllm_ascend.attention.mla_v1 import (
AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder, AscendMLAPrefillMetadata,
DecodeMLAPreprocessResult, PrefillMLAPreprocessResult,
BUILD_METADATA_STEP_PREFILL)
#isort: on
AscendMLADecodeMetadata,
AscendMLAImpl,
AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata,
DecodeMLAPreprocessResult,
PrefillMLAPreprocessResult,
BUILD_METADATA_STEP_PREFILL,
)
# isort: on
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata)
from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, CPChunkedContextMetadata, _process_attn_out_lse,
_npu_attention_update)
from vllm_ascend.compilation.acl_graph import (get_draft_graph_params,
get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.utils import weak_ref_tensors, vllm_version_is
AscendPCPMetadata,
CPChunkedContextMetadata,
_npu_attention_update,
_process_attn_out_lse,
)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import get_draft_graph_params, get_graph_params, update_graph_params_workspaces
from vllm_ascend.utils import vllm_version_is, weak_ref_tensors
if vllm_version_is('0.13.0'):
if vllm_version_is("0.13.0"):
from vllm.v1.attention.backends.utils import AttentionCGSupport
else:
from vllm.v1.attention.backend import AttentionCGSupport
@@ -54,28 +62,21 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
metadata_cls, supports_dcp_with_varlen)
super().__init__(kv_cache_spec, layer_names, vllm_config, device, metadata_cls, supports_dcp_with_varlen)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
scheduler_config = vllm_config.scheduler_config
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs *
self.decode_threshold,
dtype=torch.uint8,
device=device)
self.block_size = (self.block_size *
self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size)
self.batch_seq_mask_buf = torch.empty(max_num_seqs * self.decode_threshold, dtype=torch.uint8, device=device)
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size
)
def build(
self,
@@ -85,15 +86,10 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
) -> AscendMLAMetadata:
metadata_cls = super().build(common_prefix_len, common_attn_metadata)
if self.num_prefills == 0 and self.pcp_size > 1:
self.slot_mapping[:self.
num_decode_tokens] = self.slot_mapping[:self.
num_decode_tokens
* self.
pcp_size:
self.
pcp_size]
self.slot_mapping[self.num_decode_tokens:self.num_decode_tokens *
self.pcp_size].fill_(-1)
self.slot_mapping[: self.num_decode_tokens] = self.slot_mapping[
: self.num_decode_tokens * self.pcp_size : self.pcp_size
]
self.slot_mapping[self.num_decode_tokens : self.num_decode_tokens * self.pcp_size].fill_(-1)
metadata_cls.slot_mapping = self.slot_mapping
return metadata_cls
@@ -118,8 +114,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
# In dcp only spec decode graph padding case,
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
self.num_actual_tokens = max(
long_seq_metadata.num_actual_tokens_pcp_padded,
common_attn_metadata.num_actual_tokens)
long_seq_metadata.num_actual_tokens_pcp_padded, common_attn_metadata.num_actual_tokens
)
def build_cp_metadata(
self,
@@ -131,30 +127,23 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
return AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx)
pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
)
def build_chunked_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
):
chunked_context_metadata = super().build_chunked_metadata(
common_prefix_len, common_attn_metadata)
chunked_context_metadata = super().build_chunked_metadata(common_prefix_len, common_attn_metadata)
if chunked_context_metadata is None:
return None
@@ -162,33 +151,37 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
assert long_seq_metadata is not None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten:]).reshape(
-1, self.dcp_size * self.pcp_size)
local_context_lens_allranks = torch.tensor(num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten :]).reshape(
-1, self.dcp_size * self.pcp_size
)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
self.context_lens_cpu,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
padded_local_max_context_chunk_across_ranks = (cdiv(
self.max_context_chunk,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
local_chunk_starts = (torch.arange(
self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(
-1, self.num_prefills) *
padded_local_max_context_chunk_across_ranks)
padded_local_context_lens_cpu = (
cdiv(
self.context_lens_cpu,
self.cp_virtual_block_size,
)
* self.cp_local_block_size
)
padded_local_max_context_chunk_across_ranks = (
cdiv(
self.max_context_chunk,
self.cp_virtual_block_size,
)
* self.cp_local_block_size
)
local_chunk_starts = (
torch.arange(self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, self.num_prefills)
* padded_local_max_context_chunk_across_ranks
)
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts + padded_local_max_context_chunk_across_ranks,
)
padded_local_chunk_seq_lens = (local_chunk_ends -
local_chunk_starts).clamp(min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(self.num_chunks,
self.num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
padded_local_chunk_seq_lens = (local_chunk_ends - local_chunk_starts).clamp(min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True
)
torch.cumsum(
padded_local_chunk_seq_lens,
dim=1,
@@ -197,8 +190,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
)
chunked_metadata = CPChunkedContextMetadata(
cu_seq_lens=chunked_context_metadata.cu_seq_lens,
starts=local_chunk_starts.pin_memory().to(self.device,
non_blocking=True),
starts=local_chunk_starts.pin_memory().to(self.device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunked_context_metadata.max_seq_lens,
chunk_seq_lens=self.chunk_seq_lens,
@@ -207,18 +199,14 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.
pin_memory().to(self.device, non_blocking=True),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True),
cu_seq_lens_lst=self.cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
return chunked_metadata
def get_block_table_size(
self, common_attn_metadata: AscendCommonAttentionMetadata,
build_metadata_step: int):
self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum(
).item()
def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
self.num_decodes_flatten = self.query_lens[: self.num_decodes].sum().item()
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular attn_mask shape
@@ -231,12 +219,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendMLAPrefillMetadata:
prefill_metadata = super().build_prefill_metadata(
common_prefix_len, common_attn_metadata)
prefill_metadata.pcp_metadata = self.build_cp_metadata(
common_prefix_len, common_attn_metadata)
prefill_metadata.block_table = self.block_table[
self.num_decodes_flatten:, ...]
prefill_metadata = super().build_prefill_metadata(common_prefix_len, common_attn_metadata)
prefill_metadata.pcp_metadata = self.build_cp_metadata(common_prefix_len, common_attn_metadata)
prefill_metadata.block_table = self.block_table[self.num_decodes_flatten :, ...]
return prefill_metadata
def build_decode_metadata(
@@ -244,24 +229,20 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendMLADecodeMetadata:
decode_metadata = super().build_decode_metadata(
common_prefix_len, common_attn_metadata)
decode_metadata = super().build_decode_metadata(common_prefix_len, common_attn_metadata)
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:self.num_decodes_flatten]
num_computed_tokens_of_cp_dcp_array = np.array(num_computed_tokens_of_pcp_dcp)[: self.num_decodes_flatten]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
self.dcp_rank]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
batch_seq_mask = cp_seq_len == 0
self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
decode_metadata.cp_seq_len = cp_seq_len
decode_metadata.batch_seq_mask = batch_seq_mask
@@ -280,30 +261,35 @@ class AscendMlaCPImpl(AscendMLAImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
kv_sharing_target_layer_name: str | None,
**kwargs,
):
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **kwargs)
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**kwargs,
)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
def get_num_actual_tokens(self, attn_metadata: M):
if self.pcp_size > 1:
@@ -320,103 +306,80 @@ class AscendMlaCPImpl(AscendMLAImpl):
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache,
attn_metadata):
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
if not self.pcp_size > 1:
return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache,
attn_metadata)
return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded -
self.pcp_size * num_decode_tokens
) // self.pcp_size + num_decode_tokens
num_actual_tokens = (
attn_metadata.num_actual_tokens_pcp_padded - self.pcp_size * num_decode_tokens
) // self.pcp_size + num_decode_tokens
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
prefill_q = self.q_proj(prefill_q_c)[0] \
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
cos = attn_metadata.prefill.cos[:num_actual_tokens - num_decode_tokens]
sin = attn_metadata.prefill.sin[:num_actual_tokens - num_decode_tokens]
prefill_q = self.q_proj(prefill_q_c)[0].view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :]
prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim]
cos = attn_metadata.prefill.cos[: num_actual_tokens - num_decode_tokens]
sin = attn_metadata.prefill.sin[: num_actual_tokens - num_decode_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view(
[num_actual_tokens, self.num_kv_heads, -1])
assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1)
prefill_k_pe = k_pe
prefill_k_pe[num_decode_tokens:num_actual_tokens] = self.rope_single(
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin)
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin
)
prefill_k_c_normed = kv_c_normed[:num_actual_tokens]
prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe],
dim=-1)
prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe], dim=-1)
prefill_kv_c_k_pe = get_pcp_group().all_gather(prefill_kv_c_k_pe, 0)
prefill_kv_c_k_pe = torch.index_select(
prefill_kv_c_k_pe, 0,
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx)
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens *
self.pcp_size:]
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
)
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens * self.pcp_size :]
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
prefill_k_c_normed = prefill_k_c_normed.squeeze()
slot_mapping = attn_metadata.slot_mapping[self.pcp_size *
num_decode_tokens:]
torch_npu._npu_reshape_and_cache(key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slot_mapping)
prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
slot_mapping = attn_metadata.slot_mapping[self.pcp_size * num_decode_tokens :]
torch_npu._npu_reshape_and_cache(
key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping
)
prefill_k_nope, prefill_value = (
self.kv_b_proj(prefill_k_c_normed)[0]
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
)
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe,
prefill_k_nope, prefill_k_pe,
prefill_value)
return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value)
def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata):
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q_c = q_c[:num_decode_tokens]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_q_c)
decode_ql_nope, decode_q_pe = self.reorg_decode_q(
decode_ql_nope, decode_q_pe)
decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_q_c)
decode_ql_nope, decode_q_pe = self.reorg_decode_q(decode_ql_nope, decode_q_pe)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode(
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe)
decode_k_pe, decode_k_nope = self.exec_kv_decode(decode_kv_no_split, cos, sin, kv_cache, decode_slots)
return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
def get_context_seq_len_npu(self, index: int,
attn_metadata: AscendMLAMetadata):
def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata):
prefill_metadata = attn_metadata.prefill
assert prefill_metadata is not None
assert prefill_metadata.chunked_context is not None
assert isinstance(prefill_metadata.chunked_context,
CPChunkedContextMetadata)
assert isinstance(prefill_metadata.chunked_context, CPChunkedContextMetadata)
assert prefill_metadata.chunked_context.padded_chunk_seq_lens_npu is not None
iters = len(prefill_metadata.chunked_context.seq_tot)
assert 0 <= index < iters
return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
index]
return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[index]
def reorg_decode_q(self, decode_q_nope, decode_q_pe):
if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_q_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_q_no_split = get_dcp_group().all_gather(decode_q_no_split, 1)
decode_q_nope, decode_q_pe = decode_q_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
return decode_q_nope, decode_q_pe
def _forward_prefill(
@@ -426,12 +389,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
k_nope: torch.Tensor,
k_pe: torch.Tensor,
value: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
kv_c_and_k_pe_cache: tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
if not self.pcp_size > 1:
return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value,
kv_c_and_k_pe_cache, attn_metadata)
return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value, kv_c_and_k_pe_cache, attn_metadata)
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
num_tokens = q_nope.size(0)
@@ -455,7 +417,8 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=attn_metadata.attn_mask)
mask=attn_metadata.attn_mask,
)
output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
@@ -467,19 +430,16 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=attn_metadata.attn_mask)
mask=attn_metadata.attn_mask,
)
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
attn_output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1),
1, q_full_idx)
attn_output = torch.index_select(torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1), 1, q_full_idx)
output, _ = self._compute_prefill_context(q_nope, q_pe,
kv_c_and_k_pe_cache,
self.qk_rope_head_dim,
attn_metadata, attn_output,
attn_lse)
output, _ = self._compute_prefill_context(
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse
)
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
@@ -498,44 +458,40 @@ class AscendMlaCPImpl(AscendMLAImpl):
attn_nomask_seqlens: list[torch.Tensor],
mask: torch.Tensor,
):
attn_output = torch.empty(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
dtype=k_pe.dtype,
device=k_pe.device)
attn_lse = torch.empty(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=k_pe.device)
attn_output = torch.empty(
q_nope.shape[0], self.num_heads, self.v_head_dim, dtype=k_pe.dtype, device=k_pe.device
)
attn_lse = torch.empty(self.num_heads, q_pe.shape[0], dtype=torch.float32, device=k_pe.device)
# mask
k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx)
value_mask = torch.index_select(value, 0, kv_mask_idx)
k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_mask,
k_rope=k_pe_mask,
value=value_mask,
mask=mask,
seqlen=attn_mask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_mask,
k_rope=k_pe_mask,
value=value_mask,
mask=mask,
seqlen=attn_mask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse,
)
# nomask
if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0:
return attn_output, attn_lse
for kv_nomask_idx_split, attn_nomask_seqlens_split in zip(
kv_nomask_idx, attn_nomask_seqlens):
for kv_nomask_idx_split, attn_nomask_seqlens_split in zip(kv_nomask_idx, attn_nomask_seqlens):
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split)
value_nomask = torch.index_select(value, 0, kv_nomask_idx_split)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split)
@@ -557,7 +513,8 @@ class AscendMlaCPImpl(AscendMLAImpl):
input_layout="type_bsnd",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
softmax_lse=attn_lse,
)
return attn_output, attn_lse
def _forward_decode(
@@ -579,10 +536,8 @@ class AscendMlaCPImpl(AscendMLAImpl):
else:
num_heads = self.num_heads
k_nope = k_nope.view(-1, block_size, self.num_kv_heads,
self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads,
self.qk_rope_head_dim)
k_nope = k_nope.view(-1, block_size, self.num_kv_heads, self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads, self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
@@ -606,20 +561,35 @@ class AscendMlaCPImpl(AscendMLAImpl):
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
)
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
(
weak_ref_tensors(q_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope),
weak_ref_tensors(k_pe),
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
@@ -634,14 +604,13 @@ class AscendMlaCPImpl(AscendMLAImpl):
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
lse=softmax_lse,
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
@@ -655,20 +624,17 @@ class AscendMlaCPImpl(AscendMLAImpl):
return_lse=True,
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)
lse=softmax_lse,
)
# Update out&lse
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse,
decode_meta.batch_seq_mask)
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask)
attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse)
return self._v_up_proj(attn_output)
def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
def _out_lse_reshape(self, attn_out: torch.Tensor, attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse
def _reorg_kvcache(
@@ -706,8 +672,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
assert chunked_context.max_seq_lens is not None
assert chunked_context.chunk_size is not None
padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[
chunk_idx]
padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[chunk_idx]
local_context_lens_allranks = chunked_context.local_context_lens_allranks
sum_seq_len = chunked_context.cu_seq_lens_lst[chunk_idx][-1]
max_seq_len = chunked_context.max_seq_lens[chunk_idx]
@@ -720,14 +685,16 @@ class AscendMlaCPImpl(AscendMLAImpl):
cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0)
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for padded_local_chunk_seq_len, local_context_lens in zip(
padded_local_chunk_seq_lens_lst, local_context_lens_allranks):
padded_local_chunk_seq_lens_lst, local_context_lens_allranks
):
cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens):
# Note(qcs): We split the context into multiple chunks,
@@ -742,15 +709,12 @@ class AscendMlaCPImpl(AscendMLAImpl):
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks +
src_token_idx +
local_chunk_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
local_chunk_len]
kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len
]
k_pe_segment = allgatered_k_pe[
rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len
]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += local_chunk_len

View File

@@ -1,18 +1,15 @@
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, List, Optional
from typing import Any
import torch
import torch.nn.functional as F
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
get_ascend_device_type)
from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type
def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
@@ -21,6 +18,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
if get_ascend_device_type() == AscendDeviceType.A5:
return False
from vllm.config.compilation import CUDAGraphMode
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
return False
@@ -31,8 +29,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
@lru_cache(maxsize=1)
def enable_cp():
prefill_config = get_current_vllm_config().parallel_config
return prefill_config.prefill_context_parallel_size > 1 \
or prefill_config.decode_context_parallel_size > 1
return prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1
@dataclass
@@ -42,13 +39,14 @@ class AscendPrefillContextParallelMetadata:
Contains index tensors and sequence lengths for PCP operations.
"""
pcp_allgather_restore_idx: torch.Tensor = None
cp_kv_recover_idx_for_chunk: torch.Tensor = None
num_actual_tokens_pcp_padded: int = 0
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
q_head_idx_tensor: torch.Tensor = None
@@ -85,6 +83,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
For many of the tensors we keep both NPU and CPU versions.
"""
# CPU tensor of sequence lengths for host-side operations.
# E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths.
seq_lens_cpu: torch.Tensor = None
@@ -115,20 +114,17 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
num_input_tokens: int = 0
# Metadata for Prefill Context Parallelism (PCP) operations.
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None
# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
# This only use to eagle now. It will be use to enforce_eager in future.
return AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1],
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.
num_computed_tokens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
@@ -144,14 +140,14 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
attn_state=self.attn_state,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=self.num_input_tokens,
prefill_context_parallel_metadata=self.
prefill_context_parallel_metadata,
max_seq_len=self.max_seq_len)
prefill_context_parallel_metadata=self.prefill_context_parallel_metadata,
max_seq_len=self.max_seq_len,
)
def filter_chunked_req_indices(
seq_len: torch.Tensor,
mask_for_non_zero_chunk: Optional[List[bool]],
mask_for_non_zero_chunk: list[bool] | None,
) -> torch.Tensor:
"""
filter the reqs which are doing real chunk_prefill.
@@ -162,14 +158,15 @@ def filter_chunked_req_indices(
Returns:
filtered_indices: the real chunked req's indices
"""
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(
mask_for_non_zero_chunk)
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(mask_for_non_zero_chunk)
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
filtered_indices = torch.cat([
torch.arange(offsets[i], offsets[i] + seq_len[i])
for i in range(len(mask_for_non_zero_chunk))
if mask_for_non_zero_chunk[i]
])
filtered_indices = torch.cat(
[
torch.arange(offsets[i], offsets[i] + seq_len[i])
for i in range(len(mask_for_non_zero_chunk))
if mask_for_non_zero_chunk[i]
]
)
return filtered_indices
@@ -195,12 +192,9 @@ def split_decodes_and_prefills(
num_prefill_tokens: The number of tokens in the prefill requests.
"""
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
if long_seq_metadata else None
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
if long_seq_metadata else 0
max_query_len = common_attn_metadata.max_query_len \
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu if long_seq_metadata else None
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full if long_seq_metadata else 0
max_query_len = common_attn_metadata.max_query_len if max_query_len_pcp_full == 0 else max_query_len_pcp_full
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
@@ -208,8 +202,7 @@ def split_decodes_and_prefills(
if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
if query_lens_pcp_full is None else query_lens_pcp_full
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) if query_lens_pcp_full is None else query_lens_pcp_full
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
@@ -238,7 +231,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
kv_cache_layer: list[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
@@ -264,8 +257,7 @@ def trans_rope_weight(weight, rope_dim):
return weight.contiguous()
nope_part = weight[..., :-rope_dim, :]
rope_part = weight[..., -rope_dim:, :]
reordered_rope_part = torch.cat(
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
reordered_rope_part = torch.cat((rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
@@ -278,12 +270,9 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
nz_mat = torch.permute(
torch.reshape(
nd_mat,
(r // block_size[0], block_size[0], c // block_size[1],
block_size[1]),
(r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
),
[2, 0, 1, 3],
)
nz_mat = torch.reshape(
nz_mat,
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
nz_mat = torch.reshape(nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
return nz_mat

View File

@@ -27,8 +27,12 @@ logger = init_logger(__name__)
if HAS_TRITON:
from vllm_ascend.ops.triton.batch_invariant.matmul import (
addmm_batch_invariant, bmm_batch_invariant, linear_batch_invariant,
matmul_batch_invariant, mm_batch_invariant)
addmm_batch_invariant,
bmm_batch_invariant,
linear_batch_invariant,
matmul_batch_invariant,
mm_batch_invariant,
)
def override_envs_for_invariance():
@@ -73,10 +77,11 @@ def init_batch_invariance():
if vllm_is_batch_invariant():
if HAS_TRITON:
logger.info(
"Enabling batch-invariant mode for vLLM on Ascend NPU.", )
"Enabling batch-invariant mode for vLLM on Ascend NPU.",
)
override_envs_for_invariance()
enable_batch_invariant_mode()
else:
logger.warning(
"Batch-invariant mode requested but Triton is not available."
"skipping batch-invariant initialization.", )
"Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.",
)

View File

@@ -15,35 +15,26 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Type
import torch_npu
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
class BaseDeviceAdaptor(object):
class BaseDeviceAdaptor:
@classmethod
def reshape_and_cache(cls, key, value, key_cache, value_cache,
slot_mapping):
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slot_mapping)
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
torch_npu._npu_reshape_and_cache(
key=key, value=value, key_cache=key_cache, value_cache=value_cache, slot_indices=slot_mapping
)
class A5DeviceAdaptor(BaseDeviceAdaptor):
@classmethod
def reshape_and_cache(cls, key, value, key_cache, value_cache,
slot_mapping):
torch_npu.npu_scatter_pa_kv_cache(key=key,
value=value.contiguous(),
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping)
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
torch_npu.npu_scatter_pa_kv_cache(
key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping
)
def get_device_adaptor():
@@ -53,4 +44,4 @@ def get_device_adaptor():
return BaseDeviceAdaptor
DeviceOperator: Optional[Type['BaseDeviceAdaptor']] = get_device_adaptor()
DeviceOperator: type["BaseDeviceAdaptor"] | None = get_device_adaptor()

View File

@@ -18,21 +18,22 @@
#
import dataclasses
import os
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any
import torch
from acl.rt import memcpy # type: ignore # noqa: F401
from vllm.logger import logger
def find_loaded_library(lib_name) -> Optional[str]:
def find_loaded_library(lib_name) -> str | None:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
""" # noqa
found_line = None
with open("/proc/self/maps") as f:
for line in f:
@@ -47,20 +48,22 @@ def find_loaded_library(lib_name) -> Optional[str]:
start = found_line.index("/")
path = found_line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
f"Unexpected filename: {filename} for library {lib_name}"
assert filename.rpartition(".so")[0].startswith(lib_name), f"Unexpected filename: {filename} for library {lib_name}"
return path
camem_available = False
try:
from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401
init_module, python_create_and_map, python_unmap_and_release)
init_module,
python_create_and_map,
python_unmap_and_release,
)
lib_name = find_loaded_library("vllm_ascend_C")
camem_available = True
except ImportError as e:
logger.warning(
"Failed to import vllm_ascend_C:%s. Sleep mode will be disabled. ", e)
logger.warning("Failed to import vllm_ascend_C:%s. Sleep mode will be disabled. ", e)
init_module = None
python_create_and_map = None
python_unmap_and_release = None
@@ -68,14 +71,14 @@ except ImportError as e:
libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int]
HandleType = tuple[int, int, int, int]
@dataclasses.dataclass
class AllocationData:
handle: HandleType
tag: str
cpu_backup_tensor: Optional[torch.Tensor] = None
cpu_backup_tensor: torch.Tensor | None = None
def create_and_map(allocation_handle: HandleType) -> None:
@@ -88,18 +91,18 @@ def unmap_and_release(allocation_handle: HandleType) -> None:
def get_pluggable_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]
python_free_func: Callable[[int], tuple[int, int, int, int]],
) -> torch.npu.memory.NPUPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, 'my_malloc',
'my_free')
new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, "my_malloc", "my_free")
return new_alloc
@contextmanager
def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]):
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]],
):
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.npu.memory.MemPool(new_alloc._allocator)
with torch.npu.memory.use_mem_pool(mem_pool):
@@ -127,6 +130,7 @@ class CaMemAllocator:
the global variable will be overwritten and the free callback will
not work as expected.
"""
instance = None
default_tag: str = "default"
@@ -143,22 +147,22 @@ class CaMemAllocator:
def __init__(self):
conf = os.environ.get("PYTORCH_NPU_ALLOC_CONF", "")
assert "expandable_segments:True" not in conf, \
("Expandable segments are not compatible with memory pool. "
assert "expandable_segments:True" not in conf, (
"Expandable segments are not compatible with memory pool. "
"Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates.")
"for the latest updates."
)
self.pointer_to_data: Dict[int, AllocationData] = {}
self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CaMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
self.allocator_and_pools: dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData(
allocation_handle, self.current_tag)
self.pointer_to_data[py_d_mem] = AllocationData(allocation_handle, self.current_tag)
return
def python_free_callback(self, ptr: int) -> HandleType:
@@ -170,13 +174,10 @@ class CaMemAllocator:
data.cpu_backup_tensor = None
return data.handle
def sleep(
self,
offload_tags: Optional[Union[Tuple[str, ...],
str]] = None) -> None:
def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
"""
Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
@@ -184,9 +185,9 @@ class CaMemAllocator:
if offload_tags is None:
# by default, allocated tensors are offloaded
# when the allocator sleeps
offload_tags = (CaMemAllocator.default_tag, )
offload_tags = (CaMemAllocator.default_tag,)
elif isinstance(offload_tags, str):
offload_tags = (offload_tags, )
offload_tags = (offload_tags,)
assert isinstance(offload_tags, tuple)
@@ -194,22 +195,18 @@ class CaMemAllocator:
handle = data.handle
if data.tag in offload_tags:
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(size_in_bytes,
dtype=torch.uint8,
device='cpu',
pin_memory=True)
cpu_backup_tensor = torch.empty(size_in_bytes, dtype=torch.uint8, device="cpu", pin_memory=True)
cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_DEVICE_TO_HOST = 2
dest_max = cpu_ptr + size_in_bytes * 2
memcpy(cpu_ptr, dest_max, ptr, size_in_bytes,
ACL_MEMCPY_DEVICE_TO_HOST)
memcpy(cpu_ptr, dest_max, ptr, size_in_bytes, ACL_MEMCPY_DEVICE_TO_HOST)
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
def wake_up(self, tags: list[str] | None = None) -> None:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory."""
for ptr, data in self.pointer_to_data.items():
if tags is None or data.tag in tags:
@@ -218,20 +215,18 @@ class CaMemAllocator:
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = cpu_backup_tensor.numel(
) * cpu_backup_tensor.element_size()
size_in_bytes = cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_HOST_TO_DEVICE = 1
dest_max = ptr + size_in_bytes * 2
memcpy(ptr, dest_max, cpu_ptr, size_in_bytes,
ACL_MEMCPY_HOST_TO_DEVICE)
memcpy(ptr, dest_max, cpu_ptr, size_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE)
data.cpu_backup_tensor = None
@contextmanager
def use_memory_pool(self, tag: Optional[str] = None):
def use_memory_pool(self, tag: str | None = None):
"""
A context manager to use the memory pool.
All memory allocation created inside the context will be allocated
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
will be used.
@@ -243,8 +238,7 @@ class CaMemAllocator:
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback) as data:
with use_memory_pool_with_allocator(self.python_malloc_callback, self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.

View File

@@ -19,107 +19,89 @@
#
import os
from typing import Any, Callable, Dict
from collections.abc import Callable
from typing import Any
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
env_variables: Dict[str, Callable[[], Any]] = {
env_variables: dict[str, Callable[[], Any]] = {
# max compile thread number for package building. Usually, it is set to
# the number of CPU cores. If not set, the default value is None, which
# means all number of CPU cores will be used.
"MAX_JOBS":
lambda: os.getenv("MAX_JOBS", None),
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
# The build type of the package. It can be one of the following values:
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
"CMAKE_BUILD_TYPE":
lambda: os.getenv("CMAKE_BUILD_TYPE"),
"CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
# The CXX compiler used for compiling the package. If not set, the default
# value is None, which means the system default CXX compiler will be used.
"CXX_COMPILER":
lambda: os.getenv("CXX_COMPILER", None),
"CXX_COMPILER": lambda: os.getenv("CXX_COMPILER", None),
# The C compiler used for compiling the package. If not set, the default
# value is None, which means the system default C compiler will be used.
"C_COMPILER":
lambda: os.getenv("C_COMPILER", None),
"C_COMPILER": lambda: os.getenv("C_COMPILER", None),
# The version of the Ascend chip. It's used for package building.
# If not set, we will query chip info through `npu-smi`.
# Please make sure that the version is correct.
"SOC_VERSION":
lambda: os.getenv("SOC_VERSION", None),
"SOC_VERSION": lambda: os.getenv("SOC_VERSION", None),
# If set, vllm-ascend will print verbose logs during compilation
"VERBOSE":
lambda: bool(int(os.getenv('VERBOSE', '0'))),
"VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))),
# The home path for CANN toolkit. If not set, the default value is
# /usr/local/Ascend/ascend-toolkit/latest
"ASCEND_HOME_PATH":
lambda: os.getenv("ASCEND_HOME_PATH", None),
"ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None),
# The path for HCCL library, it's used by pyhccl communicator backend. If
# not set, the default value is libhccl.so.
"HCCL_SO_PATH":
lambda: os.environ.get("HCCL_SO_PATH", None),
"HCCL_SO_PATH": lambda: os.environ.get("HCCL_SO_PATH", None),
# The version of vllm is installed. This value is used for developers who
# installed vllm from source locally. In this case, the version of vllm is
# usually changed. For example, if the version of vllm is "0.9.0", but when
# it's installed from source, the version of vllm is usually set to "0.9.1".
# In this case, developers need to set this value to "0.9.0" to make sure
# that the correct package is installed.
"VLLM_VERSION":
lambda: os.getenv("VLLM_VERSION", None),
"VLLM_VERSION": lambda: os.getenv("VLLM_VERSION", None),
# Whether to enable the model execute time observe profile. Disable it when
# running vllm ascend in production environment.
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
),
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(
int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", "0"))
),
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
# training, the optimized model may not be suitable. In this case, set this
# value to False to disable the optimized model.
"USE_OPTIMIZED_MODEL":
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
"USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv("USE_OPTIMIZED_MODEL", "1"))),
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
# this feature is supported in A2, and eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", "0"))),
# Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
"VLLM_ASCEND_ENABLE_FLASHCOMM1": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))),
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
# The specific value set will be used as the O-matrix TP group size for flashcomm2.
# For a detailed introduction to the parameters and the differences and applicable scenarios
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))),
# buffer size for gate up prefetch
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE":
lambda: int(
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int(
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)
),
# buffer size for down proj prefetch
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE":
lambda: int(
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int(
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)
),
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
"MSMONITOR_USE_DAEMON":
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
"VLLM_ASCEND_ENABLE_MLAPO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
"MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))),
"VLLM_ASCEND_ENABLE_MLAPO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", "0"))),
# Whether to enable weight cast format to FRACTAL_NZ.
# 0: close nz;
# 1: only quant case enable nz;
# 2: enable nz as long as possible.
"VLLM_ASCEND_ENABLE_NZ":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
"VLLM_ASCEND_ENABLE_NZ": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
# Decide whether we should enable CP parallelism.
"VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", '0'))),
"VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", "0"))),
# Whether to anbale dynamic EPLB
"DYNAMIC_EPLB":
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
"DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator)
# 0, or not set: default ALLTOALL and MC2 will be used.
# 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator.
@@ -127,11 +109,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
# with W8A8. And MTP layer must be W8A8.
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
"VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")),
# Whether to anbale balance scheduling
"VLLM_ASCEND_BALANCE_SCHEDULING":
lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", '0'))),
"VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))),
}
# end-env-vars-definition