Sync from v0.13
This commit is contained in:
83
vllm/model_executor/layers/fused_moe/moe_pallas.py
Normal file
83
vllm/model_executor/layers/fused_moe/moe_pallas.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
||||
"""
|
||||
Compute the histogram of an int32 tensor. The bin edges are defined by the
|
||||
min and max values, with step = 1.
|
||||
"""
|
||||
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
|
||||
assert min <= max, "min must be less than or equal to max."
|
||||
|
||||
def searchsorted(
|
||||
sorted_sequence: torch.Tensor, values_to_search: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
|
||||
|
||||
bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to(
|
||||
input.device
|
||||
)
|
||||
return searchsorted(bin_edges, input).to(torch.int32)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
renormalize: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [*, hidden_size]
|
||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
||||
w2: [num_experts, hidden_size, intermediate_size]
|
||||
gating_output: [*, num_experts]
|
||||
"""
|
||||
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
intermediate_size = w2.shape[-1]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
assert (num_tokens * topk) % 16 == 0, (
|
||||
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
|
||||
f"16 but got {num_tokens * topk}"
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
topk_indices = topk_indices.flatten()
|
||||
topk_argsort_indices = topk_indices.argsort()
|
||||
topk_argsort_revert_indices = topk_argsort_indices.argsort()
|
||||
token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk)
|
||||
token_indices = token_indices[topk_argsort_indices]
|
||||
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
|
||||
|
||||
x = hidden_states[token_indices]
|
||||
x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
|
||||
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
|
||||
x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
|
||||
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
||||
|
||||
x = x * topk_weights.unsqueeze(dim=-1)
|
||||
x = x.sum(dim=-2)
|
||||
x = x.reshape(orig_shape)
|
||||
return x
|
||||
Reference in New Issue
Block a user