Online weight updates from torch.distributed (#2279)
This commit is contained in:
@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqInput:
|
||||
name: str
|
||||
dtype: str
|
||||
shape: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqInput:
|
||||
# The master address
|
||||
master_address: str
|
||||
# The master port
|
||||
master_port: int
|
||||
# The rank offset
|
||||
rank_offset: int
|
||||
# The world size
|
||||
world_size: int
|
||||
# The group name
|
||||
group_name: str = "weight_update_group"
|
||||
# The backend
|
||||
backend: str = "nccl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqInput:
|
||||
name: str
|
||||
|
||||
Reference in New Issue
Block a user