[Fix] the issue of random order when input is a list (#1199)
This commit is contained in:
@@ -437,13 +437,13 @@ class TokenizerManager:
|
|||||||
is_stream = hasattr(obj, "stream") and obj.stream
|
is_stream = hasattr(obj, "stream") and obj.stream
|
||||||
|
|
||||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||||
output_list = []
|
output_list = [None] * len(tasks)
|
||||||
|
|
||||||
while tasks:
|
while tasks:
|
||||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
for task in done:
|
for task in done:
|
||||||
gen_index = tasks.index(task)
|
cur_index = tasks.index(task)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = task.result()
|
result = task.result()
|
||||||
@@ -451,14 +451,14 @@ class TokenizerManager:
|
|||||||
if is_stream:
|
if is_stream:
|
||||||
yield result
|
yield result
|
||||||
else:
|
else:
|
||||||
output_list.append(result)
|
output_list[result["index"]] = result
|
||||||
|
|
||||||
tasks[gen_index] = asyncio.create_task(
|
tasks[cur_index] = asyncio.create_task(
|
||||||
generators[gen_index].__anext__()
|
generators[cur_index].__anext__()
|
||||||
)
|
)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
del generators[gen_index]
|
del generators[cur_index]
|
||||||
del tasks[gen_index]
|
del tasks[cur_index]
|
||||||
|
|
||||||
if not is_stream:
|
if not is_stream:
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|||||||
@@ -591,7 +591,7 @@ class Runtime:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
sampling_params: Optional[Dict] = None,
|
sampling_params: Optional[Dict] = None,
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
@@ -612,7 +612,7 @@ class Runtime:
|
|||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
):
|
):
|
||||||
json_data = {
|
json_data = {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
|
|||||||
@@ -28,10 +28,10 @@ from sglang.srt.server import Runtime
|
|||||||
DEFAULT_PROMPTS = [
|
DEFAULT_PROMPTS = [
|
||||||
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
||||||
# "The capital of France is",
|
# "The capital of France is",
|
||||||
|
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||||
"The capital of the United Kindom is",
|
"The capital of the United Kindom is",
|
||||||
"Today is a sunny day and I like",
|
"Today is a sunny day and I like",
|
||||||
"AI is a field of computer science focused on",
|
"AI is a field of computer science focused on",
|
||||||
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
dirpath = os.path.dirname(__file__)
|
dirpath = os.path.dirname(__file__)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import torch
|
|||||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
from sglang.test.test_utils import get_similarities
|
from sglang.test.test_utils import get_similarities
|
||||||
|
|
||||||
MODELS = [("intfloat/e5-mistral-7b-instruct", 1)]
|
MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)]
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +32,7 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size,
|
tp_size,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
|
long_context_tolerance,
|
||||||
) -> None:
|
) -> None:
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
model_path, torch_dtype=torch_dtype, is_generation_model=False
|
model_path, torch_dtype=torch_dtype, is_generation_model=False
|
||||||
@@ -52,20 +53,22 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
||||||
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
||||||
|
|
||||||
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
|
similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
|
||||||
print("max similarity diff", torch.max(abs(similarities - 1)))
|
print("similarity diff", abs(similarity - 1))
|
||||||
|
|
||||||
if hf_logits.shape[0] <= 100:
|
if len(prompts[i]) <= 1000:
|
||||||
tolerance = 1e-2
|
tolerance = 1e-5
|
||||||
assert torch.all(
|
else:
|
||||||
abs(similarities - 1) < tolerance
|
tolerance = long_context_tolerance
|
||||||
), "embeddings are not all close"
|
assert torch.all(
|
||||||
|
abs(similarity - 1) < tolerance
|
||||||
|
), "embeddings are not all close"
|
||||||
|
|
||||||
def test_prefill_logits(self):
|
def test_prefill_logits(self):
|
||||||
for model, tp_size in MODELS:
|
for model, tp_size, long_context_tolerance in MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
self.assert_close_prefill_logits(
|
self.assert_close_prefill_logits(
|
||||||
DEFAULT_PROMPTS, model, tp_size, torch_dtype
|
DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user