Add typo checker in pre-commit (#6179)
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -100,7 +100,7 @@ class LoRAManager:
|
||||
self.configs[name] = LoRAConfig(path)
|
||||
self.hf_target_names.update(self.configs[name].target_modules)
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules repectively.
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = set(
|
||||
[get_stacked_name(module) for module in self.hf_target_names]
|
||||
|
||||
@@ -50,15 +50,15 @@ class LoRAMemoryPool:
|
||||
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
||||
|
||||
# Buffer idx -> lora uid in memory pool
|
||||
# All uids are initalized as empty strings for empty buffer slots
|
||||
# Here we don't initalize to None since None is a valid uid
|
||||
# All uids are initialized as empty strings for empty buffer slots
|
||||
# Here we don't initialize to None since None is a valid uid
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
|
||||
def get_lora_A_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_name)
|
||||
@@ -75,7 +75,7 @@ class LoRAMemoryPool:
|
||||
self, module_name: str, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_name)
|
||||
|
||||
@@ -77,7 +77,7 @@ def _gate_up_lora_b_kernel(
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
# Iterate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
|
||||
@@ -79,7 +79,7 @@ def _qkv_lora_b_kernel(
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
# Iterate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
|
||||
@@ -67,7 +67,7 @@ def _sgemm_lora_a_kernel(
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
# Iterate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
|
||||
@@ -69,7 +69,7 @@ def _sgemm_lora_b_kernel(
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
# Iterate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
|
||||
@@ -79,7 +79,7 @@ def get_hidden_dim(
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
|
||||
if hasattr(base_model, "get_hidden_dim"):
|
||||
|
||||
Reference in New Issue
Block a user