Feat/support rerank (#6058)

This commit is contained in:
woodx
2025-06-17 01:50:01 +08:00
committed by GitHub
parent 91a066ec6a
commit e30ef368ab
20 changed files with 684 additions and 30 deletions

View File

@@ -224,6 +224,9 @@ class ForwardBatch:
# For input embeddings
input_embeds: Optional[torch.tensor] = None
# For cross-encoder model
token_type_ids: Optional[torch.Tensor] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
@@ -300,6 +303,7 @@ class ForwardBatch:
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
token_type_ids=batch.token_type_ids,
tbo_split_seq_index=batch.tbo_split_seq_index,
)
device = model_runner.device
@@ -356,8 +360,8 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
if support_triton(model_runner.server_args.attention_backend):
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,
ret.extend_seq_lens,