diff --git a/examples/runtime/lora.py b/examples/runtime/lora.py index 183cfb484..bf3fc2d9e 100644 --- a/examples/runtime/lora.py +++ b/examples/runtime/lora.py @@ -1,5 +1,5 @@ # launch server -# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora /home/ying/test_lora_1 /home/ying/test_lora_2 lora3=/home/ying/test_lora_3 lora4=/home/ying/test_lora_4 --disable-radix --disable-cuda-graph --max-loras-per-batch 4 +# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4 # send requests # lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length @@ -22,12 +22,12 @@ json_data = { "sampling_params": {"max_new_tokens": 32}, "lora_path": [ "/home/ying/test_lora", - "/home/ying/test_lora_1", - "/home/ying/test_lora_2", - "lora3", - "lora4", - "/home/ying/test_lora", - "/home/ying/test_lora_1", + "lora1", + "lora2", + "lora1", + "lora2", + None, + None, ], } response = requests.post( diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index dcf6450a0..379b233bd 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -28,18 +28,18 @@ from typing import Any, Dict, List, Optional, Tuple import safetensors.torch import torch from torch import nn -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.loader import DefaultModelLoader +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 409065177..cf54df1b1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -594,6 +594,16 @@ class ServerArgs: "Please use sglang<=0.3.2 or wait for later updates." ) + if isinstance(self.lora_paths, list): + lora_paths = self.lora_paths + self.lora_paths = {} + for lora_path in lora_paths: + if "=" in lora_path: + name, path = lora_path.split("=", 1) + self.lora_paths[name] = path + else: + self.lora_paths[lora_path] = lora_path + def prepare_server_args(argv: List[str]) -> ServerArgs: """ diff --git a/test/srt/models/test_lora.py b/test/srt/models/test_lora.py index e044c4c0b..85a963893 100644 --- a/test/srt/models/test_lora.py +++ b/test/srt/models/test_lora.py @@ -97,9 +97,7 @@ class TestLoRA(unittest.TestCase): ) with HFRunner( - base_path, - torch_dtype=torch_dtype, - is_generation=True, + base_path, torch_dtype=torch_dtype, model_type="generation" ) as hf_runner: hf_outputs = hf_runner.forward( prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths @@ -108,7 +106,7 @@ class TestLoRA(unittest.TestCase): with HFRunner( base_path, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", ) as hf_runner: hf_no_lora_outputs = hf_runner.forward( prompts, max_new_tokens=max_new_tokens @@ -118,7 +116,7 @@ class TestLoRA(unittest.TestCase): base_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", ) as srt_runner: srt_no_lora_outputs = srt_runner.forward( prompts, max_new_tokens=max_new_tokens @@ -198,7 +196,7 @@ class TestLoRA(unittest.TestCase): base_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", lora_paths=all_lora_paths, max_loras_per_batch=3, disable_cuda_graph=True, @@ -211,7 +209,7 @@ class TestLoRA(unittest.TestCase): with HFRunner( base_path, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", output_str_only=True, ) as hf_runner: hf_outputs = hf_runner.forward( @@ -237,7 +235,7 @@ class TestLoRA(unittest.TestCase): base_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", ) as srt_runner: srt_no_lora_outputs = srt_runner.forward( prompts, max_new_tokens=max_new_tokens @@ -247,7 +245,7 @@ class TestLoRA(unittest.TestCase): base_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", lora_paths=all_lora_paths, ) as srt_runner: srt_outputs = srt_runner.forward( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4e6ce73a5..bfa5f0cc7 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -7,7 +7,7 @@ suites = { "minimal": [ "models/test_embedding_models.py", "models/test_generation_models.py", - # "models/test_lora.py", + "models/test_lora.py", "models/test_reward_models.py", "sampling/penaltylib", "test_chunked_prefill.py",