Fix gathered_buffer issues in tbo (#7531)
This commit is contained in:
@@ -346,7 +346,10 @@ class TboForwardBatchPreparer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO improve, e.g. unify w/ `init_raw`
|
# TODO improve, e.g. unify w/ `init_raw`
|
||||||
if global_server_args_dict["moe_dense_tp_size"] == 1:
|
if (
|
||||||
|
global_server_args_dict["moe_dense_tp_size"] == 1
|
||||||
|
and batch.gathered_buffer is not None
|
||||||
|
):
|
||||||
sum_len = end_token_index - start_token_index
|
sum_len = end_token_index - start_token_index
|
||||||
gathered_buffer = torch.zeros(
|
gathered_buffer = torch.zeros(
|
||||||
(sum_len, batch.gathered_buffer.shape[1]),
|
(sum_len, batch.gathered_buffer.shape[1]),
|
||||||
|
|||||||
Reference in New Issue
Block a user