Resolves the 404 Not Found error when running compile_deep_gemm.py in multi-node setups (#5720)
This commit is contained in:
@@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request(
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||
if server_args.node_rank == 0:
|
||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||
else:
|
||||
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
|
||||
response = requests.get(f"{base_url}/health", headers=headers)
|
||||
if response.status_code == 200:
|
||||
# Rank-0 node send a request to sync with other node and then return.
|
||||
if server_args.node_rank == 0:
|
||||
response = requests.post(
|
||||
f"{base_url}/generate",
|
||||
json={
|
||||
"input_ids": [0, 1, 2, 3],
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 8,
|
||||
"temperature": 0,
|
||||
},
|
||||
},
|
||||
timeout=600,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
error = response.json()
|
||||
raise RuntimeError(f"Sync request failed: {error}")
|
||||
# Other nodes should wait for the exit signal from Rank-0 node.
|
||||
else:
|
||||
start_time_waiting = time.time()
|
||||
while proc.is_alive():
|
||||
if time.time() - start_time_waiting < timeout:
|
||||
time.sleep(10)
|
||||
else:
|
||||
raise TimeoutError("Waiting for main node timeout!")
|
||||
return proc
|
||||
except requests.RequestException:
|
||||
pass
|
||||
@@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
||||
|
||||
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
||||
|
||||
kill_process_tree(proc.pid)
|
||||
|
||||
print("\nDeepGEMM Kernels compilation finished successfully.")
|
||||
|
||||
# Sleep for safety
|
||||
time.sleep(10)
|
||||
if proc.is_alive():
|
||||
# This is the rank0 node.
|
||||
kill_process_tree(proc.pid)
|
||||
else:
|
||||
try:
|
||||
kill_process_tree(proc.pid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user