Use custom allreduce w/ torch.compile (#2185)
This commit is contained in:
@@ -65,7 +65,8 @@ def patch_model(
|
|||||||
_to_torch(model)
|
_to_torch(model)
|
||||||
monkey_patch_vllm_all_gather()
|
monkey_patch_vllm_all_gather()
|
||||||
backup_ca_comm = tp_group.ca_comm
|
backup_ca_comm = tp_group.ca_comm
|
||||||
tp_group.ca_comm = None
|
# Use custom-allreduce here
|
||||||
|
# tp_group.ca_comm = None
|
||||||
yield torch.compile(
|
yield torch.compile(
|
||||||
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user