Allow disabling flashinfer sampling kernel (#778)
This commit is contained in:
@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
|
||||
# Put some args for easily access
|
||||
global_server_args_dict = {}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
@@ -150,14 +147,6 @@ def available_models():
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
def _set_global_server_args(server_args: ServerArgs):
|
||||
global global_server_args_dict
|
||||
global_server_args_dict = {
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
}
|
||||
|
||||
|
||||
def _set_torch_compile_config():
|
||||
# The following configurations are for torch compile optimizations
|
||||
import torch._dynamo.config
|
||||
@@ -213,8 +202,6 @@ def launch_server(
|
||||
if server_args.enable_torch_compile:
|
||||
_set_torch_compile_config()
|
||||
|
||||
_set_global_server_args(server_args)
|
||||
|
||||
# Allocate ports
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port,
|
||||
|
||||
Reference in New Issue
Block a user