diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 37abdd4..7d015f1 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -24,14 +24,14 @@ class AscendTopKTopPSampler(TopKTopPSampler): k: torch.Tensor, p: torch.Tensor, ) -> torch.Tensor: - # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P - if not is_310p() and p is not None and k is not None and 1 <= int( - k.max()) <= 1024: - # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) - return torch_npu.npu_top_k_top_p(logits, p, k) - if p is None and k is None: return logits + # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P + if not is_310p(): + # npu_top_k_top_p requires parameter k ranged from 1 to 1024 + if k is None or 1 <= int(k.max()) <= 1024: + # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) + return torch_npu.npu_top_k_top_p(logits, p, k) probs = logits.softmax(dim=-1) probs_sort, _ = probs.sort(dim=-1, descending=False)