Make constrained decoding work for overlap scheduler (#2095)

This commit is contained in:
Lianmin Zheng
2024-11-19 15:04:43 -08:00
committed by GitHub
parent 55bd97f3e5
commit ffd20fcd03
8 changed files with 119 additions and 95 deletions

View File

@@ -142,7 +142,6 @@ class ModelRunner:
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
}
@@ -636,10 +635,18 @@ class ModelRunner:
def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info = forward_batch.sampling_info
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if sampling_info.grammars:
sampling_info.sampling_info_done.wait()
sampling_info.update_penalties()
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
# Sample the next tokens.