diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index e0211c3b5..f7b8f7b5d 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -1,15 +1,14 @@ import json -from typing import Callable, List, Optional, Union +from typing import List, Optional import numpy as np -import requests from sglang.backend.base_backend import BaseBackend from sglang.global_config import global_config from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SglArgument, SglSamplingParams -from sglang.utils import encode_image_base64, find_printable_text, http_request +from sglang.lang.ir import SglSamplingParams +from sglang.utils import find_printable_text, http_request class RuntimeEndpoint(BaseBackend): diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 4f5bfa3ed..36418a6cc 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -523,9 +523,9 @@ class StreamExecutor: self, sampling_params=sampling_params ) + self.variables[name] = "" self.stream_var_event[name].set() - self.variables[name] = "" for comp, meta_info in generator: self.text_ += comp self.variables[name] += comp diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 773d6a500..ec4730061 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass from enum import IntEnum, auto -from typing import List +from typing import List, Union import numpy as np import torch @@ -31,7 +31,7 @@ class BaseFinishReason: class FINISH_MATCHED_TOKEN(BaseFinishReason): - def __init__(self, matched: int | List[int]): + def __init__(self, matched: Union[int, List[int]]): super().__init__() self.matched = matched diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 715b7fd21..c2cf7d47e 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ + class_name = config.architectures[0] + if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): + # We support non-hf version of llava models, so we do not want to + # read the wrong values from the unused default text_config. + return config + if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have # `num_attention_heads` (among others). Assert here to fail early diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 2d92b53c9..4593a5731 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -1,10 +1,10 @@ # Adapted from: # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn -from transformers import Gemma2Config +from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module): def __init__( self, layer_idx: int, - config: Gemma2Config, + config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module): def __init__( self, layer_idx: int, - config: Gemma2Config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -290,7 +290,7 @@ class Gemma2Model(nn.Module): def __init__( self, - config: Gemma2Config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module): def __init__( self, - config: Gemma2Config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None,