[Fix] the issue of random order when input is a list (#1199)

This commit is contained in:
Ying Sheng
2024-08-24 21:43:03 -07:00
committed by GitHub
parent e61d13acdf
commit 1cb4da5c5f
4 changed files with 23 additions and 20 deletions

View File

@@ -20,7 +20,7 @@ import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
MODELS = [("intfloat/e5-mistral-7b-instruct", 1)]
MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)]
TORCH_DTYPES = [torch.float16]
@@ -32,6 +32,7 @@ class TestEmbeddingModels(unittest.TestCase):
model_path,
tp_size,
torch_dtype,
long_context_tolerance,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=False
@@ -52,20 +53,22 @@ class TestEmbeddingModels(unittest.TestCase):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
print("max similarity diff", torch.max(abs(similarities - 1)))
similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
print("similarity diff", abs(similarity - 1))
if hf_logits.shape[0] <= 100:
tolerance = 1e-2
assert torch.all(
abs(similarities - 1) < tolerance
), "embeddings are not all close"
if len(prompts[i]) <= 1000:
tolerance = 1e-5
else:
tolerance = long_context_tolerance
assert torch.all(
abs(similarity - 1) < tolerance
), "embeddings are not all close"
def test_prefill_logits(self):
for model, tp_size in MODELS:
for model, tp_size, long_context_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype
DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance
)