from unittest.mock import MagicMock, patch import numpy as np import torch from vllm.config import CacheConfig, CompilationMode, VllmConfig from tests.ut.base import TestBase from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType class TestEagleProposerInitialization(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.vllm_config.cache_config = MagicMock(spec=CacheConfig) self.vllm_config.scheduler_config = MagicMock() self.vllm_config.model_config = MagicMock() self.device = torch.device("cpu") self.runner = MagicMock() self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 def test_initialization_eagle(self): self.vllm_config.speculative_config.method = "eagle" self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE self.vllm_config.model_config.enforce_eager = False proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.name, SpecDcodeType.EAGLE) self.assertEqual(proposer.block_size, 16) self.assertEqual(proposer.hidden_size, 4096) self.assertTrue(proposer.use_cuda_graph) self.assertEqual(proposer.input_ids.shape, (1024, )) self.assertEqual(proposer.positions.shape, (1024, )) self.assertEqual(proposer.hidden_states.shape, (1024, 4096)) self.assertEqual(proposer.arange.shape, (33, )) def test_initialization_eagle3(self): self.vllm_config.speculative_config.method = "eagle3" self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 self.vllm_config.compilation_config.mode = CompilationMode.NONE self.vllm_config.model_config.enforce_eager = True proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.name, SpecDcodeType.EAGLE3) self.assertEqual(proposer.hidden_size, 2048) self.assertFalse(proposer.use_cuda_graph) self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) class TestEagleProposerLoadModel(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config.method = "eagle" self.device = torch.device("cpu") self.runner = MagicMock() self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") def test_load_model_pp1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 1 mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()} mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()} mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] mock_model = MagicMock() mock_model.model.embed_tokens = MagicMock() mock_model.lm_head = MagicMock() mock_get_model.return_value = MagicMock() self.proposer.name = SpecDcodeType.EAGLE self.proposer.load_model(mock_model) mock_get_model.assert_called_once() self.assertEqual(self.proposer.attn_layer_name, "layer3") self.assertIs(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 2 mock_target_layers = {"layer1": MagicMock()} mock_draft_layers = {"layer2": MagicMock()} mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] mock_model = MagicMock() original_embed = MagicMock() mock_get_model.return_value = MagicMock(model=MagicMock( embed_tokens=original_embed)) self.proposer.load_model(mock_model) self.assertIsNot(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) self.assertEqual(self.proposer.attn_layer_name, "layer2") @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") @patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal") def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, mock_get_model, mock_get_layers): mock_model = MagicMock() mock_model.get_language_model.return_value.lm_head = MagicMock() mock_supports_multi.return_value = True original_embed = MagicMock() mock_get_model.return_value = MagicMock(model=MagicMock( embed_tokens=original_embed)) mock_target_layers = {"layer1": MagicMock()} mock_draft_layers = {"layer2": MagicMock()} mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] mock_pp_group.return_value.world_size = 2 self.proposer.model = MagicMock() self.proposer.name = SpecDcodeType.EAGLE self.proposer.load_model(mock_model) mock_model.get_language_model.assert_called_once() self.assertIs(self.proposer.model.lm_head, mock_model.get_language_model.return_value.lm_head) class TestEagleProposerDummyRun(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.device = torch.device("cpu") self.runner = MagicMock() self.runner._select_moe_comm_method.return_value = "alltoall" self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.proposer.model = MagicMock() @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_basic(self, mock_context): num_tokens = 32 with_prefill = False self.proposer.dummy_run(num_tokens=num_tokens, with_prefill=with_prefill) mock_context.assert_called_once() @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_with_prefill(self, mock_context): mock_context.return_value.__enter__.return_value = None self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.runner._select_moe_comm_method.assert_called_with(64) self.proposer.model.assert_called_once() class TestEagleProposerGenerateTokenIds(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config.method = "eagle" self.device = torch.device("cpu") self.runner = MagicMock() self.runner.input_batch = MagicMock() self.runner.input_batch.req_ids = [0, 1, 2] self.runner.requests = { 0: MagicMock(get_token_id=lambda x: 100), 1: MagicMock(get_token_id=lambda x: 101), 2: MagicMock(get_token_id=lambda x: 102), } self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.proposer.attn_layer_name = "layer_0" self.proposer._propose = MagicMock( return_value=torch.tensor([[1, 2], [3, 4], [5, 6]])) def test_generate_token_ids_without_metadata(self): valid_sampled = [[20, 30, 40]] scheduler_output = MagicMock() scheduler_output.num_scheduled_tokens = [2, 1, 3] positions = torch.tensor([0, 1, 2, 3, 4, 5]) hidden_states = torch.randn(6, 4096) num_scheduled = 6 mock_attn_metadata = MagicMock() mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6]) mock_attn_metadata.block_tables = MagicMock() self.proposer._get_eagle_atten_dict = MagicMock( return_value={"layer_0": mock_attn_metadata}) result = self.proposer.generate_token_ids( valid_sampled_token_ids=valid_sampled, scheduler_output=scheduler_output, positions=positions, num_scheduled_tokens=num_scheduled, hidden_states=hidden_states, ) self.proposer._propose.assert_called_once() self.assertEqual(result, [[1, 2], [3, 4], [5, 6]]) def test_generate_token_ids_with_metadata(self): valid_sampled = [[5], [6, 7], [8, 9, 10]] spec_metadata = MagicMock() spec_metadata.num_draft_tokens = [2, 3, 4] mock_attn_metadata = MagicMock() mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) mock_attn_metadata.query_start_loc = torch.tensor([0, 1, 3, 6]) mock_attn_metadata.block_tables = MagicMock() self.proposer._get_eagle_atten_dict = MagicMock( return_value={"layer_0": mock_attn_metadata}) self.proposer._prepare_inputs = MagicMock( return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 3, 5]))) result = self.proposer.generate_token_ids( valid_sampled_token_ids=valid_sampled, spec_decode_metadata=spec_metadata, positions=torch.randn(6, 1), hidden_states=torch.randn(6, 4096), ) self.proposer._prepare_inputs.assert_called_once() self.assertEqual(self.proposer._propose.call_count, 1) self.assertEqual(len(result), 3) class TestEagleProposerHelperMethods(TestBase): def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3) self.device = torch.device("cpu") self.runner = MagicMock() self.runner.input_batch = MagicMock() self.runner.input_batch.req_ids = [0, 1, 2] self.runner.arange_np = np.arange(10) self.runner.input_batch.num_reqs = 3 self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) def test_prepare_inputs(self): self.proposer.token_arange_np = np.arange(10) mock_attn = MagicMock() mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) num_rejected = torch.tensor([1, 0, 1], device=self.device) with patch.object(self.proposer, '_prepare_inputs', return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 2, 4]))): cu_num_tokens, indices = self.proposer._prepare_inputs( mock_attn, num_rejected) self.assertEqual(cu_num_tokens.tolist(), [0, 2, 5]) self.assertEqual(indices.tolist(), [1, 2, 4])