diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index be6795846..bfe739432 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -57,9 +57,9 @@ import pandas as pd import torch import torch.distributed as dist +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch -from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/configs/model_config.py similarity index 100% rename from python/sglang/srt/model_config.py rename to python/sglang/srt/configs/model_config.py diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 7e097c6fc..c47c5c8dd 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +"""For constrained decoding.""" + import json from typing import Dict, Optional, Union diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 9a1227218..341551eca 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""Conversation templates.""" +"""Conversation chat templates.""" # Adapted from # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index ae3070c5a..f6c414ec3 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -16,11 +16,9 @@ limitations under the License. """Utilities for Huggingface Transformers.""" import contextlib -import functools -import json import os import warnings -from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union +from typing import Dict, Optional, Type, Union from huggingface_hub import snapshot_download from transformers import ( diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 1a2feacd3..adada7cda 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state from torch import nn from sglang.global_config import global_config -from sglang.srt.layers.decode_attention import decode_attention_fwd -from sglang.srt.layers.extend_attention import extend_attention_fwd +from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd +from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.model_runner import global_server_args_dict class RadixAttention(nn.Module): + """ + The attention layer implementation. + Now it has two backends: FlashInfer and Triton. + FlashInfer is faster and Triton is easier to customize. + It supports two operators: extend (i.e. prefill with cached prefix) and decode. + """ + def __init__( self, num_heads: int, @@ -49,8 +56,10 @@ class RadixAttention(nn.Module): self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id + self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 self.sliding_window_size = sliding_window_size if sliding_window_size else -1 + # Choose backend if ( not global_server_args_dict.get("disable_flashinfer", False) and self.qk_head_dim == self.v_head_dim @@ -61,8 +70,6 @@ class RadixAttention(nn.Module): self.extend_forward = self.extend_forward_triton self.decode_forward = self.decode_forward_triton - self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 - def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): if self.qk_head_dim != self.v_head_dim: o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py similarity index 100% rename from python/sglang/srt/layers/decode_attention.py rename to python/sglang/srt/layers/triton_attention/decode_attention.py diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/triton_attention/extend_attention.py similarity index 99% rename from python/sglang/srt/layers/extend_attention.py rename to python/sglang/srt/layers/triton_attention/extend_attention.py index 5c8e51c5f..81039e676 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/triton_attention/extend_attention.py @@ -22,7 +22,7 @@ import torch import triton import triton.language as tl -from sglang.srt.layers.prefill_attention import context_attention_fwd +from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd CUDA_CAPABILITY = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/layers/prefill_attention.py b/python/sglang/srt/layers/triton_attention/prefill_attention.py similarity index 100% rename from python/sglang/srt/layers/prefill_attention.py rename to python/sglang/srt/layers/triton_attention/prefill_attention.py diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 736929a65..fe7c4bcab 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ import torch.distributed import torch.distributed as dist from sglang.global_config import global_config +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer @@ -52,7 +53,6 @@ from sglang.srt.managers.schedule_batch import ( ) from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache -from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c1fb23357..867bd95a1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -15,7 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""ModelRunner runs the forward passes of the models.""" +"""Meta data for a forward pass.""" from dataclasses import dataclass from enum import IntEnum, auto from typing import TYPE_CHECKING, List diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3cb123c48..3033a7ce4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -18,7 +18,6 @@ limitations under the License. import gc import importlib import importlib.resources -import json import logging import pkgutil from functools import lru_cache @@ -45,6 +44,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config +from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict @@ -53,7 +53,6 @@ from sglang.srt.mem_cache.memory_pool import ( MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_config import AttentionArch, ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( diff --git a/scripts/deprecated/test_flashinfer.py b/scripts/deprecated/test_flashinfer.py index 638647677..7f0a081f6 100644 --- a/scripts/deprecated/test_flashinfer.py +++ b/scripts/deprecated/test_flashinfer.py @@ -6,8 +6,11 @@ from flashinfer import ( ) from flashinfer.decode import _grouped_size_compiled_for_decode_kernels -from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention from sglang.srt.layers.token_attention import token_attention_fwd +from sglang.srt.layers.triton_attention.extend_attention import ( + extend_attention_fwd, + redundant_attention, +) flashinfer_prefill_wrapper = None flashinfer_decode_wrapper = None