Add model accuracy test - step 1 (#866)
This commit is contained in:
2
.github/workflows/unit-test.yml
vendored
2
.github/workflows/unit-test.yml
vendored
@@ -35,6 +35,7 @@ jobs:
|
|||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
|
||||||
pip install --upgrade transformers
|
pip install --upgrade transformers
|
||||||
|
pip install accelerate
|
||||||
|
|
||||||
- name: Test Frontend Language with SRT Backend
|
- name: Test Frontend Language with SRT Backend
|
||||||
run: |
|
run: |
|
||||||
@@ -50,6 +51,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_eval_accuracy.py
|
python3 test_eval_accuracy.py
|
||||||
|
python3 models/test_causal_models.py
|
||||||
|
|
||||||
- name: Test Frontend Language with OpenAI Backend
|
- name: Test Frontend Language with OpenAI Backend
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
@@ -481,10 +481,10 @@ class Runtime:
|
|||||||
trust_remote_code=self.server_args.trust_remote_code,
|
trust_remote_code=self.server_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def add_request(
|
async def async_generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
sampling_params: Dict,
|
sampling_params: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
json_data = {
|
json_data = {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
@@ -507,5 +507,26 @@ class Runtime:
|
|||||||
yield cur
|
yield cur
|
||||||
pos += len(cur)
|
pos += len(cur)
|
||||||
|
|
||||||
|
add_request = async_generate
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
sampling_params: Optional[Dict] = None,
|
||||||
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
):
|
||||||
|
json_data = {
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": sampling_params,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
self.url + "/generate",
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
return json.dumps(response.json())
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|||||||
237
python/sglang/test/runners.py
Normal file
237
python/sglang/test/runners.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.srt.server import Runtime
|
||||||
|
|
||||||
|
DEFAULT_PROMPTS = [
|
||||||
|
"The capital of France is",
|
||||||
|
"The capital of the United Kindom is",
|
||||||
|
"Today is a sunny day and I like",
|
||||||
|
]
|
||||||
|
|
||||||
|
NUM_TOP_LOGPROBS = 5
|
||||||
|
|
||||||
|
|
||||||
|
def is_embedding_model(model_path):
|
||||||
|
# FIXME incomplete list
|
||||||
|
if "e5-mistral-7b-instruct" in model_path.lower():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype_str(torch_dtype):
|
||||||
|
if torch_dtype is torch.float16:
|
||||||
|
return "float16"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelOutput:
|
||||||
|
output_strs: str = None
|
||||||
|
top_input_logprobs: torch.Tensor = None
|
||||||
|
top_output_logprobs: torch.Tensor = None
|
||||||
|
embed_logits: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class HFRunner:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
is_embedding_model=None,
|
||||||
|
):
|
||||||
|
self.in_queue = multiprocessing.Queue()
|
||||||
|
self.out_queue = multiprocessing.Queue()
|
||||||
|
|
||||||
|
self.model_proc = multiprocessing.Process(
|
||||||
|
target=self.start_model_process,
|
||||||
|
args=(
|
||||||
|
self.in_queue,
|
||||||
|
self.out_queue,
|
||||||
|
model_path,
|
||||||
|
torch_dtype,
|
||||||
|
is_embedding_model,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.model_proc.start()
|
||||||
|
|
||||||
|
def start_model_process(
|
||||||
|
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
|
||||||
|
):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_embedding_model = (
|
||||||
|
is_embedding_model(model_path)
|
||||||
|
if is_embedding_model is None
|
||||||
|
else is_embedding_model
|
||||||
|
)
|
||||||
|
if not self.is_embedding_model:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
).cuda()
|
||||||
|
else:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
self.model = SentenceTransformer(
|
||||||
|
model_path,
|
||||||
|
device="cpu",
|
||||||
|
).to(dtype=torch_dtype)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
prompts, max_new_tokens = in_queue.get()
|
||||||
|
if prompts is not None:
|
||||||
|
if not self.is_embedding_model:
|
||||||
|
output_strs = []
|
||||||
|
prefill_logprobs = []
|
||||||
|
for p in prompts:
|
||||||
|
if isinstance(p, str):
|
||||||
|
input_ids = self.tokenizer.encode(
|
||||||
|
p, return_tensors="pt"
|
||||||
|
).cuda()
|
||||||
|
else:
|
||||||
|
input_ids = torch.tensor([p], device="cuda")
|
||||||
|
|
||||||
|
output_ids = self.model.generate(
|
||||||
|
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||||
|
)
|
||||||
|
output_strs.append(self.tokenizer.decode(output_ids[0]))
|
||||||
|
|
||||||
|
logits = self.model.forward(input_ids).logits[0]
|
||||||
|
logprobs = F.log_softmax(
|
||||||
|
logits, dim=-1, dtype=torch.float32
|
||||||
|
).tolist()
|
||||||
|
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
|
||||||
|
# print("index", index_of_max)
|
||||||
|
logprobs = [
|
||||||
|
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
|
||||||
|
for token_logprobs in logprobs
|
||||||
|
]
|
||||||
|
prefill_logprobs.append(logprobs)
|
||||||
|
|
||||||
|
out_queue.put(
|
||||||
|
ModelOutput(
|
||||||
|
output_strs=output_strs, top_input_logprobs=prefill_logprobs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert isinstance(prompts, List[str])
|
||||||
|
logits = self.model.encode(prompts).tolist()
|
||||||
|
|
||||||
|
out_queue.put(ModelOutput(embed_logits=logits))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
|
max_new_tokens=64,
|
||||||
|
):
|
||||||
|
self.in_queue.put((prompts, max_new_tokens))
|
||||||
|
return self.out_queue.get()
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
self.model_proc.terminate()
|
||||||
|
self.in_queue = self.out_queue = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.model_proc.terminate()
|
||||||
|
self.in_queue = self.out_queue = None
|
||||||
|
|
||||||
|
|
||||||
|
class SRTRunner:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path,
|
||||||
|
tp_size=1,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
is_embedding_model=None,
|
||||||
|
):
|
||||||
|
self.is_embedding_model = (
|
||||||
|
is_embedding_model(model_path)
|
||||||
|
if is_embedding_model is None
|
||||||
|
else is_embedding_model
|
||||||
|
)
|
||||||
|
if self.is_embedding_model:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.runtime = Runtime(
|
||||||
|
model_path=model_path,
|
||||||
|
tp_size=tp_size,
|
||||||
|
dtype=get_dtype_str(torch_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
|
max_new_tokens=64,
|
||||||
|
):
|
||||||
|
# the return value contains logprobs from prefill
|
||||||
|
output_strs = []
|
||||||
|
top_input_logprobs = []
|
||||||
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||||
|
for prompt in prompts:
|
||||||
|
response = self.runtime.generate(
|
||||||
|
prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_logprob=True,
|
||||||
|
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||||
|
)
|
||||||
|
response = json.loads(response)
|
||||||
|
output_strs.append(response["text"])
|
||||||
|
top_input_logprobs.append(
|
||||||
|
[
|
||||||
|
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||||
|
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
[
|
||||||
|
tup[0]
|
||||||
|
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||||
|
:NUM_TOP_LOGPROBS
|
||||||
|
]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# print(response["meta_info"]["output_top_logprobs"][0])
|
||||||
|
|
||||||
|
return ModelOutput(
|
||||||
|
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||||
|
)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.runtime.shutdown()
|
||||||
|
del self.runtime
|
||||||
67
test/srt/models/test_causal_models.py
Normal file
67
test/srt/models/test_causal_models.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
|
||||||
|
# ("meta-llama/Meta-Llama-3.1-8B-Instruct", 2),
|
||||||
|
]
|
||||||
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCausalModels(unittest.TestCase):
|
||||||
|
|
||||||
|
def assert_close_prefill_logits(
|
||||||
|
self,
|
||||||
|
prompts,
|
||||||
|
model_path,
|
||||||
|
tp_size,
|
||||||
|
torch_dtype,
|
||||||
|
) -> None:
|
||||||
|
with HFRunner(
|
||||||
|
model_path, torch_dtype=torch_dtype, is_embedding_model=False
|
||||||
|
) as hf_runner:
|
||||||
|
hf_outputs = hf_runner.forward(prompts)
|
||||||
|
|
||||||
|
with SRTRunner(
|
||||||
|
model_path,
|
||||||
|
tp_size=tp_size,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
is_embedding_model=False,
|
||||||
|
) as srt_runner:
|
||||||
|
srt_outputs = srt_runner.forward(prompts)
|
||||||
|
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||||
|
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||||
|
|
||||||
|
tolerance = 2e-2
|
||||||
|
assert torch.all(
|
||||||
|
abs(hf_logprobs - srt_logprobs) < tolerance
|
||||||
|
), f"prefill logprobs not all close"
|
||||||
|
|
||||||
|
def test_prefill_logits(self):
|
||||||
|
for model, tp_size in MODELS:
|
||||||
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
self.assert_close_prefill_logits(
|
||||||
|
DEFAULT_PROMPTS, model, tp_size, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
Reference in New Issue
Block a user