diff --git a/python/sglang/compile_deep_gemm.py b/python/sglang/compile_deep_gemm.py index 71324d315..b86086e20 100644 --- a/python/sglang/compile_deep_gemm.py +++ b/python/sglang/compile_deep_gemm.py @@ -30,6 +30,8 @@ multiprocessing.set_start_method("spawn", force=True) os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1" # Force enable deep gemm os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1" +# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case +os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0" @dataclasses.dataclass diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a9cb3c5b9..b299c3037 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -80,7 +80,15 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip +from sglang.srt.utils import ( + BumpAllocator, + DeepEPMode, + add_prefix, + get_bool_env_var, + get_int_env_var, + is_cuda, + is_hip, +) _is_hip = is_hip() _is_cuda = is_cuda() @@ -549,10 +557,14 @@ class DeepseekV2AttentionMLA(nn.Module): "disable_chunked_prefix_cache" ] self.attention_backend = global_server_args_dict["attention_backend"] - self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" + self.rocm_fused_decode_mla = get_bool_env_var( + "SGLANG_ROCM_FUSED_DECODE_MLA", "false" + ) # TODO: Design a finer way to determine the threshold - self.chunked_prefix_cache_threshold = 8192 + self.chunked_prefix_cache_threshold = get_int_env_var( + "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192 + ) def dispatch_attn_forward_method( self, forward_batch: ForwardBatch