[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)
This commit is contained in:
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
"model.image_newline": "language_model.model.image_newline",
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
weights = list(weights)
|
||||
for name, loaded_weight in weights:
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
if "projector" in name or "vision_tower" in name or "image_newline" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
else:
|
||||
self.language_model.load_weights([(name, loaded_weight)])
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
||||
|
||||
@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 151646
|
||||
|
||||
@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = MistralConfig(self.config._name_or_path)
|
||||
|
||||
@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 32000
|
||||
|
||||
|
||||
Reference in New Issue
Block a user