[refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204)

### What this PR does / why we need it?
Remove unnecessary attributes from set_ascend_forward_context
1.prefetch_stream
2.weight_prefetch_method
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2025-12-23 08:49:52 +08:00
committed by GitHub
parent 95e8a52156
commit c3a8d13ca7
10 changed files with 55 additions and 83 deletions

View File

@@ -286,8 +286,8 @@ def test_select_experts(
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk" with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk, \ ) as mock_native_grouped_topk, \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock(weight_prefetch_method=MagicMock())): return_value=MagicMock()):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x) x)
@@ -323,8 +323,8 @@ def test_select_experts(
@pytest.mark.parametrize("device", DEVICE) @pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str): def test_select_experts_invalid_scoring_func(device: str):
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock(weight_prefetch_method=MagicMock())), \ return_value=MagicMock()), \
pytest.raises(ValueError, pytest.raises(ValueError,
match="Unsupported scoring function: invalid"): match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device), select_experts(hidden_states=torch.randn(1, 128, device=device),
@@ -336,17 +336,3 @@ def test_select_experts_invalid_scoring_func(device: str):
gc.collect() gc.collect()
torch.npu.empty_cache() torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats() torch.npu.reset_peak_memory_stats()
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_missing_group_params(device: str):
with pytest.raises(AssertionError):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 64, device=device),
top_k=2,
use_grouped_topk=True,
renormalize=False,
scoring_func="softmax")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -92,15 +92,14 @@ def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method.finalize.side_effect = mock_finalize mock_moe_comm_method.finalize.side_effect = mock_finalize
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
mock_weight_prefetch_method = MagicMock() mock_weight_prefetch_method = MagicMock()
mock_forward_context_obj = MagicMock( mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
moe_comm_method=mock_moe_comm_method,
moe_comm_type=MoECommType.MC2, moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10, max_tokens_across_dp=10,
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
mc2_mask=torch.zeros(16, dtype=torch.bool), mc2_mask=torch.zeros(
16, dtype=torch.bool),
padded_num_tokens=16, padded_num_tokens=16,
with_quant=False, with_quant=False)
weight_prefetch_method=mock_weight_prefetch_method)
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \
@@ -133,8 +132,8 @@ def mock_dist_env(mocker: MockerFixture):
return_value=None), \ return_value=None), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None), \ return_value=None), \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=mock_forward_context_obj): return_value=mock_weight_prefetch_method):
yield { yield {
'mock_forward_context_obj': mock_forward_context_obj, 'mock_forward_context_obj': mock_forward_context_obj,

View File

@@ -63,21 +63,18 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(params['weight_scale'].shape, (10, 1)) self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1)) self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.get_forward_context") @patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method")
@patch("torch.ops.vllm.quantize") @patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul") @patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize, def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
mock_get_forward_context): mock_get_weight_prefetch_method):
layer = MagicMock() layer = MagicMock()
layer.aclnn_input_scale = 0.1 layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2 layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256) layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3 layer.deq_scale = 0.3
mock_forward_context = MagicMock() mock_get_weight_prefetch_method.return_value = MagicMock()
mock_get_forward_context.return_value = mock_forward_context
mock_weight_prefetch_method = MagicMock()
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
x = torch.randn(32, 128) x = torch.randn(32, 128)
bias = torch.randn(256) bias = torch.randn(256)

View File

