[Fix] Fix llava on multi images (#1247)

This commit is contained in:
Lianmin Zheng
2024-08-28 06:33:05 -07:00
committed by GitHub
parent b1a540ec42
commit bf53bf5142
22 changed files with 272 additions and 488 deletions

View File

@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
from sglang.srt.model_config import AttentionArch, ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
class ModelRunner:
def __init__(
self,
model_config,
model_config: ModelConfig,
mem_fraction_static: float,
gpu_id: int,
tp_rank: int,
@@ -85,7 +85,9 @@ class ModelRunner:
self.tp_size = tp_size
self.nccl_port = nccl_port
self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(self.model_config)
self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
global_server_args_dict.update(
{
"disable_flashinfer": server_args.disable_flashinfer,
@@ -95,6 +97,13 @@ class ModelRunner:
}
)
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
@@ -507,9 +516,9 @@ class ModelRunner:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable torch compile by not using --enable-torch-compile\n"
"2. disable cuda graph by --disable-cuda-graph\n"
"3. set --mem-fraction-static to a smaller value\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)