diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 862bd03..c082f98 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -3,6 +3,8 @@ import torch_npu from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample from vllm.v1.sample.sampler import Sampler +from vllm_ascend.utils import is_310p + class AscendSampler(Sampler): @@ -20,7 +22,8 @@ class AscendTopKTopPSampler(TopKTopPSampler): k: torch.Tensor, p: torch.Tensor, ) -> torch.Tensor: - if p is not None and k is not None: + # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P + if not is_310p() and p is not None and k is not None: # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) return torch_npu.npu_top_k_top_p(logits, p, k)