Refactor LoRAManager and LoRAMemoryPool state management logic for dynamic LoRA loading support (#7412)
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
# and "Punica: Multi-Tenant LoRA Serving"
|
# and "Punica: Multi-Tenant LoRA Serving"
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -45,7 +45,6 @@ class LoRAManager:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_model: torch.nn.Module,
|
base_model: torch.nn.Module,
|
||||||
lora_paths: Dict[str, str],
|
|
||||||
base_hf_config: AutoConfig,
|
base_hf_config: AutoConfig,
|
||||||
max_loras_per_batch: int,
|
max_loras_per_batch: int,
|
||||||
load_config: LoadConfig,
|
load_config: LoadConfig,
|
||||||
@@ -55,7 +54,6 @@ class LoRAManager:
|
|||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
):
|
):
|
||||||
self.base_model: torch.nn.Module = base_model
|
self.base_model: torch.nn.Module = base_model
|
||||||
self.lora_paths: Dict[str, str] = lora_paths
|
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
self.max_loras_per_batch: int = max_loras_per_batch
|
self.max_loras_per_batch: int = max_loras_per_batch
|
||||||
self.load_config: LoadConfig = load_config
|
self.load_config: LoadConfig = load_config
|
||||||
@@ -69,8 +67,8 @@ class LoRAManager:
|
|||||||
backend_type = get_backend_from_name(lora_backend)
|
backend_type = get_backend_from_name(lora_backend)
|
||||||
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
||||||
|
|
||||||
self.init_loras()
|
# Initialize mutable internal state of the LoRAManager.
|
||||||
self.init_lora_memory_pool()
|
self.init_state()
|
||||||
|
|
||||||
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
||||||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
||||||
@@ -100,72 +98,49 @@ class LoRAManager:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_loras(self):
|
def load_lora_adapters(self, lora_paths: Dict[str, str]):
|
||||||
# Config of each LoRA adapter
|
"""
|
||||||
self.configs: Dict[str, LoRAConfig] = {}
|
Load LoRA adapters from the specified paths.
|
||||||
|
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
||||||
|
|
||||||
# Target module names in huggingface lora configs.
|
Args:
|
||||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
||||||
self.hf_target_names: Set[str] = set()
|
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
||||||
for name, path in self.lora_paths.items():
|
"""
|
||||||
self.configs[name] = LoRAConfig(path)
|
|
||||||
self.hf_target_names.update(self.configs[name].target_modules)
|
|
||||||
|
|
||||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
for lora_name, lora_path in lora_paths.items():
|
||||||
weights_A: List[str] = []
|
if lora_name in self.loras:
|
||||||
weights_B: List[str] = []
|
logger.warning(
|
||||||
for module in self.hf_target_names:
|
f"LoRA adapter {lora_name} is already loaded."
|
||||||
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
"If you want to reload it, please unload it first."
|
||||||
weights_A += lora_A
|
)
|
||||||
weights_B += lora_B
|
continue
|
||||||
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
|
|
||||||
|
|
||||||
# load all weights to cpu
|
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||||
self.loras: Dict[str, LoRAAdapter] = {}
|
|
||||||
for name in self.lora_paths.keys():
|
|
||||||
lora_adapter = LoRAAdapter(
|
|
||||||
name,
|
|
||||||
self.configs[name],
|
|
||||||
self.base_hf_config,
|
|
||||||
self.load_config,
|
|
||||||
self.lora_backend,
|
|
||||||
)
|
|
||||||
lora_adapter.initialize_weights()
|
|
||||||
self.loras[name] = lora_adapter
|
|
||||||
|
|
||||||
# misc lora configs
|
self.update_state_from_configs()
|
||||||
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
|
||||||
|
|
||||||
if self.lora_backend == "flashinfer":
|
def unload_lora_adapters(self, lora_names: Set[str]):
|
||||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
"""
|
||||||
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||||
scaling = list(self.loras.values())[0].scaling
|
delete the corresponding LoRA modules.
|
||||||
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
|
||||||
assert all(x.scaling == scaling for x in self.loras.values())
|
|
||||||
|
|
||||||
# Convert original model layers to layers with LoRA
|
Args:
|
||||||
self.convert_to_lora_layers()
|
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
||||||
|
"""
|
||||||
|
for lora_name in lora_names:
|
||||||
|
if lora_name in self.loras:
|
||||||
|
del self.configs[lora_name]
|
||||||
|
else:
|
||||||
|
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
|
||||||
|
|
||||||
def init_lora_memory_pool(self):
|
self.update_state_from_configs()
|
||||||
# Initialize memory pool
|
|
||||||
self.memory_pool = LoRAMemoryPool(
|
|
||||||
self.base_hf_config,
|
|
||||||
self.max_loras_per_batch,
|
|
||||||
self.max_lora_dim,
|
|
||||||
self.dtype,
|
|
||||||
self.tp_size,
|
|
||||||
self.tp_rank,
|
|
||||||
self.lora_modules,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize target lora modules in memory pool
|
|
||||||
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
|
||||||
|
|
||||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||||
# load active loras into lora memory pool
|
# load active loras into lora memory pool
|
||||||
cur_uids = set(forward_batch.lora_paths)
|
cur_uids = set(forward_batch.lora_paths)
|
||||||
assert len(cur_uids) <= self.max_loras_per_batch
|
assert len(cur_uids) <= self.max_loras_per_batch
|
||||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
||||||
|
|
||||||
# set up batch info shared by all lora modules
|
# set up batch info shared by all lora modules
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
@@ -267,9 +242,16 @@ class LoRAManager:
|
|||||||
)
|
)
|
||||||
self.lora_backend.set_batch_info(batch_info)
|
self.lora_backend.set_batch_info(batch_info)
|
||||||
|
|
||||||
# call set_lora_info for each lora modules
|
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
||||||
for layer_id, modules in self.lora_modules.items():
|
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
||||||
for module_name, module in modules:
|
self.update_lora_info()
|
||||||
|
|
||||||
|
def update_lora_info(self):
|
||||||
|
"""
|
||||||
|
Update all LoRA modules to associate them with the latest memory buffer.
|
||||||
|
"""
|
||||||
|
for layer_id, layer_modules in self.lora_modules.items():
|
||||||
|
for module_name, module in layer_modules.items():
|
||||||
if "qkv_proj" in module_name:
|
if "qkv_proj" in module_name:
|
||||||
module.set_lora_info(
|
module.set_lora_info(
|
||||||
self.memory_pool.get_tensor(
|
self.memory_pool.get_tensor(
|
||||||
@@ -295,23 +277,139 @@ class LoRAManager:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_state(self):
|
||||||
|
"""
|
||||||
|
Initialize the internal (mutable) state of the LoRAManager.
|
||||||
|
|
||||||
|
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Configs of all active LoRA adapters.
|
||||||
|
self.configs: Dict[str, LoRAConfig] = {}
|
||||||
|
|
||||||
|
# LoRA adapter weights cached in CPU memory.
|
||||||
|
self.loras: Dict[str, LoRAAdapter] = {}
|
||||||
|
|
||||||
|
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
|
||||||
|
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
|
||||||
|
|
||||||
|
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||||
|
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
|
||||||
|
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize memory pool
|
||||||
|
self.memory_pool = LoRAMemoryPool(
|
||||||
|
self.base_hf_config,
|
||||||
|
self.max_loras_per_batch,
|
||||||
|
self.dtype,
|
||||||
|
self.tp_size,
|
||||||
|
self.tp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_state_from_configs(self):
|
||||||
|
"""
|
||||||
|
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
||||||
|
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
||||||
|
|
||||||
|
This includes:
|
||||||
|
- Initializing LoRA adapters if they are not already loaded.
|
||||||
|
- Collect all LoRA weight names based on the current loaded adapters.
|
||||||
|
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
||||||
|
- Preparing the GPU buffer pool for active LoRA weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Target module names in huggingface lora configs.
|
||||||
|
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||||
|
hf_target_module_names: Set[str] = set()
|
||||||
|
for config in self.configs.values():
|
||||||
|
hf_target_module_names.update(config.target_modules)
|
||||||
|
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||||
|
|
||||||
|
# Loads / unloads LoRA adapters based on the latest configs.
|
||||||
|
self.update_lora_adapters()
|
||||||
|
|
||||||
|
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
||||||
|
#
|
||||||
|
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
||||||
|
# multiple places to support the new weight names when the first adapter targeting such weight names
|
||||||
|
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
||||||
|
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
||||||
|
# list of LoRA weight names is expected to be extremely finite and stable.
|
||||||
|
self.update_lora_weight_names(hf_target_module_names)
|
||||||
|
self.update_lora_modules(hf_target_module_names)
|
||||||
|
self.update_memory_buffers(max_lora_dim)
|
||||||
|
|
||||||
|
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
||||||
|
"""
|
||||||
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||||
|
for module in hf_target_names:
|
||||||
|
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
||||||
|
self.lora_weight_names[0].update(lora_A)
|
||||||
|
self.lora_weight_names[1].update(lora_B)
|
||||||
|
|
||||||
|
def update_lora_adapters(self):
|
||||||
|
"""
|
||||||
|
Update the LoRA adapters in CPU memory based on the current `self.configs`.
|
||||||
|
It loads any new adapters that are not already loaded, and unloads any adapters
|
||||||
|
that are no longer in `self.configs` (e.g., unloaded).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load new adapter weights to cpu
|
||||||
|
for name, config in self.configs.items():
|
||||||
|
if name not in self.loras:
|
||||||
|
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
|
||||||
|
lora_adapter = LoRAAdapter(
|
||||||
|
name,
|
||||||
|
config,
|
||||||
|
self.base_hf_config,
|
||||||
|
self.load_config,
|
||||||
|
self.lora_backend,
|
||||||
|
)
|
||||||
|
lora_adapter.initialize_weights()
|
||||||
|
self.loras[name] = lora_adapter
|
||||||
|
|
||||||
|
# Clean up unused LoRA adapters
|
||||||
|
for name in self.loras:
|
||||||
|
if name not in self.configs:
|
||||||
|
logger.info(f"Unloading LoRA adapter {name}")
|
||||||
|
del self.loras[name]
|
||||||
|
|
||||||
|
# Additional checks for flashinfer backend
|
||||||
|
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||||
|
if self.lora_backend == "flashinfer":
|
||||||
|
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
|
||||||
|
scalings = set(x.scaling for x in self.loras.values())
|
||||||
|
assert (
|
||||||
|
len(lora_dims) == 1 and len(scalings) == 1
|
||||||
|
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||||
|
|
||||||
|
def update_memory_buffers(self, max_lora_dim: int):
|
||||||
|
"""
|
||||||
|
Update the LoRA memory pool buffers based on the current LoRA configurations and update
|
||||||
|
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
|
||||||
|
are set or updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.memory_pool.init_buffers(
|
||||||
|
self.lora_weight_names, self.base_model, max_lora_dim
|
||||||
|
)
|
||||||
|
|
||||||
def set_lora_module(self, module_name, module):
|
def set_lora_module(self, module_name, module):
|
||||||
lora_module = get_lora_layer(module, self.lora_backend)
|
lora_module = get_lora_layer(module, self.lora_backend)
|
||||||
replace_submodule(self.base_model, module_name, lora_module)
|
replace_submodule(self.base_model, module_name, lora_module)
|
||||||
return lora_module
|
return lora_module
|
||||||
|
|
||||||
def convert_to_lora_layers(self):
|
def update_lora_modules(self, hf_target_names: Set[str]):
|
||||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||||
# e.g., {"qkv_proj", "o_proj"}
|
# e.g., {"qkv_proj", "o_proj"}
|
||||||
customized_target_names = get_customized_names_from_hf_names(
|
customized_target_names = get_customized_names_from_hf_names(
|
||||||
self.hf_target_names, self.base_model
|
hf_target_names, self.base_model
|
||||||
)
|
)
|
||||||
|
|
||||||
# Monkey patch to use the LoRA version layers
|
|
||||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
|
||||||
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
|
||||||
}
|
|
||||||
|
|
||||||
for module_name, module in self.base_model.named_modules():
|
for module_name, module in self.base_model.named_modules():
|
||||||
# TODO (lifuhuang): in the future, we should consider generalizing the
|
# TODO (lifuhuang): in the future, we should consider generalizing the
|
||||||
# should_apply_lora function to support mapping by full module name instead
|
# should_apply_lora function to support mapping by full module name instead
|
||||||
@@ -326,6 +424,7 @@ class LoRAManager:
|
|||||||
# The module should be converted if it is included in target_names
|
# The module should be converted if it is included in target_names
|
||||||
if module_name.split(".")[-1] in customized_target_names:
|
if module_name.split(".")[-1] in customized_target_names:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
self.lora_modules[layer_id].append(
|
if module_name not in self.lora_modules[layer_id]:
|
||||||
(module_name, self.set_lora_module(module_name, module))
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||||
)
|
module_name, module
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Callable, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -22,21 +22,16 @@ class LoRAMemoryPool:
|
|||||||
self,
|
self,
|
||||||
base_hf_config: AutoConfig,
|
base_hf_config: AutoConfig,
|
||||||
max_loras_per_batch: int,
|
max_loras_per_batch: int,
|
||||||
max_lora_dim: int,
|
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
|
|
||||||
):
|
):
|
||||||
|
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
self.num_layer: int = base_hf_config.num_hidden_layers
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||||
self.max_loras_per_batch: int = max_loras_per_batch
|
self.max_loras_per_batch: int = max_loras_per_batch
|
||||||
self.max_lora_dim: int = max_lora_dim
|
|
||||||
self.dtype: torch.dtype = dtype
|
self.dtype: torch.dtype = dtype
|
||||||
self.tp_size: int = tp_size
|
self.tp_size: int = tp_size
|
||||||
self.tp_rank: int = tp_rank
|
self.tp_rank: int = tp_rank
|
||||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
|
||||||
|
|
||||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||||
# A_buffer contains num_layer number of row-major tensors with shape
|
# A_buffer contains num_layer number of row-major tensors with shape
|
||||||
@@ -55,79 +50,84 @@ class LoRAMemoryPool:
|
|||||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||||
|
|
||||||
def get_lora_A_shape(
|
def get_lora_A_shape(
|
||||||
self, module_name: str, base_model: torch.nn.Module
|
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||||
) -> Tuple[int]:
|
) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||||
"""
|
"""
|
||||||
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||||
c = get_stacked_multiply(module_name)
|
c = get_stacked_multiply(module_name)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||||
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
input_dim = divide(input_dim, self.tp_size)
|
||||||
input_dim = divide(input_dim, self.tp_size)
|
|
||||||
return (
|
return (
|
||||||
self.max_loras_per_batch,
|
self.max_loras_per_batch,
|
||||||
self.max_lora_dim * c,
|
max_lora_dim * c,
|
||||||
input_dim,
|
input_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_lora_B_shape(
|
def get_lora_B_shape(
|
||||||
self, module_name: str, base_model: torch.nn.Module
|
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||||
) -> Tuple[int]:
|
) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||||
"""
|
"""
|
||||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||||
c = get_stacked_multiply(module_name)
|
c = get_stacked_multiply(module_name)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||||
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
output_dim = divide(output_dim, self.tp_size)
|
||||||
output_dim = divide(output_dim, self.tp_size)
|
|
||||||
return (
|
return (
|
||||||
c,
|
c,
|
||||||
self.max_loras_per_batch,
|
self.max_loras_per_batch,
|
||||||
output_dim,
|
output_dim,
|
||||||
self.max_lora_dim,
|
max_lora_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_buffers(
|
def init_buffers(
|
||||||
self,
|
self,
|
||||||
lora_weight_names: Tuple[Set[str]],
|
lora_weight_names: Tuple[Set[str]],
|
||||||
base_model: torch.nn.Module,
|
base_model: torch.nn.Module,
|
||||||
|
max_lora_dim: int,
|
||||||
):
|
):
|
||||||
|
|
||||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||||
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
||||||
device = next(base_model.parameters()).device
|
device = next(base_model.parameters()).device
|
||||||
# Init A tensor, column_major=False
|
|
||||||
for module_A in lora_weight_names[0]:
|
def update_buffer(
|
||||||
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
buffer: Dict[str, List[torch.Tensor]],
|
||||||
self.A_buffer[module_A] = [
|
lora_weight_names: Set[str],
|
||||||
torch.empty(
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||||
lora_A_shape,
|
):
|
||||||
dtype=self.dtype,
|
new_weight_names = lora_weight_names - buffer.keys()
|
||||||
device=device,
|
for module_name in new_weight_names:
|
||||||
)
|
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
|
||||||
for _ in range(self.num_layer)
|
buffer[module_name] = [
|
||||||
]
|
torch.empty(
|
||||||
# Init B tensor, column_major=True
|
lora_shape,
|
||||||
for module_B in lora_weight_names[1]:
|
dtype=self.dtype,
|
||||||
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
device=device,
|
||||||
self.B_buffer[module_B] = [
|
)
|
||||||
torch.empty(
|
for _ in range(self.num_layer)
|
||||||
lora_B_shape,
|
]
|
||||||
dtype=self.dtype,
|
|
||||||
device=device,
|
update_buffer(
|
||||||
)
|
self.A_buffer,
|
||||||
for _ in range(self.num_layer)
|
lora_weight_names[0],
|
||||||
]
|
self.get_lora_A_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_buffer(
|
||||||
|
self.B_buffer,
|
||||||
|
lora_weight_names[1],
|
||||||
|
self.get_lora_B_shape,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_lora_batch(
|
def prepare_lora_batch(
|
||||||
self,
|
self,
|
||||||
cur_uids: Set[Optional[str]],
|
cur_uids: Set[Optional[str]],
|
||||||
lora_adapters: Dict[str, LoRAAdapter],
|
lora_adapters: Dict[str, LoRAAdapter],
|
||||||
|
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
||||||
):
|
):
|
||||||
|
|
||||||
def get_available_buffer_slot():
|
def get_available_buffer_slot():
|
||||||
for buffer_id in range(self.max_loras_per_batch):
|
for buffer_id in range(self.max_loras_per_batch):
|
||||||
# Prioritize empty slots
|
# Prioritize empty slots
|
||||||
@@ -147,14 +147,19 @@ class LoRAMemoryPool:
|
|||||||
for uid in cur_uids:
|
for uid in cur_uids:
|
||||||
if uid not in self.uid_to_buffer_id:
|
if uid not in self.uid_to_buffer_id:
|
||||||
buffer_id = get_available_buffer_slot()
|
buffer_id = get_available_buffer_slot()
|
||||||
|
lora_adapter = lora_adapters.get(uid, None)
|
||||||
self.load_lora_weight_to_buffer(
|
self.load_lora_weight_to_buffer(
|
||||||
uid, buffer_id, lora_adapters.get(uid, None)
|
uid, buffer_id, lora_adapter, lora_modules
|
||||||
)
|
)
|
||||||
self.uid_to_buffer_id[uid] = buffer_id
|
self.uid_to_buffer_id[uid] = buffer_id
|
||||||
self.buffer_id_to_uid[buffer_id] = uid
|
self.buffer_id_to_uid[buffer_id] = uid
|
||||||
|
|
||||||
def load_lora_weight_to_buffer(
|
def load_lora_weight_to_buffer(
|
||||||
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
|
self,
|
||||||
|
uid: str,
|
||||||
|
buffer_id: int,
|
||||||
|
lora_adapter: LoRAAdapter,
|
||||||
|
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
||||||
):
|
):
|
||||||
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
||||||
assert (
|
assert (
|
||||||
@@ -186,8 +191,8 @@ class LoRAMemoryPool:
|
|||||||
temp_B_buffer[lora_weight_name] = weights
|
temp_B_buffer[lora_weight_name] = weights
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
cur_layer_modules = self.lora_modules[layer_id]
|
cur_layer_modules = lora_modules[layer_id]
|
||||||
for module_name, module in cur_layer_modules:
|
for module_name, module in cur_layer_modules.items():
|
||||||
if "qkv_proj" in module_name:
|
if "qkv_proj" in module_name:
|
||||||
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
||||||
temp_A_buffer["qkv_proj"], self.tp_rank
|
temp_A_buffer["qkv_proj"], self.tp_rank
|
||||||
@@ -236,7 +241,6 @@ class LoRAMemoryPool:
|
|||||||
def get_tensor(
|
def get_tensor(
|
||||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if lora_type == LoRAType.LORA_A:
|
if lora_type == LoRAType.LORA_A:
|
||||||
return self.A_buffer[weight_name][layer_id]
|
return self.A_buffer[weight_name][layer_id]
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ def get_hidden_dim(
|
|||||||
|
|
||||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||||
"""
|
"""
|
||||||
Mapping a target module name to names of the normized LoRA weights.
|
Mapping a target module name to names of the normalized LoRA weights.
|
||||||
Returned tuple contains (name for Lora A, name for Lora B)
|
Returned tuple contains (name for Lora A, name for Lora B)
|
||||||
"""
|
"""
|
||||||
params_mapping = {
|
params_mapping = {
|
||||||
|
|||||||
@@ -278,6 +278,10 @@ class ModelRunner:
|
|||||||
self.apply_torch_tp()
|
self.apply_torch_tp()
|
||||||
|
|
||||||
# Init lora
|
# Init lora
|
||||||
|
# TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
|
||||||
|
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
|
||||||
|
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
|
||||||
|
# load LoRA adapters dynamically later.
|
||||||
if server_args.lora_paths is not None:
|
if server_args.lora_paths is not None:
|
||||||
self.init_lora_manager()
|
self.init_lora_manager()
|
||||||
|
|
||||||
@@ -796,7 +800,6 @@ class ModelRunner:
|
|||||||
def init_lora_manager(self):
|
def init_lora_manager(self):
|
||||||
self.lora_manager = LoRAManager(
|
self.lora_manager = LoRAManager(
|
||||||
base_model=self.model,
|
base_model=self.model,
|
||||||
lora_paths=self.server_args.lora_paths,
|
|
||||||
base_hf_config=self.model_config.hf_config,
|
base_hf_config=self.model_config.hf_config,
|
||||||
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
@@ -805,6 +808,7 @@ class ModelRunner:
|
|||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
)
|
)
|
||||||
|
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||||
logger.info("LoRA manager ready.")
|
logger.info("LoRA manager ready.")
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory: int):
|
def profile_max_num_token(self, total_gpu_memory: int):
|
||||||
|
|||||||
Reference in New Issue
Block a user