gemma3: impl get_attention_sliding_window_size for attn init (#4823)
This commit is contained in:
@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||
# SGLang assumes exclusive
|
||||
def get_attention_sliding_window_size(config):
|
||||
return config.sliding_window - 1
|
||||
|
||||
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
|
||||
def extract_layer_index(prefix: str) -> int:
|
||||
@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
|
||||
self.rope_scaling = {"rope_type": "default"}
|
||||
# FIXME(mick): idk why vllm does this
|
||||
# self.sliding_window = config.interleaved_sliding_window
|
||||
self.sliding_window = config.sliding_window
|
||||
self.sliding_window = get_attention_sliding_window_size(config)
|
||||
else:
|
||||
# Global attention. Use the values in config.json.
|
||||
self.rope_theta = config.rope_theta
|
||||
@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
|
||||
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
||||
# attention backend in `ForwardBatch`.
|
||||
sliding_window_size=self.sliding_window,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
def get_attention_sliding_window_size(self):
|
||||
return get_attention_sliding_window_size(self.config)
|
||||
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
||||
input_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> LogitsProcessor:
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, forward_batch, input_embeds, **kwargs
|
||||
)
|
||||
|
||||
@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def get_attention_sliding_window_size(self):
|
||||
"""
|
||||
This value is used to initialize attention backends in `ForwardBatch`.
|
||||
"""
|
||||
return self.language_model.get_attention_sliding_window_size()
|
||||
|
||||
def get_image_feature(self, image_input: MultimodalInputs):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
Reference in New Issue
Block a user