[feature] support pcp + mtp in full graph (#4572)
1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
165
tests/e2e/multicard/long_sequence/test_mtp.py
Normal file
165
tests/e2e/multicard/long_sequence/test_mtp.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
os.environ["HCCL_BUFFSIZE"] = "512"
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp1_eager():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 1,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp3_eager():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp3_piecewise_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
compilation_config={
|
||||
"cudagraph_mode": "PIECEWISE",
|
||||
"cudagraph_capture_sizes": [4, 8, 16],
|
||||
},
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_pcp_dcp_mtp3_full_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [4, 8, 16],
|
||||
},
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is('0.12.0'),
|
||||
reason="0.12.0 is not supported for context sequence.")
|
||||
def test_dcp_mtp3_full_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [4, 8, 16],
|
||||
},
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
@@ -11,6 +11,7 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
@@ -215,10 +216,23 @@ class TestMtpProposer:
|
||||
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
|
||||
mock_deps.runner.pcp_size = 2
|
||||
mock_deps.runner.input_ids_pcp_full = torch.arange(32,
|
||||
dtype=torch.int32)
|
||||
mock_deps.runner.query_start_loc_pcp_full_cpu = torch.tensor(
|
||||
[0, 8, 16, 24, 32])
|
||||
mock_deps.runner.dcp_size = 1
|
||||
mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer(
|
||||
32,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.input_ids_pcp_full.cpu = \
|
||||
torch.arange(32, dtype=torch.int32)
|
||||
mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer(
|
||||
5,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.query_start_loc_pcp_full.cpu = \
|
||||
torch.tensor([0, 8, 16, 24, 32])
|
||||
mock_deps.positions = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
|
||||
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
|
||||
@@ -232,6 +246,7 @@ class TestMtpProposer:
|
||||
proposer.speculative_config = MagicMock(
|
||||
disable_padded_drafter_batch=False)
|
||||
proposer.pcp_size = mock_deps.runner.pcp_size
|
||||
proposer.dcp_size = mock_deps.runner.dcp_size
|
||||
proposer.prepare_next_token_ids_padded = MagicMock(
|
||||
return_value=(torch.tensor([101, 200, 302]), 3))
|
||||
proposer.prepare_inputs_padded = MagicMock(
|
||||
|
||||
@@ -50,6 +50,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_reqs = num_reqs
|
||||
mock_runner.speculative_config = None
|
||||
|
||||
num_computed_tokens = []
|
||||
num_prompt_tokens = []
|
||||
@@ -169,23 +170,24 @@ def test_pcp_allgather_restore_idx_slicing():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \
|
||||
"pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens",
|
||||
[
|
||||
# Case 1: prefill only
|
||||
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
|
||||
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]),
|
||||
|
||||
# Case 2: mix prefill and decode
|
||||
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, [8, 4, 4]),
|
||||
# Case 2: mix prefill and decode (with spec decode)
|
||||
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]),
|
||||
|
||||
# Case 3: request which need to be padded
|
||||
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
|
||||
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]),
|
||||
|
||||
# Case 4: single request
|
||||
([10], 1, [0], [10], 4, 0, [4]),
|
||||
([10], 1, [0], [10], 4, 0, 1, [4]),
|
||||
])
|
||||
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
||||
num_prompt_tokens, pcp_size, pcp_rank,
|
||||
expected_pcp_tokens):
|
||||
decode_threshold, expected_pcp_tokens):
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.pcp_size = pcp_size
|
||||
mock_runner.pcp_rank = pcp_rank
|
||||
@@ -201,6 +203,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
||||
|
||||
mock_runner.num_pcp_pads = [0] * num_reqs
|
||||
mock_runner.arange_np = np.arange(10000)
|
||||
mock_runner.decode_threshold = decode_threshold
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
@@ -243,6 +246,7 @@ def test_update_tokens_for_pcp_with_padding():
|
||||
|
||||
mock_runner.num_pcp_pads = [0, 0, 0]
|
||||
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
@@ -279,6 +283,7 @@ def test_update_tokens_for_pcp_unpad_mask():
|
||||
|
||||
mock_runner.num_pcp_pads = [0, 0]
|
||||
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
@@ -369,6 +374,9 @@ def pcp_mtp_mock_runner():
|
||||
|
||||
mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer(
|
||||
mock_runner, max_num_tokens, dtype=torch.int32)
|
||||
mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer(
|
||||
mock_runner, max_num_reqs, dtype=torch.int32)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner.arange_np = np.arange(max_model_len)
|
||||
mock_runner.input_batch = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user