Fix model loading & format code (#125)
This commit is contained in:
@@ -63,7 +63,9 @@ class Req:
|
|||||||
# FIXME: This logic does not really solve the problem of determining whether
|
# FIXME: This logic does not really solve the problem of determining whether
|
||||||
# there should be a leading space.
|
# there should be a leading space.
|
||||||
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
||||||
first_token = first_token.decode() if isinstance(first_token, bytes) else first_token
|
first_token = (
|
||||||
|
first_token.decode() if isinstance(first_token, bytes) else first_token
|
||||||
|
)
|
||||||
if first_token.startswith("▁"):
|
if first_token.startswith("▁"):
|
||||||
old_output_str = " " + old_output_str
|
old_output_str = " " + old_output_str
|
||||||
new_input_string = (
|
new_input_string = (
|
||||||
|
|||||||
@@ -344,9 +344,13 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
running_req = 0 if self.running_batch is None else len(self.running_batch.reqs)
|
running_req = (
|
||||||
|
0 if self.running_batch is None else len(self.running_batch.reqs)
|
||||||
|
)
|
||||||
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
||||||
self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9
|
self.tree_cache_metrics["total"] += (
|
||||||
|
hit_tokens + new_batch_input_tokens
|
||||||
|
) / 10**9
|
||||||
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
||||||
tree_cache_hit_rate = (
|
tree_cache_hit_rate = (
|
||||||
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
||||||
@@ -584,7 +588,7 @@ def start_model_process(port):
|
|||||||
t = ThreadedServer(
|
t = ThreadedServer(
|
||||||
ModelRpcServer(),
|
ModelRpcServer(),
|
||||||
port=port,
|
port=port,
|
||||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
|
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
@@ -598,7 +602,7 @@ def start_model_process(port):
|
|||||||
con = rpyc.connect(
|
con = rpyc.connect(
|
||||||
"localhost",
|
"localhost",
|
||||||
port,
|
port,
|
||||||
config={"allow_pickle": True, "sync_request_timeout": 600},
|
config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError:
|
||||||
|
|||||||
@@ -351,7 +351,11 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format, revision
|
model_name_or_path,
|
||||||
|
cache_dir,
|
||||||
|
load_format,
|
||||||
|
revision,
|
||||||
|
fall_back_to_pt=False,
|
||||||
):
|
):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user