[fix] fix compile_deep_gemm missing kv_b_proj (#5620)
This commit is contained in:
@@ -30,6 +30,8 @@ multiprocessing.set_start_method("spawn", force=True)
|
|||||||
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
|
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
|
||||||
# Force enable deep gemm
|
# Force enable deep gemm
|
||||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
||||||
|
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
||||||
|
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
@@ -80,7 +80,15 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
|
from sglang.srt.utils import (
|
||||||
|
BumpAllocator,
|
||||||
|
DeepEPMode,
|
||||||
|
add_prefix,
|
||||||
|
get_bool_env_var,
|
||||||
|
get_int_env_var,
|
||||||
|
is_cuda,
|
||||||
|
is_hip,
|
||||||
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -549,10 +557,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
"disable_chunked_prefix_cache"
|
"disable_chunked_prefix_cache"
|
||||||
]
|
]
|
||||||
self.attention_backend = global_server_args_dict["attention_backend"]
|
self.attention_backend = global_server_args_dict["attention_backend"]
|
||||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
self.rocm_fused_decode_mla = get_bool_env_var(
|
||||||
|
"SGLANG_ROCM_FUSED_DECODE_MLA", "false"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Design a finer way to determine the threshold
|
# TODO: Design a finer way to determine the threshold
|
||||||
self.chunked_prefix_cache_threshold = 8192
|
self.chunked_prefix_cache_threshold = get_int_env_var(
|
||||||
|
"SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
|
||||||
|
)
|
||||||
|
|
||||||
def dispatch_attn_forward_method(
|
def dispatch_attn_forward_method(
|
||||||
self, forward_batch: ForwardBatch
|
self, forward_batch: ForwardBatch
|
||||||
|
|||||||
Reference in New Issue
Block a user