Fix Llava model (#594)
This commit is contained in:
@@ -1,15 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
|
||||||
|
|
||||||
from sglang.backend.base_backend import BaseBackend
|
from sglang.backend.base_backend import BaseBackend
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
from sglang.lang.ir import SglArgument, SglSamplingParams
|
from sglang.lang.ir import SglSamplingParams
|
||||||
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
from sglang.utils import find_printable_text, http_request
|
||||||
|
|
||||||
|
|
||||||
class RuntimeEndpoint(BaseBackend):
|
class RuntimeEndpoint(BaseBackend):
|
||||||
|
|||||||
@@ -523,9 +523,9 @@ class StreamExecutor:
|
|||||||
self, sampling_params=sampling_params
|
self, sampling_params=sampling_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.variables[name] = ""
|
||||||
self.stream_var_event[name].set()
|
self.stream_var_event[name].set()
|
||||||
|
|
||||||
self.variables[name] = ""
|
|
||||||
for comp, meta_info in generator:
|
for comp, meta_info in generator:
|
||||||
self.text_ += comp
|
self.text_ += comp
|
||||||
self.variables[name] += comp
|
self.variables[name] += comp
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -31,7 +31,7 @@ class BaseFinishReason:
|
|||||||
|
|
||||||
|
|
||||||
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
||||||
def __init__(self, matched: int | List[int]):
|
def __init__(self, matched: Union[int, List[int]]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matched = matched
|
self.matched = matched
|
||||||
|
|
||||||
|
|||||||
@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
|
|||||||
"""Get the "sub" config relevant to llm for multi modal models.
|
"""Get the "sub" config relevant to llm for multi modal models.
|
||||||
No op for pure text models.
|
No op for pure text models.
|
||||||
"""
|
"""
|
||||||
|
class_name = config.architectures[0]
|
||||||
|
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
||||||
|
# We support non-hf version of llava models, so we do not want to
|
||||||
|
# read the wrong values from the unused default text_config.
|
||||||
|
return config
|
||||||
|
|
||||||
if hasattr(config, "text_config"):
|
if hasattr(config, "text_config"):
|
||||||
# The code operates under the assumption that text_config should have
|
# The code operates under the assumption that text_config should have
|
||||||
# `num_attention_heads` (among others). Assert here to fail early
|
# `num_attention_heads` (among others). Assert here to fail early
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# Adapted from:
|
# Adapted from:
|
||||||
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Gemma2Config
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
config: Gemma2Config,
|
config: PretrainedConfig,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
@@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
config: Gemma2Config,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Gemma2Config,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Gemma2Config,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user