fix black in pre-commit (#1940)
This commit is contained in:
@@ -27,59 +27,67 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int,
|
||||
rank: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
per_partition_vocab_size: int, rank: int, offset: int = 0
|
||||
) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f + offset, index_l + offset
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
def vocab_range_from_global_vocab_size(
|
||||
global_vocab_size: int, rank: int, world_size: int, offset: int = 0
|
||||
) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank,
|
||||
offset=offset)
|
||||
return vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, offset=offset
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabParallelEmbeddingShardIndices:
|
||||
"""Indices for a shard of a vocab parallel embedding."""
|
||||
|
||||
padded_org_vocab_start_index: int
|
||||
padded_org_vocab_end_index: int
|
||||
padded_added_vocab_start_index: int
|
||||
@@ -100,13 +108,11 @@ class VocabParallelEmbeddingShardIndices:
|
||||
|
||||
@property
|
||||
def num_org_elements_padded(self) -> int:
|
||||
return (self.padded_org_vocab_end_index -
|
||||
self.padded_org_vocab_start_index)
|
||||
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements_padded(self) -> int:
|
||||
return (self.padded_added_vocab_end_index -
|
||||
self.padded_added_vocab_start_index)
|
||||
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_vocab_padding(self) -> int:
|
||||
@@ -122,17 +128,14 @@ class VocabParallelEmbeddingShardIndices:
|
||||
|
||||
def __post_init__(self):
|
||||
# sanity checks
|
||||
assert (self.padded_org_vocab_start_index <=
|
||||
self.padded_org_vocab_end_index)
|
||||
assert (self.padded_added_vocab_start_index <=
|
||||
self.padded_added_vocab_end_index)
|
||||
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
|
||||
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
||||
assert (self.added_vocab_start_index <=
|
||||
self.padded_added_vocab_start_index)
|
||||
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
|
||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
||||
|
||||
@@ -142,20 +145,27 @@ class VocabParallelEmbeddingShardIndices:
|
||||
|
||||
@torch.jit.script
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
input_: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
org_vocab_end_index: int,
|
||||
num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.jit.script will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||
org_vocab_end_index)
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
input_ < added_vocab_end_index
|
||||
)
|
||||
added_offset = (
|
||||
added_vocab_start_index
|
||||
- (org_vocab_end_index - org_vocab_start_index)
|
||||
- num_org_vocab_padding
|
||||
)
|
||||
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
|
||||
added_offset * added_vocab_mask
|
||||
)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
@@ -200,15 +210,17 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_tp: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_tp: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_tp = enable_tp
|
||||
@@ -223,18 +235,22 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.org_vocab_size_padded = pad_vocab_size(
|
||||
self.org_vocab_size, self.padding_size
|
||||
)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
|
||||
)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.shard_indices = self._get_indices(
|
||||
self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
tp_rank,
|
||||
self.tp_size,
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
@@ -248,11 +264,13 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method))
|
||||
type(linear_method)
|
||||
)
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
||||
)
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
|
||||
@@ -260,53 +278,68 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_embeddings_per_partition = divide(
|
||||
self.num_embeddings_padded, self.tp_size
|
||||
)
|
||||
assert (
|
||||
self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
|
||||
)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.shard_indices.org_vocab_end_index
|
||||
- self.shard_indices.org_vocab_start_index
|
||||
)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
self.shard_indices.added_vocab_end_index
|
||||
- self.shard_indices.added_vocab_start_index
|
||||
)
|
||||
|
||||
self.linear_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.linear_method.create_weights(
|
||||
self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
||||
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
||||
def _get_indices(
|
||||
cls,
|
||||
vocab_size_padded: int,
|
||||
org_vocab_size_padded: int,
|
||||
vocab_size: int,
|
||||
org_vocab_size: int,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
) -> VocabParallelEmbeddingShardIndices:
|
||||
"""Get start and end indices for vocab parallel embedding, following the
|
||||
layout outlined in the class docstring, based on the given tp_rank and
|
||||
tp_size."""
|
||||
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
||||
tp_size))
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
|
||||
)
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
offset=org_vocab_size))
|
||||
vocab_range_from_global_vocab_size(
|
||||
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
|
||||
)
|
||||
)
|
||||
# remove padding
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index,
|
||||
org_vocab_size)
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
|
||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index,
|
||||
vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
|
||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
||||
return VocabParallelEmbeddingShardIndices(
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
padded_org_vocab_start_index,
|
||||
padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index,
|
||||
padded_added_vocab_end_index,
|
||||
org_vocab_start_index,
|
||||
org_vocab_end_index,
|
||||
added_vocab_start_index,
|
||||
added_vocab_end_index,
|
||||
)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
@@ -326,32 +359,49 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
added_embeddings: List[int] = []
|
||||
padding: List[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
shard_indices = self._get_indices(
|
||||
self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
tp_rank,
|
||||
self.tp_size,
|
||||
)
|
||||
range_start = self.num_embeddings_per_partition * tp_rank
|
||||
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
||||
base_embeddings.extend(
|
||||
range(range_start,
|
||||
range_start + shard_indices.num_org_elements))
|
||||
range(range_start, range_start + shard_indices.num_org_elements)
|
||||
)
|
||||
padding.extend(
|
||||
range(range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded))
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
)
|
||||
)
|
||||
added_embeddings.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements))
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements,
|
||||
)
|
||||
)
|
||||
padding.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded))
|
||||
assert (range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded == range_end)
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements,
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements_padded,
|
||||
)
|
||||
)
|
||||
assert (
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements_padded
|
||||
== range_end
|
||||
)
|
||||
ret = base_embeddings + added_embeddings + padding
|
||||
assert len(ret) == self.num_embeddings_padded
|
||||
return ret
|
||||
@@ -385,10 +435,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# If param packed on the same dim we are sharding on, then
|
||||
# need to adjust offsets of loaded weight by pack_factor.
|
||||
if packed_dim is not None and packed_dim == output_dim:
|
||||
packed_factor = param.packed_factor if isinstance(
|
||||
param, BasevLLMParameter) else param.pack_factor
|
||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
||||
param.packed_factor)
|
||||
packed_factor = (
|
||||
param.packed_factor
|
||||
if isinstance(param, BasevLLMParameter)
|
||||
else param.pack_factor
|
||||
)
|
||||
assert loaded_weight.shape[output_dim] == (
|
||||
self.org_vocab_size // param.packed_factor
|
||||
)
|
||||
start_idx = start_idx // packed_factor
|
||||
shard_size = shard_size // packed_factor
|
||||
else:
|
||||
@@ -396,23 +450,24 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Copy the data.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0] :].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.linear_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel = self.linear_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
@@ -426,9 +481,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||
s += f", embedding_dim={self.embedding_dim}"
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
s += f', num_embeddings_padded={self.num_embeddings_padded}'
|
||||
s += f", num_embeddings_padded={self.num_embeddings_padded}"
|
||||
if self.enable_tp:
|
||||
s += f', tp_size={self.tp_size}'
|
||||
s += f", tp_size={self.tp_size}"
|
||||
return s
|
||||
|
||||
|
||||
@@ -448,27 +503,38 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
params_dtype,
|
||||
org_num_embeddings,
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user