[Feature] Initial support for multi-LoRA serving (#1307)
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from sglang.srt.server import Runtime
|
||||
@@ -52,6 +53,7 @@ def get_dtype_str(torch_dtype):
|
||||
|
||||
def get_top_logprobs(logits, k):
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
del logits
|
||||
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
|
||||
return logprobs
|
||||
|
||||
@@ -71,8 +73,10 @@ class HFRunner:
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation,
|
||||
output_str_only=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.output_str_only = output_str_only
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
@@ -95,7 +99,7 @@ class HFRunner:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
@@ -110,13 +114,16 @@ class HFRunner:
|
||||
)
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
if lora_paths is not None:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
if prompts is not None:
|
||||
if self.is_generation:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for p in prompts:
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = self.tokenizer.encode(
|
||||
p, return_tensors="pt"
|
||||
@@ -124,6 +131,16 @@ class HFRunner:
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
if lora_paths is not None and lora_paths[i] is not None:
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.base_model,
|
||||
lora_paths[i],
|
||||
torch_dtype=torch_dtype,
|
||||
is_trainable=False,
|
||||
)
|
||||
else:
|
||||
self.model = self.base_model
|
||||
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
@@ -131,25 +148,30 @@ class HFRunner:
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
output_scores=(not self.output_str_only),
|
||||
)
|
||||
output_strs.append(
|
||||
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
|
||||
)
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
if not self.output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(
|
||||
logits[0], NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
del input_logits
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(
|
||||
input_logits, NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
@@ -160,6 +182,7 @@ class HFRunner:
|
||||
)
|
||||
|
||||
else:
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
@@ -167,8 +190,9 @@ class HFRunner:
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens))
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths))
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
@@ -191,6 +215,10 @@ class SRTRunner:
|
||||
is_generation,
|
||||
tp_size=1,
|
||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths=None,
|
||||
max_loras_per_batch=4,
|
||||
disable_cuda_graph=False,
|
||||
disable_radix_cache=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.runtime = Runtime(
|
||||
@@ -201,12 +229,17 @@ class SRTRunner:
|
||||
mem_fraction_static=0.69,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
@@ -214,9 +247,10 @@ class SRTRunner:
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
@@ -256,6 +290,37 @@ class SRTRunner:
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def batch_forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
"""
|
||||
testing serving by sending all prompts once
|
||||
only return output strings and no logprobs
|
||||
"""
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = self.runtime.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user