refactor: bug fixes and refactor for vlm (#4661)

This commit is contained in:
Mick
2025-03-23 13:48:49 +08:00
committed by GitHub
parent ca75741e86
commit 11577cedb7
31 changed files with 770 additions and 735 deletions

View File

@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
else:
hidden_states = input_embeds
if len(positions.shape) == 1:
if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.rotary_emb(hidden_states, positions)
@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
)
self.post_init()
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def dtype(self) -> torch.dtype:
return self.model.layers[0].mlp.gate_up_proj.weight.dtype
return next(self.parameters()).dtype
@torch.no_grad()
def forward(