Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied (#2748)

Co-authored-by: chenxu02 <chenxu02@zhihu.com>
This commit is contained in:
Xu-Chen
2025-01-07 05:56:28 +08:00
committed by GitHub
parent 0f3eb1d294
commit 2329e1ddd0

View File

@@ -100,6 +100,7 @@ class LlamaAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
bias: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -132,14 +133,14 @@ class LlamaAttention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
@@ -194,6 +195,11 @@ class LlamaDecoderLayer(nn.Module):
) )
rope_is_neox_style = getattr(config, "rope_is_neox_style", True) rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@@ -206,6 +212,7 @@ class LlamaDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
bias=attention_bias,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,