[Bug Fix] Add partial rotary factor support for Phi-4 and upgrade to transformers v4.50.0 (#3984)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
committed by
GitHub
parent
ecbfe58bb0
commit
f8f9244a61
@@ -35,7 +35,7 @@ runtime_common = [
|
|||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pyzmq>=25.1.2",
|
"pyzmq>=25.1.2",
|
||||||
"torchao>=0.7.0",
|
"torchao>=0.7.0",
|
||||||
"transformers==4.48.3",
|
"transformers==4.50.0",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"uvloop",
|
"uvloop",
|
||||||
"xgrammar==0.1.16",
|
"xgrammar==0.1.16",
|
||||||
|
|||||||
@@ -2,21 +2,12 @@ from sglang.srt.configs.chatglm import ChatGLMConfig
|
|||||||
from sglang.srt.configs.dbrx import DbrxConfig
|
from sglang.srt.configs.dbrx import DbrxConfig
|
||||||
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||||
from sglang.srt.configs.exaone import ExaoneConfig
|
from sglang.srt.configs.exaone import ExaoneConfig
|
||||||
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
|
|
||||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||||
from sglang.srt.configs.qwen2_5_vl_config import (
|
|
||||||
Qwen2_5_VLConfig,
|
|
||||||
Qwen2_5_VLVisionConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ExaoneConfig",
|
"ExaoneConfig",
|
||||||
"ChatGLMConfig",
|
"ChatGLMConfig",
|
||||||
"DbrxConfig",
|
"DbrxConfig",
|
||||||
"DeepseekVL2Config",
|
"DeepseekVL2Config",
|
||||||
"Qwen2_5_VLConfig",
|
|
||||||
"Qwen2_5_VLVisionConfig",
|
|
||||||
"MultiModalityConfig",
|
"MultiModalityConfig",
|
||||||
"Gemma3Config",
|
|
||||||
"Gemma3TextConfig",
|
|
||||||
]
|
]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -35,10 +35,7 @@ from sglang.srt.configs import (
|
|||||||
DbrxConfig,
|
DbrxConfig,
|
||||||
DeepseekVL2Config,
|
DeepseekVL2Config,
|
||||||
ExaoneConfig,
|
ExaoneConfig,
|
||||||
Gemma3Config,
|
|
||||||
Gemma3TextConfig,
|
|
||||||
MultiModalityConfig,
|
MultiModalityConfig,
|
||||||
Qwen2_5_VLConfig,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.connector import create_remote_connector
|
from sglang.srt.connector import create_remote_connector
|
||||||
from sglang.srt.utils import is_remote_url
|
from sglang.srt.utils import is_remote_url
|
||||||
@@ -47,11 +44,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||||
DbrxConfig.model_type: DbrxConfig,
|
DbrxConfig.model_type: DbrxConfig,
|
||||||
ExaoneConfig.model_type: ExaoneConfig,
|
ExaoneConfig.model_type: ExaoneConfig,
|
||||||
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
|
||||||
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
||||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||||
Gemma3Config.model_type: Gemma3Config,
|
|
||||||
Gemma3TextConfig.model_type: Gemma3TextConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, cls in _CONFIG_REGISTRY.items():
|
for name, cls in _CONFIG_REGISTRY.items():
|
||||||
@@ -223,11 +217,26 @@ def get_processor(
|
|||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# pop 'revision' from kwargs if present.
|
||||||
|
revision = kwargs.pop("revision", tokenizer_revision)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
tokenizer_name,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fix: for Qwen2-VL model, inject default 'size' if not provided.
|
||||||
|
if config.model_type in {"qwen2_vl"}:
|
||||||
|
if "size" not in kwargs:
|
||||||
|
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
processor = AutoProcessor.from_pretrained(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
*args,
|
*args,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tokenizer_revision=tokenizer_revision,
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -441,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if rotary_dim != head_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
|
||||||
rotary_dim != head_size ({rotary_dim}!={head_size})."
|
|
||||||
)
|
|
||||||
if is_neox_style is False:
|
if is_neox_style is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.original_max_position_embeddings = original_max_position_embeddings
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
@@ -499,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
* (
|
* (
|
||||||
self.base
|
self.base
|
||||||
** (
|
** (
|
||||||
torch.arange(0, self.head_size, 2, dtype=torch.float)
|
torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
||||||
/ self.head_size
|
/ self.rotary_dim
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -549,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
cos = cos.repeat(1, 2).unsqueeze(-2)
|
cos = cos.repeat(1, 2).unsqueeze(-2)
|
||||||
sin = sin.repeat(1, 2).unsqueeze(-2)
|
sin = sin.repeat(1, 2).unsqueeze(-2)
|
||||||
|
|
||||||
query = query * cos + _rotate_neox(query) * sin
|
query_rot = query[..., : self.rotary_dim]
|
||||||
key = key * cos + _rotate_neox(key) * sin
|
query_pass = query[..., self.rotary_dim :]
|
||||||
|
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
|
||||||
|
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
|
||||||
|
key_rot = key[..., : self.rotary_dim]
|
||||||
|
key_pass = key[..., self.rotary_dim :]
|
||||||
|
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
return query.flatten(-2), key.flatten(-2)
|
return query.flatten(-2), key.flatten(-2)
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ from torch import nn
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
ROPE_INIT_FUNCTIONS,
|
ROPE_INIT_FUNCTIONS,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
Gemma3TextConfig,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.configs.gemma3 import Gemma3TextConfig
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.activation import GeluAndMul
|
from sglang.srt.layers.activation import GeluAndMul
|
||||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||||
|
|||||||
@@ -21,9 +21,15 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, PreTrainedModel
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
BatchFeature,
|
||||||
|
Gemma3Config,
|
||||||
|
Gemma3Processor,
|
||||||
|
PreTrainedModel,
|
||||||
|
)
|
||||||
|
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
||||||
|
|
||||||
from sglang.srt.configs import Gemma3Config
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
|||||||
@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
|
|||||||
self.head_dim = getattr(
|
self.head_dim = getattr(
|
||||||
config, "head_dim", self.hidden_size // self.total_num_heads
|
config, "head_dim", self.hidden_size // self.total_num_heads
|
||||||
)
|
)
|
||||||
|
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
||||||
|
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.rotary_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
|
|||||||
@@ -34,8 +34,15 @@ from einops import rearrange
|
|||||||
from transformers import AutoModel, Qwen2VLConfig
|
from transformers import AutoModel, Qwen2VLConfig
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||||
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||||
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
|
Qwen2_5_VLVisionConfig,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -714,4 +721,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
||||||
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
|
|||||||
|
|
||||||
pip install torch_memory_saver --force-reinstall
|
pip install torch_memory_saver --force-reinstall
|
||||||
|
|
||||||
pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pandas datasets
|
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets
|
||||||
|
|
||||||
# For compling xgrammar kernels
|
# For compling xgrammar kernels
|
||||||
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
||||||
|
|||||||
Reference in New Issue
Block a user