diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index db5fa2f..08a33f9 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -42,7 +42,7 @@ class TopKTopPSampler(nn.Module): """ logits = apply_top_k_top_p(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return random_sample(probs, generators), None def forward_kunlun( self, @@ -199,4 +199,4 @@ def flashinfer_sample( next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs( probs, top_k=k, top_p=p, deterministic=True) - return next_token_ids.view(-1) \ No newline at end of file + return next_token_ids.view(-1)