Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)
Co-authored-by: ZX <zx@lbx.dev> Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||||
"zmq", "vllm>=0.3.3,<=0.4.0", "interegular", "pydantic", "pillow", "outlines>=0.0.27"]
|
"zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "flashinfer>=0.0.4", "packaging"]
|
||||||
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0", "numpy"]
|
anthropic = ["anthropic>=0.20.0", "numpy"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||||
|
|||||||
@@ -266,14 +266,14 @@ class StreamExecutor:
|
|||||||
|
|
||||||
def fork(
|
def fork(
|
||||||
self,
|
self,
|
||||||
number: int,
|
size: int = 1,
|
||||||
position_ids_offset: Optional[List[int]] = None,
|
position_ids_offset: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if number > 1:
|
if size > 1:
|
||||||
self.submit(SglCommitLazy())
|
self.submit(SglCommitLazy())
|
||||||
|
|
||||||
self.sync()
|
self.sync()
|
||||||
number = int(number)
|
size = int(size)
|
||||||
|
|
||||||
exes = [
|
exes = [
|
||||||
StreamExecutor(
|
StreamExecutor(
|
||||||
@@ -283,9 +283,9 @@ class StreamExecutor:
|
|||||||
self.chat_template,
|
self.chat_template,
|
||||||
self.stream,
|
self.stream,
|
||||||
)
|
)
|
||||||
for _ in range(number)
|
for _ in range(size)
|
||||||
]
|
]
|
||||||
for i in range(number):
|
for i in range(size):
|
||||||
exes[i].variables = dict(self.variables)
|
exes[i].variables = dict(self.variables)
|
||||||
exes[i].text_ = str(self.text_)
|
exes[i].text_ = str(self.text_)
|
||||||
exes[i].messages_ = list(self.messages_)
|
exes[i].messages_ = list(self.messages_)
|
||||||
@@ -656,10 +656,10 @@ class ProgramState:
|
|||||||
|
|
||||||
def fork(
|
def fork(
|
||||||
self,
|
self,
|
||||||
number: int = 1,
|
size: int = 1,
|
||||||
position_ids_offset: Optional[List[int]] = None,
|
position_ids_offset: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
stream_executors = self.stream_executor.fork(size, position_ids_offset)
|
||||||
states = [ProgramState(x) for x in stream_executors]
|
states = [ProgramState(x) for x in stream_executors]
|
||||||
state_group = ProgramStateGroup(states, self)
|
state_group = ProgramStateGroup(states, self)
|
||||||
return state_group
|
return state_group
|
||||||
|
|||||||
@@ -109,19 +109,21 @@ class TracerProgramState(ProgramState):
|
|||||||
########### Public API ###########
|
########### Public API ###########
|
||||||
##################################
|
##################################
|
||||||
|
|
||||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||||
|
assert (size >= 1)
|
||||||
|
|
||||||
if self.only_trace_prefix:
|
if self.only_trace_prefix:
|
||||||
raise StopTracing()
|
raise StopTracing()
|
||||||
|
|
||||||
fork_node = SglFork(number)
|
fork_node = SglFork(size)
|
||||||
fork_node.prev_node = self.last_node
|
fork_node.prev_node = self.last_node
|
||||||
|
|
||||||
states = [
|
states = [
|
||||||
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
||||||
for _ in range(number)
|
for _ in range(size)
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(number):
|
for i in range(size):
|
||||||
node = SglGetForkItem(i)
|
node = SglGetForkItem(i)
|
||||||
node.prev_node = fork_node
|
node.prev_node = fork_node
|
||||||
states[i].last_node = node
|
states[i].last_node = node
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ import rpyc
|
|||||||
import torch
|
import torch
|
||||||
from rpyc.utils.classic import obtain
|
from rpyc.utils.classic import obtain
|
||||||
from rpyc.utils.server import ThreadedServer
|
from rpyc.utils.server import ThreadedServer
|
||||||
from vllm.logger import _default_handler as vllm_default_handler
|
try:
|
||||||
|
from vllm.logger import _default_handler as vllm_default_logger
|
||||||
|
except ImportError:
|
||||||
|
from vllm.logger import logger as vllm_default_logger
|
||||||
|
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||||
@@ -50,7 +53,7 @@ class ModelRpcServer:
|
|||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
self.schedule_heuristic = server_args.schedule_heuristic
|
self.schedule_heuristic = server_args.schedule_heuristic
|
||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||||
vllm_default_handler.setLevel(
|
vllm_default_logger.setLevel(
|
||||||
level=getattr(logging, server_args.log_level.upper())
|
level=getattr(logging, server_args.log_level.upper())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import torch
|
|||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.distributed import initialize_model_parallel
|
||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
@@ -142,16 +142,9 @@ class InputMetadata:
|
|||||||
self.kv_last_page_len,
|
self.kv_last_page_len,
|
||||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||||
|
self.model_runner.model_config.head_dim
|
||||||
]
|
]
|
||||||
|
|
||||||
# flashinfer >= 0.0.3
|
|
||||||
# FIXME: Drop this when flashinfer updates to 0.0.4
|
|
||||||
if (
|
|
||||||
len(inspect.signature(self.prefill_wrapper.begin_forward).parameters)
|
|
||||||
== 7
|
|
||||||
):
|
|
||||||
args.append(self.model_runner.model_config.head_dim)
|
|
||||||
|
|
||||||
self.prefill_wrapper.begin_forward(*args)
|
self.prefill_wrapper.begin_forward(*args)
|
||||||
else:
|
else:
|
||||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
@@ -304,7 +297,7 @@ class ModelRunner:
|
|||||||
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
linear_method = None
|
quant_config = None
|
||||||
|
|
||||||
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
||||||
if quant_cfg is not None:
|
if quant_cfg is not None:
|
||||||
@@ -326,12 +319,11 @@ class ModelRunner:
|
|||||||
|
|
||||||
quant_config = quant_config_class.from_config(quant_cfg)
|
quant_config = quant_config_class.from_config(quant_cfg)
|
||||||
logger.info(f"quant_config: {quant_config}")
|
logger.info(f"quant_config: {quant_config}")
|
||||||
linear_method = quant_config.get_linear_method()
|
|
||||||
|
|
||||||
with _set_default_torch_dtype(torch.float16):
|
with set_default_torch_dtype(torch.float16):
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
model = model_class(
|
model = model_class(
|
||||||
config=self.model_config.hf_config, linear_method=linear_method
|
config=self.model_config.hf_config, quant_config=quant_config
|
||||||
)
|
)
|
||||||
model.load_weights(
|
model.load_weights(
|
||||||
self.model_config.path,
|
self.model_config.path,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
# This file is based on the LLama model definition file in transformers
|
# This file is based on the LLama model definition file in transformers
|
||||||
"""PyTorch Cohere model."""
|
"""PyTorch Cohere model."""
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -29,19 +29,20 @@ from torch.nn.parameter import Parameter
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -92,7 +93,7 @@ class CohereMLP(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -102,13 +103,13 @@ class CohereMLP(nn.Module):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
[self.intermediate_size] * 2,
|
[self.intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
@@ -124,7 +125,7 @@ class CohereAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -159,13 +160,13 @@ class CohereAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -221,16 +222,16 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = CohereAttention(
|
self.self_attn = CohereAttention(
|
||||||
config, layer_id=layer_id, linear_method=linear_method
|
config, layer_id=layer_id, quant_config=quant_config
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = CohereMLP(config, linear_method=linear_method)
|
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||||
self.input_layernorm = LayerNorm(
|
self.input_layernorm = LayerNorm(
|
||||||
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
@@ -261,7 +262,7 @@ class CohereModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -271,7 +272,7 @@ class CohereModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
CohereDecoderLayer(config, i, linear_method=linear_method)
|
CohereDecoderLayer(config, i, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -303,13 +304,13 @@ class CohereForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.quant_config = quant_config
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.model = CohereModel(config, linear_method)
|
self.model = CohereModel(config, quant_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -7,26 +7,27 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE,
|
DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.distributed import (
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -56,7 +57,7 @@ class DbrxRouter(nn.Module):
|
|||||||
self.num_total_experts,
|
self.num_total_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
linear_method=None,
|
quant_config=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -75,7 +76,7 @@ class DbrxExperts(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -176,7 +177,7 @@ class DbrxAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@@ -194,13 +195,13 @@ class DbrxAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
self.d_model,
|
self.d_model,
|
||||||
self.d_model,
|
self.d_model,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -255,11 +256,11 @@ class DbrxFusedNormAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
self.attn = DbrxAttention(config, layer_id, linear_method)
|
self.attn = DbrxAttention(config, layer_id, quant_config=quant_config)
|
||||||
self.norm_1 = nn.LayerNorm(self.d_model)
|
self.norm_1 = nn.LayerNorm(self.d_model)
|
||||||
self.norm_2 = nn.LayerNorm(self.d_model)
|
self.norm_2 = nn.LayerNorm(self.d_model)
|
||||||
|
|
||||||
@@ -287,11 +288,11 @@ class DbrxBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, linear_method)
|
self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config)
|
||||||
self.ffn = DbrxExperts(config, linear_method)
|
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -313,7 +314,7 @@ class DbrxModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wte = VocabParallelEmbedding(
|
self.wte = VocabParallelEmbedding(
|
||||||
@@ -321,7 +322,7 @@ class DbrxModel(nn.Module):
|
|||||||
config.d_model,
|
config.d_model,
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[DbrxBlock(config, i, linear_method) for i in range(config.n_layers)]
|
[DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)]
|
||||||
)
|
)
|
||||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
@@ -351,13 +352,13 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.quant_config = quant_config
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
self.transformer = DbrxModel(config, linear_method)
|
self.transformer = DbrxModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.d_model,
|
config.d_model,
|
||||||
|
|||||||
@@ -10,17 +10,18 @@ from vllm.config import LoRAConfig
|
|||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -35,17 +36,17 @@ class GemmaMLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
[intermediate_size] * 2,
|
[intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, linear_method=linear_method
|
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act_fn = GeluAndMul()
|
self.act_fn = GeluAndMul()
|
||||||
|
|
||||||
@@ -66,7 +67,7 @@ class GemmaAttention(nn.Module):
|
|||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -96,13 +97,13 @@ class GemmaAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
@@ -139,7 +140,7 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -151,12 +152,12 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
rope_theta=config.rope_theta,
|
rope_theta=config.rope_theta,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = GemmaMLP(
|
self.mlp = GemmaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
@@ -192,7 +193,7 @@ class GemmaModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -203,7 +204,7 @@ class GemmaModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
GemmaDecoderLayer(config, i, linear_method)
|
GemmaDecoderLayer(config, i, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -264,14 +265,14 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
del lora_config # Unused.
|
del lora_config # Unused.
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.quant_config = quant_config
|
||||||
self.model = GemmaModel(config, linear_method)
|
self.model = GemmaModel(config, quant_config=quant_config)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
|
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
|
||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -9,20 +9,21 @@ from transformers import LlamaConfig
|
|||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -38,17 +39,17 @@ class LlamaMLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
[intermediate_size] * 2,
|
[intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, linear_method=linear_method
|
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -74,7 +75,7 @@ class LlamaAttention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -105,13 +106,13 @@ class LlamaAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
@@ -148,7 +149,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -163,13 +164,13 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
@@ -205,7 +206,7 @@ class LlamaModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -217,7 +218,7 @@ class LlamaModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
LlamaDecoderLayer(config, i, linear_method)
|
LlamaDecoderLayer(config, i, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -251,12 +252,12 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.quant_config = quant_config
|
||||||
self.model = LlamaModel(config, linear_method)
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ from typing import List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
from vllm.model_executor.weight_utils import (
|
QuantizationConfig)
|
||||||
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -27,7 +28,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlavaConfig,
|
config: LlavaConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -35,7 +36,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||||
self.config.text_config.hidden_size = config.hidden_size
|
self.config.text_config.hidden_size = config.hidden_size
|
||||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||||
self.language_model = LlamaForCausalLM(config, linear_method)
|
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||||
self.language_model.model.image_newline = nn.Parameter(
|
self.language_model.model.image_newline = nn.Parameter(
|
||||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
|
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
|
||||||
"""Inference-only Mixtral model."""
|
"""Inference-only Mixtral model."""
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -10,24 +10,25 @@ from torch import nn
|
|||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.distributed import (
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -43,7 +44,7 @@ class MixtralMLP(nn.Module):
|
|||||||
num_experts: int,
|
num_experts: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
@@ -51,13 +52,13 @@ class MixtralMLP(nn.Module):
|
|||||||
self.hidden_dim = hidden_size
|
self.hidden_dim = hidden_size
|
||||||
|
|
||||||
self.w1 = ReplicatedLinear(
|
self.w1 = ReplicatedLinear(
|
||||||
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
|
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.w2 = ReplicatedLinear(
|
self.w2 = ReplicatedLinear(
|
||||||
self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
|
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.w3 = ReplicatedLinear(
|
self.w3 = ReplicatedLinear(
|
||||||
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
|
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Use vllm's SiluAndMul
|
# TODO: Use vllm's SiluAndMul
|
||||||
@@ -76,7 +77,7 @@ class MixtralMoE(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -103,7 +104,7 @@ class MixtralMoE(nn.Module):
|
|||||||
self.num_total_experts,
|
self.num_total_experts,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if idx in self.expert_indicies
|
if idx in self.expert_indicies
|
||||||
else None
|
else None
|
||||||
@@ -148,7 +149,7 @@ class MixtralAttention(nn.Module):
|
|||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
max_position: int = 4096 * 32,
|
max_position: int = 4096 * 32,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -180,13 +181,13 @@ class MixtralAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -222,7 +223,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -236,9 +237,9 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
sliding_window=config.sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
|
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
@@ -273,7 +274,7 @@ class MixtralModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
@@ -286,7 +287,7 @@ class MixtralModel(nn.Module):
|
|||||||
# config.num_hidden_layers=16
|
# config.num_hidden_layers=16
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MixtralDecoderLayer(config, i, linear_method=linear_method)
|
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -317,12 +318,12 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.quant_config = quant_config
|
||||||
self.model = MixtralModel(config, linear_method)
|
self.model = MixtralModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -6,20 +6,21 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -35,7 +36,7 @@ class QWenMLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str = "silu",
|
hidden_act: str = "silu",
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@@ -43,14 +44,14 @@ class QWenMLP(nn.Module):
|
|||||||
2 * [intermediate_size],
|
2 * [intermediate_size],
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.c_proj = RowParallelLinear(
|
self.c_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -75,7 +76,7 @@ class QWenAttention(nn.Module):
|
|||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -91,14 +92,14 @@ class QWenAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.c_proj = RowParallelLinear(
|
self.c_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -131,7 +132,7 @@ class QWenAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QWenBlock(nn.Module):
|
class QWenBlock(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
|
def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
@@ -144,7 +145,7 @@ class QWenBlock(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
@@ -152,7 +153,7 @@ class QWenBlock(nn.Module):
|
|||||||
self.mlp = QWenMLP(
|
self.mlp = QWenMLP(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.intermediate_size // 2,
|
config.intermediate_size // 2,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -180,7 +181,7 @@ class QWenBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QWenModel(nn.Module):
|
class QWenModel(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig, linear_method=None):
|
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
@@ -192,7 +193,7 @@ class QWenModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
QWenBlock(config, i, linear_method=linear_method)
|
QWenBlock(config, i, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -217,10 +218,10 @@ class QWenModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QWenLMHeadModel(nn.Module):
|
class QWenLMHeadModel(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig, linear_method=None):
|
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.transformer = QWenModel(config, linear_method=linear_method)
|
self.transformer = QWenModel(config, quant_config=quant_config)
|
||||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
@@ -275,4 +276,4 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = QWenLMHeadModel
|
EntryClass = QWenLMHeadModel
|
||||||
@@ -1,27 +1,28 @@
|
|||||||
# Adapted from llama2.py
|
# Adapted from llama2.py
|
||||||
# Modify details for the adaptation of Qwen2 model.
|
# Modify details for the adaptation of Qwen2 model.
|
||||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -39,17 +40,17 @@ class Qwen2MLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
[intermediate_size] * 2,
|
[intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, linear_method=linear_method
|
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -75,7 +76,7 @@ class Qwen2Attention(nn.Module):
|
|||||||
rope_theta: float = 1000000,
|
rope_theta: float = 1000000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 32768,
|
max_position_embeddings: int = 32768,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -106,13 +107,13 @@ class Qwen2Attention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
@@ -149,7 +150,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -164,13 +165,13 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2MLP(
|
self.mlp = Qwen2MLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
@@ -206,7 +207,7 @@ class Qwen2Model(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -218,7 +219,7 @@ class Qwen2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen2DecoderLayer(config, i, linear_method)
|
Qwen2DecoderLayer(config, i, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -252,12 +253,12 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2Model(config, linear_method)
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
|||||||
@@ -9,20 +9,21 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearMethodBase,
|
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
@@ -34,7 +35,7 @@ from sglang.srt.managers.router.model_runner import InputMetadata
|
|||||||
|
|
||||||
class StablelmMLP(nn.Module):
|
class StablelmMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
|
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -44,10 +45,10 @@ class StablelmMLP(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
[config.intermediate_size] * 2,
|
[config.intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
config.intermediate_size, config.hidden_size, bias=False
|
config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
@@ -63,7 +64,7 @@ class StablelmAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -105,13 +106,11 @@ class StablelmAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_key_value_heads,
|
self.total_num_key_value_heads,
|
||||||
self.qkv_bias,
|
self.qkv_bias,
|
||||||
linear_method=linear_method,
|
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method,
|
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -146,11 +145,11 @@ class StablelmDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = StablelmAttention(config, layer_id=layer_id)
|
self.self_attn = StablelmAttention(config, layer_id=layer_id)
|
||||||
self.mlp = StablelmMLP(config, linear_method)
|
self.mlp = StablelmMLP(config, quant_config=quant_config)
|
||||||
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||||
@@ -182,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class StableLMEpochModel(nn.Module):
|
class StableLMEpochModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
|
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
@@ -191,7 +190,7 @@ class StableLMEpochModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
StablelmDecoderLayer(config, i, linear_method)
|
StablelmDecoderLayer(config, i, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -224,12 +223,12 @@ class StableLmForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.quant_config = quant_config
|
||||||
self.model = StableLMEpochModel(config, linear_method)
|
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import CLIPVisionModel, LlavaConfig
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
from vllm.model_executor.weight_utils import (
|
from sglang.srt.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -504,6 +504,10 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
global chat_template_name
|
global chat_template_name
|
||||||
|
|
||||||
|
if server_args.enable_flashinfer:
|
||||||
|
from sglang.srt.utils import assert_pkg_version
|
||||||
|
assert_pkg_version("flashinfer", "0.0.4")
|
||||||
|
|
||||||
# start show time thread
|
# start show time thread
|
||||||
if server_args.show_time_cost:
|
if server_args.show_time_cost:
|
||||||
enable_show_time_cost()
|
enable_show_time_cost()
|
||||||
|
|||||||
@@ -5,12 +5,14 @@ import socket
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version as pkg_version
|
||||||
|
|
||||||
show_time_cost = False
|
show_time_cost = False
|
||||||
time_infos = {}
|
time_infos = {}
|
||||||
@@ -267,3 +269,15 @@ def load_image(image_file):
|
|||||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def assert_pkg_version(pkg: str, min_version: str):
|
||||||
|
try:
|
||||||
|
installed_version = version(pkg)
|
||||||
|
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
|
||||||
|
raise Exception(
|
||||||
|
f"{pkg} is installed with version {installed_version} which "
|
||||||
|
f"is less than the minimum required version {min_version}"
|
||||||
|
)
|
||||||
|
except PackageNotFoundError:
|
||||||
|
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
|
||||||
|
|||||||
402
python/sglang/srt/weight_utils.py
Normal file
402
python/sglang/srt/weight_utils.py
Normal file
@@ -0,0 +1,402 @@
|
|||||||
|
# The PR(https://github.com/vllm-project/vllm/pull/4097) of vllm borken the sglang code.
|
||||||
|
# In order to adapt to the latest code without modifying too much code,
|
||||||
|
# copied the previous vllm/model_executor/weight_utils.py
|
||||||
|
# Copied in https://github.com/vllm-project/vllm/blob/05434764cd99990035779cf9a4ed86623b528825/vllm/model_executor/weight_utils.py
|
||||||
|
|
||||||
|
"""Utilities for downloading and initializing model weights."""
|
||||||
|
import fnmatch
|
||||||
|
import glob
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import filelock
|
||||||
|
import huggingface_hub.constants
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import HfFileSystem, snapshot_download
|
||||||
|
from safetensors.torch import load_file, safe_open, save_file
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||||
|
get_quantization_config)
|
||||||
|
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# use system-level temp directory for file locks, so that multiple users
|
||||||
|
# can share the same lock without error.
|
||||||
|
# lock files in the temp directory will be automatically deleted when the
|
||||||
|
# system reboots, so users will not complain about annoying lock files
|
||||||
|
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
|
||||||
|
'TEMP') or os.environ.get('TMP') or "/tmp/"
|
||||||
|
|
||||||
|
|
||||||
|
def enable_hf_transfer():
|
||||||
|
"""automatically activates hf_transfer
|
||||||
|
"""
|
||||||
|
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||||
|
try:
|
||||||
|
# enable hf hub transfer if available
|
||||||
|
import hf_transfer # type: ignore # noqa
|
||||||
|
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
enable_hf_transfer()
|
||||||
|
|
||||||
|
|
||||||
|
class Disabledtqdm(tqdm):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs, disable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||||
|
lock_dir = cache_dir or temp_dir
|
||||||
|
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
||||||
|
model_name = model_name_or_path.replace("/", "-")
|
||||||
|
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
||||||
|
# add hash to avoid conflict with old users' lock files
|
||||||
|
lock_file_name = hash_name + model_name + ".lock"
|
||||||
|
# mode 0o666 is required for the filelock to be shared across users
|
||||||
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
|
||||||
|
mode=0o666)
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_pointers(tensors):
|
||||||
|
ptrs = defaultdict(list)
|
||||||
|
for k, v in tensors.items():
|
||||||
|
ptrs[v.data_ptr()].append(k)
|
||||||
|
failing = []
|
||||||
|
for _, names in ptrs.items():
|
||||||
|
if len(names) > 1:
|
||||||
|
failing.append(names)
|
||||||
|
return failing
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bin_to_safetensor_file(
|
||||||
|
pt_filename: str,
|
||||||
|
sf_filename: str,
|
||||||
|
) -> None:
|
||||||
|
loaded = torch.load(pt_filename, map_location="cpu")
|
||||||
|
if "state_dict" in loaded:
|
||||||
|
loaded = loaded["state_dict"]
|
||||||
|
shared = _shared_pointers(loaded)
|
||||||
|
for shared_weights in shared:
|
||||||
|
for name in shared_weights[1:]:
|
||||||
|
loaded.pop(name)
|
||||||
|
|
||||||
|
# For tensors to be contiguous
|
||||||
|
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||||
|
|
||||||
|
dirname = os.path.dirname(sf_filename)
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
# check file size
|
||||||
|
sf_size = os.stat(sf_filename).st_size
|
||||||
|
pt_size = os.stat(pt_filename).st_size
|
||||||
|
if (sf_size - pt_size) / pt_size > 0.01:
|
||||||
|
raise RuntimeError(f"""The file size different is more than 1%:
|
||||||
|
- {sf_filename}: {sf_size}
|
||||||
|
- {pt_filename}: {pt_size}
|
||||||
|
""")
|
||||||
|
|
||||||
|
# check if the tensors are the same
|
||||||
|
reloaded = load_file(sf_filename)
|
||||||
|
for k in loaded:
|
||||||
|
pt_tensor = loaded[k]
|
||||||
|
sf_tensor = reloaded[k]
|
||||||
|
if not torch.equal(pt_tensor, sf_tensor):
|
||||||
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(woosuk): Move this to other place.
|
||||||
|
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||||
|
quant_cls = get_quantization_config(model_config.quantization)
|
||||||
|
# Read the quantization config from the HF model config, if available.
|
||||||
|
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||||
|
None)
|
||||||
|
if hf_quant_config is not None:
|
||||||
|
return quant_cls.from_config(hf_quant_config)
|
||||||
|
model_name_or_path = model_config.model
|
||||||
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
if not is_local:
|
||||||
|
# Download the config files.
|
||||||
|
with get_lock(model_name_or_path, model_config.download_dir):
|
||||||
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
|
revision=model_config.revision,
|
||||||
|
allow_patterns="*.json",
|
||||||
|
cache_dir=model_config.download_dir,
|
||||||
|
tqdm_class=Disabledtqdm)
|
||||||
|
else:
|
||||||
|
hf_folder = model_name_or_path
|
||||||
|
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||||
|
|
||||||
|
quant_config_files = [
|
||||||
|
f for f in config_files if any(
|
||||||
|
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||||
|
]
|
||||||
|
if len(quant_config_files) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot find the config file for {model_config.quantization}")
|
||||||
|
if len(quant_config_files) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found multiple config files for {model_config.quantization}: "
|
||||||
|
f"{quant_config_files}")
|
||||||
|
|
||||||
|
quant_config_file = quant_config_files[0]
|
||||||
|
with open(quant_config_file, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return quant_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_hf_model_weights(
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
fall_back_to_pt: bool = True,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
) -> Tuple[str, List[str], bool]:
|
||||||
|
# Download model weights from huggingface.
|
||||||
|
is_local = os.path.isdir(model_name_or_path) \
|
||||||
|
and load_format != "tensorizer"
|
||||||
|
use_safetensors = False
|
||||||
|
# Some quantized models use .pt files for storing the weights.
|
||||||
|
if load_format == "auto":
|
||||||
|
allow_patterns = ["*.safetensors", "*.bin"]
|
||||||
|
elif load_format == "safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
allow_patterns = ["*.safetensors"]
|
||||||
|
elif load_format == "pt":
|
||||||
|
allow_patterns = ["*.pt"]
|
||||||
|
elif load_format == "npcache":
|
||||||
|
allow_patterns = ["*.bin"]
|
||||||
|
elif load_format == "tensorizer":
|
||||||
|
allow_patterns = ["*.tensors"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
|
|
||||||
|
if fall_back_to_pt:
|
||||||
|
allow_patterns += ["*.pt"]
|
||||||
|
|
||||||
|
if not is_local and load_format != "tensorizer":
|
||||||
|
# Before we download we look at that is available:
|
||||||
|
fs = HfFileSystem()
|
||||||
|
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||||
|
|
||||||
|
# depending on what is available we download different things
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
matching = fnmatch.filter(file_list, pattern)
|
||||||
|
if len(matching) > 0:
|
||||||
|
allow_patterns = [pattern]
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"Using model weights format {allow_patterns}")
|
||||||
|
# Use file lock to prevent multiple processes from
|
||||||
|
# downloading the same model weights at the same time.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
tqdm_class=Disabledtqdm,
|
||||||
|
revision=revision)
|
||||||
|
else:
|
||||||
|
hf_folder = model_name_or_path
|
||||||
|
hf_weights_files: List[str] = []
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
|
if len(hf_weights_files) > 0:
|
||||||
|
if pattern == "*.safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
break
|
||||||
|
if not use_safetensors:
|
||||||
|
# Exclude files that are not needed for inference.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||||
|
blacklist = [
|
||||||
|
"training_args.bin",
|
||||||
|
"optimizer.bin",
|
||||||
|
"optimizer.pt",
|
||||||
|
"scheduler.pt",
|
||||||
|
"scaler.pt",
|
||||||
|
]
|
||||||
|
hf_weights_files = [
|
||||||
|
f for f in hf_weights_files
|
||||||
|
if not any(f.endswith(x) for x in blacklist)
|
||||||
|
]
|
||||||
|
|
||||||
|
if load_format == "tensorizer":
|
||||||
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
if len(hf_weights_files) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
|
|
||||||
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
def hf_model_weights_iterator(
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: Union[Tuple, str] = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
fall_back_to_pt: Optional[bool] = True,
|
||||||
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
|
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||||
|
model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
load_format=load_format,
|
||||||
|
fall_back_to_pt=fall_back_to_pt,
|
||||||
|
revision=revision)
|
||||||
|
|
||||||
|
if load_format == "npcache":
|
||||||
|
# Currently np_cache only support *.bin checkpoints
|
||||||
|
assert use_safetensors is False
|
||||||
|
|
||||||
|
# Convert the model weights from torch tensors to numpy arrays for
|
||||||
|
# faster loading.
|
||||||
|
np_folder = os.path.join(hf_folder, "np")
|
||||||
|
os.makedirs(np_folder, exist_ok=True)
|
||||||
|
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||||
|
# Use file lock to prevent multiple processes from
|
||||||
|
# dumping the same model weights to numpy at the same time.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
|
if not os.path.exists(weight_names_file):
|
||||||
|
weight_names = []
|
||||||
|
for bin_file in hf_weights_files:
|
||||||
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
|
for name, param in state.items():
|
||||||
|
param_path = os.path.join(np_folder, name)
|
||||||
|
with open(param_path, "wb") as f:
|
||||||
|
np.save(f, param.cpu().detach().numpy())
|
||||||
|
weight_names.append(name)
|
||||||
|
with open(weight_names_file, "w") as f:
|
||||||
|
json.dump(weight_names, f)
|
||||||
|
|
||||||
|
with open(weight_names_file, "r") as f:
|
||||||
|
weight_names = json.load(f)
|
||||||
|
|
||||||
|
for name in weight_names:
|
||||||
|
param_path = os.path.join(np_folder, name)
|
||||||
|
with open(param_path, "rb") as f:
|
||||||
|
param = np.load(f)
|
||||||
|
yield name, torch.from_numpy(param)
|
||||||
|
elif load_format == "tensorizer":
|
||||||
|
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
|
||||||
|
open_stream,
|
||||||
|
tensorizer_warning)
|
||||||
|
tensorizer_args = load_format.params
|
||||||
|
tensorizer_warning(
|
||||||
|
"Deserializing HuggingFace models is not optimized for "
|
||||||
|
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||||
|
"Consider deserializing a vLLM model instead for faster "
|
||||||
|
"load times. See the examples/tensorize_vllm_model.py example "
|
||||||
|
"script for serializing vLLM models.")
|
||||||
|
|
||||||
|
deserializer_args = tensorizer_args.deserializer_params
|
||||||
|
stream_params = tensorizer_args.stream_params
|
||||||
|
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||||
|
with TensorDeserializer(stream, **deserializer_args,
|
||||||
|
device="cpu") as state:
|
||||||
|
for name, param in state.items():
|
||||||
|
yield name, param
|
||||||
|
del state
|
||||||
|
elif use_safetensors:
|
||||||
|
for st_file in hf_weights_files:
|
||||||
|
with safe_open(st_file, framework="pt") as f:
|
||||||
|
for name in f.keys(): # noqa: SIM118
|
||||||
|
param = f.get_tensor(name)
|
||||||
|
yield name, param
|
||||||
|
else:
|
||||||
|
for bin_file in hf_weights_files:
|
||||||
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
|
for name, param in state.items():
|
||||||
|
yield name, param
|
||||||
|
del state
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def kv_cache_scales_loader(
|
||||||
|
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
|
||||||
|
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
|
||||||
|
"""
|
||||||
|
A simple utility to read in KV cache scaling factors that have been
|
||||||
|
previously serialized to disk. Used by the model to populate the appropriate
|
||||||
|
KV cache scaling factors. The serialization should represent a dictionary
|
||||||
|
whose keys are the TP ranks and values are another dictionary mapping layers
|
||||||
|
to their KV cache scaling factors.
|
||||||
|
Keep this function in sync with the output of examples/fp8/extract_scales.py
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(filename) as f:
|
||||||
|
context = {
|
||||||
|
"model_type": model_type,
|
||||||
|
"num_hidden_layers": num_hidden_layers,
|
||||||
|
"tp_rank": tp_rank,
|
||||||
|
"tp_size": tp_size,
|
||||||
|
}
|
||||||
|
schema_dct = json.load(f)
|
||||||
|
schema = QuantParamSchema.model_validate(schema_dct,
|
||||||
|
context=context)
|
||||||
|
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||||||
|
return layer_scales_map.items()
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"File or directory '{filename}' not found.")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Error decoding JSON in file '{filename}'.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An error occurred while reading '{filename}': {e}")
|
||||||
|
# This section is reached if and only if any of the excepts are hit
|
||||||
|
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||||
|
# which ultimately defaults to 1.0 scales
|
||||||
|
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
|
||||||
|
f"for all layers in TP rank {tp_rank} "
|
||||||
|
"as an error occurred during loading.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||||
|
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||||
|
|
||||||
|
PySafeSlice object supports indexing, which is done before loading the
|
||||||
|
actual tensor and can reduce the amount of memory being read into the
|
||||||
|
memory. However, it does not support more advanced functionalities
|
||||||
|
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||||
|
tensor with these more complicated operators, we need to convert to
|
||||||
|
tensor first.
|
||||||
|
"""
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
x = x[:]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def default_weight_loader(param: torch.Tensor,
|
||||||
|
loaded_weight: torch.Tensor) -> None:
|
||||||
|
"""Default weight loader."""
|
||||||
|
assert param.size() == loaded_weight.size()
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_dummy_weights(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
low: float = -1e-3,
|
||||||
|
high: float = 1e-3,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize model weights with random values.
|
||||||
|
|
||||||
|
The model weights must be randomly initialized for accurate performance
|
||||||
|
measurements. Additionally, the model weights should not cause NaNs in the
|
||||||
|
forward pass. We empirically found that initializing the weights with
|
||||||
|
values between -1e-3 and 1e-3 works well for most models.
|
||||||
|
"""
|
||||||
|
for param in model.state_dict().values():
|
||||||
|
if torch.is_floating_point(param):
|
||||||
|
param.data.uniform_(low, high)
|
||||||
@@ -226,7 +226,7 @@ Action 3: Finish [United States].\n
|
|||||||
|
|
||||||
def test_parallel_decoding():
|
def test_parallel_decoding():
|
||||||
max_tokens = 64
|
max_tokens = 64
|
||||||
number = 5
|
fork_size = 5
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def parallel_decoding(s, topic):
|
def parallel_decoding(s, topic):
|
||||||
@@ -234,17 +234,17 @@ def test_parallel_decoding():
|
|||||||
s += "USER: Give some tips for " + topic + ".\n"
|
s += "USER: Give some tips for " + topic + ".\n"
|
||||||
s += (
|
s += (
|
||||||
"ASSISTANT: Okay. Here are "
|
"ASSISTANT: Okay. Here are "
|
||||||
+ str(number)
|
+ str(fork_size)
|
||||||
+ " concise tips, each under 8 words:\n"
|
+ " concise tips, each under 8 words:\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate skeleton
|
# Generate skeleton
|
||||||
for i in range(1, 1 + number):
|
for i in range(1, 1 + fork_size):
|
||||||
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
|
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
|
||||||
|
|
||||||
# Generate detailed tips
|
# Generate detailed tips
|
||||||
forks = s.fork(number)
|
forks = s.fork(fork_size)
|
||||||
for i in range(number):
|
for i in range(fork_size):
|
||||||
forks[
|
forks[
|
||||||
i
|
i
|
||||||
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
|
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
|
||||||
@@ -253,7 +253,7 @@ def test_parallel_decoding():
|
|||||||
|
|
||||||
# Concatenate tips and summarize
|
# Concatenate tips and summarize
|
||||||
s += "Here are these tips with detailed explanation:\n"
|
s += "Here are these tips with detailed explanation:\n"
|
||||||
for i in range(number):
|
for i in range(fork_size):
|
||||||
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
|
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
|
||||||
|
|
||||||
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
||||||
|
|||||||
Reference in New Issue
Block a user