Fix allgather ops inside cuda graphs (#3709)
This commit is contained in:
@@ -139,6 +139,27 @@ if supports_custom_op():
|
|||||||
fake_impl=outplace_all_reduce_fake,
|
fake_impl=outplace_all_reduce_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reg_all_gather_into_tensor(
|
||||||
|
output: torch.Tensor, input: torch.Tensor, group_name: str
|
||||||
|
) -> None:
|
||||||
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
|
group = _groups[group_name]()
|
||||||
|
if group is None:
|
||||||
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
|
group._all_gather_into_tensor(output, input)
|
||||||
|
|
||||||
|
def reg_all_gather_into_tensor_fake(
|
||||||
|
output: torch.Tensor, input: torch.Tensor, group_name: str
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="reg_all_gather_into_tensor",
|
||||||
|
op_func=reg_all_gather_into_tensor,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=reg_all_gather_into_tensor_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GroupCoordinator:
|
class GroupCoordinator:
|
||||||
"""
|
"""
|
||||||
@@ -414,6 +435,23 @@ class GroupCoordinator:
|
|||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||||
|
|
||||||
|
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||||
|
pynccl_comm.all_gather(output, input)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
output, input, group=self.device_group
|
||||||
|
)
|
||||||
|
|
||||||
|
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||||
|
if not supports_custom_op():
|
||||||
|
self._all_gather_into_tensor(output, input)
|
||||||
|
else:
|
||||||
|
torch.ops.sglang.reg_all_gather_into_tensor(
|
||||||
|
output, input, group_name=self.unique_name
|
||||||
|
)
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
world_size = self.world_size
|
world_size = self.world_size
|
||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
@@ -441,9 +479,7 @@ class GroupCoordinator:
|
|||||||
output_size, dtype=input_.dtype, device=input_.device
|
output_size, dtype=input_.dtype, device=input_.device
|
||||||
)
|
)
|
||||||
# All-gather.
|
# All-gather.
|
||||||
torch.distributed.all_gather_into_tensor(
|
self.all_gather_into_tensor(output_tensor, input_)
|
||||||
output_tensor, input_, group=self.device_group
|
|
||||||
)
|
|
||||||
# Reshape
|
# Reshape
|
||||||
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
|
|||||||
@@ -824,9 +824,7 @@ def all_gather(
|
|||||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
||||||
forward_batch.gathered_buffer, padded_tensor, group=group
|
|
||||||
)
|
|
||||||
|
|
||||||
gathered_tensors = torch.concat(
|
gathered_tensors = torch.concat(
|
||||||
[
|
[
|
||||||
@@ -862,7 +860,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_group = get_tp_group().device_group
|
self.tp_group = get_tp_group()
|
||||||
if not global_server_args_dict["disable_mla"]:
|
if not global_server_args_dict["disable_mla"]:
|
||||||
self.self_attn = DeepseekV2AttentionMLA(
|
self.self_attn = DeepseekV2AttentionMLA(
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
Reference in New Issue
Block a user