diff --git a/docs/source/user_guide/feature_guide/Multi_Token_Prediction.md b/docs/source/user_guide/feature_guide/Multi_Token_Prediction.md index 7052acc9..4bf312e4 100644 --- a/docs/source/user_guide/feature_guide/Multi_Token_Prediction.md +++ b/docs/source/user_guide/feature_guide/Multi_Token_Prediction.md @@ -47,7 +47,6 @@ mtp_proposer.py ├── Proposer │ ├── load_model │ ├── dummy_run -│ ├── generate_token_ids │ ├── _prepare_inputs │ ├── _propose ``` @@ -86,11 +85,11 @@ def get_spec_decode_method(method, device, runner): if method == "ngram": - return NgramProposer(vllm_config, device, runner) + return AscendNgramProposer(vllm_config, device, runner) elif method in ["eagle", "eagle3"]: - return EagleProposer(vllm_config, device, runner) + return AscendEagleProposer(vllm_config, device, runner) elif method == 'mtp': - return MtpProposer(vllm_config, device, runner) + return AscendMtpProposer(vllm_config, device, runner) else: raise ValueError("Unknown speculative decoding method: " f"{method}") diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 5f9ab2f7..45233672 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -6,8 +6,7 @@ from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig, from tests.ut.base import TestBase from vllm_ascend.ascend_config import init_ascend_config -from vllm_ascend.spec_decode.eagle_proposer import EagleProposer -from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer class TestEagleProposerInitialization(TestBase): @@ -79,7 +78,7 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = EagleProposer(vllm_config=self.vllm_config, + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -102,7 +101,7 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = EagleProposer(vllm_config=self.vllm_config, + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -121,7 +120,7 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = EagleProposer(vllm_config=self.vllm_config, + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -140,7 +139,7 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = EagleProposer(vllm_config=self.vllm_config, + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -196,7 +195,7 @@ class TestEagleProposerLoadModel(TestBase): # Set the current vllm config set_current_vllm_config(self.vllm_config) - self.proposer = EagleProposer(vllm_config=self.vllm_config, + self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -235,7 +234,6 @@ class TestEagleProposerLoadModel(TestBase): mock_model.model.embed_tokens = MagicMock() mock_model.model.embed_tokens.weight = weight - self.proposer.name = SpecDcodeType.EAGLE mock_get_model.return_value = MagicMock() mock_get_model.return_value.model.embed_tokens.weight = weight @@ -301,7 +299,6 @@ class TestEagleProposerLoadModel(TestBase): mock_pp_group.return_value.world_size = 2 self.proposer.model = MagicMock() - self.proposer.name = SpecDcodeType.EAGLE with set_current_vllm_config(self.vllm_config): self.proposer.load_model(mock_model) @@ -373,7 +370,7 @@ class TestEagleProposerDummyRun(TestBase): # Set the current vllm config set_current_vllm_config(self.vllm_config) - self.proposer = EagleProposer(vllm_config=self.vllm_config, + self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.proposer.model = MagicMock() @@ -514,7 +511,7 @@ class TestEagleProposerHelperMethods(TestBase): # Set the current vllm config set_current_vllm_config(self.vllm_config) - self.proposer = EagleProposer(vllm_config=self.vllm_config, + self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 0c7e7265..6a4d88f8 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -12,7 +12,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.spec_decode.mtp_proposer import MtpProposer +from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer class TestMtpProposer: @@ -96,7 +96,7 @@ class TestMtpProposer: # Test basic initialization with set_current_vllm_config(vllm_config): - proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) assert proposer.vllm_config == vllm_config assert proposer.device == torch.device("cpu") @@ -118,7 +118,7 @@ class TestMtpProposer: vllm_config.scheduler_config.async_scheduling = False vllm_config.speculative_config.enforce_eager = False with set_current_vllm_config(vllm_config): - proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) assert proposer.use_cuda_graph is True @@ -133,7 +133,7 @@ class TestMtpProposer: mock_cpu_gpu_buffer.return_value = mock_buffer_instance mock_dp_group.return_value.world_size = 1 with set_current_vllm_config(vllm_config): - proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) # Mock _runnable to prevent actual execution proposer._runnable = MagicMock() @@ -165,7 +165,7 @@ class TestMtpProposer: mock_cpu_gpu_buffer.return_value = mock_buffer_instance mock_dp_group.return_value.world_size = 1 with set_current_vllm_config(vllm_config): - proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) # Mock _runnable to prevent actual execution proposer._runnable = MagicMock() @@ -197,9 +197,9 @@ class TestMtpProposer: mock_gpu_batch.req_ids = ["req1", "req2", "req3"] mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0} - proposer = MagicMock(spec=MtpProposer) + proposer = MagicMock(spec=AscendMtpProposer) proposer.input_ids = MagicMock(device=torch.device("cpu")) - proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__( + proposer.prepare_next_token_ids_cpu = AscendMtpProposer.prepare_next_token_ids_cpu.__get__( proposer) result = proposer.prepare_next_token_ids_cpu( sampled_token_ids=sampled_token_ids, @@ -253,10 +253,10 @@ class TestMtpProposer: mock_backup.copy_to_gpu = MagicMock() mock_cpu_gpu_buffer.return_value = mock_backup - proposer = MagicMock(spec=MtpProposer) + proposer = MagicMock(spec=AscendMtpProposer) proposer.backup_next_token_ids = mock_backup proposer.input_ids = MagicMock(device=torch.device("cpu")) - proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__( + proposer.prepare_next_token_ids_padded = AscendMtpProposer.prepare_next_token_ids_padded.__get__( proposer) discard_request_indices = torch.tensor([1, 3], dtype=torch.int64) @@ -327,11 +327,11 @@ class TestMtpProposer: mock_runner.pcp_size = 1 mock_runner.decode_token_per_req = MagicMock() - proposer = MagicMock(spec=MtpProposer) + proposer = MagicMock(spec=AscendMtpProposer) proposer.runner = mock_runner proposer.pcp_size = 1 proposer.arange = torch.arange(100, dtype=torch.int32) - proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__( + proposer.prepare_inputs_padded = AscendMtpProposer.prepare_inputs_padded.__get__( proposer) mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2], diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 6a1a66c9..5cfc6a70 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -16,23 +16,23 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # -from vllm_ascend.spec_decode.eagle_proposer import EagleProposer -from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer -from vllm_ascend.spec_decode.mtp_proposer import MtpProposer -from vllm_ascend.spec_decode.ngram_proposer import NgramProposer -from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer +from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer +from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer +from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer +from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer +from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer def get_spec_decode_method(method, vllm_config, device, runner): if method == "ngram": - return NgramProposer(vllm_config, device, runner) - elif method in ("eagle", "eagle3"): - return EagleProposer(vllm_config, device, runner) - elif method == "mtp": - return MtpProposer(vllm_config, device, runner) + return AscendNgramProposer(vllm_config, runner) elif method == "suffix": - return SuffixDecodingProposer(vllm_config, device, runner) + return AscendSuffixDecodingProposer(vllm_config, runner) elif method == "medusa": - return MedusaProposer(vllm_config, device, runner) + return AscendMedusaProposer(vllm_config, device) + elif method in ("eagle", "eagle3"): + return AscendEagleProposer(vllm_config, device, runner) + elif method == "mtp": + return AscendMtpProposer(vllm_config, device, runner) else: raise ValueError(f"Unknown speculative decoding method: {method}") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 89ecb2e5..db6032c7 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -30,8 +30,7 @@ from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID -from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer +from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -81,7 +80,7 @@ def split_inputs_tp_to_sp(hidden_states, out): return out[:padded_num_tokens_per_rank] -class EagleProposer(VllmEagleProposer): +class AscendEagleProposer(EagleProposer): _runnable: ACLGraphWrapper | Callable def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py deleted file mode 100644 index 42ac576a..00000000 --- a/vllm_ascend/spec_decode/interface.py +++ /dev/null @@ -1,53 +0,0 @@ -import enum - -import torch -from vllm.config import CUDAGraphMode, VllmConfig -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata - - -class SpecDcodeType(enum.Enum): - NGRAM = 0 - EAGLE = 1 - EAGLE3 = 2 - MTP = 4 - SUFFIX = 5 - MEDUSA = 6 - - -class Proposer: - def __init__(self, vllm_config: VllmConfig, device: torch.device = None, runner=None): - pass - - def load_model(self, model): - """Called by load_model in model_runner""" - raise NotImplementedError - - @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp: torch.Tensor | None = None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - ): - """Called by dummy_run in model_runner""" - raise NotImplementedError - - def generate_token_ids( - self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata = None, - scheduler_output: SchedulerOutput = None, - spec_decode_metadata: SpecDecodeMetadata = None, - positions: torch.Tensor = None, - num_scheduled_tokens: int = 0, - hidden_states: torch.Tensor = None, - aux_hidden_states: torch.Tensor = None, - ): - """Called by execute_model in model_runner""" - raise NotImplementedError diff --git a/vllm_ascend/spec_decode/medusa_proposer.py b/vllm_ascend/spec_decode/medusa_proposer.py index ff727cc8..de62cfd2 100644 --- a/vllm_ascend/spec_decode/medusa_proposer.py +++ b/vllm_ascend/spec_decode/medusa_proposer.py @@ -1,36 +1,20 @@ import torch -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import CUDAGraphMode from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer +from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.spec_decode.interface import SpecDcodeType logger = init_logger(__name__) -class MedusaProposer(VllmMedusaProposer): +class AscendMedusaProposer(MedusaProposer): """ Medusa proposer class for generating token sequences """ - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - runner, - ): - # Save config parameters - self.name = SpecDcodeType.MEDUSA - self.vllm_config = vllm_config - self.device = device - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size() - self.dtype = vllm_config.model_config.dtype - self.runner = runner - @torch.inference_mode() def dummy_run( self, @@ -62,14 +46,12 @@ class MedusaProposer(VllmMedusaProposer): self.model(hidden_states) dummy_compute_logits(hidden_states) - def generate_token_ids( + def propose( self, valid_sampled_token_ids: list[list[int]], sampling_metadata: SamplingMetadata, spec_decode_metadata: SpecDecodeMetadata, sample_hidden_states: torch.Tensor, - *args, - **kwargs, ): if sample_hidden_states.shape[0] == len(valid_sampled_token_ids): # The input to the target model does not include draft tokens. @@ -84,7 +66,7 @@ class MedusaProposer(VllmMedusaProposer): indices = offsets + num_accepted_tokens - 1 hidden_states = sample_hidden_states[indices] - spec_token_ids = self.propose( + spec_token_ids = super().propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 5da99cae..249e4502 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -15,11 +15,11 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla, update_cos_sin -from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.utils import lmhead_tp_enable -class MtpProposer(EagleProposer): +class AscendMtpProposer(AscendEagleProposer): # TODO: Find out why ModelRunner does not this explicit typing? model: nn.Module | ACLGraphWrapper diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 280d2ca8..3d698b84 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -1,16 +1,11 @@ import torch -from vllm.config import CUDAGraphMode -from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer - -from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType +from vllm.v1.spec_decode.ngram_proposer import NgramProposer -class NgramProposer(VllmNgramProposer, Proposer): - def __init__(self, vllm_config, device, runner): - super().__init__(vllm_config) - self.name = SpecDcodeType.NGRAM - self.device = device +class AscendNgramProposer(NgramProposer): + def __init__(self, vllm_config, runner): self.runner = runner + super().__init__(vllm_config) def load_model(self, *args, **kwargs): # No model to load. @@ -24,26 +19,22 @@ class NgramProposer(VllmNgramProposer, Proposer): in_graph_capturing=None, num_reqs=None, num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + aclgraph_runtime_mode=None, batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None, is_profile=False, ): pass - def generate_token_ids( + def propose( self, - valid_sampled_token_ids, - sampling_metadata=None, - scheduler_output=None, - spec_decode_metadata=None, - positions=None, - num_scheduled_tokens=None, - hidden_states=None, - aux_hidden_states=None, + sampled_token_ids: list[list[int]], + num_tokens_no_spec=None, + token_ids_cpu=None, + slot_masks: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, ) -> list[list[int]]: valid_ngram_requests = [] - for i, sampled_ids in enumerate(valid_sampled_token_ids): + for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: continue @@ -64,7 +55,7 @@ class NgramProposer(VllmNgramProposer, Proposer): valid_ngram_requests.append(i) draft_token_ids = self.batch_propose( - len(valid_sampled_token_ids), + len(sampled_token_ids), valid_ngram_requests, self.runner.input_batch.num_tokens_no_spec, self.runner.input_batch.token_ids_cpu, diff --git a/vllm_ascend/spec_decode/suffix_proposer.py b/vllm_ascend/spec_decode/suffix_proposer.py index 1cdbec3c..bb621a6e 100644 --- a/vllm_ascend/spec_decode/suffix_proposer.py +++ b/vllm_ascend/spec_decode/suffix_proposer.py @@ -1,22 +1,11 @@ -import torch -from vllm.config import CUDAGraphMode -from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer as VllmSuffixDecodingProposer - -from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer -class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): - def __init__(self, vllm_config, device, runner): +class AscendSuffixDecodingProposer(SuffixDecodingProposer): + def __init__(self, vllm_config, runner): super().__init__(vllm_config) - self.name = SpecDcodeType.SUFFIX - self.device = device self.runner = runner - def load_model(self, *args, **kwargs): - # No model to load. - pass - - @torch.inference_mode() def dummy_run( self, num_tokens, @@ -24,23 +13,12 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): in_graph_capturing=None, num_reqs=None, num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + aclgraph_runtime_mode=None, batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None, is_profile=False, ): pass - def generate_token_ids( - self, - valid_sampled_token_ids, - sampling_metadata=None, - scheduler_output=None, - spec_decode_metadata=None, - positions=None, - num_scheduled_tokens=None, - hidden_states=None, - aux_hidden_states=None, - ) -> list[list[int]]: - draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids) - return draft_token_ids + def propose(self, valid_sampled_token_ids): + return super().propose(self.runner.input_batch, valid_sampled_token_ids) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index af41b31a..748c7e01 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -74,8 +74,6 @@ from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner @@ -109,9 +107,11 @@ from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method -from vllm_ascend.spec_decode.eagle_proposer import EagleProposer -from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer -from vllm_ascend.spec_decode.mtp_proposer import MtpProposer +from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer +from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer +from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer +from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer +from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.utils import ( check_gdn_layer, enable_sp, @@ -402,9 +402,14 @@ class NPUModelRunner(GPUModelRunner): def _set_up_drafter(self): # Set up speculative decoding. - self.drafter: NgramProposer | EagleProposer | MtpProposer | SuffixDecodingProposer | MedusaProposer | None = ( - None - ) + self.drafter: ( + AscendNgramProposer + | AscendEagleProposer + | AscendMtpProposer + | AscendSuffixDecodingProposer + | AscendMedusaProposer + | None + ) = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: @@ -414,7 +419,7 @@ class NPUModelRunner(GPUModelRunner): if get_pp_group().is_last_rank: self.drafter = self._get_drafter() if self.speculative_config.method == "eagle3": - assert isinstance(self.drafter, EagleProposer) + assert isinstance(self.drafter, AscendEagleProposer) self.use_aux_hidden_state_outputs = self.drafter.eagle3_use_aux_hidden_state self.rejection_sampler = RejectionSampler(self.sampler) self.actual_seq_lengths_q = list( @@ -946,152 +951,134 @@ class NPUModelRunner(GPUModelRunner): positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: list[dict[str, Any]] | dict[str, Any], aux_hidden_states: torch.Tensor = None, sample_hidden_states: torch.Tensor = None, ) -> list[list[int]] | None: if not self.drafter: # Speculative decoding is not enabled. draft_token_ids = None - else: - if self.speculative_config.method in ("suffix", "ngram"): - draft_token_ids = self.drafter.generate_token_ids( - valid_sampled_token_ids, - sampling_metadata, - scheduler_output, - spec_decode_metadata, - positions, - num_scheduled_tokens, - hidden_states, - aux_hidden_states, - ) - elif isinstance(self.drafter, MedusaProposer): - draft_token_ids = self.drafter.generate_token_ids( - valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states - ) - elif self.speculative_config.use_eagle(): - common_attn_metadata = spec_decode_common_attn_metadata - sampled_token_ids = valid_sampled_token_ids + elif isinstance(self.drafter, (AscendNgramProposer, AscendSuffixDecodingProposer)): + draft_token_ids = self.drafter.propose(valid_sampled_token_ids) + elif isinstance(self.drafter, AscendMedusaProposer): + draft_token_ids = self.drafter.propose( + valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states + ) + elif self.speculative_config.use_eagle(): + common_attn_metadata = spec_decode_common_attn_metadata + sampled_token_ids = valid_sampled_token_ids - if self.vllm_config.speculative_config.disable_padded_drafter_batch: - # When padded-batch is disabled, the sampled_token_ids should be - # the cpu-side list[list[int]] of valid sampled tokens for each - # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), ( - "sampled_token_ids should be a python list whenpadded-batch is disabled." - ) - assert self.drafter is not None - next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens - ) - else: - # When using padded-batch, the sampled_token_ids should be - # the gpu tensor of sampled tokens for each request, of shape - # (num_reqs, num_spec_tokens + 1) with rejected tokens having - # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), ( - "sampled_token_ids should be a torch.Tensor whenpadded-batch is enabled." - ) - assert self.drafter is not None - next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_indices.gpu, - self.num_discarded_requests, - ) - self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) - - req_scheduled_tokens = scheduler_output.num_scheduled_tokens - if self.use_cp: - long_seq_metadata = self.long_seq_metadata # type: ignore - input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu - query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu - query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu - num_reqs = self.input_batch.num_reqs - num_prefill_reqs = self.pcp_manager.num_prefill_reqs - num_decode_reqs = self.pcp_manager.num_decode_reqs - else: - long_seq_metadata = None # type: ignore - num_prefill_reqs = 0 - num_decode_reqs = 0 - if spec_decode_metadata is None: - # update pcp related params - if self.pcp_size > 1: - token_indices_to_sample = query_start_loc_pcp_full[1 : num_reqs + 1] - 1 - target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] - target_positions = self._get_positions(num_scheduled_tokens) - target_hidden_states = hidden_states - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 - ) - else: - token_indices_to_sample = None - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - target_positions = self._get_positions(num_scheduled_tokens) - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 - ) - else: - target_hidden_states = hidden_states[:num_scheduled_tokens] - else: - if self.pcp_size > 1: - assert common_attn_metadata is not None - common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] = query_start_loc_pcp_full_cpu[ - : num_reqs + 1 - ] - assert common_attn_metadata is not None - common_attn_metadata.query_start_loc[: num_reqs + 1] = query_start_loc_pcp_full[: num_reqs + 1] - if self.vllm_config.speculative_config.disable_padded_drafter_batch: - # NOTE: Currently, MTP-fullgraph is incompatibility with pcp - token_indices_to_sample = None - assert self.drafter is not None - common_attn_metadata, token_indices = self.drafter.prepare_inputs( - common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens - ) - else: - assert self.drafter is not None - common_attn_metadata, token_indices, token_indices_to_sample = ( - self.drafter.prepare_inputs_padded( - common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count - ) - ) - if self.pcp_size > 1: - target_token_ids = input_ids_pcp_full[token_indices] - target_positions = positions - target_hidden_states = hidden_states - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) - else: - target_token_ids = self.input_ids.gpu[token_indices] - target_positions = self._get_positions(token_indices) - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) - else: - target_hidden_states = hidden_states[token_indices] + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list whenpadded-batch is disabled." + ) assert self.drafter is not None - draft_token_ids = self.drafter._propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata, - req_scheduled_tokens=req_scheduled_tokens, - long_seq_metadata=long_seq_metadata, - num_prefill_reqs=num_prefill_reqs, - num_decode_reqs=num_decode_reqs, - scheduler_output=scheduler_output, - num_scheduled_tokens=num_scheduled_tokens, + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens ) - else: - raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}") + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor whenpadded-batch is enabled." + ) + assert self.drafter is not None + next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) + + req_scheduled_tokens = scheduler_output.num_scheduled_tokens + if self.use_cp: + long_seq_metadata = self.long_seq_metadata # type: ignore + input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu + num_reqs = self.input_batch.num_reqs + num_prefill_reqs = self.pcp_manager.num_prefill_reqs + num_decode_reqs = self.pcp_manager.num_decode_reqs + else: + long_seq_metadata = None # type: ignore + num_prefill_reqs = 0 + num_decode_reqs = 0 + if spec_decode_metadata is None: + # update pcp related params + if self.pcp_size > 1: + token_indices_to_sample = query_start_loc_pcp_full[1 : num_reqs + 1] - 1 + target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) + target_hidden_states = hidden_states + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) + else: + token_indices_to_sample = None + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + else: + if self.pcp_size > 1: + assert common_attn_metadata is not None + common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] = query_start_loc_pcp_full_cpu[ + : num_reqs + 1 + ] + assert common_attn_metadata is not None + common_attn_metadata.query_start_loc[: num_reqs + 1] = query_start_loc_pcp_full[: num_reqs + 1] + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + # NOTE: Currently, MTP-fullgraph is incompatibility with pcp + token_indices_to_sample = None + assert self.drafter is not None + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens + ) + else: + assert self.drafter is not None + common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded( + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + ) + if self.pcp_size > 1: + target_token_ids = input_ids_pcp_full[token_indices] + target_positions = positions + target_hidden_states = hidden_states + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = self._get_positions(token_indices) + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + assert self.drafter is not None + draft_token_ids = self.drafter._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + req_scheduled_tokens=req_scheduled_tokens, + long_seq_metadata=long_seq_metadata, + num_prefill_reqs=num_prefill_reqs, + num_decode_reqs=num_decode_reqs, + scheduler_output=scheduler_output, + num_scheduled_tokens=num_scheduled_tokens, + ) + else: + raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}") return draft_token_ids @@ -1460,7 +1447,6 @@ class NPUModelRunner(GPUModelRunner): positions, scheduler_output.total_num_scheduled_tokens, hidden_states, - attn_metadata, aux_hidden_states, sample_hidden_states, ) @@ -2088,7 +2074,7 @@ class NPUModelRunner(GPUModelRunner): if kv_cache_gid > 0: cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid) if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, EagleProposer): + if isinstance(self.drafter, AscendEagleProposer): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: spec_decode_common_attn_metadata = cm else: