Fix perf regression for set_kv_buffer (#1765)
This commit is contained in:
@@ -221,17 +221,19 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
):
|
):
|
||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
copy_two_array(
|
if cache_k.dtype != self.dtype:
|
||||||
loc,
|
cache_k = cache_k.to(self.dtype)
|
||||||
self.k_buffer[layer_id],
|
cache_v = cache_v.to(self.dtype)
|
||||||
cache_k,
|
if self.store_dtype != self.dtype:
|
||||||
self.v_buffer[layer_id],
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||||
cache_v,
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||||
self.dtype,
|
else:
|
||||||
self.store_dtype,
|
self.k_buffer[layer_id][loc] = cache_k
|
||||||
)
|
self.v_buffer[layer_id][loc] = cache_v
|
||||||
|
|
||||||
|
|
||||||
|
# This compiled version is slower in the unit test
|
||||||
|
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True)
|
||||||
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||||
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user