diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3ba5f210b..d1b5fa37a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -25,7 +25,6 @@ import uuid from typing import Dict, List, Optional, Tuple, Union import fastapi -import torch import uvloop import zmq import zmq.asyncio @@ -337,6 +336,12 @@ class TokenizerManager: rids.append(tmp_obj.rid) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. + if batch_size > 128: + logger.warning( + "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. " + "The performance might be better if you just duplicate the requests n times or use " + "many threads to send them one by one with parallel sampling (n > 1)." + ) # Tokenize all requests objs = [obj[i] for i in range(batch_size)] @@ -494,9 +499,7 @@ class TokenizerManager: result = await self.parameter_update_result return result.success, result.message else: - logger.error( - f"Another parameter update is in progress in tokenizer manager" - ) + logger.error("Another parameter update is in progress in tokenizer manager") return ( False, "Another parameter update is in progress. Please try again later.", @@ -597,7 +600,68 @@ class TokenizerManager: InitWeightsUpdateGroupReqOutput, ] = await self.recv_from_detokenizer.recv_pyobj() - if isinstance(recv_obj, UpdateWeightFromDiskReqOutput): + if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue + + recv_obj.meta_info[i]["id"] = rid + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": recv_obj.meta_info[i], + } + elif isinstance(recv_obj, BatchTokenIDOut): + out_dict = { + "token_ids": recv_obj.output_ids[i], + "meta_info": recv_obj.meta_info[i], + } + else: + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": recv_obj.meta_info[i], + } + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reason[i] is not None + state.event.set() + + if self.enable_metrics: + completion_tokens = recv_obj.meta_info[i]["completion_tokens"] + + if state.first_token_time is None: + state.first_token_time = time.time() + self.metrics_collector.observe_time_to_first_token( + state.first_token_time - state.created_time + ) + else: + if completion_tokens >= 2: + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.first_token_time) + / (completion_tokens - 1) + ) + + if state.finished: + self.metrics_collector.inc_prompt_tokens( + recv_obj.meta_info[i]["prompt_tokens"] + ) + self.metrics_collector.inc_generation_tokens( + completion_tokens + ) + self.metrics_collector.observe_e2e_request_latency( + time.time() - state.created_time + ) + if completion_tokens >= 1: + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.created_time) + / completion_tokens + ) + elif isinstance(recv_obj, OpenSessionReqOutput): + self.session_futures[recv_obj.session_id].set_result( + recv_obj.session_id + ) + elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): if self.server_args.dp_size == 1: self.model_update_result.set_result(recv_obj) else: # self.server_args.dp_size > 1 @@ -605,13 +669,16 @@ class TokenizerManager: # set future if the all results are recevied if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) - continue + elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + self.init_weights_update_group_result.set_result(recv_obj) elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): assert ( self.server_args.dp_size == 1 ), "dp_size must be 1 for update weights from distributed" self.parameter_update_result.set_result(recv_obj) - continue elif isinstance(recv_obj, GetWeightsByNameReqOutput): if self.server_args.dp_size == 1: self.get_weights_by_name_result.set_result(recv_obj) @@ -621,76 +688,8 @@ class TokenizerManager: self.get_weights_by_name_result.set_result( self.get_weights_by_name_tmp ) - continue - elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - self.init_weights_update_group_result.set_result(recv_obj) - continue - elif isinstance(recv_obj, OpenSessionReqOutput): - self.session_futures[recv_obj.session_id].set_result( - recv_obj.session_id - ) - continue - - assert isinstance( - recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) - ), f"Unexpected obj received: {type(recv_obj)}" - - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue - - recv_obj.meta_info[i]["id"] = rid - if isinstance(recv_obj, BatchStrOut): - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], - } - elif isinstance(recv_obj, BatchTokenIDOut): - out_dict = { - "token_ids": recv_obj.output_ids[i], - "meta_info": recv_obj.meta_info[i], - } - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { - "embedding": recv_obj.embeddings[i], - "meta_info": recv_obj.meta_info[i], - } - state.out_list.append(out_dict) - state.finished = recv_obj.finished_reason[i] is not None - state.event.set() - - if self.enable_metrics: - completion_tokens = recv_obj.meta_info[i]["completion_tokens"] - - if state.first_token_time is None: - state.first_token_time = time.time() - self.metrics_collector.observe_time_to_first_token( - state.first_token_time - state.created_time - ) - else: - if completion_tokens >= 2: - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.first_token_time) - / (completion_tokens - 1) - ) - - if state.finished: - self.metrics_collector.inc_prompt_tokens( - recv_obj.meta_info[i]["prompt_tokens"] - ) - self.metrics_collector.inc_generation_tokens(completion_tokens) - self.metrics_collector.observe_e2e_request_latency( - time.time() - state.created_time - ) - if completion_tokens >= 1: - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.created_time) / completion_tokens - ) + else: + raise ValueError(f"Invalid object: {recv_obj=}") def convert_logprob_style( self, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5c4f5c81b..3fffa2047 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -218,16 +218,6 @@ class ModelRunner: ) self.tp_group = get_tp_group() - # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph, - # so we disable padding in cuda graph. - if self.device == "cuda" and not all( - in_the_same_node_as(self.tp_group.cpu_group, source_rank=0) - ): - self.server_args.disable_cuda_graph_padding = True - logger.info( - "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism." - ) - # Check memory for tensor parallelism if self.tp_size > 1: local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index a750d90e2..fc8ac150b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -82,7 +82,6 @@ from sglang.srt.utils import ( assert_pkg_version, configure_logger, delete_directory, - init_custom_process_group, is_port_available, kill_process_tree, maybe_set_triton_cache_manager, @@ -154,13 +153,11 @@ async def get_model_info(): @app.get("/get_server_info") async def get_server_info(): - try: - return await _get_server_info() - - except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return { + **dataclasses.asdict(tokenizer_manager.server_args), # server args + **scheduler_info, + "version": __version__, + } @app.post("/flush_cache") @@ -567,14 +564,6 @@ def launch_server( t.join() -async def _get_server_info(): - return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args - **scheduler_info, - "version": __version__, - } - - def _set_envs_and_config(server_args: ServerArgs): # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -687,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer): delete_directory(server_args.model_path) +STREAM_END_SYMBOL = b"data: [DONE]" +STREAM_CHUNK_START_SYMBOL = b"data:" + + +class Engine: + """ + SRT Engine without an HTTP server layer. + + This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where + launching the HTTP server adds unnecessary complexity or overhead, + """ + + def __init__(self, log_level: str = "error", *args, **kwargs): + """See the arguments in server_args.py::ServerArgs""" + + # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() + atexit.register(self.shutdown) + + server_args = ServerArgs(*args, log_level=log_level, **kwargs) + launch_engine(server_args=server_args) + + def generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + stream: bool = False, + ): + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + ) + + # get the current event loop + loop = asyncio.get_event_loop() + ret = loop.run_until_complete(generate_request(obj, None)) + + if stream is True: + + def generator_wrapper(): + offset = 0 + loop = asyncio.get_event_loop() + generator = ret.body_iterator + while True: + chunk = loop.run_until_complete(generator.__anext__()) + + if chunk.startswith(STREAM_END_SYMBOL): + break + else: + data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) + data["text"] = data["text"][offset:] + offset += len(data["text"]) + yield data + + # we cannot yield in the scope of generate() because python does not allow yield + return in the same function + # however, it allows to wrap the generator as a subfunction and return + return generator_wrapper() + else: + return ret + + async def async_generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Dict] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + stream: bool = False, + ): + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + ) + + ret = await generate_request(obj, None) + + if stream is True: + generator = ret.body_iterator + + async def generator_wrapper(): + + offset = 0 + + while True: + chunk = await generator.__anext__() + + if chunk.startswith(STREAM_END_SYMBOL): + break + else: + data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) + data["text"] = data["text"][offset:] + offset += len(data["text"]) + yield data + + return generator_wrapper() + else: + return ret + + def shutdown(self): + kill_process_tree(os.getpid(), include_parent=False) + + def get_tokenizer(self): + global tokenizer_manager + + if tokenizer_manager is None: + raise ReferenceError("Tokenizer Manager is not initialized.") + else: + return tokenizer_manager.tokenizer + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ): + obj = EmbeddingReqInput(text=prompt) + + # get the current event loop + loop = asyncio.get_event_loop() + return loop.run_until_complete(encode_request(obj, None)) + + def start_profile(self): + tokenizer_manager.start_profile() + + def stop_profile(self): + tokenizer_manager.stop_profile() + + def get_server_info(self): + return { + **dataclasses.asdict(tokenizer_manager.server_args), # server args + **scheduler_info, + "version": __version__, + } + + def init_weights_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + ): + """Initialize parameter update group.""" + obj = InitWeightsUpdateGroupReqInput( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + + async def _init_group(): + return await tokenizer_manager.init_weights_update_group(obj, None) + + loop = asyncio.get_event_loop() + return loop.run_until_complete(_init_group()) + + def update_weights_from_distributed(self, name, dtype, shape): + """Update weights from distributed source.""" + obj = UpdateWeightsFromDistributedReqInput( + name=name, + dtype=dtype, + shape=shape, + ) + + async def _update_weights(): + return await tokenizer_manager.update_weights_from_distributed(obj, None) + + loop = asyncio.get_event_loop() + return loop.run_until_complete(_update_weights()) + + def get_weights_by_name(self, name, truncate_size=100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + + async def _get_weights(): + return await tokenizer_manager.get_weights_by_name(obj, None) + + loop = asyncio.get_event_loop() + return loop.run_until_complete(_get_weights()) + + class Runtime: """ - A wrapper for the server. + A wrapper for the HTTP server. This is used for launching the server in a python program without using the commond line interface. + + It is mainly used for the frontend language. + You should use the Engine class if you want to do normal offline processing. """ def __init__( @@ -839,201 +1035,3 @@ class Runtime: def __del__(self): self.shutdown() - - -STREAM_END_SYMBOL = b"data: [DONE]" -STREAM_CHUNK_START_SYMBOL = b"data:" - - -class Engine: - """ - SRT Engine without an HTTP server layer. - - This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where - launching the HTTP server adds unnecessary complexity or overhead, - """ - - def __init__(self, log_level: str = "error", *args, **kwargs): - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - server_args = ServerArgs(*args, log_level=log_level, **kwargs) - launch_engine(server_args=server_args) - - def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - # get the current event loop - loop = asyncio.get_event_loop() - ret = loop.run_until_complete(generate_request(obj, None)) - - if stream is True: - - def generator_wrapper(): - offset = 0 - loop = asyncio.get_event_loop() - generator = ret.body_iterator - while True: - chunk = loop.run_until_complete(generator.__anext__()) - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - # we cannot yield in the scope of generate() because python does not allow yield + return in the same function - # however, it allows to wrap the generator as a subfunction and return - return generator_wrapper() - else: - return ret - - async def async_generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Dict] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - ret = await generate_request(obj, None) - - if stream is True: - generator = ret.body_iterator - - async def generator_wrapper(): - - offset = 0 - - while True: - chunk = await generator.__anext__() - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - return generator_wrapper() - else: - return ret - - def shutdown(self): - kill_process_tree(os.getpid(), include_parent=False) - - def get_tokenizer(self): - global tokenizer_manager - - if tokenizer_manager is None: - raise ReferenceError("Tokenizer Manager is not initialized.") - else: - return tokenizer_manager.tokenizer - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - obj = EmbeddingReqInput(text=prompt) - - # get the current event loop - loop = asyncio.get_event_loop() - return loop.run_until_complete(encode_request(obj, None)) - - def start_profile(self): - tokenizer_manager.start_profile() - - def stop_profile(self): - tokenizer_manager.stop_profile() - - async def get_server_info(self): - return await _get_server_info() - - def init_weights_update_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - ): - """Initialize parameter update group.""" - obj = InitWeightsUpdateGroupReqInput( - master_address=master_address, - master_port=master_port, - rank_offset=rank_offset, - world_size=world_size, - group_name=group_name, - backend=backend, - ) - - async def _init_group(): - return await tokenizer_manager.init_weights_update_group(obj, None) - - loop = asyncio.get_event_loop() - return loop.run_until_complete(_init_group()) - - def update_weights_from_distributed(self, name, dtype, shape): - """Update weights from distributed source.""" - obj = UpdateWeightsFromDistributedReqInput( - name=name, - dtype=dtype, - shape=shape, - ) - - async def _update_weights(): - return await tokenizer_manager.update_weights_from_distributed(obj, None) - - loop = asyncio.get_event_loop() - return loop.run_until_complete(_update_weights()) - - def get_weights_by_name(self, name, truncate_size=100): - """Get weights by parameter name.""" - obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) - - async def _get_weights(): - return await tokenizer_manager.get_weights_by_name(obj, None) - - loop = asyncio.get_event_loop() - return loop.run_until_complete(_get_weights()) diff --git a/test/srt/test_get_weights_by_name.py b/test/srt/test_get_weights_by_name.py index 1494483c7..6dcb1d249 100644 --- a/test/srt/test_get_weights_by_name.py +++ b/test/srt/test_get_weights_by_name.py @@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase): terminate_process(self.process) def assert_tie_word_embeddings(self, truncate_size): - print(f"assert_tie_word_embeddings") + print("assert_tie_word_embeddings") if self.backend == "Engine": backend_ret = _process_return( self.engine.get_weights_by_name("lm_head.weight", truncate_size) @@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase): json={"name": "lm_head.weight", "truncate_size": truncate_size}, ).json() ) - print(f"assert_tie_word_embeddings of hf and backend") + print("assert_tie_word_embeddings of hf and backend") assert np.allclose( self.hf_model.get_parameter("model.embed_tokens.weight") .cpu() diff --git a/test/srt/test_update_weights_from_distributed.py b/test/srt/test_update_weights_from_distributed.py index a4fe17813..7acbe9fb3 100644 --- a/test/srt/test_update_weights_from_distributed.py +++ b/test/srt/test_update_weights_from_distributed.py @@ -127,7 +127,7 @@ def init_process_hf( hf_instruct_params = [] hf_base_params = [] - print(f"get parameter in hf instruct model and base model") + print("get parameter in hf instruct model and base model") for parameter_name in checking_parameters: hf_instruct_params.append( hf_instruct_model.get_parameter(parameter_name)[:truncate_size] @@ -186,7 +186,6 @@ def init_process_hf( param_queue.put(("broadcast_time", broadcast_time)) # Delete the huggingface models to free up memory. - del hf_instruct_model del hf_base_model gc.collect() @@ -238,7 +237,6 @@ def init_process_sgl( print(f"rank {rank} init server on url: {url}") # Get weights of instruct model, i.e. pre-training weights. - instruct_params = [] for parameter_name in checking_parameters: instruct_params.append( @@ -253,7 +251,6 @@ def init_process_sgl( param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params)) # Init weight update group with the training engine. - if backend == "Engine": engine.init_weights_update_group( master_address="localhost", @@ -282,7 +279,6 @@ def init_process_sgl( # The last parameter is lm_head.weight, which is tied # with embed_tokens.weight. Actually, we only need # to update embed_tokens.weight once. - tie_word_embeddings = ( True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False ) @@ -291,7 +287,6 @@ def init_process_sgl( update_parameters.remove("lm_head.weight") # Get weights from the training engine and update the inference engine. - for parameter_name in update_parameters: if backend == "Engine": engine.update_weights_from_distributed( @@ -312,7 +307,6 @@ def init_process_sgl( time_end_update = time.time() # Measure the latency of broadcast/weights update. - update_time = time_end_update - time_begin_update print( f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s" @@ -320,7 +314,6 @@ def init_process_sgl( param_queue.put((f"update_sgl_dp_{rank}_time", update_time)) # Get the weights of post-training model after weights update for correctness check. - base_params = [] for parameter_name in checking_parameters: if backend == "Engine": @@ -340,7 +333,6 @@ def init_process_sgl( param_queue.put((f"sgl_dp_{rank}_base_params", base_params)) # Shutdown the engine or terminate the server process. - if backend == "Engine": engine.shutdown() else: @@ -426,7 +418,6 @@ def test_update_weights_from_distributed( # Check the correctness of weights update by verifying # the weights of instruct model and base model. - for i in range(len(params["hf_instruct"])): verify_params_close( params["hf_instruct"][i], @@ -463,7 +454,6 @@ def test_update_weights_from_distributed( ), "hf_instruct_params and hf_base_params have different lengths" # Check if the weights of lm_head are tied with embed_tokens. - params_to_check = [ ( params["hf_instruct"], @@ -509,7 +499,6 @@ def test_update_weights_from_distributed( # Time limit for broadcast and update on CI is 3 / 6 # On local H100, it's 1 / 2 - time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6 assert ( @@ -526,7 +515,6 @@ def test_update_weights_from_distributed( ), f"update_sgl_dp_two_time exceeds time limit {time_limit}s" # Delete the context and close the parameter queue. - del context param_queue.close() param_queue.join_thread()