[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)
This commit is contained in:
@@ -303,7 +303,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ class TokenToKVPoolAllocator:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.concat((self.free_slots, free_index))
|
||||
self.free_slots = torch.cat((self.free_slots, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -183,7 +183,7 @@ class TokenToKVPoolAllocator:
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.concat(self.free_group))
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
@@ -739,7 +739,7 @@ class HostKVCache(abc.ABC):
|
||||
@synchronized
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
self.free_slots = torch.concat([self.free_slots, indices])
|
||||
self.free_slots = torch.cat([self.free_slots, indices])
|
||||
self.can_use_mem_size += len(indices)
|
||||
return len(indices)
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class PagedTokenToKVPoolAllocator:
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.concat(self.free_group))
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
|
||||
@@ -152,7 +152,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
||||
return value, last_node
|
||||
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
|
||||
_dfs_helper(child)
|
||||
|
||||
_dfs_helper(self.root_node)
|
||||
return torch.concat(values)
|
||||
return torch.cat(values)
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
|
||||
Reference in New Issue
Block a user