[CI] Fix sample backward compatibility problem (#648)

b411418ff0
this vllm commit change the sample usage. This PR adapt the change for
main and make sure it works for 0.8.4 as well.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-04-25 11:53:26 +08:00
committed by GitHub
parent d785e78563
commit 3879d9cad9
3 changed files with 47 additions and 12 deletions

View File

@@ -28,6 +28,7 @@ from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerWrapperBase)
from vllm_ascend.attention.attention import AscendMetadata
from vllm_ascend.utils import vllm_version_is
# A flag to enable debug prints for the updated input tensors
# before each step.
@@ -286,10 +287,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if vllm_version_is("0.8.4"):
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
else:
assert self.sampler is not None
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
outputs.append(output)
if model_input.attn_metadata.num_prefills == 0 \