Update flashinfer to 0.0.5 (#554)
This commit is contained in:
@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(
|
||||
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
|
||||
self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
|
||||
layer_id: int, logit_cap: int = -1
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
@@ -20,7 +21,6 @@ class RadixAttention(nn.Module):
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap
|
||||
|
||||
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
||||
|
||||
@@ -30,10 +30,17 @@ class RadixAttention(nn.Module):
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
# flashinfer only accepts a boolean logit_cap argument
|
||||
if logit_cap > 0:
|
||||
assert logit_cap == 30
|
||||
self.logit_cap = True
|
||||
else:
|
||||
self.logit_cap = False
|
||||
else:
|
||||
self.prefill_forward = self.prefill_forward_triton
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
self.logit_cap = logit_cap
|
||||
|
||||
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
@@ -100,9 +107,10 @@ class RadixAttention(nn.Module):
|
||||
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.prefill_wrapper.forward(
|
||||
o = input_metadata.flashinfer_prefill_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
logits_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
@@ -110,9 +118,10 @@ class RadixAttention(nn.Module):
|
||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.decode_wrapper.forward(
|
||||
o = input_metadata.flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
logits_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
import pkgutil
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -34,7 +34,6 @@ global_server_args_dict = {}
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
model_runner: "ModelRunner"
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
@@ -65,15 +64,10 @@ class InputMetadata:
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
kv_last_page_len: torch.Tensor = None
|
||||
prefill_wrapper = None
|
||||
decode_wrapper = None
|
||||
|
||||
def init_flashinfer_args(self, tp_size):
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
|
||||
def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim):
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
@@ -93,9 +87,6 @@ class InputMetadata:
|
||||
dim=0,
|
||||
).contiguous()
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||
)
|
||||
if (
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
@@ -104,34 +95,30 @@ class InputMetadata:
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
args = [
|
||||
|
||||
self.flashinfer_prefill_wrapper.end_forward()
|
||||
self.flashinfer_prefill_wrapper.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
self.model_runner.model_config.head_dim,
|
||||
]
|
||||
|
||||
self.prefill_wrapper.begin_forward(*args)
|
||||
else:
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
1
|
||||
)
|
||||
self.decode_wrapper.begin_forward(
|
||||
else:
|
||||
self.flashinfer_decode_wrapper.end_forward()
|
||||
self.flashinfer_decode_wrapper.begin_forward(
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
self.model_runner.model_config.head_dim,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
1,
|
||||
"NONE",
|
||||
"float16",
|
||||
pos_encoding_mode="NONE",
|
||||
data_type="float16",
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
@@ -155,6 +142,8 @@ class InputMetadata:
|
||||
out_cache_cont_end=None,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
flashinfer_prefill_wrapper=None,
|
||||
flashinfer_decode_wrapper=None,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
@@ -187,7 +176,6 @@ class InputMetadata:
|
||||
other_kv_index = None
|
||||
|
||||
ret = cls(
|
||||
model_runner=model_runner,
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
@@ -205,13 +193,19 @@ class InputMetadata:
|
||||
other_kv_index=other_kv_index,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper=flashinfer_prefill_wrapper,
|
||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
ret.init_extend_args()
|
||||
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
ret.init_flashinfer_args(tp_size)
|
||||
ret.init_flashinfer_args(
|
||||
model_runner.model_config.num_attention_heads // tp_size,
|
||||
model_runner.model_config.num_key_value_heads // tp_size,
|
||||
model_runner.model_config.head_dim
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -234,12 +228,7 @@ class ModelRunner:
|
||||
self.tp_size = tp_size
|
||||
self.nccl_port = nccl_port
|
||||
self.server_args = server_args
|
||||
|
||||
global global_server_args_dict
|
||||
global_server_args_dict = {
|
||||
"enable_flashinfer": server_args.enable_flashinfer,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
}
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
|
||||
# Init torch distributed
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
||||
@@ -269,9 +258,17 @@ class ModelRunner:
|
||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||
)
|
||||
|
||||
# Set some global args
|
||||
global global_server_args_dict
|
||||
global_server_args_dict = {
|
||||
"enable_flashinfer": server_args.enable_flashinfer,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
}
|
||||
|
||||
# Load the model and create memory pool
|
||||
self.load_model()
|
||||
self.init_memory_pool(total_gpu_memory)
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
self.init_flash_infer()
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
@@ -347,6 +344,22 @@ class ModelRunner:
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def init_flash_infer(self):
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
from flashinfer import (
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
)
|
||||
workspace_buffer = torch.empty(
|
||||
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_prefill(self, batch: Batch):
|
||||
input_metadata = InputMetadata.create(
|
||||
@@ -360,6 +373,8 @@ class ModelRunner:
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -378,6 +393,8 @@ class ModelRunner:
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -398,6 +415,8 @@ class ModelRunner:
|
||||
out_cache_cont_end=batch.out_cache_cont_end,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -416,6 +435,8 @@ class ModelRunner:
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
|
||||
@@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
if server_args.enable_flashinfer:
|
||||
assert_pkg_version("flashinfer", "0.0.4")
|
||||
assert_pkg_version("flashinfer", "0.0.5")
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
Reference in New Issue
Block a user