From 92353c064311b627078046c65735b795a1d6c880 Mon Sep 17 00:00:00 2001 From: Zetong Li <48438720+slippersss@users.noreply.github.com> Date: Mon, 29 Dec 2025 16:25:52 +0800 Subject: [PATCH] [Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (#5176) ### What this PR does / why we need it? This PR aims to refactor eagle-related modules in vllm-ascend. This is the starting PR of eagle refactoring. Provided with vllm-eagle, ascend-eagle and ascend-mtp, we first let ascend-mtp inherit from ascend-eagle and let ascend-eagle inherit from vllm-eagle. As a initialization, we just delete `__init__` in mtp_proposer and simplify the corresponding logic in eagle_proposer. Based on "vllm-eagle <----- ascend-eagle <----- ascend-mtp", our target is to gradually delete ascend-mtp and enable ascend-eagle to converge to vllm-eagle. So the main workspace is eagle_proposer. In this way, we hope that contributors can concurrently refactor eagle. Incoming changes: 1. delete common methods in vllm-eagle & ascend-eagle & ascend-mtp 2. delete `load_model` in mtp_proposer 3. delete `dummy_run` and `propose` in mtp_proposer 4. ...... RFC: #5467 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Zetong Li --- tests/ut/spec_decode/test_eagle_proposer.py | 76 ++++++++++++-- tests/ut/spec_decode/test_mtp_proposer.py | 30 ++++-- vllm_ascend/spec_decode/eagle_proposer.py | 83 +++++---------- vllm_ascend/spec_decode/mtp_proposer.py | 106 ++------------------ 4 files changed, 119 insertions(+), 176 deletions(-) diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 9c0d29dc..0c85d9d7 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -5,6 +5,7 @@ import torch from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig from tests.ut.base import TestBase +from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType @@ -25,13 +26,24 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.vllm_config.model_config.uses_mrope = False + self.vllm_config.speculative_config.num_speculative_tokens = 2 + self.vllm_config.speculative_config.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(2) + ]) + self.vllm_config.additional_config = None self.mock_cpugpubuffer = patch( - "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() + self.mock_supports_multimodal_inputs = patch( + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" + ) + self.mock_supports_multimodal_inputs.start() def tearDown(self): self.mock_cpugpubuffer.stop() + self.mock_supports_multimodal_inputs.stop() def test_initialization_eagle_graph(self): self.vllm_config.speculative_config.method = "eagle" @@ -40,12 +52,12 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.model_config.enforce_eager = False self.vllm_config.speculative_config.enforce_eager = False self.vllm_config.scheduler_config.async_scheduling = False + init_ascend_config(self.vllm_config) proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) - self.assertEqual(proposer.name, SpecDcodeType.EAGLE) self.assertEqual(proposer.block_size, 16) self.assertEqual(proposer.hidden_size, 4096) self.assertTrue(proposer.use_cuda_graph) @@ -60,12 +72,12 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 self.vllm_config.compilation_config.mode = CompilationMode.NONE self.vllm_config.model_config.enforce_eager = True + init_ascend_config(self.vllm_config) proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) - self.assertEqual(proposer.name, SpecDcodeType.EAGLE3) self.assertEqual(proposer.hidden_size, 2048) self.assertFalse(proposer.use_cuda_graph) self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) @@ -77,12 +89,12 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.model_config.enforce_eager = False self.vllm_config.speculative_config.enforce_eager = False self.vllm_config.scheduler_config.async_scheduling = True + init_ascend_config(self.vllm_config) proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) - self.assertEqual(proposer.name, SpecDcodeType.EAGLE3) self.assertEqual(proposer.hidden_size, 2048) self.assertFalse(proposer.use_cuda_graph) self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) @@ -102,16 +114,28 @@ class TestEagleProposerLoadModel(TestBase): self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.vllm_config.model_config.uses_mrope = False + self.vllm_config.speculative_config.num_speculative_tokens = 2 + self.vllm_config.speculative_config.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(2) + ]) + self.vllm_config.additional_config = None + init_ascend_config(self.vllm_config) self.mock_cpugpubuffer = patch( - "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() + self.mock_supports_multimodal_inputs = patch( + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" + ) + self.mock_supports_multimodal_inputs.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) def tearDown(self): self.mock_cpugpubuffer.stop() + self.mock_supports_multimodal_inputs.stop() @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @@ -204,10 +228,20 @@ class TestEagleProposerDummyRun(TestBase): self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.vllm_config.model_config.uses_mrope = False + self.vllm_config.speculative_config.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(4) + ]) + self.vllm_config.additional_config = None + init_ascend_config(self.vllm_config) self.mock_cpugpubuffer = patch( - "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() + self.mock_supports_multimodal_inputs = patch( + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" + ) + self.mock_supports_multimodal_inputs.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -216,6 +250,7 @@ class TestEagleProposerDummyRun(TestBase): def tearDown(self): self.mock_cpugpubuffer.stop() + self.mock_supports_multimodal_inputs.stop() @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @@ -287,16 +322,28 @@ class TestEagleProposerGenerateTokenIds(TestBase): 1: MagicMock(get_token_id=lambda x: 101), 2: MagicMock(get_token_id=lambda x: 102), } + self.runner.pcp_size = 1 self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.vllm_config.model_config.uses_mrope = False + self.vllm_config.speculative_config.num_speculative_tokens = 2 + self.vllm_config.speculative_config.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(2) + ]) + self.vllm_config.additional_config = None + init_ascend_config(self.vllm_config) self.mock_cpugpubuffer = patch( - "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() + self.mock_supports_multimodal_inputs = patch( + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" + ) + self.mock_supports_multimodal_inputs.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -306,6 +353,7 @@ class TestEagleProposerGenerateTokenIds(TestBase): def tearDown(self): self.mock_cpugpubuffer.stop() + self.mock_supports_multimodal_inputs.stop() # TODO: This is equivalent to disable_padded_drafter_batch=True. # We need to add some cases about disable_padded_drafter_batch=False in future. @@ -355,16 +403,28 @@ class TestEagleProposerHelperMethods(TestBase): self.vllm_config.scheduler_config.max_num_seqs = 32 self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.vllm_config.model_config.uses_mrope = False + self.vllm_config.speculative_config.num_speculative_tokens = 2 + self.vllm_config.speculative_config.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(2) + ]) + self.vllm_config.additional_config = None + init_ascend_config(self.vllm_config) self.mock_cpugpubuffer = patch( - "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + "vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() + self.mock_supports_multimodal_inputs = patch( + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" + ) + self.mock_supports_multimodal_inputs.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) def tearDown(self): self.mock_cpugpubuffer.stop() + self.mock_supports_multimodal_inputs.stop() # TODO: This is equivalent to disable_padded_drafter_batch=True. # We need to add a test_prepare_inputs_padded in future. diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 9e9eb295..dfec5a3c 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -16,12 +16,18 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer class TestMtpProposer: + @pytest.fixture(autouse=True) + def patch_supports_multimodal_inputs(self): + with patch( + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" + ): + yield + @pytest.fixture def vllm_config(self): config = MagicMock(spec=VllmConfig) @@ -31,6 +37,9 @@ class TestMtpProposer: config.speculative_config.method = "deepseek_mtp" config.speculative_config.draft_model_config = MagicMock() config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 + config.speculative_config.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(2) + ]) config.model_config = MagicMock(spec=ModelConfig) config.model_config.dtype = torch.float16 @@ -68,7 +77,7 @@ class TestMtpProposer: runner.reserved_mc2_mask = None return runner - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner): mock_buffer_instance = MagicMock() mock_cpu_gpu_buffer.return_value = mock_buffer_instance @@ -76,7 +85,6 @@ class TestMtpProposer: # Test basic initialization proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) - assert proposer.name == SpecDcodeType.MTP assert proposer.vllm_config == vllm_config assert proposer.device == torch.device("cpu") assert proposer.dtype == torch.float16 @@ -89,7 +97,7 @@ class TestMtpProposer: assert not hasattr(proposer, "mrope_positions") assert proposer.use_sparse is False - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config, runner): mock_buffer_instance = MagicMock() @@ -105,7 +113,7 @@ class TestMtpProposer: "vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading") @patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype") @patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config") - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config, mock_set_dtype, mock_process_weights, mock_get_loader, mock_get_layers, vllm_config, runner): @@ -148,7 +156,7 @@ class TestMtpProposer: @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context, mock_get_forward_context, vllm_config, runner): mock_buffer_instance = MagicMock() @@ -173,7 +181,7 @@ class TestMtpProposer: @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context, mock_get_forward_context, vllm_config, runner): @@ -201,7 +209,7 @@ class TestMtpProposer: # Check that model was called correct number of times assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_generate_token_ids(self, mock_cpu_gpu_buffer): mock_buffer_instance = MagicMock() mock_cpu_gpu_buffer.return_value = mock_buffer_instance @@ -272,7 +280,7 @@ class TestMtpProposer: proposer._propose.assert_called_once() assert torch.equal(draft_token_ids, proposer._propose.return_value) - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer): mock_buffer_instance = MagicMock() mock_cpu_gpu_buffer.return_value = mock_buffer_instance @@ -295,7 +303,7 @@ class TestMtpProposer: assert torch.all( result == torch.tensor([30, 50, 60], dtype=torch.int32)) - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer): mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata) mock_common_attn_metadata.seq_lens_cpu = torch.tensor( @@ -377,7 +385,7 @@ class TestMtpProposer: device=torch.device("cpu")) assert torch.equal(next_token_ids, expected_next_tokens) - @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer): mock_buffer_instance = MagicMock() mock_cpu_gpu_buffer.return_value = mock_buffer_instance diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 9546208a..19de161c 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -18,8 +18,8 @@ from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context @@ -29,7 +29,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, update_attn_params) from vllm_ascend.ops.rotary_embedding import update_cos_sin -from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType +from vllm_ascend.utils import shared_expert_dp_enabled PADDING_SLOT_ID = -1 @@ -38,29 +38,15 @@ _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn' _FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'} -class EagleProposer(Proposer): +class EagleProposer(VllmEagleProposer): def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): - self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3 - self.vllm_config = vllm_config - self.device = device - self.runner = runner - self.speculative_config = vllm_config.speculative_config - self.draft_model_config = self.speculative_config.draft_model_config - self.method = self.speculative_config.method - self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + super().__init__(vllm_config, device, runner) + self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling - - self.block_size = vllm_config.cache_config.block_size - # We need to get the hidden size from the draft model config because - # the draft model's hidden size can be different from the target model's - # hidden size (e.g., Llama 3.3 70B). - self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size( - ) - # there is synchronization between mtp steps when enabling aclgraph, # disable aclgraph when use async scheduling to avoid the # synchronization overhead. @@ -77,45 +63,28 @@ class EagleProposer(Proposer): sorted( self.vllm_config.compilation_config.cudagraph_capture_sizes)) - max_batch_size = vllm_config.scheduler_config.max_num_seqs - # Currently we do not use pcp. This is used to adapt the pcp branch. - self.pcp_size = 0 - self.backup_next_token_ids = CpuGpuBuffer( - max_batch_size, - dtype=torch.int32, - pin_memory=is_pin_memory_available(), - device=device, - with_numpy=True, - ) + self.pcp_size = self.runner.pcp_size self.decode_threshold = 1 + self.num_speculative_tokens - # persistent buffers for cuda graph - self.input_ids = torch.zeros( - self.vllm_config.scheduler_config.max_num_batched_tokens, - dtype=torch.int32, - device=device) - self.positions = torch.zeros( - self.vllm_config.scheduler_config.max_num_batched_tokens, - dtype=torch.int64, - device=device) - self.hidden_states = torch.zeros( - (self.vllm_config.scheduler_config.max_num_batched_tokens, - self.hidden_size), - dtype=self.vllm_config.model_config.dtype, - device=device) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) - self.token_arange_np = np.arange(self.max_num_tokens) - max_num_slots_for_arange = max(self.max_num_tokens, max_batch_size + 1) - self.arange = torch.arange(max_num_slots_for_arange, - device=device, - dtype=torch.int32) - self.arange_cpu = torch.arange(max_num_slots_for_arange, + self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32) self.attn_mask_builder = AttentionMaskBuilder(self.device) - self.eagle3_use_aux_hidden_state: bool = ( - self._get_eagle3_use_aux_hidden_state_from_config()) + + self.enable_shared_expert_dp = shared_expert_dp_enabled() + + self.dcp_size = self.runner.dcp_size + self.pcp_rank = self.runner.pcp_rank + self.dcp_rank = self.runner.dcp_rank + + self.use_aclgraph = self.runner._use_aclgraph() + + self.full_indices = range( + self.runner.max_num_tokens * self.pcp_size * self.dcp_size + + self.pcp_size * self.dcp_size * self.runner.max_num_reqs) + + self.use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: """ @@ -165,7 +134,7 @@ class EagleProposer(Proposer): # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"): + if self.method == "eagle" and hasattr(model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") if supports_multimodal(model): self.model.lm_head = model.get_language_model().lm_head @@ -337,7 +306,7 @@ class EagleProposer(Proposer): target_token_ids = self.runner.input_ids.gpu[: num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] - if self.name == SpecDcodeType.EAGLE3: + if self.method == "eagle3": target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) @@ -371,7 +340,7 @@ class EagleProposer(Proposer): else: target_token_ids = self.runner.input_ids.gpu[token_indices] target_positions = positions[token_indices] - if self.name == SpecDcodeType.EAGLE3: + if self.method == "eagle3": target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) else: @@ -424,7 +393,7 @@ class EagleProposer(Proposer): if last_token_indices is None: last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 - if self.name == SpecDcodeType.EAGLE3: + if self.method == "eagle3": assert isinstance(self.get_model(), Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( target_hidden_states) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 26ac172a..84049d8a 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -5,8 +5,8 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from vllm.config import (CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, set_current_vllm_config) +from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config, + set_current_vllm_config) from vllm.distributed import get_pcp_group from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -20,12 +20,10 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import set_default_torch_dtype -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context @@ -35,9 +33,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, update_mla_attn_dcp_pcp_params, update_mla_attn_params) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType -from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, - shared_expert_dp_enabled) +from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable logger = init_logger(__name__) @@ -64,102 +61,11 @@ def _load_model(architecture): return model -class MtpProposer(Proposer): +class MtpProposer(EagleProposer): # TODO: Find out why ModelRunner does not this explicit typing? model: Union[nn.Module, ACLGraphWrapper] - def __init__( - self, - vllm_config: VllmConfig, - device, - runner, - ): - self.name = SpecDcodeType.MTP - self.vllm_config = vllm_config - self.speculative_config = vllm_config.speculative_config - assert self.speculative_config is not None - self.draft_model_config = self.speculative_config.draft_model_config - self.method = self.speculative_config.method - - self.runner = runner - self.device = device - self.dtype = vllm_config.model_config.dtype - self.max_model_len = vllm_config.model_config.max_model_len - self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = self.speculative_config.num_speculative_tokens - self.decode_threshold = 1 + self.num_speculative_tokens - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.token_arange_np = np.arange(self.max_num_tokens) - # We need to get the hidden size from the draft model config because - # the draft model's hidden size can be different from the target model's - # hidden size (e.g., Llama 3.3 70B). - self.hidden_size = self.draft_model_config.get_hidden_size() - self.enable_shared_expert_dp = shared_expert_dp_enabled() - - self.pcp_size = self.runner.pcp_size - self.dcp_size = self.runner.dcp_size - self.pcp_rank = self.runner.pcp_rank - self.dcp_rank = self.runner.dcp_rank - - self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None - self.draft_indexer_metadata_builder: Optional[ - AttentionMetadataBuilder] = None - self.attn_layer_names: list[str] = [] - self.indexer_layer_names: list[str] = [] - - self.use_aclgraph = self.runner._use_aclgraph() - - # persistent buffers for aclgraph graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=device) - self.uses_mrope = self.vllm_config.model_config.uses_mrope - if self.uses_mrope: - # M-RoPE need (3, max_num_tokens) - self.mrope_positions = torch.zeros((3, self.max_num_tokens), - dtype=torch.int64, - device=device) - else: - # RoPE need (max_num_tokens,) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) - self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) - self.full_indices = range( - self.runner.max_num_tokens * self.pcp_size * self.dcp_size + - self.pcp_size * self.dcp_size * self.runner.max_num_reqs) - - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - max_batch_size = vllm_config.scheduler_config.max_num_seqs - max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) - self.arange = torch.arange(max_num_slots_for_arange, - device=device, - dtype=torch.int32) - self.arange_cpu = torch.arange(max_num_slots_for_arange, - device="cpu", - dtype=torch.int32) - - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) - - self.backup_next_token_ids = CpuGpuBuffer( - max_batch_size, - dtype=torch.int32, - pin_memory=is_pin_memory_available(), - device=device, - with_numpy=True, - ) - self.use_sparse = hasattr(vllm_config.model_config.hf_config, - "index_topk") - self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling - def load_model(self, model) -> None: loader = get_model_loader(self.vllm_config.load_config)