Add model accuracy test - step 1 (#866)
This commit is contained in:
2
.github/workflows/unit-test.yml
vendored
2
.github/workflows/unit-test.yml
vendored
@@ -35,6 +35,7 @@ jobs:
|
||||
pip install -e "python[all]"
|
||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
|
||||
pip install --upgrade transformers
|
||||
pip install accelerate
|
||||
|
||||
- name: Test Frontend Language with SRT Backend
|
||||
run: |
|
||||
@@ -50,6 +51,7 @@ jobs:
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 test_eval_accuracy.py
|
||||
python3 models/test_causal_models.py
|
||||
|
||||
- name: Test Frontend Language with OpenAI Backend
|
||||
run: |
|
||||
|
||||
@@ -28,7 +28,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -481,10 +481,10 @@ class Runtime:
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
async def add_request(
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Dict,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
@@ -507,5 +507,26 @@ class Runtime:
|
||||
yield cur
|
||||
pos += len(cur)
|
||||
|
||||
add_request = async_generate
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
237
python/sglang/test/runners.py
Normal file
237
python/sglang/test/runners.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from sglang.srt.server import Runtime
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
|
||||
|
||||
def is_embedding_model(model_path):
|
||||
# FIXME incomplete list
|
||||
if "e5-mistral-7b-instruct" in model_path.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_dtype_str(torch_dtype):
|
||||
if torch_dtype is torch.float16:
|
||||
return "float16"
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
output_strs: str = None
|
||||
top_input_logprobs: torch.Tensor = None
|
||||
top_output_logprobs: torch.Tensor = None
|
||||
embed_logits: torch.Tensor = None
|
||||
|
||||
|
||||
class HFRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
is_embedding_model=None,
|
||||
):
|
||||
self.in_queue = multiprocessing.Queue()
|
||||
self.out_queue = multiprocessing.Queue()
|
||||
|
||||
self.model_proc = multiprocessing.Process(
|
||||
target=self.start_model_process,
|
||||
args=(
|
||||
self.in_queue,
|
||||
self.out_queue,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_embedding_model,
|
||||
),
|
||||
)
|
||||
self.model_proc.start()
|
||||
|
||||
def start_model_process(
|
||||
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.is_embedding_model = (
|
||||
is_embedding_model(model_path)
|
||||
if is_embedding_model is None
|
||||
else is_embedding_model
|
||||
)
|
||||
if not self.is_embedding_model:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
else:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_path,
|
||||
device="cpu",
|
||||
).to(dtype=torch_dtype)
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
if prompts is not None:
|
||||
if not self.is_embedding_model:
|
||||
output_strs = []
|
||||
prefill_logprobs = []
|
||||
for p in prompts:
|
||||
if isinstance(p, str):
|
||||
input_ids = self.tokenizer.encode(
|
||||
p, return_tensors="pt"
|
||||
).cuda()
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
output_ids = self.model.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_strs.append(self.tokenizer.decode(output_ids[0]))
|
||||
|
||||
logits = self.model.forward(input_ids).logits[0]
|
||||
logprobs = F.log_softmax(
|
||||
logits, dim=-1, dtype=torch.float32
|
||||
).tolist()
|
||||
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
|
||||
# print("index", index_of_max)
|
||||
logprobs = [
|
||||
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
|
||||
for token_logprobs in logprobs
|
||||
]
|
||||
prefill_logprobs.append(logprobs)
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=prefill_logprobs
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
assert isinstance(prompts, List[str])
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=64,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens))
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
self.model_proc.terminate()
|
||||
self.in_queue = self.out_queue = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.model_proc.terminate()
|
||||
self.in_queue = self.out_queue = None
|
||||
|
||||
|
||||
class SRTRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
tp_size=1,
|
||||
torch_dtype=torch.float16,
|
||||
is_embedding_model=None,
|
||||
):
|
||||
self.is_embedding_model = (
|
||||
is_embedding_model(model_path)
|
||||
if is_embedding_model is None
|
||||
else is_embedding_model
|
||||
)
|
||||
if self.is_embedding_model:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=64,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs.append(response["text"])
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
]
|
||||
]
|
||||
]
|
||||
)
|
||||
# print(response["meta_info"]["output_top_logprobs"][0])
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.runtime.shutdown()
|
||||
del self.runtime
|
||||
67
test/srt/models/test_causal_models.py
Normal file
67
test/srt/models/test_causal_models.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
|
||||
MODELS = [
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
|
||||
# ("meta-llama/Meta-Llama-3.1-8B-Instruct", 2),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestCausalModels(unittest.TestCase):
|
||||
|
||||
def assert_close_prefill_logits(
|
||||
self,
|
||||
prompts,
|
||||
model_path,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
) -> None:
|
||||
with HFRunner(
|
||||
model_path, torch_dtype=torch_dtype, is_embedding_model=False
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(prompts)
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_embedding_model=False,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(prompts)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||
|
||||
tolerance = 2e-2
|
||||
assert torch.all(
|
||||
abs(hf_logprobs - srt_logprobs) < tolerance
|
||||
), f"prefill logprobs not all close"
|
||||
|
||||
def test_prefill_logits(self):
|
||||
for model, tp_size in MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_prefill_logits(
|
||||
DEFAULT_PROMPTS, model, tp_size, torch_dtype
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
Reference in New Issue
Block a user