Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)

Co-authored-by: ZX <zx@lbx.dev>
Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
Qubitium
2024-05-12 07:37:49 +08:00
committed by GitHub
parent a511a2d089
commit 33b242df30
20 changed files with 611 additions and 187 deletions

View File

@@ -5,10 +5,11 @@ from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.weight_utils import (
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
@@ -27,7 +28,7 @@ class LlavaLlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -35,7 +36,7 @@ class LlavaLlamaForCausalLM(nn.Module):
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, linear_method)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)