Follow-up fixes for flashinfer 0.0.5 (#556)
This commit is contained in:
@@ -67,7 +67,7 @@ class InputMetadata:
|
||||
flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
|
||||
def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim):
|
||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
@@ -102,8 +102,8 @@ class InputMetadata:
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1
|
||||
)
|
||||
@@ -113,8 +113,8 @@ class InputMetadata:
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
@@ -203,7 +203,7 @@ class InputMetadata:
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
ret.init_flashinfer_args(
|
||||
model_runner.model_config.num_attention_heads // tp_size,
|
||||
model_runner.model_config.num_key_value_heads // tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||
model_runner.model_config.head_dim
|
||||
)
|
||||
|
||||
@@ -350,6 +350,15 @@ class ModelRunner:
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_config.num_attention_heads // self.tp_size,
|
||||
self.model_config.get_num_kv_heads(self.tp_size)):
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||
)
|
||||
@@ -357,8 +366,10 @@ class ModelRunner:
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
else:
|
||||
self.flashinfer_prefill_wrapper = self.flashinfer_decode_wrapper = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_prefill(self, batch: Batch):
|
||||
|
||||
Reference in New Issue
Block a user