model_runner simplify (#329)
This commit is contained in:
@@ -407,9 +407,7 @@ class ModelRpcServer:
|
|||||||
prefill_logprobs,
|
prefill_logprobs,
|
||||||
normalized_logprobs,
|
normalized_logprobs,
|
||||||
last_logprobs,
|
last_logprobs,
|
||||||
) = self.model_runner.forward(
|
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
batch, ForwardMode.EXTEND, batch.return_logprob
|
|
||||||
)
|
|
||||||
if prefill_logprobs is not None:
|
if prefill_logprobs is not None:
|
||||||
logprobs = prefill_logprobs.cpu().tolist()
|
logprobs = prefill_logprobs.cpu().tolist()
|
||||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
||||||
@@ -496,9 +494,7 @@ class ModelRpcServer:
|
|||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
logits, (_, _, last_logprobs) = self.model_runner.forward(
|
logits, (_, _, last_logprobs) = self.model_runner.forward(
|
||||||
batch,
|
batch, ForwardMode.DECODE
|
||||||
ForwardMode.DECODE,
|
|
||||||
batch.return_logprob,
|
|
||||||
)
|
)
|
||||||
next_token_ids, _ = batch.sample(logits)
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
|||||||
@@ -367,148 +367,88 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_prefill(
|
def forward_prefill(self, batch: Batch):
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
position_ids_offsets,
|
|
||||||
out_cache_loc,
|
|
||||||
return_logprob,
|
|
||||||
):
|
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
forward_mode=ForwardMode.PREFILL,
|
forward_mode=ForwardMode.PREFILL,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
prefix_lens=prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
position_ids_offsets=position_ids_offsets,
|
position_ids_offsets=batch.position_ids_offsets,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
return_logprob=return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
|
)
|
||||||
|
return self.model.forward(
|
||||||
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
)
|
)
|
||||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend(
|
def forward_extend(self, batch: Batch):
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
position_ids_offsets,
|
|
||||||
out_cache_loc,
|
|
||||||
return_logprob,
|
|
||||||
):
|
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
prefix_lens=prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
position_ids_offsets=position_ids_offsets,
|
position_ids_offsets=batch.position_ids_offsets,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
return_logprob=return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
|
)
|
||||||
|
return self.model.forward(
|
||||||
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
)
|
)
|
||||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_decode(
|
def forward_decode(self, batch: Batch):
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
position_ids_offsets,
|
|
||||||
out_cache_loc,
|
|
||||||
out_cache_cont_start,
|
|
||||||
out_cache_cont_end,
|
|
||||||
return_logprob,
|
|
||||||
):
|
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
prefix_lens=prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
position_ids_offsets=position_ids_offsets,
|
position_ids_offsets=batch.position_ids_offsets,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
out_cache_cont_start=out_cache_cont_start,
|
out_cache_cont_start=batch.out_cache_cont_start,
|
||||||
out_cache_cont_end=out_cache_cont_end,
|
out_cache_cont_end=batch.out_cache_cont_end,
|
||||||
return_logprob=return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
|
)
|
||||||
|
return self.model.forward(
|
||||||
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
)
|
)
|
||||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend_multi_modal(
|
def forward_extend_multi_modal(self, batch: Batch):
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
pixel_values,
|
|
||||||
image_sizes,
|
|
||||||
image_offsets,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
position_ids_offsets,
|
|
||||||
out_cache_loc,
|
|
||||||
return_logprob,
|
|
||||||
):
|
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
prefix_lens=prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
position_ids_offsets=position_ids_offsets,
|
position_ids_offsets=batch.position_ids_offsets,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
return_logprob=return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids,
|
batch.input_ids,
|
||||||
input_metadata.positions,
|
input_metadata.positions,
|
||||||
input_metadata,
|
input_metadata,
|
||||||
pixel_values,
|
batch.pixel_values,
|
||||||
image_sizes,
|
batch.image_sizes,
|
||||||
image_offsets,
|
batch.image_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
|
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
||||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||||
kwargs = {
|
return self.forward_extend_multi_modal(batch)
|
||||||
"input_ids": batch.input_ids,
|
elif forward_mode == ForwardMode.DECODE:
|
||||||
"pixel_values": batch.pixel_values,
|
return self.forward_decode(batch)
|
||||||
"image_sizes": batch.image_sizes,
|
|
||||||
"image_offsets": batch.image_offsets,
|
|
||||||
"req_pool_indices": batch.req_pool_indices,
|
|
||||||
"seq_lens": batch.seq_lens,
|
|
||||||
"prefix_lens": batch.prefix_lens,
|
|
||||||
"position_ids_offsets": batch.position_ids_offsets,
|
|
||||||
"out_cache_loc": batch.out_cache_loc,
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
}
|
|
||||||
return self.forward_extend_multi_modal(**kwargs)
|
|
||||||
else:
|
|
||||||
kwargs = {
|
|
||||||
"input_ids": batch.input_ids,
|
|
||||||
"req_pool_indices": batch.req_pool_indices,
|
|
||||||
"seq_lens": batch.seq_lens,
|
|
||||||
"prefix_lens": batch.prefix_lens,
|
|
||||||
"position_ids_offsets": batch.position_ids_offsets,
|
|
||||||
"out_cache_loc": batch.out_cache_loc,
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
}
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
|
||||||
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
|
|
||||||
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
|
||||||
return self.forward_decode(**kwargs)
|
|
||||||
elif forward_mode == ForwardMode.EXTEND:
|
elif forward_mode == ForwardMode.EXTEND:
|
||||||
return self.forward_extend(**kwargs)
|
return self.forward_extend(batch)
|
||||||
elif forward_mode == ForwardMode.PREFILL:
|
elif forward_mode == ForwardMode.PREFILL:
|
||||||
return self.forward_prefill(**kwargs)
|
return self.forward_prefill(batch)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||||
|
|||||||
Reference in New Issue
Block a user