[MTP] Force greedy sampling on AMD (#9127)
This commit is contained in:
@@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
|
||||
|
||||
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
||||
|
||||
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleDraftInput:
|
||||
@@ -423,8 +425,15 @@ class EagleVerifyInput:
|
||||
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
||||
)
|
||||
|
||||
# Sample tokens
|
||||
if batch.sampling_info.is_all_greedy:
|
||||
# Sample tokens. Force greedy sampling on AMD
|
||||
is_all_greedy = sampling_info.is_all_greedy
|
||||
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
|
||||
logger.warning(
|
||||
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
|
||||
"Falling back to greedy verification."
|
||||
)
|
||||
|
||||
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
|
||||
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user