Fix the overhead due to penalizer in bench_latency (#1496)
This commit is contained in:
@@ -260,7 +260,7 @@ def correctness_test(
|
|||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||||
for _ in range(bench_args.output_len[0]):
|
for _ in range(bench_args.output_len[0] - 1):
|
||||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||||
for i in range(len(reqs)):
|
for i in range(len(reqs)):
|
||||||
output_ids[i].append(next_token_ids[i])
|
output_ids[i].append(next_token_ids[i])
|
||||||
@@ -311,7 +311,7 @@ def latency_test_run_once(
|
|||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decode_latencies = []
|
decode_latencies = []
|
||||||
for i in range(output_len):
|
for i in range(output_len - 1):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||||
|
|||||||
@@ -429,7 +429,7 @@ class ScheduleBatch:
|
|||||||
def prepare_for_extend(self, vocab_size: int):
|
def prepare_for_extend(self, vocab_size: int):
|
||||||
self.forward_mode = ForwardMode.EXTEND
|
self.forward_mode = ForwardMode.EXTEND
|
||||||
|
|
||||||
bs = self.batch_size()
|
bs = len(self.reqs)
|
||||||
reqs = self.reqs
|
reqs = self.reqs
|
||||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||||
@@ -509,7 +509,7 @@ class ScheduleBatch:
|
|||||||
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
||||||
|
|
||||||
def check_decode_mem(self):
|
def check_decode_mem(self):
|
||||||
bs = self.batch_size()
|
bs = len(self.reqs)
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -680,14 +680,12 @@ class ScheduleBatch:
|
|||||||
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
||||||
for r in self.reqs
|
for r in self.reqs
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
|
||||||
|
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
# Alloc mem
|
# Alloc mem
|
||||||
bs = self.batch_size()
|
bs = len(self.reqs)
|
||||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||||
|
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.req_to_token[
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ class ModelTpServer:
|
|||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
self.do_not_get_new_batch = False
|
self.do_not_get_new_batch = False
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def exposed_step(self, recv_reqs: List):
|
def exposed_step(self, recv_reqs: List):
|
||||||
try:
|
try:
|
||||||
# Recv requests
|
# Recv requests
|
||||||
@@ -246,7 +247,6 @@ class ModelTpServer:
|
|||||||
self.out_pyobjs = []
|
self.out_pyobjs = []
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def forward_step(self):
|
def forward_step(self):
|
||||||
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
||||||
new_batch = None
|
new_batch = None
|
||||||
|
|||||||
@@ -97,14 +97,12 @@ class InputMetadata:
|
|||||||
self.modalities = [r.modalities for r in reqs]
|
self.modalities = [r.modalities for r in reqs]
|
||||||
|
|
||||||
def compute_positions(self, batch: ScheduleBatch):
|
def compute_positions(self, batch: ScheduleBatch):
|
||||||
position_ids_offsets = batch.position_ids_offsets
|
|
||||||
|
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
if True:
|
if True:
|
||||||
self.positions = self.seq_lens - 1
|
self.positions = self.seq_lens - 1
|
||||||
else:
|
else:
|
||||||
# Deprecated
|
# Deprecated
|
||||||
self.positions = (self.seq_lens - 1) + position_ids_offsets
|
self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
|
||||||
else:
|
else:
|
||||||
if True:
|
if True:
|
||||||
self.positions = torch.tensor(
|
self.positions = torch.tensor(
|
||||||
@@ -119,7 +117,7 @@ class InputMetadata:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Deprecated
|
# Deprecated
|
||||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
|
||||||
self.positions = torch.tensor(
|
self.positions = torch.tensor(
|
||||||
np.concatenate(
|
np.concatenate(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -467,7 +467,6 @@ class ModelRunner:
|
|||||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def forward_decode(self, batch: ScheduleBatch):
|
def forward_decode(self, batch: ScheduleBatch):
|
||||||
if self.server_args.lora_paths is not None:
|
if self.server_args.lora_paths is not None:
|
||||||
self.lora_manager.prepare_lora_batch(batch)
|
self.lora_manager.prepare_lora_batch(batch)
|
||||||
@@ -481,7 +480,6 @@ class ModelRunner:
|
|||||||
batch.input_ids, input_metadata.positions, input_metadata
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def forward_extend(self, batch: ScheduleBatch):
|
def forward_extend(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||||
if self.server_args.lora_paths is not None:
|
if self.server_args.lora_paths is not None:
|
||||||
@@ -500,7 +498,6 @@ class ModelRunner:
|
|||||||
get_embedding=True,
|
get_embedding=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def normal_text(args):
|
|||||||
"The capital of the United Kindom is",
|
"The capital of the United Kindom is",
|
||||||
"Today is a sunny day and I like",
|
"Today is a sunny day and I like",
|
||||||
]
|
]
|
||||||
max_new_tokens = 17
|
max_new_tokens = 16
|
||||||
|
|
||||||
torch.cuda.set_device(0)
|
torch.cuda.set_device(0)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user