vlm: optimize tensor transport (#6003)
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -3,8 +3,9 @@ Multi-modality utils
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,6 +28,130 @@ from sglang.utils import logger
|
||||
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
||||
# in the console when multimodal support is enabled.
|
||||
|
||||
# TODO(mick): nccl
|
||||
# cuda_ipc: for intranode tensor sharing
|
||||
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
|
||||
|
||||
|
||||
class TransportProxyTensor(torch.Tensor):
|
||||
"""
|
||||
A convenient torch.Tensor subclass that carries extra metadata and supports
|
||||
efficient inter-process communications
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
cls,
|
||||
data: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
transport_mode: TensorTransportMode = "default",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if not isinstance(data, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
|
||||
)
|
||||
|
||||
instance = data.as_subclass(cls)
|
||||
|
||||
instance._metadata = {
|
||||
"name": name,
|
||||
"fields": fields if fields is not None else {},
|
||||
"transport_mode": transport_mode,
|
||||
}
|
||||
|
||||
return instance
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Called during pickling. Implements the serialization logic.
|
||||
"""
|
||||
# acquire all serialize metadata from _metadata
|
||||
state = {
|
||||
"metadata": self._metadata,
|
||||
"tensor_data": None,
|
||||
"ipc_extra": None,
|
||||
}
|
||||
|
||||
transport_mode = self._metadata.get("transport_mode", "default")
|
||||
|
||||
if transport_mode == "cuda_ipc" and self.is_cuda:
|
||||
try:
|
||||
storage = self.untyped_storage()
|
||||
handle = storage._share_cuda_()
|
||||
|
||||
state["ipc_extra"] = {
|
||||
"handle": handle,
|
||||
"shape": self.shape,
|
||||
"dtype": self.dtype,
|
||||
"stride": self.stride(),
|
||||
"device_index": self.device.index,
|
||||
}
|
||||
state["tensor_data"] = None
|
||||
except Exception as e:
|
||||
print_warning_once(
|
||||
f"Warning: Failed to get CUDA IPC handle ({e}). Falling back to default transport."
|
||||
)
|
||||
state["metadata"]["transport_mode"] = "default"
|
||||
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
||||
else:
|
||||
state["metadata"]["transport_mode"] = "default"
|
||||
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]):
|
||||
"""
|
||||
Called during unpickling. Implements the deserialization logic.
|
||||
"""
|
||||
self._metadata = state["metadata"]
|
||||
|
||||
transport_mode = self._metadata.get("transport_mode", "default")
|
||||
|
||||
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
|
||||
ipc_extra = state["ipc_extra"]
|
||||
handle, shape, dtype, stride, source_device_index = (
|
||||
ipc_extra["handle"],
|
||||
ipc_extra["shape"],
|
||||
ipc_extra["dtype"],
|
||||
ipc_extra["stride"],
|
||||
ipc_extra["device_index"],
|
||||
)
|
||||
|
||||
try:
|
||||
target_device = torch.device(f"cuda:{source_device_index}")
|
||||
with torch.cuda.device(target_device):
|
||||
storage = torch.UntypedStorage._new_shared_cuda(*handle)
|
||||
reconstructed_tensor = torch.empty(
|
||||
0, dtype=dtype, device=target_device
|
||||
).set_(storage, storage_offset=0, size=shape, stride=stride)
|
||||
self.set_(reconstructed_tensor)
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
|
||||
raise e
|
||||
|
||||
elif state["tensor_data"] is not None:
|
||||
self.set_(state["tensor_data"])
|
||||
else:
|
||||
raise pickle.UnpicklingError(
|
||||
"Invalid state for TransportProxyTensor: no tensor data found."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> Optional[str]:
|
||||
return self._metadata.get("name")
|
||||
|
||||
@property
|
||||
def fields(self) -> Dict[str, Any]:
|
||||
return self._metadata.get("fields", {})
|
||||
|
||||
@property
|
||||
def transport_mode(self) -> TensorTransportMode:
|
||||
return self._metadata.get("transport_mode", "default")
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPattern:
|
||||
"""
|
||||
|
||||
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
|
||||
PROCESSOR_MAPPING = {}
|
||||
|
||||
|
||||
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def process_mm_data_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def get_dummy_processor():
|
||||
return DummyMultimodalProcessor()
|
||||
|
||||
|
||||
def import_processors():
|
||||
package_name = "sglang.srt.multimodal.processors"
|
||||
package = importlib.import_module(package_name)
|
||||
@@ -49,11 +37,12 @@ def import_processors():
|
||||
|
||||
|
||||
def get_mm_processor(
|
||||
hf_config, server_args: ServerArgs, processor
|
||||
hf_config, server_args: ServerArgs, processor, transport_mode
|
||||
) -> BaseMultimodalProcessor:
|
||||
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
||||
if model_cls.__name__ in hf_config.architectures:
|
||||
return processor_cls(hf_config, server_args, processor)
|
||||
return processor_cls(hf_config, server_args, processor, transport_mode)
|
||||
|
||||
raise ValueError(
|
||||
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
||||
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
||||
|
||||
@@ -209,10 +209,11 @@ class MultimodalDataItem:
|
||||
hash: int = None
|
||||
pad_value: int = None
|
||||
offsets: Optional[list] = None
|
||||
|
||||
# the raw features returned by processor, e.g. pixel_values or audio_features
|
||||
feature: Union[torch.Tensor, np.ndarray] = None
|
||||
|
||||
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
|
||||
# the precomputed embeddings, passed as final encoder embeddings
|
||||
# One and only one of the feature and precomputed_embeddings will be empty
|
||||
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
# Model-specific data stored in a dictionary
|
||||
|
||||
@@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -166,6 +167,16 @@ class ReqState:
|
||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||
is_cross_node = server_args.dist_init_addr
|
||||
|
||||
if is_cross_node:
|
||||
# Fallback to default CPU transport for multi-node
|
||||
return "default"
|
||||
else:
|
||||
return "cuda_ipc"
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
@@ -216,12 +227,13 @@ class TokenizerManager:
|
||||
revision=server_args.revision,
|
||||
use_fast=not server_args.disable_fast_image_processor,
|
||||
)
|
||||
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
||||
# images even with skip_tokenizer_init=False.
|
||||
self.mm_processor = get_mm_processor(
|
||||
self.model_config.hf_config, server_args, _processor
|
||||
self.model_config.hf_config, server_args, _processor, transport_mode
|
||||
)
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.mm_utils import TransportProxyTensor
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
||||
|
||||
@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
models = []
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
def __init__(
|
||||
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
|
||||
):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.arch = hf_config.architectures[0]
|
||||
self.server_args = server_args
|
||||
self.transport_mode = transport_mode
|
||||
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
if "pixel_values" in result and isinstance(
|
||||
result["pixel_values"], torch.Tensor
|
||||
):
|
||||
result["pixel_values"] = result["pixel_values"].to("cpu")
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
) -> List[MultimodalDataItem]:
|
||||
"""Create mm_items directly from processor output."""
|
||||
items: dict[Modality, MultimodalDataItem] = {}
|
||||
|
||||
for attr_name, value in data_dict.items():
|
||||
if attr_name == "input_ids":
|
||||
continue
|
||||
@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
|
||||
mm_token_id=mm_token_id,
|
||||
)
|
||||
|
||||
# post-process
|
||||
for item in all_collected_items:
|
||||
# replace the feature tensor with a proxy
|
||||
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
|
||||
item.feature = TransportProxyTensor(
|
||||
transport_mode=self.transport_mode, data=item.feature
|
||||
)
|
||||
elif (
|
||||
isinstance(item.precomputed_embeddings, torch.Tensor)
|
||||
and item.precomputed_embeddings.is_cuda
|
||||
):
|
||||
item.precomputed_embeddings = TransportProxyTensor(
|
||||
transport_mode=self.transport_mode, data=item.precomputed_embeddings
|
||||
)
|
||||
|
||||
return all_collected_items, input_ids, ret
|
||||
|
||||
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
models = [CLIPModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||
_processor
|
||||
)
|
||||
|
||||
@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [DeepseekVL2ForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<image>", image_token_id=self._processor.image_token_id
|
||||
).build(_processor)
|
||||
|
||||
@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
models = [Gemma3ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
|
||||
@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
|
||||
models = [Gemma3nForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
||||
|
||||
@@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
models = [InternVLChatModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
super().__init__(hf_config, server_args, _image_processor)
|
||||
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
|
||||
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token,
|
||||
|
||||
@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
|
||||
class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
models = [KimiVLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<|media_pad|>",
|
||||
# TODO: could we convert in MultimodalSpecialTokens?
|
||||
|
||||
@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
LlavaMistralForCausalLM,
|
||||
]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
||||
)
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
assert hasattr(hf_config, "vision_config")
|
||||
assert hasattr(hf_config, "text_config")
|
||||
self.vision_config = hf_config.vision_config
|
||||
@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||
|
||||
if vision_type := getattr(self.vision_config, "model_type"):
|
||||
self.inner = self._get_sgl_processor_cls(vision_type)(
|
||||
hf_config, server_args, _processor
|
||||
hf_config, server_args, _processor, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [MiniCPMV, MiniCPMO]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
|
||||
@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
self.im_start_id = getattr(tokenizer, "im_start_id", None)
|
||||
self.im_end_id = getattr(tokenizer, "im_end_id", None)
|
||||
self.im_token_id = getattr(tokenizer, "unk_id", None)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
|
||||
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MllamaForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self._processor.image_token,
|
||||
image_token_id=self._processor.image_token_id,
|
||||
|
||||
@@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [Llama4ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.boi_token_index = hf_config.boi_token_index
|
||||
|
||||
@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
|
||||
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Phi4MMForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
self.processor = Phi4MMProcessorAdapter(_processor)
|
||||
super().__init__(hf_config, server_args, self.processor)
|
||||
super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
|
||||
|
||||
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
|
||||
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
|
||||
|
||||
@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.IM_TOKEN_ID = getattr(
|
||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||
)
|
||||
|
||||
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Qwen2AudioForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
||||
|
||||
@@ -201,8 +201,8 @@ async def preprocess_video(
|
||||
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
|
||||
@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
hf_config: PretrainedConfig,
|
||||
server_args: ServerArgs,
|
||||
_processor: VILAProcessor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self._processor.tokenizer.image_token,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
|
||||
@@ -14,6 +14,7 @@ import traceback
|
||||
import urllib.request
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from json import dumps
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||
@@ -28,6 +29,24 @@ from tqdm import tqdm
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def execute_once(func):
|
||||
has_run = None
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal has_run
|
||||
if not has_run:
|
||||
func(*args, **kwargs)
|
||||
has_run = True
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@execute_once
|
||||
def info_once(message: str):
|
||||
logger.info(message)
|
||||
|
||||
|
||||
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
|
||||
"""Convert a JSON schema to a string.
|
||||
Parameters
|
||||
|
||||
Reference in New Issue
Block a user