[Bug Fix] Add partial rotary factor support for Phi-4 and upgrade to transformers v4.50.0 (#3984)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Adarsh Shirawalmath
2025-03-23 02:57:39 +05:30
committed by GitHub
parent ecbfe58bb0
commit f8f9244a61
11 changed files with 50 additions and 2125 deletions

View File

@@ -35,7 +35,7 @@ runtime_common = [
"python-multipart",
"pyzmq>=25.1.2",
"torchao>=0.7.0",
"transformers==4.48.3",
"transformers==4.50.0",
"uvicorn",
"uvloop",
"xgrammar==0.1.16",

View File

@@ -2,21 +2,12 @@ from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.qwen2_5_vl_config import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
__all__ = [
"ExaoneConfig",
"ChatGLMConfig",
"DbrxConfig",
"DeepseekVL2Config",
"Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig",
"MultiModalityConfig",
"Gemma3Config",
"Gemma3TextConfig",
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -35,10 +35,7 @@ from sglang.srt.configs import (
DbrxConfig,
DeepseekVL2Config,
ExaoneConfig,
Gemma3Config,
Gemma3TextConfig,
MultiModalityConfig,
Qwen2_5_VLConfig,
)
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
@@ -47,11 +44,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
@@ -223,11 +217,26 @@ def get_processor(
tokenizer_revision: Optional[str] = None,
**kwargs,
):
# pop 'revision' from kwargs if present.
revision = kwargs.pop("revision", tokenizer_revision)
config = AutoConfig.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
# fix: for Qwen2-VL model, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl"}:
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
revision=revision,
**kwargs,
)

View File

@@ -441,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
):
super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size})."
)
if is_neox_style is False:
raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self.rotary_dim = rotary_dim
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
@@ -499,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
* (
self.base
** (
torch.arange(0, self.head_size, 2, dtype=torch.float)
/ self.head_size
torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
/ self.rotary_dim
)
)
)
@@ -549,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)
return query.flatten(-2), key.flatten(-2)

View File

@@ -21,11 +21,11 @@ from torch import nn
from transformers import (
ROPE_INIT_FUNCTIONS,
AutoModel,
Gemma3TextConfig,
PretrainedConfig,
PreTrainedModel,
)
from sglang.srt.configs.gemma3 import Gemma3TextConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import Gemma3RMSNorm

View File

@@ -21,9 +21,15 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers import (
AutoModel,
BatchFeature,
Gemma3Config,
Gemma3Processor,
PreTrainedModel,
)
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from sglang.srt.configs import Gemma3Config
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,

View File

@@ -34,8 +34,15 @@ from einops import rearrange
from transformers import AutoModel, Qwen2VLConfig
from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -714,4 +721,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
EntryClass = [Qwen2_5_VLForConditionalGeneration]
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)