Feat/support rerank (#6058)
This commit is contained in:
@@ -224,6 +224,9 @@ class ForwardBatch:
|
||||
# For input embeddings
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
# For cross-encoder model
|
||||
token_type_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -300,6 +303,7 @@ class ForwardBatch:
|
||||
spec_info=batch.spec_info,
|
||||
capture_hidden_mode=batch.capture_hidden_mode,
|
||||
input_embeds=batch.input_embeds,
|
||||
token_type_ids=batch.token_type_ids,
|
||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||
)
|
||||
device = model_runner.device
|
||||
@@ -356,8 +360,8 @@ class ForwardBatch:
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
if support_triton(model_runner.server_args.attention_backend):
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens,
|
||||
ret.extend_seq_lens,
|
||||
|
||||
Reference in New Issue
Block a user