Support double sparsity (#1459)
This commit is contained in:
@@ -18,6 +18,7 @@ limitations under the License.
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
DoubleSparseTokenToKVPool,
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
@@ -99,6 +102,20 @@ class ModelRunner:
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
|
||||
if self.server_args.enable_double_sparsity:
|
||||
logger.info(
|
||||
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
||||
)
|
||||
self.server_args.attention_backend = "triton"
|
||||
self.server_args.disable_cuda_graph = True
|
||||
if self.server_args.ds_heavy_channel_type is None:
|
||||
raise ValueError(
|
||||
"Please specify the heavy channel type for double sparsity optimization."
|
||||
)
|
||||
self.init_double_sparsity_channel_config(
|
||||
self.server_args.ds_heavy_channel_type
|
||||
)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||
@@ -439,6 +456,16 @@ class ModelRunner:
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
)
|
||||
elif self.server_args.enable_double_sparsity:
|
||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
@@ -475,12 +502,33 @@ class ModelRunner:
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
|
||||
def init_double_sparsity_channel_config(self, selected_channel):
|
||||
|
||||
selected_channel = "." + selected_channel + "_proj"
|
||||
self.sorted_channels = []
|
||||
# load channel config
|
||||
with open(self.server_args.ds_channel_config_path, "r") as f:
|
||||
channel_config = json.load(f)
|
||||
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
||||
self.sorted_channels.append(
|
||||
torch.tensor(channel_config[key])[
|
||||
:, : self.server_args.ds_heavy_channel_num
|
||||
]
|
||||
.contiguous()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
Reference in New Issue
Block a user