[Fix] Fix OOM in llava base class (#1249)
This commit is contained in:
@@ -46,25 +46,7 @@ from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
@@ -434,14 +416,36 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
return self.image_size // self.patch_size
|
||||
|
||||
|
||||
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
||||
class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
|
||||
class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
@@ -467,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
||||
)
|
||||
|
||||
|
||||
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
||||
class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
|
||||
@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.1.5",
|
||||
"0.1.6",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
Reference in New Issue
Block a user