diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index a2425f1a2..040393c17 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -383,7 +383,7 @@ class FlashAttentionBackend(AttentionBackend): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - elif forward_batch.forward_mode.is_extend_or_draft_extend(): + elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.cu_seqlens_k = torch.nn.functional.pad( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a0ead1784..ba861b850 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -78,7 +78,7 @@ class ForwardMode(IntEnum): self == ForwardMode.EXTEND or self == ForwardMode.MIXED or self == ForwardMode.DRAFT_EXTEND - or self == self.TARGET_VERIFY + or self == ForwardMode.TARGET_VERIFY ) def is_decode(self): @@ -96,6 +96,13 @@ class ForwardMode(IntEnum): def is_draft_extend(self): return self == ForwardMode.DRAFT_EXTEND + def is_extend_or_draft_extend_or_mixed(self): + return ( + self == ForwardMode.EXTEND + or self == ForwardMode.DRAFT_EXTEND + or self == ForwardMode.MIXED + ) + def is_cuda_graph(self): return ( self == ForwardMode.DECODE @@ -103,9 +110,6 @@ class ForwardMode(IntEnum): or self == ForwardMode.IDLE ) - def is_extend_or_draft_extend(self): - return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND - def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 995c613ce..41a245a10 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -78,9 +78,11 @@ from sglang.srt.utils import ( get_available_gpu_memory, init_custom_process_group, is_cuda, + is_fa3_default_architecture, is_flashinfer_available, is_hip, is_hopper_with_cuda_12_3, + is_no_spec_infer_or_topk_one, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, @@ -242,18 +244,21 @@ class ModelRunner: elif server_args.attention_backend is None: # By default, use flashinfer for non-mla attention and triton for mla attention if not self.use_mla_backend: - server_args.attention_backend = ( - "flashinfer" if is_flashinfer_available() else "triton" - ) + if ( + is_hopper_with_cuda_12_3() + and is_no_spec_infer_or_topk_one(server_args) + and is_fa3_default_architecture(self.model_config.hf_config) + ): + server_args.attention_backend = "fa3" + else: + server_args.attention_backend = ( + "flashinfer" if is_flashinfer_available() else "triton" + ) else: - if is_hopper_with_cuda_12_3(): - if server_args.speculative_eagle_topk is None or ( - server_args.speculative_eagle_topk is not None - and server_args.speculative_eagle_topk == 1 - ): - server_args.attention_backend = "fa3" - else: - server_args.attention_backend = "triton" + if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one( + server_args + ): + server_args.attention_backend = "fa3" else: server_args.attention_backend = "triton" logger.info( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0deea5be3..e812a7802 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -569,7 +569,7 @@ def encode_video(video_path, frame_count_limit=None): def load_image( - image_file: Union[Image.Image, str, bytes] + image_file: Union[Image.Image, str, bytes], ) -> tuple[Image.Image, tuple[int, int]]: image = image_size = None if isinstance(image_file, Image.Image): @@ -1905,3 +1905,28 @@ def get_local_ip_by_remote() -> str: return s.getsockname()[0] except Exception: raise ValueError(f"Can not get local ip") + + +def is_page_size_one(server_args): + return server_args.page_size == 1 + + +def is_no_spec_infer_or_topk_one(server_args): + return server_args.speculative_eagle_topk is None or ( + server_args.speculative_eagle_topk is not None + and server_args.speculative_eagle_topk == 1 + and is_page_size_one(server_args) + ) + + +def is_fa3_default_architecture(hf_config): + architectures = getattr(hf_config, "architectures", None) + if not isinstance(architectures, list) or not architectures: + return False + default_archs = { + "Qwen2ForCausalLM", + "Llama4ForConditionalGeneration", + "LlamaForCausalLM", + "MistralForCausalLM", + } + return architectures[0] in default_archs