[Minor] more code cleanup (#4077)
This commit is contained in:
@@ -818,8 +818,8 @@ def all_gather(
|
||||
if world_size == 1:
|
||||
return input_tensor
|
||||
|
||||
all_lens = forward_batch.global_num_tokens
|
||||
max_len = max(forward_batch.global_num_tokens)
|
||||
all_lens = forward_batch.global_num_tokens_cpu
|
||||
max_len = max(forward_batch.global_num_tokens_cpu)
|
||||
|
||||
padded_tensor = torch.nn.functional.pad(
|
||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||
|
||||
Reference in New Issue
Block a user