diff --git a/docs/start/install.md b/docs/start/install.md index 2d778bf72..88a920363 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -22,7 +22,7 @@ pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ ``` -Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. If you meet with issue like **ImportError: cannot import name `_grouped_size_compiled_for_decode_kernels`**, installing FlashInfer with some older version like 0.1.6 instead of the latest version could solve it. +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 75ed9b3fc..e899bcb7e 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1077,21 +1077,6 @@ def should_use_tensor_core( if env_override is not None: return env_override.lower() == "true" - # Try to use _grouped_size_compiled_for_decode_kernels if available - # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug - try: - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels - - if not _grouped_size_compiled_for_decode_kernels( - num_attention_heads, - num_kv_heads, - ): - return True - else: - return False - except (ImportError, AttributeError): - pass - # Calculate GQA group size gqa_group_size = num_attention_heads // num_kv_heads