[Fix] Fix OOM in llava base class (#1249)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384
|
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava
|
||||||
|
|
||||||
python3 http_llava_onevision_test.py
|
python3 http_llava_onevision_test.py
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -46,25 +46,7 @@ from sglang.srt.models.mistral import MistralForCausalLM
|
|||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
class LlavaLlamaForCausalLM(nn.Module):
|
class LlavaBaseForCausalLM(nn.Module):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: LlavaConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.vision_tower = None
|
|
||||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
|
||||||
self.config.text_config.hidden_size = config.hidden_size
|
|
||||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
|
||||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
|
||||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
|
||||||
self.language_model.model.image_newline = nn.Parameter(
|
|
||||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
|
||||||
)
|
|
||||||
|
|
||||||
def pad_input_ids(
|
def pad_input_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: List[int],
|
input_ids: List[int],
|
||||||
@@ -434,14 +416,36 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
return self.image_size // self.patch_size
|
return self.image_size // self.patch_size
|
||||||
|
|
||||||
|
|
||||||
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlavaConfig,
|
config: LlavaConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.vision_tower = None
|
||||||
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||||
|
self.config.text_config.hidden_size = config.hidden_size
|
||||||
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||||
|
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||||
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||||
|
self.language_model.model.image_newline = nn.Parameter(
|
||||||
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlavaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vision_tower = None
|
self.vision_tower = None
|
||||||
if getattr(self.config, "vision_config", None) is None:
|
if getattr(self.config, "vision_config", None) is None:
|
||||||
@@ -467,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlavaConfig,
|
config: LlavaConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vision_tower = None
|
self.vision_tower = None
|
||||||
if getattr(self.config, "vision_config", None) is None:
|
if getattr(self.config, "vision_config", None) is None:
|
||||||
|
|||||||
@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"0.1.5",
|
"0.1.6",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
Reference in New Issue
Block a user