support cp&dcp (#3260)

### What this PR does / why we need it?
This PR adds the Prefill Context Parallelism (PCP) feature, which
corresponds to DCP. For specific implementation details, please refer to
the RFC https://github.com/vllm-project/vllm/issues/25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning
the sequence dimension during the prefill stage.
### Does this PR introduce _any_ user-facing change?
The current implementation primarily includes the following changes:

Modified ModelRunner.py for CP partitioning logic for tokens;
Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to
PCP.
Modified block_tables.py to extend the KV cache storage based on
DCP&PCP;
Added necessary command-line arguments to control parallelism for PCP;
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: chenjie <chenjie137@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Signed-off-by: Feng Liu <liufeng248@huawei.com>
Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Signed-off-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: chenjie <chenjie137@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Co-authored-by: Feng Liu <liufeng248@huawei.com>
Co-authored-by: gaojc <1055866782@qq.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
LookAround0301
2025-10-24 10:32:01 +08:00
committed by GitHub
parent 2bcadcb9d5
commit b54d44e664
18 changed files with 1729 additions and 211 deletions

View File

@@ -0,0 +1,60 @@
import os
import time
import argparse
from vllm import LLM, SamplingParams
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_len', type=int, default=1024)
parser.add_argument('--output_len', type=int, default=128)
parser.add_argument('--bs', type=int, default=1)
parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite")
parser.add_argument('--tp', type=int, default=2)
parser.add_argument('--pcp', type=int, default=2)
parser.add_argument('--dcp', type=int, default=1)
parser.add_argument('--iter_times', type=int, default=1)
args = parser.parse_args()
prompts = [
"The capital of France is",
"Hello, my name is Tom, I am",
"The president of United States is",
"AI future is"
]
sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len)
llm = LLM(
model=args.model_path,
trust_remote_code=True,
enforce_eager=True,
tensor_parallel_size=args.tp,
prefill_context_parallel_size=args.pcp,
decode_context_parallel_size=args.dcp,
enable_prefix_caching=False,
enable_expert_parallel=True,
enable_chunked_prefill=False,
max_num_batched_tokens=2048,
max_model_len=1024,
additional_config={"ascend_scheduler_config": {"enabled": False}},
max_num_seqs=1,
block_size=128,
gpu_memory_utilization=0.9
)
t0 = time.time()
for _ in range(args.iter_times):
outputs = llm.generate(prompts, sampling_params)
t1 = time.time()
print(f"TTFT: {(t1 - t0) * 1000 / (args.iter_times * args.bs)} ms")
for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"req_num: {i}\nGenerated text: {generated_text!r}")

View File

@@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
@@ -175,7 +176,19 @@ class TestAscendAttentionMetadataBuilder(TestBase):
class TestAscendAttentionBackendImpl(TestBase):
def setUp(self):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
@@ -328,6 +341,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
@@ -360,6 +375,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
output = self.impl.forward(layer,
@@ -390,6 +407,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -496,6 +515,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=True)
@@ -527,6 +548,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 100
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
@@ -560,6 +583,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
@@ -579,11 +604,13 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('torch.version')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache, mock_is_310p):
mock_npu_reshape_and_cache, mock_is_310p,
mock_version):
"""Test forward pass when head_size is 192"""
self.impl.head_size = 192
@@ -598,7 +625,10 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_version.cann = "8.4.RC1"
mock_vanilla_prefill.return_value = MagicMock()
output = self.impl_192.forward(layer,
@@ -612,10 +642,12 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@patch('torch.version')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_normal_v1_situation(self, mock_paged_attention,
mock_npu_reshape_and_cache):
mock_npu_reshape_and_cache,
mock_version):
"""Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -628,8 +660,12 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
@@ -641,13 +677,14 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch.version')
@patch('torch_npu.npu_format_cast')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
mock_npu_reshape_and_cache,
mock_npu_format_cast):
mock_npu_format_cast, mock_version):
"""Test forward pass on 310P device"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -660,9 +697,12 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_npu_format_cast.return_value = metadata.attn_mask
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
@@ -687,6 +727,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
with self.assertRaises(NotImplementedError):

View File

@@ -130,6 +130,7 @@ class TestAscendMLADecodeMetadata(TestBase):
class TestAscendMLAMetadata(TestBase):
def test_ascend_mla_metadata_default(self):
num_actual_tokens_pcp_padded = 100
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
@@ -150,12 +151,11 @@ class TestAscendMLAMetadata(TestBase):
decode = None
prefill = None
metadata = AscendMLAMetadata(num_actual_tokens, slot_mapping,
query_start_loc, seq_lens, block_tables,
num_decodes, num_decode_tokens,
num_prefills, num_input_tokens,
query_lens, head_dim, attn_mask,
attn_state, decode, prefill)
metadata = AscendMLAMetadata(
num_actual_tokens_pcp_padded, num_actual_tokens, slot_mapping,
query_start_loc, seq_lens, block_tables, num_decodes,
num_decode_tokens, num_prefills, num_input_tokens, query_lens,
head_dim, attn_mask, attn_state, decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
@@ -266,6 +266,10 @@ class TestAscendMLAMetadataBuilder(TestBase):
class TestAscendMLAImpl(TestBase):
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
@@ -273,8 +277,13 @@ class TestAscendMLAImpl(TestBase):
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp):
mock_tp, mock_get_dcp_size, mock_dcp):
mock_tp.world_size = 2
mock_tp.rank_in_group = MagicMock()
mock_tp.device_group = MagicMock()
mock_dcp.world_size = 1
mock_dcp.rank_in_group = MagicMock()
mock_dcp.device_group = MagicMock()
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()

View File

@@ -80,6 +80,8 @@ def test_read_agent_metadata():
worker.local_ip = worker_local_ip
worker.tp_rank = worker_tp_rank
worker.llm_datadist_role = LLMRole.PROMPT
worker.pcp_rank = 0
worker.tp_size = worker_tp_rank + 1
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
worker, rank_table)

View File

@@ -149,7 +149,9 @@ def create_request(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1)
remote_tp_size=1,
remote_cp_size=1,
remote_dcp_size=1)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)

View File

