Fix qwen config (#261)
This commit is contained in:
@@ -19,7 +19,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
||||
"zmq", "vllm>=0.3.3", "interegular", "lark", "numba",
|
||||
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow", "outlines>=0.0.27"]
|
||||
openai = ["openai>=1.0", "numpy"]
|
||||
anthropic = ["anthropic", "numpy"]
|
||||
|
||||
@@ -5,6 +5,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -25,7 +26,6 @@ from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
@@ -130,7 +130,7 @@ class QWenAttention(nn.Module):
|
||||
|
||||
|
||||
class QWenBlock(nn.Module):
|
||||
def __init__(self, config: QWenConfig, layer_id, linear_method=None):
|
||||
def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@@ -179,7 +179,7 @@ class QWenBlock(nn.Module):
|
||||
|
||||
|
||||
class QWenModel(nn.Module):
|
||||
def __init__(self, config: QWenConfig, linear_method=None):
|
||||
def __init__(self, config: PretrainedConfig, linear_method=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
@@ -216,7 +216,7 @@ class QWenModel(nn.Module):
|
||||
|
||||
|
||||
class QWenLMHeadModel(nn.Module):
|
||||
def __init__(self, config: QWenConfig, linear_method=None):
|
||||
def __init__(self, config: PretrainedConfig, linear_method=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = QWenModel(config, linear_method=linear_method)
|
||||
|
||||
Reference in New Issue
Block a user