Allow disabling flashinfer sampling kernel (#778)

This commit is contained in:
Lianmin Zheng
2024-07-27 20:18:56 -07:00
committed by GitHub
parent 30db99b3d9
commit 752e643007
6 changed files with 41 additions and 26 deletions

View File

@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI()
tokenizer_manager = None
# Put some args for easily access
global_server_args_dict = {}
@app.get("/health")
async def health() -> Response:
@@ -150,14 +147,6 @@ def available_models():
return ModelList(data=model_cards)
def _set_global_server_args(server_args: ServerArgs):
global global_server_args_dict
global_server_args_dict = {
"disable_flashinfer": server_args.disable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
def _set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
@@ -213,8 +202,6 @@ def launch_server(
if server_args.enable_torch_compile:
_set_torch_compile_config()
_set_global_server_args(server_args)
# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,