diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 8dc5effb4..36749f22f 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -97,7 +97,7 @@ class Gemma2MLP(nn.Module): class Gemma2Attention(nn.Module): def __init__( self, - layer_idx: int, + layer_id: int, config: PretrainedConfig, hidden_size: int, num_heads: int, @@ -109,7 +109,7 @@ class Gemma2Attention(nn.Module): quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.layer_idx = layer_idx + self.layer_id = layer_id self.config = config self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -156,13 +156,13 @@ class Gemma2Attention(nn.Module): dtype=torch.get_default_dtype(), ) - use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window") + use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window") self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - layer_id=layer_idx, + layer_id=layer_id, logit_cap=self.config.attn_logit_softcapping, sliding_window_size=( get_attention_sliding_window_size(config) @@ -188,7 +188,7 @@ class Gemma2Attention(nn.Module): class Gemma2DecoderLayer(nn.Module): def __init__( self, - layer_idx: int, + layer_id: int, config: PretrainedConfig, cache_config=None, quant_config: Optional[QuantizationConfig] = None, @@ -196,7 +196,7 @@ class Gemma2DecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Gemma2Attention( - layer_idx=layer_idx, + layer_id=layer_id, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -269,8 +269,8 @@ class Gemma2Model(nn.Module): ) self.layers = nn.ModuleList( [ - Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) + Gemma2DecoderLayer(layer_id, config, cache_config, quant_config) + for layer_id in range(config.num_hidden_layers) ] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 2c594acc8..d073ca3b6 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -223,8 +223,8 @@ class OlmoModel(nn.Module): ) self.layers = nn.ModuleList( [ - OlmoDecoderLayer(config, layer_idx, quant_config) - for layer_idx in range(config.num_hidden_layers) + OlmoDecoderLayer(config, layer_id, quant_config) + for layer_id in range(config.num_hidden_layers) ] ) self.norm = nn.LayerNorm( @@ -250,7 +250,7 @@ class OlmoModel(nn.Module): hidden_states = input_embeds # Apply blocks one-by-one. - for layer_idx, decoder_layer in enumerate(self.layers): + for layer_id, decoder_layer in enumerate(self.layers): # shape: (batch_size, seq_len, d_model) hidden_states = decoder_layer( positions,