refactor: multimodal data (#4754)

This commit is contained in:
Mick
2025-04-01 00:57:51 +08:00
committed by GitHub
parent c7457191a0
commit 5cb552b1d4
36 changed files with 989 additions and 1138 deletions

View File

@@ -355,11 +355,6 @@ class ForwardBatch:
for mm_input in valid_inputs[1:]:
merged.merge(mm_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
if isinstance(merged.audio_features, np.ndarray):
merged.audio_features = torch.from_numpy(merged.audio_features)
return merged
def contains_image_inputs(self) -> bool:

View File

@@ -251,17 +251,16 @@ class ModelRunner:
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
if self.is_multimodal:
self.mem_fraction_static *= 0.95
self.mem_fraction_static *= 0.90
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
if self.model_config.hf_config.architectures == [
"MllamaForConditionalGeneration"
]:
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
server_args.chunked_prefill_size = -1
logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
server_args.chunked_prefill_size = -1
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
@@ -269,18 +268,7 @@ class ModelRunner:
"Qwen2_5_VLForConditionalGeneration"
]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args.chunked_prefill_size = -1
logger.info("Automatically disable radix cache for qwen-vl series.")
server_args.disable_radix_cache = True
if server_args.enable_deepep_moe: