Clean up server_args, triton cache manager (#8332)

This commit is contained in:
Lianmin Zheng
2025-07-25 14:14:51 -07:00
committed by GitHub
parent f8260f2539
commit ed2e313eb6
12 changed files with 128 additions and 204 deletions

View File

@@ -80,7 +80,7 @@ class ServerArgs:
schedule_policy: str = "fcfs"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
page_size: int = 1
page_size: Optional[int] = None
hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8
disable_hybrid_swa_memory: bool = False
@@ -266,31 +266,20 @@ class ServerArgs:
def __post_init__(self):
# Expert parallelism
# We put it here first due to some internal ckpt conversation issues.
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.warning(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
if self.enable_flashinfer_moe:
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
self.disable_shared_experts_fusion = True
logger.warning(
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
)
# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.device is None:
self.device = get_device()
if self.served_model_name is None:
self.served_model_name = self.model_path
if self.device is None:
self.device = get_device()
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
@@ -359,7 +348,6 @@ class ServerArgs:
self.chunked_prefill_size = 16384
else:
self.chunked_prefill_size = 4096
assert self.chunked_prefill_size % self.page_size == 0
# Set cuda graph max batch size
if self.cuda_graph_max_bs is None:
@@ -410,6 +398,14 @@ class ServerArgs:
)
self.page_size = 128
# Set page size
if self.page_size is None:
self.page_size = 1
# AMD-specific Triton attention KV splits default number
if is_hip():
self.triton_attention_num_kv_splits = 16
# Choose grammar backend
if self.grammar_backend is None:
self.grammar_backend = "xgrammar"
@@ -431,6 +427,17 @@ class ServerArgs:
self.enable_dp_attention
), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
if self.enable_flashinfer_moe:
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
self.disable_shared_experts_fusion = True
logger.warning(
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
)
# DeepEP MoE
if self.enable_deepep_moe:
if self.deepep_mode == "normal":
@@ -502,14 +509,6 @@ class ServerArgs:
logger.warning(
"DeepSeek MTP does not require setting speculative_draft_model_path."
)
elif "Llama4" in model_arch:
# TODO: remove this after Llama4 supports in other backends
if self.attention_backend != "fa3":
self.attention_backend = "fa3"
logger.warning(
"Llama4 requires using fa3 attention backend. "
"Attention backend is automatically set to fa3."
)
# Auto choose parameters
if self.speculative_num_steps is None:
@@ -542,12 +541,11 @@ class ServerArgs:
) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf"
# Model loading
if is_remote_url(self.model_path):
self.load_format = "remote"
# AMD-specific Triton attention KV splits default number
if is_hip():
self.triton_attention_num_kv_splits = 16
if self.custom_weight_loader is None:
self.custom_weight_loader = []
# PD disaggregation
if self.disaggregation_mode == "decode":
@@ -572,6 +570,7 @@ class ServerArgs:
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
# Propagate env vars
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
)
@@ -580,9 +579,6 @@ class ServerArgs:
"1" if self.disable_outlines_disk_cache else "0"
)
if self.custom_weight_loader is None:
self.custom_weight_loader = []
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer
@@ -1227,6 +1223,13 @@ class ServerArgs:
default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.",
)
parser.add_argument(
"--mm-attention-backend",
type=str,
choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
# Speculative decoding
parser.add_argument(
@@ -1276,13 +1279,6 @@ class ServerArgs:
help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map,
)
parser.add_argument(
"--mm-attention-backend",
type=str,
choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
# Expert parallelism
parser.add_argument(
@@ -1530,11 +1526,6 @@ class ServerArgs:
action="store_true",
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
)
parser.add_argument(
"--disable-overlap-cg-plan",
action="store_true",
help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
@@ -1792,11 +1783,11 @@ class ServerArgs:
return hf_config
def check_server_args(self):
# Check parallel size constraints
assert (
self.tp_size * self.pp_size
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
# FIXME pp constraints
if self.pp_size > 1:
assert (
self.disable_overlap_schedule
@@ -1807,11 +1798,7 @@ class ServerArgs:
assert not (
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported unless dp attention!"
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
@@ -1820,9 +1807,32 @@ class ServerArgs:
None,
}, "moe_dense_tp_size only support 1 and None currently"
# Check model architecture
model_arch = self.get_hf_config().architectures[0]
if "Llama4" in model_arch:
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
# Check LoRA
self.check_lora_server_args()
# Check speculative decoding
if self.speculative_algorithm is not None:
assert (
not self.enable_mixed_chunk
), "enable_mixed_chunk is required for speculative decoding"
# Check chunked prefill
assert (
self.chunked_prefill_size % self.page_size == 0
), "chunked_prefill_size must be divisible by page_size"
def check_lora_server_args(self):
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and radix attention is in progress"
# Enable LoRA if any LoRA paths are provided for backward compatibility.
if self.lora_paths:
if self.enable_lora is None: