[Misc] Code clean up (#1674)

Remove useless function
- vLLM version: v0.9.2
- vLLM main:
b942c094e3

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-09 08:54:12 +08:00
committed by GitHub
parent 830332ebfc
commit cc1588be50
2 changed files with 1 additions and 81 deletions

View File

@@ -12,43 +12,6 @@ from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q.exponential_()
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs
class MtpProposer:
def __init__(
@@ -121,7 +84,7 @@ class MtpProposer:
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1