[Spec Decode]clean up spec decode interface (#6947)
This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.
This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -6,8 +6,7 @@ from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
|
||||
|
||||
class TestEagleProposerInitialization(TestBase):
|
||||
@@ -79,7 +78,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -102,7 +101,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -121,7 +120,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -140,7 +139,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
init_ascend_config(self.vllm_config)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -196,7 +195,7 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
|
||||
# Set the current vllm config
|
||||
set_current_vllm_config(self.vllm_config)
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
@@ -235,7 +234,6 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
mock_model.model.embed_tokens = MagicMock()
|
||||
mock_model.model.embed_tokens.weight = weight
|
||||
|
||||
self.proposer.name = SpecDcodeType.EAGLE
|
||||
mock_get_model.return_value = MagicMock()
|
||||
mock_get_model.return_value.model.embed_tokens.weight = weight
|
||||
|
||||
@@ -301,7 +299,6 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
mock_pp_group.return_value.world_size = 2
|
||||
|
||||
self.proposer.model = MagicMock()
|
||||
self.proposer.name = SpecDcodeType.EAGLE
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.proposer.load_model(mock_model)
|
||||
@@ -373,7 +370,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
|
||||
# Set the current vllm config
|
||||
set_current_vllm_config(self.vllm_config)
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
self.proposer.model = MagicMock()
|
||||
@@ -514,7 +511,7 @@ class TestEagleProposerHelperMethods(TestBase):
|
||||
|
||||
# Set the current vllm config
|
||||
set_current_vllm_config(self.vllm_config)
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user