testing dynamic register

This commit is contained in:
Chranos
2026-02-05 18:36:03 +08:00
parent df848b4284
commit 2068984bde

View File

@@ -146,6 +146,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
return_bias: If False, return only output tensor instead of (output, bias) tuple.
"""
def __init__(
@@ -156,6 +157,7 @@ class LinearBase(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
return_bias: bool = True,
):
super().__init__()
@@ -163,6 +165,7 @@ class LinearBase(torch.nn.Module):
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.return_bias = return_bias
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
@@ -198,13 +201,15 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)
prefix=prefix,
return_bias=return_bias)
# All the linear layer supports quant method.
assert self.quant_method is not None
@@ -238,6 +243,9 @@ class ReplicatedLinear(LinearBase):
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -281,9 +289,10 @@ class ColumnParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
quant_config, prefix, return_bias=return_bias)
self.gather_output = gather_output
@@ -375,6 +384,9 @@ class ColumnParallelLinear(LinearBase):
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -418,7 +430,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
@@ -429,7 +442,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Parameter,
@@ -653,7 +667,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
@@ -686,7 +701,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
return_bias=return_bias)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
@@ -980,9 +996,10 @@ class RowParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
return_bias: bool = True):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
quant_config, prefix, return_bias=return_bias)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@@ -1086,8 +1103,9 @@ class RowParallelLinear(LinearBase):
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str: