[Fix] Fix bugs and refactor codes in lora for better scalability. (#3652)
Co-authored-by: ShenAo1111 <1377693092@qq.com> Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
# LoRA layers class inheritance adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -30,6 +31,8 @@ from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
||||
@@ -173,6 +176,18 @@ class LoRAAdapter(nn.Module):
|
||||
if "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
if up_name not in weights:
|
||||
logger.warning(
|
||||
f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
|
||||
f"Initializing up projection to zero."
|
||||
)
|
||||
weights[up_name] = torch.zeros_like(weights[weight_name])
|
||||
# FIXME: Add gate-only support for flashinfer in future implementations
|
||||
assert self.lora_backend.name == "triton", (
|
||||
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
||||
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
||||
f"or consider implementing custom initialization logic for other backends."
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = torch.cat(
|
||||
(weights[weight_name], weights[up_name]), 0
|
||||
@@ -182,4 +197,5 @@ class LoRAAdapter(nn.Module):
|
||||
[weights[weight_name], weights[up_name]], dim=0
|
||||
)
|
||||
weights.pop(weight_name)
|
||||
weights.pop(up_name)
|
||||
if up_name in weights:
|
||||
weights.pop(up_name)
|
||||
|
||||
@@ -26,6 +26,11 @@ class LoRAConfig:
|
||||
self.path = path
|
||||
self.hf_config = self.get_lora_config()
|
||||
self.target_modules = self.hf_config["target_modules"]
|
||||
|
||||
# TODO: Support more modules
|
||||
if any(module in self.target_modules for module in ["embed_tokens", "lm_head"]):
|
||||
raise ValueError("Not supported yet")
|
||||
|
||||
self.r = self.hf_config["r"]
|
||||
self.lora_alpha = self.hf_config["lora_alpha"]
|
||||
|
||||
|
||||
@@ -76,9 +76,7 @@ class LoRAManager:
|
||||
self.hf_target_names: Set[str] = set()
|
||||
for name, path in self.lora_paths.items():
|
||||
self.configs[name] = LoRAConfig(path)
|
||||
self.hf_target_names = set(self.hf_target_names) | set(
|
||||
self.configs[name].target_modules
|
||||
)
|
||||
self.hf_target_names.update(self.configs[name].target_modules)
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules repectively.
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
||||
|
||||
@@ -189,9 +189,17 @@ class HFRunner:
|
||||
return_dict_in_generate=True,
|
||||
output_scores=(not self.output_str_only),
|
||||
)
|
||||
output_strs.append(
|
||||
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
|
||||
|
||||
text = self.tokenizer.decode(
|
||||
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
if not self.output_str_only:
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
@@ -275,6 +283,7 @@ class SRTRunner:
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
mem_fraction_static: float = 0.65,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -283,7 +292,7 @@ class SRTRunner:
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=0.65,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
@@ -315,7 +324,15 @@ class SRTRunner:
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
output_strs.append(response["text"])
|
||||
text = response["text"]
|
||||
|
||||
# Check if the text is empty or only whitespace.
|
||||
if not text.strip():
|
||||
raise ValueError(
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
|
||||
Reference in New Issue
Block a user