[Feature] Initial support for multi-LoRA serving (#1307)
This commit is contained in:
@@ -27,7 +27,7 @@ srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "intere
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate"]
|
||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
|
||||
|
||||
403
python/sglang/srt/lora/lora.py
Normal file
403
python/sglang/srt/lora/lora.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
# LoRA layers class inheritance adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.segment_gemm = segment_gemm
|
||||
self.lora_rank = lora_rank
|
||||
self.scaling = scaling
|
||||
self.set_lora = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.base_layer.forward(x)
|
||||
|
||||
def set_lora_info(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO
|
||||
return output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in ColumnParallelLinear
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_, bias
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_)
|
||||
|
||||
if self.base_layer.gather_output:
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME
|
||||
assert lora_a_output.shape[-1] == self.lora_rank * 2
|
||||
lora_output = torch.empty_like(base_output)
|
||||
output_dim = lora_output.shape[-1] // 2
|
||||
for i in range(2):
|
||||
left = output_dim * i
|
||||
right = left + output_dim
|
||||
lora_output[:, left:right] = self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * i : self.lora_rank * (i + 1)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(
|
||||
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
self.B_buffer_q = B_buffer_q
|
||||
self.B_buffer_kv = B_buffer_kv
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer_qkv,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME parallelize qkv
|
||||
lora_output = torch.empty_like(base_output)
|
||||
# q
|
||||
output_dim_q = self.B_buffer_q.shape[-2]
|
||||
lora_output[:, :output_dim_q] = self.segment_gemm.run(
|
||||
x=lora_a_output[:, : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer_q,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# kv
|
||||
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
|
||||
for i in range(2):
|
||||
left = output_dim_kv * i
|
||||
right = left + output_dim_kv
|
||||
lora_output[:, output_dim_q + left : output_dim_q + right] = (
|
||||
self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=lora_output,
|
||||
weights=self.B_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
def forward(self, input_):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_parallel
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (
|
||||
output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None
|
||||
else output_
|
||||
)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, segment_gemm, lora_rank, scaling
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
|
||||
|
||||
def get_mapped_params(module_names):
|
||||
ret = set()
|
||||
for module_name in module_names:
|
||||
ret.add(params_mapping(module_name))
|
||||
return list(ret)
|
||||
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
def __init__(self, config, base_hf_config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.base_hf_config = base_hf_config
|
||||
self.weights = {}
|
||||
self.weight_gpu = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = None
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
def __init__(self, uid, config, base_hf_config, load_config):
|
||||
super().__init__()
|
||||
self.uid = uid
|
||||
self.config = config
|
||||
assert self.config.hf_config["peft_type"].lower() == "lora"
|
||||
self.base_hf_config = base_hf_config
|
||||
self.load_config = load_config
|
||||
self.scaling = self.config.lora_alpha / self.config.r
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LoRALayer(config, base_hf_config)
|
||||
for i in range(base_hf_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.weights = {}
|
||||
self.weights_gpu = {}
|
||||
|
||||
def get_stacked_multiply(self, module_name):
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
"kv_proj": 2,
|
||||
"gate_up_proj": 2,
|
||||
}
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
for layer in self.layers:
|
||||
layer.load_to_gpu()
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = None
|
||||
for layer in self.layers:
|
||||
layer.offload_from_gpu()
|
||||
|
||||
# initialize the LoRA weights to cpu
|
||||
def initialize_weights(self):
|
||||
model_path = self.config.path
|
||||
loader = DefaultModelLoader(self.load_config)
|
||||
revision = getattr(self.config.hf_config, "revision", None)
|
||||
for name, loaded_weight in loader._get_weights_iterator(
|
||||
model_path, revision=revision, fall_back_to_pt=True
|
||||
):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is not None:
|
||||
layer_id = int(match.group(1))
|
||||
self.layers[layer_id].weights[name] = loaded_weight.cpu()
|
||||
else:
|
||||
self.weights[name] = loaded_weight.cpu()
|
||||
|
||||
# stack kv_proj and gate_up_proj
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
for weight_name in weight_names:
|
||||
if "k_proj" in weight_name:
|
||||
q_name = weight_name.replace("k_proj", "q_proj")
|
||||
v_name = weight_name.replace("k_proj", "v_proj")
|
||||
kv_name = weight_name.replace("k_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("k_proj", "qkv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
layer.weights[qkv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[q_name],
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
layer.weights.pop(q_name)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
else:
|
||||
layer.weights[kv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
elif "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(up_name)
|
||||
43
python/sglang/srt/lora/lora_config.py
Normal file
43
python/sglang/srt/lora/lora_config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
class LoRAConfig:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.hf_config = self.get_lora_config()
|
||||
self.target_modules = self.hf_config["target_modules"]
|
||||
self.r = self.hf_config["r"]
|
||||
self.lora_alpha = self.hf_config["lora_alpha"]
|
||||
|
||||
def get_lora_config(self, dummy=False):
|
||||
if dummy:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
if not os.path.isdir(self.path):
|
||||
weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
|
||||
else:
|
||||
weights_dir = self.path
|
||||
config_name = "adapter_config.json"
|
||||
with open(os.path.join(weights_dir, config_name), "r") as f:
|
||||
return json.load(f)
|
||||
256
python/sglang/srt/lora/lora_manager.py
Normal file
256
python/sglang/srt/lora/lora_manager.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import replace_submodule
|
||||
|
||||
|
||||
def get_stacked_name(name):
|
||||
# origin name -> (name for A, name for B)
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
"k_proj": ("qkv_proj", "kv_proj"),
|
||||
"v_proj": ("qkv_proj", "kv_proj"),
|
||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
}
|
||||
return params_mapping.get(name, (name, name))
|
||||
|
||||
|
||||
def get_layer_id(name):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is None:
|
||||
return None
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
class LoRAManager:
|
||||
def __init__(
|
||||
self,
|
||||
base_model,
|
||||
lora_paths,
|
||||
base_hf_config,
|
||||
max_loras_per_batch,
|
||||
load_config,
|
||||
dtype,
|
||||
):
|
||||
self.base_model = base_model
|
||||
self.lora_paths = lora_paths
|
||||
self.base_hf_config = base_hf_config
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.load_config = load_config
|
||||
self.dtype = dtype
|
||||
|
||||
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
||||
|
||||
self.init_loras()
|
||||
self.init_lora_memory_pool()
|
||||
self.init_lora_batch()
|
||||
|
||||
def match_target_modules(self, module_name):
|
||||
for target_module in self.target_modules:
|
||||
if module_name.split(".")[-1] == target_module:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_target_modules(self):
|
||||
modules = []
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
if self.match_target_modules(module_name):
|
||||
modules.append((module_name, module))
|
||||
return modules
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
module, self.segment_gemm, self.max_lora_dim, self.scaling
|
||||
)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def init_loras(self):
|
||||
# get configs and target modules
|
||||
self.configs = {}
|
||||
self.origin_target_modules = set()
|
||||
for path in self.lora_paths:
|
||||
self.configs[path] = LoRAConfig(path)
|
||||
self.origin_target_modules = set(self.origin_target_modules) | set(
|
||||
self.configs[path].target_modules
|
||||
)
|
||||
self.target_modules = set(
|
||||
[
|
||||
self.base_model.get_module_name(module)
|
||||
for module in self.origin_target_modules
|
||||
]
|
||||
)
|
||||
self.target_weights = set(
|
||||
[get_stacked_name(module) for module in self.origin_target_modules]
|
||||
)
|
||||
|
||||
# load all weights to cpu
|
||||
self.loras = []
|
||||
self.lora_id = {}
|
||||
for path in self.lora_paths:
|
||||
self.lora_id[path] = len(self.loras)
|
||||
self.loras.append(
|
||||
LoRAAdapter(
|
||||
path, self.configs[path], self.base_hf_config, self.load_config
|
||||
)
|
||||
)
|
||||
self.loras[-1].initialize_weights()
|
||||
|
||||
# misc lora configs
|
||||
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
self.scaling = self.loras[0].scaling
|
||||
# FIXME remove the restrictions
|
||||
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
||||
assert all(x.scaling == self.scaling for x in self.loras)
|
||||
|
||||
# monkey patch to use the LoRA version
|
||||
self.lora_modules = []
|
||||
for module_name, module in self.get_target_modules():
|
||||
self.lora_modules.append(
|
||||
(module_name, self.set_lora_module(module_name, module))
|
||||
)
|
||||
|
||||
def init_lora_memory_pool(self):
|
||||
# preallocate lora memory pool
|
||||
self.A_buffer = {}
|
||||
self.B_buffer = {}
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
for module_A, module_B in self.target_weights:
|
||||
# init A tensor, column_major=True
|
||||
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
||||
c = self.loras[-1].get_stacked_multiply(module_A)
|
||||
if module_A not in self.A_buffer:
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
hidden_dim_A,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
# init B tensor, column_major=True
|
||||
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
||||
c = self.loras[-1].get_stacked_multiply(module_B)
|
||||
if module_B not in self.B_buffer:
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
hidden_dim_B * c,
|
||||
self.max_lora_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
|
||||
def init_lora_batch(self):
|
||||
self.active_uids = set() # set of active loras
|
||||
self.buffer_id = {} # lora uid -> idx in memory pool
|
||||
|
||||
def get_weight_name(self, name, idx):
|
||||
for target_weight_name in self.target_weights:
|
||||
if target_weight_name[idx] in name:
|
||||
return target_weight_name[idx]
|
||||
|
||||
def load_lora(self, uid, buffer_id):
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
if uid is None:
|
||||
for i in range(num_layer):
|
||||
for k in self.A_buffer.keys():
|
||||
self.A_buffer[k][i][buffer_id] *= 0
|
||||
return
|
||||
|
||||
for i in range(num_layer):
|
||||
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
|
||||
for name, weights in layer_weights.items():
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = self.get_weight_name(name, 0)
|
||||
if lora_weight_name:
|
||||
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
else:
|
||||
lora_weight_name = self.get_weight_name(name, 1)
|
||||
if lora_weight_name:
|
||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
|
||||
def prepare_lora_batch(self, batch, extend_seq_lens=None):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set([req.lora_path for req in batch.reqs])
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
i = 0
|
||||
evictable_uids = list(self.active_uids)
|
||||
for uid in cur_uids:
|
||||
if uid not in self.active_uids:
|
||||
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
||||
i += 1
|
||||
if i < len(evictable_uids):
|
||||
self.active_uids.remove(evictable_uids[i])
|
||||
self.buffer_id.pop(evictable_uids[i])
|
||||
self.load_lora(uid, i)
|
||||
self.active_uids.add(uid)
|
||||
self.buffer_id[uid] = i
|
||||
i += 1
|
||||
|
||||
if cur_uids == set([None]):
|
||||
return
|
||||
|
||||
# setup lora in forward modules
|
||||
bs = len(batch.reqs)
|
||||
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, req in enumerate(batch.reqs):
|
||||
weight_indices[i] = self.buffer_id[req.lora_path]
|
||||
|
||||
for module_name, module in self.lora_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
if "qkv_proj" not in module_name:
|
||||
weight_name = self.get_weight_name(module_name, 0)
|
||||
module.set_lora_info(
|
||||
self.A_buffer[weight_name][layer_id],
|
||||
self.B_buffer[weight_name][layer_id],
|
||||
bs,
|
||||
seg_lens,
|
||||
weight_indices,
|
||||
)
|
||||
else:
|
||||
module.set_lora_info(
|
||||
self.A_buffer["qkv_proj"][layer_id],
|
||||
self.B_buffer["q_proj"][layer_id],
|
||||
self.B_buffer["kv_proj"][layer_id],
|
||||
bs,
|
||||
seg_lens,
|
||||
weight_indices,
|
||||
)
|
||||
@@ -55,6 +55,9 @@ class GenerateReqInput:
|
||||
|
||||
is_single: bool = True
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
def post_init(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
@@ -184,6 +187,9 @@ class TokenizedGenerateReqInput:
|
||||
# Modalities of the input images
|
||||
modalites: Optional[List[str]] = None
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
|
||||
@@ -98,7 +98,7 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
class Req:
|
||||
"""Store all inforamtion of a request."""
|
||||
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
self.origin_input_text = origin_input_text
|
||||
@@ -106,6 +106,7 @@ class Req:
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
self.lora_path = lora_path
|
||||
|
||||
# Memory info
|
||||
self.req_pool_idx = None
|
||||
|
||||
@@ -266,6 +266,11 @@ class TokenizerManager:
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
modalities,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else: # is embedding
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -364,6 +369,11 @@ class TokenizerManager:
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
modalities,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else:
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
@@ -87,6 +87,8 @@ class ModelTpServer:
|
||||
self.dp_size = server_args.dp_size
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -323,7 +325,15 @@ class ModelTpServer:
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
lora_path=recv_req.lora_path,
|
||||
)
|
||||
else:
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.tokenizer = self.tokenizer
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
@@ -442,10 +452,27 @@ class ModelTpServer:
|
||||
self.current_inflight_req
|
||||
)
|
||||
|
||||
if self.lora_paths is not None:
|
||||
lora_set = (
|
||||
set([req.lora_path for req in self.running_batch.reqs])
|
||||
if self.running_batch is not None
|
||||
else set([])
|
||||
)
|
||||
|
||||
for req in self.waiting_queue:
|
||||
if adder.no_remaining_tokens():
|
||||
break
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
if (
|
||||
self.lora_paths is not None
|
||||
and len(
|
||||
lora_set
|
||||
| set([req.lora_path for req in adder.can_run_list])
|
||||
| set([req.lora_path])
|
||||
)
|
||||
> self.max_loras_per_batch
|
||||
):
|
||||
break
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
not res
|
||||
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
@@ -107,6 +108,8 @@ class ModelRunner:
|
||||
# Init componnets
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.load_model()
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_running_requests,
|
||||
@@ -312,6 +315,17 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights"
|
||||
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
base_model=self.model,
|
||||
lora_paths=self.server_args.lora_paths,
|
||||
base_hf_config=self.model_config.hf_config,
|
||||
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
||||
load_config=self.load_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
logger.info("LoRA manager ready.")
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
@@ -450,6 +464,8 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch)
|
||||
if (
|
||||
self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||
@@ -466,6 +482,9 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
||||
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
|
||||
@@ -324,6 +324,51 @@ class LlamaForCausalLM(nn.Module):
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||
return self.config.hidden_size, self.config.hidden_size
|
||||
elif module_name in ["kv_proj"]:
|
||||
return self.config.hidden_size, self.config.hidden_size // (
|
||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return self.config.hidden_size, self.config.intermediate_size
|
||||
elif module_name == "down_proj":
|
||||
return self.config.intermediate_size, self.config.hidden_size
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_module_name(self, name):
|
||||
params_mapping = {
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
return params_mapping.get(name, name)
|
||||
|
||||
def get_module_name_from_weight_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
num_shard,
|
||||
)
|
||||
return name[: -len(".weight")], 1
|
||||
|
||||
def get_num_params(self):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@@ -611,6 +611,7 @@ class Runtime:
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
@@ -618,7 +619,9 @@ class Runtime:
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
|
||||
@@ -101,6 +101,10 @@ class ServerArgs:
|
||||
enable_mla: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
# LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
@@ -522,6 +526,21 @@ class ServerArgs:
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
|
||||
# LoRA options
|
||||
parser.add_argument(
|
||||
"--lora-paths",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The list of LoRA adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Maximum number of adapters for a running batch, include base-only request",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
@@ -539,6 +558,12 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_cuda_graph)
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
@@ -35,6 +35,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
@@ -714,3 +715,14 @@ def configure_logger(server_args, prefix: str = ""):
|
||||
datefmt="%H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
|
||||
|
||||
# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
|
||||
def replace_submodule(
|
||||
model: nn.Module, module_name: str, new_module: nn.Module
|
||||
) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from sglang.srt.server import Runtime
|
||||
@@ -52,6 +53,7 @@ def get_dtype_str(torch_dtype):
|
||||
|
||||
def get_top_logprobs(logits, k):
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
del logits
|
||||
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
|
||||
return logprobs
|
||||
|
||||
@@ -71,8 +73,10 @@ class HFRunner:
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation,
|
||||
output_str_only=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.output_str_only = output_str_only
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
@@ -95,7 +99,7 @@ class HFRunner:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
@@ -110,13 +114,16 @@ class HFRunner:
|
||||
)
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
if lora_paths is not None:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
if prompts is not None:
|
||||
if self.is_generation:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for p in prompts:
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = self.tokenizer.encode(
|
||||
p, return_tensors="pt"
|
||||
@@ -124,6 +131,16 @@ class HFRunner:
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
if lora_paths is not None and lora_paths[i] is not None:
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.base_model,
|
||||
lora_paths[i],
|
||||
torch_dtype=torch_dtype,
|
||||
is_trainable=False,
|
||||
)
|
||||
else:
|
||||
self.model = self.base_model
|
||||
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
@@ -131,25 +148,30 @@ class HFRunner:
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
output_scores=(not self.output_str_only),
|
||||
)
|
||||
output_strs.append(
|
||||
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
|
||||
)
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
if not self.output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(
|
||||
logits[0], NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
del input_logits
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(
|
||||
input_logits, NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
@@ -160,6 +182,7 @@ class HFRunner:
|
||||
)
|
||||
|
||||
else:
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
@@ -167,8 +190,9 @@ class HFRunner:
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens))
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths))
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
@@ -191,6 +215,10 @@ class SRTRunner:
|
||||
is_generation,
|
||||
tp_size=1,
|
||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths=None,
|
||||
max_loras_per_batch=4,
|
||||
disable_cuda_graph=False,
|
||||
disable_radix_cache=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.runtime = Runtime(
|
||||
@@ -201,12 +229,17 @@ class SRTRunner:
|
||||
mem_fraction_static=0.69,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
@@ -214,9 +247,10 @@ class SRTRunner:
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
@@ -256,6 +290,37 @@ class SRTRunner:
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def batch_forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
"""
|
||||
testing serving by sending all prompts once
|
||||
only return output strings and no logprobs
|
||||
"""
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = self.runtime.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
62
scripts/playground/lora/lora_hf_play.py
Normal file
62
scripts/playground/lora/lora_hf_play.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
HF_TOKEN = "..."
|
||||
|
||||
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
|
||||
|
||||
base_model = LlamaForCausalLM.from_pretrained(
|
||||
MODEL,
|
||||
device_map="auto",
|
||||
# load_in_8bit=True,
|
||||
torch_dtype=torch.float16,
|
||||
# use_auth_token=HF_TOKEN,
|
||||
).cuda()
|
||||
|
||||
|
||||
# base model generate
|
||||
with torch.no_grad():
|
||||
output_tensors = base_model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= base output ========")
|
||||
print(output)
|
||||
|
||||
|
||||
# peft model generate
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
ADAPTER,
|
||||
torch_dtype=torch.float16,
|
||||
is_trainable=False,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= peft output ========")
|
||||
print(output)
|
||||
30
scripts/playground/lora/lora_vllm_play.py
Normal file
30
scripts/playground/lora/lora_vllm_play.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
llm = LLM(model=MODEL, enable_lora=True)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
|
||||
)
|
||||
|
||||
print(outputs[0].prompt)
|
||||
print(outputs[0].outputs[0].text)
|
||||
55
scripts/playground/lora/test_lora.py
Normal file
55
scripts/playground/lora/test_lora.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import json
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
lora_path = "/home/ying/test_lora"
|
||||
prompt_file = "/home/ying/test_prompt/dialogue_choice_prompts.json"
|
||||
server_url = "http://127.0.0.1:30000"
|
||||
|
||||
client = openai.Client(base_url=server_url + "/v1", api_key="EMPTY")
|
||||
|
||||
|
||||
# @sgl.function
|
||||
# def generate(s, prompt):
|
||||
# s += prompt
|
||||
# s += sgl.gen("ans")
|
||||
|
||||
# sgl.set_default_backend(sgl.RuntimeEndpoint(server_url))
|
||||
|
||||
|
||||
def generate(prompt, lora_path):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": {},
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": None,
|
||||
"top_logprobs_num": None,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
response = requests.post(
|
||||
server_url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
|
||||
with open(prompt_file, "r") as f:
|
||||
samples = json.load(f)
|
||||
|
||||
|
||||
for sample in samples[:1]:
|
||||
assert sample[0]["role"] == "user"
|
||||
prompt = sample[0]["content"]
|
||||
assert sample[1]["role"] == "assistant"
|
||||
ref = sample[1]["content"]
|
||||
|
||||
state = generate(prompt, lora_path)
|
||||
print("================================")
|
||||
print(ref)
|
||||
print("--------------------------------")
|
||||
# print(state["ans"])
|
||||
print(state)
|
||||
print()
|
||||
52
test/srt/models/compare.py
Normal file
52
test/srt/models/compare.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
used for debug using tensor comparison
|
||||
dump {name: tensor} into "log_hf.jsonl" and "log_srt.jsonl"
|
||||
use the same name for two tensors that supposed to be close
|
||||
recommend name like: "layer 2 after mlp"
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
assert sys.argv[1] == "base"
|
||||
hf_log = "base_log_hf.jsonl"
|
||||
srt_log = "base_log_srt.jsonl"
|
||||
else:
|
||||
hf_log = "log_hf.jsonl"
|
||||
srt_log = "log_srt.jsonl"
|
||||
|
||||
|
||||
def load_data(filepath):
|
||||
tensors = {}
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
for k, v in data.items():
|
||||
tensors[k] = torch.tensor(v)
|
||||
return tensors
|
||||
|
||||
|
||||
hf_tensors = load_data(hf_log)
|
||||
srt_tensors = load_data(srt_log)
|
||||
|
||||
|
||||
def get_diff(t1, t2):
|
||||
t1 = t1.reshape(t2.shape)
|
||||
max_diff = torch.max(abs(t1.reshape(t2.shape) - t2))
|
||||
l2_dis = torch.dist(t1, t2, p=2)
|
||||
return l2_dis, max_diff
|
||||
|
||||
|
||||
for k, _ in srt_tensors.items():
|
||||
l2_dis, max_diff = get_diff(hf_tensors[k], srt_tensors[k])
|
||||
print(f"{k} {l2_dis=} {max_diff=}")
|
||||
if k == "layer 1 attn":
|
||||
print(hf_tensors[k])
|
||||
print(srt_tensors[k])
|
||||
if k == "layer 0 prefill k":
|
||||
print(srt_tensors[k].shape)
|
||||
print(hf_tensors[k].shape)
|
||||
@@ -76,6 +76,7 @@ class TestGenerationModels(unittest.TestCase):
|
||||
) -> None:
|
||||
if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
|
||||
prompts = prompts[:-1]
|
||||
|
||||
with HFRunner(
|
||||
model_path, torch_dtype=torch_dtype, is_generation=True
|
||||
) as hf_runner:
|
||||
|
||||
297
test/srt/models/test_lora.py
Normal file
297
test/srt/models/test_lora.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import multiprocessing as mp
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
|
||||
LORA_SETS = [
|
||||
# {
|
||||
# "base": "meta-llama/Llama-2-7b-hf",
|
||||
# "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"],
|
||||
# },
|
||||
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
|
||||
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
|
||||
# {
|
||||
# "base": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
# "loras": [
|
||||
# "/home/ying/test_lora",
|
||||
# "/home/ying/test_lora_1",
|
||||
# "/home/ying/test_lora_2",
|
||||
# "/home/ying/test_lora_3",
|
||||
# "/home/ying/test_lora_4",
|
||||
# ],
|
||||
# },
|
||||
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["yard1/llama-2-7b-sql-lora-test"]},
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
PROMPTS = [
|
||||
"""
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
""",
|
||||
"""
|
||||
### Instruction:
|
||||
Tell me about llamas and alpacas
|
||||
### Response:
|
||||
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
|
||||
### Question 2:
|
||||
What do you know about llamas?
|
||||
### Answer:
|
||||
""",
|
||||
]
|
||||
|
||||
# import json
|
||||
#
|
||||
# with open("/home/ying/test_prompt/dialogue_choice_prompts.json", "r") as f:
|
||||
# samples = json.load(f)
|
||||
# for sample in samples[:5]:
|
||||
# assert sample[0]["role"] == "user"
|
||||
# PROMPTS.append(sample[0]["content"][:2000])
|
||||
|
||||
|
||||
class TestLoRA(unittest.TestCase):
|
||||
|
||||
def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
|
||||
print("=================== testing inference =======================")
|
||||
base_path = lora_set["base"]
|
||||
all_lora_paths = lora_set["loras"]
|
||||
batch_lora_paths = [None]
|
||||
i = 0
|
||||
for _ in range(len(prompts) - 1):
|
||||
batch_lora_paths.append(all_lora_paths[i])
|
||||
i = (i + 1) % len(all_lora_paths)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
lora_paths=all_lora_paths,
|
||||
max_loras_per_batch=3,
|
||||
disable_cuda_graph=True,
|
||||
disable_radix_cache=True,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
) as hf_runner:
|
||||
hf_no_lora_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
# compare input logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||
hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i])
|
||||
srt_no_lora_logprobs = torch.Tensor(
|
||||
srt_no_lora_outputs.top_input_logprobs[i]
|
||||
)
|
||||
print(
|
||||
"max input diff between hf_lora and srt_lora",
|
||||
torch.max(abs(hf_logprobs - srt_logprobs)),
|
||||
)
|
||||
print(
|
||||
"max input diff between srt_base and srt_lora",
|
||||
torch.max(abs(srt_no_lora_logprobs - srt_logprobs)),
|
||||
)
|
||||
print(
|
||||
"max input diff between srt_base and hf_base",
|
||||
torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)),
|
||||
)
|
||||
print(
|
||||
"max input diff between hf_lora and hf_base",
|
||||
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
|
||||
)
|
||||
|
||||
# compare output logprobs
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
||||
# print(
|
||||
# "\noutput logprobs diff",
|
||||
# [
|
||||
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
|
||||
# for j in range(max_new_tokens)
|
||||
# ],
|
||||
# )
|
||||
print(
|
||||
"max output diff between hf_lora and srt_lora",
|
||||
torch.max(abs(hf_logprobs - srt_logprobs)),
|
||||
"\n",
|
||||
)
|
||||
|
||||
# compare output strings
|
||||
print(f"{hf_outputs.output_strs=}")
|
||||
print(f"{srt_outputs.output_strs=}")
|
||||
print(f"{hf_no_lora_outputs.output_strs=}")
|
||||
print(f"{srt_no_lora_outputs.output_strs=}")
|
||||
for i in range(len(prompts)):
|
||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
||||
str_outputs.output_strs[i].strip(" "),
|
||||
hf_outputs.output_strs[i],
|
||||
)
|
||||
# assert (
|
||||
# srt_no_lora_outputs.output_strs[i].strip(" ")
|
||||
# == hf_no_lora_outputs.output_strs[i]
|
||||
# ), (
|
||||
# srt_no_lora_outputs.output_strs[i].strip(" "),
|
||||
# hf_no_lora_outputs.output_strs[i],
|
||||
# )
|
||||
|
||||
def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
|
||||
print("=================== testing serving =======================")
|
||||
# test batch forward
|
||||
base_path = lora_set["base"]
|
||||
all_lora_paths = lora_set["loras"]
|
||||
batch_lora_paths = [None]
|
||||
i = 0
|
||||
for _ in range(len(prompts) - 1):
|
||||
batch_lora_paths.append(all_lora_paths[i])
|
||||
i = (i + 1) % len(all_lora_paths)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
lora_paths=all_lora_paths,
|
||||
max_loras_per_batch=3,
|
||||
disable_cuda_graph=True,
|
||||
disable_radix_cache=True,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.batch_forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
output_str_only=True,
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||
)
|
||||
|
||||
# compare output strings
|
||||
print(f"{hf_outputs.output_strs=}")
|
||||
print(f"{srt_outputs.output_strs=}")
|
||||
for i in range(len(prompts)):
|
||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
||||
srt_outputs.output_strs[i].strip(" "),
|
||||
hf_outputs.output_strs[i],
|
||||
)
|
||||
|
||||
def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
|
||||
print("=================== testing base inference =======================")
|
||||
base_path = lora_set["base"]
|
||||
all_lora_paths = lora_set["loras"]
|
||||
batch_lora_paths = [None] * len(prompts)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation=True,
|
||||
lora_paths=all_lora_paths,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||
)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
srt_no_lora_logprobs = torch.Tensor(
|
||||
srt_no_lora_outputs.top_input_logprobs[i]
|
||||
)
|
||||
srt_logprobs = torch.uensor(srt_outputs.top_input_logprobs[i])
|
||||
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))
|
||||
|
||||
print(f"{srt_no_lora_outputs.output_strs=}")
|
||||
print(f"{srt_outputs.output_strs=}")
|
||||
|
||||
for i in range(len(prompts)):
|
||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
||||
str_outputs.output_strs[i].strip(" "),
|
||||
hf_outputs.output_strs[i],
|
||||
)
|
||||
assert (
|
||||
srt_no_lora_outputs[i].output_strs.strip(" ")
|
||||
== hf_no_lora_outputs[i].output_strs
|
||||
)
|
||||
|
||||
def test_all(self):
|
||||
for lora_set in LORA_SETS:
|
||||
# self.load_lora_adapter(lora_set, 1)
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
tp_size = 1
|
||||
max_new_tokens = 32
|
||||
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
|
||||
# self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
|
||||
# self.base_inference(
|
||||
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
@@ -7,6 +7,7 @@ suites = {
|
||||
"minimal": [
|
||||
"models/test_embedding_models.py",
|
||||
"models/test_generation_models.py",
|
||||
"models/test_lora.py",
|
||||
"sampling/penaltylib",
|
||||
"test_chunked_prefill.py",
|
||||
"test_embedding_openai_server.py",
|
||||
|
||||
Reference in New Issue
Block a user