diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 3b567dc76..99f637730 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1070,7 +1070,7 @@ def copy_all_layer_kv_cache( num_loop = tl.cdiv(stride, BLOCK_SIZE) for i in range(num_loop): copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :] + mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :] value = tl.load( data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask )