refactor: bug fixes and refactor for vlm (#4661)
This commit is contained in:
@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
|
||||
if len(positions.shape) == 1:
|
||||
if positions.dim() == 1:
|
||||
positions = einops.rearrange(positions, "s -> 1 s")
|
||||
|
||||
position_embeddings_global = self.rotary_emb(hidden_states, positions)
|
||||
@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
||||
)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.model.layers[0].mlp.gate_up_proj.weight.dtype
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user