Fix failed ci tests on long prompts; Better error messages for embedding models (#1700)
This commit is contained in:
@@ -56,6 +56,9 @@ class GenerateReqInput:
|
|||||||
# LoRA related
|
# LoRA related
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
|
# Whether it is a single request or a batch request
|
||||||
|
is_single: bool = True
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
self.text is not None and self.input_ids is not None
|
self.text is not None and self.input_ids is not None
|
||||||
|
|||||||
@@ -150,9 +150,13 @@ class TokenizerManager:
|
|||||||
while self.model_update_lock.locked():
|
while self.model_update_lock.locked():
|
||||||
await asyncio.sleep(0.001)
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
|
||||||
|
)
|
||||||
|
|
||||||
obj.post_init()
|
obj.post_init()
|
||||||
is_single = obj.is_single
|
is_single = obj.is_single
|
||||||
|
|
||||||
if is_single:
|
if is_single:
|
||||||
async for response in self._handle_single_request(obj, request):
|
async for response in self._handle_single_request(obj, request):
|
||||||
yield response
|
yield response
|
||||||
|
|||||||
@@ -542,8 +542,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|||||||
kill_child_process(pid, including_parent=False)
|
kill_child_process(pid, including_parent=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"{res.json()=}")
|
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send("ready")
|
pipe_finish_writer.send("ready")
|
||||||
|
|||||||
@@ -40,20 +40,23 @@ class ModelCase:
|
|||||||
prefill_tolerance: float = 5e-2
|
prefill_tolerance: float = 5e-2
|
||||||
decode_tolerance: float = 5e-2
|
decode_tolerance: float = 5e-2
|
||||||
rouge_l_tolerance: float = 1
|
rouge_l_tolerance: float = 1
|
||||||
|
skip_long_prompt: bool = False
|
||||||
|
|
||||||
|
|
||||||
# Popular models that run on the CI
|
# Popular models that run on the CI
|
||||||
CI_MODELS = [
|
CI_MODELS = [
|
||||||
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
|
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
ModelCase("google/gemma-2-2b"),
|
ModelCase(
|
||||||
|
"google/gemma-2-2b", skip_long_prompt=True
|
||||||
|
), # There is a bug with new transformers library. This can only run with transformers==4.44
|
||||||
]
|
]
|
||||||
|
|
||||||
# All other models that do not run on the CI
|
# All other models that do not run on the CI
|
||||||
ALL_OTHER_MODELS = [
|
ALL_OTHER_MODELS = [
|
||||||
ModelCase("Qwen/Qwen2-1.5B"),
|
ModelCase("Qwen/Qwen2-1.5B"),
|
||||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||||
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"),
|
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
|
||||||
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2),
|
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
@@ -136,8 +139,15 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
def test_ci_models(self):
|
def test_ci_models(self):
|
||||||
for model_case in CI_MODELS:
|
for model_case in CI_MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
|
||||||
|
# Skip long prompts for models that do not have a long context
|
||||||
|
prompts = DEFAULT_PROMPTS
|
||||||
|
if model_case.skip_long_prompt:
|
||||||
|
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||||
|
|
||||||
|
# Assert the logits and output strs are close
|
||||||
self.assert_close_logits_and_output_strs(
|
self.assert_close_logits_and_output_strs(
|
||||||
DEFAULT_PROMPTS, model_case, torch_dtype
|
prompts, model_case, torch_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_others(self):
|
def test_others(self):
|
||||||
@@ -152,13 +162,9 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip long prompts for models that does not have a long context
|
# Skip long prompts for models that do not have a long context
|
||||||
prompts = DEFAULT_PROMPTS
|
prompts = DEFAULT_PROMPTS
|
||||||
if model_case.model_path in [
|
if model_case.skip_long_prompt:
|
||||||
"HuggingFaceTB/SmolLM-135M-Instruct",
|
|
||||||
"allenai/OLMo-1B-0724-hf",
|
|
||||||
"google/gemma-2-2b", # There is a bug with new transformers library. This can only run with transformers==4.44
|
|
||||||
]:
|
|
||||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||||
|
|
||||||
# Assert the logits and output strs are close
|
# Assert the logits and output strs are close
|
||||||
|
|||||||
Reference in New Issue
Block a user