diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 181ac7eef..07f3d454e 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -221,17 +221,19 @@ class MHATokenToKVPool(BaseTokenToKVPool): cache_v: torch.Tensor, ): layer_id = layer.layer_id - copy_two_array( - loc, - self.k_buffer[layer_id], - cache_k, - self.v_buffer[layer_id], - cache_v, - self.dtype, - self.store_dtype, - ) + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + cache_v = cache_v.to(self.dtype) + if self.store_dtype != self.dtype: + self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) + else: + self.k_buffer[layer_id][loc] = cache_k + self.v_buffer[layer_id][loc] = cache_v +# This compiled version is slower in the unit test +# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size @torch.compile(dynamic=True) def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): dst_1[loc] = src_1.to(dtype).view(store_dtype)