feat: update model_specific_adjustment (#5344)
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
@@ -383,7 +383,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
|
||||||
elif forward_batch.forward_mode.is_extend_or_draft_extend():
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class ForwardMode(IntEnum):
|
|||||||
self == ForwardMode.EXTEND
|
self == ForwardMode.EXTEND
|
||||||
or self == ForwardMode.MIXED
|
or self == ForwardMode.MIXED
|
||||||
or self == ForwardMode.DRAFT_EXTEND
|
or self == ForwardMode.DRAFT_EXTEND
|
||||||
or self == self.TARGET_VERIFY
|
or self == ForwardMode.TARGET_VERIFY
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_decode(self):
|
def is_decode(self):
|
||||||
@@ -96,6 +96,13 @@ class ForwardMode(IntEnum):
|
|||||||
def is_draft_extend(self):
|
def is_draft_extend(self):
|
||||||
return self == ForwardMode.DRAFT_EXTEND
|
return self == ForwardMode.DRAFT_EXTEND
|
||||||
|
|
||||||
|
def is_extend_or_draft_extend_or_mixed(self):
|
||||||
|
return (
|
||||||
|
self == ForwardMode.EXTEND
|
||||||
|
or self == ForwardMode.DRAFT_EXTEND
|
||||||
|
or self == ForwardMode.MIXED
|
||||||
|
)
|
||||||
|
|
||||||
def is_cuda_graph(self):
|
def is_cuda_graph(self):
|
||||||
return (
|
return (
|
||||||
self == ForwardMode.DECODE
|
self == ForwardMode.DECODE
|
||||||
@@ -103,9 +110,6 @@ class ForwardMode(IntEnum):
|
|||||||
or self == ForwardMode.IDLE
|
or self == ForwardMode.IDLE
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_extend_or_draft_extend(self):
|
|
||||||
return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
|
|
||||||
|
|
||||||
def is_dummy_first(self):
|
def is_dummy_first(self):
|
||||||
return self == ForwardMode.DUMMY_FIRST
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
|
|
||||||
|
|||||||
@@ -78,9 +78,11 @@ from sglang.srt.utils import (
|
|||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
init_custom_process_group,
|
init_custom_process_group,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
|
is_fa3_default_architecture,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_hopper_with_cuda_12_3,
|
is_hopper_with_cuda_12_3,
|
||||||
|
is_no_spec_infer_or_topk_one,
|
||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
set_cpu_offload_max_bytes,
|
set_cpu_offload_max_bytes,
|
||||||
@@ -242,18 +244,21 @@ class ModelRunner:
|
|||||||
elif server_args.attention_backend is None:
|
elif server_args.attention_backend is None:
|
||||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
# By default, use flashinfer for non-mla attention and triton for mla attention
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
server_args.attention_backend = (
|
if (
|
||||||
"flashinfer" if is_flashinfer_available() else "triton"
|
is_hopper_with_cuda_12_3()
|
||||||
)
|
and is_no_spec_infer_or_topk_one(server_args)
|
||||||
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
||||||
|
):
|
||||||
|
server_args.attention_backend = "fa3"
|
||||||
|
else:
|
||||||
|
server_args.attention_backend = (
|
||||||
|
"flashinfer" if is_flashinfer_available() else "triton"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if is_hopper_with_cuda_12_3():
|
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
|
||||||
if server_args.speculative_eagle_topk is None or (
|
server_args
|
||||||
server_args.speculative_eagle_topk is not None
|
):
|
||||||
and server_args.speculative_eagle_topk == 1
|
server_args.attention_backend = "fa3"
|
||||||
):
|
|
||||||
server_args.attention_backend = "fa3"
|
|
||||||
else:
|
|
||||||
server_args.attention_backend = "triton"
|
|
||||||
else:
|
else:
|
||||||
server_args.attention_backend = "triton"
|
server_args.attention_backend = "triton"
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -569,7 +569,7 @@ def encode_video(video_path, frame_count_limit=None):
|
|||||||
|
|
||||||
|
|
||||||
def load_image(
|
def load_image(
|
||||||
image_file: Union[Image.Image, str, bytes]
|
image_file: Union[Image.Image, str, bytes],
|
||||||
) -> tuple[Image.Image, tuple[int, int]]:
|
) -> tuple[Image.Image, tuple[int, int]]:
|
||||||
image = image_size = None
|
image = image_size = None
|
||||||
if isinstance(image_file, Image.Image):
|
if isinstance(image_file, Image.Image):
|
||||||
@@ -1905,3 +1905,28 @@ def get_local_ip_by_remote() -> str:
|
|||||||
return s.getsockname()[0]
|
return s.getsockname()[0]
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(f"Can not get local ip")
|
raise ValueError(f"Can not get local ip")
|
||||||
|
|
||||||
|
|
||||||
|
def is_page_size_one(server_args):
|
||||||
|
return server_args.page_size == 1
|
||||||
|
|
||||||
|
|
||||||
|
def is_no_spec_infer_or_topk_one(server_args):
|
||||||
|
return server_args.speculative_eagle_topk is None or (
|
||||||
|
server_args.speculative_eagle_topk is not None
|
||||||
|
and server_args.speculative_eagle_topk == 1
|
||||||
|
and is_page_size_one(server_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_fa3_default_architecture(hf_config):
|
||||||
|
architectures = getattr(hf_config, "architectures", None)
|
||||||
|
if not isinstance(architectures, list) or not architectures:
|
||||||
|
return False
|
||||||
|
default_archs = {
|
||||||
|
"Qwen2ForCausalLM",
|
||||||
|
"Llama4ForConditionalGeneration",
|
||||||
|
"LlamaForCausalLM",
|
||||||
|
"MistralForCausalLM",
|
||||||
|
}
|
||||||
|
return architectures[0] in default_archs
|
||||||
|
|||||||
Reference in New Issue
Block a user