forked from EngineX-Cambricon/enginex-mlu370-vllm
testing dynamic register
This commit is contained in:
@@ -364,6 +364,7 @@ class Base(nn.Module):
|
||||
Replace modules with vLLM optimized versions.
|
||||
|
||||
Uses tp_plan for tensor parallel style selection.
|
||||
Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin.
|
||||
"""
|
||||
logger.info("Replacing modules with vLLM optimized versions...")
|
||||
replaced_count = 0
|
||||
@@ -371,6 +372,9 @@ class Base(nn.Module):
|
||||
# Get tensor parallel plan
|
||||
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
|
||||
|
||||
# Modules to skip replacement (handled at wrapper level)
|
||||
skip_modules = {"lm_head", "score", "classifier"}
|
||||
|
||||
def _recursive_replace(module: nn.Module, prefix: str = ""):
|
||||
nonlocal replaced_count
|
||||
|
||||
@@ -378,6 +382,11 @@ class Base(nn.Module):
|
||||
# Skip PPMissingLayer
|
||||
if isinstance(child, PPMissingLayer):
|
||||
continue
|
||||
|
||||
# Skip modules that are handled at wrapper level
|
||||
if name in skip_modules:
|
||||
logger.debug("Skipping %s (handled at wrapper level)", name)
|
||||
continue
|
||||
|
||||
qual_name = maybe_prefix(prefix, name)
|
||||
new_module = None
|
||||
@@ -552,20 +561,22 @@ class Base(nn.Module):
|
||||
model_inputs["position_ids"] = positions
|
||||
|
||||
# Forward through model
|
||||
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||
with torch.no_grad():
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
return_dict=False,
|
||||
attention_instances=self.attention_instances,
|
||||
)
|
||||
|
||||
# Get hidden states
|
||||
if outputs.hidden_states is not None:
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
# Get hidden states from model output
|
||||
# For models using return_dict=False, outputs is a tuple
|
||||
# outputs[0] is usually the last hidden state
|
||||
if isinstance(outputs, tuple):
|
||||
hidden_states = outputs[0]
|
||||
else:
|
||||
hidden_states = outputs.logits
|
||||
hidden_states = outputs
|
||||
|
||||
# Remove batch dimension
|
||||
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
|
||||
|
||||
Reference in New Issue
Block a user