[router]: Add Embedding routing logic (#10129)
Signed-off-by: Jintao Zhang <zhangjintao9020@gmail.com> Co-authored-by: Waël Boukhobza <wawa_wael@live.fr>
This commit is contained in:
@@ -20,7 +20,12 @@ import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
|
||||
from sglang.test.test_utils import (
|
||||
CustomTestCase,
|
||||
get_similarities,
|
||||
is_in_amd_ci,
|
||||
is_in_ci,
|
||||
)
|
||||
|
||||
MODELS = [
|
||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
|
||||
@@ -74,11 +79,13 @@ class TestEmbeddingModels(CustomTestCase):
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(truncated_prompts)
|
||||
|
||||
attention_backend = "triton" if is_in_amd_ci() else None
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
attention_backend=attention_backend,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(truncated_prompts)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user