Rename layer_idx to layer_id for consistency (#2078)
This commit is contained in:
@@ -97,7 +97,7 @@ class Gemma2MLP(nn.Module):
|
|||||||
class Gemma2Attention(nn.Module):
|
class Gemma2Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_idx: int,
|
layer_id: int,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@@ -109,7 +109,7 @@ class Gemma2Attention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_id = layer_id
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -156,13 +156,13 @@ class Gemma2Attention(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
)
|
)
|
||||||
|
|
||||||
use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
|
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
||||||
self.attn = RadixAttention(
|
self.attn = RadixAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_idx,
|
layer_id=layer_id,
|
||||||
logit_cap=self.config.attn_logit_softcapping,
|
logit_cap=self.config.attn_logit_softcapping,
|
||||||
sliding_window_size=(
|
sliding_window_size=(
|
||||||
get_attention_sliding_window_size(config)
|
get_attention_sliding_window_size(config)
|
||||||
@@ -188,7 +188,7 @@ class Gemma2Attention(nn.Module):
|
|||||||
class Gemma2DecoderLayer(nn.Module):
|
class Gemma2DecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_idx: int,
|
layer_id: int,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config=None,
|
cache_config=None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
@@ -196,7 +196,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = Gemma2Attention(
|
self.self_attn = Gemma2Attention(
|
||||||
layer_idx=layer_idx,
|
layer_id=layer_id,
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
@@ -269,8 +269,8 @@ class Gemma2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
|
Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@@ -223,8 +223,8 @@ class OlmoModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
OlmoDecoderLayer(config, layer_idx, quant_config)
|
OlmoDecoderLayer(config, layer_id, quant_config)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = nn.LayerNorm(
|
self.norm = nn.LayerNorm(
|
||||||
@@ -250,7 +250,7 @@ class OlmoModel(nn.Module):
|
|||||||
hidden_states = input_embeds
|
hidden_states = input_embeds
|
||||||
|
|
||||||
# Apply blocks one-by-one.
|
# Apply blocks one-by-one.
|
||||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
for layer_id, decoder_layer in enumerate(self.layers):
|
||||||
# shape: (batch_size, seq_len, d_model)
|
# shape: (batch_size, seq_len, d_model)
|
||||||
hidden_states = decoder_layer(
|
hidden_states = decoder_layer(
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
Reference in New Issue
Block a user