init v0.11.0rc0
This commit is contained in:
@@ -97,6 +97,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
@@ -252,3 +253,16 @@ class AscendLogitsProcessor(LogitsProcessor):
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
# keep this for version compatibility
|
||||
sampling_metadata=None, # type: ignore
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return LogitsProcessor.forward(self,
|
||||
lm_head,
|
||||
hidden_states,
|
||||
embedding_bias=embedding_bias)
|
||||
|
||||
Reference in New Issue
Block a user