testing dynamic register
This commit is contained in:
@@ -146,6 +146,7 @@ class LinearBase(torch.nn.Module):
|
|||||||
skip_bias_add: If true, skip adding bias but instead return it.
|
skip_bias_add: If true, skip adding bias but instead return it.
|
||||||
params_dtype: Data type for the parameters.
|
params_dtype: Data type for the parameters.
|
||||||
quant_config: Quantization configure.
|
quant_config: Quantization configure.
|
||||||
|
return_bias: If False, return only output tensor instead of (output, bias) tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -156,6 +157,7 @@ class LinearBase(torch.nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
return_bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -163,6 +165,7 @@ class LinearBase(torch.nn.Module):
|
|||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.return_bias = return_bias
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
@@ -198,13 +201,15 @@ class ReplicatedLinear(LinearBase):
|
|||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
return_bias: bool = True):
|
||||||
super().__init__(input_size,
|
super().__init__(input_size,
|
||||||
output_size,
|
output_size,
|
||||||
skip_bias_add,
|
skip_bias_add,
|
||||||
params_dtype,
|
params_dtype,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=prefix)
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
# All the linear layer supports quant method.
|
# All the linear layer supports quant method.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
@@ -238,6 +243,9 @@ class ReplicatedLinear(LinearBase):
|
|||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output = self.quant_method.apply(self, x, bias)
|
output = self.quant_method.apply(self, x, bias)
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
@@ -281,9 +289,10 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
output_sizes: Optional[List[int]] = None,
|
output_sizes: Optional[List[int]] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
return_bias: bool = True):
|
||||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||||
quant_config, prefix)
|
quant_config, prefix, return_bias=return_bias)
|
||||||
|
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
|
||||||
@@ -375,6 +384,9 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
output = tensor_model_parallel_all_gather(output_parallel)
|
output = tensor_model_parallel_all_gather(output_parallel)
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
@@ -418,7 +430,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
return_bias: bool = True):
|
||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||||
@@ -429,7 +442,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
skip_bias_add=skip_bias_add,
|
skip_bias_add=skip_bias_add,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
def weight_loader(self,
|
def weight_loader(self,
|
||||||
param: Parameter,
|
param: Parameter,
|
||||||
@@ -653,7 +667,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
return_bias: bool = True):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.total_num_heads = total_num_heads
|
self.total_num_heads = total_num_heads
|
||||||
@@ -686,7 +701,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
skip_bias_add=skip_bias_add,
|
skip_bias_add=skip_bias_add,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||||
shard_offset_mapping = {
|
shard_offset_mapping = {
|
||||||
@@ -980,9 +996,10 @@ class RowParallelLinear(LinearBase):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = "",
|
||||||
|
return_bias: bool = True):
|
||||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||||
quant_config, prefix)
|
quant_config, prefix, return_bias=return_bias)
|
||||||
|
|
||||||
self.input_is_parallel = input_is_parallel
|
self.input_is_parallel = input_is_parallel
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@@ -1086,8 +1103,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user