Resolves the 404 Not Found error when running compile_deep_gemm.py in multi-node setups (#5720)
This commit is contained in:
@@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request(
|
|||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
}
|
}
|
||||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
if server_args.node_rank == 0:
|
||||||
|
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||||
|
else:
|
||||||
|
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
|
||||||
|
response = requests.get(f"{base_url}/health", headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
# Rank-0 node send a request to sync with other node and then return.
|
||||||
|
if server_args.node_rank == 0:
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": [0, 1, 2, 3],
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
"temperature": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
error = response.json()
|
||||||
|
raise RuntimeError(f"Sync request failed: {error}")
|
||||||
|
# Other nodes should wait for the exit signal from Rank-0 node.
|
||||||
|
else:
|
||||||
|
start_time_waiting = time.time()
|
||||||
|
while proc.is_alive():
|
||||||
|
if time.time() - start_time_waiting < timeout:
|
||||||
|
time.sleep(10)
|
||||||
|
else:
|
||||||
|
raise TimeoutError("Waiting for main node timeout!")
|
||||||
return proc
|
return proc
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
pass
|
pass
|
||||||
@@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
|||||||
|
|
||||||
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
||||||
|
|
||||||
kill_process_tree(proc.pid)
|
|
||||||
|
|
||||||
print("\nDeepGEMM Kernels compilation finished successfully.")
|
print("\nDeepGEMM Kernels compilation finished successfully.")
|
||||||
|
|
||||||
|
# Sleep for safety
|
||||||
|
time.sleep(10)
|
||||||
|
if proc.is_alive():
|
||||||
|
# This is the rank0 node.
|
||||||
|
kill_process_tree(proc.pid)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
kill_process_tree(proc.pid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
Reference in New Issue
Block a user