[Feature] Initial support for multi-LoRA serving (#1307)

This commit is contained in:
Ying Sheng
2024-09-12 16:46:14 -07:00
committed by GitHub
parent c33d82a211
commit 712216928f
21 changed files with 1435 additions and 22 deletions

View File

@@ -21,6 +21,7 @@ from typing import List, Union
import torch
import torch.nn.functional as F
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
@@ -52,6 +53,7 @@ def get_dtype_str(torch_dtype):
def get_top_logprobs(logits, k):
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
del logits
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
return logprobs
@@ -71,8 +73,10 @@ class HFRunner:
model_path,
torch_dtype,
is_generation,
output_str_only=False,
):
self.is_generation = is_generation
self.output_str_only = output_str_only
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
@@ -95,7 +99,7 @@ class HFRunner:
)
if self.is_generation:
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
@@ -110,13 +114,16 @@ class HFRunner:
)
while True:
prompts, max_new_tokens = in_queue.get()
prompts, max_new_tokens, lora_paths = in_queue.get()
if lora_paths is not None:
assert len(prompts) == len(lora_paths)
if prompts is not None:
if self.is_generation:
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for p in prompts:
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = self.tokenizer.encode(
p, return_tensors="pt"
@@ -124,6 +131,16 @@ class HFRunner:
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
self.model = PeftModel.from_pretrained(
self.base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
self.model = self.base_model
outputs = self.model.generate(
input_ids,
do_sample=False,
@@ -131,25 +148,30 @@ class HFRunner:
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
output_scores=(not self.output_str_only),
)
output_strs.append(
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
)
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
for logits in outputs.scores
]
)
del outputs
if not self.output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(
logits[0], NUM_TOP_LOGPROBS
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
)
del input_logits
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(
input_logits, NUM_TOP_LOGPROBS
).tolist()
)
del input_logits
out_queue.put(
ModelOutput(
@@ -160,6 +182,7 @@ class HFRunner:
)
else:
assert not self.output_str_only
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
@@ -167,8 +190,9 @@ class HFRunner:
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
self.in_queue.put((prompts, max_new_tokens))
self.in_queue.put((prompts, max_new_tokens, lora_paths))
return self.out_queue.get()
def terminate(self):
@@ -191,6 +215,10 @@ class SRTRunner:
is_generation,
tp_size=1,
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths=None,
max_loras_per_batch=4,
disable_cuda_graph=False,
disable_radix_cache=False,
):
self.is_generation = is_generation
self.runtime = Runtime(
@@ -201,12 +229,17 @@ class SRTRunner:
mem_fraction_static=0.69,
trust_remote_code=False,
is_embedding=not self.is_generation,
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
)
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
if self.is_generation:
# the return value contains logprobs from prefill
@@ -214,9 +247,10 @@ class SRTRunner:
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
for i, prompt in enumerate(prompts):
response = self.runtime.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
@@ -256,6 +290,37 @@ class SRTRunner:
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
def batch_forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
"""
testing serving by sending all prompts once
only return output strings and no logprobs
"""
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.runtime.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
response = json.loads(response)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
def __enter__(self):
return self