[CI] Fix sample backward compatibility problem (#648)

b411418ff0
this vllm commit change the sample usage. This PR adapt the change for
main and make sure it works for 0.8.4 as well.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-04-25 11:53:26 +08:00
committed by GitHub
parent d785e78563
commit 3879d9cad9
3 changed files with 47 additions and 12 deletions

View File

@@ -53,6 +53,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import vllm_version_is
if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
@@ -290,6 +291,12 @@ class NPUModelRunner:
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
self.attn_mask_len, self.dtype)
if vllm_version_is("0.8.4"):
self.sampler = None
else:
from vllm.v1.sample.sampler import Sampler
self.sampler = Sampler()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
@@ -645,10 +652,17 @@ class NPUModelRunner:
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
if vllm_version_is("0.8.4"):
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
assert self.sampler is not None
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.