[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)
This commit is contained in:
@@ -1244,14 +1244,14 @@ class ScheduleBatch:
|
||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
self.req_pool_indices = torch.cat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum += other.seq_lens_sum
|
||||
if self.output_ids is not None:
|
||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
||||
self.output_ids = torch.cat([self.output_ids, other.output_ids])
|
||||
if self.return_logprob and other.return_logprob:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
||||
|
||||
@@ -303,7 +303,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ class TokenToKVPoolAllocator:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.concat((self.free_slots, free_index))
|
||||
self.free_slots = torch.cat((self.free_slots, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -183,7 +183,7 @@ class TokenToKVPoolAllocator:
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.concat(self.free_group))
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
@@ -739,7 +739,7 @@ class HostKVCache(abc.ABC):
|
||||
@synchronized
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
self.free_slots = torch.concat([self.free_slots, indices])
|
||||
self.free_slots = torch.cat([self.free_slots, indices])
|
||||
self.can_use_mem_size += len(indices)
|
||||
return len(indices)
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class PagedTokenToKVPoolAllocator:
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.concat(self.free_group))
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
|
||||
@@ -152,7 +152,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
value = torch.cat(value)
|
||||
else:
|
||||
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
||||
return value, last_node
|
||||
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
|
||||
_dfs_helper(child)
|
||||
|
||||
_dfs_helper(self.root_node)
|
||||
return torch.concat(values)
|
||||
return torch.cat(values)
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
|
||||
@@ -383,7 +383,7 @@ class ForwardBatch:
|
||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
|
||||
self.mrope_positions = torch.concat(
|
||||
self.mrope_positions = torch.cat(
|
||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
)
|
||||
@@ -449,7 +449,7 @@ def compute_position_kernel(
|
||||
def compute_position_torch(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
||||
):
|
||||
positions = torch.concat(
|
||||
positions = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
||||
|
||||
@@ -1289,7 +1289,7 @@ class MlpProjector(nn.Module):
|
||||
high_x, low_x = x_or_tuple
|
||||
high_x = self.high_up_proj(high_x)
|
||||
low_x = self.low_up_proj(low_x)
|
||||
x = torch.concat([high_x, low_x], dim=-1)
|
||||
x = torch.cat([high_x, low_x], dim=-1)
|
||||
else:
|
||||
x = x_or_tuple
|
||||
|
||||
|
||||
@@ -828,7 +828,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
)
|
||||
|
||||
if isinstance(image_embeds, list):
|
||||
image_embeds = torch.concat(image_embeds)
|
||||
image_embeds = torch.cat(image_embeds)
|
||||
|
||||
return MiniCPMVImageEmbeddingInputs(
|
||||
image_bounds=image_bounds,
|
||||
|
||||
@@ -306,7 +306,7 @@ class SamplingBatchInfo:
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
setattr(self, item, torch.cat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy |= other.is_all_greedy
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
@@ -59,7 +59,7 @@ class EagleDraftInput:
|
||||
pt = 0
|
||||
for i, extend_len in enumerate(batch.extend_lens):
|
||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
||||
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||
(input_ids[1:], self.verified_id[i].reshape(1))
|
||||
)
|
||||
pt += extend_len
|
||||
|
||||
Reference in New Issue
Block a user