[Misc] Code clean up (#1674)
Remove useless function
- vLLM version: v0.9.2
- vLLM main:
b942c094e3
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -384,46 +384,3 @@ def prepare_eagle_input_sequential(out_tensor: torch.Tensor,
|
||||
(target_indices < end_pos) & \
|
||||
(offset_tensor < num_tokens)
|
||||
out_tensor[target_indices[mask]] = values_to_store[mask]
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||
# to sample the draft tokens. We will use this after we find a way to manage
|
||||
# the draft prob tensor.
|
||||
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
|
||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||
# We should refactor this to reuse the same sampling implementation.
|
||||
def compute_probs_and_sample_next_token(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if sampling_metadata.all_greedy:
|
||||
# For greedy requests, draft_probs is not used in rejection sampling.
|
||||
# Therefore, we can just return the logits.
|
||||
probs = logits
|
||||
next_token_ids = logits.argmax(dim=-1)
|
||||
return next_token_ids, probs
|
||||
|
||||
is_greedy = sampling_metadata.temperature == -1
|
||||
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits.div_(temperature.view(-1, 1))
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
|
||||
# generating the draft tokens. We only use the temperature. While this
|
||||
# could degrade the acceptance rate, it does not affect the distribution
|
||||
# of the generated tokens after rejection sampling.
|
||||
|
||||
# TODO(woosuk): Consider seeds.
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
|
||||
# will be used later for rejection sampling.
|
||||
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
|
||||
if not sampling_metadata.all_random:
|
||||
greedy_token_ids = probs.argmax(dim=-1)
|
||||
next_token_ids = torch.where(
|
||||
is_greedy,
|
||||
greedy_token_ids,
|
||||
next_token_ids,
|
||||
)
|
||||
return next_token_ids, probs
|
||||
|
||||
@@ -12,43 +12,6 @@ from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||
|
||||
|
||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||
# We should refactor this to reuse the same sampling implementation.
|
||||
def compute_probs_and_sample_next_token(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if sampling_metadata.all_greedy:
|
||||
# For greedy requests, draft_probs is not used in rejection sampling.
|
||||
# Therefore, we can just return the logits.
|
||||
probs = logits
|
||||
next_token_ids = logits.argmax(dim=-1)
|
||||
return next_token_ids, probs
|
||||
|
||||
is_greedy = sampling_metadata.temperature == -1
|
||||
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits.div_(temperature.view(-1, 1))
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
|
||||
# generating the draft tokens. We only use the temperature. While this
|
||||
# could degrade the acceptance rate, it does not affect the distribution
|
||||
# of the generated tokens after rejection sampling.
|
||||
|
||||
# TODO(woosuk): Consider seeds.
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
|
||||
if not sampling_metadata.all_random:
|
||||
greedy_token_ids = probs.argmax(dim=-1)
|
||||
next_token_ids = torch.where(
|
||||
is_greedy,
|
||||
greedy_token_ids,
|
||||
next_token_ids,
|
||||
)
|
||||
return next_token_ids, probs
|
||||
|
||||
|
||||
class MtpProposer:
|
||||
|
||||
def __init__(
|
||||
@@ -121,7 +84,7 @@ class MtpProposer:
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
Reference in New Issue
Block a user