[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)

This commit is contained in:
Lianmin Zheng
2024-09-02 21:44:45 -07:00
committed by GitHub
parent a5a134f39f
commit f64eae3a29
17 changed files with 105 additions and 158 deletions

View File

@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module):
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
}
params_dict = dict(self.named_parameters())
weights = list(weights)
for name, loaded_weight in weights:
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
if "projector" in name or "vision_tower" in name or "image_newline" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(weights)
else:
self.language_model.load_weights([(name, loaded_weight)])
@property
def num_patches_per_side(self):
@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = Qwen2Config(self.config._name_or_path)
@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646
@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = MistralConfig(self.config._name_or_path)
@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 32000