[Bugfix] Fix masked_fill_ function typo (#769)
### What this PR does / why we need it? Fix function name typo, make `mask_fill_` to `masked_fill_` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -135,7 +135,7 @@ class AttentionMaskBuilder:
|
||||
context_len:] = self.splitfuse_mask_value
|
||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||
context_len:seq_len]
|
||||
right_tensor.mask_fill_(
|
||||
right_tensor.masked_fill_(
|
||||
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
||||
current_row += q_len
|
||||
|
||||
|
||||
Reference in New Issue
Block a user