Fix flashinfer >= 0.0.3 compat (#282)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -124,14 +125,21 @@ class InputMetadata:
|
|||||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
workspace_buffer, "NHD"
|
workspace_buffer, "NHD"
|
||||||
)
|
)
|
||||||
self.prefill_wrapper.begin_forward(
|
args = [
|
||||||
self.qo_indptr,
|
self.qo_indptr,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
self.kv_indices,
|
self.kv_indices,
|
||||||
self.kv_last_page_len,
|
self.kv_last_page_len,
|
||||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||||
)
|
]
|
||||||
|
|
||||||
|
# flashinfer >= 0.0.3
|
||||||
|
# FIXME: Drop this when flashinfer updates to 0.0.4
|
||||||
|
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
||||||
|
args.append(self.model_runner.model_config.head_dim)
|
||||||
|
|
||||||
|
self.prefill_wrapper.begin_forward(*args)
|
||||||
else:
|
else:
|
||||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer, "NHD"
|
workspace_buffer, "NHD"
|
||||||
|
|||||||
Reference in New Issue
Block a user