[CI] Fix sample backward compatibility problem (#648)
b411418ff0
this vllm commit change the sample usage. This PR adapt the change for
main and make sure it works for 0.8.4 as well.
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -937,6 +937,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
if vllm_version_is("0.8.4"):
|
||||
self.sampler = None
|
||||
else:
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@@ -1404,10 +1410,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if vllm_version_is("0.8.4"):
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
else:
|
||||
assert self.sampler is not None
|
||||
output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
|
||||
Reference in New Issue
Block a user