Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
48
vllm/model_executor/layers/fused_moe/utils.py
Normal file
48
vllm/model_executor/layers/fused_moe/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from math import prod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
"""
|
||||
assert prod(v) <= x.numel()
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
A permutation routine that works on fp8 types.
|
||||
"""
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
Reference in New Issue
Block a user