From 780fbf2f389c01912e0452644a80169d96f2c826 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Tue, 14 Oct 2025 20:25:56 -0700 Subject: [PATCH] [Fix] Fix accuracy bug in CSGMV kernel caching key. (#11579) --- python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py | 2 +- python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py index 1767c5ee4..414f704a7 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py @@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) -@triton.jit +@triton.jit(do_not_specialize=["num_segs"]) def _chunked_lora_expand_kernel( # Pointers to matrices x, diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py index e0ef41fb7..b0ffdb763 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py @@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.utils import cached_triton_kernel -@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) -@triton.jit +@cached_triton_kernel( + lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"]) +) +@triton.jit(do_not_specialize=["num_segs"]) def _chunked_lora_shrink_kernel( # Pointers to matrices x,