Fix gathered_buffer issues in tbo (#7531)
This commit is contained in:
@@ -346,7 +346,10 @@ class TboForwardBatchPreparer:
|
||||
)
|
||||
|
||||
# TODO improve, e.g. unify w/ `init_raw`
|
||||
if global_server_args_dict["moe_dense_tp_size"] == 1:
|
||||
if (
|
||||
global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
and batch.gathered_buffer is not None
|
||||
):
|
||||
sum_len = end_token_index - start_token_index
|
||||
gathered_buffer = torch.zeros(
|
||||
(sum_len, batch.gathered_buffer.shape[1]),
|
||||
|
||||
Reference in New Issue
Block a user