Use model loader from vllm (#459)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Iterable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -8,6 +8,7 @@ from torch import nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(nn.Module):
|
||||
@@ -233,13 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
vision_path = self.config.mm_vision_tower
|
||||
@@ -272,9 +266,8 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
weights = list(weights)
|
||||
for name, loaded_weight in weights:
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
@@ -285,9 +278,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
)
|
||||
self.language_model.load_weights(weights)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user