[Feature] Support Tensor Parallelism and Weight Slicing for Lora (#4274)

Co-authored-by: ShenAo1111 <1377693092@qq.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
aoshen524
2025-03-18 23:33:07 -04:00
committed by GitHub
parent 3196999f63
commit 588865f0e0
13 changed files with 528 additions and 103 deletions

View File

@@ -782,6 +782,8 @@ class QKVParallelLinear(ColumnParallelLinear):
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
self.q_proj_shard_size = self.num_heads * self.head_size
self.kv_proj_shard_size = self.num_kv_heads * self.head_size
input_size = self.hidden_size
output_size = (
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size

View File

@@ -1,3 +1,5 @@
from typing import List, Tuple
import torch
from torch import nn
@@ -38,8 +40,22 @@ class BaseLayerWithLoRA(nn.Module):
def set_lora_info(self, *args):
pass
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
pass
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
"""
Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
Note: The current version does not yet implement the LoRA functionality.
This class behaves exactly the same as the base VocabParallelEmbedding.
Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
"""
def __init__(
self,
base_layer: VocabParallelEmbedding,
@@ -101,6 +117,16 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
shard_size = self.base_layer.output_partition_sizes[0]
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
B = B[start_idx:end_idx, :]
return B
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
@@ -120,6 +146,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.set_lora = True
self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b:
# TODO: avoid using contiguous() in GPU.
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self.B_buffer_gate_up = torch.cat(
(B_buffer[0], B_buffer[1]), dim=-2
@@ -142,6 +169,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
else base_output + lora_output * self.scaling
)
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
# Since the outputs for both gate and up are identical, we use a random one.
shard_size = self.base_layer.output_partition_sizes[0]
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
return B[:, start_idx:end_idx, :]
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
@@ -210,6 +247,27 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
else base_output + lora_output * self.scaling
)
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
def slice_lora_b_weights(
self, B: List[torch.Tensor], tp_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
B_q, B_kv = B
base_layer = self.base_layer
q_proj_shard_size = base_layer.q_proj_shard_size
kv_proj_shard_size = base_layer.kv_proj_shard_size
num_kv_head_replicas = base_layer.num_kv_head_replicas
q_start_idx = q_proj_shard_size * tp_rank
q_end_idx = q_start_idx + q_proj_shard_size
kv_shard_id = tp_rank // num_kv_head_replicas
kv_start_idx = kv_proj_shard_size * kv_shard_id
kv_end_idx = kv_start_idx + kv_proj_shard_size
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
@@ -274,6 +332,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
output_bias = self.base_layer.bias
return output, output_bias
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
shard_size = self.base_layer.input_size_per_partition
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
A = A[:, start_idx:end_idx].contiguous()
return A
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
return B
def get_lora_layer(
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend

View File

@@ -39,16 +39,9 @@ class LoRALayer(nn.Module):
super().__init__()
self.config: LoRAConfig = config
self.base_hf_config: AutoConfig = base_hf_config
# lora weights in cpu. The weights are loaded from checkpoint.
self.weights: Dict[str, torch.Tensor] = {}
self.weight_gpu: Dict[str, torch.Tensor] = {}
def load_to_gpu(self):
for name, weight in self.weights.items():
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
def offload_from_gpu(self):
for name, weight in self.weights.items():
self.weight_gpu[name] = None
class LoRAAdapter(nn.Module):
@@ -77,19 +70,6 @@ class LoRAAdapter(nn.Module):
)
self.weights: Dict[str, torch.Tensor] = {}
self.weights_gpu: Dict[str, torch.Tensor] = {}
def load_to_gpu(self):
for name, weight in self.weights.items():
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
for layer in self.layers:
layer.load_to_gpu()
def offload_from_gpu(self):
for name, weight in self.weights.items():
self.weights_gpu[name] = None
for layer in self.layers:
layer.offload_from_gpu()
# initialize the LoRA weights to cpu
def initialize_weights(self):

View File

@@ -23,7 +23,7 @@ import torch
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import get_lora_layer
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.mem_pool import LoRAMemoryPool
@@ -51,6 +51,8 @@ class LoRAManager:
load_config: LoadConfig,
dtype: torch.dtype,
lora_backend: str = "triton",
tp_size: int = 1,
tp_rank: int = 0,
):
self.base_model: torch.nn.Module = base_model
self.lora_paths: Dict[str, str] = lora_paths
@@ -58,6 +60,9 @@ class LoRAManager:
self.max_loras_per_batch: int = max_loras_per_batch
self.load_config: LoadConfig = load_config
self.dtype: torch.dtype = dtype
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -110,7 +115,13 @@ class LoRAManager:
def init_lora_memory_pool(self):
# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
self.base_hf_config,
self.max_loras_per_batch,
self.max_lora_dim,
self.dtype,
self.tp_size,
self.tp_rank,
self.lora_modules,
)
# Initialize target lora modules in memory pool
@@ -131,12 +142,12 @@ class LoRAManager:
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device="cuda")
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
@@ -150,22 +161,32 @@ class LoRAManager:
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each lora modules
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
if "qkv_proj" not in module_name:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
)
else:
module.set_lora_info(
self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
)
for layer_id, modules in self.lora_modules.items():
for module_name, module in modules:
if "qkv_proj" in module_name:
module.set_lora_info(
self.memory_pool.get_tensor(
"qkv_proj", layer_id, LoRAType.LORA_A
),
self.memory_pool.get_tensor(
"q_proj", layer_id, LoRAType.LORA_B
),
self.memory_pool.get_tensor(
"kv_proj", layer_id, LoRAType.LORA_B
),
)
else:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.memory_pool.get_tensor(
weight_name, layer_id, LoRAType.LORA_A
),
self.memory_pool.get_tensor(
weight_name, layer_id, LoRAType.LORA_B
),
)
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
@@ -182,10 +203,13 @@ class LoRAManager:
)
# Monkey patch to use the LoRA version layers
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
i: [] for i in range(self.base_hf_config.num_hidden_layers)
}
for module_name, module in self.base_model.named_modules():
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
self.lora_modules.append(
layer_id = get_layer_id(module_name)
self.lora_modules[layer_id].append(
(module_name, self.set_lora_module(module_name, module))
)

View File

@@ -2,9 +2,12 @@ from typing import Dict, List, Optional, Set, Tuple
import torch
from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
get_hidden_dim,
get_stacked_multiply,
@@ -21,6 +24,9 @@ class LoRAMemoryPool:
max_loras_per_batch: int,
max_lora_dim: int,
dtype: torch.dtype,
tp_size: int,
tp_rank: int,
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
):
self.base_hf_config: AutoConfig = base_hf_config
@@ -28,6 +34,9 @@ class LoRAMemoryPool:
self.max_loras_per_batch: int = max_loras_per_batch
self.max_lora_dim: int = max_lora_dim
self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
@@ -45,6 +54,41 @@ class LoRAMemoryPool:
# Here we don't initalize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1:
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
input_dim = divide(input_dim, self.tp_size)
return (
self.max_loras_per_batch,
self.max_lora_dim * c,
input_dim,
)
def get_lora_B_shape(
self, module_name: str, base_model: torch.nn.Module
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1:
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size)
return (
c,
self.max_loras_per_batch,
output_dim,
self.max_lora_dim,
)
def init_buffers(
self,
lora_weight_names: Set[Tuple[str]],
@@ -54,42 +98,31 @@ class LoRAMemoryPool:
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
for module_A, module_B in lora_weight_names:
# Init A tensor, column_major=False
input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
c = get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
input_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(self.num_layer)
]
# Init B tensor, column_major=True
_, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
c = get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
c, # stacked lora_b modules might need separation
self.max_loras_per_batch,
output_dim,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(self.num_layer)
]
device = next(base_model.parameters()).device
lora_module_A_names = set([name[0] for name in lora_weight_names])
lora_module_B_names = set([name[1] for name in lora_weight_names])
# Init A tensor, column_major=False
for module_A in lora_module_A_names:
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
self.A_buffer[module_A] = [
torch.empty(
lora_A_shape,
dtype=self.dtype,
device=device,
)
for i in range(self.num_layer)
]
# Init B tensor, column_major=True
for module_B in lora_module_B_names:
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
self.B_buffer[module_B] = [
torch.empty(
lora_B_shape,
dtype=self.dtype,
device=device,
)
for _ in range(self.num_layer)
]
def prepare_lora_batch(
self,
@@ -136,30 +169,56 @@ class LoRAMemoryPool:
assert lora_adapter is not None
for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights
temp_A_buffer: Dict[str, torch.Tensor] = {}
temp_B_buffer: Dict[str, torch.Tensor] = {}
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_A
)
if lora_weight_name:
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
weights
)
temp_A_buffer[lora_weight_name] = weights
else:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_B
)
if lora_weight_name:
c = get_stacked_multiply(lora_weight_name)
if c > 1:
for stacked_id in range(c):
self.B_buffer[lora_weight_name][layer_id][stacked_id][
buffer_id
].copy_(weights[stacked_id])
else:
self.B_buffer[lora_weight_name][layer_id][0][
buffer_id
].copy_(weights)
temp_B_buffer[lora_weight_name] = weights
if self.tp_size > 1:
cur_layer_modules = self.lora_modules[layer_id]
for module_name, module in cur_layer_modules:
if "qkv_proj" in module_name:
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
temp_A_buffer["qkv_proj"], self.tp_rank
)
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
module.slice_lora_b_weights(
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
self.tp_rank,
)
)
else:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank
)
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
temp_B_buffer[weight_name], self.tp_rank
)
for name, weights in temp_A_buffer.items():
self.A_buffer[name][layer_id][buffer_id].copy_(weights)
for name, weights in temp_B_buffer.items():
c = get_stacked_multiply(name)
if c > 1:
for stacked_id in range(c):
self.B_buffer[name][layer_id][stacked_id][buffer_id].copy_(
weights[stacked_id]
)
else:
self.B_buffer[name][layer_id][0][buffer_id].copy_(weights)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType

View File

@@ -133,9 +133,20 @@ def get_weight_name(
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name
Else return None
Else raise ValueError.
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names:
if weight_name_pair[idx] in target_name:
return weight_name_pair[idx]
raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}"
)
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]

View File

@@ -188,9 +188,6 @@ class ModelRunner:
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()
self.torch_tp_applied = True
else:
self.torch_tp_applied = False
# Init lora
if server_args.lora_paths is not None:
@@ -624,6 +621,8 @@ class ModelRunner:
load_config=self.load_config,
dtype=self.dtype,
lora_backend=self.server_args.lora_backend,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
logger.info("LoRA manager ready.")

View File

@@ -257,7 +257,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
if device == "cuda":
num_gpus = torch.cuda.device_count()
num_gpus = cuda_device_count_stateless()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:

View File

@@ -437,6 +437,7 @@ class SRTRunner:
speculative_eagle_topk: Optional[int] = None,
speculative_num_draft_tokens: Optional[int] = None,
disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
@@ -470,6 +471,7 @@ class SRTRunner:
enable_ep_moe=enable_ep_moe,
disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4,
disable_custom_all_reduce=disable_custom_all_reduce,
**spec_kwargs,
)