Revert "fix some typos" (#6244)
This commit is contained in:
@@ -41,13 +41,13 @@ class BaseLoRABackend:
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Run segment Gemm of LoRA a modules with current backend.
|
||||
"""Run segment Gemm of lora a modules with current backend.
|
||||
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||
weights: a set of LoRA weights with shape (num_lora, c * r, input_dim),
|
||||
here r is LoRA rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
|
||||
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
|
||||
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
|
||||
usually input_dim is much larger than r
|
||||
Returns:
|
||||
result with shape (s, c * r)
|
||||
@@ -57,12 +57,12 @@ class BaseLoRABackend:
|
||||
def run_lora_b_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Run segment Gemm of LoRA b modules with current backend.
|
||||
"""Run segment Gemm of lora b modules with current backend.
|
||||
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is LoRA rank
|
||||
weights: a set of LoRA weights with shape (num_lora, output_dim, r)
|
||||
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
|
||||
weights: a set of lora weights with shape (num_lora, output_dim, r)
|
||||
usually output_dim is much larger than r
|
||||
Returns:
|
||||
result with shape (s, output_dim)
|
||||
@@ -77,7 +77,7 @@ class BaseLoRABackend:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Run the LoRA pass for QKV Layer.
|
||||
"""Run the lora pass for QKV Layer.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||
@@ -100,7 +100,7 @@ class BaseLoRABackend:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Run the LoRA pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
||||
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||
|
||||
@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
# Compute LoRA for gate and up proj respectively
|
||||
# Compute lora for gate and up proj respectively
|
||||
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, :lora_rank].contiguous(),
|
||||
weights=gate_up_lora_b[0],
|
||||
|
||||
@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
assert (
|
||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
||||
), "The LoRA rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||
|
||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
|
||||
@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
|
||||
self.config: LoRAConfig = config
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
|
||||
# LoRA weights in cpu. The weights are loaded from checkpoint.
|
||||
# lora weights in cpu. The weights are loaded from checkpoint.
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
|
||||
|
||||
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
||||
|
||||
# Collect target q/k/v modules. This process is necessary since there might be no LoRA attached to k_proj
|
||||
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
||||
target_module = set()
|
||||
for weight_name in weight_names:
|
||||
if "k_proj" in weight_name:
|
||||
@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
|
||||
return
|
||||
|
||||
for weight_name in weight_names:
|
||||
# We assume every LoRA adaptor should contain LoRA modules for q_proj
|
||||
# We assume every lora adaptor should contain lora modules for q_proj
|
||||
if "q_proj" in weight_name:
|
||||
q_name = weight_name
|
||||
k_name = weight_name.replace("q_proj", "k_proj")
|
||||
@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
|
||||
kv_name = weight_name.replace("q_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||
|
||||
# If k_proj doesn't have LoRA, initialize it to zero
|
||||
# If k_proj doesn't have lora, initialize it to zero
|
||||
k_proj_weight = (
|
||||
weights[k_name]
|
||||
if "k_proj" in target_module
|
||||
|
||||
@@ -93,14 +93,14 @@ class LoRAManager:
|
||||
# Config of each LoRA adapter
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
|
||||
# Target module names in HuggingFace LoRA configs.
|
||||
# Target module names in huggingface lora configs.
|
||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
self.hf_target_names: Set[str] = set()
|
||||
for name, path in self.lora_paths.items():
|
||||
self.configs[name] = LoRAConfig(path)
|
||||
self.hf_target_names.update(self.configs[name].target_modules)
|
||||
|
||||
# Target LoRA weight names for lora_a and lora_b modules respectively.
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = set(
|
||||
[get_stacked_name(module) for module in self.hf_target_names]
|
||||
@@ -119,11 +119,11 @@ class LoRAManager:
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[name] = lora_adapter
|
||||
|
||||
# misc LoRA configs
|
||||
# misc lora configs
|
||||
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
|
||||
if self.lora_backend == "flashinfer":
|
||||
# FIXME: remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
scaling = list(self.loras.values())[0].scaling
|
||||
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
||||
@@ -144,16 +144,16 @@ class LoRAManager:
|
||||
self.lora_modules,
|
||||
)
|
||||
|
||||
# Initialize target LoRA modules in memory pool
|
||||
# Initialize target lora modules in memory pool
|
||||
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active LoRAs into LoRA memory pool
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
||||
|
||||
# set up batch info shared by all LoRA modules
|
||||
# set up batch info shared by all lora modules
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
if (
|
||||
@@ -221,7 +221,7 @@ class LoRAManager:
|
||||
)
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
|
||||
# call set_lora_info for each LoRA modules
|
||||
# call set_lora_info for each lora modules
|
||||
for layer_id, modules in self.lora_modules.items():
|
||||
for module_name, module in modules:
|
||||
if "qkv_proj" in module_name:
|
||||
|
||||
@@ -16,7 +16,7 @@ from sglang.srt.lora.utils import (
|
||||
|
||||
|
||||
class LoRAMemoryPool:
|
||||
"""Class for memory pool management of LoRA modules"""
|
||||
"""Class for memory pool management of lora modules"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -38,7 +38,7 @@ class LoRAMemoryPool:
|
||||
self.tp_rank: int = tp_rank
|
||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
||||
|
||||
# Both A_buffer and B_buffer maps LoRA weight names to its buffer space.
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
|
||||
# B_buffer contains num_layer number of column-major tensors with shape
|
||||
@@ -46,10 +46,10 @@ class LoRAMemoryPool:
|
||||
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
|
||||
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
|
||||
|
||||
# LoRA uid -> buffer idx in memory pool
|
||||
# Lora uid -> buffer idx in memory pool
|
||||
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
||||
|
||||
# Buffer idx -> LoRA uid in memory pool
|
||||
# Buffer idx -> lora uid in memory pool
|
||||
# All uids are initialized as empty strings for empty buffer slots
|
||||
# Here we don't initialize to None since None is a valid uid
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
@@ -95,7 +95,7 @@ class LoRAMemoryPool:
|
||||
base_model: torch.nn.Module,
|
||||
):
|
||||
|
||||
# lora_weight_names is a set of name pairs indicating each pair of LoRA modules to load
|
||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
||||
device = next(base_model.parameters()).device
|
||||
@@ -137,7 +137,7 @@ class LoRAMemoryPool:
|
||||
return buffer_id, ""
|
||||
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
# Evict unneeded LoRA
|
||||
# Evict unneeded lora
|
||||
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
||||
return buffer_id, self.buffer_id_to_uid[buffer_id]
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def _gate_up_lora_b_kernel(
|
||||
):
|
||||
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
||||
|
||||
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to LoRA rank
|
||||
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
|
||||
# weights: (num_lora, 2 * output_dim, K)
|
||||
# output: (s, 2 * output_dim)
|
||||
# output_dim >> K
|
||||
|
||||
@@ -39,7 +39,7 @@ def _qkv_lora_b_kernel(
|
||||
):
|
||||
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
||||
|
||||
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to LoRA rank
|
||||
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
|
||||
# weights: (num_lora, N_Q + 2 * N_KV, K)
|
||||
# output: (s, N_Q + 2 * N_KV)
|
||||
# N_Q >> K, N_KV >> K
|
||||
|
||||
@@ -22,13 +22,13 @@ class LoRABatchInfo:
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of LoRA adapter used by each sequence, in shape (bs,)
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
# ranks of each LoRA adapter, in shape (lora_num,)
|
||||
# ranks of each lora adapter, in shape (lora_num,)
|
||||
lora_ranks: torch.Tensor
|
||||
|
||||
# scaling of each LoRA adapter, in shape (lora_num,)
|
||||
# scaling of each lora adapter, in shape (lora_num,)
|
||||
scalings: torch.Tensor
|
||||
|
||||
|
||||
@@ -51,9 +51,9 @@ def get_customized_names_from_hf_names(
|
||||
hf_module_names: Set[str], base_model: torch.nn.Module
|
||||
) -> Set[str]:
|
||||
"""
|
||||
This function takes in a set of HuggingFace style module names:
|
||||
This function takes in a set of huggingface style module names:
|
||||
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
and outputs a set of module names of customized SGLang layers:
|
||||
and outputs a set of module names of customized sglang layers:
|
||||
e.g., {"qkv_proj", "o_proj"}
|
||||
"""
|
||||
if hasattr(base_model, "get_module_name"):
|
||||
@@ -87,7 +87,7 @@ def get_hidden_dim(
|
||||
else:
|
||||
"""
|
||||
WARNING: get_hidden_dim() is not defined,
|
||||
which is used to get the hidden dim for different LoRA modules
|
||||
which is used to get the hidden dim for different lora modules
|
||||
Use the default one, but please check if it is correct for your model.
|
||||
Please implement the function in the model class if it is not.
|
||||
You can reference this function in llama.py.
|
||||
@@ -108,7 +108,7 @@ def get_hidden_dim(
|
||||
|
||||
def get_stacked_name(name: str) -> Tuple[str]:
|
||||
"""
|
||||
Mapping a target LoRA module name to (stacked name for LoRA A, stacked name for LoRA B)
|
||||
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
@@ -122,7 +122,7 @@ def get_stacked_name(name: str) -> Tuple[str]:
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
"""
|
||||
Mapping a module name to its magnification at output dimension
|
||||
Mapping a lora module name to its magnification at output dimension
|
||||
"""
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
@@ -137,7 +137,7 @@ def get_weight_name(
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
target_name is name of a given module,
|
||||
lora_weight_names is a set of LoRA stacked name pairs (see get_stacked_name method above)
|
||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
Else raise ValueError.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user