[Feature] Initial support for multi-LoRA serving (#1307)
This commit is contained in:
@@ -27,7 +27,7 @@ srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "intere
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate"]
|
||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
|
||||
|
||||
403
python/sglang/srt/lora/lora.py
Normal file
403
python/sglang/srt/lora/lora.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
# LoRA layers class inheritance adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.segment_gemm = segment_gemm
|
||||
self.lora_rank = lora_rank
|
||||
self.scaling = scaling
|
||||
self.set_lora = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.base_layer.forward(x)
|
||||
|
||||
def set_lora_info(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO
|
||||
return output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in ColumnParallelLinear
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_, bias
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_)
|
||||
|
||||
if self.base_layer.gather_output:
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME
|
||||
assert lora_a_output.shape[-1] == self.lora_rank * 2
|
||||
lora_output = torch.empty_like(base_output)
|
||||
output_dim = lora_output.shape[-1] // 2
|
||||
for i in range(2):
|
||||
left = output_dim * i
|
||||
right = left + output_dim
|
||||
lora_output[:, left:right] = self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * i : self.lora_rank * (i + 1)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(
|
||||
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
self.B_buffer_q = B_buffer_q
|
||||
self.B_buffer_kv = B_buffer_kv
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer_qkv,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME parallelize qkv
|
||||
lora_output = torch.empty_like(base_output)
|
||||
# q
|
||||
output_dim_q = self.B_buffer_q.shape[-2]
|
||||
lora_output[:, :output_dim_q] = self.segment_gemm.run(
|
||||
x=lora_a_output[:, : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer_q,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# kv
|
||||
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
|
||||
for i in range(2):
|
||||
left = output_dim_kv * i
|
||||
right = left + output_dim_kv
|
||||
lora_output[:, output_dim_q + left : output_dim_q + right] = (
|
||||
self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=lora_output,
|
||||
weights=self.B_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
def forward(self, input_):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_parallel
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (
|
||||
output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None
|
||||
else output_
|
||||
)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, segment_gemm, lora_rank, scaling
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
|
||||
|
||||
def get_mapped_params(module_names):
|
||||
ret = set()
|
||||
for module_name in module_names:
|
||||
ret.add(params_mapping(module_name))
|
||||
return list(ret)
|
||||
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
def __init__(self, config, base_hf_config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.base_hf_config = base_hf_config
|
||||
self.weights = {}
|
||||
self.weight_gpu = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = None
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
def __init__(self, uid, config, base_hf_config, load_config):
|
||||
super().__init__()
|
||||
self.uid = uid
|
||||
self.config = config
|
||||
assert self.config.hf_config["peft_type"].lower() == "lora"
|
||||
self.base_hf_config = base_hf_config
|
||||
self.load_config = load_config
|
||||
self.scaling = self.config.lora_alpha / self.config.r
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LoRALayer(config, base_hf_config)
|
||||
for i in range(base_hf_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.weights = {}
|
||||
self.weights_gpu = {}
|
||||
|
||||
def get_stacked_multiply(self, module_name):
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
"kv_proj": 2,
|
||||
"gate_up_proj": 2,
|
||||
}
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
for layer in self.layers:
|
||||
layer.load_to_gpu()
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = None
|
||||
for layer in self.layers:
|
||||
layer.offload_from_gpu()
|
||||
|
||||
# initialize the LoRA weights to cpu
|
||||
def initialize_weights(self):
|
||||
model_path = self.config.path
|
||||
loader = DefaultModelLoader(self.load_config)
|
||||
revision = getattr(self.config.hf_config, "revision", None)
|
||||
for name, loaded_weight in loader._get_weights_iterator(
|
||||
model_path, revision=revision, fall_back_to_pt=True
|
||||
):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is not None:
|
||||
layer_id = int(match.group(1))
|
||||
self.layers[layer_id].weights[name] = loaded_weight.cpu()
|
||||
else:
|
||||
self.weights[name] = loaded_weight.cpu()
|
||||
|
||||
# stack kv_proj and gate_up_proj
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
for weight_name in weight_names:
|
||||
if "k_proj" in weight_name:
|
||||
q_name = weight_name.replace("k_proj", "q_proj")
|
||||
v_name = weight_name.replace("k_proj", "v_proj")
|
||||
kv_name = weight_name.replace("k_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("k_proj", "qkv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
layer.weights[qkv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[q_name],
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
layer.weights.pop(q_name)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
else:
|
||||
layer.weights[kv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
elif "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(up_name)
|
||||
43
python/sglang/srt/lora/lora_config.py
Normal file
43
python/sglang/srt/lora/lora_config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
class LoRAConfig:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.hf_config = self.get_lora_config()
|
||||
self.target_modules = self.hf_config["target_modules"]
|
||||
self.r = self.hf_config["r"]
|
||||
self.lora_alpha = self.hf_config["lora_alpha"]
|
||||
|
||||
def get_lora_config(self, dummy=False):
|
||||
if dummy:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
if not os.path.isdir(self.path):
|
||||
weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
|
||||
else:
|
||||
weights_dir = self.path
|
||||
config_name = "adapter_config.json"
|
||||
with open(os.path.join(weights_dir, config_name), "r") as f:
|
||||
return json.load(f)
|
||||
256
python/sglang/srt/lora/lora_manager.py
Normal file
256
python/sglang/srt/lora/lora_manager.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import replace_submodule
|
||||
|
||||
|
||||
def get_stacked_name(name):
|
||||
# origin name -> (name for A, name for B)
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
"k_proj": ("qkv_proj", "kv_proj"),
|
||||
"v_proj": ("qkv_proj", "kv_proj"),
|
||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
}
|
||||
return params_mapping.get(name, (name, name))
|
||||
|
||||
|
||||
def get_layer_id(name):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is None:
|
||||
return None
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
class LoRAManager:
|
||||
def __init__(
|
||||
self,
|
||||
base_model,
|
||||
lora_paths,
|
||||
base_hf_config,
|
||||
max_loras_per_batch,
|
||||
load_config,
|
||||
dtype,
|
||||
):
|
||||
self.base_model = base_model
|
||||
self.lora_paths = lora_paths
|
||||
self.base_hf_config = base_hf_config
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.load_config = load_config
|
||||
self.dtype = dtype
|
||||
|
||||
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
||||
|
||||
self.init_loras()
|
||||
self.init_lora_memory_pool()
|
||||
self.init_lora_batch()
|
||||
|
||||
def match_target_modules(self, module_name):
|
||||
for target_module in self.target_modules:
|
||||
if module_name.split(".")[-1] == target_module:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_target_modules(self):
|
||||
modules = []
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
if self.match_target_modules(module_name):
|
||||
modules.append((module_name, module))
|
||||
return modules
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
module, self.segment_gemm, self.max_lora_dim, self.scaling
|
||||
)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def init_loras(self):
|
||||
# get configs and target modules
|
||||
self.configs = {}
|
||||
self.origin_target_modules = set()
|
||||
for path in self.lora_paths:
|
||||
self.configs[path] = LoRAConfig(path)
|
||||
self.origin_target_modules = set(self.origin_target_modules) | set(
|
||||
self.configs[path].target_modules
|
||||
)
|
||||
self.target_modules = set(
|
||||
[
|
||||
self.base_model.get_module_name(module)
|
||||
for module in self.origin_target_modules
|
||||
]
|
||||
)
|
||||
self.target_weights = set(
|
||||
[get_stacked_name(module) for module in self.origin_target_modules]
|
||||
)
|
||||
|
||||
# load all weights to cpu
|
||||
self.loras = []
|
||||
self.lora_id = {}
|
||||
for path in self.lora_paths:
|
||||
self.lora_id[path] = len(self.loras)
|
||||
self.loras.append(
|
||||
LoRAAdapter(
|
||||
path, self.configs[path], self.base_hf_config, self.load_config
|
||||
)
|
||||
)
|
||||
self.loras[-1].initialize_weights()
|
||||
|
||||
# misc lora configs
|
||||
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
self.scaling = self.loras[0].scaling
|
||||
# FIXME remove the restrictions
|
||||
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
||||
assert all(x.scaling == self.scaling for x in self.loras)
|
||||
|
||||
# monkey patch to use the LoRA version
|
||||
self.lora_modules = []
|
||||
for module_name, module in self.get_target_modules():
|
||||
self.lora_modules.append(
|
||||
(module_name, self.set_lora_module(module_name, module))
|
||||
)
|
||||
|
||||
def init_lora_memory_pool(self):
|
||||
# preallocate lora memory pool
|
||||
self.A_buffer = {}
|
||||
self.B_buffer = {}
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
for module_A, module_B in self.target_weights:
|
||||
# init A tensor, column_major=True
|
||||
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
||||
c = self.loras[-1].get_stacked_multiply(module_A)
|
||||
if module_A not in self.A_buffer:
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
hidden_dim_A,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
# init B tensor, column_major=True
|
||||
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
||||
c = self.loras[-1].get_stacked_multiply(module_B)
|
||||
if module_B not in self.B_buffer:
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
hidden_dim_B * c,
|
||||
self.max_lora_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
|
||||
def init_lora_batch(self):
|
||||
self.active_uids = set() # set of active loras
|
||||
self.buffer_id = {} # lora uid -> idx in memory pool
|
||||
|
||||
def get_weight_name(self, name, idx):
|
||||
for target_weight_name in self.target_weights:
|
||||
if target_weight_name[idx] in name:
|
||||
return target_weight_name[idx]
|
||||
|
||||
def load_lora(self, uid, buffer_id):
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
if uid is None:
|
||||
for i in range(num_layer):
|
||||
for k in self.A_buffer.keys():
|
||||
self.A_buffer[k][i][buffer_id] *= 0
|
||||
return
|
||||
|
||||
for i in range(num_layer):
|
||||
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
|
||||
for name, weights in layer_weights.items():
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = self.get_weight_name(name, 0)
|
||||
if lora_weight_name:
|
||||
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
else:
|
||||
lora_weight_name = self.get_weight_name(name, 1)
|
||||
if lora_weight_name:
|
||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
|
||||
def prepare_lora_batch(self, batch, extend_seq_lens=None):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set([req.lora_path for req in batch.reqs])
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
i = 0
|
||||
evictable_uids = list(self.active_uids)
|
||||
for uid in cur_uids:
|
||||
if uid not in self.active_uids:
|
||||
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
||||
i += 1
|
||||
if i < len(evictable_uids):
|
||||
self.active_uids.remove(evictable_uids[i])
|
||||
self.buffer_id.pop(evictable_uids[i])
|
||||
self.load_lora(uid, i)
|
||||
self.active_uids.add(uid)
|
||||
self.buffer_id[uid] = i
|
||||
i += 1
|
||||
|
||||
if cur_uids == set([None]):
|
||||
return
|
||||
|
||||
# setup lora in forward modules
|
||||
bs = len(batch.reqs)
|
||||
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, req in enumerate(batch.reqs):
|
||||
weight_indices[i] = self.buffer_id[req.lora_path]
|
||||
|
||||
for module_name, module in self.lora_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
if "qkv_proj" not in module_name:
|
||||
weight_name = self.get_weight_name(module_name, 0)
|
||||
module.set_lora_info(
|
||||
self.A_buffer[weight_name][layer_id],
|
||||
self.B_buffer[weight_name][layer_id],
|
||||
bs,
|
||||
seg_lens,
|
||||
weight_indices,
|
||||
)
|
||||
else:
|
||||
module.set_lora_info(
|
||||
self.A_buffer["qkv_proj"][layer_id],
|
||||
self.B_buffer["q_proj"][layer_id],
|
||||
self.B_buffer["kv_proj"][layer_id],
|
||||
bs,
|
||||
seg_lens,
|
||||
weight_indices,
|
||||
)
|
||||
@@ -55,6 +55,9 @@ class GenerateReqInput:
|
||||
|
||||
is_single: bool = True
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
def post_init(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
@@ -184,6 +187,9 @@ class TokenizedGenerateReqInput:
|
||||
# Modalities of the input images
|
||||
modalites: Optional[List[str]] = None
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
|
||||
@@ -98,7 +98,7 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
class Req:
|
||||
"""Store all inforamtion of a request."""
|
||||
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
self.origin_input_text = origin_input_text
|
||||
@@ -106,6 +106,7 @@ class Req:
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
self.lora_path = lora_path
|
||||
|
||||
# Memory info
|
||||
self.req_pool_idx = None
|
||||
|
||||
@@ -266,6 +266,11 @@ class TokenizerManager:
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
modalities,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else: # is embedding
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -364,6 +369,11 @@ class TokenizerManager:
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
modalities,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else:
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
@@ -87,6 +87,8 @@ class ModelTpServer:
|
||||
self.dp_size = server_args.dp_size
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -323,7 +325,15 @@ class ModelTpServer:
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
lora_path=recv_req.lora_path,
|
||||
)
|
||||
else:
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.tokenizer = self.tokenizer
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
@@ -442,10 +452,27 @@ class ModelTpServer:
|
||||
self.current_inflight_req
|
||||
)
|
||||
|
||||
if self.lora_paths is not None:
|
||||
lora_set = (
|
||||
set([req.lora_path for req in self.running_batch.reqs])
|
||||
if self.running_batch is not None
|
||||
else set([])
|
||||
)
|
||||
|
||||
for req in self.waiting_queue:
|
||||
if adder.no_remaining_tokens():
|
||||
break
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
if (
|
||||
self.lora_paths is not None
|
||||
and len(
|
||||
lora_set
|
||||
| set([req.lora_path for req in adder.can_run_list])
|
||||
| set([req.lora_path])
|
||||
)
|
||||
> self.max_loras_per_batch
|
||||
):
|
||||
break
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
not res
|
||||
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
@@ -107,6 +108,8 @@ class ModelRunner:
|
||||
# Init componnets
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.load_model()
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_running_requests,
|
||||
@@ -312,6 +315,17 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights"
|
||||
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
base_model=self.model,
|
||||
lora_paths=self.server_args.lora_paths,
|
||||
base_hf_config=self.model_config.hf_config,
|
||||
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
||||
load_config=self.load_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
logger.info("LoRA manager ready.")
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
@@ -450,6 +464,8 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch)
|
||||
if (
|
||||
self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||
@@ -466,6 +482,9 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
||||
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
|
||||
@@ -324,6 +324,51 @@ class LlamaForCausalLM(nn.Module):
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||
return self.config.hidden_size, self.config.hidden_size
|
||||
elif module_name in ["kv_proj"]:
|
||||
return self.config.hidden_size, self.config.hidden_size // (
|
||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return self.config.hidden_size, self.config.intermediate_size
|
||||
elif module_name == "down_proj":
|
||||
return self.config.intermediate_size, self.config.hidden_size
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_module_name(self, name):
|
||||
params_mapping = {
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
return params_mapping.get(name, name)
|
||||
|
||||
def get_module_name_from_weight_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
num_shard,
|
||||
)
|
||||
return name[: -len(".weight")], 1
|
||||
|
||||
def get_num_params(self):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@@ -611,6 +611,7 @@ class Runtime:
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
@@ -618,7 +619,9 @@ class Runtime:
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
|
||||
@@ -101,6 +101,10 @@ class ServerArgs:
|
||||
enable_mla: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
# LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
@@ -522,6 +526,21 @@ class ServerArgs:
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
|
||||
# LoRA options
|
||||
parser.add_argument(
|
||||
"--lora-paths",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The list of LoRA adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Maximum number of adapters for a running batch, include base-only request",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
@@ -539,6 +558,12 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_cuda_graph)
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
@@ -35,6 +35,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
@@ -714,3 +715,14 @@ def configure_logger(server_args, prefix: str = ""):
|
||||
datefmt="%H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
|
||||
|
||||
# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
|
||||
def replace_submodule(
|
||||
model: nn.Module, module_name: str, new_module: nn.Module
|
||||
) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from sglang.srt.server import Runtime
|
||||
@@ -52,6 +53,7 @@ def get_dtype_str(torch_dtype):
|
||||
|
||||
def get_top_logprobs(logits, k):
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
del logits
|
||||
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
|
||||
return logprobs
|
||||
|
||||
@@ -71,8 +73,10 @@ class HFRunner:
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation,
|
||||
output_str_only=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.output_str_only = output_str_only
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
@@ -95,7 +99,7 @@ class HFRunner:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
@@ -110,13 +114,16 @@ class HFRunner:
|
||||
)
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
if lora_paths is not None:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
if prompts is not None:
|
||||
if self.is_generation:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for p in prompts:
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = self.tokenizer.encode(
|
||||
p, return_tensors="pt"
|
||||
@@ -124,6 +131,16 @@ class HFRunner:
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
if lora_paths is not None and lora_paths[i] is not None:
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.base_model,
|
||||
lora_paths[i],
|
||||
torch_dtype=torch_dtype,
|
||||
is_trainable=False,
|
||||
)
|
||||
else:
|
||||
self.model = self.base_model
|
||||
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
@@ -131,25 +148,30 @@ class HFRunner:
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
output_scores=(not self.output_str_only),
|
||||
)
|
||||
output_strs.append(
|
||||
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
|
||||
)
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
if not self.output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(
|
||||
logits[0], NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
del input_logits
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(
|
||||
input_logits, NUM_TOP_LOGPROBS
|
||||
).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
@@ -160,6 +182,7 @@ class HFRunner:
|
||||
)
|
||||
|
||||
else:
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
@@ -167,8 +190,9 @@ class HFRunner:
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens))
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths))
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
@@ -191,6 +215,10 @@ class SRTRunner:
|
||||
is_generation,
|
||||
tp_size=1,
|
||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths=None,
|
||||
max_loras_per_batch=4,
|
||||
disable_cuda_graph=False,
|
||||
disable_radix_cache=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.runtime = Runtime(
|
||||
@@ -201,12 +229,17 @@ class SRTRunner:
|
||||
mem_fraction_static=0.69,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
@@ -214,9 +247,10 @@ class SRTRunner:
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
@@ -256,6 +290,37 @@ class SRTRunner:
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def batch_forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
"""
|
||||
testing serving by sending all prompts once
|
||||
only return output strings and no logprobs
|
||||
"""
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = self.runtime.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user