[Minor] Fix code style (#2311)
This commit is contained in:
@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
|
||||
terminate_process(self.process)
|
||||
|
||||
def assert_tie_word_embeddings(self, truncate_size):
|
||||
print(f"assert_tie_word_embeddings")
|
||||
print("assert_tie_word_embeddings")
|
||||
if self.backend == "Engine":
|
||||
backend_ret = _process_return(
|
||||
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
|
||||
@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
|
||||
json={"name": "lm_head.weight", "truncate_size": truncate_size},
|
||||
).json()
|
||||
)
|
||||
print(f"assert_tie_word_embeddings of hf and backend")
|
||||
print("assert_tie_word_embeddings of hf and backend")
|
||||
assert np.allclose(
|
||||
self.hf_model.get_parameter("model.embed_tokens.weight")
|
||||
.cpu()
|
||||
|
||||
@@ -127,7 +127,7 @@ def init_process_hf(
|
||||
hf_instruct_params = []
|
||||
hf_base_params = []
|
||||
|
||||
print(f"get parameter in hf instruct model and base model")
|
||||
print("get parameter in hf instruct model and base model")
|
||||
for parameter_name in checking_parameters:
|
||||
hf_instruct_params.append(
|
||||
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
|
||||
@@ -186,7 +186,6 @@ def init_process_hf(
|
||||
param_queue.put(("broadcast_time", broadcast_time))
|
||||
|
||||
# Delete the huggingface models to free up memory.
|
||||
|
||||
del hf_instruct_model
|
||||
del hf_base_model
|
||||
gc.collect()
|
||||
@@ -238,7 +237,6 @@ def init_process_sgl(
|
||||
print(f"rank {rank} init server on url: {url}")
|
||||
|
||||
# Get weights of instruct model, i.e. pre-training weights.
|
||||
|
||||
instruct_params = []
|
||||
for parameter_name in checking_parameters:
|
||||
instruct_params.append(
|
||||
@@ -253,7 +251,6 @@ def init_process_sgl(
|
||||
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
|
||||
|
||||
# Init weight update group with the training engine.
|
||||
|
||||
if backend == "Engine":
|
||||
engine.init_weights_update_group(
|
||||
master_address="localhost",
|
||||
@@ -282,7 +279,6 @@ def init_process_sgl(
|
||||
# The last parameter is lm_head.weight, which is tied
|
||||
# with embed_tokens.weight. Actually, we only need
|
||||
# to update embed_tokens.weight once.
|
||||
|
||||
tie_word_embeddings = (
|
||||
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
||||
)
|
||||
@@ -291,7 +287,6 @@ def init_process_sgl(
|
||||
update_parameters.remove("lm_head.weight")
|
||||
|
||||
# Get weights from the training engine and update the inference engine.
|
||||
|
||||
for parameter_name in update_parameters:
|
||||
if backend == "Engine":
|
||||
engine.update_weights_from_distributed(
|
||||
@@ -312,7 +307,6 @@ def init_process_sgl(
|
||||
time_end_update = time.time()
|
||||
|
||||
# Measure the latency of broadcast/weights update.
|
||||
|
||||
update_time = time_end_update - time_begin_update
|
||||
print(
|
||||
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
|
||||
@@ -320,7 +314,6 @@ def init_process_sgl(
|
||||
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
|
||||
|
||||
# Get the weights of post-training model after weights update for correctness check.
|
||||
|
||||
base_params = []
|
||||
for parameter_name in checking_parameters:
|
||||
if backend == "Engine":
|
||||
@@ -340,7 +333,6 @@ def init_process_sgl(
|
||||
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
||||
|
||||
# Shutdown the engine or terminate the server process.
|
||||
|
||||
if backend == "Engine":
|
||||
engine.shutdown()
|
||||
else:
|
||||
@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
|
||||
|
||||
# Check the correctness of weights update by verifying
|
||||
# the weights of instruct model and base model.
|
||||
|
||||
for i in range(len(params["hf_instruct"])):
|
||||
verify_params_close(
|
||||
params["hf_instruct"][i],
|
||||
@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
|
||||
), "hf_instruct_params and hf_base_params have different lengths"
|
||||
|
||||
# Check if the weights of lm_head are tied with embed_tokens.
|
||||
|
||||
params_to_check = [
|
||||
(
|
||||
params["hf_instruct"],
|
||||
@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
|
||||
|
||||
# Time limit for broadcast and update on CI is 3 / 6
|
||||
# On local H100, it's 1 / 2
|
||||
|
||||
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
|
||||
|
||||
assert (
|
||||
@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
|
||||
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
|
||||
|
||||
# Delete the context and close the parameter queue.
|
||||
|
||||
del context
|
||||
param_queue.close()
|
||||
param_queue.join_thread()
|
||||
|
||||
Reference in New Issue
Block a user