[Fix] Fix bugs and refactor codes in lora for better scalability. (#3652)

Co-authored-by: ShenAo1111 <1377693092@qq.com>
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
aoshen524
2025-02-20 14:51:57 -05:00
committed by GitHub
parent ac05310098
commit e79f7420be
11 changed files with 459 additions and 200 deletions

View File

@@ -18,6 +18,7 @@
# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import logging
import re
from typing import Dict, List
@@ -30,6 +31,8 @@ from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader
logger = logging.getLogger(__name__)
class LoRALayer(nn.Module):
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
@@ -173,6 +176,18 @@ class LoRAAdapter(nn.Module):
if "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights:
logger.warning(
f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
f"Initializing up projection to zero."
)
weights[up_name] = torch.zeros_like(weights[weight_name])
# FIXME: Add gate-only support for flashinfer in future implementations
assert self.lora_backend.name == "triton", (
f"LoRA weight initialization currently only supported for 'triton' backend. "
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends."
)
if "lora_A" in weight_name:
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
@@ -182,4 +197,5 @@ class LoRAAdapter(nn.Module):
[weights[weight_name], weights[up_name]], dim=0
)
weights.pop(weight_name)
weights.pop(up_name)
if up_name in weights:
weights.pop(up_name)

View File

@@ -26,6 +26,11 @@ class LoRAConfig:
self.path = path
self.hf_config = self.get_lora_config()
self.target_modules = self.hf_config["target_modules"]
# TODO: Support more modules
if any(module in self.target_modules for module in ["embed_tokens", "lm_head"]):
raise ValueError("Not supported yet")
self.r = self.hf_config["r"]
self.lora_alpha = self.hf_config["lora_alpha"]

View File

@@ -76,9 +76,7 @@ class LoRAManager:
self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.hf_target_names = set(self.hf_target_names) | set(
self.configs[name].target_modules
)
self.hf_target_names.update(self.configs[name].target_modules)
# Target lora weight names for lora_a and lora_b modules repectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}