Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)

This commit is contained in:
Lifu Huang
2025-08-10 01:04:45 -07:00
committed by GitHub
parent 6b847a9a05
commit f8a173bb50
10 changed files with 137 additions and 525 deletions

View File

@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
return int(match.group(1))
def get_customized_names_from_hf_names(
hf_module_names: Set[str], base_model: torch.nn.Module
) -> Set[str]:
"""
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if hasattr(base_model, "get_module_name"):
return {base_model.get_module_name(name) for name in hf_module_names}
else:
"""
Fallback solution of mapping from config module name to module name in model class.
Please check if it aligns with your base model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return {params_mapping.get(name, name) for name in hf_module_names}
def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module
) -> Tuple[int]:
@@ -95,22 +67,9 @@ def get_hidden_dim(
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
# TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return (
config.hidden_size,
None, # qkv_proj is only used in LoRA A
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
head_dim * config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
head_dim * config.num_attention_heads,
return config.hidden_size, head_dim * (
config.num_attention_heads + config.num_key_value_heads * 2
)
elif module_name == "o_proj":
return (
@@ -118,7 +77,7 @@ def get_hidden_dim(
config.hidden_size,
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
return config.hidden_size, config.intermediate_size * 2
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
@@ -127,26 +86,22 @@ def get_hidden_dim(
def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
) -> set[str]:
"""
Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
"q_proj": (["qkv_proj"], ["q_proj"]),
"k_proj": (["qkv_proj"], ["kv_proj"]),
"v_proj": (["qkv_proj"], ["kv_proj"]),
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
result = (set(), set())
result = set()
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
weight_name = params_mapping.get(name, name)
result.add(weight_name)
return result
@@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
"""
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name(
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
target_name: str, lora_weight_names: Tuple[Set[str]]
) -> Optional[str]:
"""
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
Get the weight name in lora_weight_names that can match target_name.
If there is a weight name in lora_weight_names that can match target_name, return this name
Else raise ValueError.
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name in lora_weight_names[idx]:
for weight_name in lora_weight_names:
if weight_name in target_name:
return weight_name
raise ValueError(
@@ -180,9 +133,4 @@ def get_weight_name(
)
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]