From ad1dd74673a2e918a39d869865c1830fb634d150 Mon Sep 17 00:00:00 2001 From: Qubitium <417764+Qubitium@users.noreply.github.com> Date: Tue, 12 Mar 2024 21:45:58 +0800 Subject: [PATCH] Fix flashinfer >= 0.0.3 compat (#282) --- python/sglang/srt/managers/router/model_runner.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 1f07286ed..6ea1ac9d7 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -1,5 +1,6 @@ import importlib import logging +import inspect from dataclasses import dataclass from functools import lru_cache from pathlib import Path @@ -124,14 +125,21 @@ class InputMetadata: self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD" ) - self.prefill_wrapper.begin_forward( + args = [ self.qo_indptr, self.kv_indptr, self.kv_indices, self.kv_last_page_len, self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size, - ) + ] + + # flashinfer >= 0.0.3 + # FIXME: Drop this when flashinfer updates to 0.0.4 + if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7: + args.append(self.model_runner.model_config.head_dim) + + self.prefill_wrapper.begin_forward(*args) else: self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, "NHD"