model: support intern-s1 (#8350)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: zxy <zhou0493@e.ntu.edu.sg> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Mick <mickjagger19@icloud.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -448,6 +448,19 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="interns1",
|
||||
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="granite-3-instruct",
|
||||
@@ -609,6 +622,14 @@ def match_internvl_chat(model_path: str):
|
||||
return "internvl-2-5"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_interns1_chat(model_path: str):
|
||||
if re.search(r"intern-s1", model_path, re.IGNORECASE):
|
||||
return "interns1"
|
||||
if re.search(r"interns1", model_path, re.IGNORECASE):
|
||||
return "interns1"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
|
||||
@@ -10,6 +10,7 @@ from transformers import (
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen3Config,
|
||||
)
|
||||
|
||||
from sglang.utils import logger
|
||||
@@ -314,6 +315,8 @@ class InternVLChatConfig(PretrainedConfig):
|
||||
self.llm_config = InternLM2Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
||||
self.llm_config = Qwen2Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
||||
self.llm_config = Qwen3Config(**llm_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported architecture: {}".format(
|
||||
|
||||
@@ -635,6 +635,7 @@ multimodal_model_archs = [
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"InternVLChatModel",
|
||||
"InternS1ForConditionalGeneration",
|
||||
"Phi4MMForCausalLM",
|
||||
"VILAForConditionalGeneration",
|
||||
]
|
||||
|
||||
@@ -623,7 +623,7 @@ def generate_chat_conv(
|
||||
real_content += content.text
|
||||
elif content.type == "image_url":
|
||||
# NOTE: works for llava and intervl2_5
|
||||
if conv.name == "internvl-2-5":
|
||||
if conv.name in ["internvl-2-5", "interns1"]:
|
||||
real_content = image_token + real_content
|
||||
else:
|
||||
real_content += image_token
|
||||
@@ -817,6 +817,19 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="interns1",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
|
||||
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
||||
sep_style=SeparatorStyle.MPT,
|
||||
sep="<|im_end|>\n",
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
image_token="<image>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
@@ -986,6 +999,8 @@ register_conv_template(
|
||||
def match_internvl(model_path: str):
|
||||
if re.search(r"internvl", model_path, re.IGNORECASE):
|
||||
return "internvl-2-5"
|
||||
if re.search(r"interns1", model_path, re.IGNORECASE):
|
||||
return "interns1"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -18,11 +18,16 @@ _is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.distributed import (
|
||||
parallel_state,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from sglang.srt.distributed import utils as dist_utils
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -349,25 +354,44 @@ class VisionAttention(nn.Module):
|
||||
flatten_batch: bool = False,
|
||||
prefix: str = "",
|
||||
proj_bias: bool = True,
|
||||
num_dummy_heads: int = 0,
|
||||
qkv_bias: bool = True,
|
||||
qk_normalization: bool = False,
|
||||
layer_norm_eps: float = 1e-06,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_size = world_size
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.dropout = dropout
|
||||
self.head_size = embed_dim // num_heads
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, world_size
|
||||
num_dummy_heads + num_heads, world_size
|
||||
)
|
||||
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
||||
num_heads, world_size
|
||||
num_dummy_heads + num_heads, world_size
|
||||
)
|
||||
|
||||
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
||||
self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
|
||||
|
||||
self.qk_normalization = qk_normalization
|
||||
|
||||
# Additional dummy heads are used to enable TP for common GPU counts.
|
||||
self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = RMSNorm(
|
||||
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
||||
)
|
||||
self.k_norm = RMSNorm(
|
||||
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
||||
)
|
||||
|
||||
if global_server_args_dict["mm_attention_backend"] is None:
|
||||
if qkv_backend is None:
|
||||
qkv_backend = "sdpa"
|
||||
@@ -391,26 +415,46 @@ class VisionAttention(nn.Module):
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.head_size,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
total_num_heads=num_dummy_heads + num_heads,
|
||||
total_num_kv_heads=num_dummy_heads + num_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
input_size=embed_dim,
|
||||
output_size=3 * projection_size,
|
||||
output_size=3 * self.dummy_dim,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=embed_dim,
|
||||
input_size=self.dummy_dim,
|
||||
output_size=embed_dim,
|
||||
bias=proj_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("proj", prefix),
|
||||
)
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
"""apply qk norm for internvl vit attn"""
|
||||
q = q.flatten(1, 2)
|
||||
k = k.flatten(1, 2)
|
||||
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
q = q.unflatten(-1, (-1, self.head_size))
|
||||
k = k.unflatten(-1, (-1, self.head_size))
|
||||
return q, k
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -489,6 +533,10 @@ class VisionAttention(nn.Module):
|
||||
assert k.dim() == 3, k.dim()
|
||||
assert v.dim() == 3, v.dim()
|
||||
|
||||
# internvl
|
||||
if self.qk_normalization:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
|
||||
output = self.qkv_backend.forward(
|
||||
q=q,
|
||||
k=k,
|
||||
|
||||
@@ -61,10 +61,15 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
self.hidden_size = hidden_size
|
||||
self.variance_size_override = (
|
||||
None if var_hidden_size == hidden_size else var_hidden_size
|
||||
)
|
||||
if _use_aiter:
|
||||
self._forward_method = self.forward_aiter
|
||||
|
||||
@@ -73,6 +78,8 @@ class RMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
@@ -138,7 +145,25 @@ class RMSNorm(CustomOp):
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[..., : self.variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = (x * self.weight).to(orig_dtype)
|
||||
if residual is None:
|
||||
|
||||
328
python/sglang/srt/models/interns1.py
Normal file
328
python/sglang/srt/models/interns1.py
Normal file
@@ -0,0 +1,328 @@
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.internvl import InternVisionModel
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class InternS1ForConditionalGeneration(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_flash_attn=True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self._update_hf_config()
|
||||
image_size = (
|
||||
getattr(config, "force_image_size", None) or config.vision_config.image_size
|
||||
)
|
||||
patch_size = config.vision_config.patch_size
|
||||
if isinstance(image_size, list):
|
||||
image_size = image_size[0]
|
||||
if isinstance(patch_size, list):
|
||||
patch_size = patch_size[0]
|
||||
self.patch_size = patch_size
|
||||
self.select_layer = config.vision_feature_layer
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
||||
)
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = getattr(config, "ps_version", "v1")
|
||||
# self.template = getattr(config, 'template', 'internvl2_5')
|
||||
|
||||
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
||||
config.text_config._attn_implementation = (
|
||||
"flash_attention_2" if use_flash_attn else "eager"
|
||||
)
|
||||
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"ps_version: {self.ps_version}")
|
||||
|
||||
self.vision_model = InternVisionModel(config.vision_config)
|
||||
if config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
self.language_model = Qwen2ForCausalLM(
|
||||
config=config.text_config, quant_config=quant_config
|
||||
)
|
||||
elif config.text_config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.language_model = Qwen3MoeForCausalLM(
|
||||
config=config.text_config, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{config.text_config.architectures[0]} is not implemented."
|
||||
)
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(
|
||||
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
|
||||
),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
|
||||
def _update_hf_config(self):
|
||||
"""update hf config to support tp"""
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
num_heads = self.config.vision_config.num_attention_heads
|
||||
head_dim = self.config.vision_config.hidden_size // num_heads
|
||||
num_dummy_heads = 0
|
||||
|
||||
if num_heads % world_size != 0:
|
||||
num_dummy_heads = (
|
||||
(num_heads + world_size) // world_size
|
||||
) * world_size - num_heads
|
||||
|
||||
setattr(self.config.vision_config, "head_dim", head_dim)
|
||||
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(
|
||||
n,
|
||||
int(h * scale_factor),
|
||||
int(w * scale_factor),
|
||||
int(c / (scale_factor * scale_factor)),
|
||||
)
|
||||
if self.ps_version == "v1":
|
||||
logger.warn(
|
||||
"In ps_version 'v1', the height and width have not been swapped back, "
|
||||
"which results in a transposed image."
|
||||
)
|
||||
else:
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
).last_hidden_state
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
vit_embeds = vit_embeds[:, 1:, :]
|
||||
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
pixel_values = torch.cat([item.feature for item in items])
|
||||
image_features = self.extract_feature(pixel_values)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hs = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = mm_inputs.im_start_id
|
||||
im_end_id: int = mm_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
return helper.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
||||
"""pad attn qkv weights for dummy heads"""
|
||||
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
||||
if num_dummy_heads == 0:
|
||||
return loaded_weight
|
||||
head_dim = self.config.vision_config.head_dim
|
||||
|
||||
if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
|
||||
if name.endswith(".weight"):
|
||||
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
|
||||
elif name.endswith(".bias"):
|
||||
dummy_shape = [num_dummy_heads, head_dim]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported weight with name={name}")
|
||||
padded_weight = loaded_weight.new_zeros(dummy_shape)
|
||||
loaded_weight = torch.cat(
|
||||
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
|
||||
).flatten(0, 1)
|
||||
if "attn.proj.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(
|
||||
loaded_weight.shape[0], head_dim * num_dummy_heads
|
||||
)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
||||
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
||||
return loaded_weight
|
||||
|
||||
def _mapping_interns1_name(self, name):
|
||||
names_map = {
|
||||
"lm_head.weight": "language_model.lm_head.weight",
|
||||
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
|
||||
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
|
||||
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
|
||||
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
|
||||
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
|
||||
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
|
||||
"model.vision_tower.embeddings.cls_token": "vision_model.embeddings.class_embedding",
|
||||
"model.vision_tower.embeddings.patch_embeddings.projection.bias": "vision_model.embeddings.patch_embedding.bias",
|
||||
"model.vision_tower.embeddings.patch_embeddings.projection.weight": "vision_model.embeddings.patch_embedding.weight",
|
||||
"model.vision_tower.embeddings.position_embeddings": "vision_model.embeddings.position_embedding",
|
||||
}
|
||||
if name in names_map:
|
||||
name = names_map[name]
|
||||
elif name.startswith("model.language_model."):
|
||||
name = "language_model.model." + name[len("model.language_model.") :]
|
||||
elif name.startswith("model.vision_tower."):
|
||||
name = "vision_model." + name[len("model.vision_tower.") :]
|
||||
|
||||
if name.startswith("vision_model.encoder.layer"):
|
||||
|
||||
name = name.replace(r".layer.", r".layers.")
|
||||
name = name.replace(r".attention.", r".attn.attn.")
|
||||
name = name.replace(r".projection_layer.", r".proj.")
|
||||
name = name.replace(r".lambda_1", r".ls1")
|
||||
name = name.replace(r".lambda_2", r".ls2")
|
||||
name = name.replace(r".layernorm_before.", r".norm1.")
|
||||
name = name.replace(r".layernorm_after.", r".norm2.")
|
||||
return name
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
expert_params_mapping = []
|
||||
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
name = self._mapping_interns1_name(name)
|
||||
if "vision_model" in name:
|
||||
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
loaded_params.add(name)
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||
)
|
||||
return loaded_params
|
||||
|
||||
|
||||
EntryClass = [InternS1ForConditionalGeneration]
|
||||
@@ -1,16 +1,3 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================582====================================================
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_janus_pro import DropPath
|
||||
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
@@ -53,7 +43,6 @@ class InternAttention(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.attn = VisionAttention(
|
||||
@@ -64,18 +53,16 @@ class InternAttention(nn.Module):
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=getattr(config, "dropout", 0.0),
|
||||
proj_bias=getattr(config, "qkv_bias", True),
|
||||
qkv_bias=getattr(config, "qkv_bias", False)
|
||||
or getattr(config, "attention_bias", False),
|
||||
num_dummy_heads=getattr(config, "num_dummy_heads", 0),
|
||||
qk_normalization=getattr(config, "qk_normalization", False)
|
||||
or getattr(config, "use_qk_norm", False),
|
||||
flatten_batch=False,
|
||||
)
|
||||
|
||||
self.proj_drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
self.image_size = (
|
||||
config.image_size
|
||||
if isinstance(config.image_size, int)
|
||||
else config.image_size[0]
|
||||
)
|
||||
self.patch_size = (
|
||||
config.patch_size
|
||||
if isinstance(config.patch_size, int)
|
||||
else config.patch_size[0]
|
||||
)
|
||||
|
||||
self.class_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, self.embed_dim),
|
||||
@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
self.attn = InternAttention(config)
|
||||
self.attn = InternAttention(config=config, quant_config=quant_config)
|
||||
self.mlp = InternMLP(config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self._update_vision_config()
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module):
|
||||
self.language_model = InternLM2ForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.language_model = Qwen3MoeForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{config.llm_config.architectures[0]} is not implemented."
|
||||
@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module):
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
|
||||
def _update_vision_config(self):
|
||||
"""update vision config to support tp"""
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
num_heads = self.config.vision_config.num_attention_heads
|
||||
head_dim = self.config.vision_config.hidden_size // num_heads
|
||||
num_dummy_heads = 0
|
||||
|
||||
if num_heads % world_size != 0:
|
||||
num_dummy_heads = (
|
||||
(num_heads + world_size) // world_size
|
||||
) * world_size - num_heads
|
||||
|
||||
setattr(self.config.vision_config, "head_dim", head_dim)
|
||||
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module):
|
||||
|
||||
return helper.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
||||
"""pad attn qkv weights for dummy heads"""
|
||||
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
||||
if num_dummy_heads == 0:
|
||||
return loaded_weight
|
||||
head_dim = self.config.vision_config.head_dim
|
||||
|
||||
if "attn.qkv_proj" in name:
|
||||
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
||||
if name.endswith(".weight"):
|
||||
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
||||
elif name.endswith(".bias"):
|
||||
dummy_shape = [num_dummy_heads, head_dim]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported weight with name={name}")
|
||||
pad_func = lambda x: torch.cat(
|
||||
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
||||
).flatten(0, 1)
|
||||
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
if "attn.proj.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(
|
||||
loaded_weight.shape[0], head_dim * num_dummy_heads
|
||||
)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
||||
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
||||
return loaded_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
expert_params_mapping = []
|
||||
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module):
|
||||
name = name.replace(r"attn.", r"attn.attn.")
|
||||
name = name.replace(r"qkv.", r"qkv_proj.")
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if "wqkv" in name:
|
||||
config = self.config
|
||||
kv_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
loaded_weight = loaded_weight.view(
|
||||
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
||||
)
|
||||
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
|
||||
wq = wq.reshape(-1, wq.shape[-1])
|
||||
wk = wk.reshape(-1, wk.shape[-1])
|
||||
wv = wv.reshape(-1, wv.shape[-1])
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, wq, "q")
|
||||
weight_loader(param, wk, "k")
|
||||
weight_loader(param, wv, "v")
|
||||
else:
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if "wqkv" in name:
|
||||
config = self.config
|
||||
kv_groups = (
|
||||
config.num_attention_heads // config.num_key_value_heads
|
||||
)
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
loaded_weight = loaded_weight.view(
|
||||
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
||||
)
|
||||
wq, wk, wv = torch.split(
|
||||
loaded_weight, [kv_groups, 1, 1], dim=1
|
||||
)
|
||||
wq = wq.reshape(-1, wq.shape[-1])
|
||||
wk = wk.reshape(-1, wk.shape[-1])
|
||||
wv = wv.reshape(-1, wv.shape[-1])
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, wq, "q")
|
||||
weight_loader(param, wk, "k")
|
||||
weight_loader(param, wv, "v")
|
||||
else:
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
if "vision_model" in name:
|
||||
loaded_weight = self._pad_vit_attn_dummy_heads(
|
||||
name, loaded_weight
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
loaded_params.add(name)
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
|
||||
@@ -707,6 +707,9 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -6,6 +6,7 @@ from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
|
||||
from sglang.srt.models.internvl import InternVLChatModel
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
|
||||
|
||||
class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
models = [InternVLChatModel]
|
||||
models = [InternVLChatModel, InternS1ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
|
||||
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
||||
image_size = (
|
||||
getattr(hf_config, "force_image_size", None)
|
||||
or hf_config.vision_config.image_size
|
||||
)
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
if isinstance(image_size, list):
|
||||
image_size = image_size[0]
|
||||
if isinstance(patch_size, list):
|
||||
patch_size = patch_size[0]
|
||||
|
||||
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
||||
self.IMG_START_TOKEN = "<img>"
|
||||
@@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
||||
)
|
||||
if hasattr(self._processor, "tokenizer"):
|
||||
tokenizer = self._processor.tokenizer
|
||||
else:
|
||||
tokenizer = self._processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
tokenizer = self._processor
|
||||
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
||||
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
@@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
try:
|
||||
# TODO: video input
|
||||
raw_image = process_image_internvl(image)
|
||||
pixel_value = [raw_image.to(torch.bfloat16).cuda()]
|
||||
pixel_value = [raw_image.to(torch.bfloat16)]
|
||||
pixel_values += pixel_value
|
||||
num_patches = raw_image.shape[0]
|
||||
num_patches_list += [num_patches]
|
||||
@@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
input_text = input_text.replace("<image>", image_tokens, 1)
|
||||
|
||||
tokenizer = self._processor
|
||||
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
|
||||
input_ids = self.tokenizer(input_text, return_tensors="pt")[
|
||||
"input_ids"
|
||||
].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.mm_tokens.image_token_id,
|
||||
|
||||
Reference in New Issue
Block a user