[Misc] Code clean up (#1674)
Remove useless function
- vLLM version: v0.9.2
- vLLM main:
b942c094e3
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -12,43 +12,6 @@ from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||
|
||||
|
||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||
# We should refactor this to reuse the same sampling implementation.
|
||||
def compute_probs_and_sample_next_token(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if sampling_metadata.all_greedy:
|
||||
# For greedy requests, draft_probs is not used in rejection sampling.
|
||||
# Therefore, we can just return the logits.
|
||||
probs = logits
|
||||
next_token_ids = logits.argmax(dim=-1)
|
||||
return next_token_ids, probs
|
||||
|
||||
is_greedy = sampling_metadata.temperature == -1
|
||||
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits.div_(temperature.view(-1, 1))
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
|
||||
# generating the draft tokens. We only use the temperature. While this
|
||||
# could degrade the acceptance rate, it does not affect the distribution
|
||||
# of the generated tokens after rejection sampling.
|
||||
|
||||
# TODO(woosuk): Consider seeds.
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
|
||||
if not sampling_metadata.all_random:
|
||||
greedy_token_ids = probs.argmax(dim=-1)
|
||||
next_token_ids = torch.where(
|
||||
is_greedy,
|
||||
greedy_token_ids,
|
||||
next_token_ids,
|
||||
)
|
||||
return next_token_ids, probs
|
||||
|
||||
|
||||
class MtpProposer:
|
||||
|
||||
def __init__(
|
||||
@@ -121,7 +84,7 @@ class MtpProposer:
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
Reference in New Issue
Block a user