Sync from v0.13
This commit is contained in:
0
tests/v1/logits_processors/__init__.py
Normal file
0
tests/v1/logits_processors/__init__.py
Normal file
706
tests/v1/logits_processors/test_correctness.py
Normal file
706
tests/v1/logits_processors/test_correctness.py
Normal file
@@ -0,0 +1,706 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from typing import NamedTuple, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from tests.v1.sample.utils import (
|
||||
LogitsprocsTestFakes,
|
||||
create_fake_logits,
|
||||
create_penalty_tensor,
|
||||
create_prompt_tokens_tensor,
|
||||
fake_apply_logitsprocs,
|
||||
fake_update_logitsprocs_state,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
BatchUpdate,
|
||||
BatchUpdateBuilder,
|
||||
LogitBiasLogitsProcessor,
|
||||
LogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor,
|
||||
MoveDirectionality,
|
||||
build_logitsprocs,
|
||||
)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
|
||||
MAX_NUM_REQS = 256
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
CUDA_DEVICES = [
|
||||
f"{current_platform.device_type}:{i}"
|
||||
for i in range(1 if current_platform.device_count() == 1 else 2)
|
||||
]
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
MIN_TOKENS_LEN_THRESHOLD = 5
|
||||
REQS_PER_LOGITPROC = 50
|
||||
STR_NO_LOGITPROC = "none"
|
||||
|
||||
# LogitsProcessor subclass or "none"
|
||||
LogitprocType: TypeAlias = type[LogitsProcessor] | str
|
||||
|
||||
|
||||
class LogitsProcsRequestParams:
|
||||
"""Encapsulates key params for a single request in a batch.
|
||||
|
||||
Params can be customized based on the enabled logitproc
|
||||
"""
|
||||
|
||||
workload_index: int
|
||||
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
|
||||
out_tokens: list[int] # Output tokens required for min tokens test
|
||||
prompt_tokens: list[int] # Dummy prompt tokens placeholder
|
||||
params: SamplingParams # Settings customized for logitproc
|
||||
|
||||
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
|
||||
self.workload_index = workload_index
|
||||
self.logitproc_type = logitproc_type
|
||||
# Number of output tokens is randomly 0 or twice the min-tokens
|
||||
# threshold which will be used in testing. Output token values
|
||||
# don't matter *for these tests* so use 0 as a dummy value
|
||||
self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))
|
||||
self.prompt_tokens = []
|
||||
self.params = _sampling_params_from_logitproc(logitproc_type)
|
||||
|
||||
def __str__(self):
|
||||
"""For debugging"""
|
||||
summ = ", ".join(f"{k}={v}" for k, v in vars(self).items())
|
||||
return f"MyClass({summ})"
|
||||
|
||||
|
||||
def _generate_fake_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> SamplingMetadata:
|
||||
"""Generate fake sampling metadata with fake logitsprocs"""
|
||||
output_token_ids: list[list[int]] = []
|
||||
prompt_token_ids: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
output_token_ids.append(
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist()
|
||||
)
|
||||
prompt_token_ids.append(
|
||||
np.random.randint(
|
||||
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
|
||||
).tolist()
|
||||
)
|
||||
logitsprocs = build_logitsprocs(
|
||||
vllm_config=VllmConfig(),
|
||||
device=device,
|
||||
is_pin_memory=PIN_MEMORY_AVAILABLE,
|
||||
is_pooling_model=False,
|
||||
)
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=torch.full((batch_size,), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=create_prompt_tokens_tensor(
|
||||
prompt_token_ids, vocab_size, device
|
||||
),
|
||||
output_token_ids=output_token_ids,
|
||||
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=create_penalty_tensor(batch_size, 1.0, device),
|
||||
no_penalties=True,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=logitsprocs,
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
|
||||
"""Generate fake logits and sampling metadata"""
|
||||
fake_logits = create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
# Create one dominant token per batch, to support min-p test
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, 0] = 10.0 # High logit for first token
|
||||
fake_logits[i, 1:] = 1e-2 # Others remain low
|
||||
sampling_metadata = _generate_fake_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
return LogitsprocsTestFakes(
|
||||
logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _sampling_params_from_logitproc(logitproc_type: LogitprocType) -> SamplingParams:
|
||||
"""Customize request SamplingParams for a specified logitproc"""
|
||||
# SamplingParams for req with no logitproc
|
||||
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
|
||||
if fxn := logitsprocs_test_mapping[logitproc_type].gen_request_fxn:
|
||||
fxn(kwargs)
|
||||
return SamplingParams(**kwargs)
|
||||
|
||||
|
||||
def _generate_mixed_logitsprocs_batch_params(
|
||||
reqs_per_logitproc: int,
|
||||
logitsprocs_types: list[str],
|
||||
) -> list[LogitsProcsRequestParams]:
|
||||
"""Define key params for a batch of requests with a different
|
||||
logitproc enabled per request.
|
||||
|
||||
The batch will have `reqs_per_logitproc` repeats for all
|
||||
`logitsprocs_types` under test, including the case where
|
||||
no logitsproc is enabled. The batch is randomly shuffled. The
|
||||
size of the batch is `reqs_per_logitproc` times
|
||||
`n = len(logitsprocs_types)`
|
||||
|
||||
Args:
|
||||
reqs_per_logitproc: number of requests using each logitproc
|
||||
logitsprocs_types: logitsprocs under test
|
||||
|
||||
Returns:
|
||||
List of per-request params which configure the engine for that request's
|
||||
enabled logitproc
|
||||
"""
|
||||
batch_size = len(logitsprocs_types) * reqs_per_logitproc
|
||||
# Generate multiple repeats of key params for each logitproc;
|
||||
# apply random inverse permutation to the iteration
|
||||
# over logitsprocs, such that logitsprocs are shuffled.
|
||||
batch_perm = random.sample(range(batch_size), k=batch_size)
|
||||
return [
|
||||
LogitsProcsRequestParams(
|
||||
workload_index=idx,
|
||||
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc],
|
||||
)
|
||||
for idx, pdx in enumerate(batch_perm)
|
||||
]
|
||||
|
||||
|
||||
def _raise_error_invalid(
|
||||
msg_suffix: str,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
err_cls: type[Exception] = ValueError,
|
||||
) -> None:
|
||||
raise err_cls(
|
||||
f"Validation failed for step={step_idx}, "
|
||||
f"batch_index={batch_index}, "
|
||||
f"workload_index={request_params.workload_index}, "
|
||||
f"req_params={request_params}. Reason: {msg_suffix}"
|
||||
)
|
||||
|
||||
|
||||
def _logit_bias_params(kwargs: dict) -> None:
|
||||
"""Logit bias config"""
|
||||
kwargs["logit_bias"] = {
|
||||
random.randint(0, VOCAB_SIZE - 1): random.choice([-0.1, 0.2])
|
||||
}
|
||||
|
||||
|
||||
def _logit_bias_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate logit bias logitproc applied correctly"""
|
||||
logit_bias = request_params.params.logit_bias
|
||||
logits_old = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()
|
||||
logits_new = logits_new[batch_index].cpu()
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
logit_old_value = logits_old[token_id]
|
||||
logit_new_value = logits_new[token_id]
|
||||
if token_id in logit_bias:
|
||||
bias_value = logit_bias[token_id]
|
||||
exp_value = bias_value + logit_old_value
|
||||
if logit_new_value != pytest.approx(exp_value):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Biased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {exp_value} "
|
||||
f"given bias {bias_value}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
else:
|
||||
if logit_new_value != pytest.approx(logit_old_value):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Unbiased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {logit_old_value}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
def _min_p_params(kwargs: dict) -> None:
|
||||
"""Min-p logitproc config"""
|
||||
kwargs["min_p"] = 0.1
|
||||
|
||||
|
||||
def _min_p_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate min-p logitproc applied correctly"""
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
logits_for_token = logits_new[batch_index][token_id]
|
||||
if token_id == 0:
|
||||
# Dominant token should always be unmasked
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix="Invalid: dominant token 0 masked (-inf)",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
else:
|
||||
if request_params.params.min_p > 0.0:
|
||||
# Non-dominant tokens should be masked when min_p > 0
|
||||
if logits_for_token != -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=f"Invalid: non-dominant token {token_id} not masked",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
else:
|
||||
# No masking when min_p is 0
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=f"Invalid: token {token_id} masked when min_p=0.0",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
def _min_tokens_params(kwargs: dict) -> None:
|
||||
"""Min-tokens logitproc config"""
|
||||
kwargs["min_tokens"] = MIN_TOKENS_LEN_THRESHOLD
|
||||
kwargs["stop_token_ids"] = [
|
||||
np.random.randint(0, VOCAB_SIZE - 1)
|
||||
for _ in range(np.random.randint(0, VOCAB_SIZE))
|
||||
]
|
||||
|
||||
|
||||
def _min_tokens_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate min-tokens logitsproc applied correctly"""
|
||||
ref_num_out_tokens = len(request_params.out_tokens)
|
||||
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
|
||||
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
|
||||
mt_lp: MinTokensLogitsProcessor = next(
|
||||
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor)
|
||||
)
|
||||
assert isinstance(mt_lp, MinTokensLogitsProcessor)
|
||||
min_tok = mt_lp.min_toks.get(batch_index, None)
|
||||
|
||||
# Validate min-token logits processor state
|
||||
if min_tok:
|
||||
(_, out_tok, all_stop_token_ids) = min_tok
|
||||
num_out_tokens = len(out_tok)
|
||||
if num_out_tokens != ref_num_out_tokens:
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Number of output tokens in min-token logit processor "
|
||||
f"request metadata ({num_out_tokens}) does not match "
|
||||
f"reference ({ref_num_out_tokens})."
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
if ref_all_stop_token_ids != all_stop_token_ids:
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Stop token ids do not match reference; all_stop_token_ids: "
|
||||
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
|
||||
f"{sorted(ref_all_stop_token_ids)}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
if min_reached:
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Expected min-tokens request with min reached, but batch "
|
||||
"index is recognized by min-tokens logits processor."
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError,
|
||||
)
|
||||
|
||||
elif not min_reached:
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Expected min-tokens request with min not reached, but batch "
|
||||
"index is not recognized by min-tokens logits processor."
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError,
|
||||
)
|
||||
|
||||
# Validate min-token logits
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
logits_for_token = logits_new[batch_index][token_id]
|
||||
if token_id in ref_all_stop_token_ids and not min_reached:
|
||||
if logits_for_token != -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Token {token_id} is a stop token and "
|
||||
"the sequence has not reached min length, "
|
||||
"but the token is not masked "
|
||||
f"(logit={logits_for_token})"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
else:
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Token {token_id} should not be masked but "
|
||||
f"is (output len={ref_num_out_tokens})"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
def _none_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate that no logits processors are applied"""
|
||||
logits = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()
|
||||
ref_logits = logits_new[batch_index]
|
||||
if not torch.all(ref_logits == logits):
|
||||
mismatch_toks = (ref_logits != logits).nonzero(as_tuple=True)[0].tolist()
|
||||
mismatch_strs = []
|
||||
for token in mismatch_toks:
|
||||
val = float(logits[token])
|
||||
ref_val = float(ref_logits[token])
|
||||
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Unexpected modification of logits: {','.join(mismatch_strs)}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
class LogitsprocTestHelpers(NamedTuple):
|
||||
"""Supports setting up and validating logitsprocs unit tests."""
|
||||
|
||||
eval_fxn: Callable
|
||||
gen_request_fxn: Callable | None = None
|
||||
|
||||
|
||||
logitsprocs_test_mapping = {
|
||||
STR_NO_LOGITPROC: LogitsprocTestHelpers(eval_fxn=_none_validate),
|
||||
LogitBiasLogitsProcessor: LogitsprocTestHelpers(
|
||||
gen_request_fxn=_logit_bias_params, eval_fxn=_logit_bias_validate
|
||||
),
|
||||
MinPLogitsProcessor: LogitsprocTestHelpers(
|
||||
gen_request_fxn=_min_p_params, eval_fxn=_min_p_validate
|
||||
),
|
||||
MinTokensLogitsProcessor: LogitsprocTestHelpers(
|
||||
gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_test_cases() -> list[list[str]]:
|
||||
"""Each test case is a set of logitsprocs"""
|
||||
logitsprocs_types = list(logitsprocs_test_mapping.keys())
|
||||
return (
|
||||
[[STR_NO_LOGITPROC]]
|
||||
+ [
|
||||
[logitproc_type, STR_NO_LOGITPROC]
|
||||
for logitproc_type in logitsprocs_types
|
||||
if logitproc_type != STR_NO_LOGITPROC
|
||||
]
|
||||
+ [logitsprocs_types]
|
||||
)
|
||||
|
||||
|
||||
def _generate_fake_step_update(
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
workload_params: list[LogitsProcsRequestParams],
|
||||
wdx: int,
|
||||
batch_update_builder: BatchUpdateBuilder,
|
||||
) -> tuple[BatchUpdate | None, int, int]:
|
||||
batch_size = len(persistent_batch)
|
||||
workload_size = len(workload_params)
|
||||
workload_reqs_remaining = workload_size - wdx
|
||||
max_add_remove_per_step = max(1, int(0.2 * workload_size))
|
||||
|
||||
# 50% of steps: add no reqs
|
||||
# Other 50%: add a limited number of reqs (less than the number
|
||||
# of workload reqs remaining, less than an arbitrary max)
|
||||
# If no workload reqs remain: 100% of steps have 0 adds
|
||||
num_step_add = (
|
||||
random.choice(
|
||||
[
|
||||
0,
|
||||
random.randint(
|
||||
1, min(max_add_remove_per_step, workload_reqs_remaining)
|
||||
),
|
||||
]
|
||||
)
|
||||
if workload_reqs_remaining
|
||||
else 0
|
||||
)
|
||||
|
||||
# 50% of steps: remove no requests
|
||||
# Other 50%: remove a limited number of reqs (less than the number
|
||||
# persistent batch reqs remaining, less than an arbitrary max)
|
||||
# If persistent batch is empty: 100% of steps have 0 removals until
|
||||
# more requests are added. Assume that removed requests are always
|
||||
# drawn from the current batch, before new adds
|
||||
num_step_remove = (
|
||||
random.choice([0, random.randint(1, min(max_add_remove_per_step, batch_size))])
|
||||
if batch_size
|
||||
else 0
|
||||
)
|
||||
|
||||
num_step_add_replace = min(num_step_add, num_step_remove)
|
||||
|
||||
# Generate fake removed request indices drawn from persistent batch indices
|
||||
for removal in random.sample(range(batch_size), num_step_remove):
|
||||
batch_update_builder.removed_append(removal)
|
||||
|
||||
# Get added requests from workload
|
||||
for add_req_params in workload_params[wdx : (wdx + num_step_add_replace)]:
|
||||
# Replace as many removed requests as possible with added requests
|
||||
add_remove_idx = batch_update_builder.pop_removed()
|
||||
batch_update_builder.added.append(
|
||||
(
|
||||
add_remove_idx,
|
||||
add_req_params.params,
|
||||
add_req_params.prompt_tokens,
|
||||
add_req_params.out_tokens,
|
||||
)
|
||||
)
|
||||
persistent_batch[add_remove_idx] = add_req_params
|
||||
|
||||
# Append remaining added requests to end of batch
|
||||
add_reqs_append = workload_params[
|
||||
(wdx + num_step_add_replace) : (wdx + num_step_add)
|
||||
]
|
||||
batch_update_builder.added.extend(
|
||||
[
|
||||
(
|
||||
adx + batch_size,
|
||||
add_req_params.params,
|
||||
add_req_params.prompt_tokens,
|
||||
add_req_params.out_tokens,
|
||||
)
|
||||
for adx, add_req_params in enumerate(add_reqs_append)
|
||||
]
|
||||
)
|
||||
persistent_batch.extend(add_reqs_append)
|
||||
pre_condense_batch_size = len(persistent_batch)
|
||||
wdx += num_step_add # Update workload offset
|
||||
|
||||
# Simulate condensing persistent batch
|
||||
last_nonempty_index = pre_condense_batch_size - 1
|
||||
condensed_to_idxs = set()
|
||||
while batch_update_builder.removed:
|
||||
if (
|
||||
last_nonempty_index in batch_update_builder.removed
|
||||
or last_nonempty_index in condensed_to_idxs
|
||||
):
|
||||
last_nonempty_index -= 1
|
||||
continue
|
||||
# last_nonempty_index is the highest persistent batch index that was
|
||||
# not removed
|
||||
first_empty_index = batch_update_builder.peek_removed()
|
||||
assert first_empty_index is not None
|
||||
if first_empty_index > last_nonempty_index:
|
||||
break
|
||||
# first_empty_index is the lowest removed persistent batch index
|
||||
# that is less than last_nonempty_index
|
||||
#
|
||||
# move last_nonempty_index -> first_empty_index
|
||||
batch_update_builder.pop_removed()
|
||||
condensed_to_idxs.add(first_empty_index)
|
||||
persistent_batch[first_empty_index] = persistent_batch[last_nonempty_index]
|
||||
batch_update_builder.moved.append(
|
||||
(last_nonempty_index, first_empty_index, MoveDirectionality.UNIDIRECTIONAL)
|
||||
)
|
||||
|
||||
last_nonempty_index -= 1
|
||||
|
||||
# Now removed requests & gaps left by non-removed requests that got
|
||||
# moved downward are grouped consecutively in the upper indices of
|
||||
# the persistent batch. Truncate them to get condensed persistent batch
|
||||
condensed_batch_size = batch_size + num_step_add - num_step_remove
|
||||
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
|
||||
|
||||
if condensed_batch_size > 1:
|
||||
# Simulate arbitrary batch ordering in the kernel backend
|
||||
# Generate a random number k of non-overlapping swap tuples
|
||||
k = random.randint(0, condensed_batch_size // 2)
|
||||
idxs = list(range(condensed_batch_size))
|
||||
random.shuffle(idxs)
|
||||
swaps = [tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)]
|
||||
batch_update_builder.moved.extend(
|
||||
[(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps]
|
||||
)
|
||||
for adx, bdx in swaps:
|
||||
persistent_batch[adx], persistent_batch[bdx] = (
|
||||
persistent_batch[bdx],
|
||||
persistent_batch[adx],
|
||||
)
|
||||
|
||||
return (
|
||||
batch_update_builder.get_and_reset(condensed_batch_size),
|
||||
wdx,
|
||||
workload_size - wdx,
|
||||
)
|
||||
|
||||
|
||||
def _assert_valid(
|
||||
batch_size: int,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
slice_idxs: list[int],
|
||||
logits_w_lp: torch.Tensor,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
if not slice_idxs:
|
||||
# Trivial case of empty persistent batch
|
||||
assert len(persistent_batch) == 0
|
||||
if logits_w_lp.shape[0] != 0:
|
||||
raise ValueError(
|
||||
"Fake persistent batch is empty but logitsprocs "
|
||||
f"output batch has shape {logits_w_lp.shape}"
|
||||
)
|
||||
return
|
||||
|
||||
# Validate logits for each fake request
|
||||
for batch_index in range(batch_size):
|
||||
request_params = persistent_batch[batch_index]
|
||||
# Invoke the appropriate validation function for
|
||||
# the logitproc employed by this request
|
||||
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
|
||||
fxn(
|
||||
test_fakes=test_fakes,
|
||||
persistent_batch=persistent_batch,
|
||||
logits_new=logits_w_lp,
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
|
||||
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
|
||||
def test_logitsprocs(
|
||||
device: str, reqs_per_logitproc: int, logitsprocs_under_test: list[str]
|
||||
):
|
||||
random.seed(40)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Define a shuffled batch of requests which individually use a different
|
||||
# logitproc, or no logitproc at all
|
||||
workload_params = _generate_mixed_logitsprocs_batch_params(
|
||||
reqs_per_logitproc=reqs_per_logitproc, logitsprocs_types=logitsprocs_under_test
|
||||
)
|
||||
workload_size = len(workload_params)
|
||||
|
||||
# Create fake test data structures for testing.
|
||||
test_fakes = _generate_test_fakes(workload_size, device)
|
||||
|
||||
wdx = 0 # Next request index in workload to add
|
||||
persistent_batch: list[
|
||||
LogitsProcsRequestParams
|
||||
] = [] # Persistent batch state, as list of workload indices
|
||||
|
||||
# Generate fake removed request indices from current persistent
|
||||
# batch before adds
|
||||
batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# Break when entire workload has been added previously and persistent
|
||||
# batch is empty
|
||||
workload_reqs_remaining = workload_size
|
||||
batch_size = 0
|
||||
step_idx = 0
|
||||
while True:
|
||||
if not (workload_reqs_remaining or batch_size):
|
||||
break
|
||||
|
||||
(
|
||||
batch_update,
|
||||
wdx,
|
||||
workload_reqs_remaining,
|
||||
) = _generate_fake_step_update(
|
||||
persistent_batch=persistent_batch,
|
||||
workload_params=workload_params,
|
||||
wdx=wdx,
|
||||
batch_update_builder=batch_update_builder,
|
||||
)
|
||||
batch_size = len(persistent_batch)
|
||||
|
||||
# Apply fake batch update to logitsprocs
|
||||
fake_update_logitsprocs_state(test_fakes, batch_update)
|
||||
|
||||
# Emulate application of logits processors in engine
|
||||
slice_idxs = [req.workload_index for req in persistent_batch]
|
||||
logits_w_lp = fake_apply_logitsprocs(test_fakes, slice_idxs).cpu()
|
||||
|
||||
_assert_valid(
|
||||
batch_size=batch_size,
|
||||
persistent_batch=persistent_batch,
|
||||
test_fakes=test_fakes,
|
||||
slice_idxs=slice_idxs,
|
||||
logits_w_lp=logits_w_lp,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
step_idx += 1
|
||||
295
tests/v1/logits_processors/test_custom_offline.py
Normal file
295
tests/v1/logits_processors/test_custom_offline.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
MAX_TOKENS,
|
||||
MODEL_NAME,
|
||||
POOLING_MODEL_NAME,
|
||||
TEMP_GREEDY,
|
||||
CustomLogitprocSource,
|
||||
DummyLogitsProcessor,
|
||||
WrappedPerReqLogitsProcessor,
|
||||
prompts,
|
||||
)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
STR_POOLING_REJECTS_LOGITSPROCS,
|
||||
STR_SPEC_DEC_REJECTS_LOGITSPROCS,
|
||||
LogitsProcessor,
|
||||
)
|
||||
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(
|
||||
temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 128},
|
||||
),
|
||||
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
|
||||
SamplingParams(
|
||||
temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 67},
|
||||
),
|
||||
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
|
||||
]
|
||||
|
||||
|
||||
def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
|
||||
"""Compare `LLM` instance initialized with specified `kwargs` against
|
||||
reference `LLM` instance.
|
||||
|
||||
Two scenarios:
|
||||
1. Server has loaded dummy logitproc; test that requests which specify
|
||||
dummy logitproc arg value behave as if logitproc is operating (output
|
||||
token value should repeat), while requests that don't specify dummy
|
||||
logitproc arg value should match reference `LLM` output.
|
||||
2. Server has *not* loaded dummy logitproc; test that all requests
|
||||
behave as if logitproc is *not* operating (output matches reference
|
||||
`LLM` output.)
|
||||
|
||||
Args:
|
||||
kwargs: `LLM` constructor kwargs
|
||||
logitproc_loaded: server has loaded dummy logitproc if True
|
||||
"""
|
||||
|
||||
# Create a vLLM instance and load custom logitproc
|
||||
llm_logitproc = LLM(
|
||||
model=MODEL_NAME,
|
||||
gpu_memory_utilization=0.1,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create a reference vLLM instance without custom logitproc
|
||||
llm_ref = LLM(model=MODEL_NAME, gpu_memory_utilization=0.1)
|
||||
|
||||
# Run inference with logitproc loaded
|
||||
outputs_logitproc = llm_logitproc.generate(prompts, sampling_params_list)
|
||||
|
||||
# Reference run
|
||||
outputs_ref = llm_ref.generate(prompts, sampling_params_list)
|
||||
|
||||
# Validate outputs
|
||||
for bdx, (out_lp, out_ref, params) in enumerate(
|
||||
zip(outputs_logitproc, outputs_ref, sampling_params_list)
|
||||
):
|
||||
lp_toks = out_lp.outputs[0].token_ids
|
||||
if logitproc_loaded and params.extra_args:
|
||||
# This request exercises custom logitproc; validate that logitproc
|
||||
# forces `target_token` to be decoded in each step
|
||||
target_token = params.extra_args[DUMMY_LOGITPROC_ARG]
|
||||
if not all(x == target_token for x in lp_toks):
|
||||
raise AssertionError(
|
||||
f"Request {bdx} generated {lp_toks}, should all be {target_token}"
|
||||
)
|
||||
else:
|
||||
# This request does not exercise custom logitproc (or custom
|
||||
# logitproc is not enabled on this server); validate against
|
||||
# reference result
|
||||
ref_toks = out_ref.outputs[0].token_ids
|
||||
if lp_toks != ref_toks:
|
||||
raise AssertionError(
|
||||
f"Request {bdx} generated {lp_toks}, should match {ref_toks}"
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource))
|
||||
def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource):
|
||||
"""Test offline Python interface for passing custom logitsprocs
|
||||
|
||||
Construct an `LLM` instance which loads a custom logitproc that has a
|
||||
well-defined behavior (mask out all tokens except one `target_token`)
|
||||
|
||||
Construct a reference `LLM` instance with no custom logitproc
|
||||
|
||||
Pass in a batch of requests, 50% of which pass a `target_token` value
|
||||
in through `SamplingParams.extra_args`, 50% of which do not.
|
||||
|
||||
Validate that
|
||||
* Requests which do not activate the custom logitproc, yield the same
|
||||
results for both `LLM` instances
|
||||
* Requests which activate the custom logitproc, only output `target_token`
|
||||
|
||||
Test four scenarios, corresponding to `logitproc_source` value
|
||||
* No logitsprocs loaded - test that generated tokens match reference `LLM`
|
||||
instance output
|
||||
* Logitproc passed in via {entrypoint, class object, fully-qualified class
|
||||
name (FQCN)} - test that dummy logitproc is utilized correctly when
|
||||
provided via any of these three possible sources
|
||||
|
||||
Args:
|
||||
monkeypatch: for setting env vars
|
||||
logitproc_source: what source (entrypoint, fully-qualified class name
|
||||
(FQCN), class object, or None) the user pulls the
|
||||
logitproc from
|
||||
"""
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
random.seed(40)
|
||||
|
||||
# Choose LLM args based on logitproc source
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE:
|
||||
# Scenario: the server does not load any custom logitproc
|
||||
# Every other scenario is a different way of loading a custom logitproc
|
||||
_run_test({}, logitproc_loaded=False)
|
||||
return
|
||||
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
|
||||
# Scenario: vLLM loads a logitproc from a preconfigured entrypoint
|
||||
# To that end, mock a dummy logitproc entrypoint
|
||||
import importlib.metadata
|
||||
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
|
||||
_run_test({}, logitproc_loaded=True)
|
||||
return
|
||||
|
||||
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||
# Scenario: load logitproc from provided class object
|
||||
kwargs["logits_processors"] = [DummyLogitsProcessor]
|
||||
|
||||
_run_test(kwargs, logitproc_loaded=True)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_logitsprocs_req(monkeypatch):
|
||||
"""Test passing request-level logits processor to offline Python interface
|
||||
|
||||
Wrap a request-level logits processor to create a batch level logits
|
||||
processor that has a well-defined behavior (mask out all tokens except one
|
||||
`target_token`)
|
||||
|
||||
Construct an `LLM` instance which loads the wrapped logits processor. Pass
|
||||
the custom logitproc as a class object.
|
||||
|
||||
Construct a reference `LLM` instance with no custom logitproc
|
||||
|
||||
Pass in a batch of requests, 50% of which pass a `target_token` value
|
||||
in through `SamplingParams.extra_args`, 50% of which do not.
|
||||
|
||||
Validate that
|
||||
* Requests which do not activate the custom logitproc, yield the same
|
||||
results for both `LLM` instances
|
||||
* Requests which activate the custom logitproc, only output `target_token`
|
||||
|
||||
Args:
|
||||
monkeypatch: for setting env vars
|
||||
"""
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
random.seed(40)
|
||||
_run_test(
|
||||
{"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"])
|
||||
@pytest.mark.parametrize(
|
||||
"logitproc_source",
|
||||
[
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_FQCN,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
|
||||
],
|
||||
)
|
||||
def test_rejects_custom_logitsprocs(
|
||||
monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource
|
||||
):
|
||||
"""Validate that vLLM engine initialization properly rejects custom
|
||||
logitsprocs when the model is a pooling model or speculative decoding
|
||||
enabled.
|
||||
|
||||
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
|
||||
logitproc is actually loaded.
|
||||
|
||||
Scenario 1:
|
||||
* Mock a logitproc entrypoint
|
||||
* Validate that `LLM` does not load the logitproc
|
||||
|
||||
Scenario 2:
|
||||
* Pass custom logitproc to `LLM` constructor
|
||||
* Scenario 2a: via FQCN
|
||||
* Scenario 2b: via class object
|
||||
* Validate that initialization fails with appropriate exception
|
||||
|
||||
Args:
|
||||
monkeypatch: used to set environment variables
|
||||
logitproc_source: what source (entrypoint, fully-qualified class name
|
||||
(FQCN), or class object) the user pulls the
|
||||
logitproc from
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
random.seed(40)
|
||||
|
||||
test_params: dict[str, dict[str, Any]] = {
|
||||
"pooling": {
|
||||
"runner": "pooling",
|
||||
"model": POOLING_MODEL_NAME,
|
||||
"error_message": STR_POOLING_REJECTS_LOGITSPROCS,
|
||||
"speculative_config": None,
|
||||
},
|
||||
"spec_dec": {
|
||||
"runner": "auto",
|
||||
"model": MODEL_NAME,
|
||||
"error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS,
|
||||
"speculative_config": {"model": "ngram", "num_speculative_tokens": 1},
|
||||
},
|
||||
}
|
||||
|
||||
config = test_params[model_scenario]
|
||||
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"runner": config["runner"],
|
||||
"model": config["model"],
|
||||
"gpu_memory_utilization": 0.1,
|
||||
"speculative_config": config["speculative_config"],
|
||||
}
|
||||
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
|
||||
# Scenario: vLLM loads a model and ignores a logitproc that is
|
||||
# available at a preconfigured entrypoint
|
||||
|
||||
# Patch in dummy logitproc entrypoint
|
||||
import importlib.metadata
|
||||
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
|
||||
# fork is required for entrypoint patch to be visible to workers,
|
||||
# although they should ignore the entrypoint patch anyway
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
|
||||
|
||||
llm = LLM(**llm_kwargs)
|
||||
# Require that no logitsprocs have been loaded
|
||||
worker = llm.llm_engine.model_executor.driver_worker.worker
|
||||
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0
|
||||
return
|
||||
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||
llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||
# Scenario: load logitproc from provided class object
|
||||
llm_kwargs["logits_processors"] = [DummyLogitsProcessor]
|
||||
|
||||
with pytest.raises(ValueError, match=config["error_message"]):
|
||||
# Require that loading a model alongside the logitproc raises
|
||||
# the appropriate exception.
|
||||
LLM(**llm_kwargs)
|
||||
200
tests/v1/logits_processors/test_custom_online.py
Normal file
200
tests/v1/logits_processors/test_custom_online.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_test
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
MAX_TOKENS,
|
||||
MODEL_NAME,
|
||||
TEMP_GREEDY,
|
||||
prompts,
|
||||
)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
|
||||
|
||||
def _server_with_logitproc_entrypoint(
|
||||
env_dict: dict[str, str] | None,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
) -> None:
|
||||
"""Start vLLM server, inject dummy logitproc entrypoint"""
|
||||
|
||||
# Patch `entry_points` to inject logitproc entrypoint
|
||||
import importlib.metadata
|
||||
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
# Emulate `vllm serve <model> <CLI args>`
|
||||
sys.argv = ["vllm", "serve", model] + vllm_serve_args
|
||||
main.main()
|
||||
|
||||
|
||||
def _server_with_logitproc_fqcn(
|
||||
env_dict: dict[str, str] | None,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
) -> None:
|
||||
"""Start vLLM server, inject module with dummy logitproc"""
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
# Emulate `vllm serve <model> <CLI args>`
|
||||
sys.argv = ["vllm", "serve", model] + vllm_serve_args
|
||||
main.main()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function", params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]
|
||||
)
|
||||
def server(default_server_args, request, monkeypatch):
|
||||
"""Consider two server configurations:
|
||||
(1) --logits-processors cli arg specifies dummy logits processor via fully-
|
||||
qualified class name (FQCN); patch in a dummy logits processor module
|
||||
(2) No --logits-processors cli arg; patch in a dummy logits processor
|
||||
entrypoint
|
||||
"""
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
|
||||
if request.param:
|
||||
# Launch server, append FQCN argument, inject dummy logitproc module
|
||||
args = default_server_args + request.param
|
||||
_server_fxn = _server_with_logitproc_fqcn
|
||||
else:
|
||||
# Launch server, inject dummy logitproc entrypoint
|
||||
args = default_server_args
|
||||
_server_fxn = _server_with_logitproc_entrypoint
|
||||
|
||||
with RemoteOpenAIServerCustom(MODEL_NAME, args, _server_fxn) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
# General request argument values for these tests
|
||||
api_keyword_args = {
|
||||
# Greedy sampling ensures that requests which receive the `target_token`
|
||||
# arg will decode it in every step
|
||||
"temperature": TEMP_GREEDY,
|
||||
# Since EOS will never be decoded (unless `target_token` is EOS)
|
||||
"max_tokens": MAX_TOKENS,
|
||||
# Return decoded token logprobs (as a way of getting token id)
|
||||
"logprobs": 0,
|
||||
}
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test custom logitsprocs when starting OpenAI server from CLI
|
||||
|
||||
Launch vLLM OpenAI-compatible server, configured to load a custom logitproc
|
||||
that has a well-defined behavior (mask out all tokens except one
|
||||
`target_token`).
|
||||
|
||||
Pass in requests, 50% of which pass a `target_token` value
|
||||
in through `extra_body["vllm_xargs"]`, 50% of which do not.
|
||||
|
||||
Validate that requests which activate the custom logitproc, repeat the same
|
||||
token
|
||||
"""
|
||||
|
||||
use_dummy_logitproc = True
|
||||
for prompt in prompts:
|
||||
# Build request arguments
|
||||
request_keyword_args: dict[str, Any] = {
|
||||
**api_keyword_args,
|
||||
}
|
||||
if use_dummy_logitproc:
|
||||
# 50% of requests pass target_token custom arg
|
||||
target_token = random.choice([128, 67])
|
||||
# For requests which activate the dummy logitproc, choose one of
|
||||
# two `target_token` values which are known not to be EOS tokens
|
||||
request_keyword_args["extra_body"] = {
|
||||
"vllm_xargs": {DUMMY_LOGITPROC_ARG: target_token}
|
||||
}
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
**request_keyword_args,
|
||||
)
|
||||
|
||||
if use_dummy_logitproc:
|
||||
# Only for requests which activate dummy logitproc - validate that
|
||||
# output token is repeated
|
||||
choices: openai.types.CompletionChoice = batch.choices
|
||||
toks = choices[0].logprobs.tokens
|
||||
if not all([x == toks[0] for x in toks]):
|
||||
raise AssertionError(f"Generated {toks} should all be {toks[0]}")
|
||||
|
||||
# Alternate whether to activate dummy logitproc for each request
|
||||
use_dummy_logitproc = not use_dummy_logitproc
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_invalid_custom_logitsproc_arg(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
"""Test that request with invalid custom logitsproc is rejected"""
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
# Pass invalid (non-int) target_token value to dummy logits processor
|
||||
request_keyword_args: dict[str, Any] = {
|
||||
**api_keyword_args,
|
||||
"extra_body": {
|
||||
"vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(openai.OpenAIError) as exc_info:
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
**request_keyword_args,
|
||||
)
|
||||
|
||||
assert "is not int" in str(exc_info.value)
|
||||
191
tests/v1/logits_processors/utils.py
Normal file
191
tests/v1/logits_processors/utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
LOGITSPROCS_GROUP,
|
||||
AdapterLogitsProcessor,
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
DUMMY_LOGITPROC_ARG = "target_token"
|
||||
TEMP_GREEDY = 0.0
|
||||
MAX_TOKENS = 20
|
||||
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
|
||||
DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils"
|
||||
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
||||
|
||||
|
||||
class CustomLogitprocSource(Enum):
|
||||
"""How to source a logitproc for testing purposes"""
|
||||
|
||||
LOGITPROC_SOURCE_NONE = auto() # No custom logitproc
|
||||
LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint
|
||||
LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN)
|
||||
LOGITPROC_SOURCE_CLASS = auto() # Via provided class object
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: int | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"target_token value {target_token} {type(target_token)} is not int"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
self.req_info: dict[int, int] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Never impacts greedy sampling"""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
def extract_extra_arg(params: SamplingParams) -> int | None:
|
||||
self.validate_params(params)
|
||||
return params.extra_args and params.extra_args.get("target_token")
|
||||
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
lambda params, _, __: extract_extra_arg(params),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.req_info:
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
cols = torch.tensor(
|
||||
list(self.req_info.values()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
rows = torch.tensor(
|
||||
list(self.req_info.keys()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float("-inf")
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
"""Dummy module with dummy logitproc class"""
|
||||
dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE)
|
||||
dummy_module.DummyLogitsProcessor = DummyLogitsProcessor # type: ignore
|
||||
|
||||
|
||||
class EntryPoint:
|
||||
"""Dummy entrypoint class for logitsprocs testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.name = DUMMY_LOGITPROC_ENTRYPOINT
|
||||
self.value = DUMMY_LOGITPROC_FQCN
|
||||
|
||||
def load(self):
|
||||
return DummyLogitsProcessor
|
||||
|
||||
|
||||
class EntryPoints(list):
|
||||
"""Dummy EntryPoints class for logitsprocs testing"""
|
||||
|
||||
def __init__(self, group: str):
|
||||
# Emulate list-like functionality
|
||||
eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else []
|
||||
super().__init__(eps)
|
||||
# Extra attributes
|
||||
self.names = [ep.name for ep in eps]
|
||||
|
||||
|
||||
class DummyPerReqLogitsProcessor:
|
||||
"""The request-level logits processor masks out all logits except the
|
||||
token id identified by `target_token`"""
|
||||
|
||||
def __init__(self, target_token: int) -> None:
|
||||
"""Specify `target_token`"""
|
||||
self.target_token = target_token
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
output_ids: list[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
val_to_keep = logits[self.target_token].item()
|
||||
logits[:] = float("-inf")
|
||||
logits[self.target_token] = val_to_keep
|
||||
return logits
|
||||
|
||||
|
||||
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of wrapping a fake request-level logit processor to create a
|
||||
batch-level logits processor"""
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
def new_req_logits_processor(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> RequestLogitsProcessor | None:
|
||||
"""This method returns a new request-level logits processor, customized
|
||||
to the `target_token` value associated with a particular request.
|
||||
|
||||
Returns None if the logits processor should not be applied to the
|
||||
particular request. To use the logits processor the request must have
|
||||
a "target_token" custom argument with an integer value.
|
||||
|
||||
Args:
|
||||
params: per-request sampling params
|
||||
|
||||
Returns:
|
||||
`Callable` request logits processor, or None
|
||||
"""
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
"""Fake version of importlib.metadata.entry_points"""
|
||||
entry_points = lambda group: EntryPoints(group)
|
||||
Reference in New Issue
Block a user