Refactor attention backend (#1381)
This commit is contained in:
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import (
|
||||
@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
@@ -100,6 +96,7 @@ class ModelRunner:
|
||||
}
|
||||
)
|
||||
|
||||
# Model-specific adjustment
|
||||
if self.is_multimodal_model:
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||
@@ -107,6 +104,7 @@ class ModelRunner:
|
||||
server_args.chunked_prefill_size = None
|
||||
server_args.mem_fraction_static *= 0.95
|
||||
|
||||
# Init componnets
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.load_model()
|
||||
self.init_memory_pool(
|
||||
@@ -115,7 +113,7 @@ class ModelRunner:
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_flashinfer()
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_torch_distributed(self):
|
||||
@@ -397,9 +395,6 @@ class ModelRunner:
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
||||
# FIXME: temporarily only Triton MLA is supported
|
||||
self.server_args.attention_backend = "triton"
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
@@ -422,106 +417,42 @@ class ModelRunner:
|
||||
c = a @ b
|
||||
return c
|
||||
|
||||
def init_flashinfer(self):
|
||||
"""Init flashinfer attention kernel wrappers."""
|
||||
if self.server_args.attention_backend != "flashinfer":
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
), "turn on flashinfer to support window attention"
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
return
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_config.num_attention_heads // self.tp_size,
|
||||
self.model_config.get_num_kv_heads(self.tp_size),
|
||||
):
|
||||
use_tensor_cores = True
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
if self.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = []
|
||||
self.flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
self.cuda_graph_runner = None
|
||||
|
||||
if not self.is_generation:
|
||||
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
||||
return
|
||||
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
if self.server_args.disable_cuda_graph:
|
||||
return
|
||||
|
||||
if (
|
||||
self.server_args.disable_cuda_graph
|
||||
or self.server_args.attention_backend != "flashinfer"
|
||||
):
|
||||
self.cuda_graph_runner = None
|
||||
if self.server_args.attention_backend != "flashinfer":
|
||||
logger.warning(
|
||||
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
|
||||
if self.server_args.disable_cuda_graph_padding:
|
||||
batch_size_list = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
|
||||
self.cuda_graph_runner = CudaGraphRunner(
|
||||
self,
|
||||
max_batch_size_to_capture=max(batch_size_list),
|
||||
use_torch_compile=self.server_args.enable_torch_compile,
|
||||
disable_padding=self.server_args.disable_cuda_graph_padding,
|
||||
)
|
||||
try:
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n"
|
||||
"Possible solutions:\n"
|
||||
"1. disable cuda graph by --disable-cuda-graph\n"
|
||||
"2. set --mem-fraction-static to a smaller value\n"
|
||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
|
||||
Reference in New Issue
Block a user