testing dynamic register

This commit is contained in:
Chranos
2026-02-05 18:21:31 +08:00
parent 05605419e3
commit 4d0da98b9e
2 changed files with 69 additions and 38 deletions

View File

@@ -364,6 +364,7 @@ class Base(nn.Module):
Replace modules with vLLM optimized versions.
Uses tp_plan for tensor parallel style selection.
Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin.
"""
logger.info("Replacing modules with vLLM optimized versions...")
replaced_count = 0
@@ -371,6 +372,9 @@ class Base(nn.Module):
# Get tensor parallel plan
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
# Modules to skip replacement (handled at wrapper level)
skip_modules = {"lm_head", "score", "classifier"}
def _recursive_replace(module: nn.Module, prefix: str = ""):
nonlocal replaced_count
@@ -378,6 +382,11 @@ class Base(nn.Module):
# Skip PPMissingLayer
if isinstance(child, PPMissingLayer):
continue
# Skip modules that are handled at wrapper level
if name in skip_modules:
logger.debug("Skipping %s (handled at wrapper level)", name)
continue
qual_name = maybe_prefix(prefix, name)
new_module = None
@@ -552,20 +561,22 @@ class Base(nn.Module):
model_inputs["position_ids"] = positions
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=True,
output_hidden_states=True,
return_dict=False,
attention_instances=self.attention_instances,
)
# Get hidden states
if outputs.hidden_states is not None:
hidden_states = outputs.hidden_states[-1]
# Get hidden states from model output
# For models using return_dict=False, outputs is a tuple
# outputs[0] is usually the last hidden state
if isinstance(outputs, tuple):
hidden_states = outputs[0]
else:
hidden_states = outputs.logits
hidden_states = outputs
# Remove batch dimension
if hidden_states.dim() == 3 and hidden_states.size(0) == 1: