Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
2
.github/workflows/accuracy-test.yml
vendored
2
.github/workflows/accuracy-test.yml
vendored
@@ -43,4 +43,4 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_eval_accuracy_large.py
|
python3 test_eval_accuracy_large.py
|
||||||
timeout-minutes: 10
|
timeout-minutes: 20
|
||||||
|
|||||||
2
.github/workflows/unit-test.yml
vendored
2
.github/workflows/unit-test.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal
|
python3 run_suite.py --suite minimal
|
||||||
timeout-minutes: 18
|
timeout-minutes: 20
|
||||||
|
|
||||||
- name: Test Frontend Language
|
- name: Test Frontend Language
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
15
README.md
15
README.md
@@ -187,6 +187,13 @@ response = client.chat.completions.create(
|
|||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
# Text embedding
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="default",
|
||||||
|
input="How are you today",
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
|
It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
|
||||||
@@ -223,6 +230,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
|
**Generative Models**
|
||||||
|
|
||||||
- Llama / Llama 2 / Llama 3 / Llama 3.1
|
- Llama / Llama 2 / Llama 3 / Llama 3.1
|
||||||
- Mistral / Mixtral / Mistral NeMo
|
- Mistral / Mixtral / Mistral NeMo
|
||||||
- Gemma / Gemma 2
|
- Gemma / Gemma 2
|
||||||
@@ -243,6 +252,12 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
- ChatGLM
|
- ChatGLM
|
||||||
- InternLM 2
|
- InternLM 2
|
||||||
|
|
||||||
|
**Embedding Models**
|
||||||
|
|
||||||
|
- e5-mistral
|
||||||
|
- gte-Qwen2
|
||||||
|
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
|
||||||
|
|
||||||
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md).
|
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md).
|
||||||
|
|
||||||
#### Use Models From ModelScope
|
#### Use Models From ModelScope
|
||||||
|
|||||||
@@ -94,7 +94,10 @@ class TokenizerManager:
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
model_overide_args=model_overide_args,
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
self.is_generation = is_generation_model(self.hf_config.architectures)
|
|
||||||
|
self.is_generation = is_generation_model(
|
||||||
|
self.hf_config.architectures, self.server_args.is_embedding
|
||||||
|
)
|
||||||
|
|
||||||
if server_args.context_length is not None:
|
if server_args.context_length is not None:
|
||||||
self.context_len = server_args.context_length
|
self.context_len = server_args.context_length
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ class ModelTpServer:
|
|||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
model_overide_args=model_overide_args,
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
mem_fraction_static=server_args.mem_fraction_static,
|
mem_fraction_static=server_args.mem_fraction_static,
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ class ModelRunner:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
self.model_config.hf_config.architectures
|
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -522,9 +522,18 @@ class ModelRunner:
|
|||||||
batch,
|
batch,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
if self.is_generation:
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
return self.model.forward(
|
||||||
)
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Only embedding models have get_embedding parameter
|
||||||
|
return self.model.forward(
|
||||||
|
batch.input_ids,
|
||||||
|
input_metadata.positions,
|
||||||
|
input_metadata,
|
||||||
|
get_embedding=True,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||||
|
|||||||
@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
get_embedding: bool = True,
|
||||||
) -> EmbeddingPoolerOutput:
|
) -> EmbeddingPoolerOutput:
|
||||||
|
assert (
|
||||||
|
get_embedding
|
||||||
|
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.pooler(hidden_states, input_metadata)
|
return self.pooler(hidden_states, input_metadata)
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
@@ -275,6 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -283,11 +285,15 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
get_embedding: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
return self.logits_processor(
|
if not get_embedding:
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
return self.logits_processor(
|
||||||
)
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.pooler(hidden_states, input_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -333,11 +333,13 @@ def launch_server(
|
|||||||
start_process = start_controller_process_single
|
start_process = start_controller_process_single
|
||||||
else:
|
else:
|
||||||
start_process = start_controller_process_multi
|
start_process = start_controller_process_multi
|
||||||
|
|
||||||
proc_controller = mp.Process(
|
proc_controller = mp.Process(
|
||||||
target=start_process,
|
target=start_process,
|
||||||
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
||||||
)
|
)
|
||||||
proc_controller.start()
|
proc_controller.start()
|
||||||
|
|
||||||
proc_detoken = mp.Process(
|
proc_detoken = mp.Process(
|
||||||
target=start_detokenizer_process,
|
target=start_detokenizer_process,
|
||||||
args=(
|
args=(
|
||||||
@@ -515,6 +517,7 @@ class Runtime:
|
|||||||
|
|
||||||
self.pid = None
|
self.pid = None
|
||||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=launch_server,
|
target=launch_server,
|
||||||
args=(self.server_args, model_overide_args, pipe_writer),
|
args=(self.server_args, model_overide_args, pipe_writer),
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class ServerArgs:
|
|||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
served_model_name: Optional[str] = None
|
served_model_name: Optional[str] = None
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
|
is_embedding: bool = False
|
||||||
|
|
||||||
# Port
|
# Port
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
@@ -200,6 +201,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--is-embedding",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use a CausalLM as an embedding model.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-length",
|
"--context-length",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -458,6 +464,11 @@ class ServerArgs:
|
|||||||
assert not (
|
assert not (
|
||||||
self.dp_size > 1 and self.node_rank is not None
|
self.dp_size > 1 and self.node_rank is not None
|
||||||
), "multi-node data parallel is not supported"
|
), "multi-node data parallel is not supported"
|
||||||
|
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||||
|
logger.info(
|
||||||
|
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
||||||
|
)
|
||||||
|
self.trust_remote_code = False
|
||||||
if "gemma-2" in self.model_path.lower():
|
if "gemma-2" in self.model_path.lower():
|
||||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||||
self.disable_flashinfer = False
|
self.disable_flashinfer = False
|
||||||
|
|||||||
@@ -224,13 +224,18 @@ def is_multimodal_model(model):
|
|||||||
raise ValueError("unrecognized type")
|
raise ValueError("unrecognized type")
|
||||||
|
|
||||||
|
|
||||||
def is_generation_model(model_architectures):
|
def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||||
|
# We have two ways to determine whether a model is a generative model.
|
||||||
|
# 1. Check the model architectue
|
||||||
|
# 2. check the `is_embedding` server args
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"LlamaEmbeddingModel" in model_architectures
|
"LlamaEmbeddingModel" in model_architectures
|
||||||
or "MistralModel" in model_architectures
|
or "MistralModel" in model_architectures
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
return True
|
else:
|
||||||
|
return not is_embedding
|
||||||
|
|
||||||
|
|
||||||
def decode_video_base64(video_base64):
|
def decode_video_base64(video_base64):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@@ -63,37 +63,35 @@ class HFRunner:
|
|||||||
self,
|
self,
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
is_generation_model,
|
is_generation,
|
||||||
):
|
):
|
||||||
self.in_queue = multiprocessing.Queue()
|
self.is_generation = is_generation
|
||||||
self.out_queue = multiprocessing.Queue()
|
|
||||||
|
|
||||||
self.model_proc = multiprocessing.Process(
|
self.in_queue = mp.Queue()
|
||||||
|
self.out_queue = mp.Queue()
|
||||||
|
|
||||||
|
self.model_proc = mp.Process(
|
||||||
target=self.start_model_process,
|
target=self.start_model_process,
|
||||||
args=(
|
args=(
|
||||||
self.in_queue,
|
self.in_queue,
|
||||||
self.out_queue,
|
self.out_queue,
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
is_generation_model,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.model_proc.start()
|
self.model_proc.start()
|
||||||
|
|
||||||
def start_model_process(
|
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||||
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
|
||||||
):
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_generation_model = is_generation_model
|
if self.is_generation:
|
||||||
|
|
||||||
if self.is_generation_model:
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=False,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
).cuda()
|
).cuda()
|
||||||
else:
|
else:
|
||||||
@@ -107,7 +105,7 @@ class HFRunner:
|
|||||||
while True:
|
while True:
|
||||||
prompts, max_new_tokens = in_queue.get()
|
prompts, max_new_tokens = in_queue.get()
|
||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
if self.is_generation_model:
|
if self.is_generation:
|
||||||
output_strs = []
|
output_strs = []
|
||||||
prefill_logprobs = []
|
prefill_logprobs = []
|
||||||
for p in prompts:
|
for p in prompts:
|
||||||
@@ -171,17 +169,19 @@ class SRTRunner:
|
|||||||
self,
|
self,
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
is_generation_model,
|
is_generation,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
port=5157,
|
port=5157,
|
||||||
):
|
):
|
||||||
self.is_generation_model = is_generation_model
|
self.is_generation = is_generation
|
||||||
self.runtime = Runtime(
|
self.runtime = Runtime(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
dtype=get_dtype_str(torch_dtype),
|
dtype=get_dtype_str(torch_dtype),
|
||||||
port=port,
|
port=port,
|
||||||
mem_fraction_static=0.7,
|
mem_fraction_static=0.7,
|
||||||
|
trust_remote_code=False,
|
||||||
|
is_embedding=not self.is_generation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -189,7 +189,7 @@ class SRTRunner:
|
|||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
max_new_tokens=8,
|
max_new_tokens=8,
|
||||||
):
|
):
|
||||||
if self.is_generation_model:
|
if self.is_generation:
|
||||||
# the return value contains logprobs from prefill
|
# the return value contains logprobs from prefill
|
||||||
output_strs = []
|
output_strs = []
|
||||||
top_input_logprobs = []
|
top_input_logprobs = []
|
||||||
|
|||||||
@@ -20,7 +20,10 @@ import torch
|
|||||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
from sglang.test.test_utils import get_similarities
|
from sglang.test.test_utils import get_similarities
|
||||||
|
|
||||||
MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)]
|
MODELS = [
|
||||||
|
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
|
||||||
|
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
|
||||||
|
]
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
|
||||||
@@ -32,10 +35,10 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size,
|
tp_size,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
long_context_tolerance,
|
prefill_tolerance,
|
||||||
) -> None:
|
) -> None:
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
model_path, torch_dtype=torch_dtype, is_generation_model=False
|
model_path, torch_dtype=torch_dtype, is_generation=False
|
||||||
) as hf_runner:
|
) as hf_runner:
|
||||||
hf_outputs = hf_runner.forward(prompts)
|
hf_outputs = hf_runner.forward(prompts)
|
||||||
|
|
||||||
@@ -43,11 +46,9 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
is_generation_model=False,
|
is_generation=False,
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(
|
srt_outputs = srt_runner.forward(prompts)
|
||||||
prompts,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(len(prompts)):
|
for i in range(len(prompts)):
|
||||||
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
||||||
@@ -57,18 +58,15 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
print("similarity diff", abs(similarity - 1))
|
print("similarity diff", abs(similarity - 1))
|
||||||
|
|
||||||
if len(prompts[i]) <= 1000:
|
if len(prompts[i]) <= 1000:
|
||||||
tolerance = 1e-5
|
assert torch.all(
|
||||||
else:
|
abs(similarity - 1) < prefill_tolerance
|
||||||
tolerance = long_context_tolerance
|
), "embeddings are not all close"
|
||||||
assert torch.all(
|
|
||||||
abs(similarity - 1) < tolerance
|
|
||||||
), "embeddings are not all close"
|
|
||||||
|
|
||||||
def test_prefill_logits(self):
|
def test_prefill_logits(self):
|
||||||
for model, tp_size, long_context_tolerance in MODELS:
|
for model, tp_size, prefill_tolerance in MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
self.assert_close_prefill_logits(
|
self.assert_close_prefill_logits(
|
||||||
DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance
|
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,12 +20,46 @@ import torch
|
|||||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1),
|
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1),
|
||||||
("google/gemma-2-2b", 1, 3),
|
("google/gemma-2-2b", 1, 3, 3e-2, 1),
|
||||||
|
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1),
|
||||||
]
|
]
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
|
||||||
|
def lcs(X, Y):
|
||||||
|
m = len(X)
|
||||||
|
n = len(Y)
|
||||||
|
L = [[0] * (n + 1) for _ in range(m + 1)]
|
||||||
|
|
||||||
|
for i in range(m + 1):
|
||||||
|
for j in range(n + 1):
|
||||||
|
if i == 0 or j == 0:
|
||||||
|
L[i][j] = 0
|
||||||
|
elif X[i - 1] == Y[j - 1]:
|
||||||
|
L[i][j] = L[i - 1][j - 1] + 1
|
||||||
|
else:
|
||||||
|
L[i][j] = max(L[i - 1][j], L[i][j - 1])
|
||||||
|
|
||||||
|
return L[m][n]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_rouge_l(output_strs_list1, output_strs_list2):
|
||||||
|
rouge_l_scores = []
|
||||||
|
|
||||||
|
for s1, s2 in zip(output_strs_list1, output_strs_list2):
|
||||||
|
lcs_len = lcs(s1, s2)
|
||||||
|
precision = lcs_len / len(s1) if len(s1) > 0 else 0
|
||||||
|
recall = lcs_len / len(s2) if len(s2) > 0 else 0
|
||||||
|
if precision + recall > 0:
|
||||||
|
fmeasure = (2 * precision * recall) / (precision + recall)
|
||||||
|
else:
|
||||||
|
fmeasure = 0.0
|
||||||
|
rouge_l_scores.append(fmeasure)
|
||||||
|
|
||||||
|
return rouge_l_scores
|
||||||
|
|
||||||
|
|
||||||
class TestGenerationModels(unittest.TestCase):
|
class TestGenerationModels(unittest.TestCase):
|
||||||
|
|
||||||
def assert_close_prefill_logits_and_output_strs(
|
def assert_close_prefill_logits_and_output_strs(
|
||||||
@@ -35,10 +69,14 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
tp_size,
|
tp_size,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
prefill_tolerance,
|
||||||
|
rouge_threshold,
|
||||||
long_context_tolerance,
|
long_context_tolerance,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
|
||||||
|
prompts = prompts[:-1]
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
model_path, torch_dtype=torch_dtype, is_generation_model=True
|
model_path, torch_dtype=torch_dtype, is_generation=True
|
||||||
) as hf_runner:
|
) as hf_runner:
|
||||||
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
@@ -46,7 +84,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
is_generation_model=True,
|
is_generation=True,
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
@@ -56,17 +94,34 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
|
|
||||||
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
|
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
|
||||||
if hf_logprobs.shape[0] <= 100:
|
if hf_logprobs.shape[0] <= 100:
|
||||||
tolerance = 3e-2
|
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
abs(hf_logprobs - srt_logprobs) < tolerance
|
abs(hf_logprobs - srt_logprobs) < prefill_tolerance
|
||||||
), "prefill logprobs are not all close"
|
), "prefill logprobs are not all close"
|
||||||
|
|
||||||
print(hf_outputs.output_strs)
|
print(hf_outputs.output_strs)
|
||||||
print(srt_outputs.output_strs)
|
print(srt_outputs.output_strs)
|
||||||
assert hf_outputs.output_strs == srt_outputs.output_strs
|
rouge_l_scores = calculate_rouge_l(
|
||||||
|
hf_outputs.output_strs, srt_outputs.output_strs
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
score >= rouge_threshold for score in rouge_l_scores
|
||||||
|
), f"Not all ROUGE-L scores are greater than {rouge_threshold}"
|
||||||
|
|
||||||
def test_prefill_logits_and_output_strs(self):
|
def test_prefill_logits_and_output_strs(self):
|
||||||
for model, tp_size, long_context_tolerance in MODELS:
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
try:
|
||||||
|
mp.set_start_method("spawn")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for (
|
||||||
|
model,
|
||||||
|
tp_size,
|
||||||
|
long_context_tolerance,
|
||||||
|
prefill_tolerance,
|
||||||
|
rouge_threshold,
|
||||||
|
) in MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
max_new_tokens = 8
|
max_new_tokens = 8
|
||||||
self.assert_close_prefill_logits_and_output_strs(
|
self.assert_close_prefill_logits_and_output_strs(
|
||||||
@@ -75,6 +130,8 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
tp_size,
|
tp_size,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
prefill_tolerance=prefill_tolerance,
|
||||||
|
rouge_threshold=rouge_threshold,
|
||||||
long_context_tolerance=long_context_tolerance,
|
long_context_tolerance=long_context_tolerance,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ from sglang.test.test_utils import run_unittest_files
|
|||||||
|
|
||||||
suites = {
|
suites = {
|
||||||
"minimal": [
|
"minimal": [
|
||||||
|
"models/test_embedding_models.py",
|
||||||
|
"models/test_generation_models.py",
|
||||||
|
"sampling/penaltylib",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
@@ -13,11 +16,8 @@ suites = {
|
|||||||
"test_skip_tokenizer_init.py",
|
"test_skip_tokenizer_init.py",
|
||||||
"test_torch_compile.py",
|
"test_torch_compile.py",
|
||||||
"test_triton_attn_backend.py",
|
"test_triton_attn_backend.py",
|
||||||
"test_vision_openai_server.py",
|
|
||||||
"test_update_weights.py",
|
"test_update_weights.py",
|
||||||
"models/test_generation_models.py",
|
"test_vision_openai_server.py",
|
||||||
"models/test_embedding_models.py",
|
|
||||||
"sampling/penaltylib",
|
|
||||||
],
|
],
|
||||||
"sampling/penaltylib": glob.glob(
|
"sampling/penaltylib": glob.glob(
|
||||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||||
|
|||||||
Reference in New Issue
Block a user