Files
sglang/python/sglang/test/runners.py
2024-11-03 13:27:12 -08:00

393 lines
14 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import List, Union
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Runtime
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
"The capital of the United Kingdom is",
"Today is a sunny day and I like",
"AI is a field of computer science focused on",
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
]
dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)
NUM_TOP_LOGPROBS = 5
def get_dtype_str(torch_dtype):
if torch_dtype is torch.float16:
return "float16"
else:
raise NotImplementedError()
def get_top_logprobs(logits, k):
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
del logits
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
return logprobs
@dataclass
class ModelOutput:
output_strs: List[str] = None
output_ids: List[int] = None
top_input_logprobs: List[torch.Tensor] = None
top_output_logprobs: List[torch.Tensor] = None
embed_logits: List[torch.Tensor] = None
scores: List[float] = None
class HFRunner:
def __init__(
self,
model_path: str,
torch_dtype: torch.dtype,
model_type: str = "generation",
output_str_only: bool = False,
):
self.model_type = model_type
self.output_str_only = output_str_only
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
self.model_proc = mp.Process(
target=self.start_model_process,
args=(
self.in_queue,
self.out_queue,
model_path,
torch_dtype,
),
)
self.model_proc.start()
def needs_trust_remote_code(self, model_path):
models_needs_trust_remote = [
"LxzGordon/URM-LLaMa-3.1-8B",
]
if model_path in models_needs_trust_remote:
return True
return False
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
# Apply model-specific patches
monkey_patch_gemma2_sdpa()
# Load the model and tokenizer
if self.model_type == "generation":
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
elif self.model_type == "embedding":
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_path,
model_kwargs={"torch_dtype": torch_dtype},
).cuda()
elif self.model_type == "reward":
from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=self.needs_trust_remote_code(model_path),
).cuda()
else:
raise Exception(f"Unrecognized model type {self.model_type}")
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
# Run forward
while True:
prompts, max_new_tokens, lora_paths = in_queue.get()
if lora_paths is not None:
assert len(prompts) == len(lora_paths)
if prompts is not None:
if self.model_type == "generation":
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = self.tokenizer.encode(
p, return_tensors="pt"
).cuda()
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
self.model = PeftModel.from_pretrained(
self.base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
self.model = self.base_model
outputs = self.model.generate(
input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=(not self.output_str_only),
)
output_strs.append(
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
)
if not self.output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(
logits[0], NUM_TOP_LOGPROBS
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(
input_logits, NUM_TOP_LOGPROBS
).tolist()
)
del input_logits
out_queue.put(
ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
)
elif self.model_type == "embedding":
assert not self.output_str_only
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "reward":
scores = []
for conv in prompts:
conv_formatted = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
conv_tokenized = self.tokenizer(
conv_formatted, return_tensors="pt"
).to("cuda")
scores.append(
float(self.model(**conv_tokenized).logits[0][0].item())
)
out_queue.put(ModelOutput(scores=scores))
else:
raise Exception(f"Unrecognized model type {self.model_type}")
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
self.in_queue.put((prompts, max_new_tokens, lora_paths))
return self.out_queue.get()
def terminate(self):
self.model_proc.terminate()
self.in_queue = self.out_queue = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.model_proc.terminate()
self.in_queue = self.out_queue = None
class SRTRunner:
def __init__(
self,
model_path: str,
torch_dtype: torch.dtype,
model_type: str,
tp_size: int = 1,
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None,
max_loras_per_batch: int = 4,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.65,
trust_remote_code=False,
is_embedding=not self.is_generation,
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
)
self.tokenizer = get_tokenizer(model_path)
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = self.runtime.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
else:
scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=scores)
def batch_forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
"""
testing serving by sending all prompts once
only return output strings and no logprobs
"""
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.runtime.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
response = json.loads(response)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
else:
scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=scores)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.runtime.shutdown()
del self.runtime
def monkey_patch_gemma2_sdpa():
"""
Use sdpa by default to fix the OOM issue.
Revert this commit:
https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660
"""
from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel
def _check_and_enable_sdpa(config, hard_check_only: bool = False):
config._attn_implementation = "sdpa"
return config
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)