Files
xc-llm-ascend/vllm_ascend/spec_decode/medusa_proposer.py
SILONG ZENG 4fb3d5e1b2 [Lint]Style: Convert vllm-ascend/ to ruff format(Batch #8) (#6129)
### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| vllm_ascend/ops/\_\_init\_\_.py |
| vllm_ascend/ops/activation.py |
| vllm_ascend/ops/flashcomm2_oshard_manager.py |
| vllm_ascend/ops/layernorm.py |
| vllm_ascend/ops/mla.py |
| vllm_ascend/ops/mm_encoder_attention.py |
| vllm_ascend/ops/register_custom_ops.py |
| vllm_ascend/ops/vocab_parallel_embedding.py |
| vllm_ascend/ops/weight_prefetch.py |
| vllm_ascend/spec_decode/\_\_init\_\_.py |
| vllm_ascend/spec_decode/eagle_proposer.py |
| vllm_ascend/spec_decode/interface.py |
| vllm_ascend/spec_decode/mtp_proposer.py |
| vllm_ascend/spec_decode/ngram_proposer.py |
| vllm_ascend/spec_decode/suffix_proposer.py |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
2026-02-06 15:25:08 +08:00

92 lines
3.2 KiB
Python

import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.spec_decode.interface import SpecDcodeType
logger = init_logger(__name__)
class MedusaProposer(VllmMedusaProposer):
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner,
):
# Save config parameters
self.name = SpecDcodeType.MEDUSA
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size()
self.dtype = vllm_config.model_config.dtype
self.runner = runner
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
with_prefill: bool = False,
in_graph_capturing: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: torch.Tensor | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False,
):
hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device,
)
with set_ascend_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
num_actual_tokens=0,
in_profile_run=is_profile,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True,
):
self.model(hidden_states)
dummy_compute_logits(hidden_states)
def generate_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
spec_decode_metadata: SpecDecodeMetadata,
sample_hidden_states: torch.Tensor,
*args,
**kwargs,
):
if sample_hidden_states.shape[0] == len(valid_sampled_token_ids):
# The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states
else:
num_accepted_tokens = torch.tensor(
[len(t) for t in valid_sampled_token_ids], device=self.device, dtype=torch.long
)
num_draft_tokens = torch.tensor(spec_decode_metadata.num_draft_tokens, device=self.device, dtype=torch.long)
offsets = torch.cumsum(num_draft_tokens + 1, dim=0) - (num_draft_tokens + 1)
indices = offsets + num_accepted_tokens - 1
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
return spec_token_ids