@@ -1,7 +1,7 @@
import math import math
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Optional from typing import Any, Optional
import torch import torch
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
@@ -16,11 +16,6 @@ from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
get_ascend_device_type, has_layer_idx, get_ascend_device_type, has_layer_idx,
is_moe_model) is_moe_model)
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
else:
WeightPrefetchMethod = None
class MoECommType(Enum): class MoECommType(Enum):
ALLGATHER = 0 ALLGATHER = 0
@@ -41,9 +36,7 @@ def set_ascend_forward_context(
num_actual_tokens: Optional[int] = None, num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None, batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None, model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
is_mtp_model=False): is_mtp_model=False):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
@@ -116,13 +109,10 @@ def set_ascend_forward_context(
forward_context.layer_idx is not None and \ forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500 num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled: if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method
forward_context.is_mtp_model = is_mtp_model forward_context.is_mtp_model = is_mtp_model
if num_tokens is None and attn_metadata is not None: if num_tokens is None and attn_metadata is not None:

View File

@@ -18,7 +18,8 @@ from typing import Callable, Optional
import torch import torch
import torch_npu import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.utils import get_weight_prefetch_method
def select_experts(hidden_states: torch.Tensor, def select_experts(hidden_states: torch.Tensor,
@@ -56,7 +57,7 @@ def select_experts(hidden_states: torch.Tensor,
topk_ids: selected expert IDs of shape (num_tokens, top_k). topk_ids: selected expert IDs of shape (num_tokens, top_k).
""" """
# prefetch w1_w3_proj.weight preprocess # prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_forward_context().weight_prefetch_method weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up") hidden_states, "gate_up")

View File

@@ -24,7 +24,8 @@ from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
enable_custom_op, get_ascend_device_type) enable_custom_op, get_ascend_device_type,
get_weight_prefetch_method)
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
@@ -100,7 +101,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None bias1, bias2 = None, None
_output_dtype = w2_scale[0].dtype _output_dtype = w2_scale[0].dtype
weight_prefetch_method = get_forward_context().weight_prefetch_method weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess( weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
hidden_states) hidden_states)

View File

@@ -119,16 +119,16 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
if not getattr(forward_context, 'prefetch_mlp_enabled', False): if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return return
model_instance = forward_context.model_instance model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream weight_prefetch_stream = prefetch_stream()
layer_idx = int(prefix.split('.')[2]) layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch # start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn": if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj: if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream()) weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream): with torch.npu.stream(weight_prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch( torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
@@ -178,13 +178,13 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
return return
forward_context.prefetch_mlp_down_proj = True forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream weight_prefetch_stream = prefetch_stream()
layer_idx = forward_context.layer_idx layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch # start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream()) weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream): with torch.npu.stream(weight_prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch( torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight, model_instance.model.layers[layer_idx].mlp.down_proj.weight,
@@ -208,9 +208,9 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
return return
if forward_context.prefetch_mlp_gate_up_proj or \ if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj: forward_context.prefetch_mlp_down_proj:
prefetch_stream = forward_context.prefetch_stream weight_prefetch_stream = prefetch_stream()
# wait until prefetch done # wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream) torch.npu.current_stream().wait_stream(weight_prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_down_proj = False
return return

View File

@@ -19,10 +19,10 @@ from typing import Any, Dict, Optional
import torch import torch
import torch_npu import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType, from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type, maybe_trans_nz) get_ascend_device_type,
get_weight_prefetch_method, maybe_trans_nz)
def quant_per_tensor(in_tensor: torch.Tensor, def quant_per_tensor(in_tensor: torch.Tensor,
@@ -98,12 +98,7 @@ class AscendW8A8LinearMethod:
) -> torch.Tensor: ) -> torch.Tensor:
if x.dtype != torch.int8: if x.dtype != torch.int8:
layer_cls_name = layer.__class__.__name__ layer_cls_name = layer.__class__.__name__
try: weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
except AssertionError:
weight_prefetch_method = None
# prefetch qkvo_proj.weight preprocess # prefetch qkvo_proj.weight preprocess
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(

View File

@@ -34,7 +34,7 @@ from vllm.logger import logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import WeightPrefetchConfig, get_ascend_config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
@@ -52,6 +52,7 @@ ACL_FORMAT_FRACTAL_NZ = 29
_CUSTOM_OP_ENABLED = None _CUSTOM_OP_ENABLED = None
_CURRENT_STREAM = None _CURRENT_STREAM = None
_PREFETCH_STREAM = None _PREFETCH_STREAM = None
_WEIGHT_PREFETCH_METHOD = None
_GLOBAL_STREAM = None _GLOBAL_STREAM = None
_SHARED_EXPERTS_CALCULATION_STREAM = None _SHARED_EXPERTS_CALCULATION_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False _ASCEND_CUSTOMOP_IS_REIGISTERED = False
@@ -309,6 +310,18 @@ def prefetch_stream() -> torch.npu.Stream:
return _PREFETCH_STREAM return _PREFETCH_STREAM
def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
global _WEIGHT_PREFETCH_METHOD
if _WEIGHT_PREFETCH_METHOD is None:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
return _WEIGHT_PREFETCH_METHOD
def get_weight_prefetch_method():
return _WEIGHT_PREFETCH_METHOD
def global_stream() -> torch.npu.Stream: def global_stream() -> torch.npu.Stream:
global _GLOBAL_STREAM global _GLOBAL_STREAM
if _GLOBAL_STREAM is None: if _GLOBAL_STREAM is None:

View File

@@ -82,7 +82,6 @@ from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput,
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -106,7 +105,6 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.sample.sampler import AscendSampler
@@ -115,7 +113,8 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_moe_model, enable_sp, get_ascend_device_type, is_moe_model,
lmhead_tp_enable, maybe_trans_nz) lmhead_tp_enable, maybe_trans_nz,
set_weight_prefetch_method)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.ascend_forward_context import ( # isort: skip from vllm_ascend.ascend_forward_context import ( # isort: skip
@@ -209,18 +208,13 @@ class NPUModelRunner(GPUModelRunner):
self.pcp_rank = 0 self.pcp_rank = 0
if self.pcp_size > 1: if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
self.prefetch_stream = torch.npu.Stream(device=device)
else:
self.prefetch_stream = None
self.sampler = AscendSampler() self.sampler = AscendSampler()
self.attn_mask = None self.attn_mask = None
self.attn_state = None self.attn_state = None
# Ascend-specific configurations # Ascend-specific configurations
self.ascend_config = get_ascend_config() self.ascend_config = get_ascend_config()
self.weight_prefetch_method = WeightPrefetchMethod( set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
self.ascend_config.weight_prefetch_config)
# Dump / PrecisionDebugger configuration now comes from AscendConfig # Dump / PrecisionDebugger configuration now comes from AscendConfig
dump_cfg = self.ascend_config.dump_config dump_cfg = self.ascend_config.dump_config
self.dump_enable = dump_cfg.enable_dump self.dump_enable = dump_cfg.enable_dump
@@ -1420,9 +1414,7 @@ class NPUModelRunner(GPUModelRunner):
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output. num_actual_tokens=scheduler_output.
total_num_scheduled_tokens, total_num_scheduled_tokens,
prefetch_stream=self.prefetch_stream, model_instance=self.model):
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
self.maybe_setup_kv_connector(scheduler_output) self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states( hidden_states = self._generate_process_reqs_hidden_states(
@@ -2133,9 +2125,7 @@ class NPUModelRunner(GPUModelRunner):
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
prefetch_stream=self.prefetch_stream, model_instance=self.model):
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
hidden_states = self._generate_dummy_run_hidden_states( hidden_states = self._generate_dummy_run_hidden_states(
input_ids, positions, num_tokens_padded, input_ids, positions, num_tokens_padded,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)