Clean up v0.9.1 code (#1672)
vllm has released 0.9.2. This PR drop 0.9.1 support.
- vLLM version: v0.9.1
- vLLM main:
b942c094e3
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import has_step_pooler
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
@@ -79,7 +80,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
ProfileExecuteDuration,
|
||||
check_torchair_cache_exist, is_310p,
|
||||
maybe_converting_weight_acl_format,
|
||||
vllm_version_is, write_kv_cache_bytes_to_file)
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
@@ -95,9 +96,6 @@ import vllm.envs as envs_vllm
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
if vllm_version_is("0.9.1"):
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
|
||||
if is_310p():
|
||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||
|
||||
@@ -408,16 +406,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
generator = None
|
||||
|
||||
# For vllm v0.9.1 version compatibility, we check if
|
||||
# `pooling_params` is present in the new request data.
|
||||
pooling_params = getattr(new_req_data, "pooling_params", None)
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
pooling_params=new_req_data.pooling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
@@ -465,62 +460,59 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
if vllm_version_is("0.9.1"):
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
for i, req_id in enumerate(req_data.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
|
||||
# Update the cached states.
|
||||
num_computed_tokens = req_data.num_computed_tokens
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
if not is_last_rank:
|
||||
new_token_ids = req_data.new_token_ids[i]
|
||||
# Add the sampled token(s) from the previous step (if any).
|
||||
# This doesn't include "unverified" tokens like spec decode tokens.
|
||||
num_new_tokens = (num_computed_tokens +
|
||||
len(req_data.new_token_ids) -
|
||||
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
||||
req_state.num_tokens)
|
||||
if num_new_tokens == 1:
|
||||
# Avoid slicing list in most common case.
|
||||
req_state.output_token_ids.append(
|
||||
req_data.new_token_ids[-1])
|
||||
req_state.output_token_ids.append(new_token_ids[-1])
|
||||
elif num_new_tokens > 0:
|
||||
req_state.output_token_ids.extend(
|
||||
req_data.new_token_ids[-num_new_tokens:])
|
||||
# Update the block IDs.
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
|
||||
req_state.block_ids,
|
||||
req_data.new_block_ids,
|
||||
strict=True):
|
||||
block_ids.extend(new_block_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = req_data.new_block_ids
|
||||
new_token_ids[-num_new_tokens:])
|
||||
# Update the block IDs.
|
||||
if not resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_ids in zip( # type: ignore[call-overload]
|
||||
req_state.block_ids, new_block_ids):
|
||||
block_ids.extend(new_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
req_ids_to_add.append(req_id)
|
||||
continue
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
req_ids_to_add.append(req_id)
|
||||
continue
|
||||
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
num_computed_tokens)
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
num_computed_tokens)
|
||||
|
||||
start_index = (len(req_state.block_ids) -
|
||||
len(req_data.new_block_ids))
|
||||
self.input_batch.block_table.append_row(
|
||||
req_data.new_block_ids, req_index)
|
||||
self.input_batch.block_table.append_row(new_block_ids, req_index)
|
||||
|
||||
if not is_last_rank:
|
||||
# Add new_token_ids to token_ids_cpu.
|
||||
start_token_index = num_computed_tokens
|
||||
end_token_index = num_computed_tokens + len(
|
||||
req_data.new_token_ids)
|
||||
end_token_index = num_computed_tokens + len(new_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index,
|
||||
start_token_index:end_token_index] = req_data.new_token_ids
|
||||
start_token_index:end_token_index] = new_token_ids
|
||||
self.input_batch.num_tokens_no_spec[
|
||||
req_index] = end_token_index
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
@@ -534,75 +526,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_index:end_token_index] = spec_token_ids
|
||||
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
||||
self.input_batch.num_tokens[req_index] = end_token_index
|
||||
else:
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
for i, req_id in enumerate(req_data.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
if not is_last_rank:
|
||||
new_token_ids = req_data.new_token_ids[i]
|
||||
# Add the sampled token(s) from the previous step (if any).
|
||||
# This doesn't include "unverified" tokens like spec decode tokens.
|
||||
num_new_tokens = (num_computed_tokens +
|
||||
len(new_token_ids) -
|
||||
req_state.num_tokens)
|
||||
if num_new_tokens == 1:
|
||||
# Avoid slicing list in most common case.
|
||||
req_state.output_token_ids.append(new_token_ids[-1])
|
||||
elif num_new_tokens > 0:
|
||||
req_state.output_token_ids.extend(
|
||||
new_token_ids[-num_new_tokens:])
|
||||
# Update the block IDs.
|
||||
if not resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_ids in zip( # type: ignore[call-overload]
|
||||
req_state.block_ids, new_block_ids):
|
||||
block_ids.extend(new_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
req_ids_to_add.append(req_id)
|
||||
continue
|
||||
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
num_computed_tokens)
|
||||
|
||||
self.input_batch.block_table.append_row(
|
||||
new_block_ids, req_index)
|
||||
|
||||
if not is_last_rank:
|
||||
# Add new_token_ids to token_ids_cpu.
|
||||
start_token_index = num_computed_tokens
|
||||
end_token_index = num_computed_tokens + len(new_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index,
|
||||
start_token_index:end_token_index] = new_token_ids
|
||||
self.input_batch.num_tokens_no_spec[
|
||||
req_index] = end_token_index
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, ())
|
||||
if spec_token_ids:
|
||||
start_index = end_token_index
|
||||
end_token_index += len(spec_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index,
|
||||
start_index:end_token_index] = spec_token_ids
|
||||
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
||||
self.input_batch.num_tokens[req_index] = end_token_index
|
||||
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
@@ -835,25 +758,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# compute completion's mrope_positions on-the-fly
|
||||
dst_start = mrope_pos_ptr
|
||||
dst_end = mrope_pos_ptr + completion_part_len
|
||||
|
||||
if vllm_version_is("0.9.1"):
|
||||
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
req.mrope_position_delta,
|
||||
context_len=num_computed_tokens +
|
||||
prompt_part_len,
|
||||
seq_len=num_computed_tokens +
|
||||
prompt_part_len +
|
||||
completion_part_len,
|
||||
)
|
||||
else:
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
out=self.mrope_positions_np,
|
||||
out_offset=dst_start,
|
||||
mrope_position_delta=req.mrope_position_delta,
|
||||
context_len=num_computed_tokens + prompt_part_len,
|
||||
num_new_tokens=completion_part_len,
|
||||
)
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
out=self.mrope_positions_np,
|
||||
out_offset=dst_start,
|
||||
mrope_position_delta=req.mrope_position_delta,
|
||||
context_len=num_computed_tokens + prompt_part_len,
|
||||
num_new_tokens=completion_part_len,
|
||||
)
|
||||
|
||||
mrope_pos_ptr += completion_part_len
|
||||
|
||||
@@ -1661,30 +1572,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
if not vllm_version_is("0.9.1"):
|
||||
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
continue
|
||||
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.model_config.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.model_config.max_model_len}")
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.model_config.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.model_config.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_idx, start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
req_id = self.input_batch.req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
req_id = self.input_batch.req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
spec_token_ids = self._get_spec_token_ids(
|
||||
valid_sampled_token_ids,
|
||||
@@ -1697,25 +1607,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata,
|
||||
aux_hidden_states,
|
||||
)
|
||||
if vllm_version_is("0.9.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
durations = ProfileExecuteDuration().pop_captured_sync()
|
||||
if durations:
|
||||
@@ -2024,15 +1925,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
QKVParallelLinear, RowParallelLinear)):
|
||||
module.weight.data = torch_npu.npu_format_cast(
|
||||
module.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
try:
|
||||
# For version compatibility, remove this after we abort vllm v0.9.1 support
|
||||
from vllm.model_executor.models.interfaces import \
|
||||
has_step_pooler # type: ignore
|
||||
if has_step_pooler(self.model):
|
||||
self.input_batch.logits_processing_needs_token_ids = True
|
||||
except ImportError:
|
||||
pass
|
||||
if has_step_pooler(self.model):
|
||||
self.input_batch.logits_processing_needs_token_ids = True
|
||||
if self.drafter:
|
||||
logger.info("Loading drafter model...")
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
@@ -2362,14 +2256,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Skip requests that require top-p, top-k, etc.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
if vllm_version_is("0.9.1"):
|
||||
if not is_spec_decode_supported(req_id, self.input_batch):
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
else:
|
||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Add sampled_token_ids to token_ids_cpu.
|
||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||
|
||||
@@ -28,15 +28,13 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
from vllm_ascend.pool.metadata import PoolingMetadata
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if not vllm_version_is("0.9.1"):
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@@ -253,17 +251,13 @@ class InputBatch:
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
if not vllm_version_is("0.9.1"):
|
||||
from vllm.v1.sample.logits_processor import \
|
||||
init_builtin_logitsprocs
|
||||
|
||||
# Define logits processors.
|
||||
# TODO(andy): logits processor list should be extensible via engine
|
||||
# constructor argument; for now the list is fixed.
|
||||
self.logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=pin_memory,
|
||||
max_num_reqs=max_num_reqs + 1,
|
||||
device=device)
|
||||
# Define logits processors.
|
||||
# TODO(andy): logits processor list should be extensible via engine
|
||||
# constructor argument; for now the list is fixed.
|
||||
self.logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=pin_memory,
|
||||
max_num_reqs=max_num_reqs + 1,
|
||||
device=device)
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
@@ -314,8 +308,8 @@ class InputBatch:
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if ((not vllm_version_is("0.9.1")) and self.is_spec_decode
|
||||
and is_spec_decode_unsupported(sampling_params)):
|
||||
if self.is_spec_decode and is_spec_decode_unsupported(
|
||||
sampling_params):
|
||||
self.spec_decode_unsupported_reqs.add(req_id)
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
@@ -641,48 +635,24 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask, num_reqs)
|
||||
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
||||
|
||||
if vllm_version_is("0.9.1"):
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
min_p=None if self.no_min_p else self.min_p[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]],
|
||||
self.req_output_token_ids),
|
||||
min_tokens=self.min_tokens,
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
)
|
||||
else:
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]],
|
||||
self.req_output_token_ids),
|
||||
no_penalties=self.no_penalties,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
logitsprocs=self.logitsprocs,
|
||||
)
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
no_penalties=self.no_penalties,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
logitsprocs=self.logitsprocs,
|
||||
)
|
||||
|
||||
@property
|
||||
def pooling_metadata(self) -> PoolingMetadata:
|
||||
|
||||
Reference in New Issue
Block a user