diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py index 1767c5ee4..414f704a7 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py @@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) -@triton.jit +@triton.jit(do_not_specialize=["num_segs"]) def _chunked_lora_expand_kernel( # Pointers to matrices x, diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py index e0ef41fb7..b0ffdb763 100644 --- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py @@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.utils import cached_triton_kernel -@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) -@triton.jit +@cached_triton_kernel( + lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"]) +) +@triton.jit(do_not_specialize=["num_segs"]) def _chunked_lora_shrink_kernel( # Pointers to matrices x,