diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 140e874b6..8b7f88e8b 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -35,6 +35,7 @@ jobs: pip install -e "python[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall pip install --upgrade transformers + pip install accelerate - name: Test Frontend Language with SRT Backend run: | @@ -50,6 +51,7 @@ jobs: run: | cd test/srt python3 test_eval_accuracy.py + python3 models/test_causal_models.py - name: Test Frontend Language with OpenAI Backend run: | diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9ac18206c..18ff22432 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -28,7 +28,7 @@ import sys import threading import time from http import HTTPStatus -from typing import Dict, Optional +from typing import Dict, List, Optional, Union # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -481,10 +481,10 @@ class Runtime: trust_remote_code=self.server_args.trust_remote_code, ) - async def add_request( + async def async_generate( self, prompt: str, - sampling_params: Dict, + sampling_params: Optional[Dict] = None, ): json_data = { "text": prompt, @@ -507,5 +507,26 @@ class Runtime: yield cur pos += len(cur) + add_request = async_generate + + def generate( + self, + prompt: str, + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + top_logprobs_num: Optional[Union[List[int], int]] = None, + ): + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + } + response = requests.post( + self.url + "/generate", + json=json_data, + ) + return json.dumps(response.json()) + def __del__(self): self.shutdown() diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py new file mode 100644 index 000000000..3a8cff213 --- /dev/null +++ b/python/sglang/test/runners.py @@ -0,0 +1,237 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import multiprocessing +from dataclasses import dataclass +from typing import List, Union + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +from sglang.srt.server import Runtime + +DEFAULT_PROMPTS = [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", +] + +NUM_TOP_LOGPROBS = 5 + + +def is_embedding_model(model_path): + # FIXME incomplete list + if "e5-mistral-7b-instruct" in model_path.lower(): + return True + return False + + +def get_dtype_str(torch_dtype): + if torch_dtype is torch.float16: + return "float16" + else: + raise NotImplementedError() + + +@dataclass +class ModelOutput: + output_strs: str = None + top_input_logprobs: torch.Tensor = None + top_output_logprobs: torch.Tensor = None + embed_logits: torch.Tensor = None + + +class HFRunner: + def __init__( + self, + model_path, + torch_dtype=torch.float16, + is_embedding_model=None, + ): + self.in_queue = multiprocessing.Queue() + self.out_queue = multiprocessing.Queue() + + self.model_proc = multiprocessing.Process( + target=self.start_model_process, + args=( + self.in_queue, + self.out_queue, + model_path, + torch_dtype, + is_embedding_model, + ), + ) + self.model_proc.start() + + def start_model_process( + self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model + ): + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + self.is_embedding_model = ( + is_embedding_model(model_path) + if is_embedding_model is None + else is_embedding_model + ) + if not self.is_embedding_model: + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).cuda() + else: + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer( + model_path, + device="cpu", + ).to(dtype=torch_dtype) + + while True: + prompts, max_new_tokens = in_queue.get() + if prompts is not None: + if not self.is_embedding_model: + output_strs = [] + prefill_logprobs = [] + for p in prompts: + if isinstance(p, str): + input_ids = self.tokenizer.encode( + p, return_tensors="pt" + ).cuda() + else: + input_ids = torch.tensor([p], device="cuda") + + output_ids = self.model.generate( + input_ids, do_sample=False, max_new_tokens=max_new_tokens + ) + output_strs.append(self.tokenizer.decode(output_ids[0])) + + logits = self.model.forward(input_ids).logits[0] + logprobs = F.log_softmax( + logits, dim=-1, dtype=torch.float32 + ).tolist() + # index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1]) + # print("index", index_of_max) + logprobs = [ + sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS] + for token_logprobs in logprobs + ] + prefill_logprobs.append(logprobs) + + out_queue.put( + ModelOutput( + output_strs=output_strs, top_input_logprobs=prefill_logprobs + ) + ) + + else: + assert isinstance(prompts, List[str]) + logits = self.model.encode(prompts).tolist() + + out_queue.put(ModelOutput(embed_logits=logits)) + + def forward( + self, + prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + max_new_tokens=64, + ): + self.in_queue.put((prompts, max_new_tokens)) + return self.out_queue.get() + + def terminate(self): + self.model_proc.terminate() + self.in_queue = self.out_queue = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.model_proc.terminate() + self.in_queue = self.out_queue = None + + +class SRTRunner: + def __init__( + self, + model_path, + tp_size=1, + torch_dtype=torch.float16, + is_embedding_model=None, + ): + self.is_embedding_model = ( + is_embedding_model(model_path) + if is_embedding_model is None + else is_embedding_model + ) + if self.is_embedding_model: + raise NotImplementedError() + + self.runtime = Runtime( + model_path=model_path, + tp_size=tp_size, + dtype=get_dtype_str(torch_dtype), + ) + + def forward( + self, + prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + max_new_tokens=64, + ): + # the return value contains logprobs from prefill + output_strs = [] + top_input_logprobs = [] + sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} + for prompt in prompts: + response = self.runtime.generate( + prompt, + sampling_params=sampling_params, + return_logprob=True, + top_logprobs_num=NUM_TOP_LOGPROBS, + ) + response = json.loads(response) + output_strs.append(response["text"]) + top_input_logprobs.append( + [ + [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] + for x in response["meta_info"]["input_top_logprobs"][1:] + ] + + [ + [ + tup[0] + for tup in response["meta_info"]["output_top_logprobs"][0][ + :NUM_TOP_LOGPROBS + ] + ] + ] + ) + # print(response["meta_info"]["output_top_logprobs"][0]) + + return ModelOutput( + output_strs=output_strs, top_input_logprobs=top_input_logprobs + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.runtime.shutdown() + del self.runtime diff --git a/test/srt/models/test_causal_models.py b/test/srt/models/test_causal_models.py new file mode 100644 index 000000000..3cec4490a --- /dev/null +++ b/test/srt/models/test_causal_models.py @@ -0,0 +1,67 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +import torch + +from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner + +MODELS = [ + ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1), + # ("meta-llama/Meta-Llama-3.1-8B-Instruct", 2), +] +TORCH_DTYPES = [torch.float16] + + +class TestCausalModels(unittest.TestCase): + + def assert_close_prefill_logits( + self, + prompts, + model_path, + tp_size, + torch_dtype, + ) -> None: + with HFRunner( + model_path, torch_dtype=torch_dtype, is_embedding_model=False + ) as hf_runner: + hf_outputs = hf_runner.forward(prompts) + + with SRTRunner( + model_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_embedding_model=False, + ) as srt_runner: + srt_outputs = srt_runner.forward(prompts) + + for i in range(len(prompts)): + hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + + tolerance = 2e-2 + assert torch.all( + abs(hf_logprobs - srt_logprobs) < tolerance + ), f"prefill logprobs not all close" + + def test_prefill_logits(self): + for model, tp_size in MODELS: + for torch_dtype in TORCH_DTYPES: + self.assert_close_prefill_logits( + DEFAULT_PROMPTS, model, tp_size, torch_dtype + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore")