diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index b080524..4aa5012 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -11,6 +11,7 @@ from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform import xtorch_ops +import os logger = init_logger(__name__) @@ -28,6 +29,7 @@ class TopKTopPSampler(nn.Module): logger.info_once( "Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_kunlun + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self,