[Spec Decode]clean up spec decode interface (#6947)

This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.

This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2026-03-05 14:30:10 +08:00
committed by GitHub
parent 2bd9c35788
commit 13777bf3f0
11 changed files with 194 additions and 315 deletions

View File

@@ -6,8 +6,7 @@ from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig,
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
class TestEagleProposerInitialization(TestBase):
@@ -79,7 +78,7 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config):
proposer = EagleProposer(vllm_config=self.vllm_config,
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
@@ -102,7 +101,7 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config):
proposer = EagleProposer(vllm_config=self.vllm_config,
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
@@ -121,7 +120,7 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config):
proposer = EagleProposer(vllm_config=self.vllm_config,
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
@@ -140,7 +139,7 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config):
proposer = EagleProposer(vllm_config=self.vllm_config,
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
@@ -196,7 +195,7 @@ class TestEagleProposerLoadModel(TestBase):
# Set the current vllm config
set_current_vllm_config(self.vllm_config)
self.proposer = EagleProposer(vllm_config=self.vllm_config,
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
@@ -235,7 +234,6 @@ class TestEagleProposerLoadModel(TestBase):
mock_model.model.embed_tokens = MagicMock()
mock_model.model.embed_tokens.weight = weight
self.proposer.name = SpecDcodeType.EAGLE
mock_get_model.return_value = MagicMock()
mock_get_model.return_value.model.embed_tokens.weight = weight
@@ -301,7 +299,6 @@ class TestEagleProposerLoadModel(TestBase):
mock_pp_group.return_value.world_size = 2
self.proposer.model = MagicMock()
self.proposer.name = SpecDcodeType.EAGLE
with set_current_vllm_config(self.vllm_config):
self.proposer.load_model(mock_model)
@@ -373,7 +370,7 @@ class TestEagleProposerDummyRun(TestBase):
# Set the current vllm config
set_current_vllm_config(self.vllm_config)
self.proposer = EagleProposer(vllm_config=self.vllm_config,
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
self.proposer.model = MagicMock()
@@ -514,7 +511,7 @@ class TestEagleProposerHelperMethods(TestBase):
# Set the current vllm config
set_current_vllm_config(self.vllm_config)
self.proposer = EagleProposer(vllm_config=self.vllm_config,
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)

View File

@@ -12,7 +12,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
class TestMtpProposer:
@@ -96,7 +96,7 @@ class TestMtpProposer:
# Test basic initialization
with set_current_vllm_config(vllm_config):
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.vllm_config == vllm_config
assert proposer.device == torch.device("cpu")
@@ -118,7 +118,7 @@ class TestMtpProposer:
vllm_config.scheduler_config.async_scheduling = False
vllm_config.speculative_config.enforce_eager = False
with set_current_vllm_config(vllm_config):
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.use_cuda_graph is True
@@ -133,7 +133,7 @@ class TestMtpProposer:
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_dp_group.return_value.world_size = 1
with set_current_vllm_config(vllm_config):
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
# Mock _runnable to prevent actual execution
proposer._runnable = MagicMock()
@@ -165,7 +165,7 @@ class TestMtpProposer:
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_dp_group.return_value.world_size = 1
with set_current_vllm_config(vllm_config):
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
# Mock _runnable to prevent actual execution
proposer._runnable = MagicMock()
@@ -197,9 +197,9 @@ class TestMtpProposer:
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
proposer = MagicMock(spec=MtpProposer)
proposer = MagicMock(spec=AscendMtpProposer)
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__(
proposer.prepare_next_token_ids_cpu = AscendMtpProposer.prepare_next_token_ids_cpu.__get__(
proposer)
result = proposer.prepare_next_token_ids_cpu(
sampled_token_ids=sampled_token_ids,
@@ -253,10 +253,10 @@ class TestMtpProposer:
mock_backup.copy_to_gpu = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_backup
proposer = MagicMock(spec=MtpProposer)
proposer = MagicMock(spec=AscendMtpProposer)
proposer.backup_next_token_ids = mock_backup
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__(
proposer.prepare_next_token_ids_padded = AscendMtpProposer.prepare_next_token_ids_padded.__get__(
proposer)
discard_request_indices = torch.tensor([1, 3], dtype=torch.int64)
@@ -327,11 +327,11 @@ class TestMtpProposer:
mock_runner.pcp_size = 1
mock_runner.decode_token_per_req = MagicMock()
proposer = MagicMock(spec=MtpProposer)
proposer = MagicMock(spec=AscendMtpProposer)
proposer.runner = mock_runner
proposer.pcp_size = 1
proposer.arange = torch.arange(100, dtype=torch.int32)
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
proposer.prepare_inputs_padded = AscendMtpProposer.prepare_inputs_padded.__get__(
proposer)
mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2],