add deepseekv3 and llama4

This commit is contained in:
Chranos
2026-02-11 15:44:44 +08:00
parent 8657cbec87
commit d860f71e4d

View File

@@ -194,6 +194,11 @@ def decoder_model_forward_base_pp(
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = get_input_embeddings(input_ids) hidden_states = get_input_embeddings(input_ids)
# MLU F.embedding may output float32 even with float16 weights;
# cast to model dtype to avoid dtype mismatches downstream.
target_dtype = next(layers[start_layer].parameters()).dtype
if hidden_states.dtype != target_dtype:
hidden_states = hidden_states.to(target_dtype)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]