Files
xc-llm-ascend/tests/ut/spec_decode/test_mtp_proposer.py

348 lines
15 KiB
Python
Raw Normal View History

from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
ModelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
class TestMtpProposer:
@pytest.fixture(autouse=True)
def patch_supports_multimodal_inputs(self):
with patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
):
yield
@pytest.fixture
def vllm_config(self):
config = MagicMock(spec=VllmConfig)
config.additional_config = None
config.speculative_config = MagicMock(spec=SpeculativeConfig)
config.speculative_config.num_speculative_tokens = 2
config.speculative_config.method = "mtp"
config.speculative_config.draft_model_config = MagicMock()
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
[Main2Main] Upgrade vllm commit to 0123 (#6169) ### What this PR does / why we need it? 1. ✅ Upgrade vllm commit to: 0115 (8471b27df97c3eb79f891802fc0e858f8f7ac6a0) Modify import paths due to the refactors: https://github.com/vllm-project/vllm/pull/32245 https://github.com/vllm-project/vllm/pull/32060 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913 2. ✅Upgrade vllm commit to: 0119 (9a1f16da1e423ede2c2f52a9850cbfbb39cefe96) Fix `WorkerProc.__init__() missing 1 required positional argument: 'is_driver_worker'` due to https://github.com/vllm-project/vllm/pull/28506 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569 3. ✅Upgrade vllm commit to: 0120(148117ea2e689cd43df4be6892671a17cdae5833) 1. Add `skip_compiled` param in `set_forward_context` due to https://github.com/vllm-project/vllm/pull/30385 2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to https://github.com/vllm-project/vllm/pull/24322 change `self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size` 3. Modify UT import paths due to the refactors:https://github.com/vllm-project/vllm/pull/32060 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946 4. ✅Upgrade vllm commit to: 0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9) 1. vLLM switched `uses_mrope` from target to draft model config, making `positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's direct self.positions access and tests missing `draft_model_config.uses_mrope`. https://github.com/vllm-project/vllm/pull/32048 2. Moved bs_to_padded_graph_size from CompilationConfig to CudagraphDispatcher due to the refactor https://github.com/vllm-project/vllm/pull/30143 3. Remove unused `maybe_setup_kv_connector` due to https://github.com/vllm-project/vllm/pull/32077 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834 6. ✅Upgrade vllm commit to: 0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5) Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig due to https://github.com/vllm-project/vllm/pull/32414 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054 8. ✅Upgrade vllm commit to: 0123(dc917cceb877dfd13f98c538c4c96158047d98bd) Setting temperature=0.0 due to the removal of the default temperature value in https://github.com/vllm-project/vllm/pull/32723 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 --------- Signed-off-by: wjunLu <wjunlu217@gmail.com> Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Co-authored-by: wjunLu <wjunlu217@gmail.com>
2026-01-27 08:44:36 +08:00
config.speculative_config.draft_model_config.uses_mrope = False
[main2main] upgrade vllm main 0202 (#6560) ### What this PR does / why we need it? 1. Fix `TypeError: FusedMoEParallelConfig.__init__() missing 1 required positional argument: 'is_sequence_parallel'` due to https://github.com/vllm-project/vllm/pull/32567 2. Fix ` TypeError: '>' not supported between instances of 'MagicMock' and 'int'` due to https://github.com/vllm-project/vllm/pull/33035 3. Fix `TypeError: Can't instantiate abstract class AscendMLAImpl with abstract methods forward_mha, forward_mqa` and AttributeError: 'bool' object has no attribute 'process_weights_after_loading' due to https://github.com/vllm-project/vllm/pull/33284 4. Fix `'AscendSharedFusedMoE' object has no attribute '_routed_input_transform'`due to https://github.com/vllm-project/vllm/pull/32790 5. Fix `NPUModelRunner._dummy_run() got an unexpected keyword argument 'num_active_loras'` due to https://github.com/vllm-project/vllm/pull/32005 6. Fix the problem caused by` 'tuple' object has no attribute 'job_id'` due to https://github.com/vllm-project/vllm/pull/27492 7. Fix the problem that all_moe_layers is not equal to vllm.moe_forward, vllm.moe_forward_shared due to https://github.com/vllm-project/vllm/pull/33184 8. Add patch to fix the problem "got multiple values for keyword argument 'add_special_tokens'" due to https://github.com/vllm-project/vllm/pull/32863 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com>
2026-02-05 19:31:17 +08:00
config.speculative_config.draft_model_config.uses_xdrope_dim = 0
config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
])
config.model_config = MagicMock(spec=ModelConfig)
config.model_config.dtype = torch.float16
config.model_config.max_model_len = 2048
config.model_config.uses_mrope = False
[main2main] upgrade vllm main 0202 (#6560) ### What this PR does / why we need it? 1. Fix `TypeError: FusedMoEParallelConfig.__init__() missing 1 required positional argument: 'is_sequence_parallel'` due to https://github.com/vllm-project/vllm/pull/32567 2. Fix ` TypeError: '>' not supported between instances of 'MagicMock' and 'int'` due to https://github.com/vllm-project/vllm/pull/33035 3. Fix `TypeError: Can't instantiate abstract class AscendMLAImpl with abstract methods forward_mha, forward_mqa` and AttributeError: 'bool' object has no attribute 'process_weights_after_loading' due to https://github.com/vllm-project/vllm/pull/33284 4. Fix `'AscendSharedFusedMoE' object has no attribute '_routed_input_transform'`due to https://github.com/vllm-project/vllm/pull/32790 5. Fix `NPUModelRunner._dummy_run() got an unexpected keyword argument 'num_active_loras'` due to https://github.com/vllm-project/vllm/pull/32005 6. Fix the problem caused by` 'tuple' object has no attribute 'job_id'` due to https://github.com/vllm-project/vllm/pull/27492 7. Fix the problem that all_moe_layers is not equal to vllm.moe_forward, vllm.moe_forward_shared due to https://github.com/vllm-project/vllm/pull/33184 8. Add patch to fix the problem "got multiple values for keyword argument 'add_special_tokens'" due to https://github.com/vllm-project/vllm/pull/32863 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com>
2026-02-05 19:31:17 +08:00
config.model_config.uses_xdrope_dim = 0
config.model_config.hf_text_config = None
[BugFix] Support setting tp=1 for the Eagle draft model to take effect (#6097) According to the official documentation, the parameter "draft_tensor_parallel_size": 1 is supposed to be applied to the Eagle3 model. However, based on actual debugging, it was found that the number of tensor parallelisms (tp) of the Eagle model is consistent with that of the target model. The setting of tp for the draft model did not take effect as expected. **Note:** This feature has not been superimposed and tested with `sp` and `dp`. It will be adapted later No ```python from vllm import LLM, SamplingParams def main(): prompts = [ "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=4, gpu_memory_utilization=0.9, enforce_eager=True, speculative_config={ "method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" "draft_tensor_parallel_size": 1, "num_speculative_tokens": 3, }, ) outputs = llm.generate(prompts, sampling_params) print(f"Outputs: {outputs}") for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` Fixes vllm-project/vllm#31345 ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com> Co-authored-by: drslark <slarksblood@qq.com>
2026-01-22 11:36:23 +08:00
config.model_config.hf_config = None
config.parallel_config.tensor_parallel_size = 1
[main2main] upgrade vllm main 0202 (#6560) ### What this PR does / why we need it? 1. Fix `TypeError: FusedMoEParallelConfig.__init__() missing 1 required positional argument: 'is_sequence_parallel'` due to https://github.com/vllm-project/vllm/pull/32567 2. Fix ` TypeError: '>' not supported between instances of 'MagicMock' and 'int'` due to https://github.com/vllm-project/vllm/pull/33035 3. Fix `TypeError: Can't instantiate abstract class AscendMLAImpl with abstract methods forward_mha, forward_mqa` and AttributeError: 'bool' object has no attribute 'process_weights_after_loading' due to https://github.com/vllm-project/vllm/pull/33284 4. Fix `'AscendSharedFusedMoE' object has no attribute '_routed_input_transform'`due to https://github.com/vllm-project/vllm/pull/32790 5. Fix `NPUModelRunner._dummy_run() got an unexpected keyword argument 'num_active_loras'` due to https://github.com/vllm-project/vllm/pull/32005 6. Fix the problem caused by` 'tuple' object has no attribute 'job_id'` due to https://github.com/vllm-project/vllm/pull/27492 7. Fix the problem that all_moe_layers is not equal to vllm.moe_forward, vllm.moe_forward_shared due to https://github.com/vllm-project/vllm/pull/33184 8. Add patch to fix the problem "got multiple values for keyword argument 'add_special_tokens'" due to https://github.com/vllm-project/vllm/pull/32863 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com>
2026-02-05 19:31:17 +08:00
config.parallel_config.data_parallel_rank = 0
[BugFix] Support setting tp=1 for the Eagle draft model to take effect (#6097) According to the official documentation, the parameter "draft_tensor_parallel_size": 1 is supposed to be applied to the Eagle3 model. However, based on actual debugging, it was found that the number of tensor parallelisms (tp) of the Eagle model is consistent with that of the target model. The setting of tp for the draft model did not take effect as expected. **Note:** This feature has not been superimposed and tested with `sp` and `dp`. It will be adapted later No ```python from vllm import LLM, SamplingParams def main(): prompts = [ "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=4, gpu_memory_utilization=0.9, enforce_eager=True, speculative_config={ "method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" "draft_tensor_parallel_size": 1, "num_speculative_tokens": 3, }, ) outputs = llm.generate(prompts, sampling_params) print(f"Outputs: {outputs}") for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` Fixes vllm-project/vllm#31345 ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com> Co-authored-by: drslark <slarksblood@qq.com>
2026-01-22 11:36:23 +08:00
config.speculative_config.draft_tensor_parallel_size = 1
config.load_config = None
config.cache_config = MagicMock(spec=CacheConfig)
config.cache_config.block_size = 16
config.scheduler_config = MagicMock(spec=SchedulerConfig)
config.scheduler_config.max_num_batched_tokens = 4096
config.scheduler_config.max_num_seqs = 256
config.compilation_config = MagicMock(spec=CompilationConfig)
config.compilation_config.cudagraph_capture_sizes = [1, 2, 4, 8]
config.compilation_config.static_forward_context = dict()
config.device_config = MagicMock()
config.device_config.device = torch.device("cpu")
init_ascend_config(config)
return config
@pytest.fixture
def runner(self):
runner = MagicMock()
runner.pcp_size = 1
runner.dcp_size = 1
runner.pcp_rank = 0
runner.max_num_tokens = 4096
runner.max_num_reqs = 256
runner._use_aclgraph.return_value = False
runner.reserved_mc2_mask = None
runner.pin_memory = False
return runner
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
# Test basic initialization
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.vllm_config == vllm_config
assert proposer.device == torch.device("cpu")
assert proposer.dtype == torch.float16
assert proposer.num_speculative_tokens == 2
assert proposer.hidden_size == 4096
# Test with mrope enabled
assert hasattr(proposer, "positions")
assert not hasattr(proposer, "mrope_positions")
assert proposer.use_sparse is False
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config,
runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
runner._use_aclgraph.return_value = True
vllm_config.scheduler_config.async_scheduling = False
vllm_config.speculative_config.enforce_eager = False
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.use_cuda_graph is True
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context,
mock_get_forward_context, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer.model = MagicMock()
proposer.enable_shared_expert_dp = False
runner._sync_metadata_across_dp.return_value = (8, 8, False)
mock_get_forward_context = MagicMock()
mock_get_forward_context.cudagraph_runtime_mode = None
mock_get_forward_context.capturing = True
# Execute
proposer.dummy_run(8)
# Verify
runner._sync_metadata_across_dp.assert_called_once()
mock_set_context.assert_called()
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context,
mock_get_forward_context, vllm_config,
runner):
# Setup
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer.enable_shared_expert_dp = False
proposer.model = MagicMock()
runner._sync_metadata_across_dp.return_value = (8, 8, False)
runner.attn_groups = []
mock_get_forward_context = MagicMock()
mock_get_forward_context.cudagraph_runtime_mode = None
mock_get_forward_context.capturing = True
# Execute
proposer.dummy_run(num_tokens=8,
num_reqs=5,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
# Verify
runner._sync_metadata_across_dp.assert_called_once()
mock_set_context.assert_called()
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
sampled_token_ids = [[10, 20, 30], [40, 50], [60]]
mock_gpu_batch = MagicMock()
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
proposer = MagicMock(spec=MtpProposer)
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__(
proposer)
result = proposer.prepare_next_token_ids_cpu(
sampled_token_ids=sampled_token_ids,
requests={},
gpu_input_batch=mock_gpu_batch,
num_scheduled_tokens=mock_num_scheduled)
assert torch.all(
result == torch.tensor([30, 50, 60], dtype=torch.int32))
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer):
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
[10, 8, 5, 12], dtype=torch.int32)
mock_sampled_token_ids = torch.tensor([
[101, 102, 103],
[201, -1, 203],
[-1, -1, -1],
[301, 10000, 303],
],
dtype=torch.int32,
device=torch.device("cpu"))
mock_requests = {} # dict[str, CachedRequestState]
req0 = MagicMock(spec=CachedRequestState)
req0.get_token_id = MagicMock(return_value=1000)
mock_requests["req_0"] = req0
req1 = MagicMock(spec=CachedRequestState)
req1.get_token_id = MagicMock(return_value=2000)
mock_requests["req_1"] = req1
req2 = MagicMock(spec=CachedRequestState)
req2.get_token_id = MagicMock(return_value=3000)
mock_requests["req_2"] = req2
req3 = MagicMock(spec=CachedRequestState)
req3.get_token_id = MagicMock(return_value=4000)
mock_requests["req_3"] = req3
mock_gpu_input_batch = MagicMock(spec=InputBatch)
mock_gpu_input_batch.num_reqs = 4
mock_gpu_input_batch.req_ids = ["req_0", "req_1", "req_2", "req_3"]
mock_gpu_input_batch.vocab_size = 5000
mock_backup = MagicMock()
mock_backup.np = np.array([1, 2, 3, 4, 5, 6, 7], dtype=np.int32)
mock_backup.gpu = torch.tensor([1, 2, 3, 4, 5, 6, 7],
dtype=torch.int32)
mock_backup.copy_to_gpu = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_backup
proposer = MagicMock(spec=MtpProposer)
proposer.backup_next_token_ids = mock_backup
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__(
proposer)
discard_request_indices = torch.tensor([1, 3], dtype=torch.int64)
num_discarded_requests = 2
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata=mock_common_attn_metadata,
sampled_token_ids=mock_sampled_token_ids,
requests=mock_requests,
gpu_input_batch=mock_gpu_input_batch,
discard_request_indices=discard_request_indices,
num_discarded_requests=num_discarded_requests)
mock_backup_output = proposer.backup_next_token_ids
expected_backup_cpu = np.array(
[1000, 2000, 3000, 4000, 0, 0, 0, 0, 0, 0])
assert np.array_equal(mock_backup_output.np[:4],
expected_backup_cpu[:4])
mock_backup_output.copy_to_gpu.assert_called_once_with(4)
modified_sampled = mock_sampled_token_ids.clone()
modified_sampled.index_fill_(
0, discard_request_indices[:num_discarded_requests], -1)
assert valid_sampled_tokens_count[1].item() == 0
assert valid_sampled_tokens_count[3].item() == 0
expected_valid_count = torch.tensor([3, 0, 0, 0], dtype=torch.int32)
assert torch.equal(valid_sampled_tokens_count, expected_valid_count)
expected_next_tokens = torch.tensor([103, 2, 3, 4],
dtype=torch.int32,
device=torch.device("cpu"))
assert torch.equal(next_token_ids, expected_next_tokens)
@patch("vllm_ascend.spec_decode.eagle_proposer.HAS_TRITON", False)
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
mock_common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 8, 16, 24], dtype=torch.int32)
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
[8, 8, 8], dtype=torch.int32)
mock_common_attn_metadata.num_input_tokens = 3
mock_common_attn_metadata.query_start_loc = torch.tensor(
[0, 8, 16, 24], dtype=torch.int32)
mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8],
dtype=torch.int32)
mock_common_attn_metadata.num_actual_tokens = 24
mock_common_attn_metadata.num_reqs = 3
mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor(
[5, 6, 7], dtype=torch.int32)
mock_common_attn_metadata.block_table_tensor = MagicMock()
mock_common_attn_metadata.slot_mapping = MagicMock()
mock_common_attn_metadata.positions = MagicMock()
mock_spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
mock_spec_decode_metadata.cu_num_draft_tokens = torch.tensor(
[3, 5, 7], dtype=torch.int32)
mock_runner = MagicMock()
mock_runner.actual_seq_lengths_q = MagicMock()
mock_runner.attn_state = MagicMock()
mock_runner.graph_pad_size = 0
mock_runner.pcp_size = 1
mock_runner.decode_token_per_req = MagicMock()
proposer = MagicMock(spec=MtpProposer)
proposer.runner = mock_runner
proposer.pcp_size = 1
proposer.arange = torch.arange(100, dtype=torch.int32)
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
proposer)
mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2],
dtype=torch.int32)
(spec_common_attn_metadata, token_indices,
token_indices_to_sample) = proposer.prepare_inputs_padded(
common_attn_metadata=mock_common_attn_metadata,
spec_decode_metadata=mock_spec_decode_metadata,
valid_sampled_tokens_count=mock_valid_sampled_tokens_count)
total_num_tokens = mock_common_attn_metadata.query_start_loc_cpu[
-1].item()
expected_token_indices = proposer.arange[:total_num_tokens]
assert torch.equal(token_indices, expected_token_indices)
assert token_indices.shape == (24, )
assert token_indices.dtype == torch.int32
expected_sample_indices = torch.tensor([5, 13, 22], dtype=torch.int32)
assert torch.equal(token_indices_to_sample, expected_sample_indices)
assert isinstance(spec_common_attn_metadata,
AscendCommonAttentionMetadata)
assert torch.equal(spec_common_attn_metadata.query_start_loc,
mock_common_attn_metadata.query_start_loc)
assert torch.equal(spec_common_attn_metadata.query_start_loc_cpu,
mock_common_attn_metadata.query_start_loc_cpu)
assert torch.equal(spec_common_attn_metadata.seq_lens_cpu,
mock_common_attn_metadata.seq_lens)
assert spec_common_attn_metadata.num_reqs == mock_common_attn_metadata.num_reqs
assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens
assert spec_common_attn_metadata.max_query_len == 8
assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q