gemma3: impl get_attention_sliding_window_size for attn init (#4823)
This commit is contained in:
@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
from sglang.srt.utils import add_prefix, make_layers
|
from sglang.srt.utils import add_prefix, make_layers
|
||||||
|
|
||||||
|
|
||||||
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||||
|
# SGLang assumes exclusive
|
||||||
|
def get_attention_sliding_window_size(config):
|
||||||
|
return config.sliding_window - 1
|
||||||
|
|
||||||
|
|
||||||
# Adapted from:
|
# Adapted from:
|
||||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
|
||||||
def extract_layer_index(prefix: str) -> int:
|
def extract_layer_index(prefix: str) -> int:
|
||||||
@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
|
|||||||
self.rope_scaling = {"rope_type": "default"}
|
self.rope_scaling = {"rope_type": "default"}
|
||||||
# FIXME(mick): idk why vllm does this
|
# FIXME(mick): idk why vllm does this
|
||||||
# self.sliding_window = config.interleaved_sliding_window
|
# self.sliding_window = config.interleaved_sliding_window
|
||||||
self.sliding_window = config.sliding_window
|
self.sliding_window = get_attention_sliding_window_size(config)
|
||||||
else:
|
else:
|
||||||
# Global attention. Use the values in config.json.
|
# Global attention. Use the values in config.json.
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
|
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
|
||||||
|
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
||||||
|
# attention backend in `ForwardBatch`.
|
||||||
sliding_window_size=self.sliding_window,
|
sliding_window_size=self.sliding_window,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def get_attention_sliding_window_size(self):
|
||||||
|
return get_attention_sliding_window_size(self.config)
|
||||||
|
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
return next(self.parameters()).dtype
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LogitsProcessor:
|
) -> LogitsProcessor:
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids, positions, forward_batch, input_embeds, **kwargs
|
input_ids, positions, forward_batch, input_embeds, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.language_model.get_input_embeddings()
|
return self.language_model.get_input_embeddings()
|
||||||
|
|
||||||
|
def get_attention_sliding_window_size(self):
|
||||||
|
"""
|
||||||
|
This value is used to initialize attention backends in `ForwardBatch`.
|
||||||
|
"""
|
||||||
|
return self.language_model.get_attention_sliding_window_size()
|
||||||
|
|
||||||
def get_image_feature(self, image_input: MultimodalInputs):
|
def get_image_feature(self, image_input: MultimodalInputs):
|
||||||
"""
|
"""
|
||||||
Projects the last hidden state from the vision model into language model space.
|
Projects the last hidden state from the vision model into language model space.
|
||||||
|
|||||||
Reference in New Issue
Block a user