[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)

This commit is contained in:
JieXin Liang
2025-03-16 15:02:47 +08:00
committed by GitHub
parent 81f431eded
commit 1a3fa75f2f
14 changed files with 20 additions and 20 deletions

View File

@@ -303,7 +303,7 @@ class HiRadixCache(RadixCache):
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
value = torch.cat(value)
else:
value = torch.tensor([], dtype=torch.int32)

View File

@@ -172,7 +172,7 @@ class TokenToKVPoolAllocator:
return
if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index))
self.free_slots = torch.cat((self.free_slots, free_index))
else:
self.free_group.append(free_index)
@@ -183,7 +183,7 @@ class TokenToKVPoolAllocator:
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
self.free(torch.cat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
@@ -739,7 +739,7 @@ class HostKVCache(abc.ABC):
@synchronized
def free(self, indices: torch.Tensor) -> int:
self.mem_state[indices] = MemoryStateInt.IDLE
self.free_slots = torch.concat([self.free_slots, indices])
self.free_slots = torch.cat([self.free_slots, indices])
self.can_use_mem_size += len(indices)
return len(indices)

View File

@@ -272,7 +272,7 @@ class PagedTokenToKVPoolAllocator:
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
self.free(torch.cat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.

View File

@@ -152,7 +152,7 @@ class RadixCache(BasePrefixCache):
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int32, device=self.device)
return value, last_node
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.concat(values)
return torch.cat(values)
##### Internal Helper Functions #####