Add typo checker in pre-commit (#6179)

Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
applesaucethebun
2025-05-11 00:55:00 -04:00
committed by GitHub
parent de167cf5fa
commit 2ce8793519
99 changed files with 154 additions and 144 deletions

View File

@@ -201,7 +201,7 @@ class EAGLEWorker(TpModelWorker):
self.has_prefill_wrapper_verify = False
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
)
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
@@ -245,8 +245,8 @@ class EAGLEWorker(TpModelWorker):
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
Returns:
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
@@ -491,11 +491,11 @@ class EAGLEWorker(TpModelWorker):
)
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
# Pick indices that we care (accepted)
logits_output.next_token_logits = logits_output.next_token_logits[
res.accepeted_indices
res.accepted_indices
]
logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
@@ -597,7 +597,7 @@ class EAGLEWorker(TpModelWorker):
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
# Backup fileds that will be modified in-place
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length