From 33b242df303e03886835d08a583fefe979a3ee88 Mon Sep 17 00:00:00 2001 From: Qubitium <417764+Qubitium@users.noreply.github.com> Date: Sun, 12 May 2024 07:37:49 +0800 Subject: [PATCH] Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380) Co-authored-by: ZX Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com> --- python/pyproject.toml | 2 +- python/sglang/lang/interpreter.py | 14 +- python/sglang/lang/tracer.py | 10 +- python/sglang/srt/layers/logits_processor.py | 2 +- .../sglang/srt/managers/router/model_rpc.py | 7 +- .../srt/managers/router/model_runner.py | 20 +- python/sglang/srt/models/commandr.py | 37 +- python/sglang/srt/models/dbrx.py | 39 +- python/sglang/srt/models/gemma.py | 35 +- python/sglang/srt/models/llama2.py | 37 +- python/sglang/srt/models/llava.py | 11 +- python/sglang/srt/models/mixtral.py | 45 +- python/sglang/srt/models/qwen.py | 37 +- python/sglang/srt/models/qwen2.py | 37 +- python/sglang/srt/models/stablelm.py | 31 +- python/sglang/srt/models/yivl.py | 2 +- python/sglang/srt/server.py | 4 + python/sglang/srt/utils.py | 14 + python/sglang/srt/weight_utils.py | 402 ++++++++++++++++++ python/sglang/test/test_programs.py | 12 +- 20 files changed, 611 insertions(+), 187 deletions(-) create mode 100644 python/sglang/srt/weight_utils.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 073b642a3..6966d1452 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", - "zmq", "vllm>=0.3.3,<=0.4.0", "interegular", "pydantic", "pillow", "outlines>=0.0.27"] + "zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "flashinfer>=0.0.4", "packaging"] openai = ["openai>=1.0", "numpy", "tiktoken"] anthropic = ["anthropic>=0.20.0", "numpy"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index ebc65a0e0..fa5800e09 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -266,14 +266,14 @@ class StreamExecutor: def fork( self, - number: int, + size: int = 1, position_ids_offset: Optional[List[int]] = None, ): - if number > 1: + if size > 1: self.submit(SglCommitLazy()) self.sync() - number = int(number) + size = int(size) exes = [ StreamExecutor( @@ -283,9 +283,9 @@ class StreamExecutor: self.chat_template, self.stream, ) - for _ in range(number) + for _ in range(size) ] - for i in range(number): + for i in range(size): exes[i].variables = dict(self.variables) exes[i].text_ = str(self.text_) exes[i].messages_ = list(self.messages_) @@ -656,10 +656,10 @@ class ProgramState: def fork( self, - number: int = 1, + size: int = 1, position_ids_offset: Optional[List[int]] = None, ): - stream_executors = self.stream_executor.fork(number, position_ids_offset) + stream_executors = self.stream_executor.fork(size, position_ids_offset) states = [ProgramState(x) for x in stream_executors] state_group = ProgramStateGroup(states, self) return state_group diff --git a/python/sglang/lang/tracer.py b/python/sglang/lang/tracer.py index 74ac9b998..b506c44e1 100644 --- a/python/sglang/lang/tracer.py +++ b/python/sglang/lang/tracer.py @@ -109,19 +109,21 @@ class TracerProgramState(ProgramState): ########### Public API ########### ################################## - def fork(self, number: int, position_ids_offset: Optional[List[int]] = None): + def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None): + assert (size >= 1) + if self.only_trace_prefix: raise StopTracing() - fork_node = SglFork(number) + fork_node = SglFork(size) fork_node.prev_node = self.last_node states = [ TracerProgramState(self.backend, self.arguments, self.only_trace_prefix) - for _ in range(number) + for _ in range(size) ] - for i in range(number): + for i in range(size): node = SglGetForkItem(i) node.prev_node = fork_node states[i].last_node = node diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 617dcdf3e..f95c30786 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -1,6 +1,6 @@ import torch from torch import nn -from vllm.model_executor.parallel_utils.communication_op import ( +from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 2d30409e5..f283635c3 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -10,7 +10,10 @@ import rpyc import torch from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer -from vllm.logger import _default_handler as vllm_default_handler +try: + from vllm.logger import _default_handler as vllm_default_logger +except ImportError: + from vllm.logger import logger as vllm_default_logger from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache @@ -50,7 +53,7 @@ class ModelRpcServer: self.tp_size = server_args.tp_size self.schedule_heuristic = server_args.schedule_heuristic self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - vllm_default_handler.setLevel( + vllm_default_logger.setLevel( level=getattr(logging, server_args.log_level.upper()) ) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 0837c51bb..b2a0daf5b 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -12,8 +12,8 @@ import torch from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.model_loader import _set_default_torch_dtype -from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.distributed import initialize_model_parallel from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool @@ -142,16 +142,9 @@ class InputMetadata: self.kv_last_page_len, self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size, + self.model_runner.model_config.head_dim ] - # flashinfer >= 0.0.3 - # FIXME: Drop this when flashinfer updates to 0.0.4 - if ( - len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) - == 7 - ): - args.append(self.model_runner.model_config.head_dim) - self.prefill_wrapper.begin_forward(*args) else: self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( @@ -304,7 +297,7 @@ class ModelRunner: logger.info(f"Rank {self.tp_rank}: load weight begin.") # Load weights - linear_method = None + quant_config = None quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None) if quant_cfg is not None: @@ -326,12 +319,11 @@ class ModelRunner: quant_config = quant_config_class.from_config(quant_cfg) logger.info(f"quant_config: {quant_config}") - linear_method = quant_config.get_linear_method() - with _set_default_torch_dtype(torch.float16): + with set_default_torch_dtype(torch.float16): with torch.device("cuda"): model = model_class( - config=self.model_config.hf_config, linear_method=linear_method + config=self.model_config.hf_config, quant_config=quant_config ) model.load_weights( self.model_config.path, diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 74bf9dcdf..631e8c7f4 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -20,7 +20,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch import torch.utils.checkpoint @@ -29,19 +29,20 @@ from torch.nn.parameter import Parameter from transformers import PretrainedConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -92,7 +93,7 @@ class CohereMLP(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -102,13 +103,13 @@ class CohereMLP(nn.Module): self.hidden_size, [self.intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.act_fn = SiluAndMul() @@ -124,7 +125,7 @@ class CohereAttention(nn.Module): self, config: PretrainedConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -159,13 +160,13 @@ class CohereAttention(nn.Module): self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -221,16 +222,16 @@ class CohereDecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size self.self_attn = CohereAttention( - config, layer_id=layer_id, linear_method=linear_method + config, layer_id=layer_id, quant_config=quant_config ) - self.mlp = CohereMLP(config, linear_method=linear_method) + self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm( param_shape=(config.hidden_size), eps=config.layer_norm_eps ) @@ -261,7 +262,7 @@ class CohereModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -271,7 +272,7 @@ class CohereModel(nn.Module): ) self.layers = nn.ModuleList( [ - CohereDecoderLayer(config, i, linear_method=linear_method) + CohereDecoderLayer(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -303,13 +304,13 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.logits_processor = LogitsProcessor(config) - self.model = CohereModel(config, linear_method) + self.model = CohereModel(config, quant_config) @torch.no_grad() def forward( diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 50215a2ef..4b30a1f57 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -7,26 +7,27 @@ import torch import torch.nn as nn from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.linear import ( - LinearMethodBase, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.parallel_utils.communication_op import ( +from vllm.distributed import ( tensor_model_parallel_all_reduce, ) -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -56,7 +57,7 @@ class DbrxRouter(nn.Module): self.num_total_experts, bias=False, params_dtype=params_dtype, - linear_method=None, + quant_config=None, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -75,7 +76,7 @@ class DbrxExperts(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, ): super().__init__() @@ -176,7 +177,7 @@ class DbrxAttention(nn.Module): self, config: DbrxConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model @@ -194,13 +195,13 @@ class DbrxAttention(nn.Module): self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( self.d_model, self.d_model, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -255,11 +256,11 @@ class DbrxFusedNormAttention(nn.Module): self, config: DbrxConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, layer_id, linear_method) + self.attn = DbrxAttention(config, layer_id, quant_config=quant_config) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -287,11 +288,11 @@ class DbrxBlock(nn.Module): self, config: DbrxConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, linear_method) - self.ffn = DbrxExperts(config, linear_method) + self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config) + self.ffn = DbrxExperts(config, quant_config=quant_config) def forward( self, @@ -313,7 +314,7 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.wte = VocabParallelEmbedding( @@ -321,7 +322,7 @@ class DbrxModel(nn.Module): config.d_model, ) self.blocks = nn.ModuleList( - [DbrxBlock(config, i, linear_method) for i in range(config.n_layers)] + [DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)] ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): @@ -351,13 +352,13 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, linear_method) + self.transformer = DbrxModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 37b352803..4b0b00479 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -10,17 +10,18 @@ from vllm.config import LoRAConfig from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_world_size, ) -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -35,17 +36,17 @@ class GemmaMLP(nn.Module): self, hidden_size: int, intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, linear_method=linear_method + intermediate_size, hidden_size, bias=False, quant_config=quant_config, ) self.act_fn = GeluAndMul() @@ -66,7 +67,7 @@ class GemmaAttention(nn.Module): layer_id: int = 0, max_position_embeddings: int = 8192, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -96,13 +97,13 @@ class GemmaAttention(nn.Module): self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -139,7 +140,7 @@ class GemmaDecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -151,12 +152,12 @@ class GemmaDecoderLayer(nn.Module): layer_id=layer_id, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = GemmaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -192,7 +193,7 @@ class GemmaModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -203,7 +204,7 @@ class GemmaModel(nn.Module): ) self.layers = nn.ModuleList( [ - GemmaDecoderLayer(config, i, linear_method) + GemmaDecoderLayer(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -264,14 +265,14 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config # Unused. super().__init__() self.config = config - self.linear_method = linear_method - self.model = GemmaModel(config, linear_method) + self.quant_config = quant_config + self.model = GemmaModel(config, quant_config=quant_config) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 2f366d158..26c412871 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -1,7 +1,7 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1 """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch from torch import nn @@ -9,20 +9,21 @@ from transformers import LlamaConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_world_size, ) -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -38,17 +39,17 @@ class LlamaMLP(nn.Module): hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, linear_method=linear_method + intermediate_size, hidden_size, bias=False, quant_config=quant_config, ) if hidden_act != "silu": raise ValueError( @@ -74,7 +75,7 @@ class LlamaAttention(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -105,13 +106,13 @@ class LlamaAttention(nn.Module): self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -148,7 +149,7 @@ class LlamaDecoderLayer(nn.Module): self, config: LlamaConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -163,13 +164,13 @@ class LlamaDecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -205,7 +206,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -217,7 +218,7 @@ class LlamaModel(nn.Module): ) self.layers = nn.ModuleList( [ - LlamaDecoderLayer(config, i, linear_method) + LlamaDecoderLayer(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -251,12 +252,12 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = LlamaModel(config, linear_method) + self.quant_config = quant_config + self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index aca97d3b4..232aee1d3 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -5,10 +5,11 @@ from typing import List, Optional import numpy as np import torch from torch import nn -from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig +from transformers import CLIPVisionModel, LlavaConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.layers.linear import LinearMethodBase -from vllm.model_executor.weight_utils import ( +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -27,7 +28,7 @@ class LlavaLlamaForCausalLM(nn.Module): def __init__( self, config: LlavaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -35,7 +36,7 @@ class LlavaLlamaForCausalLM(nn.Module): self.config.vision_config.hidden_size = config.mm_hidden_size self.config.text_config.hidden_size = config.hidden_size self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = LlamaForCausalLM(config, linear_method) + self.language_model = LlamaForCausalLM(config, quant_config=quant_config) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size, dtype=torch.float16) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index ed7ef24d0..99d81ce74 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -1,7 +1,7 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1 """Inference-only Mixtral model.""" -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import torch @@ -10,24 +10,25 @@ from torch import nn from transformers import MixtralConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - LinearMethodBase, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.parallel_utils.communication_op import ( +from vllm.distributed import ( tensor_model_parallel_all_reduce, ) -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -43,7 +44,7 @@ class MixtralMLP(nn.Module): num_experts: int, hidden_size: int, intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.num_experts = num_experts @@ -51,13 +52,13 @@ class MixtralMLP(nn.Module): self.hidden_dim = hidden_size self.w1 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config ) self.w2 = ReplicatedLinear( - self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method + self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config ) self.w3 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config ) # TODO: Use vllm's SiluAndMul @@ -76,7 +77,7 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -103,7 +104,7 @@ class MixtralMoE(nn.Module): self.num_total_experts, config.hidden_size, config.intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) if idx in self.expert_indicies else None @@ -148,7 +149,7 @@ class MixtralAttention(nn.Module): layer_id: int = 0, max_position: int = 4096 * 32, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None, ) -> None: super().__init__() @@ -180,13 +181,13 @@ class MixtralAttention(nn.Module): self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -222,7 +223,7 @@ class MixtralDecoderLayer(nn.Module): self, config: MixtralConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -236,9 +237,9 @@ class MixtralDecoderLayer(nn.Module): layer_id=layer_id, rope_theta=rope_theta, sliding_window=config.sliding_window, - linear_method=linear_method, + quant_config=quant_config, ) - self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method) + self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -273,7 +274,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -286,7 +287,7 @@ class MixtralModel(nn.Module): # config.num_hidden_layers=16 self.layers = nn.ModuleList( [ - MixtralDecoderLayer(config, i, linear_method=linear_method) + MixtralDecoderLayer(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -317,12 +318,12 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = MixtralModel(config, linear_method) + self.quant_config = quant_config + self.model = MixtralModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index e7fee4a92..9d157f81f 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional import torch from torch import nn @@ -6,20 +6,21 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_world_size, ) -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -35,7 +36,7 @@ class QWenMLP(nn.Module): hidden_size: int, intermediate_size: int, hidden_act: str = "silu", - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -43,14 +44,14 @@ class QWenMLP(nn.Module): 2 * [intermediate_size], bias=False, gather_output=False, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, input_is_parallel=True, - linear_method=linear_method, + quant_config=quant_config, ) if hidden_act != "silu": raise ValueError( @@ -75,7 +76,7 @@ class QWenAttention(nn.Module): layer_id: int = 0, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size @@ -91,14 +92,14 @@ class QWenAttention(nn.Module): self.head_dim, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -131,7 +132,7 @@ class QWenAttention(nn.Module): class QWenBlock(nn.Module): - def __init__(self, config: PretrainedConfig, layer_id, linear_method=None): + def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -144,7 +145,7 @@ class QWenBlock(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, layer_id=layer_id, - linear_method=linear_method, + quant_config=quant_config, ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -152,7 +153,7 @@ class QWenBlock(nn.Module): self.mlp = QWenMLP( config.hidden_size, config.intermediate_size // 2, - linear_method=linear_method, + quant_config=quant_config, ) def forward( @@ -180,7 +181,7 @@ class QWenBlock(nn.Module): class QWenModel(nn.Module): - def __init__(self, config: PretrainedConfig, linear_method=None): + def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -192,7 +193,7 @@ class QWenModel(nn.Module): ) self.h = nn.ModuleList( [ - QWenBlock(config, i, linear_method=linear_method) + QWenBlock(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -217,10 +218,10 @@ class QWenModel(nn.Module): class QWenLMHeadModel(nn.Module): - def __init__(self, config: PretrainedConfig, linear_method=None): + def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,): super().__init__() self.config = config - self.transformer = QWenModel(config, linear_method=linear_method) + self.transformer = QWenModel(config, quant_config=quant_config) vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -275,4 +276,4 @@ class QWenLMHeadModel(nn.Module): weight_loader(param, loaded_weight) -EntryClass = QWenLMHeadModel +EntryClass = QWenLMHeadModel \ No newline at end of file diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index e38941990..dc1dd0de3 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -1,27 +1,28 @@ # Adapted from llama2.py # Modify details for the adaptation of Qwen2 model. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch from torch import nn from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_world_size, ) -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -39,17 +40,17 @@ class Qwen2MLP(nn.Module): hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, linear_method=linear_method + intermediate_size, hidden_size, bias=False, quant_config=quant_config, ) if hidden_act != "silu": raise ValueError( @@ -75,7 +76,7 @@ class Qwen2Attention(nn.Module): rope_theta: float = 1000000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 32768, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -106,13 +107,13 @@ class Qwen2Attention(nn.Module): self.total_num_heads, self.total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -149,7 +150,7 @@ class Qwen2DecoderLayer(nn.Module): self, config: Qwen2Config, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -164,13 +165,13 @@ class Qwen2DecoderLayer(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -206,7 +207,7 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -218,7 +219,7 @@ class Qwen2Model(nn.Module): ) self.layers = nn.ModuleList( [ - Qwen2DecoderLayer(config, i, linear_method) + Qwen2DecoderLayer(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -252,12 +253,12 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = Qwen2Model(config, linear_method) + self.quant_config = quant_config + self.model = Qwen2Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 9d559ecfa..7ad495c95 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -9,20 +9,21 @@ from torch import nn from transformers import PretrainedConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.parallel_utils.parallel_state import ( +from vllm.distributed import ( get_tensor_model_parallel_world_size, ) -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) @@ -34,7 +35,7 @@ from sglang.srt.managers.router.model_runner import InputMetadata class StablelmMLP(nn.Module): def __init__( - self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None + self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -44,10 +45,10 @@ class StablelmMLP(nn.Module): config.hidden_size, [config.intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( - config.intermediate_size, config.hidden_size, bias=False + config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config, ) self.act_fn = SiluAndMul() @@ -63,7 +64,7 @@ class StablelmAttention(nn.Module): self, config: PretrainedConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -105,13 +106,11 @@ class StablelmAttention(nn.Module): self.total_num_heads, self.total_num_key_value_heads, self.qkv_bias, - linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method, ) self.rotary_emb = get_rope( self.head_dim, @@ -146,11 +145,11 @@ class StablelmDecoderLayer(nn.Module): self, config: PretrainedConfig, layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.self_attn = StablelmAttention(config, layer_id=layer_id) - self.mlp = StablelmMLP(config, linear_method) + self.mlp = StablelmMLP(config, quant_config=quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) @@ -182,7 +181,7 @@ class StablelmDecoderLayer(nn.Module): class StableLMEpochModel(nn.Module): def __init__( - self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None + self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( @@ -191,7 +190,7 @@ class StableLMEpochModel(nn.Module): ) self.layers = nn.ModuleList( [ - StablelmDecoderLayer(config, i, linear_method) + StablelmDecoderLayer(config, i, quant_config=quant_config) for i in range(config.num_hidden_layers) ] ) @@ -224,12 +223,12 @@ class StableLmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = StableLMEpochModel(config, linear_method) + self.quant_config = quant_config + self.model = StableLMEpochModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 0e6c87811..6f4a9b59f 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -6,7 +6,7 @@ from typing import List, Optional import torch import torch.nn as nn from transformers import CLIPVisionModel, LlavaConfig -from vllm.model_executor.weight_utils import ( +from sglang.srt.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 79ed26c93..4ad5701c5 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -504,6 +504,10 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): global tokenizer_manager global chat_template_name + if server_args.enable_flashinfer: + from sglang.srt.utils import assert_pkg_version + assert_pkg_version("flashinfer", "0.0.4") + # start show time thread if server_args.show_time_cost: enable_show_time_cost() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0f7322bb6..6b2c258d1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -5,12 +5,14 @@ import socket import sys import time import traceback +from importlib.metadata import PackageNotFoundError, version from io import BytesIO from typing import List, Optional import numpy as np import requests import torch +from packaging import version as pkg_version show_time_cost = False time_infos = {} @@ -267,3 +269,15 @@ def load_image(image_file): image = Image.open(BytesIO(base64.b64decode(image_file))) return image + + +def assert_pkg_version(pkg: str, min_version: str): + try: + installed_version = version(pkg) + if pkg_version.parse(installed_version) < pkg_version.parse(min_version): + raise Exception( + f"{pkg} is installed with version {installed_version} which " + f"is less than the minimum required version {min_version}" + ) + except PackageNotFoundError: + raise Exception(f"{pkg} with minimum required version {min_version} is not installed") diff --git a/python/sglang/srt/weight_utils.py b/python/sglang/srt/weight_utils.py new file mode 100644 index 000000000..0df3468c2 --- /dev/null +++ b/python/sglang/srt/weight_utils.py @@ -0,0 +1,402 @@ +# The PR(https://github.com/vllm-project/vllm/pull/4097) of vllm borken the sglang code. +# In order to adapt to the latest code without modifying too much code, +# copied the previous vllm/model_executor/weight_utils.py +# Copied in https://github.com/vllm-project/vllm/blob/05434764cd99990035779cf9a4ed86623b528825/vllm/model_executor/weight_utils.py + +"""Utilities for downloading and initializing model weights.""" +import fnmatch +import glob +import hashlib +import json +import os +from collections import defaultdict +from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union + +import filelock +import huggingface_hub.constants +import numpy as np +import torch +from huggingface_hub import HfFileSystem, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (QuantizationConfig, + get_quantization_config) +from vllm.model_executor.layers.quantization.schema import QuantParamSchema + +logger = init_logger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = os.environ.get('TMPDIR') or os.environ.get( + 'TEMP') or os.environ.get('TMP') or "/tmp/" + + +def enable_hf_transfer(): + """automatically activates hf_transfer + """ + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class Disabledtqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), + mode=0o666) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, model_config.download_dir): + hf_folder = snapshot_download(model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=model_config.download_dir, + tqdm_class=Disabledtqdm) + else: + hf_folder = model_name_or_path + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in quant_cls.get_config_filenames()) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + return quant_cls.from_config(config) + + +def prepare_hf_model_weights( + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + fall_back_to_pt: bool = True, + revision: Optional[str] = None, +) -> Tuple[str, List[str], bool]: + # Download model weights from huggingface. + is_local = os.path.isdir(model_name_or_path) \ + and load_format != "tensorizer" + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == "auto": + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == "safetensors": + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == "pt": + allow_patterns = ["*.pt"] + elif load_format == "npcache": + allow_patterns = ["*.bin"] + elif load_format == "tensorizer": + allow_patterns = ["*.tensors"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local and load_format != "tensorizer": + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info(f"Using model weights format {allow_patterns}") + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download(model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=Disabledtqdm, + revision=revision) + else: + hf_folder = model_name_or_path + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + if not use_safetensors: + # Exclude files that are not needed for inference. + # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + + if load_format == "tensorizer": + return hf_folder, hf_weights_files, use_safetensors + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + +def hf_model_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: Union[Tuple, str] = "auto", + revision: Optional[str] = None, + fall_back_to_pt: Optional[bool] = True, +) -> Iterator[Tuple[str, torch.Tensor]]: + hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( + model_name_or_path, + cache_dir=cache_dir, + load_format=load_format, + fall_back_to_pt=fall_back_to_pt, + revision=revision) + + if load_format == "npcache": + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names = [] + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + elif load_format == "tensorizer": + from vllm.model_executor.tensorizer_loader import (TensorDeserializer, + open_stream, + tensorizer_warning) + tensorizer_args = load_format.params + tensorizer_warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state + elif use_safetensors: + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + else: + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def kv_cache_scales_loader( + filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, + model_type: Optional[str]) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + Keep this function in sync with the output of examples/fp8/extract_scales.py + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, + context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + + except FileNotFoundError: + logger.error(f"File or directory '{filename}' not found.") + except json.JSONDecodeError: + logger.error(f"Error decoding JSON in file '{filename}'.") + except Exception as e: + logger.error(f"An error occurred while reading '{filename}': {e}") + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning("Defaulting to KV cache scaling factors = 1.0 " + f"for all layers in TP rank {tp_rank} " + "as an error occurred during loading.") + return [] + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + param.data.uniform_(low, high) \ No newline at end of file diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index a1f29e4c3..32c319166 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -226,7 +226,7 @@ Action 3: Finish [United States].\n def test_parallel_decoding(): max_tokens = 64 - number = 5 + fork_size = 5 @sgl.function def parallel_decoding(s, topic): @@ -234,17 +234,17 @@ def test_parallel_decoding(): s += "USER: Give some tips for " + topic + ".\n" s += ( "ASSISTANT: Okay. Here are " - + str(number) + + str(fork_size) + " concise tips, each under 8 words:\n" ) # Generate skeleton - for i in range(1, 1 + number): + for i in range(1, 1 + fork_size): s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n" # Generate detailed tips - forks = s.fork(number) - for i in range(number): + forks = s.fork(fork_size) + for i in range(fork_size): forks[ i ] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:" @@ -253,7 +253,7 @@ def test_parallel_decoding(): # Concatenate tips and summarize s += "Here are these tips with detailed explanation:\n" - for i in range(number): + for i in range(fork_size): s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n" s += "\nIn summary," + sgl.gen("summary", max_tokens=512)