Fix Llava model (#594)

This commit is contained in:
Mingyi
2024-07-06 00:58:46 -07:00
committed by GitHub
parent dc1b8bcfaa
commit c0982ac553
5 changed files with 18 additions and 13 deletions

View File

@@ -1,15 +1,14 @@
import json
from typing import Callable, List, Optional, Union
from typing import List, Optional
import numpy as np
import requests
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglArgument, SglSamplingParams
from sglang.utils import encode_image_base64, find_printable_text, http_request
from sglang.lang.ir import SglSamplingParams
from sglang.utils import find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):

View File

@@ -523,9 +523,9 @@ class StreamExecutor:
self, sampling_params=sampling_params
)
self.variables[name] = ""
self.stream_var_event[name].set()
self.variables[name] = ""
for comp, meta_info in generator:
self.text_ += comp
self.variables[name] += comp

View File

@@ -3,7 +3,7 @@
import warnings
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
from typing import List, Union
import numpy as np
import torch
@@ -31,7 +31,7 @@ class BaseFinishReason:
class FINISH_MATCHED_TOKEN(BaseFinishReason):
def __init__(self, matched: int | List[int]):
def __init__(self, matched: Union[int, List[int]]):
super().__init__()
self.matched = matched

View File

@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
return config
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early

View File

@@ -1,10 +1,10 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import Gemma2Config
from transformers import PretrainedConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module):
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
@@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
def __init__(
self,
config: Gemma2Config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
@@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module):
def __init__(
self,
config: Gemma2Config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,