fix: fix video input for qwen3-vl (#11442)
This commit is contained in:
@@ -1142,6 +1142,13 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if (
|
||||||
|
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
|
||||||
|
) and video_grid_thw is not None:
|
||||||
|
video_grid_thw = torch.repeat_interleave(
|
||||||
|
video_grid_thw, video_grid_thw[:, 0], dim=0
|
||||||
|
)
|
||||||
|
video_grid_thw[:, 0] = 1
|
||||||
mrope_position_deltas = []
|
mrope_position_deltas = []
|
||||||
if input_ids is not None and (
|
if input_ids is not None and (
|
||||||
image_grid_thw is not None or video_grid_thw is not None
|
image_grid_thw is not None or video_grid_thw is not None
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -360,7 +359,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
(
|
(
|
||||||
FreezeGCReq,
|
FreezeGCReq,
|
||||||
lambda x: None,
|
lambda x: None,
|
||||||
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
),
|
||||||
|
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
||||||
(HealthCheckOutput, lambda x: None),
|
(HealthCheckOutput, lambda x: None),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.mm_processor and obj.contains_mm_input():
|
if self.mm_processor and obj.contains_mm_input():
|
||||||
if not isinstance(obj.image_data, list):
|
if not isinstance(obj.image_data, list) and obj.image_data:
|
||||||
obj.image_data = [obj.image_data]
|
obj.image_data = [obj.image_data]
|
||||||
if not isinstance(obj.audio_data, list):
|
if not isinstance(obj.audio_data, list) and obj.audio_data:
|
||||||
obj.audio_data = [obj.audio_data]
|
obj.audio_data = [obj.audio_data]
|
||||||
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||||
image_data=obj.image_data,
|
image_data=obj.image_data,
|
||||||
|
|||||||
@@ -196,7 +196,6 @@ MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if _is_npu:
|
if _is_npu:
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
@@ -636,6 +635,22 @@ class ModelRunner:
|
|||||||
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.model_config.hf_config.model_type == "qwen3_vl_moe":
|
||||||
|
if (
|
||||||
|
quantization_config := getattr(
|
||||||
|
self.model_config.hf_config, "quantization_config", None
|
||||||
|
)
|
||||||
|
) is not None:
|
||||||
|
text_config = self.model_config.hf_text_config
|
||||||
|
weight_block_size_n = quantization_config["weight_block_size"][0]
|
||||||
|
if (
|
||||||
|
text_config.moe_intermediate_size
|
||||||
|
// (self.tp_size // self.moe_ep_size)
|
||||||
|
) % weight_block_size_n != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"For qwen3-vl-fp8 models, please make sure ({text_config.moe_intermediate_size=} // ({self.tp_size=} // {self.moe_ep_size=})) % {weight_block_size_n=} == 0"
|
||||||
|
)
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
logger.info("Init torch distributed begin.")
|
logger.info("Init torch distributed begin.")
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,27 @@ class TestQwen2VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
|
|||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen3VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=[
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.80",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
class TestQwen2_5_VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
|
class TestQwen2_5_VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
@@ -494,7 +494,7 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
**(self.get_vision_request_kwargs()),
|
**(self.get_vision_request_kwargs()),
|
||||||
)
|
)
|
||||||
|
|
||||||
video_response = response.choices[0].message.content
|
video_response = response.choices[0].message.content.lower()
|
||||||
|
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
print(f"Video response:\n{video_response}")
|
print(f"Video response:\n{video_response}")
|
||||||
@@ -502,9 +502,10 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
|
|
||||||
# Add assertions to validate the video response
|
# Add assertions to validate the video response
|
||||||
assert (
|
assert (
|
||||||
"iPod" in video_response
|
"ipod" in video_response
|
||||||
or "device" in video_response
|
or "device" in video_response
|
||||||
or "microphone" in video_response
|
or "microphone" in video_response
|
||||||
|
or "phone" in video_response
|
||||||
), f"video_response: {video_response}, should contain 'iPod' or 'device'"
|
), f"video_response: {video_response}, should contain 'iPod' or 'device'"
|
||||||
assert (
|
assert (
|
||||||
"man" in video_response
|
"man" in video_response
|
||||||
|
|||||||
Reference in New Issue
Block a user