[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)
This commit is contained in:
@@ -235,7 +235,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|||||||
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||||
)
|
)
|
||||||
output.append(qkv)
|
output.append(qkv)
|
||||||
output = torch.concat(output, dim=-2)
|
output = torch.cat(output, dim=-2)
|
||||||
|
|
||||||
# reshape
|
# reshape
|
||||||
output = rearrange(output, "b h n d -> b n (h d)")
|
output = rearrange(output, "b h n d -> b n (h d)")
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|||||||
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||||
)
|
)
|
||||||
output.append(qkv)
|
output.append(qkv)
|
||||||
output = torch.concat(output, dim=-2)
|
output = torch.cat(output, dim=-2)
|
||||||
# reshape
|
# reshape
|
||||||
output = rearrange(output, "b h n d -> b n (h d)")
|
output = rearrange(output, "b h n d -> b n (h d)")
|
||||||
# normalize
|
# normalize
|
||||||
|
|||||||
@@ -1244,14 +1244,14 @@ class ScheduleBatch:
|
|||||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||||
|
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.cat(
|
||||||
[self.req_pool_indices, other.req_pool_indices]
|
[self.req_pool_indices, other.req_pool_indices]
|
||||||
)
|
)
|
||||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.seq_lens_sum += other.seq_lens_sum
|
self.seq_lens_sum += other.seq_lens_sum
|
||||||
if self.output_ids is not None:
|
if self.output_ids is not None:
|
||||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
self.output_ids = torch.cat([self.output_ids, other.output_ids])
|
||||||
if self.return_logprob and other.return_logprob:
|
if self.return_logprob and other.return_logprob:
|
||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
||||||
|
|||||||
@@ -303,7 +303,7 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||||
if value:
|
if value:
|
||||||
value = torch.concat(value)
|
value = torch.cat(value)
|
||||||
else:
|
else:
|
||||||
value = torch.tensor([], dtype=torch.int32)
|
value = torch.tensor([], dtype=torch.int32)
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ class TokenToKVPoolAllocator:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.is_not_in_free_group:
|
if self.is_not_in_free_group:
|
||||||
self.free_slots = torch.concat((self.free_slots, free_index))
|
self.free_slots = torch.cat((self.free_slots, free_index))
|
||||||
else:
|
else:
|
||||||
self.free_group.append(free_index)
|
self.free_group.append(free_index)
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ class TokenToKVPoolAllocator:
|
|||||||
def free_group_end(self):
|
def free_group_end(self):
|
||||||
self.is_not_in_free_group = True
|
self.is_not_in_free_group = True
|
||||||
if self.free_group:
|
if self.free_group:
|
||||||
self.free(torch.concat(self.free_group))
|
self.free(torch.cat(self.free_group))
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
@@ -739,7 +739,7 @@ class HostKVCache(abc.ABC):
|
|||||||
@synchronized
|
@synchronized
|
||||||
def free(self, indices: torch.Tensor) -> int:
|
def free(self, indices: torch.Tensor) -> int:
|
||||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||||
self.free_slots = torch.concat([self.free_slots, indices])
|
self.free_slots = torch.cat([self.free_slots, indices])
|
||||||
self.can_use_mem_size += len(indices)
|
self.can_use_mem_size += len(indices)
|
||||||
return len(indices)
|
return len(indices)
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
def free_group_end(self):
|
def free_group_end(self):
|
||||||
self.is_not_in_free_group = True
|
self.is_not_in_free_group = True
|
||||||
if self.free_group:
|
if self.free_group:
|
||||||
self.free(torch.concat(self.free_group))
|
self.free(torch.cat(self.free_group))
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||||
if value:
|
if value:
|
||||||
value = torch.concat(value)
|
value = torch.cat(value)
|
||||||
else:
|
else:
|
||||||
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
||||||
return value, last_node
|
return value, last_node
|
||||||
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
_dfs_helper(child)
|
_dfs_helper(child)
|
||||||
|
|
||||||
_dfs_helper(self.root_node)
|
_dfs_helper(self.root_node)
|
||||||
return torch.concat(values)
|
return torch.cat(values)
|
||||||
|
|
||||||
##### Internal Helper Functions #####
|
##### Internal Helper Functions #####
|
||||||
|
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ class ForwardBatch:
|
|||||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||||
mrope_positions_list[i] = mrope_positions
|
mrope_positions_list[i] = mrope_positions
|
||||||
|
|
||||||
self.mrope_positions = torch.concat(
|
self.mrope_positions = torch.cat(
|
||||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
@@ -449,7 +449,7 @@ def compute_position_kernel(
|
|||||||
def compute_position_torch(
|
def compute_position_torch(
|
||||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
positions = torch.concat(
|
positions = torch.cat(
|
||||||
[
|
[
|
||||||
torch.arange(
|
torch.arange(
|
||||||
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
||||||
|
|||||||
@@ -1289,7 +1289,7 @@ class MlpProjector(nn.Module):
|
|||||||
high_x, low_x = x_or_tuple
|
high_x, low_x = x_or_tuple
|
||||||
high_x = self.high_up_proj(high_x)
|
high_x = self.high_up_proj(high_x)
|
||||||
low_x = self.low_up_proj(low_x)
|
low_x = self.low_up_proj(low_x)
|
||||||
x = torch.concat([high_x, low_x], dim=-1)
|
x = torch.cat([high_x, low_x], dim=-1)
|
||||||
else:
|
else:
|
||||||
x = x_or_tuple
|
x = x_or_tuple
|
||||||
|
|
||||||
|
|||||||
@@ -828,7 +828,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(image_embeds, list):
|
if isinstance(image_embeds, list):
|
||||||
image_embeds = torch.concat(image_embeds)
|
image_embeds = torch.cat(image_embeds)
|
||||||
|
|
||||||
return MiniCPMVImageEmbeddingInputs(
|
return MiniCPMVImageEmbeddingInputs(
|
||||||
image_bounds=image_bounds,
|
image_bounds=image_bounds,
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ class SamplingBatchInfo:
|
|||||||
]:
|
]:
|
||||||
self_val = getattr(self, item, None)
|
self_val = getattr(self, item, None)
|
||||||
other_val = getattr(other, item, None)
|
other_val = getattr(other, item, None)
|
||||||
setattr(self, item, torch.concat([self_val, other_val]))
|
setattr(self, item, torch.cat([self_val, other_val]))
|
||||||
|
|
||||||
self.is_all_greedy |= other.is_all_greedy
|
self.is_all_greedy |= other.is_all_greedy
|
||||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class EagleDraftInput:
|
|||||||
pt = 0
|
pt = 0
|
||||||
for i, extend_len in enumerate(batch.extend_lens):
|
for i, extend_len in enumerate(batch.extend_lens):
|
||||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||||
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||||
(input_ids[1:], self.verified_id[i].reshape(1))
|
(input_ids[1:], self.verified_id[i].reshape(1))
|
||||||
)
|
)
|
||||||
pt += extend_len
|
pt += extend_len
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ def lightning_attention_decode_naive(q, k, v, past_kv, slope):
|
|||||||
kv.to(torch.float32),
|
kv.to(torch.float32),
|
||||||
)
|
)
|
||||||
output.append(qkv)
|
output.append(qkv)
|
||||||
output = torch.concat(output, dim=-2)
|
output = torch.cat(output, dim=-2)
|
||||||
|
|
||||||
return output.to(original_dtype), kv
|
return output.to(original_dtype), kv
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def naive_lightning_attention_decode(q, k, v, past_kv, slope):
|
|||||||
kv.to(torch.float32),
|
kv.to(torch.float32),
|
||||||
)
|
)
|
||||||
output.append(qkv)
|
output.append(qkv)
|
||||||
output = torch.concat(output, dim=-2)
|
output = torch.cat(output, dim=-2)
|
||||||
|
|
||||||
return output.to(original_dtype), kv
|
return output.to(original_dtype), kv
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user