[Bug Fix] Add partial rotary factor support for Phi-4 and upgrade to transformers v4.50.0 (#3984)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Adarsh Shirawalmath
2025-03-23 02:57:39 +05:30
committed by GitHub
parent ecbfe58bb0
commit f8f9244a61
11 changed files with 50 additions and 2125 deletions

View File

@@ -35,7 +35,7 @@ runtime_common = [
"python-multipart", "python-multipart",
"pyzmq>=25.1.2", "pyzmq>=25.1.2",
"torchao>=0.7.0", "torchao>=0.7.0",
"transformers==4.48.3", "transformers==4.50.0",
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.16", "xgrammar==0.1.16",

View File

@@ -2,21 +2,12 @@ from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.qwen2_5_vl_config import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
__all__ = [ __all__ = [
"ExaoneConfig", "ExaoneConfig",
"ChatGLMConfig", "ChatGLMConfig",
"DbrxConfig", "DbrxConfig",
"DeepseekVL2Config", "DeepseekVL2Config",
"Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig",
"MultiModalityConfig", "MultiModalityConfig",
"Gemma3Config",
"Gemma3TextConfig",
] ]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -35,10 +35,7 @@ from sglang.srt.configs import (
DbrxConfig, DbrxConfig,
DeepseekVL2Config, DeepseekVL2Config,
ExaoneConfig, ExaoneConfig,
Gemma3Config,
Gemma3TextConfig,
MultiModalityConfig, MultiModalityConfig,
Qwen2_5_VLConfig,
) )
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url from sglang.srt.utils import is_remote_url
@@ -47,11 +44,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig, ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig, DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config, DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig, MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig,
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
@@ -223,11 +217,26 @@ def get_processor(
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
**kwargs, **kwargs,
): ):
# pop 'revision' from kwargs if present.
revision = kwargs.pop("revision", tokenizer_revision)
config = AutoConfig.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
# fix: for Qwen2-VL model, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl"}:
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
tokenizer_name, tokenizer_name,
*args, *args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, revision=revision,
**kwargs, **kwargs,
) )

View File

@@ -441,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
): ):
super().__init__() super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size})."
)
if is_neox_style is False: if is_neox_style is False:
raise ValueError( raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
) )
self.rotary_dim = rotary_dim
self.head_size = head_size self.head_size = head_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings
@@ -499,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
* ( * (
self.base self.base
** ( ** (
torch.arange(0, self.head_size, 2, dtype=torch.float) torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
/ self.head_size / self.rotary_dim
) )
) )
) )
@@ -549,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
cos = cos.repeat(1, 2).unsqueeze(-2) cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2) sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin query_rot = query[..., : self.rotary_dim]
key = key * cos + _rotate_neox(key) * sin query_pass = query[..., self.rotary_dim :]
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)
return query.flatten(-2), key.flatten(-2) return query.flatten(-2), key.flatten(-2)

View File

@@ -21,11 +21,11 @@ from torch import nn
from transformers import ( from transformers import (
ROPE_INIT_FUNCTIONS, ROPE_INIT_FUNCTIONS,
AutoModel, AutoModel,
Gemma3TextConfig,
PretrainedConfig, PretrainedConfig,
PreTrainedModel, PreTrainedModel,
) )
from sglang.srt.configs.gemma3 import Gemma3TextConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm

View File

@@ -21,9 +21,15 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, PreTrainedModel from transformers import (
AutoModel,
BatchFeature,
Gemma3Config,
Gemma3Processor,
PreTrainedModel,
)
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from sglang.srt.configs import Gemma3Config
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
self.head_dim = getattr( self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads config, "head_dim", self.hidden_size // self.total_num_heads
) )
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,

View File

@@ -34,8 +34,15 @@ from einops import rearrange
from transformers import AutoModel, Qwen2VLConfig from transformers import AutoModel, Qwen2VLConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@@ -714,4 +721,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
EntryClass = [Qwen2_5_VLForConditionalGeneration] EntryClass = [Qwen2_5_VLForConditionalGeneration]
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)

View File

@@ -20,7 +20,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
pip install torch_memory_saver --force-reinstall pip install torch_memory_saver --force-reinstall
pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pandas datasets pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets
# For compling xgrammar kernels # For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12