Refactor attention backend (#1381)

This commit is contained in:
Lianmin Zheng
2024-09-11 11:44:26 -07:00
committed by GitHub
parent c03cece42f
commit fec185ce0c
16 changed files with 568 additions and 564 deletions

View File

@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
import torch
import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import (
@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__)
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
def __init__(
self,
model_config: ModelConfig,
@@ -100,6 +96,7 @@ class ModelRunner:
}
)
# Model-specific adjustment
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
@@ -107,6 +104,7 @@ class ModelRunner:
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
# Init componnets
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
@@ -115,7 +113,7 @@ class ModelRunner:
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
self.init_attention_backend()
self.init_cuda_graphs()
def init_torch_distributed(self):
@@ -397,9 +395,6 @@ class ModelRunner:
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.attention_backend = "triton"
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
@@ -422,106 +417,42 @@ class ModelRunner:
c = a @ b
return c
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.attention_backend != "flashinfer":
assert (
self.sliding_window_size is None
), "turn on flashinfer to support window attention"
self.flashinfer_prefill_wrapper_ragged = None
self.flashinfer_prefill_wrapper_paged = None
self.flashinfer_decode_wrapper = None
return
if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size),
):
use_tensor_cores = True
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self.attn_backend = TritonAttnBackend(self)
else:
use_tensor_cores = False
if self.sliding_window_size is None:
self.flashinfer_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)
else:
self.flashinfer_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = None
self.flashinfer_prefill_wrapper_paged = []
self.flashinfer_decode_wrapper = []
for i in range(2):
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
self.cuda_graph_runner = None
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph:
return
if (
self.server_args.disable_cuda_graph
or self.server_args.attention_backend != "flashinfer"
):
self.cuda_graph_runner = None
if self.server_args.attention_backend != "flashinfer":
logger.warning(
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
)
return
logger.info("Capture cuda graph begin. This can take up to several minutes.")
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner(
self,
max_batch_size_to_capture=max(batch_size_list),
use_torch_compile=self.server_args.enable_torch_compile,
disable_padding=self.server_args.disable_cuda_graph_padding,
)
try:
self.cuda_graph_runner.capture(batch_size_list)
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
self.cuda_graph_runner = CudaGraphRunner(self)
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):