From 0a97d7962d31728a3e4d5936b105ab27a83cd1a9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 28 Aug 2024 08:38:50 -0700 Subject: [PATCH] [Fix] Fix OOM in llava base class (#1249) --- .../http_llava_onevision_test.py | 2 +- python/sglang/srt/models/llava.py | 51 ++++++++++--------- python/sglang/srt/server.py | 2 +- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/llava_onevision/http_llava_onevision_test.py index 41d60b12a..0c93d2ce2 100644 --- a/examples/runtime/llava_onevision/http_llava_onevision_test.py +++ b/examples/runtime/llava_onevision/http_llava_onevision_test.py @@ -1,7 +1,7 @@ """ Usage: -python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384 +python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava python3 http_llava_onevision_test.py """ diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index bc522bec9..7dcf5348b 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -46,25 +46,7 @@ from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM -class LlavaLlamaForCausalLM(nn.Module): - def __init__( - self, - config: LlavaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.vision_tower = None - self.config.vision_config.hidden_size = config.mm_hidden_size - self.config.text_config.hidden_size = config.hidden_size - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = LlamaForCausalLM(config, quant_config=quant_config) - if "unpad" in getattr(config, "mm_patch_merge_type", ""): - self.language_model.model.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size, dtype=torch.float16) - ) - +class LlavaBaseForCausalLM(nn.Module): def pad_input_ids( self, input_ids: List[int], @@ -434,14 +416,36 @@ class LlavaLlamaForCausalLM(nn.Module): return self.image_size // self.patch_size -class LlavaQwenForCausalLM(LlavaLlamaForCausalLM): +class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, ) -> None: - super().__init__(config, quant_config=quant_config, cache_config=cache_config) + super().__init__() + + self.config = config + self.vision_tower = None + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = LlamaForCausalLM(config, quant_config=quant_config) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + +class LlavaQwenForCausalLM(LlavaBaseForCausalLM): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.config = config self.vision_tower = None if getattr(self.config, "vision_config", None) is None: @@ -467,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM): ) -class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): +class LlavaMistralForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, ) -> None: - super().__init__(config, quant_config=quant_config, cache_config=cache_config) + super().__init__() + self.config = config self.vision_tower = None if getattr(self.config, "vision_config", None) is None: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9c36216ed..5ba2a45e7 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs): if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", - "0.1.5", + "0.1.6", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.",