Revert "fix some typos" (#6244)

This commit is contained in:
Lianmin Zheng
2025-05-12 12:53:26 -07:00
committed by GitHub
parent bad7c26fdc
commit e8e18dcdcc
95 changed files with 276 additions and 276 deletions

View File

@@ -41,13 +41,13 @@ class BaseLoRABackend:
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of LoRA a modules with current backend.
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of LoRA weights with shape (num_lora, c * r, input_dim),
here r is LoRA rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
Returns:
result with shape (s, c * r)
@@ -57,12 +57,12 @@ class BaseLoRABackend:
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of LoRA b modules with current backend.
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is LoRA rank
weights: a set of LoRA weights with shape (num_lora, output_dim, r)
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
@@ -77,7 +77,7 @@ class BaseLoRABackend:
*args,
**kwargs,
) -> torch.Tensor:
"""Run the LoRA pass for QKV Layer.
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
@@ -100,7 +100,7 @@ class BaseLoRABackend:
*args,
**kwargs,
) -> torch.Tensor:
"""Run the LoRA pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths

View File

@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
dtype=x.dtype,
)
# Compute LoRA for gate and up proj respectively
# Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0],

View File

@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if self.lora_backend.fuse_stacked_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The LoRA rank of q and kv should be the same when enabling fusion of qkv lora_b"
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)

View File

@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
self.config: LoRAConfig = config
self.base_hf_config: AutoConfig = base_hf_config
# LoRA weights in cpu. The weights are loaded from checkpoint.
# lora weights in cpu. The weights are loaded from checkpoint.
self.weights: Dict[str, torch.Tensor] = {}
@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
# Collect target q/k/v modules. This process is necessary since there might be no LoRA attached to k_proj
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set()
for weight_name in weight_names:
if "k_proj" in weight_name:
@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
return
for weight_name in weight_names:
# We assume every LoRA adaptor should contain LoRA modules for q_proj
# We assume every lora adaptor should contain lora modules for q_proj
if "q_proj" in weight_name:
q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj")
@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have LoRA, initialize it to zero
# If k_proj doesn't have lora, initialize it to zero
k_proj_weight = (
weights[k_name]
if "k_proj" in target_module

View File

@@ -93,14 +93,14 @@ class LoRAManager:
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
# Target module names in HuggingFace LoRA configs.
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.hf_target_names.update(self.configs[name].target_modules)
# Target LoRA weight names for lora_a and lora_b modules respectively.
# Target lora weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self.lora_weight_names: Set[Tuple[str]] = set(
[get_stacked_name(module) for module in self.hf_target_names]
@@ -119,11 +119,11 @@ class LoRAManager:
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# misc LoRA configs
# misc lora configs
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
if self.lora_backend == "flashinfer":
# FIXME: remove the restrictions after supporting multi-rank for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
scaling = list(self.loras.values())[0].scaling
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
@@ -144,16 +144,16 @@ class LoRAManager:
self.lora_modules,
)
# Initialize target LoRA modules in memory pool
# Initialize target lora modules in memory pool
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active LoRAs into LoRA memory pool
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
# set up batch info shared by all LoRA modules
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
if (
@@ -221,7 +221,7 @@ class LoRAManager:
)
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each LoRA modules
# call set_lora_info for each lora modules
for layer_id, modules in self.lora_modules.items():
for module_name, module in modules:
if "qkv_proj" in module_name:

View File

@@ -16,7 +16,7 @@ from sglang.srt.lora.utils import (
class LoRAMemoryPool:
"""Class for memory pool management of LoRA modules"""
"""Class for memory pool management of lora modules"""
def __init__(
self,
@@ -38,7 +38,7 @@ class LoRAMemoryPool:
self.tp_rank: int = tp_rank
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
# Both A_buffer and B_buffer maps LoRA weight names to its buffer space.
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
# B_buffer contains num_layer number of column-major tensors with shape
@@ -46,10 +46,10 @@ class LoRAMemoryPool:
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
# LoRA uid -> buffer idx in memory pool
# Lora uid -> buffer idx in memory pool
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> LoRA uid in memory pool
# Buffer idx -> lora uid in memory pool
# All uids are initialized as empty strings for empty buffer slots
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
@@ -95,7 +95,7 @@ class LoRAMemoryPool:
base_model: torch.nn.Module,
):
# lora_weight_names is a set of name pairs indicating each pair of LoRA modules to load
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
device = next(base_model.parameters()).device
@@ -137,7 +137,7 @@ class LoRAMemoryPool:
return buffer_id, ""
for buffer_id in range(self.max_loras_per_batch):
# Evict unneeded LoRA
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
return buffer_id, self.buffer_id_to_uid[buffer_id]

View File

@@ -37,7 +37,7 @@ def _gate_up_lora_b_kernel(
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to LoRA rank
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, 2 * output_dim, K)
# output: (s, 2 * output_dim)
# output_dim >> K

View File

@@ -39,7 +39,7 @@ def _qkv_lora_b_kernel(
):
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to LoRA rank
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, N_Q + 2 * N_KV, K)
# output: (s, N_Q + 2 * N_KV)
# N_Q >> K, N_KV >> K

View File

@@ -22,13 +22,13 @@ class LoRABatchInfo:
# Maximum sequence length of current batch
max_len: int
# The index of LoRA adapter used by each sequence, in shape (bs,)
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
# ranks of each LoRA adapter, in shape (lora_num,)
# ranks of each lora adapter, in shape (lora_num,)
lora_ranks: torch.Tensor
# scaling of each LoRA adapter, in shape (lora_num,)
# scaling of each lora adapter, in shape (lora_num,)
scalings: torch.Tensor
@@ -51,9 +51,9 @@ def get_customized_names_from_hf_names(
hf_module_names: Set[str], base_model: torch.nn.Module
) -> Set[str]:
"""
This function takes in a set of HuggingFace style module names:
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized SGLang layers:
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if hasattr(base_model, "get_module_name"):
@@ -87,7 +87,7 @@ def get_hidden_dim(
else:
"""
WARNING: get_hidden_dim() is not defined,
which is used to get the hidden dim for different LoRA modules
which is used to get the hidden dim for different lora modules
Use the default one, but please check if it is correct for your model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
@@ -108,7 +108,7 @@ def get_hidden_dim(
def get_stacked_name(name: str) -> Tuple[str]:
"""
Mapping a target LoRA module name to (stacked name for LoRA A, stacked name for LoRA B)
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
"""
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
@@ -122,7 +122,7 @@ def get_stacked_name(name: str) -> Tuple[str]:
def get_stacked_multiply(module_name: str) -> int:
"""
Mapping a module name to its magnification at output dimension
Mapping a lora module name to its magnification at output dimension
"""
stacked_rank = {
"qkv_proj": 3,
@@ -137,7 +137,7 @@ def get_weight_name(
) -> Optional[str]:
"""
target_name is name of a given module,
lora_weight_names is a set of LoRA stacked name pairs (see get_stacked_name method above)
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name
Else raise ValueError.
"""