From 3879d9cad95c14e3cce8fc053540e369a39cd341 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Fri, 25 Apr 2025 11:53:26 +0800 Subject: [PATCH] [CI] Fix sample backward compatibility problem (#648) https://github.com/vllm-project/vllm/commit/b411418ff090a168c85eab243b14b7350bf73db4 this vllm commit change the sample usage. This PR adapt the change for main and make sure it works for 0.8.4 as well. Signed-off-by: wangxiyuan --- vllm_ascend/worker/draft_model_runner.py | 16 ++++++++++++---- vllm_ascend/worker/model_runner.py | 21 +++++++++++++++++---- vllm_ascend/worker/model_runner_v1.py | 22 ++++++++++++++++++---- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/vllm_ascend/worker/draft_model_runner.py b/vllm_ascend/worker/draft_model_runner.py index 162c1ee..504d94e 100644 --- a/vllm_ascend/worker/draft_model_runner.py +++ b/vllm_ascend/worker/draft_model_runner.py @@ -28,6 +28,7 @@ from vllm.worker.model_runner_base import (ModelRunnerBase, ModelRunnerWrapperBase) from vllm_ascend.attention.attention import AscendMetadata +from vllm_ascend.utils import vllm_version_is # A flag to enable debug prints for the updated input tensors # before each step. @@ -286,10 +287,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): if not self.is_driver_worker: return [] # Sample the next token. - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) + if vllm_version_is("0.8.4"): + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + else: + assert self.sampler is not None + output = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) outputs.append(output) if model_input.attn_metadata.num_prefills == 0 \ diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index bfcdc14..f1425d4 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -937,6 +937,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None + if vllm_version_is("0.8.4"): + self.sampler = None + else: + from vllm.model_executor.layers.sampler import get_sampler + self.sampler = get_sampler() + def get_model(self) -> nn.Module: return self.model @@ -1404,10 +1410,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) + if vllm_version_is("0.8.4"): + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + else: + assert self.sampler is not None + output = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time and output is not None): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ca157ab..5a9f78f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -53,6 +53,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.platform import NPUPlatform +from vllm_ascend.utils import vllm_version_is if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -290,6 +291,12 @@ class NPUModelRunner: self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( self.attn_mask_len, self.dtype) + if vllm_version_is("0.8.4"): + self.sampler = None + else: + from vllm.v1.sample.sampler import Sampler + self.sampler = Sampler() + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -645,10 +652,17 @@ class NPUModelRunner: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata - sampler_output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) + if vllm_version_is("0.8.4"): + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + assert self.sampler is not None + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize.