Add more support for intel Gaudi accelerators (#2357)
This commit is contained in:
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
|
||||
return select_index.to(self.device, non_blocking=True)
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
||||
else:
|
||||
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
# This compiled version is slower in the unit test
|
||||
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||
@torch.compile(dynamic=True)
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||
|
||||
Reference in New Issue
Block a user