Share target model embed and head weights for nextn (#4033)
This commit is contained in:
@@ -280,7 +280,8 @@ class ForwardBatch:
|
|||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
if (
|
if (
|
||||||
model_runner.server_args.attention_backend != "torch_native"
|
model_runner.server_args.attention_backend != "torch_native"
|
||||||
and model_runner.server_args.speculative_algorithm != "NEXTN"
|
# TODO: Fix triton kernel illegal memory access for EAGLE
|
||||||
|
and model_runner.server_args.speculative_algorithm != "EAGLE"
|
||||||
):
|
):
|
||||||
ret.extend_num_tokens = batch.extend_num_tokens
|
ret.extend_num_tokens = batch.extend_num_tokens
|
||||||
positions, ret.extend_start_loc = compute_position_triton(
|
positions, ret.extend_start_loc = compute_position_triton(
|
||||||
|
|||||||
@@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
self.model = DeepseekModelNextN(config, quant_config)
|
self.model = DeepseekModelNextN(config, quant_config)
|
||||||
|
|
||||||
if global_server_args_dict["enable_dp_attention"]:
|
if global_server_args_dict["enable_dp_attention"]:
|
||||||
self.model.shared_head.head = ReplicatedLinear(
|
self.lm_head = ReplicatedLinear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||||
else:
|
else:
|
||||||
self.model.shared_head.head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.shared_head.head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
@@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
|
|
||||||
nextn_layer_prefix = "model.layers.0"
|
nextn_layer_prefix = "model.layers.0"
|
||||||
nextn_spec_weight_names = [
|
nextn_spec_weight_names = [
|
||||||
"shared_head.head",
|
|
||||||
"shared_head.norm",
|
"shared_head.norm",
|
||||||
"eh_proj",
|
"eh_proj",
|
||||||
"embed_tokens",
|
|
||||||
"enorm",
|
"enorm",
|
||||||
"hnorm",
|
"hnorm",
|
||||||
]
|
]
|
||||||
@@ -180,7 +178,11 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if not name.startswith(nextn_layer_prefix):
|
if not name.startswith(nextn_layer_prefix):
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
|
# Use shared head and embed weights from target model
|
||||||
|
if "shared_head.head" in name or "embed_tokens" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
is_decoder = True
|
is_decoder = True
|
||||||
# For nextn specific weights
|
# For nextn specific weights
|
||||||
for weight_name in nextn_spec_weight_names:
|
for weight_name in nextn_spec_weight_names:
|
||||||
|
|||||||
@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
if is_hip_:
|
if is_hip_:
|
||||||
self_attn.w_scale *= 2.0
|
self_attn.w_scale *= 2.0
|
||||||
|
|
||||||
|
def get_embed_and_head(self):
|
||||||
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|
||||||
|
def set_embed_and_head(self, embed, head):
|
||||||
|
del self.model.embed_tokens.weight
|
||||||
|
del self.lm_head.weight
|
||||||
|
self.model.embed_tokens.weight = embed
|
||||||
|
self.lm_head.weight = head
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -270,10 +270,11 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Speculative Decoding
|
# Speculative Decoding
|
||||||
if (
|
if self.speculative_algorithm == "NEXTN":
|
||||||
self.speculative_algorithm == "EAGLE"
|
# NEXTN shares the same implementation of EAGLE
|
||||||
or self.speculative_algorithm == "NEXTN"
|
self.speculative_algorithm = "EAGLE"
|
||||||
):
|
|
||||||
|
if self.speculative_algorithm == "EAGLE":
|
||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
self.prefill_only_one_req = True
|
self.prefill_only_one_req = True
|
||||||
self.disable_cuda_graph_padding = True
|
self.disable_cuda_graph_padding = True
|
||||||
|
|||||||
@@ -83,7 +83,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
|
||||||
# Share the embedding and lm_head
|
# Share the embedding and lm_head
|
||||||
if not self.speculative_algorithm.is_nextn():
|
|
||||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
if server_args.speculative_token_map is not None:
|
if server_args.speculative_token_map is not None:
|
||||||
head = head.clone()
|
head = head.clone()
|
||||||
@@ -94,12 +93,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
else:
|
else:
|
||||||
self.hot_token_id = None
|
self.hot_token_id = None
|
||||||
self.model_runner.model.set_embed_and_head(embed, head)
|
self.model_runner.model.set_embed_and_head(embed, head)
|
||||||
else:
|
|
||||||
if server_args.speculative_token_map is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"NEXTN does not support speculative-token-map now"
|
|
||||||
)
|
|
||||||
self.hot_token_id = None
|
|
||||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||||
|
|
||||||
# Create multi-step attn backends and cuda graph runners
|
# Create multi-step attn backends and cuda graph runners
|
||||||
|
|||||||
@@ -5,24 +5,16 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
NONE = auto()
|
NONE = auto()
|
||||||
EAGLE = auto()
|
EAGLE = auto()
|
||||||
|
|
||||||
# NEXTN spec decoding is for DeepSeek V3/R1
|
|
||||||
# currently it's implemented based on EAGLE
|
|
||||||
NEXTN = auto()
|
|
||||||
|
|
||||||
def is_none(self):
|
def is_none(self):
|
||||||
return self == SpeculativeAlgorithm.NONE
|
return self == SpeculativeAlgorithm.NONE
|
||||||
|
|
||||||
def is_eagle(self):
|
def is_eagle(self):
|
||||||
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN
|
return self == SpeculativeAlgorithm.EAGLE
|
||||||
|
|
||||||
def is_nextn(self):
|
|
||||||
return self == SpeculativeAlgorithm.NEXTN
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_string(name: str):
|
def from_string(name: str):
|
||||||
name_map = {
|
name_map = {
|
||||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||||
"NEXTN": SpeculativeAlgorithm.NEXTN,
|
|
||||||
None: SpeculativeAlgorithm.NONE,
|
None: SpeculativeAlgorithm.NONE,
|
||||||
}
|
}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for key in matching_keys:
|
for key in matching_keys:
|
||||||
|
if "embed_tokens" in key or "shared_head.head" in key:
|
||||||
|
continue
|
||||||
new_key = key.replace(prefix, "model.layers.0")
|
new_key = key.replace(prefix, "model.layers.0")
|
||||||
params[new_key] = f.get_tensor(key)
|
params[new_key] = f.get_tensor(key)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user