add deepseekv3 and llama4
This commit is contained in:
@@ -194,6 +194,11 @@ def decoder_model_forward_base_pp(
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = get_input_embeddings(input_ids)
|
||||
# MLU F.embedding may output float32 even with float16 weights;
|
||||
# cast to model dtype to avoid dtype mismatches downstream.
|
||||
target_dtype = next(layers[start_layer].parameters()).dtype
|
||||
if hidden_states.dtype != target_dtype:
|
||||
hidden_states = hidden_states.to(target_dtype)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
Reference in New Issue
Block a user