Make constrained decoding work for overlap scheduler (#2095)
This commit is contained in:
@@ -52,15 +52,19 @@ if TYPE_CHECKING:
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
# Contains both EXTEND and DECODE.
|
||||
# Contains both EXTEND and DECODE when doing chunked prefill.
|
||||
MIXED = auto()
|
||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
|
||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
||||
IDLE = auto()
|
||||
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
||||
DUMMY_FIRST = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
return self == ForwardMode.PREFILL
|
||||
|
||||
@@ -76,6 +80,9 @@ class ForwardMode(IntEnum):
|
||||
def is_idle(self):
|
||||
return self == ForwardMode.IDLE
|
||||
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardBatch:
|
||||
|
||||
@@ -142,7 +142,6 @@ class ModelRunner:
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"disable_mla": server_args.disable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"disable_penalizer": server_args.disable_penalizer,
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
}
|
||||
@@ -636,10 +635,18 @@ class ModelRunner:
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info = forward_batch.sampling_info
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
|
||||
if sampling_info.sampling_info_done:
|
||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||
# in process_batch_result of the last batch.
|
||||
if sampling_info.grammars:
|
||||
sampling_info.sampling_info_done.wait()
|
||||
sampling_info.update_penalties()
|
||||
else:
|
||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
||||
|
||||
# Sample the next tokens.
|
||||
|
||||
Reference in New Issue
Block a user