Allow overwrite flashinfer use_tensorcore (#2169)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
# SGLang Documentation
|
# SGLang Documentation
|
||||||
|
This is the documentation repository for SGLang. It is auto-generated from https://github.com/sgl-project/sglang/tree/main/docs.
|
||||||
|
|
||||||
## Build the documentation website
|
## Build the documentation website
|
||||||
|
|
||||||
|
|||||||
@@ -407,7 +407,7 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
|||||||
|
|
||||||
|
|
||||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||||
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
|
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true":
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|||||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
@@ -45,13 +46,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Parse constants
|
# Parse constants
|
||||||
if not _grouped_size_compiled_for_decode_kernels(
|
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
self.decode_use_tensor_cores = (
|
||||||
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
os.environ["SGLANG_FLASHINFER_USE_TENSOR_CORE"].lower() == "true"
|
||||||
):
|
)
|
||||||
self.decode_use_tensor_cores = True
|
|
||||||
else:
|
else:
|
||||||
self.decode_use_tensor_cores = False
|
if not _grouped_size_compiled_for_decode_kernels(
|
||||||
|
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
||||||
|
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
||||||
|
):
|
||||||
|
self.decode_use_tensor_cores = True
|
||||||
|
else:
|
||||||
|
self.decode_use_tensor_cores = False
|
||||||
|
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ from sglang.utils import get_exception_traceback
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Test retract decode
|
# Test retract decode
|
||||||
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
|
|||||||
@@ -930,7 +930,7 @@ def get_nvgpu_memory_capacity():
|
|||||||
|
|
||||||
def crash_on_warnings():
|
def crash_on_warnings():
|
||||||
# Crash on warning if we are running CI tests
|
# Crash on warning if we are running CI tests
|
||||||
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
def get_device_name(device_id: int = 0) -> str:
|
def get_device_name(device_id: int = 0) -> str:
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8
|
|||||||
|
|
||||||
def is_in_ci():
|
def is_in_ci():
|
||||||
"""Return whether it is in CI runner."""
|
"""Return whether it is in CI runner."""
|
||||||
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
|
|||||||
Reference in New Issue
Block a user