Fix Llava model (#594)
This commit is contained in:
@@ -1,15 +1,14 @@
|
||||
import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglArgument, SglSamplingParams
|
||||
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.utils import find_printable_text, http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
@@ -523,9 +523,9 @@ class StreamExecutor:
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
self.variables[name] = ""
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
self.variables[name] = ""
|
||||
for comp, meta_info in generator:
|
||||
self.text_ += comp
|
||||
self.variables[name] += comp
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -31,7 +31,7 @@ class BaseFinishReason:
|
||||
|
||||
|
||||
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
||||
def __init__(self, matched: int | List[int]):
|
||||
def __init__(self, matched: Union[int, List[int]]):
|
||||
super().__init__()
|
||||
self.matched = matched
|
||||
|
||||
|
||||
@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
No op for pure text models.
|
||||
"""
|
||||
class_name = config.architectures[0]
|
||||
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
||||
# We support non-hf version of llava models, so we do not want to
|
||||
# read the wrong values from the unused default text_config.
|
||||
return config
|
||||
|
||||
if hasattr(config, "text_config"):
|
||||
# The code operates under the assumption that text_config should have
|
||||
# `num_attention_heads` (among others). Assert here to fail early
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Gemma2Config
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
@@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
config: Gemma2Config,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
@@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
config: Gemma2Config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma2Config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
@@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma2Config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
|
||||
Reference in New Issue
Block a user