Refactor LoRAManager and LoRAMemoryPool state management logic for dynamic LoRA loading support (#7412)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,7 +45,6 @@ class LoRAManager:
|
||||
def __init__(
|
||||
self,
|
||||
base_model: torch.nn.Module,
|
||||
lora_paths: Dict[str, str],
|
||||
base_hf_config: AutoConfig,
|
||||
max_loras_per_batch: int,
|
||||
load_config: LoadConfig,
|
||||
@@ -55,7 +54,6 @@ class LoRAManager:
|
||||
tp_rank: int = 0,
|
||||
):
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.lora_paths: Dict[str, str] = lora_paths
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.max_loras_per_batch: int = max_loras_per_batch
|
||||
self.load_config: LoadConfig = load_config
|
||||
@@ -69,8 +67,8 @@ class LoRAManager:
|
||||
backend_type = get_backend_from_name(lora_backend)
|
||||
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
||||
|
||||
self.init_loras()
|
||||
self.init_lora_memory_pool()
|
||||
# Initialize mutable internal state of the LoRAManager.
|
||||
self.init_state()
|
||||
|
||||
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
||||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
||||
@@ -100,72 +98,49 @@ class LoRAManager:
|
||||
],
|
||||
)
|
||||
|
||||
def init_loras(self):
|
||||
# Config of each LoRA adapter
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
def load_lora_adapters(self, lora_paths: Dict[str, str]):
|
||||
"""
|
||||
Load LoRA adapters from the specified paths.
|
||||
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
||||
|
||||
# Target module names in huggingface lora configs.
|
||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
self.hf_target_names: Set[str] = set()
|
||||
for name, path in self.lora_paths.items():
|
||||
self.configs[name] = LoRAConfig(path)
|
||||
self.hf_target_names.update(self.configs[name].target_modules)
|
||||
Args:
|
||||
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
||||
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
||||
"""
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
weights_A: List[str] = []
|
||||
weights_B: List[str] = []
|
||||
for module in self.hf_target_names:
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
||||
weights_A += lora_A
|
||||
weights_B += lora_B
|
||||
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
|
||||
for lora_name, lora_path in lora_paths.items():
|
||||
if lora_name in self.loras:
|
||||
logger.warning(
|
||||
f"LoRA adapter {lora_name} is already loaded."
|
||||
"If you want to reload it, please unload it first."
|
||||
)
|
||||
continue
|
||||
|
||||
# load all weights to cpu
|
||||
self.loras: Dict[str, LoRAAdapter] = {}
|
||||
for name in self.lora_paths.keys():
|
||||
lora_adapter = LoRAAdapter(
|
||||
name,
|
||||
self.configs[name],
|
||||
self.base_hf_config,
|
||||
self.load_config,
|
||||
self.lora_backend,
|
||||
)
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[name] = lora_adapter
|
||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||
|
||||
# misc lora configs
|
||||
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
self.update_state_from_configs()
|
||||
|
||||
if self.lora_backend == "flashinfer":
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
scaling = list(self.loras.values())[0].scaling
|
||||
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
||||
assert all(x.scaling == scaling for x in self.loras.values())
|
||||
def unload_lora_adapters(self, lora_names: Set[str]):
|
||||
"""
|
||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||
delete the corresponding LoRA modules.
|
||||
|
||||
# Convert original model layers to layers with LoRA
|
||||
self.convert_to_lora_layers()
|
||||
Args:
|
||||
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
||||
"""
|
||||
for lora_name in lora_names:
|
||||
if lora_name in self.loras:
|
||||
del self.configs[lora_name]
|
||||
else:
|
||||
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
|
||||
|
||||
def init_lora_memory_pool(self):
|
||||
# Initialize memory pool
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
self.base_hf_config,
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim,
|
||||
self.dtype,
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
self.lora_modules,
|
||||
)
|
||||
|
||||
# Initialize target lora modules in memory pool
|
||||
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
||||
self.update_state_from_configs()
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
||||
|
||||
# set up batch info shared by all lora modules
|
||||
bs = forward_batch.batch_size
|
||||
@@ -267,9 +242,16 @@ class LoRAManager:
|
||||
)
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
|
||||
# call set_lora_info for each lora modules
|
||||
for layer_id, modules in self.lora_modules.items():
|
||||
for module_name, module in modules:
|
||||
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
||||
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
||||
self.update_lora_info()
|
||||
|
||||
def update_lora_info(self):
|
||||
"""
|
||||
Update all LoRA modules to associate them with the latest memory buffer.
|
||||
"""
|
||||
for layer_id, layer_modules in self.lora_modules.items():
|
||||
for module_name, module in layer_modules.items():
|
||||
if "qkv_proj" in module_name:
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(
|
||||
@@ -295,23 +277,139 @@ class LoRAManager:
|
||||
),
|
||||
)
|
||||
|
||||
def init_state(self):
|
||||
"""
|
||||
Initialize the internal (mutable) state of the LoRAManager.
|
||||
|
||||
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
|
||||
"""
|
||||
|
||||
# Configs of all active LoRA adapters.
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
|
||||
# LoRA adapter weights cached in CPU memory.
|
||||
self.loras: Dict[str, LoRAAdapter] = {}
|
||||
|
||||
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
|
||||
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
|
||||
|
||||
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
|
||||
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
||||
}
|
||||
|
||||
# Initialize memory pool
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
self.base_hf_config,
|
||||
self.max_loras_per_batch,
|
||||
self.dtype,
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
)
|
||||
|
||||
def update_state_from_configs(self):
|
||||
"""
|
||||
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
||||
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
||||
|
||||
This includes:
|
||||
- Initializing LoRA adapters if they are not already loaded.
|
||||
- Collect all LoRA weight names based on the current loaded adapters.
|
||||
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
||||
- Preparing the GPU buffer pool for active LoRA weights.
|
||||
"""
|
||||
|
||||
# Target module names in huggingface lora configs.
|
||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
hf_target_module_names: Set[str] = set()
|
||||
for config in self.configs.values():
|
||||
hf_target_module_names.update(config.target_modules)
|
||||
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
|
||||
# Loads / unloads LoRA adapters based on the latest configs.
|
||||
self.update_lora_adapters()
|
||||
|
||||
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
||||
#
|
||||
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
||||
# multiple places to support the new weight names when the first adapter targeting such weight names
|
||||
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
||||
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
||||
# list of LoRA weight names is expected to be extremely finite and stable.
|
||||
self.update_lora_weight_names(hf_target_module_names)
|
||||
self.update_lora_modules(hf_target_module_names)
|
||||
self.update_memory_buffers(max_lora_dim)
|
||||
|
||||
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
||||
"""
|
||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||
"""
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
for module in hf_target_names:
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
||||
self.lora_weight_names[0].update(lora_A)
|
||||
self.lora_weight_names[1].update(lora_B)
|
||||
|
||||
def update_lora_adapters(self):
|
||||
"""
|
||||
Update the LoRA adapters in CPU memory based on the current `self.configs`.
|
||||
It loads any new adapters that are not already loaded, and unloads any adapters
|
||||
that are no longer in `self.configs` (e.g., unloaded).
|
||||
"""
|
||||
|
||||
# Load new adapter weights to cpu
|
||||
for name, config in self.configs.items():
|
||||
if name not in self.loras:
|
||||
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
|
||||
lora_adapter = LoRAAdapter(
|
||||
name,
|
||||
config,
|
||||
self.base_hf_config,
|
||||
self.load_config,
|
||||
self.lora_backend,
|
||||
)
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[name] = lora_adapter
|
||||
|
||||
# Clean up unused LoRA adapters
|
||||
for name in self.loras:
|
||||
if name not in self.configs:
|
||||
logger.info(f"Unloading LoRA adapter {name}")
|
||||
del self.loras[name]
|
||||
|
||||
# Additional checks for flashinfer backend
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
if self.lora_backend == "flashinfer":
|
||||
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
|
||||
scalings = set(x.scaling for x in self.loras.values())
|
||||
assert (
|
||||
len(lora_dims) == 1 and len(scalings) == 1
|
||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||
|
||||
def update_memory_buffers(self, max_lora_dim: int):
|
||||
"""
|
||||
Update the LoRA memory pool buffers based on the current LoRA configurations and update
|
||||
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
|
||||
are set or updated.
|
||||
"""
|
||||
|
||||
self.memory_pool.init_buffers(
|
||||
self.lora_weight_names, self.base_model, max_lora_dim
|
||||
)
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(module, self.lora_backend)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def convert_to_lora_layers(self):
|
||||
def update_lora_modules(self, hf_target_names: Set[str]):
|
||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||
# e.g., {"qkv_proj", "o_proj"}
|
||||
customized_target_names = get_customized_names_from_hf_names(
|
||||
self.hf_target_names, self.base_model
|
||||
hf_target_names, self.base_model
|
||||
)
|
||||
|
||||
# Monkey patch to use the LoRA version layers
|
||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
||||
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
||||
}
|
||||
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
# TODO (lifuhuang): in the future, we should consider generalizing the
|
||||
# should_apply_lora function to support mapping by full module name instead
|
||||
@@ -326,6 +424,7 @@ class LoRAManager:
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in customized_target_names:
|
||||
layer_id = get_layer_id(module_name)
|
||||
self.lora_modules[layer_id].append(
|
||||
(module_name, self.set_lora_module(module_name, module))
|
||||
)
|
||||
if module_name not in self.lora_modules[layer_id]:
|
||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||
module_name, module
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,21 +22,16 @@ class LoRAMemoryPool:
|
||||
self,
|
||||
base_hf_config: AutoConfig,
|
||||
max_loras_per_batch: int,
|
||||
max_lora_dim: int,
|
||||
dtype: torch.dtype,
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
|
||||
):
|
||||
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||
self.max_loras_per_batch: int = max_loras_per_batch
|
||||
self.max_lora_dim: int = max_lora_dim
|
||||
self.dtype: torch.dtype = dtype
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
||||
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
@@ -55,79 +50,84 @@ class LoRAMemoryPool:
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
|
||||
def get_lora_A_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module
|
||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_name)
|
||||
if self.tp_size > 1:
|
||||
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
input_dim = divide(input_dim, self.tp_size)
|
||||
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
input_dim = divide(input_dim, self.tp_size)
|
||||
return (
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
max_lora_dim * c,
|
||||
input_dim,
|
||||
)
|
||||
|
||||
def get_lora_B_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module
|
||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_name)
|
||||
if self.tp_size > 1:
|
||||
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
output_dim = divide(output_dim, self.tp_size)
|
||||
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
output_dim = divide(output_dim, self.tp_size)
|
||||
return (
|
||||
c,
|
||||
self.max_loras_per_batch,
|
||||
output_dim,
|
||||
self.max_lora_dim,
|
||||
max_lora_dim,
|
||||
)
|
||||
|
||||
def init_buffers(
|
||||
self,
|
||||
lora_weight_names: Tuple[Set[str]],
|
||||
base_model: torch.nn.Module,
|
||||
max_lora_dim: int,
|
||||
):
|
||||
|
||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
||||
device = next(base_model.parameters()).device
|
||||
# Init A tensor, column_major=False
|
||||
for module_A in lora_weight_names[0]:
|
||||
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
lora_A_shape,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(self.num_layer)
|
||||
]
|
||||
# Init B tensor, column_major=True
|
||||
for module_B in lora_weight_names[1]:
|
||||
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
lora_B_shape,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(self.num_layer)
|
||||
]
|
||||
|
||||
def update_buffer(
|
||||
buffer: Dict[str, List[torch.Tensor]],
|
||||
lora_weight_names: Set[str],
|
||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||
):
|
||||
new_weight_names = lora_weight_names - buffer.keys()
|
||||
for module_name in new_weight_names:
|
||||
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
|
||||
buffer[module_name] = [
|
||||
torch.empty(
|
||||
lora_shape,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(self.num_layer)
|
||||
]
|
||||
|
||||
update_buffer(
|
||||
self.A_buffer,
|
||||
lora_weight_names[0],
|
||||
self.get_lora_A_shape,
|
||||
)
|
||||
|
||||
update_buffer(
|
||||
self.B_buffer,
|
||||
lora_weight_names[1],
|
||||
self.get_lora_B_shape,
|
||||
)
|
||||
|
||||
def prepare_lora_batch(
|
||||
self,
|
||||
cur_uids: Set[Optional[str]],
|
||||
lora_adapters: Dict[str, LoRAAdapter],
|
||||
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
||||
):
|
||||
|
||||
def get_available_buffer_slot():
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
# Prioritize empty slots
|
||||
@@ -147,14 +147,19 @@ class LoRAMemoryPool:
|
||||
for uid in cur_uids:
|
||||
if uid not in self.uid_to_buffer_id:
|
||||
buffer_id = get_available_buffer_slot()
|
||||
lora_adapter = lora_adapters.get(uid, None)
|
||||
self.load_lora_weight_to_buffer(
|
||||
uid, buffer_id, lora_adapters.get(uid, None)
|
||||
uid, buffer_id, lora_adapter, lora_modules
|
||||
)
|
||||
self.uid_to_buffer_id[uid] = buffer_id
|
||||
self.buffer_id_to_uid[buffer_id] = uid
|
||||
|
||||
def load_lora_weight_to_buffer(
|
||||
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
|
||||
self,
|
||||
uid: str,
|
||||
buffer_id: int,
|
||||
lora_adapter: LoRAAdapter,
|
||||
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
||||
):
|
||||
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
||||
assert (
|
||||
@@ -186,8 +191,8 @@ class LoRAMemoryPool:
|
||||
temp_B_buffer[lora_weight_name] = weights
|
||||
|
||||
if self.tp_size > 1:
|
||||
cur_layer_modules = self.lora_modules[layer_id]
|
||||
for module_name, module in cur_layer_modules:
|
||||
cur_layer_modules = lora_modules[layer_id]
|
||||
for module_name, module in cur_layer_modules.items():
|
||||
if "qkv_proj" in module_name:
|
||||
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
||||
temp_A_buffer["qkv_proj"], self.tp_rank
|
||||
@@ -236,7 +241,6 @@ class LoRAMemoryPool:
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
) -> torch.Tensor:
|
||||
|
||||
if lora_type == LoRAType.LORA_A:
|
||||
return self.A_buffer[weight_name][layer_id]
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ def get_hidden_dim(
|
||||
|
||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Mapping a target module name to names of the normized LoRA weights.
|
||||
Mapping a target module name to names of the normalized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
|
||||
@@ -278,6 +278,10 @@ class ModelRunner:
|
||||
self.apply_torch_tp()
|
||||
|
||||
# Init lora
|
||||
# TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
|
||||
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
|
||||
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
|
||||
# load LoRA adapters dynamically later.
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
|
||||
@@ -796,7 +800,6 @@ class ModelRunner:
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
base_model=self.model,
|
||||
lora_paths=self.server_args.lora_paths,
|
||||
base_hf_config=self.model_config.hf_config,
|
||||
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
||||
load_config=self.load_config,
|
||||
@@ -805,6 +808,7 @@ class ModelRunner:
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||
logger.info("LoRA manager ready.")
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
|
||||
Reference in New Issue
Block a user