Fix: Minicpm (#7612)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -32,7 +32,7 @@ from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
WHISPER_ATTENTION_CLASSES,
|
||||
WhisperAttention,
|
||||
WhisperConfig,
|
||||
WhisperEncoder,
|
||||
)
|
||||
@@ -1090,7 +1090,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
||||
Reference in New Issue
Block a user