diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index df5015f1..89f558db 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.ngram_proposer import NgramProposer from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer @@ -31,6 +32,8 @@ def get_spec_decode_method(method, vllm_config, device, runner): return MtpProposer(vllm_config, device, runner) elif method == 'suffix': return SuffixDecodingProposer(vllm_config, device, runner) + elif method == "medusa": + return MedusaProposer(vllm_config, device, runner) else: raise ValueError("Unknown speculative decoding method: " f"{method}") diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index f7f92ddb..feec5bcf 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -14,6 +14,7 @@ class SpecDcodeType(enum.Enum): EAGLE3 = 2 MTP = 4 SUFFIX = 5 + MEDUSA = 6 class Proposer: diff --git a/vllm_ascend/spec_decode/medusa_proposer.py b/vllm_ascend/spec_decode/medusa_proposer.py new file mode 100644 index 00000000..eda2e41a --- /dev/null +++ b/vllm_ascend/spec_decode/medusa_proposer.py @@ -0,0 +1,98 @@ +from typing import Optional + +import torch +import torch.nn as nn +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import is_mixture_of_experts +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer + +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.spec_decode.interface import SpecDcodeType + +logger = init_logger(__name__) + + +class MedusaProposer(VllmMedusaProposer): + """ + Medusa proposer class for generating token sequences + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner, + ): + # Save config parameters + self.name = SpecDcodeType.MEDUSA + self.vllm_config = vllm_config + self.device = device + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.hidden_size = (vllm_config.speculative_config.draft_model_config. + get_hidden_size()) + self.dtype = vllm_config.model_config.dtype + self.runner = runner + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + in_graph_capturing: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None, + is_profile=False): + hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device, + ) + with set_ascend_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + num_actual_tokens=0, + in_profile_run=is_profile, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True): + self.model(hidden_states) + dummy_compute_logits(hidden_states) + + def generate_token_ids(self, valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + spec_decode_metadata: SpecDecodeMetadata, + sample_hidden_states: torch.Tensor, + *args, + **kwargs + ): + + if sample_hidden_states.shape[0] == len(valid_sampled_token_ids): + # The input to the target model does not include draft tokens. + hidden_states = sample_hidden_states + else: + num_accepted_tokens = torch.tensor( + [len(t) for t in valid_sampled_token_ids], + device=self.device, + dtype=torch.long) + num_draft_tokens = torch.tensor( + spec_decode_metadata.num_draft_tokens, + device=self.device, + dtype=torch.long) + + offsets = torch.cumsum(num_draft_tokens + 1, + dim=0) - (num_draft_tokens + 1) + indices = offsets + num_accepted_tokens - 1 + hidden_states = sample_hidden_states[indices] + + spec_token_ids = self.propose( + target_hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + return spec_token_ids diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a0e291bf..cd728982 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -100,6 +100,7 @@ from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_moe_model, @@ -363,7 +364,8 @@ class NPUModelRunner(GPUModelRunner): def _set_up_drafter(self): # Set up speculative decoding. self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, - SuffixDecodingProposer]] = None + SuffixDecodingProposer, + MedusaProposer]] = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: @@ -1288,6 +1290,7 @@ class NPUModelRunner(GPUModelRunner): hidden_states: torch.Tensor, attn_metadata: dict[str, Any], aux_hidden_states: torch.Tensor = None, + sample_hidden_states: torch.Tensor = None ) -> Optional[list[list[int]]]: if not self.drafter: # Speculative decoding is not enabled. @@ -1298,7 +1301,10 @@ class NPUModelRunner(GPUModelRunner): valid_sampled_token_ids, sampling_metadata, scheduler_output, spec_decode_metadata, positions, num_scheduled_tokens, hidden_states, aux_hidden_states) - + elif isinstance(self.drafter, MedusaProposer): + draft_token_ids = self.drafter.generate_token_ids( + valid_sampled_token_ids, sampling_metadata, + spec_decode_metadata, sample_hidden_states) elif self.speculative_config.use_eagle(): common_attn_metadata = self.spec_decode_common_attn_metadata sampled_token_ids = valid_sampled_token_ids @@ -1660,6 +1666,7 @@ class NPUModelRunner(GPUModelRunner): hidden_states, attn_metadata, aux_hidden_states, + sample_hidden_states ) self._copy_draft_token_ids_to_cpu(scheduler_output)