From 564a898ad975192b593be81387d11faf15cb1d3e Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 13 Jul 2024 23:39:37 -0700 Subject: [PATCH] Optimize mem indices mangement (#619) --- benchmark/latency_throughput/bench_one.py | 9 +- python/sglang/backend/runtime_endpoint.py | 18 +- python/sglang/bench_latency.py | 1 - python/sglang/global_config.py | 1 + python/sglang/lang/chat_template.py | 4 +- python/sglang/lang/ir.py | 6 +- .../managers/controller/cuda_graph_runner.py | 48 +++- .../srt/managers/controller/infer_batch.py | 46 ++-- .../srt/managers/controller/model_runner.py | 19 +- .../srt/managers/controller/radix_cache.py | 3 +- .../srt/managers/controller/tp_worker.py | 4 +- python/sglang/srt/memory_pool.py | 32 +-- python/sglang/srt/models/minicpm.py | 9 +- python/sglang/srt/models/qwen2_moe.py | 227 ++++++++++-------- python/sglang/srt/utils.py | 2 +- 15 files changed, 251 insertions(+), 178 deletions(-) diff --git a/benchmark/latency_throughput/bench_one.py b/benchmark/latency_throughput/bench_one.py index b912d3a02..36ae8a436 100644 --- a/benchmark/latency_throughput/bench_one.py +++ b/benchmark/latency_throughput/bench_one.py @@ -17,7 +17,8 @@ def run_one_batch_size(bs): if args.input_len: input_ids = [ - [int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs) + [int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] + for _ in range(bs) ] else: text = [f"{i, }" for i in range(bs)] @@ -116,9 +117,11 @@ if __name__ == "__main__": parser.add_argument("--port", type=int, default=None) parser.add_argument("--backend", type=str, default="srt") parser.add_argument("--input-len", type=int, default=None) - parser.add_argument("--batch-size", type=int, nargs='*', default=[1]) + parser.add_argument("--batch-size", type=int, nargs="*", default=[1]) parser.add_argument("--max-tokens", type=int, default=256) - parser.add_argument("--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B") + parser.add_argument( + "--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B" + ) args = parser.parse_args() if args.port is None: diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index da27a57e9..d845e8116 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -12,7 +12,6 @@ from sglang.utils import http_request class RuntimeEndpoint(BaseBackend): - def __init__( self, base_url: str, @@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend): self.model_info = res.json() self.chat_template = get_chat_template_by_model_path( - self.model_info["model_path"]) + self.model_info["model_path"] + ) def get_model_name(self): return self.model_info["model_path"] @@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend): else: raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") - for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]: + for item in [ + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", + ]: value = getattr(sampling_params, item, None) if value is not None: data[item] = value @@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend): else: raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") - for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]: + for item in [ + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", + ]: value = getattr(sampling_params, item, None) if value is not None: data[item] = value diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 23ec11a34..c4c6d0ecf 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -32,7 +32,6 @@ import logging import multiprocessing import time - import numpy as np import torch import torch.distributed as dist diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 662cb4a6f..ba2895a9d 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -44,4 +44,5 @@ class GlobalConfig: # adjust_cache: Adjust the position embedding of KV cache. self.concate_and_append_mode = "no_adjust" + global_config = GlobalConfig() diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 273eb8c3b..bfde4bbdb 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -84,7 +84,7 @@ register_chat_template( "system": ("SYSTEM:", "\n"), "user": ("USER:", "\n"), "assistant": ("ASSISTANT:", "\n"), - } + }, ) ) @@ -177,7 +177,7 @@ register_chat_template( "assistant": ("", "<|im_end|>\n"), }, style=ChatTemplateStyle.PLAIN, - stop_str=("<|im_end|>",) + stop_str=("<|im_end|>",), ) ) diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 83c6f79b0..e5d5e837a 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -24,9 +24,9 @@ class SglSamplingParams: presence_penalty: float = 0.0 ignore_eos: bool = False return_logprob: Optional[bool] = None - logprob_start_len: Optional[int] = None, - top_logprobs_num: Optional[int] = None, - return_text_in_logprobs: Optional[bool] = None, + logprob_start_len: Optional[int] = (None,) + top_logprobs_num: Optional[int] = (None,) + return_text_in_logprobs: Optional[bool] = (None,) # for constrained generation, not included in to_xxx_kwargs dtype: Optional[str] = None diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index 2e37e55b5..7218936be 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture from sglang.global_config import global_config from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.managers.controller.infer_batch import ( - Batch, ForwardMode, InputMetadata, init_flashinfer_args + Batch, + ForwardMode, + InputMetadata, + init_flashinfer_args, ) @@ -24,18 +27,28 @@ class CudaGraphRunner: # Common inputs self.max_bs = max_batch_size_to_capture self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") - self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.req_pool_indices = torch.zeros( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") - self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") - self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.position_ids_offsets = torch.zeros( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) + self.out_cache_loc = torch.zeros( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) # FlashInfer inputs - self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0] + self.flashinfer_workspace_buffer = ( + self.model_runner.flashinfer_workspace_buffers[0] + ) self.flashinfer_kv_indptr = torch.zeros( (self.max_bs + 1,), dtype=torch.int32, device="cuda" ) self.flashinfer_kv_indices = torch.zeros( - (self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda" + (self.max_bs * model_runner.model_config.context_len,), + dtype=torch.int32, + device="cuda", ) self.flashinfer_kv_last_page_len = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" @@ -49,7 +62,12 @@ class CudaGraphRunner: with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream for bs in batch_size_list: - graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs) + ( + graph, + input_buffers, + output_buffers, + flashinfer_handler, + ) = self.capture_one_batch_size(bs) self.graphs[bs] = graph self.input_buffers[bs] = input_buffers self.output_buffers[bs] = output_buffers @@ -71,17 +89,19 @@ class CudaGraphRunner: # FlashInfer inputs if not _grouped_size_compiled_for_decode_kernels( - self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size, + self.model_runner.model_config.num_attention_heads + // self.model_runner.tp_size, self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size), ): use_tensor_cores = True else: use_tensor_cores = False flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD", + self.flashinfer_workspace_buffer, + "NHD", use_cuda_graph=True, use_tensor_cores=use_tensor_cores, - paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1], + paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], paged_kv_indices_buffer=self.flashinfer_kv_indices, paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], ) @@ -163,10 +183,14 @@ class CudaGraphRunner: else: output = LogitProcessorOutput( next_token_logits=output.next_token_logits[:raw_bs], - next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None, + next_token_logprobs=output.next_token_logprobs[:raw_bs] + if output.next_token_logprobs is not None + else None, normalized_prompt_logprobs=None, prefill_token_logprobs=None, prefill_top_logprobs=None, - decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None, + decode_top_logprobs=output.decode_top_logprobs[:raw_bs] + if output.decode_top_logprobs is not None + else None, ) return output diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index d89e9786e..375ec6eeb 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -668,7 +668,9 @@ class Batch: sampled_index = torch.multinomial(probs_sort, num_samples=1) except RuntimeError as e: warnings.warn(f"Ignore errors in sampling: {e}") - sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device) + sampled_index = torch.ones( + probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device + ) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( -1 ) @@ -749,8 +751,14 @@ class InputMetadata: skip_flashinfer_init=False, ): if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: - init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens, - model_runner.flashinfer_decode_wrapper) + init_flashinfer_args( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + model_runner.flashinfer_decode_wrapper, + ) batch_size = len(req_pool_indices) @@ -807,16 +815,24 @@ class InputMetadata: ) if model_runner.server_args.disable_flashinfer: - (ret.triton_max_seq_len, - ret.triton_max_extend_len, - ret.triton_start_loc, - ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens) + ( + ret.triton_max_seq_len, + ret.triton_max_extend_len, + ret.triton_start_loc, + ret.triton_prefix_lens, + ) = init_triton_args(forward_mode, seq_lens, prefix_lens) return ret -def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper): +def init_flashinfer_args( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper, +): num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) head_dim = model_runner.model_config.head_dim @@ -827,9 +843,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, else: paged_kernel_lens = prefix_lens - kv_indptr = torch.zeros( - (batch_size + 1,), dtype=torch.int32, device="cuda" - ) + kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) req_pool_indices_cpu = req_pool_indices.cpu().numpy() paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() @@ -842,9 +856,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, ], dim=0, ).contiguous() - kv_last_page_len = torch.ones( - (batch_size,), dtype=torch.int32, device="cuda" - ) + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") if forward_mode == ForwardMode.DECODE: flashinfer_decode_wrapper.end_forward() @@ -859,9 +871,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, ) else: # extend part - qo_indptr = torch.zeros( - (batch_size + 1,), dtype=torch.int32, device="cuda" - ) + qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) model_runner.flashinfer_prefill_wrapper_ragged.end_forward() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 315dd4d66..d68d9af32 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -16,7 +16,12 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config -from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict +from sglang.srt.managers.controller.infer_batch import ( + Batch, + ForwardMode, + InputMetadata, + global_server_args_dict, +) from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -83,7 +88,9 @@ class ModelRunner: # Set some global args global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer - global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32 + global_server_args_dict[ + "attention_reduce_in_fp32" + ] = server_args.attention_reduce_in_fp32 # Load the model and create memory pool self.load_model() @@ -217,7 +224,9 @@ class ModelRunner: self.flashinfer_workspace_buffers[1], "NHD" ) self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores + self.flashinfer_workspace_buffers[0], + "NHD", + use_tensor_cores=use_tensor_cores, ) def init_cuda_graphs(self): @@ -229,7 +238,9 @@ class ModelRunner: logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.") batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)] - self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list)) + self.cuda_graph_runner = CudaGraphRunner( + self, max_batch_size_to_capture=max(batch_size_list) + ) self.cuda_graph_runner.capture(batch_size_list) @torch.inference_mode() diff --git a/python/sglang/srt/managers/controller/radix_cache.py b/python/sglang/srt/managers/controller/radix_cache.py index ab8d6b446..bc7b758dd 100644 --- a/python/sglang/srt/managers/controller/radix_cache.py +++ b/python/sglang/srt/managers/controller/radix_cache.py @@ -125,7 +125,8 @@ class RadixCache: if x.lock_ref > 0: continue - num_evicted += evict_callback(x.value) + evict_callback(x.value) + num_evicted += len(x.value) self._delete_leaf(x) if len(x.parent.children) == 0: diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index b572e120e..12c278fd5 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -314,7 +314,9 @@ class ModelTpServer: self.forward_queue.append(req) def get_new_fill_batch(self) -> Optional[Batch]: - running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0 + running_bs = ( + len(self.running_batch.reqs) if self.running_batch is not None else 0 + ) if running_bs >= self.max_running_requests: return diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 245e6ef08..46010ccf7 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -39,10 +39,12 @@ class ReqToTokenPool: class TokenToKVPool: def __init__(self, size, dtype, head_num, head_dim, layer_num): self.size = size - # mem_state is the reference counter. + # This can be promised: + # assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0) # We also add one slot. This slot is used for writing dummy output from padded tokens. - self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda") - self.total_ref_ct = 0 + self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda") + self.total_size = self.size + self.total_alloc = 0 # [size, key/value, head_num, head_dim] for each layer self.kv_data = [ @@ -71,7 +73,9 @@ class TokenToKVPool: addition_size = need_size - buffer_len alloc_size = max(addition_size, self.prefetch_chunk_size) - select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32) + select_index = ( + torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32) + ) if select_index.shape[0] < addition_size: return None @@ -105,26 +109,22 @@ class TokenToKVPool: return select_index.to(torch.int32), start_loc, start_loc + need_size def used_size(self): - return len(torch.nonzero(self.mem_state).squeeze(1)) + return self.total_alloc def available_size(self): - return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer) + return self.total_size - self.total_alloc + len(self.prefetch_buffer) def add_refs(self, token_index: torch.Tensor): - self.total_ref_ct += len(token_index) - self.mem_state[token_index] += 1 + self.total_alloc += len(token_index) + self.mem_state[token_index] ^= True def dec_refs(self, token_index: torch.Tensor): - self.total_ref_ct -= len(token_index) - self.mem_state[token_index] -= 1 - - num_freed = torch.sum(self.mem_state[token_index] == 0) - - return num_freed + self.total_alloc -= len(token_index) + self.mem_state[token_index] ^= True def clear(self): self.mem_state.fill_(0) - self.total_ref_ct = 0 + self.total_alloc = 0 # We also add one slot. This slot is used for writing dummy output from padded tokens. - self.add_refs(torch.tensor([0], dtype=torch.int32)) + self.mem_state[0] = True diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 072bf99ab..3f16c95f9 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple import torch from torch import nn - from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size - from vllm.model_executor.layers.activation import SiluAndMul - from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata class MiniCPMMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module): class MiniCPMAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module): class MiniCPMDecoderLayer(nn.Module): - def __init__( self, config, @@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module): class MiniCPMModel(nn.Module): - def __init__( self, config, @@ -274,7 +267,7 @@ class MiniCPMForCausalLM(nn.Module): ) -> None: super().__init__() self.config = config - + self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, quant_config=quant_config) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 79187cd43..072002c6f 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -8,24 +8,28 @@ import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig - from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -34,8 +38,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.controller.model_runner import InputMetadata -class Qwen2MoeMLP(nn.Module): +class Qwen2MoeMLP(nn.Module): def __init__( self, hidden_size: int, @@ -46,17 +50,20 @@ class Qwen2MoeMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results) + quant_config=quant_config, + reduce_results=reduce_results, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -67,7 +74,6 @@ class Qwen2MoeMLP(nn.Module): class Qwen2MoeSparseMoeBlock(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -79,20 +85,22 @@ class Qwen2MoeSparseMoeBlock(nn.Module): if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) - self.experts = FusedMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config) + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + ) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None) + self.gate = ReplicatedLinear( + config.hidden_size, config.num_experts, bias=False, quant_config=None + ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, @@ -103,9 +111,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, - 1, - bias=False) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -114,24 +120,24 @@ class Qwen2MoeSparseMoeBlock(nn.Module): if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: - shared_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_output + shared_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + ) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) class Qwen2MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -190,17 +196,19 @@ class Qwen2MoeAttention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = RadixAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - layer_id=layer_id) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata + input_metadata: InputMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -211,7 +219,6 @@ class Qwen2MoeAttention(nn.Module): class Qwen2MoeDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -223,8 +230,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen2MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -239,13 +245,13 @@ class Qwen2MoeDecoderLayer(nn.Module): # Note: Qwen/Qwen2-57B-A14B-Instruct does not have # `mlp_only_layers` in the config. - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_id + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen2MoeSparseMoeBlock(config=config, - quant_config=quant_config) + config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, @@ -253,10 +259,10 @@ class Qwen2MoeDecoderLayer(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -270,23 +276,20 @@ class Qwen2MoeDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata + input_metadata=input_metadata, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class Qwen2MoeModel(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -301,13 +304,14 @@ class Qwen2MoeModel(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, - layer_id, - cache_config, - quant_config=quant_config) - for layer_id in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Qwen2MoeDecoderLayer( + config, layer_id, cache_config, quant_config=quant_config + ) + for layer_id in range(config.num_hidden_layers) + ] + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -315,7 +319,7 @@ class Qwen2MoeModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - input_embeds: torch.Tensor = None + input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -324,10 +328,9 @@ class Qwen2MoeModel(nn.Module): residual = None for i in range(len(self.layers)): layer = self.layers[i] - hidden_states, residual = layer(positions, - hidden_states, - input_metadata, - residual) + hidden_states, residual = layer( + positions, hidden_states, input_metadata, residual + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -346,9 +349,9 @@ class Qwen2MoeForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = Qwen2MoeModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) self.sampler = Sampler() @@ -357,17 +360,22 @@ class Qwen2MoeForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - input_embeds: torch.Tensor = None + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, - input_embeds) - return self.logits_processor(input_ids, hidden_states, self.lm_head.weight, - input_metadata) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) - def compute_logits(self, input_ids: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata) -> torch.Tensor: - logits = self.logits_processor(input_ids, hidden_states, self.lm_head.weight, - input_metadata) + def compute_logits( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + logits = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) return logits def sample( @@ -391,11 +399,18 @@ class Qwen2MoeForCausalLM(nn.Module): expert_params_mapping = [ # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] - else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_experts) for shard_id, - weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + ( + "experts.w13_weight" + if weight_name in ["gate_proj", "up_proj"] + else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + shard_id, + ) + for expert_id in range(self.config.num_experts) + for shard_id, weight_name in enumerate( + ["gate_proj", "down_proj", "up_proj"] + ) ] params_dict = dict(self.named_parameters()) @@ -433,11 +448,13 @@ class Qwen2MoeForCausalLM(nn.Module): name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -447,8 +464,10 @@ class Qwen2MoeForCausalLM(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) + EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 03a2d60ab..981b5e218 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -474,9 +474,9 @@ def monkey_patch_vllm_dummy_weight_loader(): DummyModelLoader, LoRAConfig, ModelConfig, + MultiModalConfig, ParallelConfig, SchedulerConfig, - MultiModalConfig, _initialize_model, initialize_dummy_weights, nn,