diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 1f07286ed..6ea1ac9d7 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -1,5 +1,6 @@ import importlib import logging +import inspect from dataclasses import dataclass from functools import lru_cache from pathlib import Path @@ -124,14 +125,21 @@ class InputMetadata: self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD" ) - self.prefill_wrapper.begin_forward( + args = [ self.qo_indptr, self.kv_indptr, self.kv_indices, self.kv_last_page_len, self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size, - ) + ] + + # flashinfer >= 0.0.3 + # FIXME: Drop this when flashinfer updates to 0.0.4 + if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7: + args.append(self.model_runner.model_config.head_dim) + + self.prefill_wrapper.begin_forward(*args) else: self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, "NHD"