Allow overwrite flashinfer use_tensorcore (#2169)

This commit is contained in:
Lianmin Zheng
2024-11-24 20:58:17 -08:00
committed by GitHub
parent dd44173dad
commit 8e1adb8441
6 changed files with 18 additions and 10 deletions

View File

@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
import os
from enum import Enum, auto
from typing import TYPE_CHECKING, List
@@ -45,13 +46,19 @@ class FlashInferAttnBackend(AttentionBackend):
super().__init__()
# Parse constants
if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
self.decode_use_tensor_cores = (
os.environ["SGLANG_FLASHINFER_USE_TENSOR_CORE"].lower() == "true"
)
else:
self.decode_use_tensor_cores = False
if not _grouped_size_compiled_for_decode_kernels(
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
):
self.decode_use_tensor_cores = True
else:
self.decode_use_tensor_cores = False
self.max_context_len = model_runner.model_config.context_len
assert not (

View File

@@ -81,7 +81,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
class Scheduler:

View File

@@ -930,7 +930,7 @@ def get_nvgpu_memory_capacity():
def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true"
def get_device_name(device_id: int = 0) -> str: