From 712216928fa252d6592a1518579018a69cb72bfe Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 12 Sep 2024 16:46:14 -0700 Subject: [PATCH] [Feature] Initial support for multi-LoRA serving (#1307) --- python/pyproject.toml | 2 +- python/sglang/srt/lora/lora.py | 403 ++++++++++++++++++ python/sglang/srt/lora/lora_config.py | 43 ++ python/sglang/srt/lora/lora_manager.py | 256 +++++++++++ python/sglang/srt/managers/io_struct.py | 6 + python/sglang/srt/managers/schedule_batch.py | 3 +- .../sglang/srt/managers/tokenizer_manager.py | 10 + python/sglang/srt/managers/tp_worker.py | 29 +- .../sglang/srt/model_executor/model_runner.py | 19 + python/sglang/srt/models/llama.py | 45 ++ python/sglang/srt/server.py | 3 + python/sglang/srt/server_args.py | 25 ++ python/sglang/srt/utils.py | 12 + python/sglang/test/runners.py | 103 ++++- scripts/playground/lora/lora_hf_play.py | 62 +++ scripts/playground/lora/lora_vllm_play.py | 30 ++ scripts/playground/lora/test_lora.py | 55 +++ test/srt/models/compare.py | 52 +++ test/srt/models/test_generation_models.py | 1 + test/srt/models/test_lora.py | 297 +++++++++++++ test/srt/run_suite.py | 1 + 21 files changed, 1435 insertions(+), 22 deletions(-) create mode 100644 python/sglang/srt/lora/lora.py create mode 100644 python/sglang/srt/lora/lora_config.py create mode 100644 python/sglang/srt/lora/lora_manager.py create mode 100644 scripts/playground/lora/lora_hf_play.py create mode 100644 scripts/playground/lora/lora_vllm_play.py create mode 100644 scripts/playground/lora/test_lora.py create mode 100644 test/srt/models/compare.py create mode 100644 test/srt/models/test_lora.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 1389822a3..dd6f8ece7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "intere openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate"] +test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py new file mode 100644 index 000000000..6cc1f0348 --- /dev/null +++ b/python/sglang/srt/lora/lora.py @@ -0,0 +1,403 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters" +# and "Punica: Multi-Tenant LoRA Serving" + +# LoRA layers class inheritance adapted from: +# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py + + +import json +import os +import re +from typing import Any, Dict, List, Optional, Tuple + +import safetensors.torch +import torch +from torch import nn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.loader import DefaultModelLoader + +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata + + +class BaseLayerWithLoRA(nn.Module): + def __init__(self, base_layer, segment_gemm, lora_rank, scaling): + super().__init__() + self.base_layer = base_layer + self.segment_gemm = segment_gemm + self.lora_rank = lora_rank + self.scaling = scaling + self.set_lora = False + + def forward(self, x: torch.Tensor): + return self.base_layer.forward(x) + + def set_lora_info(self, *args): + pass + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + def __init__( + self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + self.weight = base_layer.weight + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + def __init__( + self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + # TODO + return output + + def forward(self, input_: torch.Tensor): + # duplicate the logic in ColumnParallelLinear + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_, bias + ) + + if self.set_lora: + output_parallel = self.apply_lora(output_parallel, input_) + + if self.base_layer.gather_output: + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output, output_bias + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + def __init__( + self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): + self.set_lora = True + self.A_buffer = A_buffer + self.B_buffer = B_buffer + self.bs = bs + self.seq_lens = seq_lens + self.weight_indices = weight_indices + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + lora_a_output = self.segment_gemm.run( + x=x, + weights=self.A_buffer, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + # FIXME + assert lora_a_output.shape[-1] == self.lora_rank * 2 + lora_output = torch.empty_like(base_output) + output_dim = lora_output.shape[-1] // 2 + for i in range(2): + left = output_dim * i + right = left + output_dim + lora_output[:, left:right] = self.segment_gemm.run( + x=lora_a_output[ + :, self.lora_rank * i : self.lora_rank * (i + 1) + ].contiguous(), + weights=self.B_buffer[:, left:right, :].contiguous(), + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + return base_output + lora_output * self.scaling + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + def __init__( + self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def set_lora_info( + self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices + ): + self.set_lora = True + self.A_buffer_qkv = A_buffer_qkv + self.B_buffer_q = B_buffer_q + self.B_buffer_kv = B_buffer_kv + self.bs = bs + self.seq_lens = seq_lens + self.weight_indices = weight_indices + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + lora_a_output = self.segment_gemm.run( + x=x, + weights=self.A_buffer_qkv, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + # FIXME parallelize qkv + lora_output = torch.empty_like(base_output) + # q + output_dim_q = self.B_buffer_q.shape[-2] + lora_output[:, :output_dim_q] = self.segment_gemm.run( + x=lora_a_output[:, : self.lora_rank].contiguous(), + weights=self.B_buffer_q, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + # kv + output_dim_kv = self.B_buffer_kv.shape[-2] // 2 + for i in range(2): + left = output_dim_kv * i + right = left + output_dim_kv + lora_output[:, output_dim_q + left : output_dim_q + right] = ( + self.segment_gemm.run( + x=lora_a_output[ + :, self.lora_rank * (i + 1) : self.lora_rank * (i + 2) + ].contiguous(), + weights=self.B_buffer_kv[:, left:right, :].contiguous(), + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + ) + return base_output + lora_output * self.scaling + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + def __init__( + self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): + self.set_lora = True + self.A_buffer = A_buffer + self.B_buffer = B_buffer + self.bs = bs + self.seq_lens = seq_lens + self.weight_indices = weight_indices + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + lora_output = self.segment_gemm.run( + x=x, + weights=self.A_buffer, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + lora_output = self.segment_gemm.run( + x=lora_output, + weights=self.B_buffer, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + return base_output + lora_output * self.scaling + + def forward(self, input_): + # duplicate the logic in RowParallelLinear + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_parallel + ) + + if self.set_lora: + output_parallel = self.apply_lora(output_parallel, input_parallel) + + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = ( + output_ + self.base_layer.bias + if self.base_layer.bias is not None + else output_ + ) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + +def get_lora_layer( + layer: nn.Module, segment_gemm, lora_rank, scaling +) -> BaseLayerWithLoRA: + supported_layer_types = { + # the order matters + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLoRA, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling) + return ret + raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") + + +def get_mapped_params(module_names): + ret = set() + for module_name in module_names: + ret.add(params_mapping(module_name)) + return list(ret) + + +class LoRALayer(nn.Module): + def __init__(self, config, base_hf_config): + super().__init__() + self.config = config + self.base_hf_config = base_hf_config + self.weights = {} + self.weight_gpu = {} + + def load_to_gpu(self): + for name, weight in self.weights.items(): + self.weight_gpu[name] = weight.to(torch.float16).to("cuda") + + def offload_from_gpu(self): + for name, weight in self.weights.items(): + self.weight_gpu[name] = None + + +class LoRAAdapter(nn.Module): + def __init__(self, uid, config, base_hf_config, load_config): + super().__init__() + self.uid = uid + self.config = config + assert self.config.hf_config["peft_type"].lower() == "lora" + self.base_hf_config = base_hf_config + self.load_config = load_config + self.scaling = self.config.lora_alpha / self.config.r + + self.layers = nn.ModuleList( + [ + LoRALayer(config, base_hf_config) + for i in range(base_hf_config.num_hidden_layers) + ] + ) + + self.weights = {} + self.weights_gpu = {} + + def get_stacked_multiply(self, module_name): + stacked_rank = { + "qkv_proj": 3, + "kv_proj": 2, + "gate_up_proj": 2, + } + return stacked_rank[module_name] if module_name in stacked_rank else 1 + + def load_to_gpu(self): + for name, weight in self.weights.items(): + self.weights_gpu[name] = weight.to(torch.float16).to("cuda") + for layer in self.layers: + layer.load_to_gpu() + + def offload_from_gpu(self): + for name, weight in self.weights.items(): + self.weights_gpu[name] = None + for layer in self.layers: + layer.offload_from_gpu() + + # initialize the LoRA weights to cpu + def initialize_weights(self): + model_path = self.config.path + loader = DefaultModelLoader(self.load_config) + revision = getattr(self.config.hf_config, "revision", None) + for name, loaded_weight in loader._get_weights_iterator( + model_path, revision=revision, fall_back_to_pt=True + ): + match = re.search(r"layers\.(\d+)\.", name) + if match is not None: + layer_id = int(match.group(1)) + self.layers[layer_id].weights[name] = loaded_weight.cpu() + else: + self.weights[name] = loaded_weight.cpu() + + # stack kv_proj and gate_up_proj + for i in range(self.base_hf_config.num_hidden_layers): + layer = self.layers[i] + weight_names = [name for name, _ in layer.weights.items()] + for weight_name in weight_names: + if "k_proj" in weight_name: + q_name = weight_name.replace("k_proj", "q_proj") + v_name = weight_name.replace("k_proj", "v_proj") + kv_name = weight_name.replace("k_proj", "kv_proj") + qkv_name = weight_name.replace("k_proj", "qkv_proj") + if "lora_A" in weight_name: + layer.weights[qkv_name] = torch.cat( + ( + layer.weights[q_name], + layer.weights[weight_name], + layer.weights[v_name], + ), + 0, + ) + layer.weights.pop(q_name) + layer.weights.pop(weight_name) + layer.weights.pop(v_name) + else: + layer.weights[kv_name] = torch.cat( + ( + layer.weights[weight_name], + layer.weights[v_name], + ), + 0, + ) + layer.weights.pop(weight_name) + layer.weights.pop(v_name) + elif "gate_proj" in weight_name: + up_name = weight_name.replace("gate_proj", "up_proj") + gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") + layer.weights[gate_up_name] = torch.cat( + (layer.weights[weight_name], layer.weights[up_name]), 0 + ) + layer.weights.pop(weight_name) + layer.weights.pop(up_name) diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py new file mode 100644 index 000000000..59af0c3a9 --- /dev/null +++ b/python/sglang/srt/lora/lora_config.py @@ -0,0 +1,43 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import os + +from huggingface_hub import snapshot_download + + +class LoRAConfig: + def __init__( + self, + path: str, + ) -> None: + self.path = path + self.hf_config = self.get_lora_config() + self.target_modules = self.hf_config["target_modules"] + self.r = self.hf_config["r"] + self.lora_alpha = self.hf_config["lora_alpha"] + + def get_lora_config(self, dummy=False): + if dummy: + raise NotImplementedError() + else: + if not os.path.isdir(self.path): + weights_dir = snapshot_download(self.path, allow_patterns=["*.json"]) + else: + weights_dir = self.path + config_name = "adapter_config.json" + with open(os.path.join(weights_dir, config_name), "r") as f: + return json.load(f) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py new file mode 100644 index 000000000..5e11280a4 --- /dev/null +++ b/python/sglang/srt/lora/lora_manager.py @@ -0,0 +1,256 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters" +# and "Punica: Multi-Tenant LoRA Serving" + + +import re +from dataclasses import dataclass + +import torch +from flashinfer import SegmentGEMMWrapper + +from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer +from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.utils import replace_submodule + + +def get_stacked_name(name): + # origin name -> (name for A, name for B) + params_mapping = { + "q_proj": ("qkv_proj", "q_proj"), + "k_proj": ("qkv_proj", "kv_proj"), + "v_proj": ("qkv_proj", "kv_proj"), + "gate_proj": ("gate_up_proj", "gate_up_proj"), + "up_proj": ("gate_up_proj", "gate_up_proj"), + } + return params_mapping.get(name, (name, name)) + + +def get_layer_id(name): + match = re.search(r"layers\.(\d+)\.", name) + if match is None: + return None + return int(match.group(1)) + + +class LoRAManager: + def __init__( + self, + base_model, + lora_paths, + base_hf_config, + max_loras_per_batch, + load_config, + dtype, + ): + self.base_model = base_model + self.lora_paths = lora_paths + self.base_hf_config = base_hf_config + self.max_loras_per_batch = max_loras_per_batch + self.load_config = load_config + self.dtype = dtype + + workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") + self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + + self.init_loras() + self.init_lora_memory_pool() + self.init_lora_batch() + + def match_target_modules(self, module_name): + for target_module in self.target_modules: + if module_name.split(".")[-1] == target_module: + return True + return False + + def get_target_modules(self): + modules = [] + for module_name, module in self.base_model.named_modules(): + if self.match_target_modules(module_name): + modules.append((module_name, module)) + return modules + + def set_lora_module(self, module_name, module): + lora_module = get_lora_layer( + module, self.segment_gemm, self.max_lora_dim, self.scaling + ) + replace_submodule(self.base_model, module_name, lora_module) + return lora_module + + def init_loras(self): + # get configs and target modules + self.configs = {} + self.origin_target_modules = set() + for path in self.lora_paths: + self.configs[path] = LoRAConfig(path) + self.origin_target_modules = set(self.origin_target_modules) | set( + self.configs[path].target_modules + ) + self.target_modules = set( + [ + self.base_model.get_module_name(module) + for module in self.origin_target_modules + ] + ) + self.target_weights = set( + [get_stacked_name(module) for module in self.origin_target_modules] + ) + + # load all weights to cpu + self.loras = [] + self.lora_id = {} + for path in self.lora_paths: + self.lora_id[path] = len(self.loras) + self.loras.append( + LoRAAdapter( + path, self.configs[path], self.base_hf_config, self.load_config + ) + ) + self.loras[-1].initialize_weights() + + # misc lora configs + self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) + self.scaling = self.loras[0].scaling + # FIXME remove the restrictions + assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values()) + assert all(x.scaling == self.scaling for x in self.loras) + + # monkey patch to use the LoRA version + self.lora_modules = [] + for module_name, module in self.get_target_modules(): + self.lora_modules.append( + (module_name, self.set_lora_module(module_name, module)) + ) + + def init_lora_memory_pool(self): + # preallocate lora memory pool + self.A_buffer = {} + self.B_buffer = {} + num_layer = self.base_hf_config.num_hidden_layers + for module_A, module_B in self.target_weights: + # init A tensor, column_major=True + hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A) + c = self.loras[-1].get_stacked_multiply(module_A) + if module_A not in self.A_buffer: + self.A_buffer[module_A] = [ + torch.empty( + ( + self.max_loras_per_batch, + self.max_lora_dim * c, + hidden_dim_A, + ), + dtype=self.dtype, + device="cuda", + ) + for i in range(num_layer) + ] + # init B tensor, column_major=True + _, hidden_dim_B = self.base_model.get_hidden_dim(module_B) + c = self.loras[-1].get_stacked_multiply(module_B) + if module_B not in self.B_buffer: + self.B_buffer[module_B] = [ + torch.empty( + ( + self.max_loras_per_batch, + hidden_dim_B * c, + self.max_lora_dim, + ), + dtype=self.dtype, + device="cuda", + ) + for i in range(num_layer) + ] + + def init_lora_batch(self): + self.active_uids = set() # set of active loras + self.buffer_id = {} # lora uid -> idx in memory pool + + def get_weight_name(self, name, idx): + for target_weight_name in self.target_weights: + if target_weight_name[idx] in name: + return target_weight_name[idx] + + def load_lora(self, uid, buffer_id): + num_layer = self.base_hf_config.num_hidden_layers + if uid is None: + for i in range(num_layer): + for k in self.A_buffer.keys(): + self.A_buffer[k][i][buffer_id] *= 0 + return + + for i in range(num_layer): + layer_weights = self.loras[self.lora_id[uid]].layers[i].weights + for name, weights in layer_weights.items(): + if "lora_A" in name: + lora_weight_name = self.get_weight_name(name, 0) + if lora_weight_name: + self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights) + else: + lora_weight_name = self.get_weight_name(name, 1) + if lora_weight_name: + self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) + + def prepare_lora_batch(self, batch, extend_seq_lens=None): + # load active loras into lora memory pool + cur_uids = set([req.lora_path for req in batch.reqs]) + assert len(cur_uids) <= self.max_loras_per_batch + i = 0 + evictable_uids = list(self.active_uids) + for uid in cur_uids: + if uid not in self.active_uids: + while i < len(evictable_uids) and evictable_uids[i] in cur_uids: + i += 1 + if i < len(evictable_uids): + self.active_uids.remove(evictable_uids[i]) + self.buffer_id.pop(evictable_uids[i]) + self.load_lora(uid, i) + self.active_uids.add(uid) + self.buffer_id[uid] = i + i += 1 + + if cur_uids == set([None]): + return + + # setup lora in forward modules + bs = len(batch.reqs) + seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs) + weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") + for i, req in enumerate(batch.reqs): + weight_indices[i] = self.buffer_id[req.lora_path] + + for module_name, module in self.lora_modules: + layer_id = get_layer_id(module_name) + + if "qkv_proj" not in module_name: + weight_name = self.get_weight_name(module_name, 0) + module.set_lora_info( + self.A_buffer[weight_name][layer_id], + self.B_buffer[weight_name][layer_id], + bs, + seg_lens, + weight_indices, + ) + else: + module.set_lora_info( + self.A_buffer["qkv_proj"][layer_id], + self.B_buffer["q_proj"][layer_id], + self.B_buffer["kv_proj"][layer_id], + bs, + seg_lens, + weight_indices, + ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f5279eb8d..abd10a9f1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -55,6 +55,9 @@ class GenerateReqInput: is_single: bool = True + # LoRA related + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None @@ -184,6 +187,9 @@ class TokenizedGenerateReqInput: # Modalities of the input images modalites: Optional[List[str]] = None + # LoRA related + lora_path: Optional[str] = None # None means just use the base model + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0a5eb3cdf..6a8e4d9f1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -98,7 +98,7 @@ class FINISH_ABORT(BaseFinishReason): class Req: """Store all inforamtion of a request.""" - def __init__(self, rid, origin_input_text, origin_input_ids): + def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None): # Input and output info self.rid = rid self.origin_input_text = origin_input_text @@ -106,6 +106,7 @@ class Req: self.origin_input_ids = origin_input_ids self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids + self.lora_path = lora_path # Memory info self.req_pool_idx = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d2fa67601..54b08b337 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -266,6 +266,11 @@ class TokenizerManager: top_logprobs_num, obj.stream, modalities, + ( + obj.lora_path[index] + if isinstance(obj.lora_path, list) + else obj.lora_path + ), ) else: # is embedding tokenized_obj = TokenizedEmbeddingReqInput( @@ -364,6 +369,11 @@ class TokenizerManager: obj.top_logprobs_num[index], obj.stream, modalities, + ( + obj.lora_path[index] + if isinstance(obj.lora_path, list) + else obj.lora_path + ), ) else: tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b1131b011..096c13108 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -87,6 +87,8 @@ class ModelTpServer: self.dp_size = server_args.dp_size self.schedule_policy = server_args.schedule_policy self.disable_regex_jump_forward = server_args.disable_regex_jump_forward + self.lora_paths = server_args.lora_paths + self.max_loras_per_batch = server_args.max_loras_per_batch # Init model and tokenizer self.model_config = ModelConfig( @@ -323,7 +325,15 @@ class ModelTpServer: self, recv_req: TokenizedGenerateReqInput, ): - req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) + if isinstance(recv_req, TokenizedGenerateReqInput): + req = Req( + recv_req.rid, + recv_req.input_text, + recv_req.input_ids, + lora_path=recv_req.lora_path, + ) + else: + req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.tokenizer = self.tokenizer req.sampling_params = recv_req.sampling_params req.pixel_values = recv_req.pixel_values @@ -442,10 +452,27 @@ class ModelTpServer: self.current_inflight_req ) + if self.lora_paths is not None: + lora_set = ( + set([req.lora_path for req in self.running_batch.reqs]) + if self.running_batch is not None + else set([]) + ) + for req in self.waiting_queue: if adder.no_remaining_tokens(): break req.init_next_round_input(None if prefix_computed else self.tree_cache) + if ( + self.lora_paths is not None + and len( + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) + > self.max_loras_per_batch + ): + break res = adder.add_one_req(req) if ( not res diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5868b0074..9fcb85454 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -41,6 +41,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import SampleOutput +from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, @@ -107,6 +108,8 @@ class ModelRunner: # Init componnets min_per_gpu_memory = self.init_torch_distributed() self.load_model() + if server_args.lora_paths is not None: + self.init_lora_manager() self.init_memory_pool( min_per_gpu_memory, server_args.max_running_requests, @@ -312,6 +315,17 @@ class ModelRunner: logger.info("Update weights end.") return True, "Succeeded to update model weights" + def init_lora_manager(self): + self.lora_manager = LoRAManager( + base_model=self.model, + lora_paths=self.server_args.lora_paths, + base_hf_config=self.model_config.hf_config, + max_loras_per_batch=self.server_args.max_loras_per_batch, + load_config=self.load_config, + dtype=self.dtype, + ) + logger.info("LoRA manager ready.") + def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 @@ -450,6 +464,8 @@ class ModelRunner: @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): + if self.server_args.lora_paths is not None: + self.lora_manager.prepare_lora_batch(batch) if ( self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)) @@ -466,6 +482,9 @@ class ModelRunner: @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): input_metadata = InputMetadata.from_schedule_batch(self, batch) + if self.server_args.lora_paths is not None: + self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens) + if self.is_generation: return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index ac53712fc..c45d9bcd8 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -324,6 +324,51 @@ class LlamaForCausalLM(nn.Module): sample_output = self.sampler(logits_output, input_metadata.sampling_info) return sample_output, logits_output + def get_hidden_dim(self, module_name): + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return self.config.hidden_size, self.config.hidden_size + elif module_name in ["kv_proj"]: + return self.config.hidden_size, self.config.hidden_size // ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + + def get_module_name_from_weight_name(self, name): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id, num_shard) + ("qkv_proj", "q_proj", "q", 3), + ("qkv_proj", "k_proj", "k", 3), + ("qkv_proj", "v_proj", "v", 3), + ("gate_up_proj", "gate_proj", 0, 2), + ("gate_up_proj", "up_proj", 1, 2), + ] + for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 52806daa9..d2a248a92 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -611,6 +611,7 @@ class Runtime: return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, ): json_data = { "text": prompt, @@ -618,7 +619,9 @@ class Runtime: "return_logprob": return_logprob, "logprob_start_len": logprob_start_len, "top_logprobs_num": top_logprobs_num, + "lora_path": lora_path, } + assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) response = requests.post( self.url + "/generate", json=json_data, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 36a16bc9f..a35c5b423 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -101,6 +101,10 @@ class ServerArgs: enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False + # LoRA + lora_paths: Optional[List[str]] = None + max_loras_per_batch: int = 8 + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -522,6 +526,21 @@ class ServerArgs: help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) + # LoRA options + parser.add_argument( + "--lora-paths", + type=str, + nargs="*", + default=None, + help="The list of LoRA adapters.", + ) + parser.add_argument( + "--max-loras-per-batch", + type=int, + default=8, + help="Maximum number of adapters for a running batch, include base-only request", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size @@ -539,6 +558,12 @@ class ServerArgs: assert not ( self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" + assert ( + self.max_loras_per_batch > 0 + # FIXME + and (self.lora_paths is None or self.disable_cuda_graph) + and (self.lora_paths is None or self.disable_radix_cache) + ), "compatibility of lora and cuda graph and radix attention is in progress" def prepare_server_args(argv: List[str]) -> ServerArgs: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 66a5679d7..125bb556f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -35,6 +35,7 @@ import torch import torch.distributed as dist from fastapi.responses import JSONResponse from packaging import version as pkg_version +from torch import nn from torch.nn.parameter import Parameter from triton.runtime.cache import ( FileCacheManager, @@ -714,3 +715,14 @@ def configure_logger(server_args, prefix: str = ""): datefmt="%H:%M:%S", force=True, ) + + +# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9 +def replace_submodule( + model: nn.Module, module_name: str, new_module: nn.Module +) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 1d18d305f..68dd43dad 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -21,6 +21,7 @@ from typing import List, Union import torch import torch.nn.functional as F +from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer from sglang.srt.server import Runtime @@ -52,6 +53,7 @@ def get_dtype_str(torch_dtype): def get_top_logprobs(logits, k): logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + del logits logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1) return logprobs @@ -71,8 +73,10 @@ class HFRunner: model_path, torch_dtype, is_generation, + output_str_only=False, ): self.is_generation = is_generation + self.output_str_only = output_str_only self.in_queue = mp.Queue() self.out_queue = mp.Queue() @@ -95,7 +99,7 @@ class HFRunner: ) if self.is_generation: - self.model = AutoModelForCausalLM.from_pretrained( + self.base_model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=False, @@ -110,13 +114,16 @@ class HFRunner: ) while True: - prompts, max_new_tokens = in_queue.get() + prompts, max_new_tokens, lora_paths = in_queue.get() + if lora_paths is not None: + assert len(prompts) == len(lora_paths) + if prompts is not None: if self.is_generation: output_strs = [] top_input_logprobs = [] top_output_logprobs = [] - for p in prompts: + for i, p in enumerate(prompts): if isinstance(p, str): input_ids = self.tokenizer.encode( p, return_tensors="pt" @@ -124,6 +131,16 @@ class HFRunner: else: input_ids = torch.tensor([p], device="cuda") + if lora_paths is not None and lora_paths[i] is not None: + self.model = PeftModel.from_pretrained( + self.base_model, + lora_paths[i], + torch_dtype=torch_dtype, + is_trainable=False, + ) + else: + self.model = self.base_model + outputs = self.model.generate( input_ids, do_sample=False, @@ -131,25 +148,30 @@ class HFRunner: top_p=None, max_new_tokens=max_new_tokens, return_dict_in_generate=True, - output_scores=True, + output_scores=(not self.output_str_only), ) output_strs.append( self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :]) ) - # outputs.scores: (num_token, 1, vocab_size) - top_output_logprobs.append( - [ - get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist() - for logits in outputs.scores - ] - ) - del outputs + if not self.output_str_only: + # outputs.scores: (num_token, 1, vocab_size) + top_output_logprobs.append( + [ + get_top_logprobs( + logits[0], NUM_TOP_LOGPROBS + ).tolist() + for logits in outputs.scores + ] + ) + del outputs - input_logits = self.model.forward(input_ids).logits[0] - top_input_logprobs.append( - get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist() - ) - del input_logits + input_logits = self.model.forward(input_ids).logits[0] + top_input_logprobs.append( + get_top_logprobs( + input_logits, NUM_TOP_LOGPROBS + ).tolist() + ) + del input_logits out_queue.put( ModelOutput( @@ -160,6 +182,7 @@ class HFRunner: ) else: + assert not self.output_str_only logits = self.model.encode(prompts).tolist() out_queue.put(ModelOutput(embed_logits=logits)) @@ -167,8 +190,9 @@ class HFRunner: self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, max_new_tokens=8, + lora_paths=None, ): - self.in_queue.put((prompts, max_new_tokens)) + self.in_queue.put((prompts, max_new_tokens, lora_paths)) return self.out_queue.get() def terminate(self): @@ -191,6 +215,10 @@ class SRTRunner: is_generation, tp_size=1, port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + lora_paths=None, + max_loras_per_batch=4, + disable_cuda_graph=False, + disable_radix_cache=False, ): self.is_generation = is_generation self.runtime = Runtime( @@ -201,12 +229,17 @@ class SRTRunner: mem_fraction_static=0.69, trust_remote_code=False, is_embedding=not self.is_generation, + lora_paths=lora_paths, + max_loras_per_batch=max_loras_per_batch, + disable_cuda_graph=disable_cuda_graph, + disable_radix_cache=disable_radix_cache, ) def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, max_new_tokens=8, + lora_paths=None, ): if self.is_generation: # the return value contains logprobs from prefill @@ -214,9 +247,10 @@ class SRTRunner: top_input_logprobs = [] top_output_logprobs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} - for prompt in prompts: + for i, prompt in enumerate(prompts): response = self.runtime.generate( prompt, + lora_path=lora_paths[i] if lora_paths else None, sampling_params=sampling_params, return_logprob=True, logprob_start_len=0, @@ -256,6 +290,37 @@ class SRTRunner: logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) + def batch_forward( + self, + prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + max_new_tokens=8, + lora_paths=None, + ): + """ + testing serving by sending all prompts once + only return output strings and no logprobs + """ + if self.is_generation: + # the return value contains logprobs from prefill + output_strs = [] + sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} + response = self.runtime.generate( + prompts, + lora_path=lora_paths if lora_paths else None, + sampling_params=sampling_params, + ) + response = json.loads(response) + output_strs = [r["text"] for r in response] + + return ModelOutput( + output_strs=output_strs, + ) + else: + response = self.runtime.encode(prompts) + response = json.loads(response) + logits = [x["embedding"] for x in response] + return ModelOutput(embed_logits=logits) + def __enter__(self): return self diff --git a/scripts/playground/lora/lora_hf_play.py b/scripts/playground/lora/lora_hf_play.py new file mode 100644 index 000000000..127cea696 --- /dev/null +++ b/scripts/playground/lora/lora_hf_play.py @@ -0,0 +1,62 @@ +import torch +from peft import PeftModel +from transformers import LlamaForCausalLM, LlamaTokenizer + +MODEL = "mistralai/Mistral-7B-Instruct-v0.3" +# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B" +ADAPTER = "/home/ying/test_lora" +HF_TOKEN = "..." + + +prompt = """ +### Instruction: +Write a poem about the transformers Python library. +Mention the word "large language models" in that poem. +### Response: +The Transformers are large language models, +They're used to make predictions on text. +""" + + +tokenizer = LlamaTokenizer.from_pretrained(MODEL) + +base_model = LlamaForCausalLM.from_pretrained( + MODEL, + device_map="auto", + # load_in_8bit=True, + torch_dtype=torch.float16, + # use_auth_token=HF_TOKEN, +).cuda() + + +# base model generate +with torch.no_grad(): + output_tensors = base_model.generate( + input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(), + max_new_tokens=32, + do_sample=False, + )[0] + +output = tokenizer.decode(output_tensors, skip_special_tokens=True) +print("======= base output ========") +print(output) + + +# peft model generate +model = PeftModel.from_pretrained( + base_model, + ADAPTER, + torch_dtype=torch.float16, + is_trainable=False, +) + +with torch.no_grad(): + output_tensors = model.generate( + input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(), + max_new_tokens=32, + do_sample=False, + )[0] + +output = tokenizer.decode(output_tensors, skip_special_tokens=True) +print("======= peft output ========") +print(output) diff --git a/scripts/playground/lora/lora_vllm_play.py b/scripts/playground/lora/lora_vllm_play.py new file mode 100644 index 000000000..75762d739 --- /dev/null +++ b/scripts/playground/lora/lora_vllm_play.py @@ -0,0 +1,30 @@ +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +MODEL = "mistralai/Mistral-7B-Instruct-v0.3" +ADAPTER = "/home/ying/test_lora" +prompt = """ +### Instruction: +Write a poem about the transformers Python library. +Mention the word "large language models" in that poem. +### Response: +The Transformers are large language models, +They're used to make predictions on text. +""" + + +llm = LLM(model=MODEL, enable_lora=True) + +sampling_params = SamplingParams( + temperature=0, + max_tokens=32, +) + +prompts = [prompt] + +outputs = llm.generate( + prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER) +) + +print(outputs[0].prompt) +print(outputs[0].outputs[0].text) diff --git a/scripts/playground/lora/test_lora.py b/scripts/playground/lora/test_lora.py new file mode 100644 index 000000000..069020c42 --- /dev/null +++ b/scripts/playground/lora/test_lora.py @@ -0,0 +1,55 @@ +import json + +import openai +import requests + +import sglang as sgl + +lora_path = "/home/ying/test_lora" +prompt_file = "/home/ying/test_prompt/dialogue_choice_prompts.json" +server_url = "http://127.0.0.1:30000" + +client = openai.Client(base_url=server_url + "/v1", api_key="EMPTY") + + +# @sgl.function +# def generate(s, prompt): +# s += prompt +# s += sgl.gen("ans") + +# sgl.set_default_backend(sgl.RuntimeEndpoint(server_url)) + + +def generate(prompt, lora_path): + json_data = { + "text": prompt, + "sampling_params": {}, + "return_logprob": False, + "logprob_start_len": None, + "top_logprobs_num": None, + "lora_path": lora_path, + } + response = requests.post( + server_url + "/generate", + json=json_data, + ) + return json.dumps(response.json()) + + +with open(prompt_file, "r") as f: + samples = json.load(f) + + +for sample in samples[:1]: + assert sample[0]["role"] == "user" + prompt = sample[0]["content"] + assert sample[1]["role"] == "assistant" + ref = sample[1]["content"] + + state = generate(prompt, lora_path) + print("================================") + print(ref) + print("--------------------------------") + # print(state["ans"]) + print(state) + print() diff --git a/test/srt/models/compare.py b/test/srt/models/compare.py new file mode 100644 index 000000000..2fe35357c --- /dev/null +++ b/test/srt/models/compare.py @@ -0,0 +1,52 @@ +""" +used for debug using tensor comparison +dump {name: tensor} into "log_hf.jsonl" and "log_srt.jsonl" +use the same name for two tensors that supposed to be close +recommend name like: "layer 2 after mlp" +""" + +import json +import sys + +import torch + +if len(sys.argv) > 1: + assert sys.argv[1] == "base" + hf_log = "base_log_hf.jsonl" + srt_log = "base_log_srt.jsonl" +else: + hf_log = "log_hf.jsonl" + srt_log = "log_srt.jsonl" + + +def load_data(filepath): + tensors = {} + with open(filepath, "r") as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + for k, v in data.items(): + tensors[k] = torch.tensor(v) + return tensors + + +hf_tensors = load_data(hf_log) +srt_tensors = load_data(srt_log) + + +def get_diff(t1, t2): + t1 = t1.reshape(t2.shape) + max_diff = torch.max(abs(t1.reshape(t2.shape) - t2)) + l2_dis = torch.dist(t1, t2, p=2) + return l2_dis, max_diff + + +for k, _ in srt_tensors.items(): + l2_dis, max_diff = get_diff(hf_tensors[k], srt_tensors[k]) + print(f"{k} {l2_dis=} {max_diff=}") + if k == "layer 1 attn": + print(hf_tensors[k]) + print(srt_tensors[k]) + if k == "layer 0 prefill k": + print(srt_tensors[k].shape) + print(hf_tensors[k].shape) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 46854b3e8..341b856e3 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -76,6 +76,7 @@ class TestGenerationModels(unittest.TestCase): ) -> None: if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": prompts = prompts[:-1] + with HFRunner( model_path, torch_dtype=torch_dtype, is_generation=True ) as hf_runner: diff --git a/test/srt/models/test_lora.py b/test/srt/models/test_lora.py new file mode 100644 index 000000000..51f20e492 --- /dev/null +++ b/test/srt/models/test_lora.py @@ -0,0 +1,297 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import multiprocessing as mp +import unittest +import uuid + +import torch + +from sglang.test.runners import HFRunner, SRTRunner + +LORA_SETS = [ + # { + # "base": "meta-llama/Llama-2-7b-hf", + # "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"], + # }, + {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, + # {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]}, + # { + # "base": "mistralai/Mistral-7B-Instruct-v0.3", + # "loras": [ + # "/home/ying/test_lora", + # "/home/ying/test_lora_1", + # "/home/ying/test_lora_2", + # "/home/ying/test_lora_3", + # "/home/ying/test_lora_4", + # ], + # }, + # {"base": "meta-llama/Llama-2-7b-hf", "loras": ["yard1/llama-2-7b-sql-lora-test"]}, +] +TORCH_DTYPES = [torch.float16] + +PROMPTS = [ + """ +### Instruction: +Write a poem about the transformers Python library. +Mention the word "large language models" in that poem. +### Response: +The Transformers are large language models, +They're used to make predictions on text. +""", + """ +### Instruction: +Tell me about llamas and alpacas +### Response: +Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. +### Question 2: +What do you know about llamas? +### Answer: +""", +] + +# import json +# +# with open("/home/ying/test_prompt/dialogue_choice_prompts.json", "r") as f: +# samples = json.load(f) +# for sample in samples[:5]: +# assert sample[0]["role"] == "user" +# PROMPTS.append(sample[0]["content"][:2000]) + + +class TestLoRA(unittest.TestCase): + + def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): + print("=================== testing inference =======================") + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [None] + i = 0 + for _ in range(len(prompts) - 1): + batch_lora_paths.append(all_lora_paths[i]) + i = (i + 1) % len(all_lora_paths) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + lora_paths=all_lora_paths, + max_loras_per_batch=3, + disable_cuda_graph=True, + disable_radix_cache=True, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + is_generation=True, + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + is_generation=True, + ) as hf_runner: + hf_no_lora_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + ) as srt_runner: + srt_no_lora_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + for i in range(len(prompts)): + # compare input logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i]) + srt_no_lora_logprobs = torch.Tensor( + srt_no_lora_outputs.top_input_logprobs[i] + ) + print( + "max input diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and srt_lora", + torch.max(abs(srt_no_lora_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and hf_base", + torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)), + ) + print( + "max input diff between hf_lora and hf_base", + torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), + ) + + # compare output logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) + # print( + # "\noutput logprobs diff", + # [ + # float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j]))) + # for j in range(max_new_tokens) + # ], + # ) + print( + "max output diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + "\n", + ) + + # compare output strings + print(f"{hf_outputs.output_strs=}") + print(f"{srt_outputs.output_strs=}") + print(f"{hf_no_lora_outputs.output_strs=}") + print(f"{srt_no_lora_outputs.output_strs=}") + for i in range(len(prompts)): + assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( + str_outputs.output_strs[i].strip(" "), + hf_outputs.output_strs[i], + ) + # assert ( + # srt_no_lora_outputs.output_strs[i].strip(" ") + # == hf_no_lora_outputs.output_strs[i] + # ), ( + # srt_no_lora_outputs.output_strs[i].strip(" "), + # hf_no_lora_outputs.output_strs[i], + # ) + + def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): + print("=================== testing serving =======================") + # test batch forward + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [None] + i = 0 + for _ in range(len(prompts) - 1): + batch_lora_paths.append(all_lora_paths[i]) + i = (i + 1) % len(all_lora_paths) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + lora_paths=all_lora_paths, + max_loras_per_batch=3, + disable_cuda_graph=True, + disable_radix_cache=True, + ) as srt_runner: + srt_outputs = srt_runner.batch_forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + is_generation=True, + output_str_only=True, + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + # compare output strings + print(f"{hf_outputs.output_strs=}") + print(f"{srt_outputs.output_strs=}") + for i in range(len(prompts)): + assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( + srt_outputs.output_strs[i].strip(" "), + hf_outputs.output_strs[i], + ) + + def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): + print("=================== testing base inference =======================") + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [None] * len(prompts) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + ) as srt_runner: + srt_no_lora_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + lora_paths=all_lora_paths, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + for i in range(len(prompts)): + srt_no_lora_logprobs = torch.Tensor( + srt_no_lora_outputs.top_input_logprobs[i] + ) + srt_logprobs = torch.uensor(srt_outputs.top_input_logprobs[i]) + print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs))) + + print(f"{srt_no_lora_outputs.output_strs=}") + print(f"{srt_outputs.output_strs=}") + + for i in range(len(prompts)): + assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( + str_outputs.output_strs[i].strip(" "), + hf_outputs.output_strs[i], + ) + assert ( + srt_no_lora_outputs[i].output_strs.strip(" ") + == hf_no_lora_outputs[i].output_strs + ) + + def test_all(self): + for lora_set in LORA_SETS: + # self.load_lora_adapter(lora_set, 1) + for torch_dtype in TORCH_DTYPES: + tp_size = 1 + max_new_tokens = 32 + self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) + # self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) + # self.base_inference( + # PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens + # ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d5982844c..bbea7215b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -7,6 +7,7 @@ suites = { "minimal": [ "models/test_embedding_models.py", "models/test_generation_models.py", + "models/test_lora.py", "sampling/penaltylib", "test_chunked_prefill.py", "test_embedding_openai_server.py",