Use custom allreduce w/ torch.compile (#2185)

This commit is contained in:
Lianmin Zheng
2024-11-25 14:55:01 -08:00
committed by GitHub
parent 4d62bca542
commit c4336b2b60

View File

@@ -65,7 +65,8 @@ def patch_model(
_to_torch(model)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
tp_group.ca_comm = None
# Use custom-allreduce here
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
)