[Fix] Fix accuracy bug in CSGMV kernel caching key. (#11579)
This commit is contained in:
@@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||
@triton.jit
|
||||
@triton.jit(do_not_specialize=["num_segs"])
|
||||
def _chunked_lora_expand_kernel(
|
||||
# Pointers to matrices
|
||||
x,
|
||||
|
||||
@@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.utils import cached_triton_kernel
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||
@triton.jit
|
||||
@cached_triton_kernel(
|
||||
lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
|
||||
)
|
||||
@triton.jit(do_not_specialize=["num_segs"])
|
||||
def _chunked_lora_shrink_kernel(
|
||||
# Pointers to matrices
|
||||
x,
|
||||
|
||||
Reference in New Issue
Block a user