### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
https://github.com/vllm-project/vllm/pull/32245
https://github.com/vllm-project/vllm/pull/32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
https://github.com/vllm-project/vllm/pull/28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
https://github.com/vllm-project/vllm/pull/30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
https://github.com/vllm-project/vllm/pull/24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:https://github.com/vllm-project/vllm/pull/32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
https://github.com/vllm-project/vllm/pull/32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
https://github.com/vllm-project/vllm/pull/30143
3. Remove unused `maybe_setup_kv_connector` due to
https://github.com/vllm-project/vllm/pull/32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to https://github.com/vllm-project/vllm/pull/32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in https://github.com/vllm-project/vllm/pull/32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
---------
Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
447 lines
20 KiB
Python
447 lines
20 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.ascend_config import init_ascend_config
|
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
|
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
|
|
|
|
|
class TestEagleProposerInitialization(TestBase):
|
|
|
|
def setUp(self):
|
|
self.vllm_config = MagicMock(spec=VllmConfig)
|
|
self.vllm_config.speculative_config = MagicMock()
|
|
self.vllm_config.cache_config = MagicMock(spec=CacheConfig)
|
|
self.vllm_config.scheduler_config = MagicMock()
|
|
self.vllm_config.model_config = MagicMock()
|
|
self.device = torch.device("cpu")
|
|
self.runner = MagicMock()
|
|
self.runner.pin_memory = False
|
|
|
|
self.vllm_config.cache_config.block_size = 16
|
|
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
|
self.vllm_config.model_config.dtype = torch.float16
|
|
self.vllm_config.model_config.max_model_len = 2048
|
|
self.vllm_config.model_config.uses_mrope = False
|
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
|
(i + 1) * (0, ) for i in range(2)
|
|
])
|
|
self.vllm_config.additional_config = None
|
|
|
|
self.mock_cpugpubuffer = patch(
|
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
|
self.mock_cpugpubuffer.start()
|
|
self.mock_supports_multimodal_inputs = patch(
|
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
|
)
|
|
self.mock_supports_multimodal_inputs.start()
|
|
|
|
def tearDown(self):
|
|
self.mock_cpugpubuffer.stop()
|
|
self.mock_supports_multimodal_inputs.stop()
|
|
|
|
def test_initialization_eagle_graph(self):
|
|
self.vllm_config.speculative_config.method = "eagle"
|
|
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
|
self.vllm_config.speculative_config.draft_model_config.uses_mrope = False
|
|
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
|
self.vllm_config.model_config.enforce_eager = False
|
|
self.vllm_config.model_config.uses_mrope = False
|
|
self.vllm_config.speculative_config.enforce_eager = False
|
|
self.vllm_config.scheduler_config.async_scheduling = False
|
|
init_ascend_config(self.vllm_config)
|
|
|
|
proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
device=self.device,
|
|
runner=self.runner)
|
|
|
|
self.assertEqual(proposer.hidden_size, 4096)
|
|
self.assertTrue(proposer.use_cuda_graph)
|
|
|
|
expected_max_num_tokens = proposer.max_num_tokens
|
|
self.assertEqual(proposer.input_ids.shape, (expected_max_num_tokens, ))
|
|
self.assertEqual(proposer.positions.shape, (expected_max_num_tokens, ))
|
|
self.assertEqual(proposer.hidden_states.shape, (expected_max_num_tokens, 4096))
|
|
self.assertEqual(proposer.arange.shape, (expected_max_num_tokens, ))
|
|
|
|
def test_initialization_eagle3_enforce_eager(self):
|
|
self.vllm_config.speculative_config.method = "eagle3"
|
|
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
|
|
self.vllm_config.compilation_config.mode = CompilationMode.NONE
|
|
self.vllm_config.model_config.enforce_eager = True
|
|
init_ascend_config(self.vllm_config)
|
|
|
|
proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
device=self.device,
|
|
runner=self.runner)
|
|
|
|
self.assertEqual(proposer.hidden_size, 2048)
|
|
self.assertFalse(proposer.use_cuda_graph)
|
|
expected_max_num_tokens = proposer.max_num_tokens
|
|
self.assertEqual(proposer.hidden_states.shape, (expected_max_num_tokens, 2048))
|
|
|
|
def test_initialization_eagle3_full_graph_async(self):
|
|
self.vllm_config.speculative_config.method = "eagle3"
|
|
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
|
|
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
|
self.vllm_config.model_config.enforce_eager = False
|
|
self.vllm_config.speculative_config.enforce_eager = False
|
|
self.vllm_config.scheduler_config.async_scheduling = True
|
|
init_ascend_config(self.vllm_config)
|
|
|
|
proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
device=self.device,
|
|
runner=self.runner)
|
|
|
|
self.assertEqual(proposer.hidden_size, 2048)
|
|
self.assertTrue(proposer.use_cuda_graph)
|
|
expected_max_num_tokens = proposer.max_num_tokens
|
|
self.assertEqual(proposer.hidden_states.shape, (expected_max_num_tokens, 2048))
|
|
|
|
def test_initialization_mtp_full_graph_async(self):
|
|
self.vllm_config.speculative_config.method = "mtp"
|
|
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
|
|
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
|
self.vllm_config.model_config.enforce_eager = False
|
|
self.vllm_config.speculative_config.enforce_eager = False
|
|
self.vllm_config.scheduler_config.async_scheduling = True
|
|
init_ascend_config(self.vllm_config)
|
|
|
|
proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
device=self.device,
|
|
runner=self.runner)
|
|
|
|
self.assertEqual(proposer.hidden_size, 2048)
|
|
self.assertFalse(proposer.use_cuda_graph)
|
|
expected_max_num_tokens = proposer.max_num_tokens
|
|
self.assertEqual(proposer.hidden_states.shape, (expected_max_num_tokens, 2048))
|
|
|
|
|
|
class TestEagleProposerLoadModel(TestBase):
|
|
|
|
def setUp(self):
|
|
self.vllm_config = MagicMock(spec=VllmConfig)
|
|
self.vllm_config.speculative_config = MagicMock()
|
|
self.vllm_config.speculative_config.method = "eagle"
|
|
self.device = torch.device("cpu")
|
|
self.runner = MagicMock()
|
|
self.runner.pin_memory = False
|
|
|
|
self.vllm_config.cache_config.block_size = 16
|
|
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
|
self.vllm_config.model_config.dtype = torch.float16
|
|
self.vllm_config.model_config.max_model_len = 2048
|
|
self.vllm_config.model_config.uses_mrope = False
|
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
|
(i + 1) * (0, ) for i in range(2)
|
|
])
|
|
self.vllm_config.additional_config = None
|
|
init_ascend_config(self.vllm_config)
|
|
|
|
self.mock_cpugpubuffer = patch(
|
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
|
self.mock_cpugpubuffer.start()
|
|
self.mock_supports_multimodal_inputs = patch(
|
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
|
)
|
|
self.mock_supports_multimodal_inputs.start()
|
|
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
device=self.device,
|
|
runner=self.runner)
|
|
|
|
def tearDown(self):
|
|
self.mock_cpugpubuffer.stop()
|
|
self.mock_supports_multimodal_inputs.stop()
|
|
|
|
@patch(
|
|
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
|
|
def test_load_model_pp1(self, mock_pp_group, mock_get_model,
|
|
mock_get_layers):
|
|
mock_pp_group.return_value.world_size = 1
|
|
mock_target_layer1 = MagicMock()
|
|
mock_target_layer2 = MagicMock()
|
|
mock_draft_layer1 = MagicMock()
|
|
mock_draft_layer3 = MagicMock()
|
|
mock_get_layers.side_effect = [{
|
|
"layer1": mock_target_layer1,
|
|
"layer2": mock_target_layer2
|
|
}, {}, {}, {
|
|
"layer1": mock_draft_layer1,
|
|
"layer3": mock_draft_layer3
|
|
}]
|
|
|
|
weight = torch.zeros(0)
|
|
|
|
mock_model = MagicMock()
|
|
mock_model.supports_multimodal = False
|
|
mock_model.lm_head = MagicMock()
|
|
mock_model.multimodal_cpu_fields = None
|
|
mock_model.merge_by_field_config = None
|
|
mock_model.model.embed_tokens = MagicMock()
|
|
mock_model.model.embed_tokens.weight = weight
|
|
|
|
self.proposer.name = SpecDcodeType.EAGLE
|
|
mock_get_model.return_value = MagicMock()
|
|
mock_get_model.return_value.model.embed_tokens.weight = weight
|
|
|
|
self.proposer.load_model(mock_model)
|
|
mock_get_model.assert_called_once()
|
|
self.assertEqual(self.proposer.attn_layer_names, ["layer3"])
|
|
self.assertIs(self.proposer.model.model.embed_tokens,
|
|
mock_model.model.embed_tokens)
|
|
|
|
@patch(
|
|
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
|
|
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,
|
|
mock_get_layers):
|
|
mock_pp_group.return_value.world_size = 2
|
|
mock_target_layer1 = MagicMock()
|
|
mock_draft_layer2 = MagicMock()
|
|
|
|
mock_get_layers.side_effect = [{
|
|
"layer1": mock_target_layer1
|
|
}, {}, {}, {
|
|
"layer2": mock_draft_layer2
|
|
}]
|
|
|
|
mock_model = MagicMock()
|
|
original_embed = MagicMock()
|
|
mock_model.multimodal_cpu_fields = None
|
|
mock_model.merge_by_field_config = None
|
|
mock_get_model.return_value = MagicMock(model=MagicMock(
|
|
embed_tokens=original_embed))
|
|
|
|
self.proposer.load_model(mock_model)
|
|
|
|
self.assertIsNot(self.proposer.model.model.embed_tokens,
|
|
mock_model.model.embed_tokens)
|
|
self.assertEqual(self.proposer.attn_layer_names, ["layer2"])
|
|
|
|
@patch(
|
|
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal")
|
|
def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group,
|
|
mock_get_model, mock_get_layers):
|
|
mock_model = MagicMock()
|
|
mock_model.get_language_model.return_value.lm_head = MagicMock()
|
|
mock_supports_multi.return_value = True
|
|
original_embed = MagicMock()
|
|
mock_get_model.return_value = MagicMock(model=MagicMock(
|
|
embed_tokens=original_embed))
|
|
|
|
mock_target_layer1 = MagicMock()
|
|
mock_draft_layer2 = MagicMock()
|
|
|
|
mock_get_layers.side_effect = [{
|
|
"layer1": mock_target_layer1
|
|
}, {}, {}, {
|
|
"layer2": mock_draft_layer2
|
|
}]
|
|
mock_pp_group.return_value.world_size = 2
|
|
|
|
self.proposer.model = MagicMock()
|
|
self.proposer.name = SpecDcodeType.EAGLE
|
|
|
|
self.proposer.load_model(mock_model)
|
|
self.assertEqual(mock_model.get_language_model.call_count, 2)
|
|
self.assertIs(self.proposer.model.lm_head,
|
|
mock_model.get_language_model.return_value.lm_head)
|
|
|
|
|
|
class TestEagleProposerDummyRun(TestBase):
|
|
|
|
def setUp(self):
|
|
self.vllm_config = MagicMock(spec=VllmConfig)
|
|
self.vllm_config.speculative_config = MagicMock()
|
|
self.vllm_config.speculative_config.num_speculative_tokens = 4
|
|
self.device = torch.device("cpu")
|
|
self.runner = MagicMock()
|
|
self.runner.pcp_size = 1
|
|
self.runner.dcp_size = 1
|
|
self.runner.pin_memory = False
|
|
self.runner._sync_metadata_across_dp.return_value = (8, torch.tensor([8]), False)
|
|
|
|
self.vllm_config.cache_config.block_size = 16
|
|
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
|
self.vllm_config.model_config.dtype = torch.float16
|
|
self.vllm_config.model_config.max_model_len = 2048
|
|
self.vllm_config.model_config.uses_mrope = False
|
|
self.vllm_config.model_config.use_mla = False
|
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
|
(i + 1) * (0, ) for i in range(4)
|
|
])
|
|
self.vllm_config.additional_config = None
|
|
init_ascend_config(self.vllm_config)
|
|
|
|
self.mock_cpugpubuffer = patch(
|
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
|
self.mock_cpugpubuffer.start()
|
|
self.mock_supports_multimodal_inputs = patch(
|
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
|
)
|
|
self.mock_supports_multimodal_inputs.start()
|
|
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
device=self.device,
|
|
runner=self.runner)
|
|
self.proposer.model = MagicMock()
|
|
self.proposer._runnable = MagicMock()
|
|
self.proposer.update_stream = MagicMock()
|
|
|
|
def tearDown(self):
|
|
self.mock_cpugpubuffer.stop()
|
|
self.mock_supports_multimodal_inputs.stop()
|
|
|
|
# cpu does not support parallel-group, let alone `sp`
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
|
**{"return_value.sp_enabled": False})
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
|
def test_dummy_run_basic(self, mock_context, mock_get_context):
|
|
num_tokens = 32
|
|
with_prefill = False
|
|
|
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
|
self.proposer.enable_shared_expert_dp = False
|
|
self.proposer.dummy_run(num_tokens=num_tokens,
|
|
with_prefill=with_prefill)
|
|
|
|
self.assertTrue(self.proposer._runnable.call_count == 1)
|
|
|
|
# cpu does not support parallel-group, let alone `sp`
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
|
**{"return_value.sp_enabled": False})
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
|
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
|
|
mock_context.return_value.__enter__.return_value = None
|
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
|
self.proposer.enable_shared_expert_dp = False
|
|
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
|
|
self.assertTrue(self.proposer._runnable.call_count == 1)
|
|
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
|
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
|
|
mock_update_full_graph_params):
|
|
last_use_cuda_graph = self.proposer.use_cuda_graph
|
|
mock_return_context = MagicMock()
|
|
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
mock_return_context.capturing = True
|
|
# cpu does not support parallel-group, let alone `sp`
|
|
mock_return_context.sp_enabled = False
|
|
mock_get_context.return_value = mock_return_context
|
|
self.proposer.use_cuda_graph = True
|
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
|
self.proposer.enable_shared_expert_dp = False
|
|
self.proposer.dummy_run(num_tokens=64,
|
|
in_graph_capturing=True,
|
|
aclgraph_runtime_mode=CUDAGraphMode.FULL)
|
|
self.assertTrue(self.proposer._runnable.call_count == 1)
|
|
mock_update_full_graph_params.assert_not_called()
|
|
self.proposer.use_cuda_graph = last_use_cuda_graph
|
|
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
|
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
|
|
mock_update_full_graph_params):
|
|
last_use_cuda_graph = self.proposer.use_cuda_graph
|
|
mock_return_context = MagicMock()
|
|
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
|
mock_return_context.capturing = False
|
|
# cpu does not support parallel-group, let alone `sp`
|
|
mock_return_context.sp_enabled = False
|
|
mock_get_context.return_value = mock_return_context
|
|
self.proposer.use_cuda_graph = True
|
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
|
self.proposer.enable_shared_expert_dp = False
|
|
self.proposer.dummy_run(num_tokens=64,
|
|
in_graph_capturing=False,
|
|
aclgraph_runtime_mode=CUDAGraphMode.FULL)
|
|
self.assertTrue(self.proposer._runnable.call_count == 1)
|
|
self.assertTrue(mock_update_full_graph_params.call_count == 1)
|
|
self.proposer.use_cuda_graph = last_use_cuda_graph
|
|
|
|
|
|
class TestEagleProposerHelperMethods(TestBase):
|
|
|
|
# TODO: Can add some tests about prepare_next_token_ids in future.
|
|
|
|
def setUp(self):
|
|
self.vllm_config = MagicMock(spec=VllmConfig)
|
|
self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3)
|
|
self.device = torch.device("cpu")
|
|
self.runner = MagicMock()
|
|
self.runner.input_batch = MagicMock()
|
|
self.runner.input_batch.req_ids = [0, 1, 2]
|
|
self.runner.arange_np = np.arange(10)
|
|
self.runner.input_batch.num_reqs = 3
|
|
self.runner.pin_memory = False
|
|
|
|
self.vllm_config.cache_config.block_size = 16
|
|
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
|
self.vllm_config.model_config.dtype = torch.float16
|
|
self.vllm_config.model_config.max_model_len = 2048
|
|
self.vllm_config.model_config.uses_mrope = False
|
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
|
(i + 1) * (0, ) for i in range(2)
|
|
])
|
|
self.vllm_config.additional_config = None
|
|
init_ascend_config(self.vllm_config)
|
|
|
|
self.mock_cpugpubuffer = patch(
|
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
|
self.mock_cpugpubuffer.start()
|
|
self.mock_supports_multimodal_inputs = patch(
|
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
|
)
|
|
self.mock_supports_multimodal_inputs.start()
|
|
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
device=self.device,
|
|
runner=self.runner)
|
|
|
|
def tearDown(self):
|
|
self.mock_cpugpubuffer.stop()
|
|
self.mock_supports_multimodal_inputs.stop()
|
|
|
|
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
|
# We need to add a test_prepare_inputs_padded in future.
|
|
def test_prepare_inputs(self):
|
|
self.proposer.token_arange_np = np.arange(10)
|
|
mock_attn = MagicMock()
|
|
mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
|
|
num_rejected = torch.tensor([1, 0, 1], device=self.device)
|
|
mock_return_attn = MagicMock()
|
|
|
|
with patch.object(self.proposer,
|
|
'prepare_inputs',
|
|
return_value=(mock_return_attn,
|
|
torch.tensor([1, 2, 4]))):
|
|
return_attn, indices = self.proposer.prepare_inputs(
|
|
mock_attn, num_rejected)
|
|
self.assertEqual(indices.tolist(), [1, 2, 4])
|