[feature] support pcp + mtp in full graph (#4572)
1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -269,6 +269,7 @@ jobs:
|
||||
pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py
|
||||
pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_basic.py
|
||||
pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_accuracy.py
|
||||
pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_mtp.py
|
||||
|
||||
- name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct)
|
||||
shell: bash -l {0}
|
||||
|
||||
165
tests/e2e/multicard/long_sequence/test_mtp.py
Normal file
165
tests/e2e/multicard/long_sequence/test_mtp.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
os.environ["HCCL_BUFFSIZE"] = "512"
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp1_eager():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 1,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp3_eager():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp3_piecewise_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
compilation_config={
|
||||
"cudagraph_mode": "PIECEWISE",
|
||||
"cudagraph_capture_sizes": [4, 8, 16],
|
||||
},
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp3_full_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [4, 8, 16],
|
||||
},
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_dcp_mtp3_full_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [4, 8, 16],
|
||||
},
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
@@ -11,6 +11,7 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
@@ -215,10 +216,23 @@ class TestMtpProposer:
|
||||
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
|
||||
mock_deps.runner.pcp_size = 2
|
||||
mock_deps.runner.input_ids_pcp_full = torch.arange(32,
|
||||
dtype=torch.int32)
|
||||
mock_deps.runner.query_start_loc_pcp_full_cpu = torch.tensor(
|
||||
[0, 8, 16, 24, 32])
|
||||
mock_deps.runner.dcp_size = 1
|
||||
mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer(
|
||||
32,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.input_ids_pcp_full.cpu = \
|
||||
torch.arange(32, dtype=torch.int32)
|
||||
mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer(
|
||||
5,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.query_start_loc_pcp_full.cpu = \
|
||||
torch.tensor([0, 8, 16, 24, 32])
|
||||
mock_deps.positions = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
|
||||
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
|
||||
@@ -232,6 +246,7 @@ class TestMtpProposer:
|
||||
proposer.speculative_config = MagicMock(
|
||||
disable_padded_drafter_batch=False)
|
||||
proposer.pcp_size = mock_deps.runner.pcp_size
|
||||
proposer.dcp_size = mock_deps.runner.dcp_size
|
||||
proposer.prepare_next_token_ids_padded = MagicMock(
|
||||
return_value=(torch.tensor([101, 200, 302]), 3))
|
||||
proposer.prepare_inputs_padded = MagicMock(
|
||||
|
||||
@@ -50,6 +50,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_reqs = num_reqs
|
||||
mock_runner.speculative_config = None
|
||||
|
||||
num_computed_tokens = []
|
||||
num_prompt_tokens = []
|
||||
@@ -169,23 +170,24 @@ def test_pcp_allgather_restore_idx_slicing():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \
|
||||
"pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens",
|
||||
[
|
||||
# Case 1: prefill only
|
||||
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
|
||||
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]),
|
||||
|
||||
# Case 2: mix prefill and decode
|
||||
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, [8, 4, 4]),
|
||||
# Case 2: mix prefill and decode (with spec decode)
|
||||
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]),
|
||||
|
||||
# Case 3: request which need to be padded
|
||||
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
|
||||
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]),
|
||||
|
||||
# Case 4: single request
|
||||
([10], 1, [0], [10], 4, 0, [4]),
|
||||
([10], 1, [0], [10], 4, 0, 1, [4]),
|
||||
])
|
||||
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
||||
num_prompt_tokens, pcp_size, pcp_rank,
|
||||
expected_pcp_tokens):
|
||||
decode_threshold, expected_pcp_tokens):
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.pcp_size = pcp_size
|
||||
mock_runner.pcp_rank = pcp_rank
|
||||
@@ -201,6 +203,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
||||
|
||||
mock_runner.num_pcp_pads = [0] * num_reqs
|
||||
mock_runner.arange_np = np.arange(10000)
|
||||
mock_runner.decode_threshold = decode_threshold
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
@@ -243,6 +246,7 @@ def test_update_tokens_for_pcp_with_padding():
|
||||
|
||||
mock_runner.num_pcp_pads = [0, 0, 0]
|
||||
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
@@ -279,6 +283,7 @@ def test_update_tokens_for_pcp_unpad_mask():
|
||||
|
||||
mock_runner.num_pcp_pads = [0, 0]
|
||||
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
@@ -369,6 +374,9 @@ def pcp_mtp_mock_runner():
|
||||
|
||||
mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer(
|
||||
mock_runner, max_num_tokens, dtype=torch.int32)
|
||||
mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer(
|
||||
mock_runner, max_num_reqs, dtype=torch.int32)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner.arange_np = np.arange(max_model_len)
|
||||
mock_runner.input_batch = MagicMock()
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
get_mtp_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
@@ -92,6 +93,10 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
|
||||
if num_actual_tokens_pcp_padded is None:
|
||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||
# In dcp only spec decode graph padding case,
|
||||
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
|
||||
num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded,
|
||||
num_actual_tokens)
|
||||
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
assert num_computed_tokens_of_pcp_dcp is not None
|
||||
|
||||
@@ -113,15 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
common_attn_metadata.block_table_tensor[:graph_pad_size])
|
||||
else:
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
if self.pcp_size > 1:
|
||||
num_decodes_flatten = num_decodes * self.decode_threshold
|
||||
block_table = common_attn_metadata.block_table_tensor[:
|
||||
num_decodes_flatten
|
||||
+
|
||||
num_prefills]
|
||||
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
@@ -144,6 +140,13 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
num_decodes_flatten = query_lens[:num_decodes].sum().item()
|
||||
block_table = common_attn_metadata.block_table_tensor[:
|
||||
num_decodes_flatten
|
||||
+ num_prefills]
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if num_prefills > 0:
|
||||
@@ -201,7 +204,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
dtype=torch.int32)
|
||||
|
||||
local_context_lens_allranks = torch.tensor(
|
||||
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
|
||||
num_computed_tokens_of_pcp_dcp[num_decodes_flatten:]
|
||||
).reshape(-1, self.dcp_size * self.pcp_size)
|
||||
# Note(qcs): The max local context lengths
|
||||
# padded to `cp_local_block_size`.
|
||||
@@ -280,9 +283,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
cos=cos,
|
||||
pcp_metadata=pcp_metadata,
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
prefill_metadata.block_table = block_table[
|
||||
num_decodes_flatten:, ...]
|
||||
prefill_metadata.block_table = \
|
||||
block_table[num_decodes_flatten:, ...]
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
@@ -293,13 +295,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
if self.pcp_size > 1:
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
block_table = block_table[:num_decodes_flatten, ...]
|
||||
else:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
block_table = block_table[:num_decodes_flatten, ...]
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
if graph_pad_size > num_decodes and \
|
||||
self.speculative_config.disable_padded_drafter_batch:
|
||||
@@ -308,8 +304,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
|
||||
# [bs, pcp_size, dcp_size]
|
||||
num_computed_tokens_of_cp_dcp_array = np.array(
|
||||
num_computed_tokens_of_pcp_dcp)[:num_decodes *
|
||||
self.decode_threshold]
|
||||
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
|
||||
|
||||
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
|
||||
self.dcp_rank]
|
||||
@@ -1057,8 +1052,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
"return_lse": True,
|
||||
"calc_type": "calc_type_ring",
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.is_mtp_model:
|
||||
graph_params = get_mtp_graph_params()
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
event = torch.npu.ExternalEvent()
|
||||
|
||||
@@ -67,6 +67,12 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
# original query_lens before pcp split
|
||||
query_lens_pcp_full_cpu: torch.Tensor = None
|
||||
|
||||
# original max_query_len before pcp split
|
||||
max_query_len_pcp_full: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
@@ -189,6 +195,8 @@ def split_decodes_and_prefills(
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
While pcp > 1, query_lens is split across pcp ranks, so we pass in the
|
||||
original query_lens and max_query_len to distinguish prefills and decodes.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
||||
@@ -201,7 +209,13 @@ def split_decodes_and_prefills(
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
|
||||
if long_seq_metadata else None
|
||||
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
|
||||
if long_seq_metadata else 0
|
||||
max_query_len = common_attn_metadata.max_query_len \
|
||||
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
@@ -209,7 +223,8 @@ def split_decodes_and_prefills(
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
|
||||
if query_lens_pcp_full is None else query_lens_pcp_full
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
@@ -440,7 +440,10 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
|
||||
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
if forward_context.is_mtp_model:
|
||||
graph_params = get_mtp_graph_params()
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
|
||||
@@ -32,6 +32,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
@@ -98,6 +99,7 @@ class MtpProposer(Proposer):
|
||||
self.pcp_size = self.runner.pcp_size
|
||||
self.dcp_size = self.runner.dcp_size
|
||||
self.pcp_rank = self.runner.pcp_rank
|
||||
self.dcp_rank = self.runner.dcp_rank
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
self.draft_indexer_metadata_builder: Optional[
|
||||
@@ -267,6 +269,13 @@ class MtpProposer(Proposer):
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update long_seq related params and flatten block_table
|
||||
common_attn_metadata.prefill_context_parallel_metadata = \
|
||||
self.runner.long_seq_metadata
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
self.runner.input_batch.block_table[0].get_device_tensor()[
|
||||
:num_reqs * self.decode_threshold]
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||
@@ -310,9 +319,15 @@ class MtpProposer(Proposer):
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context, num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(
|
||||
self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
else:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if self.enable_shared_expert_dp:
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
positions, True)
|
||||
@@ -364,11 +379,11 @@ class MtpProposer(Proposer):
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
@@ -396,12 +411,11 @@ class MtpProposer(Proposer):
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
common_attn_metadata.query_start_loc_cpu = \
|
||||
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
||||
common_attn_metadata.query_start_loc = \
|
||||
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
self._prepare_inputs(
|
||||
@@ -630,15 +644,18 @@ class MtpProposer(Proposer):
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert long_seq_metadata is not None
|
||||
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||
ori_last_token_indices = last_token_indices.clone()
|
||||
query_lens_d = self.runner.query_lens[:num_decode_reqs]
|
||||
if self.pcp_size > 1:
|
||||
# 1. preprocess decode/prefill input_ids & target_hidden_states
|
||||
# decode input_ids: keep unchanged
|
||||
# decode target_hidden_states: remove padding
|
||||
# prefill input_ids: add padding and pcp split
|
||||
# prefill target_hidden_states: pcp split
|
||||
num_tokens_d = num_decode_reqs * self.decode_threshold
|
||||
num_tokens_d = query_lens_d.sum().item()
|
||||
num_tokens_d_padded = num_tokens_d * self.pcp_size
|
||||
input_ids_d = self.input_ids[:num_tokens_d]
|
||||
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
|
||||
@@ -646,12 +663,17 @@ class MtpProposer(Proposer):
|
||||
target_hidden_states[:num_tokens_d_padded]
|
||||
if num_tokens_d:
|
||||
# remove padding (from pcp all-gather) in decode part
|
||||
target_hidden_states_d = target_hidden_states_d_padded.reshape(
|
||||
[
|
||||
num_decode_reqs, self.decode_threshold * self.pcp_size,
|
||||
-1
|
||||
])[:, :self.decode_threshold, :].reshape(
|
||||
[num_tokens_d, -1])
|
||||
mask_start_loc = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]
|
||||
])
|
||||
mask_len = query_lens_d
|
||||
mask = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
mask += list(
|
||||
range(mask_start_loc[req_id],
|
||||
mask_start_loc[req_id] + mask_len[req_id]))
|
||||
target_hidden_states_d = target_hidden_states_d_padded[mask]
|
||||
else:
|
||||
target_hidden_states_d = target_hidden_states_d_padded
|
||||
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
|
||||
@@ -670,25 +692,26 @@ class MtpProposer(Proposer):
|
||||
torch.cat([input_ids_d, input_ids_p], dim=0))
|
||||
target_hidden_states = torch.cat(
|
||||
[target_hidden_states_d, target_hidden_states_p], dim=0)
|
||||
# 2. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max(self.decode_threshold,
|
||||
max_query_len_p)
|
||||
common_attn_metadata.seq_lens[num_decode_reqs:] = seq_lens_p
|
||||
common_attn_metadata.seq_lens_cpu[num_decode_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + \
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs].item()
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs + 1:] = \
|
||||
query_start_loc_p
|
||||
common_attn_metadata.query_start_loc_cpu[num_decode_reqs + 1:] = \
|
||||
query_start_loc_p
|
||||
# 3. update sample_indices according to main model
|
||||
# 2. update sample_indices according to main model
|
||||
if num_decode_reqs:
|
||||
last_token_indices[:num_decode_reqs] = \
|
||||
self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
|
||||
if num_prefill_reqs:
|
||||
last_token_indices[-num_prefill_reqs:] = \
|
||||
self.runner.logits_indices[-num_prefill_reqs:]
|
||||
# 3. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max(
|
||||
self.decode_threshold, max_query_len_p)
|
||||
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
|
||||
common_attn_metadata.seq_lens_cpu[
|
||||
-num_prefill_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + \
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs].item()
|
||||
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = \
|
||||
query_start_loc_p
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = \
|
||||
query_start_loc_p
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
@@ -796,10 +819,15 @@ class MtpProposer(Proposer):
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens)
|
||||
else:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
@@ -814,7 +842,9 @@ class MtpProposer(Proposer):
|
||||
last_token_indices,
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size > 1 and step == 0:
|
||||
# remove graph padding before all_gather
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0, self.runner.
|
||||
@@ -855,6 +885,51 @@ class MtpProposer(Proposer):
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
if getattr(attn_metadata_i, "num_decode_tokens", 0):
|
||||
attn_metadata_i.num_decode_tokens = batch_size
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
positions = target_positions[ori_last_token_indices]
|
||||
# For pcp/dcp, tokens are split across different cp ranks,
|
||||
# so we can not simply update slot_mapping by += 1.
|
||||
# Instead, we pre-allocate mtp slot_mapping in model_runner
|
||||
# (_generate_pcp_mtp_input), and use updated slot_indices
|
||||
# to get corresponding slot_mapping in each step.
|
||||
num_reject_tokens = torch.tensor(
|
||||
self.runner.cu_num_tokens_pcp_full,
|
||||
dtype=torch.int32).to(
|
||||
self.device) - ori_last_token_indices - 1
|
||||
num_accept_tokens = \
|
||||
query_lens_d.to(self.device) - num_reject_tokens
|
||||
ori_seq_len = attn_metadata_i.seq_lens
|
||||
mtp_slot_mapping = self.runner.mtp_slot_pad
|
||||
|
||||
# slot_mapping index base offset:
|
||||
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
|
||||
slot_idx_base = (
|
||||
torch.cat([
|
||||
torch.tensor(
|
||||
[0], dtype=torch.int32, device=self.device),
|
||||
(torch.cumsum(query_lens_d, dim=0)[:-1] *
|
||||
self.pcp_size).to(self.device)
|
||||
]) +
|
||||
torch.arange(num_decode_reqs, device=self.device) *
|
||||
(self.num_speculative_tokens - 1) * self.pcp_size +
|
||||
(num_accept_tokens - 1) * self.pcp_size)
|
||||
slot_indices_list = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
slot_indices_list.append(
|
||||
torch.arange(slot_idx_base[req_id],
|
||||
slot_idx_base[req_id] + self.pcp_size,
|
||||
device=self.device))
|
||||
slot_indices = torch.cat(slot_indices_list, dim=0)
|
||||
|
||||
# fold block_table (restore it to original size before flattened)
|
||||
block_indices = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(query_lens_d, dim=0)[:-1]
|
||||
])
|
||||
attn_metadata_i.decode.block_table[:batch_size] = \
|
||||
attn_metadata_i.decode.block_table[block_indices]
|
||||
attn_metadata_i.decode.block_table = \
|
||||
attn_metadata_i.decode.block_table[:batch_size]
|
||||
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
@@ -901,13 +976,40 @@ class MtpProposer(Proposer):
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping += 1
|
||||
if self.pcp_size > 1:
|
||||
exceeds_max_model_len = exceeds_max_model_len.repeat_interleave(
|
||||
slot_mapping.size(0) // exceeds_max_model_len.size(0))
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update local seq_len and batch_seq_mask
|
||||
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
|
||||
ori_seq_len + step + 1,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.runner.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
cp_seq_len = \
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
|
||||
batch_seq_mask = (cp_seq_len == 0)
|
||||
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
|
||||
batch_seq_mask, non_blocking=True)
|
||||
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
|
||||
shape[0]]
|
||||
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||
attn_metadata_i.decode.cp_seq_len = cp_seq_len
|
||||
attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
|
||||
# update slot_mapping
|
||||
slot_indices += self.pcp_size
|
||||
slot_mapping = mtp_slot_mapping[slot_indices]
|
||||
attn_metadata_i.slot_mapping[:batch_size *
|
||||
self.pcp_size] = slot_mapping
|
||||
else:
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
self.positions[batch_size:num_input_tokens] = 0
|
||||
self.input_ids[batch_size:num_input_tokens] = 0
|
||||
|
||||
@@ -75,7 +75,7 @@ class BlockTable:
|
||||
logical_table_size = max_num_blocks_per_req
|
||||
|
||||
duplicate_size = 1
|
||||
if self.pcp_world_size > 1:
|
||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||
duplicate_size += num_speculative_tokens
|
||||
self.block_table = self._make_buffer(max_num_reqs * duplicate_size,
|
||||
logical_table_size,
|
||||
|
||||
@@ -280,7 +280,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_actual_tokens_pcp_padded = 0
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.int32)
|
||||
self.query_start_loc_pcp_full = self._make_buffer(
|
||||
@@ -289,8 +289,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.decode_token_per_req += self.speculative_config.num_speculative_tokens
|
||||
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
||||
self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.decode_threshold = 1 + (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0)
|
||||
@@ -575,6 +576,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.pcp_size > 1:
|
||||
if not self.vllm_config.model_config.use_mla:
|
||||
self.generate_kv_idx(scheduler_output)
|
||||
tokens_before_update = tokens.copy()
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
@@ -591,7 +593,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_valid_tokens = np.array([
|
||||
num_tokens -
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||||
for num_tokens, i in zip(tokens, req_ids)
|
||||
for num_tokens, i in zip((tokens_before_update if self.
|
||||
pcp_size > 1 else tokens), req_ids)
|
||||
],
|
||||
dtype=np.int32)
|
||||
|
||||
@@ -909,7 +912,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
|
||||
num_draft_tokens, cu_num_tokens,
|
||||
self.num_pcp_pads[:num_reqs].numpy())
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||||
@@ -931,10 +935,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self._generate_pcp_mtp_input(
|
||||
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
scheduler_output.num_scheduled_tokens, with_prefill,
|
||||
req_indices, positions_np, cu_num_tokens)
|
||||
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
total_num_scheduled_tokens)
|
||||
@@ -1040,7 +1045,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
# For pcp + spec decode, we flatten block_table
|
||||
# to avoid irregular spec_attn_mask shape, e.g.,
|
||||
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
||||
@@ -1048,12 +1053,13 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \
|
||||
self.query_start_loc_pcp_full.cpu[:num_reqs]
|
||||
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs]
|
||||
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
|
||||
num_decode_reqs_flatten = \
|
||||
ori_query_lens_cpu[:num_decode_reqs].sum().item()
|
||||
blk_table_tensor[
|
||||
num_decode_reqs_flatten:num_decode_reqs_flatten +
|
||||
num_prefill_reqs].copy_(
|
||||
@@ -1061,9 +1067,15 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_prefill_reqs].clone())
|
||||
blk_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
blk_table_tensor[:num_decode_reqs].repeat_interleave(
|
||||
self.decode_threshold, dim=0))
|
||||
ori_query_lens[:num_decode_reqs], dim=0))
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
|
||||
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
|
||||
if 'pad_size' in locals() and pad_size > 0:
|
||||
ori_query_lens_cpu[-pad_size:] = \
|
||||
torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
long_seq_metadata.max_query_len_pcp_full = \
|
||||
ori_query_lens_cpu.max().item()
|
||||
|
||||
if self.speculative_config and \
|
||||
self.spec_decode_common_attn_metadata is None:
|
||||
@@ -1861,7 +1873,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
block_table_tensor[:num_reqs * self.decode_threshold]
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
@@ -3029,9 +3041,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||||
tokens = np.array(tokens, dtype=np.int32)
|
||||
num_decode_reqs = sum(
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
|
||||
self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum()
|
||||
num_decode_tokens = sum(tokens[:num_decode_reqs])
|
||||
num_padded_scheduled_tokens = np.ceil(
|
||||
tokens /
|
||||
@@ -3118,8 +3128,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
|
||||
# In dummy run num_reqs == 0, update it from seq_lens
|
||||
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \
|
||||
if self.pcp_size > 1 and self.speculative_config else self.query_lens
|
||||
num_decodes = (query_lens <= self.decode_threshold).sum().item()
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
@@ -3137,16 +3149,41 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
# to avoid irregular spec_attn_mask shape.
|
||||
# Same as block_table, we flatten decode seq_lens to query_lens,
|
||||
# and keep prefill seq_lens unchanged.
|
||||
for decode_idx in range(self.decode_threshold):
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
|
||||
self._get_cp_local_seq_lens(
|
||||
torch.tensor(context_lens),
|
||||
torch.tensor(context_lens) - decode_idx,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
if self.decode_threshold > 1:
|
||||
num_computed_tokens_of_pcp_dcp_list = []
|
||||
if num_decodes:
|
||||
num_decodes_flatten = \
|
||||
self.query_lens[:num_decodes].sum().item()
|
||||
if self.query_lens[:num_decodes].min().item(
|
||||
) == self.decode_threshold:
|
||||
decode_flatten_idx = list(range(num_decodes_flatten))
|
||||
else:
|
||||
decode_flatten_idx = []
|
||||
for req_id in range(num_decodes):
|
||||
offset = (req_id + 1) * self.decode_threshold
|
||||
decode_flatten_idx += \
|
||||
list(range(offset - self.query_lens[req_id], offset))
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
|
||||
if num_prefills:
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
(num_decodes + 1) * self.decode_threshold -
|
||||
1::self.decode_threshold])
|
||||
num_computed_tokens_of_pcp_dcp = torch.cat(
|
||||
num_computed_tokens_of_pcp_dcp_list, dim=0)
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
|
||||
@@ -3278,6 +3315,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_reqs: int,
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
with_prefill: bool = True,
|
||||
req_indices=None,
|
||||
positions_np=None,
|
||||
cu_num_tokens=None,
|
||||
):
|
||||
"""
|
||||
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||||
@@ -3288,6 +3329,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
|
||||
num_scheduled_tokens_pcp_full)
|
||||
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens_pcp_full)
|
||||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||
@@ -3313,11 +3356,45 @@ class NPUModelRunner(GPUModelRunner):
|
||||
torch.from_numpy(token_indices_pcp_full),
|
||||
out=self.input_ids_pcp_full.
|
||||
cpu[:total_num_scheduled_tokens_pcp_full])
|
||||
self.query_lens_pcp_full.copy_to_gpu()
|
||||
self.query_start_loc_pcp_full.copy_to_gpu()
|
||||
self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_(
|
||||
self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
|
||||
# For mtpx, pre-allocate mtp slot_mapping here
|
||||
if self.decode_threshold > 2 and not with_prefill:
|
||||
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
|
||||
num_tokens_mtp = \
|
||||
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
|
||||
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size
|
||||
req_indices_split = np.array_split(req_indices,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
positions_split = np.array_split(positions_np,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
for req_idx in range(num_reqs):
|
||||
ori_req_indice = req_indices_split[req_idx]
|
||||
ori_position = positions_split[req_idx]
|
||||
req_indices_split[req_idx] = np.append(
|
||||
ori_req_indice,
|
||||
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
|
||||
positions_split[req_idx] = np.append(
|
||||
ori_position,
|
||||
np.arange(ori_position[-1] + 1,
|
||||
ori_position[-1] + self.decode_threshold - 1))
|
||||
req_indices_mtp = np.concatenate(req_indices_split)
|
||||
positions_mtp = np.concatenate(positions_split)
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices_mtp, positions_mtp)
|
||||
mtp_slot_ori = self.input_batch.block_table.block_tables[
|
||||
0].slot_mapping.cpu[:num_tokens_mtp]
|
||||
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
|
||||
unpad_mask[::self.pcp_size] = True
|
||||
mtp_slot_pad = \
|
||||
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
|
||||
mtp_slot_pad[unpad_mask] = mtp_slot_ori
|
||||
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
Reference in New Issue
Block a user