[Minor] Fix code style (#2311)

This commit is contained in:
Lianmin Zheng
2024-12-02 02:27:36 -08:00
committed by GitHub
parent c54bda300a
commit 18108abe5d
5 changed files with 292 additions and 317 deletions

View File

@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
terminate_process(self.process)
def assert_tie_word_embeddings(self, truncate_size):
print(f"assert_tie_word_embeddings")
print("assert_tie_word_embeddings")
if self.backend == "Engine":
backend_ret = _process_return(
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
json={"name": "lm_head.weight", "truncate_size": truncate_size},
).json()
)
print(f"assert_tie_word_embeddings of hf and backend")
print("assert_tie_word_embeddings of hf and backend")
assert np.allclose(
self.hf_model.get_parameter("model.embed_tokens.weight")
.cpu()

View File

@@ -127,7 +127,7 @@ def init_process_hf(
hf_instruct_params = []
hf_base_params = []
print(f"get parameter in hf instruct model and base model")
print("get parameter in hf instruct model and base model")
for parameter_name in checking_parameters:
hf_instruct_params.append(
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
@@ -186,7 +186,6 @@ def init_process_hf(
param_queue.put(("broadcast_time", broadcast_time))
# Delete the huggingface models to free up memory.
del hf_instruct_model
del hf_base_model
gc.collect()
@@ -238,7 +237,6 @@ def init_process_sgl(
print(f"rank {rank} init server on url: {url}")
# Get weights of instruct model, i.e. pre-training weights.
instruct_params = []
for parameter_name in checking_parameters:
instruct_params.append(
@@ -253,7 +251,6 @@ def init_process_sgl(
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
# Init weight update group with the training engine.
if backend == "Engine":
engine.init_weights_update_group(
master_address="localhost",
@@ -282,7 +279,6 @@ def init_process_sgl(
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to update embed_tokens.weight once.
tie_word_embeddings = (
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
)
@@ -291,7 +287,6 @@ def init_process_sgl(
update_parameters.remove("lm_head.weight")
# Get weights from the training engine and update the inference engine.
for parameter_name in update_parameters:
if backend == "Engine":
engine.update_weights_from_distributed(
@@ -312,7 +307,6 @@ def init_process_sgl(
time_end_update = time.time()
# Measure the latency of broadcast/weights update.
update_time = time_end_update - time_begin_update
print(
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
@@ -320,7 +314,6 @@ def init_process_sgl(
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
# Get the weights of post-training model after weights update for correctness check.
base_params = []
for parameter_name in checking_parameters:
if backend == "Engine":
@@ -340,7 +333,6 @@ def init_process_sgl(
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
# Shutdown the engine or terminate the server process.
if backend == "Engine":
engine.shutdown()
else:
@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
# Check the correctness of weights update by verifying
# the weights of instruct model and base model.
for i in range(len(params["hf_instruct"])):
verify_params_close(
params["hf_instruct"][i],
@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
), "hf_instruct_params and hf_base_params have different lengths"
# Check if the weights of lm_head are tied with embed_tokens.
params_to_check = [
(
params["hf_instruct"],
@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
# Time limit for broadcast and update on CI is 3 / 6
# On local H100, it's 1 / 2
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
assert (
@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
# Delete the context and close the parameter queue.
del context
param_queue.close()
param_queue.join_thread()