Revert "fix some typos" (#6244)
This commit is contained in:
@@ -22,13 +22,13 @@ class LoRABatchInfo:
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of LoRA adapter used by each sequence, in shape (bs,)
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
# ranks of each LoRA adapter, in shape (lora_num,)
|
||||
# ranks of each lora adapter, in shape (lora_num,)
|
||||
lora_ranks: torch.Tensor
|
||||
|
||||
# scaling of each LoRA adapter, in shape (lora_num,)
|
||||
# scaling of each lora adapter, in shape (lora_num,)
|
||||
scalings: torch.Tensor
|
||||
|
||||
|
||||
@@ -51,9 +51,9 @@ def get_customized_names_from_hf_names(
|
||||
hf_module_names: Set[str], base_model: torch.nn.Module
|
||||
) -> Set[str]:
|
||||
"""
|
||||
This function takes in a set of HuggingFace style module names:
|
||||
This function takes in a set of huggingface style module names:
|
||||
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
and outputs a set of module names of customized SGLang layers:
|
||||
and outputs a set of module names of customized sglang layers:
|
||||
e.g., {"qkv_proj", "o_proj"}
|
||||
"""
|
||||
if hasattr(base_model, "get_module_name"):
|
||||
@@ -87,7 +87,7 @@ def get_hidden_dim(
|
||||
else:
|
||||
"""
|
||||
WARNING: get_hidden_dim() is not defined,
|
||||
which is used to get the hidden dim for different LoRA modules
|
||||
which is used to get the hidden dim for different lora modules
|
||||
Use the default one, but please check if it is correct for your model.
|
||||
Please implement the function in the model class if it is not.
|
||||
You can reference this function in llama.py.
|
||||
@@ -108,7 +108,7 @@ def get_hidden_dim(
|
||||
|
||||
def get_stacked_name(name: str) -> Tuple[str]:
|
||||
"""
|
||||
Mapping a target LoRA module name to (stacked name for LoRA A, stacked name for LoRA B)
|
||||
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
@@ -122,7 +122,7 @@ def get_stacked_name(name: str) -> Tuple[str]:
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
"""
|
||||
Mapping a module name to its magnification at output dimension
|
||||
Mapping a lora module name to its magnification at output dimension
|
||||
"""
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
@@ -137,7 +137,7 @@ def get_weight_name(
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
target_name is name of a given module,
|
||||
lora_weight_names is a set of LoRA stacked name pairs (see get_stacked_name method above)
|
||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
Else raise ValueError.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user