[PERF]support H2P communication optimization for PanguProMoe (#1463)
### What this PR does / why we need it? In this PR, we support H2P communication optimization when running PanguProMoE with dp_size > 1. H2P use `reduce_scatter` and `all_gather` to replace `all_reduce` to improve performance: original layer: input_layernorm --> attn --> tp all_reduce --> post_attention_layernorm --> dp all_gather --> moe/mlp --> dp reduce_scatter --> tp all_reduce now: input_layernorm --> tp all_gather --> attn --> tp reduce_scatter --> post_attention_layernorm --> all_rank all_gather --> moe/mlp --> all_rank reduce_scatter Besides, because `reduce_scatter` requires num_tokens that can be divided by group size, we need pad the seqs based on `max_tokens_across_dp`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR has been tested with both offline and online inference using PanguProMoE-72B. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -18,20 +18,26 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
from vllm.distributed import (divide, get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_tp_group,
|
||||
get_world_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
@@ -47,6 +53,7 @@ from vllm.model_executor.models.utils import (
|
||||
extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
@@ -56,6 +63,225 @@ logger = init_logger(__name__)
|
||||
_ROUTER_SCALE = None
|
||||
|
||||
|
||||
def use_h2p():
|
||||
# only use H2P when dp_size > 1.
|
||||
if get_dp_group().world_size > 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear.
|
||||
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp).
|
||||
class CustomMergedColumnParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
output_size = sum(output_sizes)
|
||||
self.output_sizes = output_sizes
|
||||
self.tp_size = get_world_group().world_size
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: int):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
tp_rank = get_world_group().rank_in_group
|
||||
tp_size = get_world_group().world_size
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(
|
||||
self, input_
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear.
|
||||
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp)
|
||||
# and detach communication to enable customized communication algorithms(e.g., H2P).
|
||||
class CustomRowParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
group=None,
|
||||
):
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.group = group if group is not None else get_world_group()
|
||||
self.tp_rank = self.group.rank_in_group
|
||||
self.tp_size = self.group.world_size
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = self.group.rank_in_group
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
is_sharded_weight = is_sharded_weight
|
||||
|
||||
param_data = param.data
|
||||
if input_dim is not None and not is_sharded_weight:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(
|
||||
self, input_
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class PanguProMoEMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -68,21 +294,39 @@ class PanguProMoEMLP(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if not use_h2p():
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
else:
|
||||
self.gate_up_proj = CustomMergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = CustomRowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@@ -207,11 +451,30 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
global _ROUTER_SCALE
|
||||
_ROUTER_SCALE = self.router_scale
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
if not use_h2p():
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
else:
|
||||
# TODO: when using h2p, we have to skip communication in vLLM
|
||||
# native FusedMoE. here we need to design a better FusedMoE
|
||||
# (maybe using AscendFusedMoE) to enable these different
|
||||
# communication schema.
|
||||
final_hidden_states = self.experts.quant_method(
|
||||
layer=self.experts,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.experts.top_k,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
global_num_experts=self.experts.global_num_experts,
|
||||
expert_map=self.experts.expert_map,
|
||||
custom_routing_function=self.experts.custom_routing_function,
|
||||
apply_router_weight_on_input=self.experts.
|
||||
apply_router_weight_on_input)
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
if not use_h2p():
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
@@ -265,13 +528,22 @@ class PanguProMoEAttention(nn.Module):
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
if use_h2p():
|
||||
self.o_proj = CustomRowParallelLinear(self.total_num_heads *
|
||||
self.head_dim,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
group=get_tp_group())
|
||||
else:
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -337,8 +609,7 @@ class PanguProMoEDecoderLayer(nn.Module):
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
if (layer_idx
|
||||
not in mlp_only_layers) and (config.num_experts > 0): ### ???
|
||||
if (layer_idx not in mlp_only_layers) and (config.num_experts > 0):
|
||||
self.mlp = PanguProMoESparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@@ -364,7 +635,14 @@ class PanguProMoEDecoderLayer(nn.Module):
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
h2p_unpad_idx: Optional[torch.Tensor] = None,
|
||||
h2p_pad_idx: Optional[torch.Tensor] = None,
|
||||
is_start_layer: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \
|
||||
and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]
|
||||
tp_size = get_tp_group().world_size
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
@@ -372,16 +650,64 @@ class PanguProMoEDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
if use_h2p():
|
||||
if is_start_layer:
|
||||
if need_h2p_pad:
|
||||
residual = residual.index_select(dim=0, index=h2p_pad_idx)
|
||||
residual = torch.tensor_split(
|
||||
residual, tp_size)[get_tp_group().rank_in_group]
|
||||
else:
|
||||
if tp_size > 1:
|
||||
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||
if need_h2p_pad:
|
||||
hidden_states = hidden_states.index_select(
|
||||
dim=0, index=h2p_unpad_idx)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if use_h2p():
|
||||
if need_h2p_pad:
|
||||
hidden_states = hidden_states.index_select(dim=0,
|
||||
index=h2p_pad_idx)
|
||||
if tp_size > 1:
|
||||
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_tp_group().device_group)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if use_h2p():
|
||||
all_rank_group = get_world_group().device_group
|
||||
output_size = (hidden_states.shape[0] *
|
||||
get_world_group().world_size,
|
||||
hidden_states.shape[1])
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
hidden_states,
|
||||
group=all_rank_group)
|
||||
hidden_states = output_tensor
|
||||
|
||||
hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata)
|
||||
|
||||
if use_h2p():
|
||||
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_world_group().device_group)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
@@ -440,19 +766,61 @@ class PanguProMoEModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
if use_h2p():
|
||||
# calculate necessary padding/unpadding idx before model forward.
|
||||
|
||||
# the attn_metadata will be passed directly when use torchair.
|
||||
# if attn_meatadata is not passed, we try to get it from forward_context.
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is None:
|
||||
# when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks
|
||||
# are same.
|
||||
max_tokens_across_dp = hidden_states.shape[0]
|
||||
else:
|
||||
max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
||||
|
||||
tp_size = get_tp_group().world_size
|
||||
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
|
||||
# we need pad it before if the shape can't be divided by group size.
|
||||
# for h2p, we need pad it so that it can be divided by tp_size.
|
||||
h2p_padded_len = (
|
||||
tp_size - (max_tokens_across_dp % tp_size)
|
||||
) % tp_size + max_tokens_across_dp - hidden_states.shape[0]
|
||||
h2p_unpad_idx = torch.arange(hidden_states.shape[0],
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int32)
|
||||
h2p_pad_idx = torch.cat([
|
||||
h2p_unpad_idx,
|
||||
torch.zeros(h2p_padded_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
])
|
||||
else:
|
||||
h2p_unpad_idx = None
|
||||
h2p_pad_idx = None
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata)
|
||||
attn_metadata, h2p_unpad_idx, h2p_pad_idx,
|
||||
i == self.start_layer)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if use_h2p():
|
||||
if get_tp_group().world_size > 1:
|
||||
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||
if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]:
|
||||
hidden_states = hidden_states.index_select(dim=0,
|
||||
index=h2p_unpad_idx)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user