@@ -13,7 +13,7 @@
# This file is a part of the vllm-ascend project.
#
from types import SimpleNamespace
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
@@ -100,6 +100,11 @@ def mock_distributed():
pp_group.rank_in_group = 0
pp_group.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mlp_tp_group = Mock(spec=GroupCoordinator)
mlp_tp_group.rank_in_group = 0
mlp_tp_group.world_size = 1
@@ -117,6 +122,9 @@ def mock_distributed():
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch('vllm.distributed.parallel_state.get_dcp_group', return_value=dcp_group), \
patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)), \
patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1),\
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \

View File

@@ -19,29 +19,45 @@ from dataclasses import dataclass
from enum import Enum
from typing import ClassVar, List, Optional, Tuple, Type
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
# isort: off
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec, version_check)
nd_to_nz_2d, nd_to_nz_spec,
prefill_context_parallel_enable, version_check)
from ..utils import weak_ref_tensors
if prefill_context_parallel_enable():
from vllm.distributed import (get_pcp_group,
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort:on
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -127,15 +143,47 @@ class AscendAttentionState(Enum):
@dataclass
class AscendMetadata:
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
@dataclass
class AscendMetadataForPrefill:
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
pcp_allgather_restore_idx: Optional[List[int]] = None
@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None
@dataclass
class AscendMetadata:
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
# Number of tokens excluding padding.
num_actual_tokens_pcp_padded: int = 0
num_actual_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
# The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding).
@@ -168,6 +216,10 @@ class AscendMetadata:
# *************************** Other Properties *************************** #
enable_dbo_across_dp: bool = False
prefill: Optional[AscendMetadataForPrefill] = None
decode_meta: Optional[AscendMetadataForDecode] = None
class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
@@ -207,10 +259,25 @@ class AscendAttentionMetadataBuilder:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
decode_threshold = 1
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -218,7 +285,7 @@ class AscendAttentionMetadataBuilder:
+ 1]
if attn_state == AscendAttentionState.DecodeOnly and \
common_attn_metadata.num_input_tokens > num_actual_tokens:
common_attn_metadata.num_input_tokens > num_actual_tokens:
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
seq_lens = torch.cat([
seq_lens,
@@ -252,8 +319,51 @@ class AscendAttentionMetadataBuilder:
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
prefill_metadata = None
if num_prefills > 0:
pcp_metadata = None
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.
attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx
if common_long_seq_metadata is not None else None)
decode_metadata = None
if num_decodes > 0:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
num_computed_tokens_of_pcp_dcp = np.array(
num_computed_tokens_of_pcp_dcp)
decode_metadata = AscendMetadataForDecode(
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
@@ -264,7 +374,11 @@ class AscendAttentionMetadataBuilder:
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
num_prefills=num_prefills,
num_decodes=num_decodes,
prefill=prefill_metadata,
decode_meta=decode_metadata)
return attn_metadata
def build_for_graph_capture(
@@ -322,6 +436,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.key_cache = None
self.value_cache = None
self.torch_npu_check = version_check()
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None
def _forward_prefill_no_cache(
self,
@@ -581,6 +707,236 @@ class AscendAttentionBackendImpl(AttentionImpl):
out=output)
return output
def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
max_len = max(lengths)
splits = torch.split(tensor_tnd, lengths, dim=0)
padded = []
for s in splits:
pad_len = max_len - s.shape[0]
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
padded.append(s_pad)
tensor_bsnd = torch.stack(padded, dim=0)
return tensor_bsnd
def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
slices = []
for i, length in enumerate(lengths):
slices.append(tensor_bsnd[i, :length])
tensor_tnd = torch.cat(slices, dim=0)
return tensor_tnd
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
q_seqlens: List[int],
k_nomask: torch.Tensor,
v_nomask: torch.Tensor,
kv_seqlens_nomask: List[int],
k_mask: torch.Tensor,
v_mask: torch.Tensor,
kv_seqlens_mask: List[int],
mask: torch.Tensor) -> torch.Tensor:
q = self._pack_tnd_2_bsnd(q, q_seqlens)
# nomask Attention
if k_nomask is not None:
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_nomask,
actual_seq_lengths=q_seqlens)
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
q_seqlens)
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
attn_lse_nomask = self._unpack_bsnd_2_tnd(
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)
# mask Attention
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=mask,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_mask,
actual_seq_lengths=q_seqlens)
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
attn_lse_mask = self._unpack_bsnd_2_tnd(
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)
# update
output = attn_out_mask
if k_nomask is not None:
output, _ = self._update_out_and_lse(
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
return output
def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
# Use precomputed indices from the metadata (already converted to tensors and on device)
q_head_idx = attn_metadata.prefill.pcp_metadata.q_head_idx
q_tail_idx = attn_metadata.prefill.pcp_metadata.q_tail_idx
kv_with_q_head_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
# 1. Attention calculation in the first half of Q in load balancing
output_head = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_head_idx),
q_seqlens=attn_mask_seqlens[0].tolist(),
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
if self.pcp_rank > 0 else None,
v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx)
if self.pcp_rank > 0 else None,
kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(),
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
mask=mask)
# 2. the Attention calculation in the latter half of Q in load balancing
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
output_tail = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_tail_idx),
q_seqlens=attn_mask_seqlens[0].tolist(),
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx),
kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(),
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
mask=mask)
# 3. Combine the output of the first half and second half.
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
return output
def _update_out_and_lse(self, out_list: torch.Tensor,
lse_list: torch.Tensor) -> torch.Tensor:
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
Args:
out_list: shape = [N, batch_size, num_heads, head_size]
lse_list: shape = [N, batch_size, num_heads, 1]
Returns:
out_final: shape = [batch_size, num_heads, head_size]
lse_final: shape = [batch_size, num_heads, 1]
"""
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
dim=0)
return out_final, lse_final
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
attn_metadata: AscendMetadata) -> torch.Tensor:
assert self.key_cache is not None
assert self.value_cache is not None
if self.dcp_size > 1:
query = get_dcp_group().all_gather(query, 1)
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads
# 1. Compute out&lse by "npu_fused_infer_attention_score"
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
self.value_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
num_heads=num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=attn_metadata.block_tables,
block_size=self.key_cache.shape[1],
actual_seq_lengths_kv=attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
)
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3])
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
attn_out_lse_split_on_seq = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
attn_out_lse_split_dcp = torch.stack(
attn_out_lse_split_on_seq,
dim=0) # [dcp, batch_size, num_heads, head_size+1]
# Update out&lse
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
attn_out, attn_lse = self._update_out_and_lse(
attn_out_split_dcp, attn_lse_split_dcp)
if self.pcp_size > 1:
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# 3. AllGather out&lse within CP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
# 4. Update out&lse
attn_out_lse_allgather = torch.stack(
attn_out_lse_list,
dim=0) # [pcp, batch_size, num_heads, head_size+1]
attn_out_allgather, attn_lse_allgather = torch.split(
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
attn_lse_allgather)
return attn_out
def forward(
self,
layer: AttentionLayer,
@@ -633,7 +989,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
else:
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
@@ -650,14 +1009,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
if len(kv_cache) > 1:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if attn_type == AttentionType.ENCODER_ONLY:
if has_decode:
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
torch_npu._npu_reshape_and_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping)
if has_prefill:
if self.pcp_size > 1:
kv = torch.cat([key, value], dim=-1)
all_kv = get_pcp_group().all_gather(kv, dim=0)
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
all_kv = torch.index_select(all_kv, 0,
pcp_allgather_restore_idx)
key, value = all_kv.split(
[self.head_size, self.head_size], dim=-1)
torch_npu._npu_reshape_and_cache(
key=key[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
value=value[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=attn_metadata.
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])
if self.pcp_size * self.dcp_size > 1:
output = self._forward_pcp_dcp(query, key, value,
attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
attn_out = torch_npu.npu_fusion_attention(
query,
@@ -668,7 +1059,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
@@ -679,7 +1069,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
AscendAttentionState.PrefillCacheHit:
output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
@@ -701,6 +1091,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size)
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: Optional[torch.Tensor]) -> torch.Tensor:
assert attn_metadata is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
if output is None:
raise ValueError("Output buffer is required")
if has_decode:
decode_query = query[:num_decode_tokens]
output_decode = self._forward_decode_pcp_dcp(
decode_query, attn_metadata)
output[:num_decode_tokens] = output_decode
if has_prefill:
prefill_query = query[num_decode_tokens:]
key = key[self.pcp_size * num_decode_tokens:]
value = value[self.pcp_size * num_decode_tokens:]
if self.pcp_size > 1:
output_prefill = self._forward_prefill_cp(
prefill_query, key, value, attn_metadata)
else:
max_prefill_seq_len = attn_metadata.seq_lens[
attn_metadata.num_decode_tokens:].max().item()
if attn_metadata.attn_mask is not None:
attn_metadata.attn_mask = attn_metadata.attn_mask[:
max_prefill_seq_len, :
max_prefill_seq_len]
else:
ValueError("Attn_metadata.attn_mask is required")
seq_lens_back = attn_metadata.seq_lens
attn_metadata.seq_lens = attn_metadata.seq_lens[
attn_metadata.num_decode_tokens:]
output_prefill = self._forward_prefill_no_cache(
prefill_query, key, value, attn_metadata,
output[num_decode_tokens:], prefill_query.shape[0])
attn_metadata.seq_lens = seq_lens_back
output[num_decode_tokens:] = output_prefill
return output
def unified_ascend_attention_with_output(
query: torch.Tensor,

View File

@@ -2,14 +2,24 @@ from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
TypeVar)
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
# isort: off
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase,
@@ -32,9 +42,15 @@ from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
is_enable_nz, prefill_context_parallel_enable)
from vllm_ascend.worker.npu_input_batch import InputBatch
if prefill_context_parallel_enable():
from vllm.distributed import (get_pcp_group,
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort:on
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -65,6 +81,22 @@ class AscendMLABackend(AttentionBackend):
return AscendMLAImpl
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
@dataclass
class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@@ -92,6 +124,7 @@ class AscendMLAPrefillMetadata:
chunked_context: Optional[ChunkedContextMetadata] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
pcp_metadata: Optional[AscendPCPMetadata] = None
@dataclass
@@ -107,6 +140,8 @@ class AscendMLADecodeMetadata:
attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None
@dataclass
@@ -124,6 +159,7 @@ class AscendMLAMetadata:
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens_pcp_padded: int
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
@@ -297,6 +333,11 @@ class AscendMLAMetadataBuilder:
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
@@ -308,10 +349,14 @@ class AscendMLAMetadataBuilder:
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
if self.cos_cache is None:
self.cos_cache = model.model.layers[
@@ -332,6 +377,31 @@ class AscendMLAMetadataBuilder:
prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
pcp_metadata = None
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.
attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=long_seq_metadata.
pcp_allgather_restore_idx if long_seq_metadata else None)
reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[reqs_start:].max().item()
@@ -392,6 +462,7 @@ class AscendMLAMetadataBuilder:
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
pcp_metadata=pcp_metadata,
)
decode_metadata = None
@@ -426,7 +497,9 @@ class AscendMLAMetadataBuilder:
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
cos=cos,
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
@@ -444,9 +517,12 @@ class AscendMLAMetadataBuilder:
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...])
cos=cos[:num_decode_tokens, ...],
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
@@ -494,6 +570,7 @@ class DecodeMLAPreprocessResult(NamedTuple):
q_pe: Optional[torch.Tensor] = None
k_nope: Optional[torch.Tensor] = None
k_pe: Optional[torch.Tensor] = None
decode_q_wo_k_up: Optional[torch.Tensor] = None
class PrefillMLAPreprocessResult(NamedTuple):
@@ -561,8 +638,27 @@ class AscendMLAImpl(MLAAttentionImpl):
self.speculative_config = vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_group = get_tp_group(
).device_group if self.tp_size > 1 else None
def _v_up_proj(self, x):
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
if self.W_UV.shape[0] * self.W_UV.shape[
1] < 65536 and not self.dcp_size * self.pcp_size > 1:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
x = torch_npu.npu_transpose_batchmatmul(x,
self.W_UV,
@@ -1062,7 +1158,6 @@ class AscendMLAImpl(MLAAttentionImpl):
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, k_nope, **common_kwargs)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
@@ -1162,9 +1257,9 @@ class AscendMLAImpl(MLAAttentionImpl):
# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c, need_gather_q_kv)
q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)
kv_no_split.contiguous(), need_gather_q_kv)
decode_preprocess_res = None
prefill_preprocess_res = None
@@ -1177,8 +1272,17 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_q_c)
if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_ql_nope, decode_q_pe],
dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_ql_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens *
self.pcp_size:self.
pcp_size]
decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode(
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
@@ -1186,6 +1290,10 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
# Preprocess for prefill tokens
if has_prefill:
if self.pcp_size > 1:
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded
- self.pcp_size * num_decode_tokens
) // self.pcp_size + num_decode_tokens
prefill_kv_no_split = kv_no_split[
num_decode_tokens:num_actual_tokens]
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
@@ -1193,20 +1301,65 @@ class AscendMLAImpl(MLAAttentionImpl):
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
if self.pcp_size > 1:
cos = attn_metadata.prefill.cos[:num_actual_tokens -
num_decode_tokens]
sin = attn_metadata.prefill.sin[:num_actual_tokens -
num_decode_tokens]
else:
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
self.num_kv_heads, -1)
if self.pcp_size > 1:
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view(
[num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1)
prefill_k_pe = k_pe
prefill_k_pe[
num_decode_tokens:num_actual_tokens] = self.rope_single(
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos,
sin)
prefill_k_c_normed = kv_c_normed[:num_actual_tokens]
prefill_kv_c_k_pe = torch.cat(
[prefill_k_c_normed, prefill_k_pe], dim=-1)
prefill_kv_c_k_pe = get_pcp_group().all_gather(
prefill_kv_c_k_pe, 0)
prefill_kv_c_k_pe = torch.index_select(
prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.
pcp_allgather_restore_idx)
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens *
self.pcp_size:]
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
prefill_k_c_normed = prefill_k_c_normed.squeeze()
slot_mapping = attn_metadata.slot_mapping[self.pcp_size *
num_decode_tokens:]
torch_npu._npu_reshape_and_cache(key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slot_mapping)
else:
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if not self.pcp_size > 1:
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
self.num_kv_heads, -1)
prefill_k_pe = prefill_k_pe.expand(
(*prefill_k_nope.shape[:-1], -1))
prefill_preprocess_res = PrefillMLAPreprocessResult(
@@ -1227,7 +1380,10 @@ class AscendMLAImpl(MLAAttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
num_actual_tokens = attn_metadata.num_actual_tokens
if self.pcp_size > 1:
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
else:
num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
@@ -1253,12 +1409,20 @@ class AscendMLAImpl(MLAAttentionImpl):
if decode_preprocess_res is not None:
# MLA Preprocess for decoding
output_decode = self._forward_decode(decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata)
if self.pcp_size * self.dcp_size > 1:
output_decode = self._forward_decode_pcp_dcp(
decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata,
)
else:
output_decode = self._forward_decode(
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
@@ -1271,10 +1435,16 @@ class AscendMLAImpl(MLAAttentionImpl):
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
if self.pcp_size > 1:
output_prefill = self._forward_prefill_cp(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
else:
output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
@@ -1311,3 +1481,281 @@ class AscendMLAImpl(MLAAttentionImpl):
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded
def _forward_prefill_cp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
value: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
num_tokens = q_nope.size(0)
# Use precomputed indices from the metadata (already converted to tensors and on device)
q_head_idx = attn_metadata.prefill.pcp_metadata.q_head_idx
q_tail_idx = attn_metadata.prefill.pcp_metadata.q_tail_idx
kv_with_q_head_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_head_mask_idx,
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)
output_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_tail_mask_idx,
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=mask)
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
return output
def _attention_with_mask_and_nomask(
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor,
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor,
mask: torch.Tensor):
attn_output = torch.empty(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
dtype=k_pe.dtype,
device=k_pe.device)
attn_lse = torch.empty(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=k_pe.device)
# mask
k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx)
value_mask = torch.index_select(value, 0, kv_mask_idx)
k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_mask,
k_rope=k_pe_mask,
value=value_mask,
mask=mask,
seqlen=attn_mask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
# nomask
if kv_nomask_idx.shape[0] == 0:
return attn_output
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_nomask,
k_rope=k_pe_nomask,
value=value_nomask,
mask=mask,
seqlen=attn_nomask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=attn_output,
prev_lse=attn_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
return attn_output
def _forward_decode_pcp_dcp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
block_size: int,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
if self.dcp_size > 1:
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads
k_nope = k_nope.view(-1, block_size, self.num_kv_heads,
self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads,
self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
num_computed_tokens_of_pcp_dcp = np.array(
decode_meta.num_computed_tokens_of_pcp_dcp
)[:attn_metadata.num_decodes] # [bs, pcp_size, dcp_size]
seq_mask_pcp = torch.where(
torch.tensor(num_computed_tokens_of_pcp_dcp.sum(2)) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, :]) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank]
seq_len = torch.tensor(seq_len, dtype=torch.int32)
# npu_multi_head_latent_attention does not support seq_len = 0,
# update where seq_len == 0 to 1.
# This will not influence result, since we will use seq_mask to update lse.
seq_len = torch.where(seq_len == 0, 1, seq_len)
if torch.sum(seq_len).item() == 0:
# Case that no kv_cache has been stored on this rank, no need to do following computation.
attn_output = torch.zeros(
[num_tokens, num_heads, self.kv_lora_rank],
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.full((num_tokens, num_heads, 1),
float('-inf'),
dtype=q_nope.dtype,
device=q_nope.device)
else:
attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring")
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
attn_out_lse_split_on_seq = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
# Update out&lse
attn_out_g = None
attn_lse_g = None
for i, attn_out_lse_l in enumerate(attn_out_lse_split_on_seq):
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
[self.kv_lora_rank, 1],
dim=-1)
attn_out_g, attn_lse_g = self._update_out_and_lse(
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
seq_mask_dcp[:, i])
attn_output = attn_out_g
softmax_lse = attn_lse_g
if self.pcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
# AllGather out&lse within PCP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
# Update out&lse
attn_out_g = None
attn_lse_g = None
for i, attn_out_lse_l in enumerate(attn_out_lse_list):
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
[self.kv_lora_rank, 1],
dim=-1)
attn_out_g, attn_lse_g = self._update_out_and_lse(
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
seq_mask_pcp[:, i])
attn_output = attn_out_g
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
# TODO use update op to replace this
def _update_out_and_lse(
self,
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
mask: torch.Tensor = None,
):
if out is None:
out = block_out.to(torch.float32)
lse = block_lse
else:
if mask is None:
mask = torch.ones([block_out.size(0)],
dtype=torch.uint8,
device=block_out.device)
out_mask = mask[:, None, None].expand_as(block_out)
lse_mask = mask[:, None, None].expand_as(block_lse)
block_out = block_out.to(torch.float32)
out_without_update = out.clone()
lse_without_update = lse.clone()
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
# mask
out = torch.where(out_mask, out, out_without_update)
lse = torch.where(lse_mask, lse, lse_without_update)
return out, lse

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, List
from typing import Any, List, Optional
import torch
import torch.nn.functional as F
@@ -9,6 +9,39 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.forward_context import ForwardContext, get_forward_context
@dataclass
# class AscendCommonLongSequenceMetadata:
class AscendPrefillContextParallelMetadata:
pcp_allgather_restore_idx: torch.Tensor = None
num_actual_tokens_pcp_padded: Optional[int] = None
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None
q_head_idx_tensor: torch.Tensor = None
q_tail_idx_tensor: torch.Tensor = None
kv_with_q_head_nomask_idx_tensor: torch.Tensor = None
kv_with_q_head_mask_idx_tensor: torch.Tensor = None
kv_with_q_tail_nomask_idx_tensor: torch.Tensor = None
kv_with_q_tail_mask_idx_tensor: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
@dataclass
class AscendCommonAttentionMetadata:
"""
@@ -72,6 +105,9 @@ class AscendCommonAttentionMetadata:
cos: torch.Tensor = None
sin: torch.Tensor = None
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
def split_decodes_and_prefills(
common_attn_metadata: AscendCommonAttentionMetadata,

View File

@@ -22,7 +22,8 @@ from vllm import envs
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_tp_group, get_world_group
from vllm.distributed.parallel_state import (get_dcp_group, get_tp_group,
get_world_group)
from vllm.forward_context import ForwardContext
from vllm.utils import get_ip, logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@@ -30,7 +31,12 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
prefill_context_parallel_enable)
if prefill_context_parallel_enable():
from vllm.distributed.parallel_state import \
get_prefill_context_model_parallel_rank
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
@@ -66,6 +72,8 @@ class ReqMeta:
remote_port: str
engine_id: str
remote_tp_size: str
remote_cp_size: str
remote_dcp_size: str
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
@@ -82,6 +90,8 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_tp_size=kv_transfer_params["remote_tp_size"],
remote_cp_size=kv_transfer_params["remote_cp_size"],
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
)
@@ -185,8 +195,11 @@ class LLMDataDistCMgrConnectorScheduler():
else:
dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable(
) else 1
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
self.port = dp_rank_local * self.pcp_size * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
@@ -298,6 +311,8 @@ class LLMDataDistCMgrConnectorScheduler():
remote_port=self.port,
remote_tp_size=str(
self.vllm_config.parallel_config.tensor_parallel_size),
remote_cp_size=str(self.pcp_size),
remote_dcp_size=str(self.dcp_size),
)
@@ -322,6 +337,11 @@ class LLMDataDistCMgrConnectorWorker():
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_rank = get_tp_group().rank_in_group
self.rank = get_world_group().rank
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable(
) else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if prefill_context_parallel_enable() else 0
self.dcp_size = get_dcp_group().world_size
self.local_ip = get_ip()
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
self.local_agent_metadata: Optional[
@@ -362,7 +382,8 @@ class LLMDataDistCMgrConnectorWorker():
def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.pcp_size * self.tp_size + self.pcp_rank * self.tp_size + self.tp_rank \
if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
url = f"tcp://{envs_ascend.VLLM_ASCEND_LLMDD_RPC_IP}:{port}"
msg_encoder = msgspec.msgpack.Encoder()
msg_decoder = msgspec.msgpack.Decoder()
@@ -472,9 +493,10 @@ class LLMDataDistCMgrConnectorWorker():
d for d in device_list if d.get("server_id") == self.local_ip
and device_filter(d.get("device_id", ""))
]
if len(device_list) <= self.tp_rank:
if len(device_list) <= self.pcp_rank * self.tp_size + self.tp_rank:
continue
device_info = device_list[self.tp_rank]
device_info = device_list[self.pcp_rank * self.tp_size +
self.tp_rank]
super_pod_id_ = device_info.get("super_pod_id", None)
server_id_ = device_info["server_id"]
device_id_ = device_info["device_id"]
@@ -648,6 +670,8 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id=meta.engine_id,
request_id=req_id,
remote_tp_size=meta.remote_tp_size,
remote_cp_size=meta.remote_cp_size,
remote_dcp_size=meta.remote_dcp_size,
)
futures.append(future)
@@ -839,6 +863,62 @@ class LLMDataDistCMgrConnectorWorker():
f"Failed to send reqest_id {request_id} to prefill: {e}"
)
def _get_kv_split_metadata(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_port: int,
remote_tp_size: int,
remote_cp_size: int,
remote_dcp_size: int,
) -> tuple[int, list[int], list[int]]:
"""
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
if self.pcp_size == remote_cp_size and self.dcp_size == remote_dcp_size:
# remote & local cp/dcp are equal, do kv transfer point-to-point
remote_kv_num = 1
remote_ports = [remote_port + self.pcp_rank * self.tp_size + tp_offset \
for tp_offset in range(self.tp_rank, int(remote_tp_size), self.tp_size)]
remote_block_nums = [len(remote_block_ids)]
elif (self.use_mla and self.pcp_size == 1 and self.dcp_size == 1) \
or (not self.use_mla and self.pcp_size == 1 and remote_tp_size == self.tp_size and remote_dcp_size == self.dcp_size):
# remote & local cp/dcp are not equal, each D node needs to pull from cp(*dcp) P nodes
# 1. for mla, support D cp_size = dcp_size = 1
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
remote_dcp_size = remote_dcp_size // self.dcp_size
remote_kv_num = remote_cp_size * remote_dcp_size
cp_dcp_offsets = []
for cp_idx in range(remote_cp_size):
cp_offset = cp_idx * remote_tp_size
cp_dcp_offsets += list(
range(cp_offset, cp_offset + remote_dcp_size))
remote_ports = [remote_port + cp_dcp_offset + (self.tp_rank if not self.use_mla else 0) \
for cp_dcp_offset in cp_dcp_offsets]
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
local_block_num = len(local_block_ids)
remote_block_nums = [
local_block_num // (remote_cp_size * remote_dcp_size)
] * remote_cp_size * remote_dcp_size
num_remain_blocks = local_block_num % (remote_cp_size *
remote_dcp_size)
for i in range(num_remain_blocks):
remote_block_nums[i] += 1
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
remote_ports = remote_ports[
num_remain_blocks:] + remote_ports[:num_remain_blocks]
remote_block_nums = remote_block_nums[
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
else:
# Other cases are not supported now, maybe need to reshard kv_cache.
raise NotImplementedError(
f'Current case is not supported now: use_mla={self.use_mla}, '
f'P tp={remote_tp_size}, pcp={remote_cp_size}, dcp={remote_dcp_size}, '
f'D tp={self.tp_size}, pcp={self.pcp_size}, dcp={self.dcp_size}'
)
return remote_kv_num, remote_ports, remote_block_nums
def _read_blocks(
self,
local_block_ids: list[int],
@@ -848,97 +928,119 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id: str,
request_id: str,
remote_tp_size: str,
remote_cp_size: str,
remote_dcp_size: str,
):
# if remote_ip not in self.linked_cluster:
tp_offset = self.tp_rank % int(remote_tp_size)
remote_cluster_id = self.connect_to_remote_agent(
remote_ip, remote_port + tp_offset)
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
remote_kv_num, remote_ports, remote_block_nums = self._get_kv_split_metadata(
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,
remote_port=remote_port,
remote_tp_size=int(remote_tp_size),
remote_cp_size=int(remote_cp_size),
remote_dcp_size=int(remote_dcp_size),
)
logger.debug(
f'Pull blocks from remote: remote_kv_num={remote_kv_num}, remote_ports={remote_ports}, '
f'remote_block_nums={remote_block_nums}, local_block_ids={local_block_ids}'
)
logger.info(f"remote cluster id is: {remote_cluster_id}")
if self.use_mla:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sparse:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
remote_cache_key_k_idx = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=2)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_idx,
self.cache[2], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key,
self.cache, # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
remote_ports = list(
range(remote_port + self.tp_rank,
remote_port + int(remote_tp_size), self.tp_size))
local_block_offset = 0
remote_block_ids_full = remote_block_ids
local_block_ids_full = local_block_ids
for remote_kv_id in range(remote_kv_num):
remote_port = remote_ports[remote_kv_id]
num_blocks_to_pull = remote_block_nums[remote_kv_id]
if num_blocks_to_pull == 0:
continue
remote_block_ids = remote_block_ids_full[:num_blocks_to_pull]
local_block_ids = local_block_ids_full[
local_block_offset:local_block_offset + num_blocks_to_pull]
local_block_offset += num_blocks_to_pull
remote_cluster_id = self.connect_to_remote_agent(
remote_ip, remote_port)
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
logger.info(f"remote cluster id is: {remote_cluster_id}")
if self.use_mla:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sparse:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
remote_cache_key_k_idx = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=2)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_idx,
self.cache[2], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key,
self.cache, # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)
@@ -990,4 +1092,4 @@ def zmq_ctx(socket_type: Any,
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)
ctx.destroy(linger=0)

View File

@@ -7,6 +7,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import prefill_context_parallel_enable
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -58,9 +59,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
if prefill_context_parallel_enable():
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.prefill_context_parallel_size *
parallel_config.tensor_parallel_size)
else:
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
pd_head_ratio = get_ascend_config().pd_head_ratio

View File

@@ -172,6 +172,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
"VLLM_ASCEND_ENABLE_NZ":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
# Decide whether we should enable CP parallelism.
"VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", '0')))
}
# end-env-vars-definition

View File

@@ -32,6 +32,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
prefill_context_parallel_enable,
update_aclgraph_sizes)
if TYPE_CHECKING:
@@ -131,7 +132,8 @@ class NPUPlatform(Platform):
if (model_config is not None and not model_config.use_mla
and not scheduler_config.async_scheduling
and model_config.runner_type != "pooling"):
and model_config.runner_type != "pooling"
and not prefill_context_parallel_enable()):
logger.info(
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
"as the performance of operators supporting this feature "
@@ -322,6 +324,16 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config.chunked_prefill_enabled = True
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
if vllm_config.kv_transfer_config is not None and \
prefill_context_parallel_enable() and \
cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \
parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1:
raise AssertionError(
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
f"and block_size({cache_config.block_size}) "
"needs to be equal if use cp or dcp > 1 in P/D disaggregate scenario."
)
@classmethod
def get_attn_backend_cls(
cls,

View File

@@ -648,6 +648,10 @@ def shared_expert_dp_enabled() -> bool:
return get_ascend_config().enable_shared_expert_dp or enable_sp()
def prefill_context_parallel_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL
def is_moe_model(vllm_config: VllmConfig):
global _IS_MOE_MODEL
if _IS_MOE_MODEL is None:

View File

@@ -5,6 +5,11 @@ import torch
from vllm.distributed import get_dcp_group
from vllm.utils import cdiv
from vllm_ascend.utils import prefill_context_parallel_enable
if prefill_context_parallel_enable():
from vllm.distributed import get_pcp_group
class BlockTable:
@@ -15,7 +20,8 @@ class BlockTable:
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_sizes: Union[list[int], None] = None):
kernel_sizes: Union[list[int], None] = None,
cp_kv_cache_interleave_size: int = 1):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
@@ -80,13 +86,20 @@ class BlockTable:
dtype=torch.int64,
device=self.device)
try:
self.pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_world_size > 1 else 0
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.pcp_world_size = 1
self.pcp_rank = 0
self.kernel_sizes = kernel_sizes
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
@@ -132,14 +145,14 @@ class BlockTable:
# here because M (max_model_len) is not necessarily divisible by
# block_size.
if self.dcp_world_size > 1:
if self.dcp_world_size * self.pcp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
virtual_block_size = self.block_size * self.dcp_world_size * self.pcp_world_size
# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
@@ -157,9 +170,14 @@ class BlockTable:
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
mask = (virtual_block_offsets // self.cp_kv_cache_interleave_size %
(self.dcp_world_size *
self.pcp_world_size) == self.current_rank)
# Calculate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
block_offsets = virtual_block_offsets \
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) \
* self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
@@ -242,16 +260,20 @@ class MultiGroupBlockTable:
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
kernel_sizes: Optional[list[list[int]]] = None) -> None:
kernel_sizes: Optional[list[list[int]]] = None,
cp_kv_cache_interleave_size: int = 1) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
cp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
cp_world_size = 1
if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
@@ -267,9 +289,12 @@ class MultiGroupBlockTable:
self.block_tables = [
BlockTable(
block_size, max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size),
max(
cdiv(max_model_len,
block_size * dcp_world_size * cp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device, kernel_size_list)
pin_memory, device, kernel_size_list,
cp_kv_cache_interleave_size)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
]

View File

@@ -50,8 +50,8 @@ from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group,
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
get_pp_group, get_tp_group,
is_global_first_rank)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
@@ -107,7 +107,8 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params,
update_attn_params,
@@ -132,9 +133,16 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
enable_sp, get_ascend_soc_version, is_310p,
is_enable_nz, lmhead_tp_enable)
is_enable_nz, lmhead_tp_enable,
prefill_context_parallel_enable)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if prefill_context_parallel_enable():
from vllm.distributed import get_pcp_group
from vllm.distributed.parallel_state import (
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size)
if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
from vllm.v1.core.sched.output import SchedulerOutput
@@ -260,6 +268,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
decode_max_num_seqs)
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.device = device
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
self.prefetch_stream = torch.npu.Stream(device=device)
@@ -320,7 +334,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)
if torch.version.cann.startswith("8.3"):
if self.pcp_size > 1:
self.attn_mask_builder = None
elif torch.version.cann.startswith("8.3"):
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
self.device)
@@ -454,6 +470,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
pin_memory=True)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.pcp_allgather_restore_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.use_aclgraph = self._use_aclgraph()
self.aclgraph_batch_sizes = list(
@@ -525,6 +548,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size
if prefill_context_parallel_enable() else 1,
)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
@@ -890,12 +916,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
# Chunk Prefill situation.
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
if torch.version.cann.startswith("8.3"):
if self.dcp_size > 1:
max_seq_len = max(seq_lens.max().item(), 0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
elif torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
return self.attn_mask_builder.get_splitfuse_attn_mask(
@@ -945,7 +979,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
src_end = num_computed_tokens + prompt_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
req.mrope_positions[:,src_start:src_end]
req.mrope_positions[:, src_start:src_end]
mrope_pos_ptr += prompt_part_len
@@ -1219,7 +1253,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = num_scheduled_tokens.max()
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
_, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
positions_np = np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
# update total_num_scheduled_tokens
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
total_num_pcp_pads = sum(self.num_pcp_pads)
max_num_scheduled_tokens = max(tokens)
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
@@ -1284,10 +1338,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
if self.pcp_size > 1:
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
position_pcp[:total_num_scheduled_tokens],
out=positions_np)
else:
self.positions_np[:total_num_scheduled_tokens] = positions_np
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -1315,13 +1372,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Prepare some information for building Attention-Metadata
# Compute and commit slot mapping
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.query_start_loc[:num_reqs + 1].copy_(
@@ -1351,6 +1401,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
@@ -1428,9 +1479,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
spec_decode_metadata = None
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(
cu_num_tokens
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
logits_indices = logits_indices.to(self.device, non_blocking=True)
else:
# pcp not supported now
assert self.pcp_size == 1
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
@@ -1458,10 +1513,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
# prepare pcp meta data
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens, seq_lens_cpu)
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
slot_mapping_size = (total_num_scheduled_tokens
if self.pcp_size == 1 else
total_num_scheduled_tokens * self.pcp_size -
total_num_pcp_pads)
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
@@ -1479,13 +1541,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()
slot_mapping = blk_table.slot_mapping_cpu[:
total_num_scheduled_tokens]
self.slot_mapping[:total_num_scheduled_tokens].copy_(
slot_mapping[:total_num_scheduled_tokens],
slot_mapping = blk_table.slot_mapping_cpu[:slot_mapping_size]
self.slot_mapping[:slot_mapping_size].copy_(
slot_mapping[:slot_mapping_size],
non_blocking=True,
)
self.slot_mapping[total_num_scheduled_tokens:].fill_(0)
self.slot_mapping[slot_mapping_size:].fill_(0)
if self.pcp_size > 1:
assert pcp_unpad_mask is not None
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
pcp_unpad_mask
.
shape[
0]]
pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[
pcp_unpad_mask] = self.slot_mapping[:slot_mapping_size]
self.slot_mapping[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1494,7 +1567,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
seq_lens_cpu=self.seq_lens_cpu,
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
num_actual_tokens=slot_mapping_size,
num_input_tokens=num_input_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn
@@ -1512,6 +1585,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
decode_token_per_req=self.decode_token_per_req,
cos=self.cos,
sin=self.sin,
prefill_context_parallel_metadata=long_seq_metadata,
)
if self.speculative_config and \
@@ -1587,6 +1661,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
pad_size = get_forward_context().pad_size
if pad_size > 0:
hidden_states = hidden_states[:-pad_size, :]
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = torch.index_select(
hidden_states, 0,
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
return hidden_states
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
@@ -2485,8 +2565,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def profile_run(self) -> None:
# Trigger compilation for general shape.
with self.set_in_profile_run():
hidden_states = self._dummy_run(self.max_num_tokens,
with_prefill=True)
hidden_states = self._dummy_run(
self.max_num_tokens //
self.pcp_size if self.pcp_size > 1 else self.max_num_tokens,
with_prefill=True)
# MC2 will consume additional NPU memory.
# Therefore, we need to run the MC2 path once here to complete its initialization,
# allowing vLLM to correctly estimate the maximum memory required.
@@ -3620,3 +3702,236 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _build_drafter_prepare_inputs_torchair_param(self):
return False
def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
if not self.pcp_size > 1:
return tokens, None, None
tokens = np.array(tokens, dtype=np.int32)
num_decode_reqs = sum(
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
self.input_batch.num_prompt_tokens[:num_reqs])
num_padded_scheduled_tokens = np.ceil(
tokens /
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_size
self.num_pcp_pads = num_padded_scheduled_tokens - tokens
cu_padded_tokens, pcp_padded_arange = \
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
unpad_mask = torch.from_numpy(
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
pcp_tokens)
def get_current_rank_positions(cu_tokens, rank):
positions_start_loc = np.zeros_like(cu_tokens)
positions_start_loc[1:] = cu_tokens[:-1]
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
tail_start_loc = positions_start_loc + \
(2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
positions[pcp_head_chunk_mask] = pcp_chunk_arange + \
np.repeat(head_start_loc, pcp_chunk_sizes)
# Decode reqs do not have tail chunks.
positions[~pcp_head_chunk_mask] = \
pcp_chunk_arange[num_decode_reqs:] + \
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:]
return positions
positions = get_current_rank_positions(
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
# Decode tokens are duplicate and their positions always be 0.
positions[:num_decode_reqs] = 0
all_positions = [
get_current_rank_positions(cu_padded_tokens, rank_i)
for rank_i in range(self.pcp_size)
]
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
all_positions_tensor.float().argsort().long(), non_blocking=True)
pcp_tokens[:num_decode_reqs] = 1
return pcp_tokens, positions, unpad_mask
def _get_pcp_local_seq_lens(
self,
seq_lens: torch.Tensor,
pcp_world_size: int = 1,
dcp_world_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each (p/d)cp rank.
"""
num_requests = seq_lens.size(0)
total_world_size = pcp_world_size * dcp_world_size
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
rank_offsets = (torch.arange(total_world_size,
dtype=torch.int32).unsqueeze(0).repeat(
num_requests, 1))
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
total_world_size * cp_kv_cache_interleave_size)
remainder = seq_lens_tiled - base * total_world_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = (base + remainder).reshape(
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
num_reqs = self.input_batch.num_reqs
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
num_prefills = num_reqs - num_decodes
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
seq_lens,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
).numpy(),
)
if self.pcp_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
chunk_seqlens = []
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
q_req_offset = 0
kv_req_offset = 0
q_head_chunk_id = self.pcp_rank
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
for i, seq_len in enumerate(seq_lens):
if i < num_decodes:
continue
chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len)
q_head_idx.extend(
list(range(q_req_offset, q_req_offset + chunk_len)))
kv_with_q_head_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_head_chunk_id)))
kv_with_q_head_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_head_chunk_id,
kv_req_offset + chunk_len *
(q_head_chunk_id + 1))))
kv_with_q_head_nomask_seqlens.append(chunk_len *
q_head_chunk_id)
q_tail_idx.extend(
list(
range(q_req_offset + chunk_len,
q_req_offset + chunk_len * 2)))
kv_with_q_tail_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_tail_chunk_id)))
kv_with_q_tail_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_tail_chunk_id,
kv_req_offset + chunk_len *
(q_tail_chunk_id + 1))))
kv_with_q_tail_nomask_seqlens.append(chunk_len *
q_tail_chunk_id)
q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_size
# Convert lists to tensors and move to device
def _list_to_tensor(lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst),
dtype=dtype,
device=device)
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
non_blocking=True)
return tensor_npu
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
self.q_head_idx_tensor = q_head_idx_tensor
self.q_tail_idx_tensor = q_tail_idx_tensor
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
torch.int32)
self.q_full_idx = q_full_idx
self.kv_idx_names = {
'kv_with_q_head_nomask_idx_tensor':
kv_with_q_head_nomask_idx,
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
'kv_with_q_tail_nomask_idx_tensor':
kv_with_q_tail_nomask_idx,
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
}
for key, value in self.kv_idx_names.items():
tensor_npu = _list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
head_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
dtype=torch.int32)
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
if self.vllm_config.model_config.use_mla:
pcp_prefill_mask = torch.triu(
torch.ones(512,
512,
device=self.device,
dtype=self.dtype), 1)
else:
max_seq_len = max(seq_lens, default=0)
pcp_prefill_mask = torch.triu(
torch.full((num_prefills, max_seq_len, max_seq_len),
True,
device=self.device,
dtype=torch.bool), 1)
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
'pcp_prefill_mask': pcp_prefill_mask
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
num_actual_tokens_pcp_padded]
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_nomask_idx_tensor']
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_mask_idx_tensor']
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_nomask_idx_tensor']
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_mask_idx_tensor']
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
'attn_mask_seqlens']
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
return long_seq_metadata

View File

@@ -94,19 +94,21 @@ class CachedRequestState:
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
kernel_block_sizes: Optional[list[list[int]]] = None):
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
kernel_block_sizes: Optional[list[list[int]]] = None,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
@@ -151,7 +153,9 @@ class InputBatch:
device=device,
block_sizes=block_sizes,
num_speculative_tokens=num_speculative_tokens,
kernel_sizes=kernel_block_sizes)
kernel_sizes=kernel_block_sizes,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),

View File

@@ -49,6 +49,7 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (init_ascend_soc_version,
prefill_context_parallel_enable,
register_ascend_customop, sleep_mode_enabled,
try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -381,9 +382,17 @@ class NPUWorker(WorkerBase):
init_distributed_environment(self.parallel_config.world_size,
self.rank, self.distributed_init_method,
self.local_rank, "hccl")
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
if prefill_context_parallel_enable():
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
self.parallel_config.prefill_context_parallel_size,
self.parallel_config.decode_context_parallel_size)
else:
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
self.parallel_config.decode_context_parallel_size)
init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config)