[CI] Fix sample backward compatibility problem (#648)
b411418ff0
this vllm commit change the sample usage. This PR adapt the change for
main and make sure it works for 0.8.4 as well.
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from vllm.worker.model_runner_base import (ModelRunnerBase,
|
|||||||
ModelRunnerWrapperBase)
|
ModelRunnerWrapperBase)
|
||||||
|
|
||||||
from vllm_ascend.attention.attention import AscendMetadata
|
from vllm_ascend.attention.attention import AscendMetadata
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
# A flag to enable debug prints for the updated input tensors
|
# A flag to enable debug prints for the updated input tensors
|
||||||
# before each step.
|
# before each step.
|
||||||
@@ -286,10 +287,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return []
|
return []
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
if vllm_version_is("0.8.4"):
|
||||||
logits=logits,
|
output = self.model.sample(
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
logits=logits,
|
||||||
)
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.sampler is not None
|
||||||
|
output = self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
if model_input.attn_metadata.num_prefills == 0 \
|
if model_input.attn_metadata.num_prefills == 0 \
|
||||||
|
|||||||
@@ -937,6 +937,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
|||||||
SamplingMetadataCache() \
|
SamplingMetadataCache() \
|
||||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
|
if vllm_version_is("0.8.4"):
|
||||||
|
self.sampler = None
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.sampler import get_sampler
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@@ -1404,10 +1410,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
|||||||
model_input.async_callback()
|
model_input.async_callback()
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output: SamplerOutput = self.model.sample(
|
if vllm_version_is("0.8.4"):
|
||||||
logits=logits,
|
output = self.model.sample(
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
logits=logits,
|
||||||
)
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.sampler is not None
|
||||||
|
output = self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time
|
and self.observability_config.collect_model_forward_time
|
||||||
and output is not None):
|
and output is not None):
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|||||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr # type: ignore[import-untyped]
|
import xgrammar as xgr # type: ignore[import-untyped]
|
||||||
@@ -290,6 +291,12 @@ class NPUModelRunner:
|
|||||||
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||||
self.attn_mask_len, self.dtype)
|
self.attn_mask_len, self.dtype)
|
||||||
|
|
||||||
|
if vllm_version_is("0.8.4"):
|
||||||
|
self.sampler = None
|
||||||
|
else:
|
||||||
|
from vllm.v1.sample.sampler import Sampler
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
output.
|
output.
|
||||||
@@ -645,10 +652,17 @@ class NPUModelRunner:
|
|||||||
|
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
sampler_output = self.model.sample(
|
if vllm_version_is("0.8.4"):
|
||||||
logits=logits,
|
sampler_output = self.model.sample(
|
||||||
sampling_metadata=sampling_metadata,
|
logits=logits,
|
||||||
)
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.sampler is not None
|
||||||
|
sampler_output = self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
|
|||||||
Reference in New Issue
Block a user