Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)
This commit is contained in:
@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
def get_customized_names_from_hf_names(
|
||||
hf_module_names: Set[str], base_model: torch.nn.Module
|
||||
) -> Set[str]:
|
||||
"""
|
||||
This function takes in a set of huggingface style module names:
|
||||
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
and outputs a set of module names of customized sglang layers:
|
||||
e.g., {"qkv_proj", "o_proj"}
|
||||
"""
|
||||
if hasattr(base_model, "get_module_name"):
|
||||
return {base_model.get_module_name(name) for name in hf_module_names}
|
||||
else:
|
||||
"""
|
||||
Fallback solution of mapping from config module name to module name in model class.
|
||||
Please check if it aligns with your base model.
|
||||
Please implement the function in the model class if it is not.
|
||||
You can reference this function in llama.py.
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
return {params_mapping.get(name, name) for name in hf_module_names}
|
||||
|
||||
|
||||
def get_hidden_dim(
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
@@ -95,22 +67,9 @@ def get_hidden_dim(
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
# TODO: the special handling of qkv will be addressed in #8940.
|
||||
if module_name == "qkv_proj":
|
||||
return (
|
||||
config.hidden_size,
|
||||
None, # qkv_proj is only used in LoRA A
|
||||
)
|
||||
elif module_name == "kv_proj":
|
||||
return (
|
||||
None, # kv_proj is only used in LoRA B
|
||||
head_dim * config.num_key_value_heads,
|
||||
)
|
||||
elif module_name == "q_proj":
|
||||
return (
|
||||
None, # q_proj is only used in LoRA B
|
||||
head_dim * config.num_attention_heads,
|
||||
return config.hidden_size, head_dim * (
|
||||
config.num_attention_heads + config.num_key_value_heads * 2
|
||||
)
|
||||
elif module_name == "o_proj":
|
||||
return (
|
||||
@@ -118,7 +77,7 @@ def get_hidden_dim(
|
||||
config.hidden_size,
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return config.hidden_size, config.intermediate_size
|
||||
return config.hidden_size, config.intermediate_size * 2
|
||||
elif module_name == "down_proj":
|
||||
return config.intermediate_size, config.hidden_size
|
||||
else:
|
||||
@@ -127,26 +86,22 @@ def get_hidden_dim(
|
||||
|
||||
def get_normalized_lora_weight_names(
|
||||
target_modules: Iterable[str],
|
||||
) -> Tuple[set[str], set[str]]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": (["qkv_proj"], ["q_proj"]),
|
||||
"k_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"v_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
|
||||
result = (set(), set())
|
||||
result = set()
|
||||
for name in target_modules:
|
||||
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
||||
result[0].update(lora_a)
|
||||
result[1].update(lora_b)
|
||||
weight_name = params_mapping.get(name, name)
|
||||
result.add(weight_name)
|
||||
return result
|
||||
|
||||
|
||||
@@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
|
||||
"""
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
"kv_proj": 2,
|
||||
"gate_up_proj": 2,
|
||||
}
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
|
||||
def get_weight_name(
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
target_name is name of a given module,
|
||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||
Get the weight name in lora_weight_names that can match target_name.
|
||||
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
Else raise ValueError.
|
||||
"""
|
||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||
for weight_name in lora_weight_names[idx]:
|
||||
for weight_name in lora_weight_names:
|
||||
if weight_name in target_name:
|
||||
return weight_name
|
||||
raise ValueError(
|
||||
@@ -180,9 +133,4 @@ def get_weight_name(
|
||||
)
|
||||
|
||||
|
||||
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
||||
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
||||
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
||||
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
||||
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
||||
|
||||
Reference in New Issue
Block a user