Allow disabling flashinfer sampling kernel (#778)
This commit is contained in:
@@ -7,8 +7,11 @@ from torch import nn
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.managers.controller.model_runner import (
|
||||||
from sglang.srt.server import global_server_args_dict
|
ForwardMode,
|
||||||
|
InputMetadata,
|
||||||
|
global_server_args_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.server import global_server_args_dict
|
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
||||||
|
|
||||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||||
REDUCE_TRITON_TYPE = tl.float32
|
REDUCE_TRITON_TYPE = tl.float32
|
||||||
|
|||||||
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
|
# Put some global args for easy access
|
||||||
|
global_server_args_dict = {
|
||||||
|
"disable_flashinfer": False,
|
||||||
|
"disable_flashinfer_sampling": False,
|
||||||
|
"attention_reduce_in_fp32": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||||
@@ -687,7 +694,7 @@ class Batch:
|
|||||||
# TODO(lmzheng): apply penalty
|
# TODO(lmzheng): apply penalty
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
if True:
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
|
|||||||
@@ -25,7 +25,12 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
|
from sglang.srt.managers.controller.infer_batch import (
|
||||||
|
Batch,
|
||||||
|
ForwardMode,
|
||||||
|
InputMetadata,
|
||||||
|
global_server_args_dict,
|
||||||
|
)
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -60,7 +65,13 @@ class ModelRunner:
|
|||||||
self.nccl_port = nccl_port
|
self.nccl_port = nccl_port
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||||
monkey_patch_vllm_dummy_weight_loader()
|
global_server_args_dict.update(
|
||||||
|
{
|
||||||
|
"disable_flashinfer": server_args.disable_flashinfer,
|
||||||
|
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
||||||
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Init torch distributed
|
# Init torch distributed
|
||||||
torch.cuda.set_device(self.gpu_id)
|
torch.cuda.set_device(self.gpu_id)
|
||||||
@@ -108,6 +119,7 @@ class ModelRunner:
|
|||||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
monkey_patch_vllm_dummy_weight_loader()
|
||||||
device_config = DeviceConfig()
|
device_config = DeviceConfig()
|
||||||
load_config = LoadConfig(load_format=self.server_args.load_format)
|
load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||||
vllm_model_config = VllmModelConfig(
|
vllm_model_config = VllmModelConfig(
|
||||||
|
|||||||
@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
tokenizer_manager = None
|
tokenizer_manager = None
|
||||||
|
|
||||||
# Put some args for easily access
|
|
||||||
global_server_args_dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> Response:
|
async def health() -> Response:
|
||||||
@@ -150,14 +147,6 @@ def available_models():
|
|||||||
return ModelList(data=model_cards)
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
def _set_global_server_args(server_args: ServerArgs):
|
|
||||||
global global_server_args_dict
|
|
||||||
global_server_args_dict = {
|
|
||||||
"disable_flashinfer": server_args.disable_flashinfer,
|
|
||||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _set_torch_compile_config():
|
def _set_torch_compile_config():
|
||||||
# The following configurations are for torch compile optimizations
|
# The following configurations are for torch compile optimizations
|
||||||
import torch._dynamo.config
|
import torch._dynamo.config
|
||||||
@@ -213,8 +202,6 @@ def launch_server(
|
|||||||
if server_args.enable_torch_compile:
|
if server_args.enable_torch_compile:
|
||||||
_set_torch_compile_config()
|
_set_torch_compile_config()
|
||||||
|
|
||||||
_set_global_server_args(server_args)
|
|
||||||
|
|
||||||
# Allocate ports
|
# Allocate ports
|
||||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||||
server_args.port,
|
server_args.port,
|
||||||
|
|||||||
@@ -52,13 +52,14 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_flashinfer: bool = False
|
disable_flashinfer: bool = False
|
||||||
|
disable_flashinfer_sampling: bool = False
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_regex_jump_forward: bool = False
|
disable_regex_jump_forward: bool = False
|
||||||
disable_cuda_graph: bool = False
|
disable_cuda_graph: bool = False
|
||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
attention_reduce_in_fp32: bool = False
|
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
|
attention_reduce_in_fp32: bool = False
|
||||||
efficient_weight_load: bool = False
|
efficient_weight_load: bool = False
|
||||||
|
|
||||||
# Distributed args
|
# Distributed args
|
||||||
@@ -303,7 +304,12 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-flashinfer",
|
"--disable-flashinfer",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable flashinfer inference kernels.",
|
help="Disable flashinfer attention kernels.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-flashinfer-sampling",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable flashinfer sampling kernels.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-radix-cache",
|
"--disable-radix-cache",
|
||||||
@@ -330,17 +336,17 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Optimize the model with torch.compile, experimental feature.",
|
help="Optimize the model with torch.compile, experimental feature.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-p2p-check",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-reduce-in-fp32",
|
"--attention-reduce-in-fp32",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
||||||
"This only affects Triton attention kernels",
|
"This only affects Triton attention kernels",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--enable-p2p-check",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--efficient-weight-load",
|
"--efficient-weight-load",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user