add deepseekv3 and llama4
This commit is contained in:
@@ -194,6 +194,11 @@ def decoder_model_forward_base_pp(
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
hidden_states = get_input_embeddings(input_ids)
|
hidden_states = get_input_embeddings(input_ids)
|
||||||
|
# MLU F.embedding may output float32 even with float16 weights;
|
||||||
|
# cast to model dtype to avoid dtype mismatches downstream.
|
||||||
|
target_dtype = next(layers[start_layer].parameters()).dtype
|
||||||
|
if hidden_states.dtype != target_dtype:
|
||||||
|
hidden_states = hidden_states.to(target_dtype)
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|||||||
Reference in New Issue
Block a user