[Spec Decode]clean up spec decode interface (#6947)
This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.
This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -6,8 +6,7 @@ from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
|
||||
|
||||
class TestEagleProposerInitialization(TestBase):
|
||||
@@ -79,7 +78,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -102,7 +101,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -121,7 +120,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -140,7 +139,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -196,7 +195,7 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
|
||||
# Set the current vllm config
|
||||
set_current_vllm_config(self.vllm_config)
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -235,7 +234,6 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
mock_model.model.embed_tokens = MagicMock()
|
||||
mock_model.model.embed_tokens.weight = weight
|
||||
|
||||
self.proposer.name = SpecDcodeType.EAGLE
|
||||
mock_get_model.return_value = MagicMock()
|
||||
mock_get_model.return_value.model.embed_tokens.weight = weight
|
||||
|
||||
@@ -301,7 +299,6 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
mock_pp_group.return_value.world_size = 2
|
||||
|
||||
self.proposer.model = MagicMock()
|
||||
self.proposer.name = SpecDcodeType.EAGLE
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.proposer.load_model(mock_model)
|
||||
@@ -373,7 +370,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
|
||||
# Set the current vllm config
|
||||
set_current_vllm_config(self.vllm_config)
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
self.proposer.model = MagicMock()
|
||||
@@ -514,7 +511,7 @@ class TestEagleProposerHelperMethods(TestBase):
|
||||
|
||||
# Set the current vllm config
|
||||
set_current_vllm_config(self.vllm_config)
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
|
||||
|
||||
|
||||
class TestMtpProposer:
|
||||
@@ -96,7 +96,7 @@ class TestMtpProposer:
|
||||
|
||||
# Test basic initialization
|
||||
with set_current_vllm_config(vllm_config):
|
||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
|
||||
assert proposer.vllm_config == vllm_config
|
||||
assert proposer.device == torch.device("cpu")
|
||||
@@ -118,7 +118,7 @@ class TestMtpProposer:
|
||||
vllm_config.scheduler_config.async_scheduling = False
|
||||
vllm_config.speculative_config.enforce_eager = False
|
||||
with set_current_vllm_config(vllm_config):
|
||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
|
||||
assert proposer.use_cuda_graph is True
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestMtpProposer:
|
||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||
mock_dp_group.return_value.world_size = 1
|
||||
with set_current_vllm_config(vllm_config):
|
||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
|
||||
# Mock _runnable to prevent actual execution
|
||||
proposer._runnable = MagicMock()
|
||||
@@ -165,7 +165,7 @@ class TestMtpProposer:
|
||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||
mock_dp_group.return_value.world_size = 1
|
||||
with set_current_vllm_config(vllm_config):
|
||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
|
||||
# Mock _runnable to prevent actual execution
|
||||
proposer._runnable = MagicMock()
|
||||
@@ -197,9 +197,9 @@ class TestMtpProposer:
|
||||
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
|
||||
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
|
||||
|
||||
proposer = MagicMock(spec=MtpProposer)
|
||||
proposer = MagicMock(spec=AscendMtpProposer)
|
||||
proposer.input_ids = MagicMock(device=torch.device("cpu"))
|
||||
proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__(
|
||||
proposer.prepare_next_token_ids_cpu = AscendMtpProposer.prepare_next_token_ids_cpu.__get__(
|
||||
proposer)
|
||||
result = proposer.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
@@ -253,10 +253,10 @@ class TestMtpProposer:
|
||||
mock_backup.copy_to_gpu = MagicMock()
|
||||
mock_cpu_gpu_buffer.return_value = mock_backup
|
||||
|
||||
proposer = MagicMock(spec=MtpProposer)
|
||||
proposer = MagicMock(spec=AscendMtpProposer)
|
||||
proposer.backup_next_token_ids = mock_backup
|
||||
proposer.input_ids = MagicMock(device=torch.device("cpu"))
|
||||
proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__(
|
||||
proposer.prepare_next_token_ids_padded = AscendMtpProposer.prepare_next_token_ids_padded.__get__(
|
||||
proposer)
|
||||
|
||||
discard_request_indices = torch.tensor([1, 3], dtype=torch.int64)
|
||||
@@ -327,11 +327,11 @@ class TestMtpProposer:
|
||||
mock_runner.pcp_size = 1
|
||||
mock_runner.decode_token_per_req = MagicMock()
|
||||
|
||||
proposer = MagicMock(spec=MtpProposer)
|
||||
proposer = MagicMock(spec=AscendMtpProposer)
|
||||
proposer.runner = mock_runner
|
||||
proposer.pcp_size = 1
|
||||
proposer.arange = torch.arange(100, dtype=torch.int32)
|
||||
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
|
||||
proposer.prepare_inputs_padded = AscendMtpProposer.prepare_inputs_padded.__get__(
|
||||
proposer)
|
||||
|
||||
mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2],
|
||||
|
||||
Reference in New Issue
Block a user