diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 303dae36..2937aa4a 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -170,23 +170,25 @@ class NPUWorker(WorkerBase): hidden_size = self.vllm_config.model_config.hf_config.hidden_size model = self.model_runner.model - for name, param in model.named_parameters(): - if 'w2_weight' in name and param.shape[2] == hidden_size: - parts = name.split('.') - param_name = parts[-1] - parent_module = model.get_submodule(".".join(parts[:-1])) + if tags is None or "weights" in tags: + for name, param in model.named_parameters(): + if 'w2_weight' in name and param.shape[2] == hidden_size: + parts = name.split('.') + param_name = parts[-1] + parent_module = model.get_submodule(".".join(parts[:-1])) - w2_data = param.transpose(1, 2) - w2_data = torch.nn.Parameter(w2_data, requires_grad=False) - setattr(parent_module, param_name, w2_data) - elif 'w13_weight' in name and param.shape[1] == hidden_size: - parts = name.split('.') - param_name = parts[-1] - parent_module = model.get_submodule(".".join(parts[:-1])) + w2_data = param.transpose(1, 2) + w2_data = torch.nn.Parameter(w2_data, requires_grad=False) + setattr(parent_module, param_name, w2_data) + elif 'w13_weight' in name and param.shape[1] == hidden_size: + parts = name.split('.') + param_name = parts[-1] + parent_module = model.get_submodule(".".join(parts[:-1])) - w13_data = param.transpose(1, 2) - w13_data = torch.nn.Parameter(w13_data, requires_grad=False) - setattr(parent_module, param_name, w13_data) + w13_data = param.transpose(1, 2) + w13_data = torch.nn.Parameter(w13_data, + requires_grad=False) + setattr(parent_module, param_name, w13_data) # Restore the buffers after level 2 sleep if len(self._sleep_saved_buffers):