diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index b7905373..c1ef10db 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -3,14 +3,17 @@ from typing import Optional import torch import torch.nn as nn +import torch_npu import vllm.v1.sample.rejection_sampler as rs from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.rejection_sampler import (RejectionSampler, - apply_sampling_constraints, generate_uniform_probs) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type + PLACEHOLDER_TOKEN_ID = -1 GREEDY_TEMPERATURE = -1 # Maximum number of speculative draft tokens allowed per request in a single @@ -104,6 +107,70 @@ class AscendRejectionSampler(RejectionSampler, nn.Module): return output_token_ids +def apply_sampling_constraints( + logits: torch.Tensor, # [num_tokens, vocab_size] + cu_num_draft_tokens: torch.Tensor, # [batch_size] + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + """Process logits based on sampling metadata. + + This function applies temperature scaling to the logits, + as well as top-k and top-p. For greedy decoding, it returns + the original logits. + + Args: + logits: Input logits tensor to be processed. + cu_num_draft_tokens: Cumulative number of draft tokens. + sampling_metadata: Metadata containing sampling parameters such as + temperature and whether greedy sampling is used. + + Returns: + torch.Tensor: Processed logits if non-greedy sampling is used, + otherwise returns the original logits. + """ + assert logits.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + if sampling_metadata.all_greedy: + return logits + + num_tokens = logits.shape[0] + temperature = expand_batch_to_tokens( + sampling_metadata.temperature, + cu_num_draft_tokens, + num_tokens, + replace_from=GREEDY_TEMPERATURE, + replace_to=1, + ) + # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor. + logits.div_(temperature.unsqueeze(-1)) + + # Get expanded top_k and top_p tensors. + top_k = None + if sampling_metadata.top_k is not None: + top_k = expand_batch_to_tokens( + sampling_metadata.top_k, + cu_num_draft_tokens, + num_tokens, + ) + top_p = None + if sampling_metadata.top_p is not None: + top_p = expand_batch_to_tokens( + sampling_metadata.top_p, + cu_num_draft_tokens, + num_tokens, + ) + + if get_ascend_device_type( + ) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int( + top_k.max()) <= 1024: + return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16), + top_k) + else: + # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, + # which is slow for large vocab sizes. This may cause performance issues. + return apply_top_k_top_p(logits, top_k, top_p) + + def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor,