Support double sparsity (#1459)

This commit is contained in:
Shuo Yang
2024-10-14 02:00:41 -07:00
committed by GitHub
parent 0c1e87964b
commit 061e546313
8 changed files with 1269 additions and 1 deletions

View File

@@ -18,6 +18,7 @@ limitations under the License.
import gc
import importlib
import importlib.resources
import json
import logging
import pkgutil
from functools import lru_cache
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constrained import disable_cache
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
DoubleSparseTokenToKVPool,
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
@@ -99,6 +102,20 @@ class ModelRunner:
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
self.server_args.attention_backend = "triton"
self.server_args.disable_cuda_graph = True
if self.server_args.ds_heavy_channel_type is None:
raise ValueError(
"Please specify the heavy channel type for double sparsity optimization."
)
self.init_double_sparsity_channel_config(
self.server_args.ds_heavy_channel_type
)
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
@@ -439,6 +456,16 @@ class ModelRunner:
layer_num=self.model_config.num_hidden_layers,
device=self.device,
)
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
heavy_channel_num=self.server_args.ds_heavy_channel_num,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
@@ -475,12 +502,33 @@ class ModelRunner:
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self.attn_backend = TritonAttnBackend(self)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"
self.sorted_channels = []
# load channel config
with open(self.server_args.ds_channel_config_path, "r") as f:
channel_config = json.load(f)
for i in range(self.model_config.num_hidden_layers):
key = "model.layers." + str(i) + ".self_attn" + selected_channel
self.sorted_channels.append(
torch.tensor(channel_config[key])[
:, : self.server_args.ds_heavy_channel_num
]
.contiguous()
.cuda()
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner