diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py index 8d61858c6..c80c17a24 100644 --- a/benchmark/gsm8k/bench_other.py +++ b/benchmark/gsm8k/bench_other.py @@ -65,7 +65,7 @@ def main(args): def get_one_answer(i): answer = call_generate( prompt=few_shot_examples + questions[i], - #prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i], + # prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i], temperature=0, max_tokens=256, stop="Question", diff --git a/benchmark/latency_throughput/bench_throughput.py b/benchmark/latency_throughput/bench_throughput.py index 530cfd3ab..286f1fb12 100644 --- a/benchmark/latency_throughput/bench_throughput.py +++ b/benchmark/latency_throughput/bench_throughput.py @@ -158,7 +158,9 @@ async def send_request( timeout = aiohttp.ClientTimeout(total=3 * 3600) async with aiohttp.ClientSession(timeout=timeout) as session: while True: - async with session.post(api_url, headers=headers, json=pload) as response: + async with session.post( + api_url, headers=headers, json=pload + ) as response: chunks = [] async for chunk, _ in response.content.iter_chunks(): chunks.append(chunk) @@ -228,19 +230,32 @@ def main(args: argparse.Namespace): np.random.seed(args.seed) api_url = f"http://{args.host}:{args.port}/generate" - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code + ) if args.dataset: input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) else: input_lens = np.random.randint( - int(args.input_len * args.range_ratio), args.input_len + 1, size=args.num_prompts) + int(args.input_len * args.range_ratio), + args.input_len + 1, + size=args.num_prompts, + ) output_lens = np.random.randint( - int(args.output_len * args.range_ratio), args.output_len + 1, size=args.num_prompts) + int(args.output_len * args.range_ratio), + args.output_len + 1, + size=args.num_prompts, + ) offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts) input_requests = [] for i in range(args.num_prompts): - prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) benchmark_start_time = time.perf_counter() @@ -287,16 +302,15 @@ if __name__ == "__main__": ) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=30000) - parser.add_argument( - "--dataset", type=str, help="Path to the dataset." - ) + parser.add_argument("--dataset", type=str, help="Path to the dataset.") parser.add_argument("--input-len", type=int, default=2048) parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--range-ratio", type=float, default=1.0) parser.add_argument( - "--tokenizer", type=str, + "--tokenizer", + type=str, default="NousResearch/Meta-Llama-3-8B", - help="Name or path of the tokenizer." + help="Name or path of the tokenizer.", ) parser.add_argument( "--best-of", diff --git a/benchmark/mmlu/bench_other.py b/benchmark/mmlu/bench_other.py index 1799744f0..c5d48dac6 100644 --- a/benchmark/mmlu/bench_other.py +++ b/benchmark/mmlu/bench_other.py @@ -170,4 +170,4 @@ if __name__ == "__main__": parser.add_argument("--data_dir", type=str, default="data") parser.add_argument("--nsub", type=int, default=60) args = add_common_other_args_and_parse(parser) - main(args) \ No newline at end of file + main(args) diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index d6303c86a..556b9eb33 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -24,10 +24,10 @@ from sglang.api import ( # SGL Backends from sglang.backend.anthropic import Anthropic +from sglang.backend.litellm import LiteLLM from sglang.backend.openai import OpenAI from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.vertexai import VertexAI -from sglang.backend.litellm import LiteLLM # Global Configurations from sglang.global_config import global_config diff --git a/python/sglang/backend/litellm.py b/python/sglang/backend/litellm.py index dc89dc16d..eef6b0cda 100644 --- a/python/sglang/backend/litellm.py +++ b/python/sglang/backend/litellm.py @@ -33,7 +33,8 @@ class LiteLLM(BaseBackend): self.model_name = model_name self.chat_template = chat_template or get_chat_template_by_model_path( - model_name) + model_name + ) self.client_params = { "api_key": api_key, diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 2cb5992d8..6f65f4eab 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -1,7 +1,7 @@ +import dataclasses import logging import time import warnings -import dataclasses from typing import Callable, List, Optional, Union import numpy as np @@ -105,14 +105,16 @@ class OpenAI(BaseBackend): def get_chat_template(self): return self.chat_template - def _prepare_spec_execution(self, sampling_params: SglSamplingParams, - num_api_spec_tokens: int, spec_var_name: str): + def _prepare_spec_execution( + self, + sampling_params: SglSamplingParams, + num_api_spec_tokens: int, + spec_var_name: str, + ): if "max_tokens" not in self.spec_kwargs: self.spec_kwargs["max_tokens"] = num_api_spec_tokens else: - assert ( - self.spec_kwargs["max_tokens"] == num_api_spec_tokens - ) + assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens params = sampling_params.to_openai_kwargs() for key, value in params.items(): @@ -151,8 +153,9 @@ class OpenAI(BaseBackend): ) prompt = s.messages_ else: - return self._prepare_spec_execution(sampling_params, - s.num_api_spec_tokens, spec_var_name) + return self._prepare_spec_execution( + sampling_params, s.num_api_spec_tokens, spec_var_name + ) else: prompt = s.text_ @@ -325,7 +328,7 @@ class OpenAI(BaseBackend): ret_str = ret.choices[0].text ret_token = self.tokenizer.encode(ret_str)[0] self.token_usage.prompt_tokens += ret.usage.prompt_tokens - self.token_usage.completion_tokens= ret.usage.completion_tokens + self.token_usage.completion_tokens = ret.usage.completion_tokens # TODO: # 1. return logits as the scores @@ -355,7 +358,9 @@ class OpenAI(BaseBackend): return decision, scores, None, None -def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): +def openai_completion( + client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs +): for attempt in range(retries): try: if is_chat: @@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, return comp -def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): +def openai_completion_stream( + client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs +): for attempt in range(retries): try: if is_chat: if "stop" in kwargs and kwargs["stop"] is None: kwargs.pop("stop") generator = client.chat.completions.create( - messages=prompt, stream=True, stream_options={"include_usage": True}, - **kwargs + messages=prompt, + stream=True, + stream_options={"include_usage": True}, + **kwargs, ) for ret in generator: if len(ret.choices) == 0: @@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp yield content or "", {} else: generator = client.completions.create( - prompt=prompt, stream=True, stream_options={"include_usage": True}, - **kwargs + prompt=prompt, + stream=True, + stream_options={"include_usage": True}, + **kwargs, ) for ret in generator: if len(ret.choices) == 0: diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 789879f00..4f5bfa3ed 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -507,7 +507,7 @@ class StreamExecutor: ) return - else: # Speculative execution on models with completion interface + else: # Speculative execution on models with completion interface comp, meta_info = self._spec_gen(sampling_params) self.text_ += comp diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index c2b041fe3..0567689e0 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -81,12 +81,10 @@ class SglSamplingParams: "top_p": self.top_p, "top_k": self.top_k, } - + def to_litellm_kwargs(self): if self.regex is not None: - warnings.warn( - "Regular expression is not supported in the LiteLLM backend." - ) + warnings.warn("Regular expression is not supported in the LiteLLM backend.") return { "max_tokens": self.max_new_tokens, "stop": self.stop or None, diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 3b4ee3ed8..a8c1e2feb 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -10,4 +10,4 @@ if __name__ == "__main__": args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args, None) \ No newline at end of file + launch_server(server_args, None) diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index 294a4fa70..a048c3dec 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,4 +1,5 @@ """Launch the inference server for Llava-video model.""" + import argparse import multiprocessing as mp diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index ab6f56e5d..988f502c8 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Union from outlines.caching import cache as disk_cache from outlines.caching import disable_cache from outlines.fsm.guide import RegexGuide -from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm +from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm from outlines.models.transformers import TransformerTokenizer from pydantic import BaseModel diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 387ccf024..8789d5ffe 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -1,4 +1,5 @@ """Cache for the compressed finite state machine.""" + from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_cache import BaseCache diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index 39356a71a..f17b187f3 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -8,11 +8,12 @@ from collections import defaultdict import interegular import outlines.caching + from sglang.srt.constrained import ( FSMInfo, disk_cache, - make_deterministic_fsm, make_byte_level_fsm, + make_deterministic_fsm, ) from sglang.srt.constrained.base_cache import BaseCache diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 992c2021b..b3988cc82 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -1,4 +1,5 @@ """Conversation templates.""" + # Adapted from # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py import dataclasses diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index ec9b1dccd..50ad25e76 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -1,10 +1,10 @@ """Utilities for Huggingface Transformers.""" +import functools import json import os import warnings -import functools -from typing import Optional, Union, AbstractSet, Collection, Literal +from typing import AbstractSet, Collection, Literal, Optional, Union from huggingface_hub import snapshot_download from transformers import ( @@ -179,6 +179,7 @@ def get_processor( class TiktokenTokenizer: def __init__(self, tokenizer_path): import tiktoken + PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # Read JSON @@ -190,7 +191,8 @@ class TiktokenTokenizer: bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] } special_tokens = { - bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"] + bytes(item["bytes"]).decode(): item["token"] + for item in tok_dict["special_tokens"] } assert tok_dict["word_split"] == "V1" @@ -202,7 +204,10 @@ class TiktokenTokenizer: } if "default_allowed_special" in tok_dict: default_allowed_special = set( - [bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]] + [ + bytes(bytes_list).decode() + for bytes_list in tok_dict["default_allowed_special"] + ] ) else: default_allowed_special = None @@ -216,14 +221,20 @@ class TiktokenTokenizer: self, text: str, *, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + allowed_special: Union[ + Literal["all"], AbstractSet[str] + ] = set(), # noqa: B006 disallowed_special: Union[Literal["all"], Collection[str]] = "all", ) -> list[int]: if isinstance(allowed_special, set): allowed_special |= self._default_allowed_special return tiktoken.Encoding.encode( - self, text, allowed_special=allowed_special, disallowed_special=disallowed_special + self, + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, ) + tokenizer.encode = functools.partial(encode_patched, tokenizer) # Convert to HF interface @@ -237,10 +248,14 @@ class TiktokenTokenizer: def decode(self, x): return self.tokenizer.decode(x) - def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False): + def batch_decode( + self, batch, skip_special_tokens=True, spaces_between_special_tokens=False + ): if isinstance(batch[0], int): batch = [[x] for x in batch] return self.tokenizer.decode_batch(batch) def convert_ids_to_tokens(self, index): - return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore") \ No newline at end of file + return self.tokenizer.decode_single_token_bytes(index).decode( + "utf-8", errors="ignore" + ) diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe.py index 776194710..60e22f28c 100644 --- a/python/sglang/srt/layers/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe.py @@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple import torch import triton import triton.language as tl - from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.utils import is_hip @@ -109,12 +108,16 @@ def fused_moe_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) if use_fp8: a_scale = tl.load(a_scale_ptr) @@ -130,13 +133,12 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -147,9 +149,7 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_fp8: @@ -159,15 +159,14 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, - num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, block_size: int, num_experts: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -206,32 +205,38 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + ops.moe_align_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, - use_fp8: bool) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8: bool, +) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + grid = lambda META: ( + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + ) fused_moe_kernel[grid]( A, @@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @functools.lru_cache -def get_moe_configs(E: int, N: int, - dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int, json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", - config_file_path) + logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} @@ -352,40 +358,30 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] M, _ = hidden_states.shape E, N, _ = w1.shape if is_hip(): # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, - dim=-1, - dtype=torch.float32) + routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32) topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) else: import vllm._moe_C as moe_kernels - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) moe_kernels.topk_softmax( topk_weights, topk_ids, @@ -400,8 +396,7 @@ def fused_moe( config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], - "float8" if use_fp8 else None) + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the @@ -415,7 +410,7 @@ def fused_moe( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 4, } if M <= E: @@ -425,61 +420,72 @@ def fused_moe( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 4 + "num_stages": 4, } - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) - compute_type = (tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16) + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8) + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8) + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) if inplace: - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) \ No newline at end of file + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index eb32ff7b1..f613f44e5 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -1,4 +1,5 @@ """Logits processing.""" + import torch from torch import nn from vllm.distributed import ( diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 651349735..605ec643b 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -1,6 +1,7 @@ """Radix attention.""" -import torch + import numpy as np +import torch from torch import nn from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd @@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada class RadixAttention(nn.Module): - def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1): + def __init__( + self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1 + ): super().__init__() self.tp_q_head_num = num_heads self.tp_k_head_num = num_kv_heads diff --git a/python/sglang/srt/managers/controller/dp_worker.py b/python/sglang/srt/managers/controller/dp_worker.py index ca2a03cf2..3b6becfd2 100644 --- a/python/sglang/srt/managers/controller/dp_worker.py +++ b/python/sglang/srt/managers/controller/dp_worker.py @@ -4,7 +4,7 @@ import asyncio import logging import queue import threading -from typing import List, Callable +from typing import Callable, List import uvloop import zmq @@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread): # async sleep for receiving the subsequent request and avoiding cache miss if len(out_pyobjs) != 0: - has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) + has_finished = any( + [obj.finished_reason is not None for obj in out_pyobjs] + ) if has_finished: await asyncio.sleep(self.request_dependency_delay) await asyncio.sleep(global_config.wait_for_new_request_delay) @@ -108,4 +110,4 @@ def start_data_parallel_worker( step_func=model_tp_client.step, ) worker_thread.start() - return worker_thread \ No newline at end of file + return worker_thread diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 7ff9406ea..0cab5455d 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -1,17 +1,17 @@ """Meta data for requests and batches""" +import warnings from dataclasses import dataclass from enum import IntEnum, auto from typing import List -import warnings import numpy as np import torch +from sglang.srt.constrained import RegexGuide +from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool -from sglang.srt.constrained.jump_forward import JumpForwardMap -from sglang.srt.constrained import RegexGuide INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 diff --git a/python/sglang/srt/managers/controller/manager_multi.py b/python/sglang/srt/managers/controller/manager_multi.py index 83f45b9a8..72e3bed80 100644 --- a/python/sglang/srt/managers/controller/manager_multi.py +++ b/python/sglang/srt/managers/controller/manager_multi.py @@ -13,15 +13,15 @@ import zmq import zmq.asyncio from sglang.global_config import global_config +from sglang.srt.managers.controller.dp_worker import ( + DataParallelWorkerThread, + start_data_parallel_worker, +) from sglang.srt.managers.io_struct import ( AbortReq, FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.managers.controller.dp_worker import ( - DataParallelWorkerThread, - start_data_parallel_worker, -) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import get_exception_traceback @@ -136,7 +136,7 @@ class Controller: self.recv_reqs = [] if next_step_input: await self.dispatching(next_step_input) - #else: + # else: # logger.error("There is no live worker.") await asyncio.sleep(global_config.wait_for_new_request_delay) diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index d1c49c6e2..e4d02d036 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -1,4 +1,5 @@ """A controller that manages a group of tensor parallel workers.""" + import asyncio import logging import time @@ -49,7 +50,9 @@ class ControllerSingle: # async sleep for receiving the subsequent request and avoiding cache miss slept = False if len(out_pyobjs) != 0: - has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) + has_finished = any( + [obj.finished_reason is not None for obj in out_pyobjs] + ) if has_finished: if self.request_dependency_delay > 0: slept = True @@ -94,4 +97,4 @@ def start_controller_process( except Exception: logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) finally: - kill_parent_process() \ No newline at end of file + kill_parent_process() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 11c198be4..692ed7ac3 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -1,4 +1,5 @@ """ModelRunner runs the forward passes of the models.""" + import importlib import importlib.resources import logging @@ -12,15 +13,18 @@ import torch import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig -from vllm.distributed import initialize_model_parallel, init_distributed_environment +from vllm.distributed import init_distributed_environment, initialize_model_parallel from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check - +from sglang.srt.utils import ( + get_available_gpu_memory, + is_multimodal_model, + monkey_patch_vllm_p2p_access_check, +) logger = logging.getLogger("srt.model_runner") @@ -441,7 +445,9 @@ def import_model_classes(): module = importlib.import_module(name) if hasattr(module, "EntryClass"): entry = module.EntryClass - if isinstance(entry, list): # To support multiple model classes in one module + if isinstance( + entry, list + ): # To support multiple model classes in one module for tmp in entry: model_arch_name_to_cls[tmp.__name__] = tmp else: @@ -449,7 +455,9 @@ def import_model_classes(): # compat: some models such as chatglm has incorrect class set in config.json # usage: [ tuple("From_Entry_Class_Name": EntryClass), ] - if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list): + if hasattr(module, "EntryClassRemapping") and isinstance( + module.EntryClassRemapping, list + ): for remap in module.EntryClassRemapping: if isinstance(remap, tuple) and len(remap) == 2: model_arch_name_to_cls[remap[0]] = remap[1] diff --git a/python/sglang/srt/managers/controller/radix_cache.py b/python/sglang/srt/managers/controller/radix_cache.py index 04a184c10..ab8d6b446 100644 --- a/python/sglang/srt/managers/controller/radix_cache.py +++ b/python/sglang/srt/managers/controller/radix_cache.py @@ -1,6 +1,7 @@ """ The radix tree data structure for managing the KV cache. """ + import heapq import time from collections import defaultdict diff --git a/python/sglang/srt/managers/controller/schedule_heuristic.py b/python/sglang/srt/managers/controller/schedule_heuristic.py index 6e75a7ad4..4ae1a7069 100644 --- a/python/sglang/srt/managers/controller/schedule_heuristic.py +++ b/python/sglang/srt/managers/controller/schedule_heuristic.py @@ -1,4 +1,5 @@ """Request scheduler heuristic.""" + import random from collections import defaultdict diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 3d4c48e51..2f3e86593 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -15,22 +15,22 @@ from sglang.global_config import global_config from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.io_struct import ( - AbortReq, - BatchTokenIDOut, - FlushCacheReq, - TokenizedGenerateReqInput, -) from sglang.srt.managers.controller.infer_batch import ( + FINISH_ABORT, BaseFinishReason, Batch, - FINISH_ABORT, ForwardMode, Req, ) from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic +from sglang.srt.managers.io_struct import ( + AbortReq, + BatchTokenIDOut, + FlushCacheReq, + TokenizedGenerateReqInput, +) from sglang.srt.model_config import ModelConfig from sglang.srt.server_args import ModelPortArgs, ServerArgs from sglang.srt.utils import ( @@ -96,13 +96,13 @@ class ModelTpServer: trust_remote_code=server_args.trust_remote_code, ) self.max_total_num_tokens = self.model_runner.max_total_num_tokens - self.max_prefill_tokens = max( - self.model_config.context_len, - ( - min(self.max_total_num_tokens // 6, 65536) - if server_args.max_prefill_tokens is None - else server_args.max_prefill_tokens - ), + self.max_prefill_tokens = ( + max( + self.model_config.context_len, + min(self.max_total_num_tokens // 6, 65536), + ) + if server_args.max_prefill_tokens is None + else server_args.max_prefill_tokens ) self.max_running_requests = ( self.max_total_num_tokens // 2 diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index b5231e69a..ecba679e2 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -1,4 +1,5 @@ """DetokenizerManager is a process that detokenizes the token ids.""" + import asyncio import inspect @@ -7,10 +8,10 @@ import zmq import zmq.asyncio from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import get_exception_traceback, graceful_registry -from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1897a2c41..20590bc24 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -7,8 +7,8 @@ import uuid from dataclasses import dataclass from typing import Dict, List, Optional, Union -from sglang.srt.sampling_params import SamplingParams from sglang.srt.managers.controller.infer_batch import BaseFinishReason +from sglang.srt.sampling_params import SamplingParams @dataclass diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8fe3ff8dc..42f970370 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1,11 +1,12 @@ """TokenizerManager is a process that tokenizes the text.""" + import asyncio import concurrent.futures import dataclasses import logging import multiprocessing as mp import os -from typing import List, Dict +from typing import Dict, List import numpy as np import transformers @@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import ( from sglang.srt.managers.io_struct import ( AbortReq, BatchStrOut, + BatchTokenIDOut, FlushCacheReq, GenerateReqInput, TokenizedGenerateReqInput, ) -from sglang.srt.managers.io_struct import BatchTokenIDOut from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs @@ -91,7 +92,7 @@ class TokenizerManager: ) self.to_create_loop = True - self.rid_to_state: Dict[str, ReqState] = {} + self.rid_to_state: Dict[str, ReqState] = {} async def get_pixel_values(self, image_data): aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) @@ -322,7 +323,6 @@ class TokenizerManager: state.finished = recv_obj.finished_reason[i] is not None state.event.set() - def convert_logprob_style( self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs ): diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 3c0062bae..315ab4163 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -1,8 +1,9 @@ from typing import Optional -from sglang.srt.hf_transformers_utils import get_config, get_context_length from transformers import PretrainedConfig +from sglang.srt.hf_transformers_utils import get_config, get_context_length + class ModelConfig: def __init__( @@ -17,8 +18,12 @@ class ModelConfig: self.trust_remote_code = trust_remote_code self.revision = revision self.model_overide_args = model_overide_args - self.hf_config = get_config(self.path, trust_remote_code, revision, - model_overide_args=model_overide_args) + self.hf_config = get_config( + self.path, + trust_remote_code, + revision, + model_overide_args=model_overide_args, + ) self.hf_text_config = get_hf_text_config(self.hf_config) if context_length is not None: self.context_len = context_length @@ -55,18 +60,23 @@ class ModelConfig: # KV heads. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( - self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_text_config, - "multi_query", False): + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 # For DBRX and MPT if self.hf_config.model_type in ["dbrx", "mpt"]: - return getattr(self.hf_config.attn_config, "kv_n_heads", - self.hf_config.num_attention_heads) + return getattr( + self.hf_config.attn_config, + "kv_n_heads", + self.hf_config.num_attention_heads, + ) attributes = [ # For Falcon: @@ -94,13 +104,12 @@ class ModelConfig: # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // tensor_parallel_size) + return max(1, total_num_kv_heads // tensor_parallel_size) def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. - No op for pure text models. + No op for pure text models. """ if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 415542dce..a15cc3d4c 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -5,30 +5,32 @@ from typing import Iterable, List, Optional, Tuple import torch -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata -from sglang.srt.layers.logits_processor import LogitsProcessor from torch import nn from torch.nn import LayerNorm - from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.controller.model_runner import InputMetadata LoraConfig = None @@ -49,9 +51,11 @@ class GLMAttention(nn.Module): assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention - self.total_num_kv_heads = (config.multi_query_group_num - if config.multi_query_attention else - config.num_attention_heads) + self.total_num_kv_heads = ( + config.multi_query_group_num + if config.multi_query_attention + else config.num_attention_heads + ) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. @@ -91,11 +95,13 @@ class GLMAttention(nn.Module): base=10000 * rope_ratio, is_neox_style=False, ) - self.attn = RadixAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - layer_id=layer_id) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) def forward( self, @@ -176,14 +182,16 @@ class GLMBlock(nn.Module): ): super().__init__() self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm) + config.apply_residual_connection_post_layernorm + ) self.fp32_residual_connection = config.fp32_residual_connection layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = layer_norm_func(config.hidden_size, - eps=config.layernorm_epsilon) + self.input_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon + ) # Self attention. self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config) @@ -191,7 +199,8 @@ class GLMBlock(nn.Module): # Layernorm on the attention output self.post_attention_layernorm = layer_norm_func( - config.hidden_size, eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) # MLP self.mlp = GLMMLP(config, quant_config) @@ -250,16 +259,19 @@ class GLMTransformer(nn.Module): self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList([ - GLMBlock(config, i, cache_config, quant_config) - for i in range(self.num_layers) - ]) + self.layers = nn.ModuleList( + [ + GLMBlock(config, i, cache_config, quant_config) + for i in range(self.num_layers) + ] + ) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = layer_norm_func( - config.hidden_size, eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) def forward( self, @@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module): ): super().__init__() - self.embedding = VocabParallelEmbedding(config.padded_vocab_size, - config.hidden_size) + self.embedding = VocabParallelEmbedding( + config.padded_vocab_size, config.hidden_size + ) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels self.encoder = GLMTransformer(config, cache_config, quant_config) - self.output_layer = ParallelLMHead(config.padded_vocab_size, - config.hidden_size) + self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) def forward( self, @@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module): class ChatGLMForCausalLM(nn.Module): packed_modules_mapping = { "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"] + "dense_h_to_4h": ["dense_h_to_4h"], } # LoRA specific attributes supported_lora_modules = [ @@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config - self.max_position_embeddings = getattr(config, "max_sequence_length", - 8192) + self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config) @@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module): positions: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, - input_metadata) + hidden_states = self.transformer(input_ids, positions, input_metadata) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) @@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + EntryClass = ChatGLMForCausalLM # compat: glm model.config class == ChatGLMModel EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)] diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index c08e7eb3a..2757645e1 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -23,7 +23,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import Optional, Tuple, Iterable +from typing import Iterable, Optional, Tuple import torch import torch.utils.checkpoint @@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.utils import set_weight_attrs from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 8436386c1..b21142d2e 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.utils import set_weight_attrs from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.logits_processor import LogitsProcessor diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index e150a56ca..b8896ef88 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig -from vllm.config import LoRAConfig, CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index a2dcfec8e..9cae0b105 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -1,7 +1,7 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" -from typing import Iterable, Optional, Tuple, List +from typing import Iterable, List, Optional, Tuple import numpy as np import torch @@ -9,7 +9,6 @@ import torch.nn.functional as F import tqdm from torch import nn from transformers import PretrainedConfig - from vllm import _custom_ops as ops from vllm.config import CacheConfig from vllm.distributed import ( @@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.utils import set_weight_attrs from vllm.utils import print_warning_once -from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.fused_moe import fused_moe +from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.controller.model_runner import InputMetadata - use_fused = True @@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module): final_hidden_states = torch.zeros( (hidden_states.shape[0], hidden_dim), - dtype=hidden_states.dtype, device=hidden_states.device + dtype=hidden_states.dtype, + device=hidden_states.device, ) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0) + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_total_experts + ).permute(2, 1, 0) for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] @@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module): # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_hidden_states = ( + expert_layer(current_state) + * routing_weights[top_x_list, idx_list, None] + ) # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. @@ -198,32 +202,46 @@ class Grok1MoE(nn.Module): self.params_dtype = params_dtype # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=None) + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None, + ) if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn self.w13_weight = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype)) + torch.empty( + self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + dtype=params_dtype, + ) + ) self.w2_weight = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype)) + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + dtype=params_dtype, + ) + ) - set_weight_attrs(self.w13_weight, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2_weight, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.w13_weight, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2_weight, + { + "weight_loader": self.weight_loader, + }, + ) # Used for fp8. self.w13_scale = None @@ -233,46 +251,69 @@ class Grok1MoE(nn.Module): if self.use_fp8: # WEIGHT_SCALE (for fp8) - self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, - dtype=torch.float32), - requires_grad=False) - self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, - dtype=torch.float32), - requires_grad=False) + self.w13_scale = nn.Parameter( + torch.ones(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) + self.w2_scale = nn.Parameter( + torch.ones(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(self.w13_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2_scale, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.w13_scale, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2_scale, + { + "weight_loader": self.weight_loader, + }, + ) # ACT_SCALE (for fp8) if quant_config.activation_scheme == "static": if not quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " - "was not serialized fp8.") - self.a13_scale = nn.Parameter(torch.zeros( - self.num_total_experts, dtype=torch.float32), - requires_grad=False) - self.a2_scale = nn.Parameter(torch.zeros( - self.num_total_experts, dtype=torch.float32), - requires_grad=False) + "was not serialized fp8." + ) + self.a13_scale = nn.Parameter( + torch.zeros(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) + self.a2_scale = nn.Parameter( + torch.zeros(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) - set_weight_attrs(self.a13_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.a2_scale, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.a13_scale, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.a2_scale, + { + "weight_loader": self.weight_loader, + }, + ) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int, pre_sharded: bool): + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + pre_sharded: bool, + ): param_data = param.data shard_size = self.intermediate_size if pre_sharded: @@ -284,8 +325,9 @@ class Grok1MoE(nn.Module): if weight_name.endswith("w1.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] if "act_scale" in weight_name or "weight_scale" in weight_name: @@ -298,17 +340,17 @@ class Grok1MoE(nn.Module): # If checkpoint is fp16, quantize here. if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(self.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(self.w2_weight.data, - dtype=torch.float8_e4m3fn) + w13_weight = torch.empty_like( + self.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[ - expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], self.w2_scale[ - expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :]) + w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :] + ) + w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :] + ) self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) @@ -319,40 +361,40 @@ class Grok1MoE(nn.Module): if self.a13_scale is None or self.a2_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") + "activation scales are None." + ) - if (not all_close_1d(self.a13_scale) - or not all_close_1d(self.a2_scale)): + if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale): print_warning_once( "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. ") + "Using the maximum across experts for each layer. " + ) - self.a13_scale = nn.Parameter(self.a13_scale.max(), - requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), - requires_grad=False) + self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=False, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale) + final_hidden_states = fused_moe( + hidden_states, + self.w13_weight, + self.w2_weight, + router_logits, + self.top_k, + renormalize=False, + inplace=True, + use_fp8=self.use_fp8, + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale, + ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) @@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - quant_config=quant_config) + quant_config=quant_config, + ) else: self.block_sparse_moe = Grok1MoEUnfused( - config=config, quant_config=quant_config) + config=config, quant_config=quant_config + ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -478,12 +522,21 @@ class Grok1DecoderLayer(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.post_attn_norm(self.self_attn( - positions=positions, hidden_states=self.pre_attn_norm(hidden_states), - input_metadata=input_metadata, - )) + hidden_states + hidden_states = ( + self.post_attn_norm( + self.self_attn( + positions=positions, + hidden_states=self.pre_attn_norm(hidden_states), + input_metadata=input_metadata, + ) + ) + + hidden_states + ) - hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states + hidden_states = ( + self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + + hidden_states + ) return hidden_states @@ -525,9 +578,7 @@ class Grok1Model(nn.Module): hidden_states.mul_(self.config.embedding_multiplier_scale) for i in range(len(self.layers)): - hidden_states = self.layers[i]( - positions, hidden_states, input_metadata - ) + hidden_states = self.layers[i](positions, hidden_states, input_metadata) hidden_states = self.norm(hidden_states) hidden_states.mul_(self.config.output_multiplier_scale) @@ -572,28 +623,41 @@ class Grok1ModelForCausalLM(nn.Module): ] if use_fused: - expert_params_mapping = [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + expert_params_mapping = ( + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ( + "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id) + ( + "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ( + "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + f"experts.{expert_id}.{weight_name}.act_scale", + expert_id, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + ) else: expert_params_mapping = [] @@ -601,11 +665,11 @@ class Grok1ModelForCausalLM(nn.Module): if get_tensor_model_parallel_rank() == 0: weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4)) for name, loaded_weight in weights: - #print(get_tensor_model_parallel_rank(), name) + # print(get_tensor_model_parallel_rank(), name) if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -623,19 +687,22 @@ class Grok1ModelForCausalLM(nn.Module): name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id, - pre_sharded=get_tensor_model_parallel_world_size() > 1) + weight_loader( + param, + loaded_weight, + weight_name, + expert_id=expert_id, + pre_sharded=get_tensor_model_parallel_world_size() > 1, + ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) @@ -645,10 +712,11 @@ def all_close_1d(x: torch.Tensor) -> bool: old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") -def _prepare_presharded_weights(self, - model_name_or_path: str, - revision: Optional[str], - fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + + +def _prepare_presharded_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool +) -> Tuple[str, List[str], bool]: import glob import os @@ -668,4 +736,4 @@ def _prepare_presharded_weights(self, return hf_folder, hf_weights_files, use_safetensors -EntryClass = Grok1ModelForCausalLM \ No newline at end of file +EntryClass = Grok1ModelForCausalLM diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index d0f162c31..051036525 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -1,7 +1,7 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Optional, Tuple, Iterable +from typing import Any, Dict, Iterable, Optional, Tuple import torch import tqdm @@ -10,7 +10,7 @@ from transformers import LlamaConfig from vllm.config import CacheConfig from vllm.distributed import ( get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size + get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) + config.original_max_position_embeddings + ) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 258cc8c2f..915c9bee0 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -1,11 +1,17 @@ """Inference-only LLaVa model compatible with HuggingFace weights.""" -from typing import List, Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import numpy as np import torch from torch import nn -from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig +from transformers import ( + CLIPVisionConfig, + CLIPVisionModel, + LlavaConfig, + MistralConfig, + Qwen2Config, +) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -19,8 +25,8 @@ from sglang.srt.mm_utils import ( unpad_image_shape, ) from sglang.srt.models.llama2 import LlamaForCausalLM -from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.mistral import MistralForCausalLM +from sglang.srt.models.qwen2 import Qwen2ForCausalLM class LlavaLlamaForCausalLM(nn.Module): @@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): first_call = True + def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: batch_size = pixel_values.shape[0] @@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward(): ) -EntryClass = [ - LlavaLlamaForCausalLM, - LlavaQwenForCausalLM, - LlavaMistralForCausalLM -] +EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 541258811..47e20583c 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -1,6 +1,6 @@ """Inference-only LLaVa video model compatible with HuggingFace weights.""" -from typing import List, Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import numpy as np import torch diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 2c14dd142..abcde6de5 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.utils import set_weight_attrs from vllm.utils import print_warning_once - from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.controller.model_runner import InputMetadata - class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -76,32 +74,46 @@ class MixtralMoE(nn.Module): self.params_dtype = params_dtype # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=None) + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None, + ) if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn self.w13_weight = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype)) + torch.empty( + self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + dtype=params_dtype, + ) + ) self.w2_weight = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype)) + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + dtype=params_dtype, + ) + ) - set_weight_attrs(self.w13_weight, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2_weight, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.w13_weight, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2_weight, + { + "weight_loader": self.weight_loader, + }, + ) # Used for fp8. self.w13_scale = None @@ -111,46 +123,68 @@ class MixtralMoE(nn.Module): if self.use_fp8: # WEIGHT_SCALE (for fp8) - self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, - dtype=torch.float32), - requires_grad=False) - self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, - dtype=torch.float32), - requires_grad=False) + self.w13_scale = nn.Parameter( + torch.ones(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) + self.w2_scale = nn.Parameter( + torch.ones(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(self.w13_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2_scale, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.w13_scale, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2_scale, + { + "weight_loader": self.weight_loader, + }, + ) # ACT_SCALE (for fp8) if quant_config.activation_scheme == "static": if not quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " - "was not serialized fp8.") - self.a13_scale = nn.Parameter(torch.zeros( - self.num_total_experts, dtype=torch.float32), - requires_grad=False) - self.a2_scale = nn.Parameter(torch.zeros( - self.num_total_experts, dtype=torch.float32), - requires_grad=False) + "was not serialized fp8." + ) + self.a13_scale = nn.Parameter( + torch.zeros(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) + self.a2_scale = nn.Parameter( + torch.zeros(self.num_total_experts, dtype=torch.float32), + requires_grad=False, + ) - set_weight_attrs(self.a13_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.a2_scale, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.a13_scale, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.a2_scale, + { + "weight_loader": self.weight_loader, + }, + ) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size @@ -158,8 +192,9 @@ class MixtralMoE(nn.Module): if weight_name.endswith("w1.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] if "act_scale" in weight_name or "weight_scale" in weight_name: @@ -172,17 +207,17 @@ class MixtralMoE(nn.Module): # If checkpoint is fp16, quantize here. if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(self.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(self.w2_weight.data, - dtype=torch.float8_e4m3fn) + w13_weight = torch.empty_like( + self.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[ - expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], self.w2_scale[ - expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :]) + w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :] + ) + w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :] + ) self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) @@ -193,40 +228,40 @@ class MixtralMoE(nn.Module): if self.a13_scale is None or self.a2_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") + "activation scales are None." + ) - if (not all_close_1d(self.a13_scale) - or not all_close_1d(self.a2_scale)): + if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale): print_warning_once( "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. ") + "Using the maximum across experts for each layer. " + ) - self.a13_scale = nn.Parameter(self.a13_scale.max(), - requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), - requires_grad=False) + self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale) + final_hidden_states = fused_moe( + hidden_states, + self.w13_weight, + self.w2_weight, + router_logits, + self.top_k, + renormalize=True, + inplace=True, + use_fp8=self.use_fp8, + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale, + ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) @@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - quant_config=quant_config) + quant_config=quant_config, + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module): ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + expert_params_mapping = ( + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ( + "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id) + ( + "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ( + "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + f"experts.{expert_id}.{weight_name}.act_scale", + expert_id, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + ) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module): name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=expert_id + ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 94df124e8..aa8f8a759 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.controller.model_runner import InputMetadata diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index c2ff0aeea..9c59d14fe 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -1,6 +1,6 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 -from typing import Any, Dict, Optional, Iterable, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple import torch from torch import nn diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 5d0115522..dc50075ca 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -1,7 +1,7 @@ # Adapted from llama2.py # Modify details for the adaptation of Qwen2 model. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Any, Dict, Optional, Tuple, Iterable +from typing import Any, Dict, Iterable, Optional, Tuple import torch from torch import nn diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 72fa1508d..875ddd70b 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -2,7 +2,7 @@ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1 """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b) model compatible with HuggingFace weights.""" -from typing import Optional, Tuple, Iterable +from typing import Iterable, Optional, Tuple import torch from torch import nn diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 2675502b0..3016bfe13 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -1,14 +1,14 @@ """Inference-only Yi-VL model.""" -from typing import Tuple, Iterable, Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn from transformers import CLIPVisionModel, LlavaConfig from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from sglang.srt.models.llava import ( LlavaLlamaForCausalLM, monkey_path_clip_vision_embed_forward, diff --git a/python/sglang/srt/openai_api_adapter.py b/python/sglang/srt/openai_api_adapter.py index 1230dc07c..75656f324 100644 --- a/python/sglang/srt/openai_api_adapter.py +++ b/python/sglang/srt/openai_api_adapter.py @@ -6,7 +6,7 @@ import os from http import HTTPStatus from fastapi import Request -from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from sglang.srt.conversation import ( Conversation, @@ -40,21 +40,18 @@ chat_template_name = None def create_error_response( message: str, err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST): - error = ErrorResponse(message=message, - type=err_type, - code=status_code.value) - return JSONResponse(content=error.model_dump(), - status_code=error.code) + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +): + error = ErrorResponse(message=message, type=err_type, code=status_code.value) + return JSONResponse(content=error.model_dump(), status_code=error.code) def create_streaming_error_response( message: str, err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: - error = ErrorResponse(message=message, - type=err_type, - code=status_code.value) + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +) -> str: + error = ErrorResponse(message=message, type=err_type, code=status_code.value) json_str = json.dumps({"error": error.model_dump()}) return json_str @@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request): n_prev_token = 0 try: async for content in tokenizer_manager.generate_request( - adapted_request, raw_request): + adapted_request, raw_request + ): text = content["text"] prompt_tokens = content["meta_info"]["prompt_tokens"] completion_tokens = content["meta_info"]["completion_tokens"] @@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request): decode_token_logprobs=content["meta_info"][ "decode_token_logprobs" ][n_prev_token:], - decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][ - n_prev_token: - ], + decode_top_logprobs=content["meta_info"][ + "decode_top_logprobs" + ][n_prev_token:], ) - n_prev_token = len(content["meta_info"]["decode_token_logprobs"]) + n_prev_token = len( + content["meta_info"]["decode_token_logprobs"] + ) else: logprobs = None @@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request): yield f"data: {error}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request)) + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request), + ) # Non-streaming response. try: ret = await tokenizer_manager.generate_request( - adapted_request, raw_request).__anext__() + adapted_request, raw_request + ).__anext__() except ValueError as e: return create_error_response(str(e)) @@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): stream_buffer = "" try: - async for content in tokenizer_manager.generate_request(adapted_request, raw_request): + async for content in tokenizer_manager.generate_request( + adapted_request, raw_request + ): if is_first: # First chunk with role is_first = False @@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): yield f"data: {error}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request)) + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request), + ) # Non-streaming response. try: ret = await tokenizer_manager.generate_request( - adapted_request, raw_request).__anext__() + adapted_request, raw_request + ).__anext__() except ValueError as e: return create_error_response(str(e)) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7b6dca68f..c98a760c5 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -13,7 +13,7 @@ import sys import threading import time from http import HTTPStatus -from typing import Optional, Dict +from typing import Dict, Optional # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.constrained import disable_cache from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.controller.manager_multi import ( + start_controller_process as start_controller_process_multi, +) +from sglang.srt.managers.controller.manager_single import ( + start_controller_process as start_controller_process_single, +) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single -from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api_adapter import ( load_chat_template_for_openai_api, @@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request): yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(stream_results(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj)) + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(obj), + ) else: try: ret = await tokenizer_manager.generate_request(obj, request).__anext__() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 43ae6f62a..8a7f33eb6 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1,8 +1,8 @@ """Common utilities.""" import base64 -import multiprocessing import logging +import multiprocessing import os import random import socket @@ -17,12 +17,11 @@ import requests import rpyc import torch import triton -from rpyc.utils.server import ThreadedServer from fastapi.responses import JSONResponse from packaging import version as pkg_version +from rpyc.utils.server import ThreadedServer from starlette.middleware.base import BaseHTTPMiddleware - logger = logging.getLogger(__name__) @@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int): protocol_config={ "allow_public_attrs": True, "allow_pickle": True, - "sync_request_timeout": 3600 + "sync_request_timeout": 3600, }, ) t.logger.setLevel(logging.WARN) @@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"): config={ "allow_public_attrs": True, "allow_pickle": True, - "sync_request_timeout": 3600 + "sync_request_timeout": 3600, }, ) break @@ -423,7 +422,9 @@ def suppress_other_loggers(): vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.config").setLevel(logging.ERROR) - logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN) + logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel( + logging.WARN + ) logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN) @@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): device_name = torch.cuda.get_device_name(gpu_id) if "RTX 40" not in device_name: import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) @@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ) response = await call_next(request) return response - diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 4ad480887..6fa8f8214 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -356,16 +356,25 @@ def test_completion_speculative(): s += "Construct a character within the following format:\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + s += ( + "Name:" + + sgl.gen("name", stop="\n") + + "\nBirthday:" + + sgl.gen("birthday", stop="\n") + ) s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" - @sgl.function def gen_character_no_spec(s): s += "Construct a character within the following format:\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + s += ( + "Name:" + + sgl.gen("name", stop="\n") + + "\nBirthday:" + + sgl.gen("birthday", stop="\n") + ) s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" token_usage = sgl.global_config.default_backend.token_usage @@ -378,7 +387,9 @@ def test_completion_speculative(): gen_character_no_spec().sync() usage_with_no_spec = token_usage.prompt_tokens - assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}" + assert ( + usage_with_spec < usage_with_no_spec + ), f"{usage_with_spec} vs {usage_with_no_spec}" def test_chat_completion_speculative(): @@ -386,8 +397,17 @@ def test_chat_completion_speculative(): def gen_character_spec(s): s += sgl.system("You are a helpful assistant.") s += sgl.user("Construct a character within the following format:") - s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n") + s += sgl.assistant( + "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" + ) s += sgl.user("Please generate new Name, Birthday and Job.\n") - s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) + s += sgl.assistant( + "Name:" + + sgl.gen("name", stop="\n") + + "\nBirthday:" + + sgl.gen("birthday", stop="\n") + + "\nJob:" + + sgl.gen("job", stop="\n") + ) - gen_character_spec().sync() \ No newline at end of file + gen_character_spec().sync() diff --git a/python/sglang/utils.py b/python/sglang/utils.py index e4a3e9adb..0f5fd4390 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -15,7 +15,6 @@ from json import dumps import numpy as np import requests - logger = logging.getLogger(__name__) @@ -255,8 +254,10 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None): def graceful_registry(sub_module_name): def graceful_shutdown(signum, frame): - logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...") + logger.info( + f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..." + ) if signum == signal.SIGTERM: logger.info(f"{sub_module_name} recive sigterm") - signal.signal(signal.SIGTERM, graceful_shutdown) \ No newline at end of file + signal.signal(signal.SIGTERM, graceful_shutdown) diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py index b028bc0ab..d35495e4d 100644 --- a/test/lang/test_openai_backend.py +++ b/test/lang/test_openai_backend.py @@ -2,6 +2,8 @@ import unittest from sglang import OpenAI, set_default_backend from sglang.test.test_programs import ( + test_chat_completion_speculative, + test_completion_speculative, test_decode_int, test_decode_json, test_expert_answer, @@ -14,8 +16,6 @@ from sglang.test.test_programs import ( test_select, test_stream, test_tool_use, - test_completion_speculative, - test_chat_completion_speculative ) @@ -97,4 +97,4 @@ if __name__ == "__main__": # global_config.verbosity = 2 # t = TestOpenAIBackend() # t.setUp() - # t.test_chat_completion_speculative() \ No newline at end of file + # t.test_chat_completion_speculative()