[Feature] Support Tensor Parallelism and Weight Slicing for Lora (#4274)
Co-authored-by: ShenAo1111 <1377693092@qq.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -782,6 +782,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
self.q_proj_shard_size = self.num_heads * self.head_size
|
||||
self.kv_proj_shard_size = self.num_kv_heads * self.head_size
|
||||
input_size = self.hidden_size
|
||||
output_size = (
|
||||
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -38,8 +40,22 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
def set_lora_info(self, *args):
|
||||
pass
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
pass
|
||||
|
||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
"""
|
||||
Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
|
||||
|
||||
Note: The current version does not yet implement the LoRA functionality.
|
||||
This class behaves exactly the same as the base VocabParallelEmbedding.
|
||||
Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: VocabParallelEmbedding,
|
||||
@@ -101,6 +117,16 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
return A
|
||||
|
||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
||||
shard_size = self.base_layer.output_partition_sizes[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = (tp_rank + 1) * shard_size
|
||||
B = B[start_idx:end_idx, :]
|
||||
return B
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
@@ -120,6 +146,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.set_lora = True
|
||||
self.A_buffer_gate_up = A_buffer
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
# TODO: avoid using contiguous() in GPU.
|
||||
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
||||
self.B_buffer_gate_up = torch.cat(
|
||||
(B_buffer[0], B_buffer[1]), dim=-2
|
||||
@@ -142,6 +169,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
return A
|
||||
|
||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
||||
# Since the outputs for both gate and up are identical, we use a random one.
|
||||
shard_size = self.base_layer.output_partition_sizes[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = (tp_rank + 1) * shard_size
|
||||
return B[:, start_idx:end_idx, :]
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def init__(
|
||||
@@ -210,6 +247,27 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
return A
|
||||
|
||||
def slice_lora_b_weights(
|
||||
self, B: List[torch.Tensor], tp_rank: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
B_q, B_kv = B
|
||||
base_layer = self.base_layer
|
||||
q_proj_shard_size = base_layer.q_proj_shard_size
|
||||
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
||||
num_kv_head_replicas = base_layer.num_kv_head_replicas
|
||||
|
||||
q_start_idx = q_proj_shard_size * tp_rank
|
||||
q_end_idx = q_start_idx + q_proj_shard_size
|
||||
|
||||
kv_shard_id = tp_rank // num_kv_head_replicas
|
||||
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
||||
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
||||
|
||||
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
@@ -274,6 +332,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
shard_size = self.base_layer.input_size_per_partition
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = (tp_rank + 1) * shard_size
|
||||
A = A[:, start_idx:end_idx].contiguous()
|
||||
return A
|
||||
|
||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
||||
return B
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
|
||||
|
||||
@@ -39,16 +39,9 @@ class LoRALayer(nn.Module):
|
||||
super().__init__()
|
||||
self.config: LoRAConfig = config
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
|
||||
# lora weights in cpu. The weights are loaded from checkpoint.
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
self.weight_gpu: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = None
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
@@ -77,19 +70,6 @@ class LoRAAdapter(nn.Module):
|
||||
)
|
||||
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
self.weights_gpu: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
for layer in self.layers:
|
||||
layer.load_to_gpu()
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = None
|
||||
for layer in self.layers:
|
||||
layer.offload_from_gpu()
|
||||
|
||||
# initialize the LoRA weights to cpu
|
||||
def initialize_weights(self):
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
||||
from sglang.srt.lora.layers import get_lora_layer
|
||||
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
||||
@@ -51,6 +51,8 @@ class LoRAManager:
|
||||
load_config: LoadConfig,
|
||||
dtype: torch.dtype,
|
||||
lora_backend: str = "triton",
|
||||
tp_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
):
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.lora_paths: Dict[str, str] = lora_paths
|
||||
@@ -58,6 +60,9 @@ class LoRAManager:
|
||||
self.max_loras_per_batch: int = max_loras_per_batch
|
||||
self.load_config: LoadConfig = load_config
|
||||
self.dtype: torch.dtype = dtype
|
||||
self.device: torch.device = next(self.base_model.parameters()).device
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
|
||||
# LoRA backend for running sgemm kernels
|
||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||
@@ -110,7 +115,13 @@ class LoRAManager:
|
||||
def init_lora_memory_pool(self):
|
||||
# Initialize memory pool
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
|
||||
self.base_hf_config,
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim,
|
||||
self.dtype,
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
self.lora_modules,
|
||||
)
|
||||
|
||||
# Initialize target lora modules in memory pool
|
||||
@@ -131,12 +142,12 @@ class LoRAManager:
|
||||
seg_lens = (
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs, device="cuda")
|
||||
else torch.ones(bs, device=self.device)
|
||||
)
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
|
||||
@@ -150,22 +161,32 @@ class LoRAManager:
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
|
||||
# call set_lora_info for each lora modules
|
||||
for module_name, module in self.lora_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
if "qkv_proj" not in module_name:
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
||||
)
|
||||
else:
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
|
||||
self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
|
||||
self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
|
||||
)
|
||||
for layer_id, modules in self.lora_modules.items():
|
||||
for module_name, module in modules:
|
||||
if "qkv_proj" in module_name:
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(
|
||||
"qkv_proj", layer_id, LoRAType.LORA_A
|
||||
),
|
||||
self.memory_pool.get_tensor(
|
||||
"q_proj", layer_id, LoRAType.LORA_B
|
||||
),
|
||||
self.memory_pool.get_tensor(
|
||||
"kv_proj", layer_id, LoRAType.LORA_B
|
||||
),
|
||||
)
|
||||
else:
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(
|
||||
weight_name, layer_id, LoRAType.LORA_A
|
||||
),
|
||||
self.memory_pool.get_tensor(
|
||||
weight_name, layer_id, LoRAType.LORA_B
|
||||
),
|
||||
)
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
@@ -182,10 +203,13 @@ class LoRAManager:
|
||||
)
|
||||
|
||||
# Monkey patch to use the LoRA version layers
|
||||
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
|
||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
||||
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
||||
}
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in customized_target_names:
|
||||
self.lora_modules.append(
|
||||
layer_id = get_layer_id(module_name)
|
||||
self.lora_modules[layer_id].append(
|
||||
(module_name, self.set_lora_module(module_name, module))
|
||||
)
|
||||
|
||||
@@ -2,9 +2,12 @@ from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import divide
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.utils import (
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
||||
LoRAType,
|
||||
get_hidden_dim,
|
||||
get_stacked_multiply,
|
||||
@@ -21,6 +24,9 @@ class LoRAMemoryPool:
|
||||
max_loras_per_batch: int,
|
||||
max_lora_dim: int,
|
||||
dtype: torch.dtype,
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
|
||||
):
|
||||
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -28,6 +34,9 @@ class LoRAMemoryPool:
|
||||
self.max_loras_per_batch: int = max_loras_per_batch
|
||||
self.max_lora_dim: int = max_lora_dim
|
||||
self.dtype: torch.dtype = dtype
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
||||
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
@@ -45,6 +54,41 @@ class LoRAMemoryPool:
|
||||
# Here we don't initalize to None since None is a valid uid
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
|
||||
def get_lora_A_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
||||
"""
|
||||
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_name)
|
||||
if self.tp_size > 1:
|
||||
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
input_dim = divide(input_dim, self.tp_size)
|
||||
return (
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
input_dim,
|
||||
)
|
||||
|
||||
def get_lora_B_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
||||
"""
|
||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_name)
|
||||
if self.tp_size > 1:
|
||||
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
output_dim = divide(output_dim, self.tp_size)
|
||||
return (
|
||||
c,
|
||||
self.max_loras_per_batch,
|
||||
output_dim,
|
||||
self.max_lora_dim,
|
||||
)
|
||||
|
||||
def init_buffers(
|
||||
self,
|
||||
lora_weight_names: Set[Tuple[str]],
|
||||
@@ -54,42 +98,31 @@ class LoRAMemoryPool:
|
||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
||||
|
||||
for module_A, module_B in lora_weight_names:
|
||||
# Init A tensor, column_major=False
|
||||
input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_A)
|
||||
if module_A not in self.A_buffer:
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
input_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(self.num_layer)
|
||||
]
|
||||
|
||||
# Init B tensor, column_major=True
|
||||
_, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_B)
|
||||
if module_B not in self.B_buffer:
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
(
|
||||
c, # stacked lora_b modules might need separation
|
||||
self.max_loras_per_batch,
|
||||
output_dim,
|
||||
self.max_lora_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(self.num_layer)
|
||||
]
|
||||
device = next(base_model.parameters()).device
|
||||
lora_module_A_names = set([name[0] for name in lora_weight_names])
|
||||
lora_module_B_names = set([name[1] for name in lora_weight_names])
|
||||
# Init A tensor, column_major=False
|
||||
for module_A in lora_module_A_names:
|
||||
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
lora_A_shape,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(self.num_layer)
|
||||
]
|
||||
# Init B tensor, column_major=True
|
||||
for module_B in lora_module_B_names:
|
||||
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
lora_B_shape,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(self.num_layer)
|
||||
]
|
||||
|
||||
def prepare_lora_batch(
|
||||
self,
|
||||
@@ -136,30 +169,56 @@ class LoRAMemoryPool:
|
||||
assert lora_adapter is not None
|
||||
for layer_id in range(self.num_layer):
|
||||
layer_weights = lora_adapter.layers[layer_id].weights
|
||||
temp_A_buffer: Dict[str, torch.Tensor] = {}
|
||||
temp_B_buffer: Dict[str, torch.Tensor] = {}
|
||||
for name, weights in layer_weights.items():
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = get_weight_name(
|
||||
name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
if lora_weight_name:
|
||||
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
|
||||
weights
|
||||
)
|
||||
temp_A_buffer[lora_weight_name] = weights
|
||||
else:
|
||||
lora_weight_name = get_weight_name(
|
||||
name, self.lora_weight_names, LoRAType.LORA_B
|
||||
)
|
||||
if lora_weight_name:
|
||||
c = get_stacked_multiply(lora_weight_name)
|
||||
if c > 1:
|
||||
for stacked_id in range(c):
|
||||
self.B_buffer[lora_weight_name][layer_id][stacked_id][
|
||||
buffer_id
|
||||
].copy_(weights[stacked_id])
|
||||
else:
|
||||
self.B_buffer[lora_weight_name][layer_id][0][
|
||||
buffer_id
|
||||
].copy_(weights)
|
||||
temp_B_buffer[lora_weight_name] = weights
|
||||
|
||||
if self.tp_size > 1:
|
||||
cur_layer_modules = self.lora_modules[layer_id]
|
||||
for module_name, module in cur_layer_modules:
|
||||
if "qkv_proj" in module_name:
|
||||
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
||||
temp_A_buffer["qkv_proj"], self.tp_rank
|
||||
)
|
||||
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
|
||||
module.slice_lora_b_weights(
|
||||
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
|
||||
self.tp_rank,
|
||||
)
|
||||
)
|
||||
else:
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
||||
temp_A_buffer[weight_name], self.tp_rank
|
||||
)
|
||||
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
||||
temp_B_buffer[weight_name], self.tp_rank
|
||||
)
|
||||
|
||||
for name, weights in temp_A_buffer.items():
|
||||
self.A_buffer[name][layer_id][buffer_id].copy_(weights)
|
||||
|
||||
for name, weights in temp_B_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
if c > 1:
|
||||
for stacked_id in range(c):
|
||||
self.B_buffer[name][layer_id][stacked_id][buffer_id].copy_(
|
||||
weights[stacked_id]
|
||||
)
|
||||
else:
|
||||
self.B_buffer[name][layer_id][0][buffer_id].copy_(weights)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
|
||||
@@ -133,9 +133,20 @@ def get_weight_name(
|
||||
target_name is name of a given module,
|
||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
Else return None
|
||||
Else raise ValueError.
|
||||
"""
|
||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||
for weight_name_pair in lora_weight_names:
|
||||
if weight_name_pair[idx] in target_name:
|
||||
return weight_name_pair[idx]
|
||||
raise ValueError(
|
||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
||||
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
||||
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
||||
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
||||
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
||||
|
||||
@@ -188,9 +188,6 @@ class ModelRunner:
|
||||
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
||||
if self.tp_size > 1 and supports_torch_tp:
|
||||
self.apply_torch_tp()
|
||||
self.torch_tp_applied = True
|
||||
else:
|
||||
self.torch_tp_applied = False
|
||||
|
||||
# Init lora
|
||||
if server_args.lora_paths is not None:
|
||||
@@ -624,6 +621,8 @@ class ModelRunner:
|
||||
load_config=self.load_config,
|
||||
dtype=self.dtype,
|
||||
lora_backend=self.server_args.lora_backend,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
logger.info("LoRA manager ready.")
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||
"""
|
||||
if device == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
num_gpus = cuda_device_count_stateless()
|
||||
assert gpu_id < num_gpus
|
||||
|
||||
if torch.cuda.current_device() != gpu_id:
|
||||
|
||||
@@ -437,6 +437,7 @@ class SRTRunner:
|
||||
speculative_eagle_topk: Optional[int] = None,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
disable_overlap_schedule: bool = False,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -470,6 +471,7 @@ class SRTRunner:
|
||||
enable_ep_moe=enable_ep_moe,
|
||||
disable_overlap_schedule=disable_overlap_schedule,
|
||||
cuda_graph_max_bs=4,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**spec_kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user