[Feature] Initial support for multi-LoRA serving (#1307)
This commit is contained in:
62
scripts/playground/lora/lora_hf_play.py
Normal file
62
scripts/playground/lora/lora_hf_play.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
HF_TOKEN = "..."
|
||||
|
||||
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
|
||||
|
||||
base_model = LlamaForCausalLM.from_pretrained(
|
||||
MODEL,
|
||||
device_map="auto",
|
||||
# load_in_8bit=True,
|
||||
torch_dtype=torch.float16,
|
||||
# use_auth_token=HF_TOKEN,
|
||||
).cuda()
|
||||
|
||||
|
||||
# base model generate
|
||||
with torch.no_grad():
|
||||
output_tensors = base_model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= base output ========")
|
||||
print(output)
|
||||
|
||||
|
||||
# peft model generate
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
ADAPTER,
|
||||
torch_dtype=torch.float16,
|
||||
is_trainable=False,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= peft output ========")
|
||||
print(output)
|
||||
30
scripts/playground/lora/lora_vllm_play.py
Normal file
30
scripts/playground/lora/lora_vllm_play.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
llm = LLM(model=MODEL, enable_lora=True)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
|
||||
)
|
||||
|
||||
print(outputs[0].prompt)
|
||||
print(outputs[0].outputs[0].text)
|
||||
55
scripts/playground/lora/test_lora.py
Normal file
55
scripts/playground/lora/test_lora.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import json
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
lora_path = "/home/ying/test_lora"
|
||||
prompt_file = "/home/ying/test_prompt/dialogue_choice_prompts.json"
|
||||
server_url = "http://127.0.0.1:30000"
|
||||
|
||||
client = openai.Client(base_url=server_url + "/v1", api_key="EMPTY")
|
||||
|
||||
|
||||
# @sgl.function
|
||||
# def generate(s, prompt):
|
||||
# s += prompt
|
||||
# s += sgl.gen("ans")
|
||||
|
||||
# sgl.set_default_backend(sgl.RuntimeEndpoint(server_url))
|
||||
|
||||
|
||||
def generate(prompt, lora_path):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": {},
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": None,
|
||||
"top_logprobs_num": None,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
response = requests.post(
|
||||
server_url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
|
||||
with open(prompt_file, "r") as f:
|
||||
samples = json.load(f)
|
||||
|
||||
|
||||
for sample in samples[:1]:
|
||||
assert sample[0]["role"] == "user"
|
||||
prompt = sample[0]["content"]
|
||||
assert sample[1]["role"] == "assistant"
|
||||
ref = sample[1]["content"]
|
||||
|
||||
state = generate(prompt, lora_path)
|
||||
print("================================")
|
||||
print(ref)
|
||||
print("--------------------------------")
|
||||
# print(state["ans"])
|
||||
print(state)
|
||||
print()
|
||||
Reference in New Issue
Block a user