forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
0
vllm-v0.6.2/tests/tokenization/__init__.py
Normal file
0
vllm-v0.6.2/tests/tokenization/__init__.py
Normal file
22
vllm-v0.6.2/tests/tokenization/test_cached_tokenizer.py
Normal file
22
vllm-v0.6.2/tests/tokenization/test_cached_tokenizer.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
|
||||
|
||||
def test_cached_tokenizer():
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
|
||||
reference_tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": ["<SEP>"]})
|
||||
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
|
||||
|
||||
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
|
||||
"prompt")
|
||||
assert set(reference_tokenizer.all_special_ids) == set(
|
||||
cached_tokenizer.all_special_ids)
|
||||
assert set(reference_tokenizer.all_special_tokens) == set(
|
||||
cached_tokenizer.all_special_tokens)
|
||||
assert set(reference_tokenizer.all_special_tokens_extended) == set(
|
||||
cached_tokenizer.all_special_tokens_extended)
|
||||
320
vllm-v0.6.2/tests/tokenization/test_detokenize.py
Normal file
320
vllm-v0.6.2/tests/tokenization/test_detokenize.py
Normal file
@@ -0,0 +1,320 @@
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
||||
detokenize_incrementally)
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
|
||||
"我很感谢你的热情",
|
||||
# Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
|
||||
# for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
|
||||
# incomplete UTF-8 characters
|
||||
# see https://github.com/vllm-project/vllm/pull/9625
|
||||
"ပုံပြင်လေးပြောပြပါ်",
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
# "mistralai/Pixtral-12B-2409",
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool, starting_index: int):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(starting_index, len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
else:
|
||||
prev_tokens += new_tokens
|
||||
return decoded_text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(tokenizer_name):
|
||||
return (MistralTokenizer.from_pretrained(tokenizer_name)
|
||||
if "mistral" in tokenizer_name else
|
||||
AutoTokenizer.from_pretrained(tokenizer_name))
|
||||
|
||||
|
||||
@pytest.mark.skip("Do not support Pixtral-12B-2409.")
|
||||
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
|
||||
@pytest.mark.parametrize(
|
||||
"truth",
|
||||
[
|
||||
# Burmese text triggers an edge-case where tokens may map to bytes with
|
||||
# incomplete UTF-8 characters
|
||||
"ပုံပြင်လေးပြောပြပါ",
|
||||
# Using "URGENCY" since "CY" has token id 130282
|
||||
"URGENCY🌶️",
|
||||
])
|
||||
def test_mistral_edge_case(tokenizer, truth):
|
||||
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/9625
|
||||
"""
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
|
||||
decoded_text = _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
assert decoded_text == truth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
if "mistral" in tokenizer_name:
|
||||
yield (
|
||||
True if request.param else
|
||||
pytest.skip("mistral doesn't support skip_special_tokens=False"))
|
||||
else:
|
||||
yield bool(request.param)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
|
||||
if with_prompt:
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
||||
generated_input_ids = truth_tokens[len(truth) // 2:]
|
||||
all_input_ids = prompt_input_ids + generated_input_ids
|
||||
starting_index = len(prompt_input_ids)
|
||||
prompt = tokenizer.decode(prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
generated = truth[len(prompt):]
|
||||
else:
|
||||
generated = truth
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if skip_special_tokens:
|
||||
if tokenizer.bos_token_id is not None:
|
||||
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
||||
starting_index += 1
|
||||
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index)
|
||||
|
||||
assert decoded_text == generated
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
tokenizer, [len(tokenizer)],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index)
|
||||
|
||||
assert decoded_text == ''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
init_kwargs = dict(
|
||||
tokenizer_id=tokenizer_name,
|
||||
enable_lora=False,
|
||||
max_num_seqs=100,
|
||||
max_input_length=None,
|
||||
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
|
||||
trust_remote_code=False,
|
||||
revision=None,
|
||||
)
|
||||
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
None,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
return Detokenizer(tokenizer_group)
|
||||
|
||||
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||
tokenizer) -> List[int]:
|
||||
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
|
||||
return complete_sequence_token_ids
|
||||
|
||||
|
||||
def create_sequence(prompt_token_ids=None):
|
||||
prompt_token_ids = prompt_token_ids or [1]
|
||||
return Sequence(
|
||||
seq_id=0,
|
||||
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
|
||||
def create_dummy_logprobs(
|
||||
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
|
||||
return [{
|
||||
token_id: Logprob(logprob=0.0),
|
||||
token_id + 1: Logprob(logprob=0.1)
|
||||
} for token_id in complete_sequence_token_ids]
|
||||
|
||||
|
||||
def create_dummy_prompt_logprobs(
|
||||
complete_sequence_token_ids: List[int]
|
||||
) -> List[Optional[Dict[int, Any]]]:
|
||||
# logprob for the first prompt token is None.
|
||||
logprobs: List[Optional[Dict[int, Any]]] = [None]
|
||||
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
|
||||
return logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
|
||||
def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer,
|
||||
skip_special_tokens: bool):
|
||||
"""Verify Detokenizer decodes logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
logprobs=2)
|
||||
|
||||
# Run sequentially.
|
||||
seq = create_sequence()
|
||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||
sequential_logprobs_text_chosen_token: List[str] = []
|
||||
sequential_logprobs_text_other_token: List[str] = []
|
||||
for new_token, logprobs in zip(complete_sequence_token_ids,
|
||||
dummy_logprobs):
|
||||
seq.append_token_id(new_token, logprobs)
|
||||
detokenizer.decode_sequence_inplace(seq, sampling_params)
|
||||
sequential_logprobs_text_chosen_token.append(
|
||||
seq.output_logprobs[-1][new_token].decoded_token)
|
||||
sequential_logprobs_text_other_token.append(
|
||||
seq.output_logprobs[-1][new_token + 1].decoded_token)
|
||||
sequential_result = seq.output_text
|
||||
|
||||
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
||||
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
||||
|
||||
if skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# generated text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
assert sequential_result == complete_sequence
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer):
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=True,
|
||||
prompt_logprobs=1)
|
||||
|
||||
# Run sequentially.
|
||||
seq = create_sequence(complete_sequence_token_ids)
|
||||
seq_group = SequenceGroup(request_id="1",
|
||||
seqs=[seq],
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=0.0)
|
||||
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
|
||||
detokenizer.decode_prompt_logprobs_inplace(seq_group,
|
||||
dummy_logprobs,
|
||||
position_offset=0)
|
||||
# First logprob is None.
|
||||
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
|
||||
1:] # type: ignore
|
||||
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
||||
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
|
||||
text = text_full[len(text_first):]
|
||||
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# prompt text. Note that the first logprob is None.
|
||||
assert text == "".join([
|
||||
logprobs[token_id].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
assert text != "".join([
|
||||
logprobs[token_id + 1].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
|
||||
def test_decode_prompt_logprobs_chunked_prefill(
|
||||
vllm_runner,
|
||||
model,
|
||||
chunked_prefill_token_size: int,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype="half",
|
||||
max_logprobs=5,
|
||||
gpu_memory_utilization=0.5,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
|
||||
vllm_sampling_params = SamplingParams(max_tokens=10,
|
||||
logprobs=5,
|
||||
prompt_logprobs=5,
|
||||
temperature=0.0)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
for idx, result in enumerate(vllm_results):
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
|
||||
# Compared detokenized prompts ids to original prompt.
|
||||
generated_string = ""
|
||||
for (prompt_token,
|
||||
prompt_logprobs) in zip(result.prompt_token_ids[1:],
|
||||
result.prompt_logprobs[1:]):
|
||||
# prompt_logprobs is a dict of the token_id: logprob
|
||||
# We select the token_id corresponding to the actual prompt
|
||||
# Decoded token in the detokenized string corresponding to this
|
||||
# prompt token.
|
||||
generated_string += prompt_logprobs[prompt_token].decoded_token
|
||||
|
||||
assert generated_string == example_prompts[idx], (
|
||||
"Detokenized prompt logprobs do not match original prompt")
|
||||
31
vllm-v0.6.2/tests/tokenization/test_get_eos.py
Normal file
31
vllm-v0.6.2/tests/tokenization/test_get_eos.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
This test file includes some cases where it is inappropriate to
|
||||
only get the `eos_token_id` from the tokenizer as defined by
|
||||
:meth:`vllm.LLMEngine._get_eos_token_id`.
|
||||
"""
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def test_get_llama3_eos_token():
|
||||
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 128009
|
||||
|
||||
generation_config = try_get_generation_config(model_name,
|
||||
trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == [128001, 128009]
|
||||
|
||||
|
||||
def test_get_blip2_eos_token():
|
||||
model_name = "Salesforce/blip2-opt-2.7b"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 2
|
||||
|
||||
generation_config = try_get_generation_config(model_name,
|
||||
trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == 50118
|
||||
20
vllm-v0.6.2/tests/tokenization/test_tokenizer.py
Normal file
20
vllm-v0.6.2/tests/tokenization/test_tokenizer.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
TOKENIZER_NAMES = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
|
||||
def test_tokenizer_revision(tokenizer_name: str):
|
||||
# Assume that "main" branch always exists
|
||||
tokenizer = get_tokenizer(tokenizer_name, revision="main")
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase)
|
||||
|
||||
# Assume that "never" branch always does not exist
|
||||
with pytest.raises(OSError, match='not a valid git identifier'):
|
||||
get_tokenizer(tokenizer_name, revision="never")
|
||||
214
vllm-v0.6.2/tests/tokenization/test_tokenizer_group.py
Normal file
214
vllm-v0.6.2/tests/tokenization/test_tokenizer_group.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
RayTokenizerGroupPool)
|
||||
|
||||
from ..conftest import get_tokenizer_pool_config
|
||||
|
||||
|
||||
class CustomTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._i = 0
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
self._i += 1
|
||||
return super().encode(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type",
|
||||
[None, "ray", CustomTokenizerGroup])
|
||||
async def test_tokenizer_group(tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer_group.encode_async(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
||||
if tokenizer_group_type is CustomTokenizerGroup:
|
||||
assert tokenizer_group._i > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_pool(tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group_pool = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
# Send multiple requests to the tokenizer group pool
|
||||
# (more than the pool size)
|
||||
# and check that all requests are processed correctly.
|
||||
num_requests = tokenizer_group_pool.pool_size * 5
|
||||
requests = [
|
||||
tokenizer_group_pool.encode_async(request_id=str(i),
|
||||
prompt=f"prompt {i}",
|
||||
lora_request=None)
|
||||
for i in range(num_requests)
|
||||
]
|
||||
results = await asyncio.gather(*requests)
|
||||
expected_results = [
|
||||
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
|
||||
]
|
||||
assert results == expected_results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_ray_pool_env_var_propagation(
|
||||
tokenizer_group_type):
|
||||
"""Test that env vars from caller process are propagated to
|
||||
tokenizer Ray actors."""
|
||||
env_var = "MY_ENV_VAR"
|
||||
|
||||
class EnvVarCheckerTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def ping(self):
|
||||
assert os.environ.get(env_var) == "1"
|
||||
return super().ping()
|
||||
|
||||
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
|
||||
_worker_cls = EnvVarCheckerTokenizerGroup
|
||||
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None)
|
||||
with pytest.raises(AssertionError):
|
||||
tokenizer_pool.ping()
|
||||
|
||||
with patch.dict(os.environ, {env_var: "1"}):
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None)
|
||||
tokenizer_pool.ping()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
"""Test that Ray tokenizer pool group can recover from failures and
|
||||
if that's not possible, mark itself as unhealthy."""
|
||||
|
||||
class FailingTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
fail_at: Optional[List[int]] = None,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.i = 0
|
||||
self.fail_at = fail_at or []
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
self.i += 1
|
||||
if self.i in self.fail_at:
|
||||
sys.exit(1)
|
||||
return super().encode(*args, **kwargs)
|
||||
|
||||
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
|
||||
_worker_cls = FailingTokenizerGroup
|
||||
|
||||
# Fail at first iteration
|
||||
fail_at = [1]
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
fail_at=fail_at)
|
||||
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||
|
||||
# Modify fail at to not fail at all (will be re-read when actor is
|
||||
# re-initialized).
|
||||
fail_at[0] = 1000
|
||||
|
||||
# We should recover successfully.
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
|
||||
# Check that we have a new actor
|
||||
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
|
||||
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
|
||||
|
||||
# Fail at first iteration
|
||||
fail_at = [1]
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
fail_at=fail_at)
|
||||
|
||||
# We should fail after re-initialization.
|
||||
with pytest.raises(RuntimeError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
|
||||
# check_health should raise the same thing
|
||||
with pytest.raises(RuntimeError):
|
||||
tokenizer_group_pool.check_health()
|
||||
|
||||
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
|
||||
# cause a re-initialization.
|
||||
fail_at = []
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=2,
|
||||
fail_at=fail_at)
|
||||
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||
|
||||
# Prompt too long error
|
||||
with pytest.raises(ValueError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt" * 100,
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
# Actors should stay the same.
|
||||
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
|
||||
Reference in New Issue
Block a user