Fix gathered_buffer issues in tbo (#7531)

This commit is contained in:
Qiaolin Yu
2025-06-25 14:42:21 -07:00
committed by GitHub
parent a1c1ebe935
commit b8df43ab9c

View File

@@ -346,7 +346,10 @@ class TboForwardBatchPreparer:
)
# TODO improve, e.g. unify w/ `init_raw`
if global_server_args_dict["moe_dense_tp_size"] == 1:
if (
global_server_args_dict["moe_dense_tp_size"] == 1
and batch.gathered_buffer is not None
):
sum_len = end_token_index - start_token_index
gathered_buffer = torch.zeros(
(sum_len, batch.gathered_buffer.shape[1]),