[Refactor]3/N Refactor mla_v1.py & extract mla_cp (#4933)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The functions related to Cp differ significantly from those of normal
MLA-Attention, but the coupling is quite severe.

Steps:
Isolate PCP and DCP
(1) create a new python file: mla_cp.py
(2) add classes AscendMlaCPImpl and
AscendMlaCPMetadataBuilder,Inheritance AscendMLAImpl and
AscendMLAMetadataBuilder
(3) Remove PCP and DCP-related methods from mla_v1.py to mla_cp.py

vLLM version: v0.12.0

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
wujinyuan1
2025-12-15 12:59:18 +08:00
committed by GitHub
parent 98b9e2e18e
commit 545e856971
3 changed files with 1359 additions and 719 deletions

View File

@@ -912,7 +912,6 @@ class TestAscendMLAImpl(TestBase):
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa) self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm) self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32) self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
def test_q_proj_and_k_up_proj(self): def test_q_proj_and_k_up_proj(self):
batch_size = 4 batch_size = 4

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +1,24 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple, from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
Type, TypeVar) TypeVar)
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torch_npu import torch_npu
from torch import nn from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (get_dcp_group, from vllm.distributed import (get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size, get_decode_context_model_parallel_world_size,
get_pcp_group, get_tensor_model_parallel_rank, get_pcp_group)
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
@@ -53,7 +50,6 @@ MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
class AscendMLABackend(AttentionBackend): class AscendMLABackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@staticmethod @staticmethod
@@ -62,34 +58,26 @@ class AscendMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls(): def get_builder_cls():
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder return AscendMLAMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]: head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size) return num_blocks, block_size, num_kv_heads, head_size
@staticmethod @staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]: def get_impl_cls() -> Type["MLAAttentionImpl"]:
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
return AscendMlaCPImpl
return AscendMLAImpl return AscendMLAImpl
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
@dataclass @dataclass
class AscendMLAPrefillMetadata: class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend""" """ Prefill Specific Metadata for Ascend"""
@@ -113,6 +101,21 @@ class AscendMLAPrefillMetadata:
cu_seq_lens_lst: Optional[list[list[int]]] = None cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None chunk_size: Optional[int] = None
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
attn_mask: torch.Tensor attn_mask: torch.Tensor
query_lens: torch.Tensor query_lens: torch.Tensor
seq_lens: list[int] seq_lens: list[int]
@@ -148,7 +151,6 @@ class AscendMLADecodeMetadata:
@dataclass @dataclass
class AscendMLAMetadata: class AscendMLAMetadata:
"""Metadata for MLACommon. """Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
@@ -209,8 +211,8 @@ class AscendMLAMetadataBuilder:
""" """
def __init__(self, def __init__(self,
kv_cache_spec, kv_cache_spec: MLAAttentionSpec,
layer_names, layer_names: list[str],
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None): metadata_cls: Optional[AscendMLAMetadata] = None):
@@ -350,7 +352,8 @@ class AscendMLAMetadataBuilder:
FIA_SEQ_LEN_LIMIT = 16 FIA_SEQ_LEN_LIMIT = 16
need_padding = num_reqs_pad_size != 0 and \ need_padding = num_reqs_pad_size != 0 and \
len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \
common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[
-1] > FIA_SEQ_LEN_LIMIT
if need_padding: if need_padding:
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size] num_reqs:num_reqs + num_reqs_pad_size]
@@ -408,7 +411,6 @@ class AscendMLAMetadataBuilder:
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
@@ -428,13 +430,7 @@ class AscendMLAMetadataBuilder:
common_attn_metadata.block_table_tensor[:graph_pad_size]) common_attn_metadata.block_table_tensor[:graph_pad_size])
else: else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]
if num_actual_tokens_pcp_padded is None: if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens num_actual_tokens_pcp_padded = num_actual_tokens
@@ -465,30 +461,6 @@ class AscendMLAMetadataBuilder:
chunked_context_metadata = None chunked_context_metadata = None
if num_prefills > 0: if num_prefills > 0:
pcp_metadata = None pcp_metadata = None
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.
attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask
if long_seq_metadata else None,
pcp_allgather_restore_idx=long_seq_metadata.
pcp_allgather_restore_idx if long_seq_metadata else None)
reqs_start = num_decodes # prefill_start reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens tokens_start = num_decode_tokens
@@ -522,78 +494,14 @@ class AscendMLAMetadataBuilder:
out=cu_seq_lens_cpu[:, 1:], out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32) dtype=torch.int32)
if self.dcp_size * self.pcp_size > 1:
if num_computed_tokens_of_pcp_dcp is not None:
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
).reshape(-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
context_lens_cpu,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
padded_local_max_context_chunk_across_ranks = (cdiv(
max_context_chunk,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
local_chunk_starts = (
torch.arange(num_chunks,
dtype=torch.int32).unsqueeze(1).expand(
-1, num_prefills) *
padded_local_max_context_chunk_across_ranks)
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts +
padded_local_max_context_chunk_across_ranks,
)
padded_local_chunk_seq_lens = (local_chunk_ends -
local_chunk_starts).clamp(
min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(
padded_local_chunk_seq_lens,
dim=1,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=local_chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(
dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
.tolist(),
local_context_lens_allranks=local_context_lens_allranks
.tolist(),
padded_local_cu_seq_lens=
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = ( chunked_context_metadata = (
AscendMLAPrefillMetadata.ChunkedContextMetadata( AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to( cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True), device, non_blocking=True),
starts=chunk_starts.pin_memory().to( starts=chunk_starts.pin_memory().to(device,
device, non_blocking=True), non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(), seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max( max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens, chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(), chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace, workspace=self.chunked_prefill_workspace,
@@ -620,9 +528,6 @@ class AscendMLAMetadataBuilder:
cos=cos, cos=cos,
pcp_metadata=pcp_metadata, pcp_metadata=pcp_metadata,
) )
if self.pcp_size > 1:
prefill_metadata.block_table = block_table[
num_decodes_flatten:, ...]
decode_metadata = None decode_metadata = None
if num_decodes > 0: if num_decodes > 0:
@@ -633,11 +538,6 @@ class AscendMLAMetadataBuilder:
max_seq_lens = seq_lens[:num_decodes].max().item() max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes] seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens] input_positions = input_positions[:num_decode_tokens]
if self.pcp_size > 1:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
block_table = block_table[:num_decodes_flatten, ...]
else:
block_table = block_table[:num_decodes, ...] block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp # NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
@@ -646,23 +546,6 @@ class AscendMLAMetadataBuilder:
block_table = block_table[:graph_pad_size, ...] block_table = block_table[:graph_pad_size, ...]
seq_lens_list = seq_lens.tolist() seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None:
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:num_decodes *
self.decode_threshold]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.
shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
else:
cp_seq_len, batch_seq_mask = None, None cp_seq_len, batch_seq_mask = None, None
if graph_pad_size > num_reqs: if graph_pad_size > num_reqs:
@@ -670,7 +553,7 @@ class AscendMLAMetadataBuilder:
num_reqs_pad_size = graph_pad_size - num_reqs num_reqs_pad_size = graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q) num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
seq_lens_list = seq_lens_list + [0] * (graph_pad_size - \ seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
num_decodes) num_decodes)
num_block_pad_size = graph_pad_size - block_table.shape[0] num_block_pad_size = graph_pad_size - block_table.shape[0]
if num_block_pad_size > 0: if num_block_pad_size > 0:
@@ -833,7 +716,7 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str], kv_sharing_target_layer_name: Optional[str],
**kwargs, **kwargs,
) -> None: ):
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
@@ -870,7 +753,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None) self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
@@ -881,35 +763,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.speculative_config = self.vllm_config.speculative_config self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_group = get_tp_group(
).device_group if self.tp_size > 1 else None
def _v_up_proj(self, x): def _v_up_proj(self, x):
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.dcp_size * self.pcp_size > 1:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V) # # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
@@ -1137,10 +991,6 @@ class AscendMLAImpl(MLAAttentionImpl):
dtype=q_nope.dtype, dtype=q_nope.dtype,
device=q_nope.device) device=q_nope.device)
if self.dcp_size * self.pcp_size > 1:
context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
i]
torch_npu.atb.npu_paged_cache_load( torch_npu.atb.npu_paged_cache_load(
cache_kv_c, cache_kv_c,
cache_k_pe, cache_k_pe,
@@ -1151,36 +1001,8 @@ class AscendMLAImpl(MLAAttentionImpl):
value=k_pe, value=k_pe,
) )
cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
if self.dcp_size > 1:
cache_kv_c_k_pe = get_dcp_group().all_gather(
cache_kv_c_k_pe, 0)
if self.pcp_size > 1:
cache_kv_c_k_pe = get_pcp_group().all_gather(
cache_kv_c_k_pe, 0)
if self.dcp_size * self.pcp_size > 1:
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = self._reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
padded_local_chunk_seq_lens_lst=prefill_metadata.
chunked_context.padded_local_chunk_seq_lens[i],
local_context_lens_allranks=prefill_metadata.
chunked_context.local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.
cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.
max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)
kv_c_normed = kv_c_normed.squeeze() kv_c_normed = kv_c_normed.squeeze()
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope \ k_nope, v = kv_nope \
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
@@ -1248,8 +1070,9 @@ class AscendMLAImpl(MLAAttentionImpl):
calc_type="calc_type_first_ring", calc_type="calc_type_first_ring",
output=attn_output, output=attn_output,
softmax_lse=attn_lse) softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \ attn_output, attn_lse = self._compute_prefill_context(
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim,
attn_metadata, attn_output, attn_lse)
attn_output = attn_output.reshape( attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim]) [num_tokens, self.num_heads * self.v_head_dim])
@@ -1488,13 +1311,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_lora_rank) self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_q_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_preprocess_res = DecodeMLAPreprocessResult( decode_preprocess_res = DecodeMLAPreprocessResult(
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
return decode_preprocess_res, None return decode_preprocess_res, None
@@ -1551,17 +1367,8 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = attn_metadata.decode.sin sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = \ decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_q_c) self._q_proj_and_k_up_proj(decode_q_c)
if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_ql_nope, decode_q_pe],
dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_ql_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens * decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1]
self.pcp_size:self.
pcp_size]
decode_kv_no_split = kv_no_split[:num_decode_tokens] decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode( decode_k_pe, decode_k_nope = self.exec_kv_decode(
decode_kv_no_split, cos, sin, kv_cache, decode_slots) decode_kv_no_split, cos, sin, kv_cache, decode_slots)
@@ -1569,10 +1376,6 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
# Preprocess for prefill tokens # Preprocess for prefill tokens
if has_prefill: if has_prefill:
if self.pcp_size > 1:
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded
- self.pcp_size * num_decode_tokens
) // self.pcp_size + num_decode_tokens
prefill_kv_no_split = kv_no_split[ prefill_kv_no_split = kv_no_split[
num_decode_tokens:num_actual_tokens] num_decode_tokens:num_actual_tokens]
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
@@ -1580,55 +1383,11 @@ class AscendMLAImpl(MLAAttentionImpl):
.view(-1, self.num_heads, self.qk_head_dim) .view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.pcp_size > 1:
cos = attn_metadata.prefill.cos[:num_actual_tokens -
num_decode_tokens]
sin = attn_metadata.prefill.sin[:num_actual_tokens -
num_decode_tokens]
else:
cos = attn_metadata.prefill.cos cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin sin = attn_metadata.prefill.sin
prefill_slots = attn_metadata.slot_mapping[ prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens] num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.pcp_size > 1:
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view(
[num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1)
prefill_k_pe = k_pe
prefill_k_pe[
num_decode_tokens:num_actual_tokens] = self.rope_single(
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos,
sin)
prefill_k_c_normed = kv_c_normed[:num_actual_tokens]
prefill_kv_c_k_pe = torch.cat(
[prefill_k_c_normed, prefill_k_pe], dim=-1)
prefill_kv_c_k_pe = get_pcp_group().all_gather(
prefill_kv_c_k_pe, 0)
prefill_kv_c_k_pe = torch.index_select(
prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.
pcp_allgather_restore_idx)
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens *
self.pcp_size:]
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
prefill_k_c_normed = prefill_k_c_normed.squeeze()
slot_mapping = attn_metadata.slot_mapping[self.pcp_size *
num_decode_tokens:]
torch_npu._npu_reshape_and_cache(key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slot_mapping)
else:
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
prefill_k_nope, prefill_value = self.kv_b_proj( prefill_k_nope, prefill_value = self.kv_b_proj(
@@ -1636,7 +1395,6 @@ class AscendMLAImpl(MLAAttentionImpl):
-1, self.num_heads, -1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split( self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if not self.pcp_size > 1:
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
self.num_kv_heads, -1) self.num_kv_heads, -1)
prefill_k_pe = prefill_k_pe.expand( prefill_k_pe = prefill_k_pe.expand(
@@ -1662,9 +1420,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.vllm_config, self.o_proj): self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj) reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0) return output.fill_(0)
if self.pcp_size > 1:
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
else:
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \ assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \ attn_metadata.num_prefills is not None and \
@@ -1693,20 +1448,12 @@ class AscendMLAImpl(MLAAttentionImpl):
if decode_preprocess_res is not None: if decode_preprocess_res is not None:
# MLA Preprocess for decoding # MLA Preprocess for decoding
if self.pcp_size * self.dcp_size > 1: output_decode = self._forward_decode(decode_preprocess_res.ql_nope,
output_decode = self._forward_decode_pcp_dcp(
decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], kv_cache[0].shape[1],
attn_metadata, attn_metadata)
)
else:
output_decode = self._forward_decode(
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], attn_metadata)
o_proj_input[:num_decode_tokens] = output_decode o_proj_input[:num_decode_tokens] = output_decode
@@ -1714,12 +1461,6 @@ class AscendMLAImpl(MLAAttentionImpl):
# FIX: aicore move should be also placed on the comm stream in dbo, # FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy # otherwise it may affect the accuracy
# TODO: use an elegant way to overlap # TODO: use an elegant way to overlap
if self.pcp_size > 1:
output_prefill = self._forward_prefill_cp(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
else:
output_prefill = self._forward_prefill( output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
@@ -1743,377 +1484,3 @@ class AscendMLAImpl(MLAAttentionImpl):
if has_prefill: if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded return output_padded
def _forward_prefill_cp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
value: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
num_tokens = q_nope.size(0)
# Use precomputed indices from the metadata (already converted to tensors and on device)
q_head_idx = attn_metadata.prefill.pcp_metadata.q_head_idx
q_tail_idx = attn_metadata.prefill.pcp_metadata.q_tail_idx
kv_with_q_head_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_head_mask_idx,
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)
output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_tail_mask_idx,
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=mask)
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
attn_output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1),
1, q_full_idx)
output, _ = self._compute_prefill_context( \
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
return output
def _attention_with_mask_and_nomask(
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor,
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor,
mask: torch.Tensor):
attn_output = torch.empty(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
dtype=k_pe.dtype,
device=k_pe.device)
attn_lse = torch.empty(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=k_pe.device)
# mask
k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx)
value_mask = torch.index_select(value, 0, kv_mask_idx)
k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_mask,
k_rope=k_pe_mask,
value=value_mask,
mask=mask,
seqlen=attn_mask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
# nomask
if kv_nomask_idx.shape[0] == 0:
return attn_output, attn_lse
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_nomask,
k_rope=k_pe_nomask,
value=value_nomask,
mask=mask,
seqlen=attn_nomask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=attn_output,
prev_lse=attn_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
return attn_output, attn_lse
def _forward_decode_pcp_dcp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
block_size: int,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
if self.dcp_size > 1:
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads
k_nope = k_nope.view(-1, block_size, self.num_kv_heads,
self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads,
self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
seq_len = decode_meta.cp_seq_len
common_kwargs = {
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)
# Update out&lse
attn_out_lse_list = self._process_attn_out_lse(attn_output,
softmax_lse,
decode_meta)
attn_output = self._npu_attention_update(attn_out_lse_list)
return self._v_up_proj(attn_output)
def _npu_attention_update(
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
attn_out_split_cp = []
attn_lse_split_cp = []
for attn_out_lse in attn_out_lse_list:
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
attn_out_split_cp.append(attn_out_allgather)
attn_lse_split_cp.append(attn_lse_allgather)
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
attn_out_split_cp, 0)
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
self.kv_lora_rank)
return attn_out
def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse
def _process_attn_out_lse(
self,
attn_output: torch.Tensor,
softmax_lse: torch.Tensor,
decode_meta: AscendMLADecodeMetadata,
) -> List[torch.Tensor]:
attn_out_lse_list = []
out_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
lse_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(softmax_lse)
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if self.dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
if self.pcp_size > 1:
attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
if self.pcp_size > 1:
# AllGather out&lse within PCP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp
return attn_out_lse_list
def _reorg_kvcache(
self,
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
padded_local_chunk_seq_lens_lst: list[int],
local_context_lens_allranks: list[list[int]],
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
e.g.
kv_c_normed in rank0 = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...]
kv_c_normed in rank1 = [T0_4, T0_5, pad, pad, T1_2, pad, ...]
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
T0_4, T0_5, pad, pad, T1_2, pad, ...]
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
T1_0, T1_1, T1_2, ...]
Args:
padded_local_chunk_seq_lens_lst: local chunk context lengths
under current CP rank.
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for padded_local_chunk_seq_len, local_context_lens in zip(
padded_local_chunk_seq_lens_lst, local_context_lens_allranks):
cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens):
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len = min(
max(0, local_context_len - chunk_idx * chunk_size),
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks +
src_token_idx +
local_chunk_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
local_chunk_len]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += local_chunk_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe