First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

19
Dockerfile Normal file
View File

@@ -0,0 +1,19 @@
FROM zibo.harbor.iluvatar.com.cn:30000/saas/bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:20250731115755
RUN pip install --no-cache-dir triton==2.1.0
COPY pkgs/triton /usr/local/corex/lib64/python3/dist-packages/triton
COPY pkgs/triton-2.1.0+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/triton-2.1.0+corex.4.1.2.dist-info
COPY pkgs/xformers-0.0.22+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/xformers-0.0.22+corex.4.1.2.dist-info
COPY pkgs/xformers /usr/local/corex/lib64/python3/dist-packages/xformers
COPY paged_attn.py /usr/local/lib/python3.10/site-packages/vllm/attention/ops/paged_attn.py
COPY __init__.py /usr/local/lib/python3.10/site-packages/vllm/triton_utils/__init__.py
COPY prefix_prefill.py /usr/local/lib/python3.10/site-packages/vllm/attention/ops/prefix_prefill.py
RUN mkdir /workspace
WORKDIR /workspace/
COPY ./launch_service /workspace/launch_service
ENTRYPOINT ["./launch_service"]

9
__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from vllm.triton_utils.importing import HAS_TRITON
__all__ = ["HAS_TRITON"]
#from vllm.triton_utils.custom_cache_manager import (
# maybe_set_triton_cache_manager)
#from vllm.triton_utils.libentry import libentry
__all__ += ["maybe_set_triton_cache_manager", "libentry"]

542
attention.py Normal file
View File

@@ -0,0 +1,542 @@
"""Multi-head attention."""
import os
enable_infer_paged_attn = os.getenv("ENABLE_INFER_PAGED_ATTN",None)
from typing import List, Optional
import importlib
import torch
import torch.nn as nn
from ixformer.contrib.xformers import ops as xops
from ixformer.contrib.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm._C import ops
from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
context_attention_fwd)
from vllm.utils import is_hip
# _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# _PARTITION_SIZE = 512
_SUPPORTED_HEAD_SIZES = [64, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 256
class PagedAttention(nn.Module):
"""MHA/MQA/GQA layer with PagedAttention.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Reshape and store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
self.use_ref_attention = self.check_use_ref_attention()
# TODO align vllm do not need those
self.attn_op = xops.fmha.flash.FwOp()
head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32),
self.num_queries_per_kv)
self.register_buffer("head_mapping", head_mapping, persistent=False)
def check_use_ref_attention(self) -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None
def ref_masked_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""PagedAttention forward pass.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
slot_mapping = input_metadata.slot_mapping
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
)
if input_metadata.is_prompt:
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens)
input_metadata.attn_bias = attn_bias
if self.use_ref_attention:
output = self.ref_masked_attention(
query,
key,
value,
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
return output.reshape(num_tokens, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=self.attn_op,
alibi_slopes=self.alibi_slopes
)
output = out.view_as(query)
else:
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)
else:
# Decoding run.
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.head_mapping, # self.num_kv_heads
self.scale,
self.alibi_slopes,
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
# TODO align
"""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
PagedAttention forward pass.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)
if input_metadata.is_prompt:
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
if self.use_ref_attention:
output = self.ref_masked_attention(
query,
key,
value,
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
return output.reshape(batch_size, seq_len, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else:
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)
else:
# Decoding run.
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
"""
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
batch_size,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias
def _paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
head_mapping: torch.Tensor, # num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
use_sqrt_alibi: bool = False
) -> torch.Tensor:
output = torch.empty_like(query)
use_v2 = enable_infer_paged_attn is None and key_cache.dim() == 4
if not use_v2:
block_size = value_cache.shape[3]
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
head_mapping, # num_kv_heads
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
else:
# Run PagedAttention V2.
block_size = value_cache.shape[2]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
head_mapping, # num_kv_heads
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
return output
# ↓ add for smoothquant
class DequantPagedAttention(PagedAttention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
quant_kv_cache: bool = False,
kv_quant_params: torch.Tensor = None,
quant_scale: float = 1.0,
use_per_token_quant: bool = True,
) -> None:
super().__init__(num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window)
self.register_parameter(
"quant_scale",
torch.nn.Parameter(
torch.tensor(quant_scale, dtype=torch.float32,requires_grad=False))
)
self.use_per_token_quant = use_per_token_quant
def _apply(self, fn):
super()._apply(fn)
self.quant_scale.data = self.quant_scale.cpu()
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.quant_scale.data = self.quant_scale.to(*args, **kwargs)
self.quant_scale.data = self.quant_scale.to(torch.float32)
return self
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
out = super().forward(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
)
quant_out = torch.empty_like(out, dtype=torch.int8)
if self.use_per_token_quant:
scale = torch.empty(out.numel() // out.shape[-1],
dtype=torch.float32,
device=out.device)
ops.quant(quant_out, out, scale)
return quant_out, scale
else:
ops.quant(quant_out, out, self.quant_scale.item())
return (quant_out, )

79
launch_service Executable file
View File

@@ -0,0 +1,79 @@
#!/bin/bash
export PYTHONPATH=/usr/local/corex/lib64/python3/dist-packages
export LD_LIBRARY_PATH=/usr/local/corex/lib64:/usr/local/openmpi/lib
export PATH=/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin
export JAVA_HOME=/root/apps/jdk1.8.0_411
export JRE_HOME=/root/apps/jdk1.8.0_411/jre
export JMETER_HOME=/root/apps/apache-jmeter-5.6.3
export CLASSPATH=.:/root/apps/jdk1.8.0_411/lib/dt.jar:/root/apps/jdk1.8.0_411/lib/tools.jar:/root/apps/apache-jmeter-5.6.3/lib/ext/ApacheJMeter_core.jar:/root/apps/apache-jmeter-5.6.3/lib/jorphan.jar:/root/apps/apache-jmeter-5.6.3/lib/logkit-2.0.jar:
export PATH=/root/apps/apache-jmeter-5.6.3/bin:/root/apps/jdk1.8.0_411/bin:/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin
/iluvatar/welcome.sh
data
cat /proc/cpuinfo | tail -n 50
ixsmi
unset CUDA_VISIBLE_DEVICES
export
date
DEFAULT_HOST="0.0.0.0"
DEFAULT_PORT="80"
DEFAULT_SERVED_MODEL_NAME="llm"
DEFAULT_MODEL_PATH="/model"
DEFAULT_MAX_MODEL_LEN="10000"
DEFAULT_TENSOR_PARALLEL_SIZE="1"
DEFAULT_MAX_NUM_SEQS="64"
DEFAULT_ENFORCE_EAGER="true"
DEFAULT_DISABLE_LOG_REQUESTS="true"
DEFAULT_PREFIX_CACHING="true"
HOST_VAL=${HOST:-$DEFAULT_HOST}
PORT_VAL=${PORT:-$DEFAULT_PORT}
SERVED_MODEL_NAME_VAL=${SERVED_MODEL_NAME:-$DEFAULT_SERVED_MODEL_NAME}
MODEL_PATH_VAL=${MODEL_PATH:-$DEFAULT_MODEL_PATH}
MAX_MODEL_LEN_VAL=${MAX_MODEL_LEN:-$DEFAULT_MAX_MODEL_LEN}
TENSOR_PARALLEL_SIZE_VAL=${TENSOR_PARALLEL_SIZE:-$DEFAULT_TENSOR_PARALLEL_SIZE}
MAX_NUM_SEQS_VAL=${MAX_NUM_SEQS:-$DEFAULT_MAX_NUM_SEQS}
INCLUDE_ENFORCE_EAGER_FLAG=${ENFORCE_EAGER:-$DEFAULT_ENFORCE_EAGER}
INCLUDE_DISABLE_LOG_REQUESTS_FLAG=${DISABLE_LOG_REQUESTS:-$DEFAULT_DISABLE_LOG_REQUESTS}
INCLUDE_PREFIX_CACHING_FLAG=${PREFIX_CACHING:-$DEFAULT_PREFIX_CACHING}
CMD_ARGS=()
CMD_ARGS+=(--host "$HOST_VAL")
CMD_ARGS+=(--port "$PORT_VAL")
if [[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]]; then
CMD_ARGS+=(--enforce-eager)
fi
if [[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]]; then
CMD_ARGS+=(--disable-log-requests)
fi
if [[ "$INCLUDE_PREFIX_CACHING_FLAG" != "false" && "$INCLUDE_PREFIX_CACHING_FLAG" != "0" ]]; then
CMD_ARGS+=(--enable-prefix-caching)
fi
CMD_ARGS+=(--served-model-name "$SERVED_MODEL_NAME_VAL")
CMD_ARGS+=(--model "$MODEL_PATH_VAL")
CMD_ARGS+=(--max-model-len "$MAX_MODEL_LEN_VAL")
CMD_ARGS+=(--tensor-parallel-size "$TENSOR_PARALLEL_SIZE_VAL")
CMD_ARGS+=(--max-num-seqs "$MAX_NUM_SEQS_VAL")
echo "--------------------------------------------------"
echo "Starting VLLM OpenAI API Server..."
echo "Using effective arguments:"
echo " Host (--host): $HOST_VAL"
echo " Port (--port): $PORT_VAL"
echo " Enforce Eager (--enforce-eager):" $([[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: ENFORCE_EAGER=$ENFORCE_EAGER)")
echo " Disable Log Req (--disable-log-requests):" $([[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: DISABLE_LOG_REQUESTS=$DISABLE_LOG_REQUESTS)")
echo " Served Model Name (--served-model-name): $SERVED_MODEL_NAME_VAL"
echo " Model Path (--model): $MODEL_PATH_VAL"
echo " Max Model Length (--max-model-len): $MAX_MODEL_LEN_VAL"
echo " Tensor Parallel Size (--tensor-parallel-size): $TENSOR_PARALLEL_SIZE_VAL"
echo " Max Num Seqs (--max-num-seqs): $MAX_NUM_SEQS_VAL"
echo "--------------------------------------------------"
echo "Full cmd:"
echo "python3 -m vllm.entrypoints.openai.api_server ${CMD_ARGS[*]}"
echo "--------------------------------------------------"
python3 -m vllm.entrypoints.openai.api_server "${CMD_ARGS[@]}"

242
paged_attn.py Normal file
View File

@@ -0,0 +1,242 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = True
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
# query_start_loc is (batch_size + 1,)
query_start_loc[:-1],
seq_lens_tensor,
context_lens,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@@ -0,0 +1 @@
pip

View File

@@ -0,0 +1,33 @@
Metadata-Version: 2.1
Name: triton
Version: 2.1.0+corex.4.1.2
Summary: A language and compiler for custom Deep Learning operations
Home-page: https://github.com/openai/triton/
Author: Philippe Tillet
Author-email: phil@openai.com
Keywords: Compiler,Deep Learning
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Topic :: Software Development :: Build Tools
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Dist: filelock
Provides-Extra: build
Requires-Dist: cmake>=3.18; extra == "build"
Requires-Dist: lit; extra == "build"
Provides-Extra: tests
Requires-Dist: autopep8; extra == "tests"
Requires-Dist: flake8; extra == "tests"
Requires-Dist: isort; extra == "tests"
Requires-Dist: numpy; extra == "tests"
Requires-Dist: pytest; extra == "tests"
Requires-Dist: scipy>=1.7.1; extra == "tests"
Provides-Extra: tutorials
Requires-Dist: matplotlib; extra == "tutorials"
Requires-Dist: pandas; extra == "tutorials"
Requires-Dist: tabulate; extra == "tutorials"

View File

@@ -0,0 +1,156 @@
triton-2.1.0+corex.4.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
triton-2.1.0+corex.4.1.2.dist-info/METADATA,sha256=WsBBXGUNH3GMUl-Sr8PVEA9wKr8pKyeQQ7BeQEq4Moc,1271
triton-2.1.0+corex.4.1.2.dist-info/RECORD,,
triton-2.1.0+corex.4.1.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
triton-2.1.0+corex.4.1.2.dist-info/WHEEL,sha256=jkyzK-PMDGe6NpviSR1tpgon24csQEXDXqxhDgOHeYM,105
triton-2.1.0+corex.4.1.2.dist-info/direct_url.json,sha256=rC0cgRYuAV96ZhqohICZYGcWgh85fJNAUbExKOBJuaE,284
triton-2.1.0+corex.4.1.2.dist-info/top_level.txt,sha256=g7iYLuhGmQFd2TwONpkqhhg7KJl6GMjbwGZvSp3_tnE,275
triton/_C/libtriton.so,sha256=TabhXpykxBhWhTTiPjVzF4qoFgKy2B5-GV04hVDF1HY,256934256
triton/__init__.py,sha256=J6RjARnKOLfAU59wgewEy82ZoICORUQe_mfHWHQ87lY,1267
triton/__pycache__/__init__.cpython-310.pyc,,
triton/__pycache__/testing.cpython-310.pyc,,
triton/common/__init__.py,sha256=uG0IrQy1GiwfCkg9FHqVcYHQ9dspmDtma8SSdFqKXew,48
triton/common/__pycache__/__init__.cpython-310.pyc,,
triton/common/__pycache__/build.cpython-310.pyc,,
triton/common/build.py,sha256=Ssz_enawYCeVbyHsmJRk9e81dvlFCIWNKikWcWJjZBo,4609
triton/compiler/__init__.py,sha256=VGVviF6fZSXDMUrcO9o0H96j9gQXcYDjnJx8TG18xhU,144
triton/compiler/__pycache__/__init__.cpython-310.pyc,,
triton/compiler/__pycache__/code_generator.cpython-310.pyc,,
triton/compiler/__pycache__/compiler.cpython-310.pyc,,
triton/compiler/__pycache__/errors.cpython-310.pyc,,
triton/compiler/__pycache__/make_launcher.cpython-310.pyc,,
triton/compiler/code_generator.py,sha256=2T0mFAy4TUOh1v2Mh3csLtztT9Aky0zNH-fjfLqx5W8,49616
triton/compiler/compiler.py,sha256=0_H9l2n4MlHKA9zfploeC9nF-oWC8sw5oUOvgGirRwU,23506
triton/compiler/errors.py,sha256=PiquMxHuHayRvdH3hMXSQlOnDswK3TGNWENB29YX_FU,1666
triton/compiler/make_launcher.py,sha256=dhi2y8cTVkaFUAFfH-s7gBq40ukhq18_Rl674lsS408,12458
triton/debugger/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
triton/debugger/__pycache__/__init__.cpython-310.pyc,,
triton/debugger/__pycache__/core.cpython-310.pyc,,
triton/debugger/__pycache__/debugger.cpython-310.pyc,,
triton/debugger/__pycache__/memory_map.cpython-310.pyc,,
triton/debugger/__pycache__/tl_lang.cpython-310.pyc,,
triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc,,
triton/debugger/core.py,sha256=zjYVVAleV3EdBMOhMUR-c1X9H3613fu1ttOsG5n0VgQ,150
triton/debugger/debugger.py,sha256=-G9B4apnYC1woN-dTwdBkk4EdS-bavFnJ763_kvUr6o,5399
triton/debugger/memory_map.py,sha256=7xvtTJsJoGYqX4bwsJJyxJvl8H5YOnA3nnqwF5zifi8,3165
triton/debugger/tl_lang.py,sha256=P2ZZa9QNkLxp63PY7kSdTrEWU4mF5zzB3hsBIm1Xum0,17980
triton/debugger/torch_wrapper.py,sha256=iD0qbPnsR-rMG0anMJRoEhIDvQaVzFmBuFTKkPmnj_g,353
triton/language/__init__.py,sha256=lLiZyE7r2PJ3vQcUfw2EAkFaBJHLzG4YRUlJKGlOxrk,2884
triton/language/__pycache__/__init__.cpython-310.pyc,,
triton/language/__pycache__/core.cpython-310.pyc,,
triton/language/__pycache__/math.cpython-310.pyc,,
triton/language/__pycache__/random.cpython-310.pyc,,
triton/language/__pycache__/semantic.cpython-310.pyc,,
triton/language/__pycache__/standard.cpython-310.pyc,,
triton/language/core.py,sha256=cxhJ3qG4h2T85FUuaVdDTAdB7eXkTHzmaerFZd57cow,54553
triton/language/extra/__init__.py,sha256=zyhlj6Mo9_KD7E5szGxI18Ko3b31E00-s1ybOM5KzkQ,39
triton/language/extra/__pycache__/__init__.cpython-310.pyc,,
triton/language/extra/__pycache__/cuda.cpython-310.pyc,,
triton/language/extra/cuda.bc,sha256=03xfoq03DK6PfMS2fm05V90ei5jlnDdZaD6kZffQEas,1808
triton/language/extra/cuda.py,sha256=7UCCXfkwgLPpl87FRZsXyrHN6tt0xRoQeeBksSMdgyE,641
triton/language/math.py,sha256=g1z24oHGwyfZZNPkp9OUQ0_nXlnGEQCwAMcKKrRt1kE,75450
triton/language/random.py,sha256=yxzlT8jCKAv0yo4aVcWL5kBj4OOMyRcpbuGXMUJw0E0,5630
triton/language/semantic.py,sha256=hmyErm5bWEvlt2NgWPzubVUNBh4Bbfrb5BJ6y42VScM,62939
triton/language/standard.py,sha256=jpc-ha22ylEHusr-f4OUSLo7iuso3yF4cTjOTpSMqo0,2320
triton/ops/__init__.py,sha256=KM696wNNUi4gHanlZo54WcfmYnRWkykHKGnuhRoqAgQ,370
triton/ops/__pycache__/__init__.cpython-310.pyc,,
triton/ops/__pycache__/bmm_matmul.cpython-310.pyc,,
triton/ops/__pycache__/cross_entropy.cpython-310.pyc,,
triton/ops/__pycache__/flash_attention.cpython-310.pyc,,
triton/ops/__pycache__/matmul.cpython-310.pyc,,
triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc,,
triton/ops/blocksparse/__init__.py,sha256=6YEVQNzipgQCpoO_7B8H7ckaSW2Idt1244s7IyLWAwc,100
triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc,,
triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc,,
triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc,,
triton/ops/blocksparse/matmul.py,sha256=uBOKf2pS5fFtGokQvrWS0GrVHNipIrsTZ_-vPr5lfOY,15617
triton/ops/blocksparse/softmax.py,sha256=eFvPebQZ__8L4pCmks-PsqpCG2pgheRL2f0Z4WTKrts,7893
triton/ops/bmm_matmul.py,sha256=grVD0Z800EWMC35TtZZoGgQihZqRUfnjJ9PfOpJmhfA,6381
triton/ops/cross_entropy.py,sha256=NRZdHy8nh70E0B_7BIHc-3-rkQF5Nb9nBnbfxfjHUAs,3457
triton/ops/flash_attention.py,sha256=AFjILTHWCkYXHDNkEBdrcsYIQ4OmYvYsKh7G2StXfH8,10335
triton/ops/matmul.py,sha256=bMoTeSnGBrR5oeSgtUr6syeLNPepydVpy1l9Ts-5E-g,8447
triton/ops/matmul_perf_model.py,sha256=KpYV6bJ9dRqg2js-1vLm_uPQyOytbUT8jaSVeleIXmI,6672
triton/runtime/__init__.py,sha256=4N2bT8rp10BAMHQzF9CiIt276t7XhA22c1uJkNB2ejk,516
triton/runtime/__pycache__/__init__.cpython-310.pyc,,
triton/runtime/__pycache__/autotuner.cpython-310.pyc,,
triton/runtime/__pycache__/cache.cpython-310.pyc,,
triton/runtime/__pycache__/driver.cpython-310.pyc,,
triton/runtime/__pycache__/errors.cpython-310.pyc,,
triton/runtime/__pycache__/jit.cpython-310.pyc,,
triton/runtime/autotuner.py,sha256=CYq8EkWf2mzAgubAvHheLtpsCYbrEdu7BbwKxYIhips,12718
triton/runtime/backends/cuda.c,sha256=5s9niGGe1tCvt1AW3O66jqiYQtpHQM9aUlKXppnTvO0,4680
triton/runtime/backends/hip.c,sha256=L37Uz-DsX5ntV_nI3OknmtB_xuufIWrfZAvW4SafFTA,4005
triton/runtime/cache.py,sha256=kpcFPdEbH3cO60HVdzhYtnOuIycXJnbDqWhbuF5hkFQ,4190
triton/runtime/driver.py,sha256=UIsodifreXt2d9NLErHR7wu7Fbr4nCpkhfOtYniMhnE,5100
triton/runtime/errors.py,sha256=XuA6URwCy4e3iYYbX9k35X89wnoasp7_K1LAPvsTbMQ,591
triton/runtime/jit.py,sha256=9LQhkSlhcKoWzx_x6E9E51eByfAj1U7yjwAyj-Mh57g,21468
triton/testing.py,sha256=akS7bFcE0LVMi56WCZSt7Wu4CfkRGsasJ9_tL1Zj3Vo,15520
triton/third_party/cuda/bin/ptxas,sha256=65yHgyKTqojU5uWOzQ134l3rk9pcGou5HefXzVsCgaw,20495600
triton/third_party/cuda/include/cuda.h,sha256=yCBNCpZt_qVabzyVattc2qrEq36p5MeL3kDzvNdeV4A,790319
triton/third_party/cuda/lib/libdevice.10.bc,sha256=XC-uN8huaMOjhgWpX1EtfRLV89uYYxC-R_VzBKpype4,473728
triton/third_party/rocm/lib/bitcode/asanrtl.bc,sha256=cHu7_b3BolfsDJMlWnGTY71qI6fNNdSVPDMFIojOh9k,24708
triton/third_party/rocm/lib/bitcode/cuda2gcn.bc,sha256=UDl-zwwWB5Ad__lIlhReRQBx61m0nMORxdppdTxkjQo,41568
triton/third_party/rocm/lib/bitcode/cuda2gcn.patch,sha256=EuHo5MOHjwbqdnuv53532LaFAuRh3C3J1QfRo0B3NJM,600
triton/third_party/rocm/lib/bitcode/hip.bc,sha256=yttUFcAzSE3ydohSCvOlTp5QZ3SjSH4juoSOYXz_zGg,2372
triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh,sha256=9jMtbtlZBIY6yyEp8V12-6Nw4WahdrCPFyFnRb1kpls,348
triton/third_party/rocm/lib/bitcode/ockl.bc,sha256=S3XZYIoykPyJ8XOFeVdqyRcdqxoBu50-AaU8oTdAhS8,227392
triton/third_party/rocm/lib/bitcode/oclc_abi_version_400.bc,sha256=K3TaLRC6uoHXTMK6WS7NgLJzjXeKKfvf_8a6lqfNwAs,1920
triton/third_party/rocm/lib/bitcode/oclc_abi_version_500.bc,sha256=9e6Z5f4lm_5asnMSHuPOrrnwlV6jlStgrwYTEiOU5tc,1920
triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_off.bc,sha256=DiRiGcUfL17ukJS40RV8OHCWvEU-Oi4oS5NCn-SRAFY,1936
triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_on.bc,sha256=cVJuiMvffHLxRqvPmzpbMp8nX7yxtUojn9A-iBfCmaY,1936
triton/third_party/rocm/lib/bitcode/oclc_daz_opt_off.bc,sha256=hcifwxlJIx5HlOcj_LjOUeqezaZ2v1xzX6LHNSQSkbE,1920
triton/third_party/rocm/lib/bitcode/oclc_daz_opt_on.bc,sha256=jYa1MyGdVRH6etbC6gicJ8egrwwV4wVHbe4aHRQ27R8,1920
triton/third_party/rocm/lib/bitcode/oclc_finite_only_off.bc,sha256=gWpVuB2OXig5vPczOQshtvZ_oJO72R5fFRiAYTILhgM,1928
triton/third_party/rocm/lib/bitcode/oclc_finite_only_on.bc,sha256=z9bPfFcxKqMILV3SsRJUIk6ExpZHak2Asv8FP9Ni9bw,1928
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1010.bc,sha256=0vFumpEzejM0Xz7hSSbEYT-iZCrPrDeTRicpD-H-jXM,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1011.bc,sha256=VqawQ_trHRs7rn9NQEkKQ0gDO1BW73hkl9ZdP5Ixybc,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1012.bc,sha256=2C8uRNtmebw0dHh8Vg5_PuFLiY4x19wcydnS7VvadOo,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1013.bc,sha256=lZfO_8vUqSCn1o3VDALvFN-Cclq3eOYDe7-FthD0Gng,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1030.bc,sha256=WIadAgcRBfWphRmcrQJUYewZXWtSN7npAAMAkNQ0C1w,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1031.bc,sha256=ofxROlV79qeW1SyCo9VmJ-W9CVK7zxZtbWLDA-nKRh0,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1032.bc,sha256=IXOzwf4sdsw453PcColeLVjEQek4H1X3G-IfqvoZfhA,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1033.bc,sha256=4KNBcZx7R7e6CJE-4-Ge3v4xzOIuGOOOdWhpga_txbU,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1034.bc,sha256=nmMcRhEdWZcSzMbO0VCj98VTujkBsswVyQiozF4Pgo4,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1035.bc,sha256=EfSk25ZwCSNNBVOhKoGvkKpPXmW5WSYPTCqD0PCwAB0,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1036.bc,sha256=L7hI21xe_IX3IFyQf_A73k1YXEf95w16rLV1BB6FwQE,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1100.bc,sha256=_UxQflZQXU8FJcg7B6VKJtJJwJRKdDY2qiauHkP6nFs,2096
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1101.bc,sha256=Lb83xl4VE1cQjOxz0gMwWS1_p-sq1wHjs67DpEI2RL0,2096
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1102.bc,sha256=Xqalbf-GktJPs_ZOrg54FCCdxzMAZ4U7Ny7HNqh4Blc,2096
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1103.bc,sha256=KMSqCG4J8tDGDRUoYK97xw-Gj0CeFJN0bYVn7Y8ot70,2096
triton/third_party/rocm/lib/bitcode/oclc_isa_version_600.bc,sha256=mkocG9cx-o5vwHCT98Fcq4v_fol3RUk24VPaJ2KNiJ0,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_601.bc,sha256=Ib-dn5dosOuXVbBdG59-RERNDk2eMX4vvWUm0Z8Qqpc,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_602.bc,sha256=QH5FGGkLH3fFhPJneGRrucoP44FBTKjm7VMKzGSHeDA,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_700.bc,sha256=ibu033_NCSnqAYqqD70CRVLeym0RFpgm-PX0q-Y2yDs,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_701.bc,sha256=P3d62xvb3_OSmACax3BeobuBAOP04X_0FjXQDKIL7bk,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_702.bc,sha256=WmuD3Kq09F2KQ3u6mFCfE0MoDZWQAfjwDkOHwZpBnsA,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_703.bc,sha256=3_iPbveGJ9fJa1Kvu-uEEGbuRAOrbTMauJGJNFPBvHI,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_704.bc,sha256=0PEze_BCWRcbGBeM_Jaote9xmzsW7hDCwbeuHs55PGo,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_705.bc,sha256=QBqWlg1VDWgQPXDvuPzWQ8--WRP3oSyk25nv70yQXMU,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_801.bc,sha256=iQKrQ9rQx2UiET0BQ-1RgibWE4_aouTdPoaTg0gUEZk,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_802.bc,sha256=9euokmFJ4pmx6IfkcBCpnHoRJrHqQ1I_zZMIYY7j7Ks,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_803.bc,sha256=-v7D1qhgp4YwsFgYL2R3gN4cN0vR-6RGxMGxMm7RpKM,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_805.bc,sha256=MalaxEN85FaTuJ9CUahDNKaCR2CZj81YNzVeXOHsDOA,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_810.bc,sha256=8Dbv2N9xQB93CzJ0DWe1YLpstI6SpEH71XABjEJu7BQ,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_900.bc,sha256=j3tpNoOQGtuHd4isYNNPdtYVGSmeXbIiHfLgPRioS8Q,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_902.bc,sha256=3SlZ_VGBX9y8cRm9hF3XdtadPFJFLsIaoOnGdnqtdFY,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_904.bc,sha256=OLEXkdaKEefZgzsommw6c5378MDasYKH_PSkGoU7hGU,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_906.bc,sha256=KeOq_NoSwkh3D-1eYbzEvYcLqx-a4XHlMib3GQhtoFY,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_908.bc,sha256=PFAdOwLxneisb1dnPMU8nkm8kslmNuCGWbVFOP9bgj0,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_909.bc,sha256=J4Cwry1PZErUfvaYu4ZNCct1NQMDle0mz84ycbKJp6g,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_90a.bc,sha256=ojZ7tLg2N7Hq1omZIrCkBRzia0bYtOdGSU4CmkkdfhE,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_90c.bc,sha256=6mVoOg82NG_NfT2Su6fB18DFgRBRtOSQnRL2QgjRdHY,1920
triton/third_party/rocm/lib/bitcode/oclc_isa_version_940.bc,sha256=Z5ahEe7z2mwBgxuBySEOG7-rRRdpVACwawch5PE6s2I,1920
triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_off.bc,sha256=evf88YMj0hBicj04zDdTM0Xlu4KG9fpUlj3iGS0F87A,1928
triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_on.bc,sha256=bYF53Lu769UGewikZaQMfoxOerQBGs3SfLpQXoj3zVk,1928
triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_off.bc,sha256=XlcwAnD8EeEHQ6FksQG3-Rxc07zMt7aFXVd-xmGBeDo,1928
triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_on.bc,sha256=s7t5si7PLY1kIAs0oJGY-9uSL-lMKtcXdGZqzXyFteI,1928
triton/third_party/rocm/lib/bitcode/ocml.bc,sha256=bWAX1FuKc-cJrmLlC6kYhvs5fzRoBbc0Y7BP5qgg0BA,192228
triton/third_party/rocm/lib/bitcode/opencl.bc,sha256=mL2MGdVis93KLQWVYYj0pT6d8iS99DKlbnEe-1VGdss,2841028
triton/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
triton/tools/__pycache__/__init__.cpython-310.pyc,,
triton/tools/__pycache__/aot.cpython-310.pyc,,
triton/tools/__pycache__/build_extern.cpython-310.pyc,,
triton/tools/__pycache__/disasm.cpython-310.pyc,,
triton/tools/aot.py,sha256=DUb9edluQKqfvY6aeIyQk2dN1Ljv3i6GYystyObqc8o,4752
triton/tools/build_extern.py,sha256=t5vdo6a3bcFfawMQBnTderw7Xi-L49x4udomoOWRUtw,14661
triton/tools/disasm.py,sha256=FOtqCFjUedLoZtb1hektBD32L1YxC43qnWKZr6VXc5E,4593

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.44.0)
Root-Is-Purelib: false
Tag: cp310-cp310-linux_x86_64

View File

@@ -0,0 +1 @@
{"archive_info": {"hash": "sha256=0b2734afe117d3c56cda0a7cd1f4752510dc91d8678638321dcbaee07110da40", "hashes": {"sha256": "0b2734afe117d3c56cda0a7cd1f4752510dc91d8678638321dcbaee07110da40"}}, "url": "file:///usr/local/apps_whl/triton-2.1.0%2Bcorex.4.1.2-cp310-cp310-linux_x86_64.whl"}

View File

@@ -0,0 +1,15 @@
triton
triton/_C
triton/common
triton/compiler
triton/debugger
triton/language
triton/language/extra
triton/ops
triton/ops/blocksparse
triton/runtime
triton/runtime/backends
triton/third_party/cuda/bin
triton/third_party/cuda/include
triton/third_party/cuda/lib
triton/tools

BIN
pkgs/triton/_C/libtriton.so Executable file

Binary file not shown.

68
pkgs/triton/__init__.py Normal file
View File

@@ -0,0 +1,68 @@
"""isort:skip_file"""
__version__ = '2.1.0'
# ---------------------------------------
# Note: import order is significant here.
# submodules
from .runtime import (
autotune,
Config,
heuristics,
JITFunction,
KernelInterface,
reinterpret,
TensorWrapper,
OutOfResources,
MockTensor,
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from .debugger.debugger import program_ids_from_grid
from . import language
from . import testing
__all__ = [
"autotune",
"cdiv",
"CompilationError",
"compile",
"Config",
"heuristics",
"impl",
"jit",
"JITFunction",
"KernelInterface",
"language",
"MockTensor",
"next_power_of_2",
"ops",
"OutOfResources",
"reinterpret",
"runtime",
"TensorWrapper",
"testing",
"program_ids_from_grid",
]
# -------------------------------------
# misc. utilities that don't fit well
# into any specific module
# -------------------------------------
def cdiv(x, y):
return (x + y - 1) // y
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,3 @@
from .build import _build
__all__ = ["_build"]

Binary file not shown.

137
pkgs/triton/common/build.py Normal file
View File

@@ -0,0 +1,137 @@
import contextlib
import functools
import io
import os
import shutil
import subprocess
import sys
import sysconfig
import setuptools
# TODO: is_hip shouldn't be here
def is_hip():
import torch
return torch.version.hip is not None
def is_corex():
import torch
return hasattr(torch, "corex") and torch.corex == True
@functools.lru_cache()
def cuda_home_dirs():
loc = subprocess.check_output(["whereis", "clang++"]).decode().strip().split()[1]
default_dir = os.path.dirname(os.path.dirname(loc))
return os.getenv("CUDA_HOME", default=default_dir)
@functools.lru_cache()
def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(name, src, srcdir):
if is_hip():
hip_lib_dir = os.path.join(rocm_path_dir(), "lib")
hip_include_dir = os.path.join(rocm_path_dir(), "include")
else:
if is_corex():
cuda_path = cuda_home_dirs()
cu_include_dir = os.path.join(cuda_path, "include")
cuda_lib_dirs = [os.path.join(cuda_path, "lib64")]
else:
cuda_lib_dirs = libcuda_dirs()
base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
cuda_path = os.path.join(base_dir, "third_party", "cuda")
cu_include_dir = os.path.join(cuda_path, "include")
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
cuda_header = os.path.join(cu_include_dir, "cuda.h")
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
cu_include_dir = triton_include_dir
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
if cc is None:
# TODO: support more things here.
clang = shutil.which("clang")
gcc = shutil.which("gcc")
if is_corex():
cc = clang if clang is not None else gcc
else:
cc = gcc if gcc is not None else clang
if cc is None:
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
if is_hip():
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so

View File

@@ -0,0 +1,4 @@
from .compiler import CompiledKernel, compile
from .errors import CompilationError
__all__ = ["compile", "CompiledKernel", "CompilationError"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,631 @@
from __future__ import annotations
import functools
import hashlib
import json
import os
import re
import subprocess
import tempfile
from collections import namedtuple
from pathlib import Path
from typing import Any, Tuple
import triton
import triton._C.libtriton.triton as _triton
from ..runtime import driver
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..tools.disasm import extract
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
def is_corex():
import torch
return hasattr(torch, "corex") and torch.corex == True
CUDA_DEFAULT_WARP_SIZE = 64 if is_corex() else 32
def inline_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.run(mod)
return mod
def ttir_compute_capability_rewrite(mod, arch):
# For hardware without support, we must rewrite all load/store
# with block (tensor) pointers into tensors of pointers
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
if _is_cuda(arch):
pm.add_rewrite_tensor_pointer_pass(arch)
pm.run(mod)
return mod
def optimize_ttir(mod, arch):
mod = inline_triton_ir(mod)
mod = ttir_compute_capability_rewrite(mod, arch)
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_licm_pass()
pm.add_symbol_dce_pass()
pm.run(mod)
return mod
def ttir_to_ttgir(mod, num_warps, warpsize):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
pm.run(mod)
return mod
def optimize_ttgir(mod, num_stages, arch, use_sme = 0):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_coalesce_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
if _is_cuda(arch):
pm.add_tritongpu_accelerate_matmul_pass(arch, use_sme)
# TODO change interface of accelerate_matmul_pass
if is_hip() and gpu_has_mfma():
pm.add_tritongpu_accelerate_matmul_pass(80)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
if is_corex():
pm.add_tritongpu_matmul_smeload_pass(arch) #BI 70 MR > 71,only MR support sme
pm.add_tritongpu_remove_layout_conversions_pass()
# TODO enable this pass for AMD GPU when it is ready
if not is_hip():
pm.add_tritongpu_pipeline_pass(num_stages)
if not is_corex():
pm.add_tritongpu_prefetch_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
pm.run(mod)
return mod
def _add_external_libs(mod, libs):
for name, path in libs.items():
if len(name) == 0 or len(path) == 0:
return
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
def ttgir_to_llir(mod, extern_libs, arch):
if extern_libs:
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
else:
return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)
# PTX translation
@functools.lru_cache()
def ptx_get_version(cuda_version) -> int:
'''
Get the highest PTX version supported by the current CUDA driver.
'''
assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.'))
if major == 12:
return 80 + minor
if major == 11:
return 70 + minor
if major == 10:
return 63 + minor
raise RuntimeError("Triton only support CUDA 10.0 or higher")
@functools.lru_cache()
def path_to_ptxas():
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get("TRITON_PTXAS_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
]
for ptxas in paths:
if os.path.exists(ptxas) and os.path.isfile(ptxas):
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return ptxas, version.group(1)
raise RuntimeError("Cannot find ptxas")
def llir_to_cubin(mod: Any, arch: int):
'''
Compile LLVM module to cubin.
:param mod: a LLVM module
:return: str
'''
return _triton.translate_llvmir_to_cubin(mod, arch)
def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
:return: PTX code
'''
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)
def ptx_to_cubin(ptx: str, arch: int):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param compute_capability: compute capability
:return: str
'''
ptxas, _ = path_to_ptxas()
return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
# AMDGCN translation
def get_amdgcn_bitcode_paths(arch):
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
"ocml.bc",
"ockl.bc",
"oclc_finite_only_off.bc",
"oclc_daz_opt_off.bc",
"oclc_correctly_rounded_sqrt_on.bc",
"oclc_unsafe_math_off.bc",
"oclc_wavefrontsize64_on.bc",
"oclc_abi_version_400.bc",]
gfx_arch = arch[1]
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
amdgcn_bitcode_paths = {}
i = 0
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
bc_path = bitcode_path_dir + bc_lib
if os.path.exists(bc_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
i += 1
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
if os.path.exists(bc_gfx_path):
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
return amdgcn_bitcode_paths
def get_amdgpu_arch_fulldetails():
"""
get the amdgpu fulll ISA details for compiling:
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
"""
try:
# TODO: package rocm.cc with Triton
arch_info = _triton.get_arch_info()
warpsize = _triton.get_warp_size()
gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--')
arch_triple = gfx_arch_details[0]
arch_name_features = gfx_arch_details[1].split(':')
arch_name = arch_name_features[0]
arch_features = ""
if (len(arch_name_features) == 3):
arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\
"-" + re.search('\\w+', arch_name_features[2]).group(0)
return [arch_triple, arch_name, arch_features, warpsize]
except BaseException:
return None
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
'''
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
:param mod: a TritonGPU dialect module
:return:
- AMDGCN code
- Path to HSACO object
'''
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
# ------------------------------------------------------------------------------
# compiler
# ------------------------------------------------------------------------------
def get_kernel_name(src: str, pattern: str, llir: bool = False) -> str:
'''
Get kernel name from PTX code.
This Kernel name is required when launching the kernel.
'''
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
assert src
for line in src.split('\n'):
line = line.strip()
if line.startswith(pattern):
if not llir:
return line.split()[-1]
return line.split("(")[0].split("@")[-1]
def convert_type_repr(x):
match = re.search(r'!tt\.ptr<(.*)>', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return x
def make_hash(fn, arch, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
debug = kwargs.get("debug", False)
use_sme = kwargs.get("use_sme", 0)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}-{use_sme}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
def _get_jsonable_constants(constants):
def _is_jsonable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
serialized_constants = {}
for constant in constants:
if _is_jsonable(constants[constant]):
serialized_constants[constant] = constants[constant]
return serialized_constants
def parse_mlir_module(path, context):
module = _triton.ir.parse_mlir_module(path, context)
# module takes ownership of the context
module.context = context
return module
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
# TODO: architecture descriptor class
def _is_cuda(arch):
return isinstance(arch, int)
def is_hip():
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
return torch.version.hip is not None
from ..language.semantic import gpu_has_mfma
def get_architecture_descriptor(capability):
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if capability is None:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
else:
capability = get_amdgpu_arch_fulldetails()
return capability
def add_rocm_stages(arch, extern_libs, stages):
extern_libs.update(get_amdgcn_bitcode_paths(arch))
for key in list(extern_libs):
if extern_libs[key] == '' or extern_libs[key] is None:
extern_libs.pop(key)
gfx_arch_full_details = arch
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
if gfx_arch is None:
raise RuntimeError('gfx_arch is None (not specified)')
stages["amdgcn"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
gfx_arch_full_details[0],
gfx_arch_full_details[2]))
def add_cuda_stages(arch, extern_libs, stages):
stages["ptx"] = (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, arch))
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, arch))
def add_iluvatar_stages(arch, extern_libs, stages):
stages["cubin"] = (lambda path: Path(path).read_bytes(),
lambda src: llir_to_cubin(src, arch))
def compile(fn, **kwargs):
if is_hip():
capability = None
else:
capability = kwargs.get("cc", None)
arch = get_architecture_descriptor(capability)
is_cuda = _is_cuda(arch)
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
context = _triton.ir.context()
asm = dict()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2)
extern_libs = kwargs.get("extern_libs", dict())
use_sme = kwargs.get("use_sme", 0)
if extern_libs is None:
extern_libs = dict()
debug = kwargs.get("debug", False)
# build compilation stages
stages = dict()
stages["ast"] = (lambda path: fn, None)
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug), arch))
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch, use_sme))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch))
if is_cuda:
if is_corex():
add_iluvatar_stages(arch, extern_libs, stages)
else:
add_cuda_stages(arch, extern_libs, stages)
else:
add_rocm_stages(arch, extern_libs, stages)
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
signature = kwargs["signature"]
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
kwargs["configs"] = configs
name = fn.__name__
first_stage = 0
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
kwargs["signature"] = signature
else:
assert isinstance(fn, str)
_, ir = os.path.basename(fn).split(".")
src = Path(fn).read_text()
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
types = re.findall(arg_type_pattern[ir], signature)
if ir == 'ttgir':
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
num_warps = int(num_warps_matches[0])
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir)
# cache manager
so_path = make_stub(name, signature, constants)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
# load metadata if any
metadata = None
metadata_filename = f"{name}.json"
# The group is addressed by the metadata
metadata_group = fn_cache_manager.get_group(
metadata_filename
) or {}
metadata_path = metadata_group.get(metadata_filename)
if metadata_path is not None:
with open(metadata_path) as f:
metadata = json.load(f)
else:
metadata = {"num_warps": num_warps,
"warp_size": warp_size,
"num_stages": num_stages,
"constants": _get_jsonable_constants(constants),
"debug": debug,
"use_sme": use_sme}
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
ir_filename = f"{name}.{ir}"
if ir == ext:
next_module = parse(fn)
else:
path = metadata_group.get(ir_filename)
if path is None:
next_module = compile_kernel(module)
if ir == "amdgcn":
extra_file_name = f"{name}.hsaco_path"
metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename)
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
else:
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
fn_cache_manager.put(next_module, ir_filename)
else:
if ir == "amdgcn":
extra_file_name = f"{name}.hsaco_path"
hasco_path = metadata_group.get(extra_file_name)
assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn"
next_module = (parse(path), parse(hasco_path))
else:
next_module = parse(path)
if ir == "llir":
metadata["name"] = get_kernel_name(next_module, pattern="define iluvatar_kernel void", llir=True)
if ir == "cubin":
asm[ir] = next_module
elif ir == "amdgcn":
asm[ir] = str(next_module[0])
else:
asm[ir] = str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
if ir == "amdgcn":
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
asm["hsaco_path"] = next_module[1]
module = next_module
# write-back metadata, if it didn't come from the cache
if metadata_path is None:
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
# return handle to compiled kernel
return CompiledKernel(fn, so_path, metadata, asm)
class CompiledKernel:
# Hooks for external tools to monitor the execution of triton kernels
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, fn, so_path, metadata, asm):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
mod = importlib.util.module_from_spec(spec)
self.fn = fn
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.warp_size = metadata["warp_size"]
self.num_stages = metadata["num_stages"]
self.constants = metadata["constants"]
# initialize asm dict
self.asm = asm
# binaries are lazily initialized
# because it involves doing runtime things
# (e.g., checking amount of shared memory on current device)
self.metadata = metadata
self.cu_module = None
self.cu_function = None
def _init_handles(self):
if self.cu_module is not None:
return
device = triton.runtime.jit.get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
self.n_spills = n_spills
self.n_regs = n_regs
self.cu_module = mod
self.cu_function = func
def __getattribute__(self, name):
if name == 'c_wrapper':
self._init_handles()
return super().__getattribute__(name)
def __getitem__(self, grid):
self._init_handles()
def runner(*args, stream=None):
if stream is None:
stream = triton.runtime.jit.get_cuda_stream()
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
return runner
def get_sass(self, fun=None):
if 'sass' in self.asm:
return self.asm['sass']
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass

View File

@@ -0,0 +1,52 @@
import ast
from typing import Optional, Union
class CompilationError(Exception):
source_line_count_max_in_message = 12
def _format_message(self) -> str:
node = self.node
if self.src is None:
source_excerpt = " <source unavailable>"
else:
source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
if source_excerpt:
source_excerpt.append(' ' * node.col_offset + '^')
source_excerpt = '\n'.join(source_excerpt)
else:
source_excerpt = " <source empty>"
message = "at {}:{}:{}".format(node.lineno, node.col_offset, source_excerpt)
if self.error_message:
message += '\n' + self.error_message
return message
def __init__(self, src: Optional[str], node: ast.AST, error_message: Union[str, None]):
self.src = src
self.node = node
self.error_message = error_message
self.message = self._format_message()
def set_source_code(self, src: Optional[str]):
self.src = src
self.message = self._format_message()
def __str__(self):
return self.message
def __repr__(self):
return "{}({!r})".format(type(self).__name__, self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return type(self), (self.src, self.node, self.error_message)
class CompileTimeAssertionFailure(CompilationError):
"""Specific exception for failed tests in `static_assert` invocations"""
pass
class UnsupportedLanguageConstruct(CompilationError):
pass

View File

@@ -0,0 +1,392 @@
import hashlib
import os
import tempfile
from ..common import _build
from ..runtime.cache import get_cache_manager
from ..runtime.jit import version_key
def is_hip():
import torch
return torch.version.hip is not None
def is_corex():
import torch
return hasattr(torch, "corex") and torch.corex == True
# ----- stub --------
def make_so_cache_key(version_hash, signature, constants):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{version_hash}-{''.join(signature.values())}{constants}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
def make_stub(name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
cache_path = so_cache_manager.get_file(so_name)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src = generate_launcher(constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(name, src_path, tmpdir)
so_cache_manager.put(src, f"{name}.c", binary=False)
with open(so, "rb") as f:
return so_cache_manager.put(f.read(), so_name, binary=True)
else:
return cache_path
# ----- source code generation --------
def ty_to_cpp(ty):
if ty[0] == '*':
return "hipDeviceptr_t" if is_hip() else "CUdeviceptr"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
"bf16": "float",
"fp32": "float",
"f32": "float",
"fp64": "double",
}[ty]
def generate_launcher(constants, signature):
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
}[ty]
def format_of(ty):
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"uint32_t": "I",
"int32_t": "i",
"uint64_t": "K",
"int64_t": "L",
}[ty]
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
if is_hip():
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
#include <stdio.h>
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [HIP]: ";
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static int getWarpSize(hipStream_t stream)
{{
int device_id = hipGetStreamDeviceId(stream);
gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__);
hipDeviceProp_t prop;
HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
return prop.warpSize;
}}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if (gridX*gridY*gridZ > 0) {{
int warp_size = getWarpSize(stream);
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0));
}}
}}
typedef struct _DevicePtrInfo {{
hipDeviceptr_t dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if (ptr) {{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
if (!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == hipErrorInvalidValue) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject_CallObject(launch_enter_hook, args);
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
if (PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
else:
warp_size = 64 if is_corex() else 32
src = f"""
#include \"cuda.h\"
#include <stdbool.h>
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{{
if (code != CUDA_SUCCESS)
{{
const char* prefix = "Triton Error [CUDA]: ";
const char* str;
cuGetErrorString(code, &str);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * {warp_size}, 1, 1, shared_memory, stream, params, 0));
}}
}}
typedef struct _DevicePtrInfo {{
CUdeviceptr dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
if(!ptr_info.dev_ptr)
return ptr_info;
/*
uint64_t dev_ptr;
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == CUDA_ERROR_INVALID_VALUE) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = dev_ptr;
*/
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject_CallObject(launch_enter_hook, args);
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
return src

View File

Binary file not shown.

View File

@@ -0,0 +1,9 @@
from typing import Tuple
import dataclasses
@dataclasses.dataclass
class ExecutionContext:
program_id: Tuple[int]
program_size: Tuple[int]

View File

@@ -0,0 +1,170 @@
import itertools
import random
from typing import Tuple
import triton
import triton.language as tl
from .core import ExecutionContext
from .memory_map import MemoryMap
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
debugger_constexpr)
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
tl_method_backup = {}
def get_proxy_method(proxy, name):
method = getattr(proxy, name)
def fun(*args, **kwarg):
return method(*args, **kwarg)
return fun
def attach_triton(module, proxy):
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
for name in method_list:
if hasattr(module, name):
attr = getattr(module, name)
tl_method_backup[name] = attr
if callable(attr):
setattr(module, name, get_proxy_method(proxy, name))
else:
setattr(module, name, getattr(proxy, name))
def detach_triton(module):
for name, method in tl_method_backup.items():
setattr(module, name, method)
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
# reverse the grid dimensions and generate the range for each dimension
reversed_grid = reversed(grid)
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
# gen all combinations
index_combinations = list(itertools.product(*ranges_for_each_dimension))
random.shuffle(index_combinations)
for index_combination in index_combinations:
yield index_combination
class DebuggerFunction:
def __init__(self, func, grid=(1,)):
self.func = func
self.grid = grid
def _is_constexpr(self, name):
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr
def _get_constexpr(self):
result = []
for name, annotation in self.func.__annotations__.items():
if annotation is triton.language.core.constexpr:
result.append(name)
return result
def _assert_constexpr(self, **kwargs):
constexp = self._get_constexpr()
missing = [i for i in constexp if i not in kwargs.keys()]
assert len(missing) == 0, f"You must specify constexpr {missing}"
def _get_grid(self, **kwargs):
if callable(self.grid):
return self.grid(kwargs)
else:
return self.grid
def __call__(self, *args, **kwargs):
self._assert_constexpr(**kwargs)
memory = MemoryMap()
def convert_arg(v):
name, arg = v
if torch.is_tensor(arg):
ptr = memory.add_tensor(arg)
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
if self._is_constexpr(name):
return debugger_constexpr(arg)
return WrappedTensor(_primitive_to_tensor(arg))
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
grid = self._get_grid(**kwargs)
for program_id in program_ids_from_grid(grid):
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
attach_triton(tl, proxy)
self.func(*new_args, **new_kwargs)
detach_triton(tl)
class GridSelector:
"""
Entry point of the debugger
"""
def __init__(self, func):
version = torch.__version__
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
self.func = func
def __getitem__(self, grid):
return DebuggerFunction(self.func, grid)
def __call__(self, *args, **kwargs):
return DebuggerFunction(self.func)(*args, **kwargs)
class AutotuneGridSelector:
def __init__(self, func, autotune_params):
self.func = func
self.autotune_params = autotune_params
def __getitem__(self, grid):
return AutotuneRunner(self.func, self.autotune_params, grid)
def __call__(self, *args, **kwargs):
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
class AutotuneRunner:
def __init__(self, func, autotune_params, grid=None):
self.func = func
self.autotune_params = autotune_params
self.grid = grid
def __call__(self, *args, **kwargs):
assert len(self.autotune_params["configs"]) >= 1
for config in self.autotune_params["configs"][1:]:
def convert_arg(v):
if torch.is_tensor(v):
return torch.clone(v)
return v
new_args = tuple(map(convert_arg, args))
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
if self.grid:
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
else:
self.func(*new_args, **new_kwargs, **config.kwargs)
main_config = self.autotune_params["configs"][0]
if self.grid:
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
else:
self.func(*args, **kwargs, **main_config.kwargs)
def triton_debug_autotune(**kwars):
def wrapper(func):
return AutotuneGridSelector(func, kwars)
return wrapper

View File

@@ -0,0 +1,100 @@
import dataclasses
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
@dataclasses.dataclass
class RegisteredStorage:
storage: torch.Storage
dtype: torch.dtype
size: int
ptr: int
@property
def end_ptr(self) -> int:
return self.ptr + self.size
@property
def access_tensor(self) -> torch.Tensor:
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
def ensure_immutable(self):
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
class MemoryMap:
storages: [RegisteredStorage]
def __init__(self):
self.storages = []
def _get_registered_storage(self, pointer: torch.Tensor):
max_pointer = torch.max(pointer).item()
min_pointer = torch.min(pointer).item()
registered_storage = next(
filter(
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
),
None,
)
if registered_storage is None:
raise Exception("Storage not found or pointers spanning multiple tensors")
registered_storage.ensure_immutable()
return registered_storage
def add_tensor(self, t: torch.Tensor):
storage = t.untyped_storage()
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
return t.data_ptr()
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
):
assert pointer.is_cuda
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert mask.is_cuda
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
# Todo: The type is wrong here, we can't determine the correct type
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
block[mask] = access_tensor[index_tensor[mask]]
return block
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
return
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)

View File

@@ -0,0 +1,621 @@
import triton
from .core import ExecutionContext
from .memory_map import MemoryMap
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
def _primitive_to_tensor(x):
"""
Converts various Python primitive data types to PyTorch tensor.
"""
tensor_args = {"device": "cuda"}
if isinstance(x, bool):
return torch.tensor([x], dtype=torch.bool, **tensor_args)
elif isinstance(x, int):
if -(2**31) <= x < 2**31:
return torch.tensor([x], dtype=torch.int32, **tensor_args)
elif -(2**63) <= x < 2**63:
return torch.tensor([x], dtype=torch.int64, **tensor_args)
else:
raise RuntimeError(f"Nonrepresentable integer {x}.")
elif isinstance(x, float):
return torch.tensor([x], dtype=torch.float32, **tensor_args)
elif torch.is_tensor(x):
return x
elif isinstance(x, WrappedTensor):
return x
elif isinstance(x, debugger_constexpr):
if x.value is None:
return None
return _primitive_to_tensor(x.value)
elif x is None:
return None
assert False, f"cannot convert {x} of type {type(x)} to tensor"
def _infer_tensor(func):
"""
A decorator function to harmonize function args:
- converts primitives to PyTorch tensors
- wraps PyTorch tensors with WrappedTensors
"""
def wrapper(*args):
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
return func(*new_args)
return wrapper
def _tensor_operation(func):
"""
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
"""
def wrapper(*args, **kwargs):
for arg in args:
assert not torch.is_tensor(arg), "unexpected tensor argument"
def unwrap_tensor(v):
if isinstance(v, WrappedTensor):
return v.tensor
if isinstance(v, debugger_constexpr):
return v.value
return v
new_args = tuple(map(unwrap_tensor, args))
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
result = func(args[0], *new_args[1:], **new_kwargs)
return WrappedTensor(result) if torch.is_tensor(result) else result
return wrapper
class debugger_constexpr:
def __init__(self, value):
if isinstance(value, debugger_constexpr):
self.value = value.value
else:
self.value = value
def __str__(self) -> str:
return "debugger_constexpr(" + str(self.value) + ")"
def __index__(self) -> int:
return self.value
def __bool__(self):
return bool(self.value)
def __ge__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value >= other
def __gt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value > other
def __le__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value <= other
def __lt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value < other
def __eq__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value == other
def __or__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __ror__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __and__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def __rand__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def to(self, dtype, bitcast=False, _builder=None):
if dtype in [torch.int64]:
ret_ty = int
elif dtype == torch.bool:
ret_ty = bool
elif dtype in [torch.float64]:
ret_ty = float
else:
raise ValueError("dtype not supported in debugger")
return debugger_constexpr(ret_ty(self.value))
class WrappedTensor:
def __init__(self, tensor):
self.tensor = tensor
def __index__(self) -> int:
return self.tensor.item()
def __str__(self) -> str:
return "wrapped_" + str(self.tensor)
def __bool__(self) -> bool:
return torch.all(self.tensor == True).item() # noqa: E712
@property
def dtype(self):
return self.tensor.dtype
@_infer_tensor
@_tensor_operation
def __add__(self, other):
return torch.add(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __radd__(self, other):
return self.__add__(other)
@_infer_tensor
@_tensor_operation
def __sub__(self, other):
return torch.sub(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rsub__(self, other):
return torch.sub(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mul__(self, other):
return torch.mul(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmul__(self, other):
return self.__mul__(other)
@_infer_tensor
@_tensor_operation
def __truediv__(self, other):
return torch.div(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rtruediv__(self, other):
return torch.div(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __floordiv__(self, other):
return torch.floor_divide(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rfloordiv__(self, other):
return torch.floor_divide(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mod__(self, other):
return torch.remainder(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmod__(self, other):
return torch.remainder(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __neg__(self):
return -self.tensor
@_infer_tensor
@_tensor_operation
def __invert__(self):
return ~self.tensor
@_infer_tensor
@_tensor_operation
def __and__(self, other):
return torch.bitwise_and(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __or__(self, other):
return torch.bitwise_or(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __xor__(self, other):
return torch.bitwise_xor(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __lshift__(self, other):
return torch.bitwise_left_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rshift__(self, other):
return torch.bitwise_right_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __gt__(self, other):
return self.tensor > other
@_infer_tensor
@_tensor_operation
def __rgt__(self, other):
return other > self.tensor
@_infer_tensor
@_tensor_operation
def __ge__(self, other):
return self.tensor >= other
@_infer_tensor
@_tensor_operation
def __rge__(self, other):
return other >= self.tensor
@_infer_tensor
@_tensor_operation
def __lt__(self, other):
return self.tensor < other
@_infer_tensor
@_tensor_operation
def __rlt__(self, other):
return other < self.tensor
@_infer_tensor
@_tensor_operation
def __le__(self, other):
return self.tensor <= other
@_infer_tensor
@_tensor_operation
def __rle__(self, other):
return other <= self.tensor
@_infer_tensor
@_tensor_operation
def __eq__(self, other):
return torch.equal(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __ne__(self, other):
return not torch.equal(self.tensor, other)
@_tensor_operation
def __getitem__(self, slices):
return self.tensor.__getitem__(slices)
# if isinstance(slices, slice):
# slices = [slices]
# src_shape = self.shape
# dst_shape = []
# curr = 0
# for sl in slices:
# if isinstance(sl, constexpr) and sl.value is None:
# dst_shape.append(1)
# elif sl == slice(None, None, None):
# dst_shape.append(src_shape[curr].value)
# curr += 1
# ret = torch.reshape(self.tensor, dst_shape, )
# return ret
@_tensor_operation
def to(self, dtype, bitcast=False):
return self.tensor.to(dtype)
# if isinstance(bitcast, constexpr):
# bitcast = bitcast.value
# if bitcast:
# return semantic.bitcast(self, dtype, )
# return semantic.cast(self, dtype, )
def _constexpr_to_value(v):
if isinstance(v, debugger_constexpr):
return v.value
return v
class TritonLangProxy:
_memory_map: MemoryMap
_context: ExecutionContext
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
self._memory_map = memory_map
self._context = context
# Types
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
# constexpr = debugger_constexpr
# Program functions
@_tensor_operation
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
cache_modifier="",
eviction_policy="",
volatile=False,
):
return self._memory_map.load(pointer, mask, other)
@_tensor_operation
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
return self._memory_map.store(pointer, value, mask)
@_tensor_operation
def program_id(self, axis):
assert axis < len(self._context.program_id)
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def num_programs(self, axis):
assert axis < len(self._context.program_size)
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def arange(self, start, end):
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
@_tensor_operation
def zeros(self, shape, dtype):
for i, d in enumerate(shape):
if not isinstance(d, debugger_constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
if isinstance(dtype, triton.language.core.dtype):
if dtype.is_fp32():
dtype = torch.float32
elif dtype.is_fp16():
dtype = torch.float16
elif dtype.is_bf16():
dtype = torch.bfloat16
elif dtype.is_int32():
dtype = torch.int32
elif dtype.is_int16():
dtype = torch.int16
elif dtype.is_int8():
dtype = torch.int8
else:
raise TypeError(f"Unsupported dtype {dtype}")
return torch.zeros(size=shape, dtype=dtype, device="cuda")
@_tensor_operation
def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16):
raise NotImplementedError()
@_tensor_operation
def broadcast(self, input, other):
raise NotImplementedError()
@_tensor_operation
def broadcast_to(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def cat(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def reshape(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
assert input.dtype == other.dtype
if trans_a:
input = input.T
if trans_b:
other = other.T
return torch.matmul(input=input, other=other)
@_tensor_operation
def atomic_cas(self, pointer, cmp, val):
stored = self._memory_map.load(pointer, None, 0.0)
if not isinstance(cmp, torch.Tensor):
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
if not isinstance(val, torch.Tensor):
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
if stored == cmp:
self._memory_map.store(pointer, val, None)
return stored
@_tensor_operation
def atomic_xchg(self, pointer, val, mask=None):
if isinstance(val, int):
val = torch.tensor([val], dtype=torch.int32, device="cuda")
stored = self._memory_map.load(pointer, mask, 0.0)
self._memory_map.store(pointer, val, mask)
return stored
@_tensor_operation
def atomic_add(self, pointer, val, mask=None):
# arbitrary other value as it will masked during storing
stored = self._memory_map.load(pointer, mask, 0.0)
result = stored + val
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_max(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.maximum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_min(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.minimum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_and(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_and(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_or(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_or(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_xor(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_xor(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def where(self, condition, x, y):
condition = _primitive_to_tensor(condition)
x = _primitive_to_tensor(x)
y = _primitive_to_tensor(y)
return torch.where(condition, x, y)
@_tensor_operation
def umulhi(self, x, y):
raise NotImplementedError()
@_tensor_operation
def fdiv(self, x, y, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def exp(self, x):
return torch.exp(x)
@_tensor_operation
def log(self, x):
return torch.log(x)
@_tensor_operation
def cos(self, x):
return torch.cos(x)
@_tensor_operation
def sin(self, x):
return torch.sin(x)
@_tensor_operation
def sqrt(self, x):
return torch.sqrt(x)
@_tensor_operation
def globaltimer(self):
raise NotImplementedError()
@_tensor_operation
def clock(self):
raise NotImplementedError()
@_tensor_operation
def debug_barrier(self):
raise NotImplementedError()
@_tensor_operation
def multiple_of(self, input, values):
return input
@_tensor_operation
def max_contiguous(self, input, values):
return input
@_tensor_operation
def abs(self, x):
return torch.abs(x)
@_tensor_operation
def cdiv(self, x, div):
return (x + div - 1) // div
@_tensor_operation
def minimum(self, x, y):
if isinstance(x, int):
x = torch.tensor(x, device="cuda")
if isinstance(y, int):
y = torch.tensor(y, device="cuda")
return torch.minimum(x, y)
@_tensor_operation
def maximum(self, x, y):
return torch.maximum(x, y)
@_tensor_operation
def sigmoid(self, x):
raise NotImplementedError()
@_tensor_operation
def softmax(self, x, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def ravel(self, x):
raise NotImplementedError()
@_tensor_operation
def swizzle2d(self, i, j, size_i, size_j, size_g):
raise NotImplementedError()
@_tensor_operation
def zeros_like(self, input):
raise NotImplementedError()
@_tensor_operation
def max(self, input, axis=None):
if axis is None:
return torch.max(input)
return torch.max(input, dim=axis).values
@_tensor_operation
def argmax(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def min(self, input, axis=None):
if axis is None:
return torch.min(input)
return torch.min(input, dim=axis).values
@_tensor_operation
def argmin(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def sum(self, input, axis=None):
if axis is None:
return torch.sum(input)
return torch.sum(input, dim=axis)
@_tensor_operation
def xor_sum(self, input, axis):
raise NotImplementedError()

View File

@@ -0,0 +1,18 @@
try:
import torch as _torch
except ImportError:
_torch = None
class TorchWrapper:
"""
Helps in making torch an optional dependency
"""
def __getattr__(self, name):
if _torch is None:
raise ImportError("Triton requires PyTorch to be installed")
return getattr(_torch, name)
torch = TorchWrapper()

View File

@@ -0,0 +1,201 @@
"""isort:skip_file"""
# Import order is significant here.
from . import math
from . import extra
from .standard import (
cdiv,
sigmoid,
softmax,
ravel,
swizzle2d,
zeros,
zeros_like,
)
from .core import (
abs,
advance,
arange,
argmin,
argmax,
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
bfloat16,
block_type,
broadcast,
broadcast_to,
cat,
constexpr,
cos,
debug_barrier,
device_assert,
device_print,
dot,
dtype,
exp,
expand_dims,
full,
fdiv,
float16,
float32,
float64,
float8e4,
float8e5,
function_type,
int1,
int16,
int32,
int64,
int8,
load,
log,
make_block_ptr,
max,
max_contiguous,
maximum,
min,
minimum,
multiple_of,
num_programs,
pi32_t,
pointer_type,
program_id,
reduce,
reshape,
sin,
sqrt,
static_assert,
static_print,
store,
sum,
static_range,
tensor,
trans,
triton,
uint16,
uint32,
uint64,
uint8,
umulhi,
view,
void,
where,
xor_sum,
)
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
uint32_to_uniform_float,
)
__all__ = [
"abs",
"advance",
"arange",
"argmin",
"argmax",
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
"bfloat16",
"block_type",
"broadcast",
"broadcast_to",
"builtin",
"cat",
"cdiv",
"constexpr",
"cos",
"debug_barrier",
"device_assert",
"device_print",
"dot",
"dtype",
"exp",
"expand_dims",
"extra",
"fdiv",
"float16",
"float32",
"float64",
"float8e4",
"float8e5",
"full",
"function_type",
"int1",
"int16",
"int32",
"int64",
"int8",
"ir",
"math",
"load",
"log",
"make_block_ptr",
"max",
"max_contiguous",
"maximum",
"min",
"minimum",
"multiple_of",
"num_programs",
"pair_uniform_to_normal",
"philox",
"philox_impl",
"pi32_t",
"pointer_type",
"program_id",
"rand",
"rand4x",
"randint",
"randint4x",
"randn",
"randn4x",
"ravel",
"reduce",
"reshape",
"sigmoid",
"sin",
"softmax",
"sqrt",
"static_range",
"static_assert",
"static_print",
"store",
"sum",
"swizzle2d",
"tensor",
"trans",
"triton",
"uint16",
"uint32",
"uint32_to_uniform_float",
"uint64",
"uint8",
"umulhi",
"view",
"void",
"where",
"xor_sum",
"zeros",
"zeros_like",
]

Binary file not shown.

Binary file not shown.

1729
pkgs/triton/language/core.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
from . import cuda
__all__ = ['cuda']

Binary file not shown.

View File

@@ -0,0 +1,19 @@
import os
from .. import core
__path__ = os.path.dirname(os.path.abspath(__file__))
@core.extern
def globaltimer(_builder=None):
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
{tuple(): ("globaltimer", core.dtype("int64")),
}, is_pure=False, _builder=_builder)
@core.extern
def smid(_builder=None):
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
{tuple(): ("smid", core.dtype("int32")),
}, is_pure=True, _builder=_builder)

1537
pkgs/triton/language/math.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
import triton
from . import core as tl
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
# -------------------
# randint
# -------------------
@triton.jit
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
"""
for _ in tl.static_range(n_rounds):
# for _ in range(n_rounds):
# update random state
A = PHILOX_ROUND_A
B = PHILOX_ROUND_B
_c0, _c2 = c0, c2
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
c1 = B * _c2
c3 = A * _c0
# raise key
k0 = k0 + PHILOX_KEY_A
k1 = k1 + PHILOX_KEY_B
return c0, c1, c2, c3
@triton.jit
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
seed = seed.to(tl.uint64)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
seed_lo = (seed & 0xffffffff).to(tl.uint32)
c0 = c0.to(tl.uint32, bitcast=True)
c1 = c1.to(tl.uint32, bitcast=True)
c2 = c2.to(tl.uint32, bitcast=True)
c3 = c3.to(tl.uint32, bitcast=True)
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
@triton.jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
ret, _, _, _ = randint4x(seed, offset, n_rounds)
return ret
@triton.jit
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`.
This is the maximally efficient entry point
to Triton's Philox pseudo-random number generator.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
# _0 = tl.zeros(offset.shape, offset.dtype)
_0 = offset * 0
return philox(seed, offset, _0, _0, _0, n_rounds)
# -------------------
# rand
# -------------------
# @triton.jit
# def uint32_to_uniform_float(x):
# """
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
# """
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
# return x * two_to_the_minus_32
@triton.jit
def uint32_to_uniform_float(x):
"""
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
"""
x = x.to(tl.int32, bitcast=True)
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
scale = 4.6566127342e-10
x = tl.where(x < 0, -x - 1, x)
return x * scale
@triton.jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`U(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
offset = offset.to(tl.uint32, bitcast=True)
source = randint(seed, offset, n_rounds)
return uint32_to_uniform_float(source)
@triton.jit
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offsets` block,
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
offsets = offsets.to(tl.uint32, bitcast=True)
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
u3 = uint32_to_uniform_float(i3)
u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4
# -------------------
# randn
# -------------------
@triton.jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
n1, n2 = pair_uniform_to_normal(u1, u2)
n3, n4 = pair_uniform_to_normal(u3, u4)
return n1, n2, n3, n4

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from ..runtime.jit import jit
from . import core
# -----------------------
# Standard library
# -----------------------
@jit
def cdiv(x, div):
"""
Computes the ceiling division of :code:`x` by :code:`div`
:param x: the input number
:type input: Block
:param div: the divisor
:param div: Block
"""
return (x + div - 1) // div
@jit
@core._add_math_1arg_docstr("sigmoid")
def sigmoid(x):
return 1 / (1 + core.exp(-x))
@jit
@core._add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding=False):
z = x - core.max(x, 0)
num = core.exp(z)
den = core.sum(num, 0)
return core.fdiv(num, den, ieee_rounding)
@jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`.
:param x: the input tensor
:type x: Block
"""
return core.view(x, [x.numel])
@jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
Transforms indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
into
[[0, 2, 4 , 6 ],
[1, 3, 5 , 7 ],
[8, 10, 12, 14],
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i * size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = core.minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j
@jit
def zeros(shape, dtype):
"""
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
return core.full(shape, 0, dtype)
@jit
def zeros_like(input):
return zeros(input.shape, input.dtype)

View File

@@ -0,0 +1,17 @@
# from .conv import _conv, conv
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .flash_attention import attention
from .matmul import _matmul, matmul
from .bmm_matmul import _bmm, bmm
__all__ = [
"blocksparse",
"_cross_entropy",
"cross_entropy",
"_matmul",
"matmul",
"_bmm",
"bmm",
"attention",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,7 @@
from .matmul import matmul
from .softmax import softmax
__all__ = [
"matmul",
"softmax",
]

View File

@@ -0,0 +1,437 @@
import torch
import triton
import triton.language as tl
# ********************************************************
# --------------------------------------------------------
# Sparse = Dense x Dense (SDD)
# This operation uses super-blocking to make sure that
# it's done efficiently when small blocks can be grouped
# together
# --------------------------------------------------------
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@triton.jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
block_id = tl.program_id(0) + grid_offset
lut += block_id * 3
# offsets
off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head
# initialize pointers to A
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A \
+ off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B \
+ off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
acc += tl.dot(a, b, out_dtype=tl.float32)
a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk
c = acc.to(C.dtype.element_ty)
# ---------------- #
# Epilogue #
# ---------------- #
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C \
+ off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# (A * B)^T = B^T * A^T
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
# shape constraints
a_dim = -2 if trans_a else -1
b_dim = -1 if trans_b else -2
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
if Ka != Kb:
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
# allocate output
if out is None:
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
else:
assert out.shape == (a.shape[0], lut.shape[0], block, block)
c = out
grid = [c.shape[1], 1, c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
)
return c
def sdd_lut(layout, block, device):
lut = layout.nonzero(as_tuple=False).to(device).int()
lut = lut.contiguous()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel.
# -----------------------------
@triton.jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2)
header = lut + pid_n * 4
offset = tl.load(header + 0)
K = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \
+ block_id * stride_ha \
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B (dense)
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa)
b = tl.load(pb)
acc += tl.dot(a, b, out_dtype=tl.float32)
pa += inc_a
pb += inc_b * stride_bk
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
c = acc.to(C.dtype.element_ty)
# initialize pointers to C
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C \
+ off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0)
BS1 = b.size(1)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# allocate output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
# meta-parameter heuristics
TILE_N = 128
# compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
# exit()
return c
def dsd_lut(layout, block, step, trans, device):
"""
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
Example (BLOCK=32, STEP=16)
[[1, 0, 0, 1, 0],
[0, 1, 1, 0, 1],
[1, 0, 1, 0, 0]]
Then the offsets for A are
[0 , 16, 32, 48] <- row 0
\\----/ \\----/
col=0 col=3
[64, 80, 96, 112, 128, 144] <- row 1
\\----/ \\----/ \\------/
col=1 col=2 col=3
[160, 176, 192, 208]
which leads to increments table
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
Because B is dense, the offsets are
[0, 16, 96, 112] <- row 0
[32, 48, 64, 80] <- row 1
[0, 16, 64, 80] <- row 2
"""
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
sizes = sizes.flatten()
segments = sizes * step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
# -------------------------------
# dense input pointer increments
# -------------------------------
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
# that is smaller than the block size, so we need to do a bit of extra work
# to handle this case
B_idx = nnz[:, 2] * block
B_incs = B_idx.clone()
B_incs[1:] -= B_idx[:-1]
div = block // step
B_incs = B_incs.view(-1, 1).repeat(1, div)
B_incs[:, 1:] = step
B_incs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
B_incs = B_incs.view(-1)
# -------------------------------
# sparse input pointer increments
# -------------------------------
# same as above, except that the increments are in the sparse memory layout
if trans:
A_idx = torch.arange(num_blocks, device=layout.device)
else:
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone().long()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
A_incs = A_idx * block * block
A_incs[1:] -= A_idx[:-1] * block * block
A_incs = A_incs.view(-1, 1).repeat(1, div)
if trans:
A_incs[:, 1:] = step
A_incs[:, 0] -= (div - 1) * step
else:
A_incs[:, 1:] = step * block
A_incs[:, 0] -= (div - 1) * step * block
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
A_incs = A_incs.view(-1)
# create header
width = col_id.size(0)
offsets = offsets * 2 * div + 4 * width
segments = segments * div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
# pad by a factor 2*MAX_NUM_STAGES
# to accommodate pre-fetching inside the kernel
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
incs = torch.cat((incs, pad))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
# AB = (B^T A^T)^T
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.db_lut = db_lut
ctx.db_width = db_width
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
ctx.trans_c = trans_c
ctx.has_out = out is not None
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
da, db = None, None
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
dout = dc if ctx.has_out else None
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, dout
class matmul:
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
self.block = block
self.mode = mode
self.trans_a = trans_a
self.trans_b = trans_b
self.trans_c = trans_c
self.layout = layout
self.spdims = layout.shape
step = min(block, 32)
if self.mode == 'sdd':
self.c_lut, self.c_width = sdd_lut(layout, block, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
if self.mode == 'dsd':
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
return c

View File

@@ -0,0 +1,239 @@
import torch
import triton
import triton.language as tl
def num_warps(n):
if n <= 128:
return 1
if n <= 256:
return 2
if n <= 512:
return 4
if n <= 4096:
return 8
return 16
@triton.jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# pointer offset
off_a = z * stride_xz
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load X
mask = block_n < size
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
a = a.to(tl.float32)
# compute
out = a
out *= scale
# apply relative attention
if R is not None:
R += z * stride_zr
R += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent)
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
out += rel_logits
out = out.to(tl.float32)
# apply causal mask
out = tl.where((ns > m) & is_causal, -float("inf"), out)
# computation
out = tl.softmax(out)
# write-back
tl.store(Out + off_a + lane_n, out, mask=mask)
@triton.jit
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
Out, stride_zout,
scale,
LUT,
DR, extent, stride_zr, stride_hr, stride_er,
is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# row-col offset
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
mask = block_n < size
# pointers
As = Out + z * stride_zout + off_mn
DOuts = DOut + z * stride_zdout + off_mn
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load data
a = tl.load(As + lane_n, mask=mask, other=0.0)
a = a.to(tl.float32)
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
dout = dout.to(tl.float32)
# compute
a = tl.where((ns > m) & is_causal & (a == a), 0., a)
da = a * (dout - tl.sum(a * dout, 0))
# apply relative attention
if DR is not None:
DR += z * stride_zr
DR += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
da = da * scale
# convert da
# write-back
DAs = DA + z * stride_zdx + off_mn
tl.store(DAs + lane_n, da, mask=mask)
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
sizes = _empty.clone()
# sizes along rows
for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
total_sizes = sizes * block
# offsets in block format
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices
columns = layout.nonzero(as_tuple=False)[:, 2]
header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, columns)).type(torch.int32).to(device)
return lut, int(total_sizes.max())
@staticmethod
def forward(
ctx, a, scale, rel_logits, is_causal,
spdims, block, lut, maxlut, is_dense
):
if scale is not None and isinstance(scale, torch.Tensor):
assert scale.device.type == "cpu"
scale = scale.item()
M = a.shape[0]
grid = [spdims[0], spdims[1] * block, M]
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=triton.next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
)
# save to context
# ctx.mark_dirty(x)
ctx.save_for_backward(out, lut)
ctx.spdims = spdims
ctx.block = block
ctx.maxlut = maxlut
ctx.scale = scale
ctx.rel_shape = rel_shape
ctx.rel_strides = rel_strides
ctx.rel_dtype = a.dtype
ctx.is_dense = is_dense
ctx.is_causal = is_causal
return out
@staticmethod
def backward(ctx, dout):
# retrieve from context
out, lut = ctx.saved_tensors
# relative logits gradients
dr = None
if ctx.needs_input_grad[3]:
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
# run kernel
M = out.shape[0]
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
da = torch.empty_like(dout)
_blocksparse_softmax_bwd[grid](
da, da.stride(0),
dout, dout.stride(0),
out, out.stride(0),
ctx.scale,
lut,
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
)
return (da, None, None, dr, None,
None, None, None, None, None,
None,
None, None, None,
None,
None, None, None
)
class softmax:
def __init__(self, layout, block, device, is_dense=False):
self.spdims = layout.shape
self.layout = layout
self.block = block
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
self.is_dense = is_dense
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
if rel_logits is not None and rel_logits.dtype != a.dtype:
raise ValueError(f"relative position embedding must be {a.dtype}")
a = _softmax.apply(
a, scale, rel_logits, is_causal,
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
)
return a

View File

@@ -0,0 +1,163 @@
import torch
import triton
import triton.language as tl
from .matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [1]:
# TODO support block size 16 for MFMA dot op
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 4 if block_n <= 64 else 8
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
#for split_k in [2, 4, 8, 16]:
# configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
def get_configs_compute_bound():
configs = []
for block_m in [64, 128, 256]:
for block_n in [64, 128, 256]:
for block_k in [32, 64, 128]:
num_warps = 8 if block_n <= 64 else 16
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=1, num_warps=num_warps))
return configs
@triton.autotune(
configs=[
] + get_configs_compute_bound() + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % args['BLOCK_K'] == 0,
})
@triton.jit
def _bmm_kernel(A, B, C, M, N, K,
stride_aq, stride_am, stride_ak,
stride_bq, stride_bk, stride_bn,
stride_cq, stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
idx_q = tl.program_id(1) # batch dimension for BMM
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_q = tl.program_id(1) # batch dimension for BMM
idx_m = rm[:, None]
idx_n = rn[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn + idx_q * stride_cq)
mask = (idx_m < M) & (idx_n < N)
# handles write-back with reduction-splitting
tl.store(C, acc, mask=mask)
class _bmm(torch.autograd.Function):
kernel = _bmm_kernel
_locks = {}
@staticmethod
def _call(a, b, dot_out_dtype):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
#only MR support Trans layout
if hasattr(torch, "corex"):
capability = torch.cuda.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
if (capability < 71):
if a.stride(0) >= 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) >= 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[0] == b.shape[0], "incompatible dimensions"
assert a.shape[2] == b.shape[1], "incompatible dimensions"
B, M, K = a.shape
_, _, N = b.shape
# allocates output
c = torch.empty((B, M, N), device=device, dtype=a.dtype)
if dot_out_dtype is None:
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
else:
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
if dot_out_dtype == torch.float16:
dot_out_dtype = tl.float16
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), B, 1)
_bmm_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1), a.stride(2),
b.stride(0), b.stride(1), b.stride(2),
c.stride(0), c.stride(1), c.stride(2),
dot_out_dtype=dot_out_dtype,
GROUP_M=8)
return c
@staticmethod
def forward(ctx, a, b, dot_out_dtype=None):
return _bmm._call(a, b, dot_out_dtype=dot_out_dtype)
bmm = _bmm.apply

View File

@@ -0,0 +1,94 @@
import torch
import triton
import triton.language as tl
def num_warps(N):
if N < 2048:
return 4
elif N < 8192:
return 8
return 16
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)
# pointers to logit and probs
LOGITS = LOGITS + row * N + cols
WRIT_PROBS = PROBS + row * N + cols
READ_PROBS = PROBS + row * N + idx
# write-back negative log-probs
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
logits = logits.to(tl.float32)
logits = logits - tl.max(logits, 0)
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
tl.store(WRIT_PROBS, probs, mask=cols < N)
# There is a bug in the compiler, which fails to insert a barrier here.
# We add it explicitly for now. Will be fixed soon.
tl.debug_barrier()
# write-back loss
probs = tl.load(READ_PROBS)
tl.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)
# pointers to probs
PROBS = PROBS + row * N + cols
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
# and we have -log(p[k]) stored in PROBS, so this is easy
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
probs = tl.exp(probs.to(tl.float32))
delta = cols == idx
# write result in-place in PROBS
dout = tl.load(DPROBS + row)
din = (probs - delta) * dout
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
class _cross_entropy(torch.autograd.Function):
@classmethod
def forward(cls, ctx, logits, indices):
# make sure we can use triton
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
# make kernel
device, dtype = logits.device, logits.dtype
n_cols = logits.shape[-1]
# run the kernel
result = torch.empty_like(indices, dtype=dtype, device=device)
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
grid = lambda opt: (logits.numel() // n_cols, )
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
# save for backward
ctx.save_for_backward(neg_logprobs, indices)
return result
@classmethod
def backward(cls, ctx, dneg_logprobs):
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
so we initialize the gradient as neg_logprobs, so we can just exponentiate
to get p[k], which is most of what we need... neg_logprobs will be
modified in place to become the gradient we want
"""
# load saved tensors
neg_logprobs, indices = ctx.saved_tensors
# run the kernel
# neg_logprobs will be modified in place to become our gradient:
n_cols = neg_logprobs.shape[-1]
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
return neg_logprobs, None
cross_entropy = _cross_entropy.apply

View File

@@ -0,0 +1,271 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import torch
import triton
import triton.language as tl
from triton.common.build import is_corex
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
# only support for Ampere now
capability = torch.cuda.get_device_capability()
if not is_corex():
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
BLOCK = 128
else:
BLOCK = 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2 if not is_corex() else 1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
BLOCK = 128 if not is_corex() else 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
num_warps = 16 if is_corex() and ctx.BLOCK_DMODEL > 64 else 8
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply

184
pkgs/triton/ops/matmul.py Normal file
View File

@@ -0,0 +1,184 @@
import torch
import triton
import triton.language as tl
from .matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
if hasattr(torch, "corex"):
return configs
for num_stages in [1, 2]:
# TODO support block size 16 for MFMA dot op
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
def get_configs_compute_bound():
configs = []
if hasattr(torch, "corex"):
for block_m in [32, 64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for num_stages in [1, 2]:
num_warps = 16 if block_m >= 128 or block_n >=128 or block_k >= 128 else 8
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
return configs
def get_nv_config():
configs = []
if hasattr(torch, "corex"):
return configs
configs = [# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
]
return configs
@triton.autotune(
configs=get_nv_config() + get_configs_compute_bound() + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
class _matmul(torch.autograd.Function):
kernel = _kernel
_locks = {}
@staticmethod
def _call(a, b, dot_out_dtype):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
if dot_out_dtype is None:
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
else:
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
if dot_out_dtype == torch.float16:
dot_out_dtype = tl.float16
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
GROUP_M=8)
return c
@staticmethod
def forward(ctx, a, b, dot_out_dtype=None):
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype)
matmul = _matmul.apply

View File

@@ -0,0 +1,164 @@
import heapq
import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.runtime import driver
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops
def get_tflops(backend, device, num_ctas, num_warps, dtype):
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8 and dtype == torch.float32:
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
A, B, C,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
):
''' return estimated running time in ms
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
dtype = A.dtype
dtsize = A.element_size()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
compute_ms = total_ops / tput
# time to load data
num_sm = driver.utils.get_device_properties(device)["multiprocessor_count"]
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
# loading time in ms
load_ms = total_dram / dram_bw + total_l2 / l2_bw
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram / store_bw
else:
reduce_bw = store_bw
store_ms = store_c_dram / reduce_bw
# c.zero_()
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
store_ms += zero_ms
total_time_ms = compute_ms + load_ms + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def early_config_prune(configs, named_args):
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['A'].element_size()
dtype = named_args['A'].dtype
# 1. make sure we have enough smem
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
max_shared_memory = driver.utils.get_device_properties(device)["max_shared_mem"]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
configs = pruned_configs
# Some dtypes do not allow atomic_add
if dtype not in [torch.float16, torch.float32]:
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {}
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map:
configs_map[key].append((config, num_stages))
else:
configs_map[key] = [(config, num_stages)]
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if capability[0] >= 8:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas / min(4, num_warps) * 8
ldgsts_latency = 300 # Does this matter?
optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])
else: # Volta & Turing only supports num_stages <= 2
if hasattr(torch, "corex"):
for stage in range(len(v)):
random_config = v[stage][0]
random_config.num_stages = v[stage][1]
pruned_configs.append(random_config)
else:
random_config = v[0][0]
random_config.num_stages = 2
pruned_configs.append(random_config)
return pruned_configs

View File

@@ -0,0 +1,21 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .driver import driver
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
version_key)
__all__ = [
"driver",
"Config",
"Heuristics",
"autotune",
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
"reinterpret",
"TensorWrapper",
"OutOfResources",
"MockTensor",
"Autotuner",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,305 @@
from __future__ import annotations
import builtins
import time
from typing import Dict
import json
import os
import hashlib
from ..testing import do_bench
from .jit import KernelInterface
from .cache import default_cache_dir
def build_best_config_hash(args_names, key):
cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
hasher = hashlib.sha256()
hasher.update(f"{'_'.join(args_names) + str(key)}\n".encode())
cfg_hash = hasher.hexdigest()
cfg_hash_dir = os.path.join(cache_dir, cfg_hash)
cfg_hash_file = os.path.splitext(cfg_hash)[0] + ".best_config"
cfg_hash_file = os.path.join(cfg_hash_dir, cfg_hash_file)
return cfg_hash_dir, cfg_hash_file
def load_best_config(args_names, key):
_, cfg_hash_file = build_best_config_hash(args_names, key)
if os.path.exists(cfg_hash_file):
with open(cfg_hash_file) as fd:
best_config = json.loads(fd.read())
num_warps = best_config.pop('num_warps') if 'num_warps' in best_config else 4
num_stages = best_config.pop('num_stages') if 'num_stages' in best_config else 1
return best_config, num_warps, num_stages
return None
def save_best_config(cfg, args_names, key):
cfg_hash_dir, cfg_hash_file = build_best_config_hash(args_names, key)
if os.path.exists(cfg_hash_dir):
return
os.makedirs(cfg_hash_dir, exist_ok=True)
with open(cfg_hash_file, "w") as fd:
fd.write(
json.dumps(
{
**cfg.kwargs,
"num_warps": cfg.num_warps,
"num_stages": cfg.num_stages,
}
)
)
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.required = required
self.limit = limit
self.name = name
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.required, self.limit, self.name))
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
'''
if not configs:
self.configs = [Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
try:
return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return [float('inf'), float('inf'), float('inf')]
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
all_args = {**self.nargs, **kwargs}
_args = []
for name in self.arg_names:
if name in all_args:
_args.append(all_args[name])
key = [_args[i] for i in self.key_idx]
divisibility = 16
for arg in args:
if hasattr(arg, "data_ptr"):
key.append(arg.dtype)
key.append(arg.data_ptr() % divisibility == 0)
elif isinstance(arg, int):
key.append(arg)
key = tuple(key)
if key not in self.cache:
load_config = load_best_config(self.arg_names, key)
if load_config:
best_config, num_warps, num_stages = load_config
config = Config(best_config, num_warps, num_stages)
self.cache[key] = config
self.hook(args)
else:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
save_best_config(self.cache[key], self.arg_names, key)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
self.nargs = None
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
`num_warps=8`, then each kernel instance will be automatically parallelized to
cooperatively execute using `8 * 32 = 256` threads.
:type num_warps: int
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_stages: {self.num_stages}')
return ', '.join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple times.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
resets the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
return decorator
class Heuristics(KernelInterface):
def __init__(self, fn, arg_names, values) -> None:
self.fn = fn
self.values = values
self.arg_names = arg_names
def run(self, *args, **kwargs):
for v, heur in self.values.items():
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
return self.fn.run(*args, **kwargs)
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
:param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
return Heuristics(fn, fn.arg_names, values)
return decorator

View File

@@ -0,0 +1,131 @@
#include "cuda.h"
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line) {
if (code != CUDA_SUCCESS) {
const char *prefix = "Triton Error [CUDA]: ";
const char *str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}
#define CUDA_CHECK(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
if (PyErr_Occurred()) \
return NULL; \
}
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
// Get device handle
CUdevice device;
cuDeviceGet(&device, device_id);
// create a struct to hold device properties
int max_shared_mem;
int multiprocessor_count;
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
max_shared_mem, "multiprocessor_count",
multiprocessor_count, "sm_clock_rate", sm_clock_rate,
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
mem_bus_width);
}
static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name;
const char *data;
Py_ssize_t data_size;
int shared;
int device;
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
&device)) {
return NULL;
}
CUfunction fun;
CUmodule mod;
int32_t n_regs = 0;
int32_t n_spills = 0;
// create driver handles
CUcontext pctx = 0;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
if (PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided cubin into CUDA driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
NULL, // documentation
-1, // size
ModuleMethods};
PyMODINIT_FUNC PyInit_cuda_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}

View File

@@ -0,0 +1,120 @@
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdio.h>
#include <stdlib.h>
static inline void gpuAssert(hipError_t code, const char *file, int line) {
{
if (code != HIP_SUCCESS) {
{
const char *prefix = "Triton Error [HIP]: ";
const char *str = hipGetErrorString(code);
char err[1024] = {0};
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}
}
}
#define HIP_CHECK(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
if (PyErr_Occurred()) \
return NULL; \
}
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
return NULL;
hipDeviceProp_t props;
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
// create a struct to hold device properties
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
props.sharedMemPerBlock, "multiprocessor_count",
props.multiProcessorCount, "sm_clock_rate",
props.clockRate, "mem_clock_rate", props.memoryClockRate,
"mem_bus_width", props.memoryBusWidth);
}
static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name;
const char *data;
Py_ssize_t data_size;
int shared;
int device;
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
&device)) {
return NULL;
}
// Open HSACO file
FILE *hsaco_file;
if ((hsaco_file = fopen(data, "rb")) == NULL) {
return NULL;
}
// Read HSCAO file into Buffer
fseek(hsaco_file, 0L, SEEK_END);
size_t hsaco_file_size = ftell(hsaco_file);
unsigned char *hsaco =
(unsigned char *)malloc(hsaco_file_size * sizeof(unsigned char));
rewind(hsaco_file);
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file);
fclose(hsaco_file);
// set HIP options
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
// launch HIP Binary
hipModule_t mod;
hipFunction_t fun;
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval);
hipModuleGetFunction(&fun, mod, name);
free(hsaco);
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
if (PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided hsaco into HIP driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
NULL, // documentation
-1, // size
ModuleMethods};
PyMODINIT_FUNC PyInit_hip_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}

View File

@@ -0,0 +1,131 @@
import json
import os
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Optional
def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
class CacheManager(ABC):
def __init__(self, key):
pass
@abstractmethod
def get_file(self, filename) -> Optional[str]:
pass
@abstractmethod
def has_file(self, filename) -> bool:
pass
@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
@abstractmethod
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
pass
@abstractmethod
def put_group(self, filename: str, group: Dict[str, str]):
pass
class FileCacheManager(CacheManager):
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
def has_file(self, filename):
if not self.cache_dir:
return False
return os.path.exists(self._make_path(filename))
def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
return self._make_path(filename)
else:
return None
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
# Invalid group data.
if child_paths is None:
return None
result = {}
for c in child_paths:
p = self._make_path(c)
if not os.path.exists(p):
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
result[c] = p
return result
# Note a group of pushed files as being part of a group
def put_group(self, filename: str, group: Dict[str, str]):
if not self.cache_dir:
return
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)
def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
return
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
rnd_id = random.randint(0, 1000000)
# we use the PID incase a bunch of these around so we can see what PID made it
pid = os.getpid()
# use tempfile to be robust against program interruptions
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
return filepath
__cache_cls = FileCacheManager
__cache_cls_nme = "DEFAULT"
def get_cache_manager(key) -> CacheManager:
import os
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
global __cache_cls
global __cache_cls_nme
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)
__cache_cls_nme = user_cache_manager
return __cache_cls(key)

View File

@@ -0,0 +1,175 @@
import abc
import hashlib
import os
import tempfile
from pathlib import Path
from ..common.build import _build
from .cache import get_cache_manager
class DriverBase(metaclass=abc.ABCMeta):
CUDA = 0
HIP = 1
@staticmethod
def third_party_dir():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
def __init__(self) -> None:
pass
# -----------------------------
# CUDA
# -----------------------------
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "cuda_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("cuda_utils", src_path, tmpdir)
cache.put(src, "main.c", binary=False)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = CudaUtils()
self.backend = self.CUDA
# -----------------------------
# HIP
# -----------------------------
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "backends", "hip.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "hip_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("hip_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(HIPDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = HIPUtils()
self.backend = self.HIP
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = None
self.backend = None
# -----------------------------
# Driver
# -----------------------------
class LazyProxy:
def __init__(self, init_fn):
self._init_fn = init_fn
self._obj = None
def _initialize_obj(self):
if self._obj is None:
self._obj = self._init_fn()
def __getattr__(self, name):
self._initialize_obj()
return getattr(self._obj, name)
def __setattr__(self, name, value):
if name in ['_init_fn', '_obj']:
super().__setattr__(name, value)
else:
self._initialize_obj()
setattr(self._obj, name, value)
def __delattr__(self, name):
self._initialize_obj()
delattr(self._obj, name)
def __repr__(self):
if self._obj is None:
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
return repr(self._obj)
def __str__(self):
self._initialize_obj()
return str(self._obj)
def initialize_driver():
import torch
if torch.version.hip is not None:
return HIPDriver()
elif torch.cuda.is_available():
return CudaDriver()
else:
return UnsupportedDriver()
driver = LazyProxy(initialize_driver)

View File

@@ -0,0 +1,15 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.required = required
self.limit = limit
self.name = name
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.required, self.limit, self.name))

573
pkgs/triton/runtime/jit.py Normal file
View File

@@ -0,0 +1,573 @@
from __future__ import annotations, division
import ast
import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
import torch
import triton
def get_disable_sme():
disable_sme = os.getenv("TRITON_DISABLE_SME", default="0")
cc = torch.cuda.get_device_capability()
cc = cc[0] * 10 + cc[1]
if cc == 70: # for ivcore10
disable_sme = "1"
return disable_sme
def get_corex_sme(args, tl_args, enable_sme=True):
can_use_sme = 0
if not enable_sme:
return can_use_sme
import torch
if not (hasattr(torch, "corex") and torch.corex == True):
return can_use_sme
close_sme = get_disable_sme()
if close_sme == "1":
return can_use_sme
index = 0
for i, arg_name in enumerate(args):
arg = args.get(arg_name)
if (i in tl_args):
continue
if (isinstance(arg, int) and arg == 1):
continue
if torch.is_tensor(arg) and arg.dtype in [torch.float16, torch.float32, torch.bfloat16, torch.int8] and arg.dim() >= 2:
dim_M = arg.shape[-2]
dim_K = arg.shape[-1]
sme_dim = 64 / arg.element_size()
if (arg.is_contiguous() and dim_K % sme_dim == 0) or \
(not arg.is_contiguous() and dim_M % sme_dim == 0):
can_use_sme = (1 << index) | can_use_sme
index += 1
return can_use_sme
def get_cuda_stream(idx=None):
if idx is None:
idx = get_current_device()
try:
from torch._C import _cuda_getCurrentRawStream
return _cuda_getCurrentRawStream(idx)
except ImportError:
import torch
return torch.cuda.current_stream(idx).cuda_stream
def get_current_device():
import torch
return torch.cuda.current_device()
def set_current_device(idx):
import torch
torch.cuda.set_device(idx)
def get_device_capability(idx):
import torch
return torch.cuda.get_device_capability(idx)
T = TypeVar('T')
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
class DependenciesFinder(ast.NodeVisitor):
"""
This AST visitor is used to find dependencies of a JITFunction. This can
be used to invalidate a JITFunction's hash when its source code -- or
that of its dependencies -- changes.
"""
def __init__(self, globals, src) -> None:
super().__init__()
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
self.globals = globals
def visit_Name(self, node):
return self.globals.get(node.id, None)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
return None
return getattr(lhs, node.attr)
def visit_Call(self, node):
func = self.visit(node.func)
if func is None:
return
if inspect.isbuiltin(func):
return
if func.__module__ and func.__module__.startswith('triton.'):
return
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
noinline = str(getattr(func, 'noinline', False))
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(*triton.__path__, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@staticmethod
def _key_of(arg):
if hasattr(arg, "dtype"):
return arg.dtype
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
elif isinstance(x, int):
return x % JITFunction.divisibility == 0
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
@staticmethod
def _type_of(key):
# None are nullptr -- implicitly converted to *i8
if key is None:
return '*i8'
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
"float8e5": "fp8e5",
"float8e4": "fp8e4",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint16": "u16",
"uint32": "u32",
"uint64": "u64",
}
# reinterpret can create triton type
for v in list(tys.values()):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
return signature
def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _get_arg_specialization_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
if arg_annotation == '':
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \
else (False,)'
elif 'Tensor' in arg_annotation:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
elif arg_annotation == 'int':
return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)'
else:
return '(False,)'
def _get_arg_sig_key(self, arg) -> str:
arg_annotation = self.__annotations__.get(arg, '')
if 'Tensor' in arg_annotation:
return f'{arg}.dtype'
elif arg_annotation == 'bool':
return "i1"
elif arg_annotation == 'float':
return 'fp32'
else:
return f'_key_of({arg})'
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
specializations = []
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [self._get_arg_specialization_key(arg)]
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, enable_sme=True, extern_libs=None, stream=None, warmup=False, device=None):
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
use_sme = get_corex_sme({{{grid_args}}}, self.constexprs, enable_sme)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug, use_sme)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
if callable(grid):
grid = grid({{{grid_args}}})
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device is None:
device = get_current_device()
set_current_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
bin = cache[device].get(key, None)
if bin is not None:
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
return bin
# kernel not cached -- compile
else:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, use_sme=use_sme)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
"""
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
"cache": self.cache, "triton": triton,
"get_current_device": get_current_device,
"set_current_device": set_current_device,
"get_corex_sme": get_corex_sme}
exec(src, scope)
return scope[self.fn.__name__]
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
# cache of just-in-time compiled kernels
self.cache = defaultdict(dict)
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
# index of constexprs
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
# launcher
self.run = self._make_launcher()
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
def cache_key(self):
# TODO : hash should be attribute of `self`
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
return self.hash
def warmup(self, *args, **kwargs):
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
self.hash = None
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
# -----------------------------------------------------------------------------
# `jit` decorator
# -----------------------------------------------------------------------------
@overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
interpret: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if interpret:
from ..debugger.debugger import GridSelector
return GridSelector(fn)
else:
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)
else:
return decorator
# -----------------------------------------------------------------------------
# Utilities for mocking tensors
# -----------------------------------------------------------------------------
class MockTensor:
"""
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if arg.__class__.__name__ == "dtype" and\
arg.__module__ == "torch":
return MockTensor(arg)
return arg
def __init__(self, dtype):
self.dtype = dtype
@staticmethod
def data_ptr():
return 0 # optimistically assumes multiple of 16
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif hasattr(tensor, "data_ptr"):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')

424
pkgs/triton/testing.py Normal file
View File

@@ -0,0 +1,424 @@
import functools
import os
import subprocess
import sys
from contextlib import contextmanager
import triton._C.libtriton.triton as _triton
from triton.common.build import is_corex
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(',')
ret = [int(x) for x in ret]
return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
quantiles=None,
fast_flush=True,
return_mode="mean"):
assert return_mode in ["min", "max", "mean", "median"]
import torch
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
fn()
torch.cuda.synchronize()
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()
def assert_close(x, y, atol=None, rtol=None, err_msg=''):
import numpy as np
import torch
# canonicalize arguments to be tensors
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
if not isinstance(y, torch.Tensor):
y = torch.tensor(y)
# absolute tolerance
if atol is None:
atol = 1e-2
atol = atol(x.dtype) if callable(atol) else atol
# relative tolerance hook
if rtol is None:
rtol = 0.
rtol = rtol(x.dtype) if callable(rtol) else rtol
# we use numpy instead of pytorch
# as it seems more memory efficient
# pytorch tends to oom on large tensors
if isinstance(x, torch.Tensor):
if x.dtype == torch.bfloat16:
x = x.float()
x = x.cpu().detach().numpy()
if isinstance(y, torch.Tensor):
if y.dtype == torch.bfloat16:
y = y.float()
y = y.cpu().detach().numpy()
# we handle size==1 case separately as we can
# provide better error message there
if x.size > 1 or y.size > 1:
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
return
if not np.allclose(x, y, atol=atol, rtol=rtol):
raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
class Benchmark:
"""
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
"""
def __init__(
self,
x_names,
x_vals,
line_arg,
line_vals,
line_names,
plot_name,
args,
xlabel='',
ylabel='',
x_log=False,
y_log=False,
color=None,
styles=None,
):
"""
Constructor
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
:type x_names: List[str]
:param x_vals: List of values to use for the arguments in :code:`x_names`.
:type x_vals: List[Any]
:param line_arg: Argument name for which different values correspond to different lines in the plot.
:type line_arg: str
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
:type line_vals: List[str]
:param line_names: Label names for the different lines.
:type line_names: List[str]
:param plot_name: Name of the plot.
:type plot_name: str
:param args: List of arguments to remain fixed throughout the benchmark.
:type args: List[str]
:param xlabel: Label for the x axis of the plot.
:type xlabel: str, optional
:param ylabel: Label for the y axis of the plot.
:type ylabel: str, optional
:param x_log: Whether the x axis should be log scale.
:type x_log: bool, optional
:param y_log: Whether the y axis should be log scale.
:type y_log: bool, optional
"""
self.x_names = x_names
self.x_vals = x_vals
self.x_log = x_log
self.line_arg = line_arg
self.line_vals = line_vals
self.line_names = line_names
self.y_log = y_log
self.styles = styles
# plot info
self.xlabel = xlabel
self.ylabel = ylabel
self.plot_name = plot_name
self.args = args
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots, print_data):
import os
import matplotlib.pyplot as plt
import pandas as pd
y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names]
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row_mean, row_min, row_max = [], [], []
for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
try:
y_mean, y_min, y_max = ret
except TypeError:
y_mean, y_min, y_max = ret, None, None
row_mean += [y_mean]
row_min += [y_min]
row_max += [y_max]
df.loc[len(df)] = [x] + row_mean + row_min + row_max
if bench.plot_name:
plt.figure()
ax = plt.subplot()
x = bench.x_names[0]
for i, y in enumerate(bench.line_names):
y_min, y_max = df[y + '-min'], df[y + '-max']
col = bench.styles[i][0] if bench.styles else None
sty = bench.styles[i][1] if bench.styles else None
ax.plot(df[x], df[y], label=y, color=col, ls=sty)
if y_min is not None and y_max is not None:
ax.fill_between(df[x], y_min, y_max, alpha=0.15, color=col)
ax.legend()
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
ax.set_xlabel(xlabel)
ax.set_ylabel(bench.ylabel)
# ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.x_log else "linear")
ax.set_yscale("log" if bench.y_log else "linear")
if show_plots:
plt.show()
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[[bench.x_names[0]] + bench.line_names]
if print_data:
print(bench.plot_name + ':')
print(df)
if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
def run(self, show_plots=False, print_data=False, save_path=''):
has_single_bench = isinstance(self.benchmarks, Benchmark)
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
if save_path:
html = open(os.path.join(save_path, "results.html"), "w")
html.write("<html><body>\n")
for bench in benchmarks:
self._run(bench, save_path, show_plots, print_data)
if save_path:
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
if save_path:
html.write("</body></html>\n")
def perf_report(benchmarks):
"""
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
:param benchmarks: Benchmarking configurations.
:type benchmarks: List of :class:`Benchmark`
"""
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper
def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s '''
import torch
from .runtime import driver
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = driver.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
bus_width = driver.utils.get_device_properties(device)["mem_bus_width"]
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
return bw_gbps
def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None):
import torch
from .runtime import driver
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
if not clock_rate:
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8:
assert dtype == torch.float16 or is_corex()
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else:
if dtype == torch.float32:
ops_per_sub_core = 256
elif dtype in [torch.float16, torch.bfloat16]:
ops_per_sub_core = 512
elif dtype == torch.int8:
ops_per_sub_core = 1024
else:
raise RuntimeError("dtype not supported")
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops
# create decorator that wraps test function into
# a cuda-memcheck system call
def cuda_memcheck(**target_kwargs):
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
import psutil
ppid_name = psutil.Process(os.getppid()).name()
run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
if run_cuda_memcheck and ppid_name != "cuda-memcheck":
path = os.path.realpath(test_fn.__globals__["__file__"])
# get path of current file
env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
test_id = kwargs['request'].node.callspec.id
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
def nvsmi_attr(attrs):
attrs = ",".join(attrs)
cmd = [
"nvidia-smi",
"-i",
"0",
"--query-gpu=" + attrs,
"--format=csv,noheader,nounits",
]
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(",")
ret = [int(x) for x in ret]
return ret
@contextmanager
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
try:
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
]
)
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
]
)
cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0]
cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0]
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
gbps = 640 * 2 * ref_mem_clock * 1e-3
yield tflops, gbps
finally:
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
def get_max_simd_tflops(dtype, backend=None, device=None):
import torch
from .runtime import driver
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
if dtype == torch.float32:
ops_per_sub_core = 32 # 2*16
elif dtype == torch.float16:
ops_per_sub_core = 64
else:
raise RuntimeError("dtype not supported")
else:
if dtype == torch.float32:
ops_per_sub_core = 32
elif dtype in [torch.float16, torch.bfloat16]:
ops_per_sub_core = 64
else:
raise RuntimeError("dtype not supported")
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops

BIN
pkgs/triton/third_party/cuda/bin/ptxas vendored Executable file

Binary file not shown.

19348
pkgs/triton/third_party/cuda/include/cuda.h vendored Executable file

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,24 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b65f1b5..19cc5a9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -62,12 +62,13 @@ include(OCL)
set(AMDGCN_LIB_LIST)
set(AMDGCN_DEP_LIST)
add_subdirectory(irif)
-add_subdirectory(oclc)
-add_subdirectory(ocml)
-add_subdirectory(ockl)
-add_subdirectory(opencl)
-add_subdirectory(hip)
-add_subdirectory(asanrtl)
+#add_subdirectory(oclc)
+#add_subdirectory(ocml)
+#add_subdirectory(ockl)
+#add_subdirectory(opencl)
+#add_subdirectory(hip)
+#add_subdirectory(asanrtl)
+add_subdirectory(cuda2gcn)
enable_testing()
add_subdirectory(test/compile)

Binary file not shown.

View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
pushd .
git clone https://github.com/dfukalov/ROCm-Device-Libs.git
cd ROCm-Device-Libs
git apply ../cuda2gcn.patch
mkdir build
cd build
cmake .. -DCMAKE_PREFIX_PATH=$HOME/.triton/llvm/clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04
make -j4
popd
cp ROCm-Device-Libs/build/amdgcn/bitcode/cuda2gcn.bc .
rm -rf ROCm-Device-Libs

Some files were not shown because too many files have changed in this diff Show More