from unittest.mock import MagicMock, patch import numpy as np import torch from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig from tests.ut.base import TestBase from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType class TestEagleProposerInitialization(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.vllm_config.cache_config = MagicMock(spec=CacheConfig) self.vllm_config.scheduler_config = MagicMock() self.vllm_config.model_config = MagicMock() self.device = torch.device("cpu") self.runner = MagicMock() self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(2) ]) self.vllm_config.additional_config = None self.mock_cpugpubuffer = patch( "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" ) self.mock_supports_multimodal_inputs.start() def tearDown(self): self.mock_cpugpubuffer.stop() self.mock_supports_multimodal_inputs.stop() def test_initialization_eagle_graph(self): self.vllm_config.speculative_config.method = "eagle" self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE self.vllm_config.model_config.enforce_eager = False self.vllm_config.speculative_config.enforce_eager = False self.vllm_config.scheduler_config.async_scheduling = False init_ascend_config(self.vllm_config) proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.hidden_size, 4096) self.assertTrue(proposer.use_cuda_graph) self.assertEqual(proposer.input_ids.shape, (1024, )) self.assertEqual(proposer.positions.shape, (1024, )) self.assertEqual(proposer.hidden_states.shape, (1024, 4096)) self.assertEqual(proposer.arange.shape, (1024, )) def test_initialization_eagle3_enforce_eager(self): self.vllm_config.speculative_config.method = "eagle3" self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 self.vllm_config.compilation_config.mode = CompilationMode.NONE self.vllm_config.model_config.enforce_eager = True init_ascend_config(self.vllm_config) proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.hidden_size, 2048) self.assertFalse(proposer.use_cuda_graph) self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) def test_initialization_eagle3_full_graph_async(self): self.vllm_config.speculative_config.method = "eagle3" self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE self.vllm_config.model_config.enforce_eager = False self.vllm_config.speculative_config.enforce_eager = False self.vllm_config.scheduler_config.async_scheduling = True init_ascend_config(self.vllm_config) proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.hidden_size, 2048) self.assertFalse(proposer.use_cuda_graph) self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) class TestEagleProposerLoadModel(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config.method = "eagle" self.device = torch.device("cpu") self.runner = MagicMock() self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(2) ]) self.vllm_config.additional_config = None init_ascend_config(self.vllm_config) self.mock_cpugpubuffer = patch( "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" ) self.mock_supports_multimodal_inputs.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) def tearDown(self): self.mock_cpugpubuffer.stop() self.mock_supports_multimodal_inputs.stop() @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") def test_load_model_pp1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 1 mock_target_layer1 = MagicMock() mock_target_layer2 = MagicMock() mock_draft_layer1 = MagicMock() mock_draft_layer3 = MagicMock() mock_get_layers.side_effect = [{ "layer1": mock_target_layer1, "layer2": mock_target_layer2 }, {}, {}, { "layer1": mock_draft_layer1, "layer3": mock_draft_layer3 }] mock_model = MagicMock() mock_model.model.embed_tokens = MagicMock() mock_model.lm_head = MagicMock() mock_model.multimodal_cpu_fields = None mock_model.merge_by_field_config = None mock_get_model.return_value = MagicMock() self.proposer.name = SpecDcodeType.EAGLE self.proposer.load_model(mock_model) mock_get_model.assert_called_once() self.assertEqual(self.proposer.attn_layer_name, ["layer3"]) self.assertIs(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 2 mock_target_layer1 = MagicMock() mock_draft_layer2 = MagicMock() mock_get_layers.side_effect = [{ "layer1": mock_target_layer1 }, {}, {}, { "layer2": mock_draft_layer2 }] mock_model = MagicMock() original_embed = MagicMock() mock_model.multimodal_cpu_fields = None mock_model.merge_by_field_config = None mock_get_model.return_value = MagicMock(model=MagicMock( embed_tokens=original_embed)) self.proposer.load_model(mock_model) self.assertIsNot(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) self.assertEqual(self.proposer.attn_layer_name, ["layer2"]) @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") @patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal") def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, mock_get_model, mock_get_layers): mock_model = MagicMock() mock_model.get_language_model.return_value.lm_head = MagicMock() mock_supports_multi.return_value = True original_embed = MagicMock() mock_get_model.return_value = MagicMock(model=MagicMock( embed_tokens=original_embed)) mock_target_layer1 = MagicMock() mock_draft_layer2 = MagicMock() mock_get_layers.side_effect = [{ "layer1": mock_target_layer1 }, {}, {}, { "layer2": mock_draft_layer2 }] mock_pp_group.return_value.world_size = 2 self.proposer.model = MagicMock() self.proposer.name = SpecDcodeType.EAGLE self.proposer.load_model(mock_model) mock_model.get_language_model.assert_called_once() self.assertIs(self.proposer.model.lm_head, mock_model.get_language_model.return_value.lm_head) class TestEagleProposerDummyRun(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config.num_speculative_tokens = 4 self.device = torch.device("cpu") self.runner = MagicMock() self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(4) ]) self.vllm_config.additional_config = None init_ascend_config(self.vllm_config) self.mock_cpugpubuffer = patch( "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" ) self.mock_supports_multimodal_inputs.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.proposer.model = MagicMock() self.proposer.update_stream = MagicMock() def tearDown(self): self.mock_cpugpubuffer.stop() self.mock_supports_multimodal_inputs.stop() @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_basic(self, mock_context, mock_get_context): num_tokens = 32 with_prefill = False # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=num_tokens, with_prefill=with_prefill) self.assertTrue(self.proposer.model.call_count == 4) @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_with_prefill(self, mock_context, mock_get_context): mock_context.return_value.__enter__.return_value = None # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.assertTrue(self.proposer.model.call_count == 4) @patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context, mock_update_attn_params): last_use_cuda_graph = self.proposer.use_cuda_graph mock_return_context = MagicMock() mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.capturing = True mock_get_context.return_value = mock_return_context self.proposer.use_cuda_graph = True # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=64, in_graph_capturing=True, aclgraph_runtime_mode=CUDAGraphMode.FULL) self.assertTrue(self.proposer.model.call_count == 4) mock_update_attn_params.assert_not_called() self.proposer.use_cuda_graph = last_use_cuda_graph @patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_in_graph_run(self, mock_context, mock_get_context, mock_update_attn_params): last_use_cuda_graph = self.proposer.use_cuda_graph mock_return_context = MagicMock() mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.capturing = False mock_get_context.return_value = mock_return_context self.proposer.use_cuda_graph = True # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=64, in_graph_capturing=False, aclgraph_runtime_mode=CUDAGraphMode.FULL) self.assertTrue(self.proposer.model.call_count == 4) self.assertTrue(mock_update_attn_params.call_count == 4) self.proposer.use_cuda_graph = last_use_cuda_graph class TestEagleProposerHelperMethods(TestBase): # TODO: Can add some tests about prepare_next_token_ids in future. def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3) self.device = torch.device("cpu") self.runner = MagicMock() self.runner.input_batch = MagicMock() self.runner.input_batch.req_ids = [0, 1, 2] self.runner.arange_np = np.arange(10) self.runner.input_batch.num_reqs = 3 self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(2) ]) self.vllm_config.additional_config = None init_ascend_config(self.vllm_config) self.mock_cpugpubuffer = patch( "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" ) self.mock_supports_multimodal_inputs.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) def tearDown(self): self.mock_cpugpubuffer.stop() self.mock_supports_multimodal_inputs.stop() # TODO: This is equivalent to disable_padded_drafter_batch=True. # We need to add a test_prepare_inputs_padded in future. def test_prepare_inputs(self): self.proposer.token_arange_np = np.arange(10) mock_attn = MagicMock() mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) num_rejected = torch.tensor([1, 0, 1], device=self.device) mock_return_attn = MagicMock() with patch.object(self.proposer, 'prepare_inputs', return_value=(mock_return_attn, torch.tensor([1, 2, 4]))): return_attn, indices = self.proposer.prepare_inputs( mock_attn, num_rejected) self.assertEqual(indices.tolist(), [1, 2, 4])