[Fix] Fix accuracy bug in CSGMV kernel caching key. (#11579)

This commit is contained in:
Lifu Huang
2025-10-14 20:25:56 -07:00
committed by GitHub
parent 825432fce6
commit 780fbf2f38
2 changed files with 5 additions and 3 deletions

View File

@@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
@triton.jit
@triton.jit(do_not_specialize=["num_segs"])
def _chunked_lora_expand_kernel(
# Pointers to matrices
x,

View File

@@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
@triton.jit
@cached_triton_kernel(
lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
)
@triton.jit(do_not_specialize=["num_segs"])
def _chunked_lora_shrink_kernel(
# Pointers to matrices
x,