[feature] support pcp + mtp in full graph (#4572)
1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
@@ -215,10 +216,23 @@ class TestMtpProposer:
|
||||
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
|
||||
mock_deps.runner.pcp_size = 2
|
||||
mock_deps.runner.input_ids_pcp_full = torch.arange(32,
|
||||
dtype=torch.int32)
|
||||
mock_deps.runner.query_start_loc_pcp_full_cpu = torch.tensor(
|
||||
[0, 8, 16, 24, 32])
|
||||
mock_deps.runner.dcp_size = 1
|
||||
mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer(
|
||||
32,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.input_ids_pcp_full.cpu = \
|
||||
torch.arange(32, dtype=torch.int32)
|
||||
mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer(
|
||||
5,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.query_start_loc_pcp_full.cpu = \
|
||||
torch.tensor([0, 8, 16, 24, 32])
|
||||
mock_deps.positions = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
|
||||
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
|
||||
@@ -232,6 +246,7 @@ class TestMtpProposer:
|
||||
proposer.speculative_config = MagicMock(
|
||||
disable_padded_drafter_batch=False)
|
||||
proposer.pcp_size = mock_deps.runner.pcp_size
|
||||
proposer.dcp_size = mock_deps.runner.dcp_size
|
||||
proposer.prepare_next_token_ids_padded = MagicMock(
|
||||
return_value=(torch.tensor([101, 200, 302]), 3))
|
||||
proposer.prepare_inputs_padded = MagicMock(
|
||||
|
||||
Reference in New Issue
Block a user