diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 6a5c130..9cceda6 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -29,7 +29,8 @@ class AscendTopKTopPSampler(TopKTopPSampler): p: torch.Tensor, ) -> torch.Tensor: # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P - if not is_310p() and p is not None and k is not None: + if not is_310p() and p is not None and k is not None and 1 <= int( + k.max()) <= 1024: # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) return torch_npu.npu_top_k_top_p(logits, p, k)