Fix flashinfer >= 0.0.3 compat (#282)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import logging
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
@@ -124,14 +125,21 @@ class InputMetadata:
|
||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
self.prefill_wrapper.begin_forward(
|
||||
args = [
|
||||
self.qo_indptr,
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
)
|
||||
]
|
||||
|
||||
# flashinfer >= 0.0.3
|
||||
# FIXME: Drop this when flashinfer updates to 0.0.4
|
||||
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
||||
args.append(self.model_runner.model_config.head_dim)
|
||||
|
||||
self.prefill_wrapper.begin_forward(*args)
|
||||
else:
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
|
||||
Reference in New Issue
Block a user