Handle empty input string for embedding models (#5621)
Co-authored-by: Ravi Theja Desetty <ravitheja@Ravis-MacBook-Pro.local>
This commit is contained in:
@@ -175,6 +175,32 @@ def guess_chat_template_name_from_model_path(model_path):
|
||||
)
|
||||
|
||||
|
||||
def _validate_prompt(prompt: str):
|
||||
"""Validate that the prompt is not empty or whitespace only."""
|
||||
is_invalid = False
|
||||
|
||||
# Check for empty/whitespace string
|
||||
if isinstance(prompt, str):
|
||||
is_invalid = not prompt.strip()
|
||||
# Check for various invalid list cases: [], [""], [" "], [[]]
|
||||
elif isinstance(prompt, list):
|
||||
is_invalid = not prompt or (
|
||||
len(prompt) == 1
|
||||
and (
|
||||
(isinstance(prompt[0], str) and not prompt[0].strip())
|
||||
or (isinstance(prompt[0], list) and not prompt[0])
|
||||
)
|
||||
)
|
||||
|
||||
if is_invalid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Input cannot be empty or contain only whitespace.",
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def v1_files_create(
|
||||
file: UploadFile, purpose: str, file_storage_path: str = None
|
||||
):
|
||||
@@ -1753,6 +1779,8 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
||||
|
||||
for request in all_requests:
|
||||
prompt = request.input
|
||||
# Check for empty/whitespace string
|
||||
prompt = _validate_prompt(request.input)
|
||||
assert (
|
||||
type(prompt) is first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
|
||||
@@ -676,6 +676,22 @@ class TestOpenAIEmbedding(CustomTestCase):
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||
|
||||
def test_empty_string_embedding(self):
|
||||
"""Test embedding an empty string."""
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Text embedding example with empty string
|
||||
text = ""
|
||||
# Expect a BadRequestError for empty input
|
||||
with self.assertRaises(openai.BadRequestError) as cm:
|
||||
client.embeddings.create(
|
||||
model=self.model,
|
||||
input=text,
|
||||
)
|
||||
# check the status code
|
||||
self.assertEqual(cm.exception.status_code, 400)
|
||||
|
||||
|
||||
class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user