[Fix, LoRA] fix LoRA with updates in main (#1545)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# launch server
|
||||
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora /home/ying/test_lora_1 /home/ying/test_lora_2 lora3=/home/ying/test_lora_3 lora4=/home/ying/test_lora_4 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
|
||||
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
|
||||
|
||||
# send requests
|
||||
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
|
||||
@@ -22,12 +22,12 @@ json_data = {
|
||||
"sampling_params": {"max_new_tokens": 32},
|
||||
"lora_path": [
|
||||
"/home/ying/test_lora",
|
||||
"/home/ying/test_lora_1",
|
||||
"/home/ying/test_lora_2",
|
||||
"lora3",
|
||||
"lora4",
|
||||
"/home/ying/test_lora",
|
||||
"/home/ying/test_lora_1",
|
||||
"lora1",
|
||||
"lora2",
|
||||
"lora1",
|
||||
"lora2",
|
||||
None,
|
||||
None,
|
||||
],
|
||||
}
|
||||
response = requests.post(
|
||||
|
||||
@@ -28,18 +28,18 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
|
||||
|
||||
@@ -594,6 +594,16 @@ class ServerArgs:
|
||||
"Please use sglang<=0.3.2 or wait for later updates."
|
||||
)
|
||||
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
self.lora_paths = {}
|
||||
for lora_path in lora_paths:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
self.lora_paths[name] = path
|
||||
else:
|
||||
self.lora_paths[lora_path] = lora_path
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
"""
|
||||
|
||||
@@ -97,9 +97,7 @@ class TestLoRA(unittest.TestCase):
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||
@@ -108,7 +106,7 @@ class TestLoRA(unittest.TestCase):
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
) as hf_runner:
|
||||
hf_no_lora_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
@@ -118,7 +116,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
@@ -198,7 +196,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
lora_paths=all_lora_paths,
|
||||
max_loras_per_batch=3,
|
||||
disable_cuda_graph=True,
|
||||
@@ -211,7 +209,7 @@ class TestLoRA(unittest.TestCase):
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
output_str_only=True,
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
@@ -237,7 +235,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
@@ -247,7 +245,7 @@ class TestLoRA(unittest.TestCase):
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
model_type="generation",
|
||||
lora_paths=all_lora_paths,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(
|
||||
|
||||
@@ -7,7 +7,7 @@ suites = {
|
||||
"minimal": [
|
||||
"models/test_embedding_models.py",
|
||||
"models/test_generation_models.py",
|
||||
# "models/test_lora.py",
|
||||
"models/test_lora.py",
|
||||
"models/test_reward_models.py",
|
||||
"sampling/penaltylib",
|
||||
"test_chunked_prefill.py",
|
||||
|
||||
Reference in New Issue
Block a user