[Bug Fix] Add partial rotary factor support for Phi-4 and upgrade to transformers v4.50.0 (#3984)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
committed by
GitHub
parent
ecbfe58bb0
commit
f8f9244a61
@@ -35,7 +35,7 @@ runtime_common = [
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"torchao>=0.7.0",
|
||||
"transformers==4.48.3",
|
||||
"transformers==4.50.0",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.16",
|
||||
|
||||
@@ -2,21 +2,12 @@ from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||
from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.qwen2_5_vl_config import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ExaoneConfig",
|
||||
"ChatGLMConfig",
|
||||
"DbrxConfig",
|
||||
"DeepseekVL2Config",
|
||||
"Qwen2_5_VLConfig",
|
||||
"Qwen2_5_VLVisionConfig",
|
||||
"MultiModalityConfig",
|
||||
"Gemma3Config",
|
||||
"Gemma3TextConfig",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -35,10 +35,7 @@ from sglang.srt.configs import (
|
||||
DbrxConfig,
|
||||
DeepseekVL2Config,
|
||||
ExaoneConfig,
|
||||
Gemma3Config,
|
||||
Gemma3TextConfig,
|
||||
MultiModalityConfig,
|
||||
Qwen2_5_VLConfig,
|
||||
)
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
from sglang.srt.utils import is_remote_url
|
||||
@@ -47,11 +44,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||
DbrxConfig.model_type: DbrxConfig,
|
||||
ExaoneConfig.model_type: ExaoneConfig,
|
||||
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
||||
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||
Gemma3Config.model_type: Gemma3Config,
|
||||
Gemma3TextConfig.model_type: Gemma3TextConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
@@ -223,11 +217,26 @@ def get_processor(
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# pop 'revision' from kwargs if present.
|
||||
revision = kwargs.pop("revision", tokenizer_revision)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# fix: for Qwen2-VL model, inject default 'size' if not provided.
|
||||
if config.model_type in {"qwen2_vl"}:
|
||||
if "size" not in kwargs:
|
||||
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -441,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rotary_dim != head_size:
|
||||
raise ValueError(
|
||||
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
||||
rotary_dim != head_size ({rotary_dim}!={head_size})."
|
||||
)
|
||||
if is_neox_style is False:
|
||||
raise ValueError(
|
||||
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
||||
)
|
||||
|
||||
self.rotary_dim = rotary_dim
|
||||
self.head_size = head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
@@ -499,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
* (
|
||||
self.base
|
||||
** (
|
||||
torch.arange(0, self.head_size, 2, dtype=torch.float)
|
||||
/ self.head_size
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
||||
/ self.rotary_dim
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -549,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
cos = cos.repeat(1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 2).unsqueeze(-2)
|
||||
|
||||
query = query * cos + _rotate_neox(query) * sin
|
||||
key = key * cos + _rotate_neox(key) * sin
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
return query.flatten(-2), key.flatten(-2)
|
||||
|
||||
|
||||
@@ -21,11 +21,11 @@ from torch import nn
|
||||
from transformers import (
|
||||
ROPE_INIT_FUNCTIONS,
|
||||
AutoModel,
|
||||
Gemma3TextConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
from sglang.srt.configs.gemma3 import Gemma3TextConfig
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||
|
||||
@@ -21,9 +21,15 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, PreTrainedModel
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
BatchFeature,
|
||||
Gemma3Config,
|
||||
Gemma3Processor,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
||||
|
||||
from sglang.srt.configs import Gemma3Config
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
|
||||
@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", self.hidden_size // self.total_num_heads
|
||||
)
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
||||
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
|
||||
@@ -34,8 +34,15 @@ from einops import rearrange
|
||||
from transformers import AutoModel, Qwen2VLConfig
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -714,4 +721,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
|
||||
|
||||
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
||||
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
|
||||
|
||||
Reference in New Issue
Block a user