[Minor] move triton attention kernels into a separate folder (#1379)
This commit is contained in:
@@ -57,9 +57,9 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""For constrained decoding."""
|
||||
|
||||
import json
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Conversation templates."""
|
||||
"""Conversation chat templates."""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
|
||||
@@ -16,11 +16,9 @@ limitations under the License.
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
|
||||
@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state
|
||||
from torch import nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
"""
|
||||
The attention layer implementation.
|
||||
Now it has two backends: FlashInfer and Triton.
|
||||
FlashInfer is faster and Triton is easier to customize.
|
||||
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -49,8 +56,10 @@ class RadixAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
||||
|
||||
# Choose backend
|
||||
if (
|
||||
not global_server_args_dict.get("disable_flashinfer", False)
|
||||
and self.qk_head_dim == self.v_head_dim
|
||||
@@ -61,8 +70,6 @@ class RadixAttention(nn.Module):
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.prefill_attention import context_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
@@ -29,6 +29,7 @@ import torch.distributed
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
@@ -52,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
|
||||
)
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
|
||||
@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
"""Meta data for a forward pass."""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
@@ -18,7 +18,6 @@ limitations under the License.
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
@@ -53,7 +53,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
|
||||
Reference in New Issue
Block a user