First commit
This commit is contained in:
19
Dockerfile
Normal file
19
Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM zibo.harbor.iluvatar.com.cn:30000/saas/bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:20250731115755
|
||||
|
||||
RUN pip install --no-cache-dir triton==2.1.0
|
||||
|
||||
COPY pkgs/triton /usr/local/corex/lib64/python3/dist-packages/triton
|
||||
COPY pkgs/triton-2.1.0+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/triton-2.1.0+corex.4.1.2.dist-info
|
||||
COPY pkgs/xformers-0.0.22+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/xformers-0.0.22+corex.4.1.2.dist-info
|
||||
COPY pkgs/xformers /usr/local/corex/lib64/python3/dist-packages/xformers
|
||||
|
||||
COPY paged_attn.py /usr/local/lib/python3.10/site-packages/vllm/attention/ops/paged_attn.py
|
||||
COPY __init__.py /usr/local/lib/python3.10/site-packages/vllm/triton_utils/__init__.py
|
||||
COPY prefix_prefill.py /usr/local/lib/python3.10/site-packages/vllm/attention/ops/prefix_prefill.py
|
||||
|
||||
RUN mkdir /workspace
|
||||
WORKDIR /workspace/
|
||||
|
||||
COPY ./launch_service /workspace/launch_service
|
||||
|
||||
ENTRYPOINT ["./launch_service"]
|
||||
9
__init__.py
Normal file
9
__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
|
||||
__all__ = ["HAS_TRITON"]
|
||||
|
||||
#from vllm.triton_utils.custom_cache_manager import (
|
||||
# maybe_set_triton_cache_manager)
|
||||
#from vllm.triton_utils.libentry import libentry
|
||||
|
||||
__all__ += ["maybe_set_triton_cache_manager", "libentry"]
|
||||
542
attention.py
Normal file
542
attention.py
Normal file
@@ -0,0 +1,542 @@
|
||||
"""Multi-head attention."""
|
||||
import os
|
||||
enable_infer_paged_attn = os.getenv("ENABLE_INFER_PAGED_ATTN",None)
|
||||
from typing import List, Optional
|
||||
|
||||
import importlib
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ixformer.contrib.xformers import ops as xops
|
||||
from ixformer.contrib.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
|
||||
context_attention_fwd)
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
# _PARTITION_SIZE = 512
|
||||
_SUPPORTED_HEAD_SIZES = [64, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 256
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
"""MHA/MQA/GQA layer with PagedAttention.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Reshape and store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention using either
|
||||
xformers or the PagedAttention custom op.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
self.use_ref_attention = self.check_use_ref_attention()
|
||||
|
||||
# TODO align vllm do not need those
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(self.num_kv_heads, dtype=torch.int32),
|
||||
self.num_queries_per_kv)
|
||||
self.register_buffer("head_mapping", head_mapping, persistent=False)
|
||||
|
||||
def check_use_ref_attention(self) -> bool:
|
||||
if not is_hip():
|
||||
return False
|
||||
# For ROCm, check whether flash attention is installed or not.
|
||||
# if not, use_ref_attention needs to be True
|
||||
return importlib.util.find_spec("flash_attn") is None
|
||||
|
||||
def ref_masked_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
seq_len, _, _ = query.shape
|
||||
attn_mask = torch.triu(torch.ones(seq_len,
|
||||
seq_len,
|
||||
dtype=query.dtype,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
||||
|
||||
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
|
||||
key).float()
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
slot_mapping = input_metadata.slot_mapping
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
cache_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
# normal attention
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
|
||||
if self.use_ref_attention:
|
||||
output = self.ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
|
||||
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
alibi_slopes=self.alibi_slopes
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.block_tables, # [BS, max_block_per_request]
|
||||
input_metadata.start_loc,
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.context_lens,
|
||||
input_metadata.max_seq_len,
|
||||
getattr(self, "alibi_slopes", None),
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.head_mapping, # self.num_kv_heads
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
# TODO align
|
||||
"""
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
PagedAttention forward pass.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
cache_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping.flatten(),
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
# normal attention
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0],
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
[seq_len] * batch_size)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||
seq_len, query.dtype)
|
||||
|
||||
if self.use_ref_attention:
|
||||
output = self.ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
|
||||
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
else:
|
||||
query = query.unflatten(0, (batch_size, seq_len))
|
||||
key = key.unflatten(0, (batch_size, seq_len))
|
||||
value = value.unflatten(0, (batch_size, seq_len))
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.block_tables, # [BS, max_block_per_request]
|
||||
input_metadata.start_loc,
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.context_lens,
|
||||
input_metadata.max_seq_len,
|
||||
getattr(self, "alibi_slopes", None),
|
||||
)
|
||||
|
||||
else:
|
||||
# Decoding run.
|
||||
output = _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
"""
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
return attn_bias
|
||||
|
||||
|
||||
def _paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
head_mapping: torch.Tensor, # num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
use_sqrt_alibi: bool = False
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
|
||||
use_v2 = enable_infer_paged_attn is None and key_cache.dim() == 4
|
||||
if not use_v2:
|
||||
block_size = value_cache.shape[3]
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping, # num_kv_heads
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
block_size = value_cache.shape[2]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping, # num_kv_heads
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
# ↓ add for smoothquant
|
||||
class DequantPagedAttention(PagedAttention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
quant_kv_cache: bool = False,
|
||||
kv_quant_params: torch.Tensor = None,
|
||||
quant_scale: float = 1.0,
|
||||
use_per_token_quant: bool = True,
|
||||
) -> None:
|
||||
super().__init__(num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window)
|
||||
self.register_parameter(
|
||||
"quant_scale",
|
||||
torch.nn.Parameter(
|
||||
torch.tensor(quant_scale, dtype=torch.float32,requires_grad=False))
|
||||
)
|
||||
self.use_per_token_quant = use_per_token_quant
|
||||
|
||||
def _apply(self, fn):
|
||||
super()._apply(fn)
|
||||
self.quant_scale.data = self.quant_scale.cpu()
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.quant_scale.data = self.quant_scale.to(*args, **kwargs)
|
||||
self.quant_scale.data = self.quant_scale.to(torch.float32)
|
||||
return self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
out = super().forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
)
|
||||
quant_out = torch.empty_like(out, dtype=torch.int8)
|
||||
if self.use_per_token_quant:
|
||||
scale = torch.empty(out.numel() // out.shape[-1],
|
||||
dtype=torch.float32,
|
||||
device=out.device)
|
||||
ops.quant(quant_out, out, scale)
|
||||
return quant_out, scale
|
||||
else:
|
||||
ops.quant(quant_out, out, self.quant_scale.item())
|
||||
return (quant_out, )
|
||||
79
launch_service
Executable file
79
launch_service
Executable file
@@ -0,0 +1,79 @@
|
||||
#!/bin/bash
|
||||
|
||||
export PYTHONPATH=/usr/local/corex/lib64/python3/dist-packages
|
||||
export LD_LIBRARY_PATH=/usr/local/corex/lib64:/usr/local/openmpi/lib
|
||||
export PATH=/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin
|
||||
export JAVA_HOME=/root/apps/jdk1.8.0_411
|
||||
export JRE_HOME=/root/apps/jdk1.8.0_411/jre
|
||||
export JMETER_HOME=/root/apps/apache-jmeter-5.6.3
|
||||
export CLASSPATH=.:/root/apps/jdk1.8.0_411/lib/dt.jar:/root/apps/jdk1.8.0_411/lib/tools.jar:/root/apps/apache-jmeter-5.6.3/lib/ext/ApacheJMeter_core.jar:/root/apps/apache-jmeter-5.6.3/lib/jorphan.jar:/root/apps/apache-jmeter-5.6.3/lib/logkit-2.0.jar:
|
||||
export PATH=/root/apps/apache-jmeter-5.6.3/bin:/root/apps/jdk1.8.0_411/bin:/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin
|
||||
/iluvatar/welcome.sh
|
||||
|
||||
data
|
||||
cat /proc/cpuinfo | tail -n 50
|
||||
ixsmi
|
||||
unset CUDA_VISIBLE_DEVICES
|
||||
export
|
||||
date
|
||||
|
||||
DEFAULT_HOST="0.0.0.0"
|
||||
DEFAULT_PORT="80"
|
||||
DEFAULT_SERVED_MODEL_NAME="llm"
|
||||
DEFAULT_MODEL_PATH="/model"
|
||||
DEFAULT_MAX_MODEL_LEN="10000"
|
||||
DEFAULT_TENSOR_PARALLEL_SIZE="1"
|
||||
DEFAULT_MAX_NUM_SEQS="64"
|
||||
DEFAULT_ENFORCE_EAGER="true"
|
||||
DEFAULT_DISABLE_LOG_REQUESTS="true"
|
||||
DEFAULT_PREFIX_CACHING="true"
|
||||
|
||||
HOST_VAL=${HOST:-$DEFAULT_HOST}
|
||||
PORT_VAL=${PORT:-$DEFAULT_PORT}
|
||||
SERVED_MODEL_NAME_VAL=${SERVED_MODEL_NAME:-$DEFAULT_SERVED_MODEL_NAME}
|
||||
MODEL_PATH_VAL=${MODEL_PATH:-$DEFAULT_MODEL_PATH}
|
||||
MAX_MODEL_LEN_VAL=${MAX_MODEL_LEN:-$DEFAULT_MAX_MODEL_LEN}
|
||||
TENSOR_PARALLEL_SIZE_VAL=${TENSOR_PARALLEL_SIZE:-$DEFAULT_TENSOR_PARALLEL_SIZE}
|
||||
MAX_NUM_SEQS_VAL=${MAX_NUM_SEQS:-$DEFAULT_MAX_NUM_SEQS}
|
||||
INCLUDE_ENFORCE_EAGER_FLAG=${ENFORCE_EAGER:-$DEFAULT_ENFORCE_EAGER}
|
||||
INCLUDE_DISABLE_LOG_REQUESTS_FLAG=${DISABLE_LOG_REQUESTS:-$DEFAULT_DISABLE_LOG_REQUESTS}
|
||||
INCLUDE_PREFIX_CACHING_FLAG=${PREFIX_CACHING:-$DEFAULT_PREFIX_CACHING}
|
||||
|
||||
CMD_ARGS=()
|
||||
CMD_ARGS+=(--host "$HOST_VAL")
|
||||
CMD_ARGS+=(--port "$PORT_VAL")
|
||||
|
||||
if [[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]]; then
|
||||
CMD_ARGS+=(--enforce-eager)
|
||||
fi
|
||||
if [[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]]; then
|
||||
CMD_ARGS+=(--disable-log-requests)
|
||||
fi
|
||||
if [[ "$INCLUDE_PREFIX_CACHING_FLAG" != "false" && "$INCLUDE_PREFIX_CACHING_FLAG" != "0" ]]; then
|
||||
CMD_ARGS+=(--enable-prefix-caching)
|
||||
fi
|
||||
|
||||
CMD_ARGS+=(--served-model-name "$SERVED_MODEL_NAME_VAL")
|
||||
CMD_ARGS+=(--model "$MODEL_PATH_VAL")
|
||||
CMD_ARGS+=(--max-model-len "$MAX_MODEL_LEN_VAL")
|
||||
CMD_ARGS+=(--tensor-parallel-size "$TENSOR_PARALLEL_SIZE_VAL")
|
||||
CMD_ARGS+=(--max-num-seqs "$MAX_NUM_SEQS_VAL")
|
||||
|
||||
echo "--------------------------------------------------"
|
||||
echo "Starting VLLM OpenAI API Server..."
|
||||
echo "Using effective arguments:"
|
||||
echo " Host (--host): $HOST_VAL"
|
||||
echo " Port (--port): $PORT_VAL"
|
||||
echo " Enforce Eager (--enforce-eager):" $([[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: ENFORCE_EAGER=$ENFORCE_EAGER)")
|
||||
echo " Disable Log Req (--disable-log-requests):" $([[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: DISABLE_LOG_REQUESTS=$DISABLE_LOG_REQUESTS)")
|
||||
echo " Served Model Name (--served-model-name): $SERVED_MODEL_NAME_VAL"
|
||||
echo " Model Path (--model): $MODEL_PATH_VAL"
|
||||
echo " Max Model Length (--max-model-len): $MAX_MODEL_LEN_VAL"
|
||||
echo " Tensor Parallel Size (--tensor-parallel-size): $TENSOR_PARALLEL_SIZE_VAL"
|
||||
echo " Max Num Seqs (--max-num-seqs): $MAX_NUM_SEQS_VAL"
|
||||
echo "--------------------------------------------------"
|
||||
echo "Full cmd:"
|
||||
echo "python3 -m vllm.entrypoints.openai.api_server ${CMD_ARGS[*]}"
|
||||
echo "--------------------------------------------------"
|
||||
|
||||
python3 -m vllm.entrypoints.openai.api_server "${CMD_ARGS[@]}"
|
||||
242
paged_attn.py
Normal file
242
paged_attn.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
use_v1 = True
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc[:-1],
|
||||
seq_lens_tensor,
|
||||
context_lens,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
1
pkgs/triton-2.1.0+corex.4.1.2.dist-info/INSTALLER
Normal file
1
pkgs/triton-2.1.0+corex.4.1.2.dist-info/INSTALLER
Normal file
@@ -0,0 +1 @@
|
||||
pip
|
||||
33
pkgs/triton-2.1.0+corex.4.1.2.dist-info/METADATA
Normal file
33
pkgs/triton-2.1.0+corex.4.1.2.dist-info/METADATA
Normal file
@@ -0,0 +1,33 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: triton
|
||||
Version: 2.1.0+corex.4.1.2
|
||||
Summary: A language and compiler for custom Deep Learning operations
|
||||
Home-page: https://github.com/openai/triton/
|
||||
Author: Philippe Tillet
|
||||
Author-email: phil@openai.com
|
||||
Keywords: Compiler,Deep Learning
|
||||
Classifier: Development Status :: 4 - Beta
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Topic :: Software Development :: Build Tools
|
||||
Classifier: License :: OSI Approved :: MIT License
|
||||
Classifier: Programming Language :: Python :: 3.7
|
||||
Classifier: Programming Language :: Python :: 3.8
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Requires-Dist: filelock
|
||||
Provides-Extra: build
|
||||
Requires-Dist: cmake>=3.18; extra == "build"
|
||||
Requires-Dist: lit; extra == "build"
|
||||
Provides-Extra: tests
|
||||
Requires-Dist: autopep8; extra == "tests"
|
||||
Requires-Dist: flake8; extra == "tests"
|
||||
Requires-Dist: isort; extra == "tests"
|
||||
Requires-Dist: numpy; extra == "tests"
|
||||
Requires-Dist: pytest; extra == "tests"
|
||||
Requires-Dist: scipy>=1.7.1; extra == "tests"
|
||||
Provides-Extra: tutorials
|
||||
Requires-Dist: matplotlib; extra == "tutorials"
|
||||
Requires-Dist: pandas; extra == "tutorials"
|
||||
Requires-Dist: tabulate; extra == "tutorials"
|
||||
|
||||
156
pkgs/triton-2.1.0+corex.4.1.2.dist-info/RECORD
Normal file
156
pkgs/triton-2.1.0+corex.4.1.2.dist-info/RECORD
Normal file
@@ -0,0 +1,156 @@
|
||||
triton-2.1.0+corex.4.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
triton-2.1.0+corex.4.1.2.dist-info/METADATA,sha256=WsBBXGUNH3GMUl-Sr8PVEA9wKr8pKyeQQ7BeQEq4Moc,1271
|
||||
triton-2.1.0+corex.4.1.2.dist-info/RECORD,,
|
||||
triton-2.1.0+corex.4.1.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
triton-2.1.0+corex.4.1.2.dist-info/WHEEL,sha256=jkyzK-PMDGe6NpviSR1tpgon24csQEXDXqxhDgOHeYM,105
|
||||
triton-2.1.0+corex.4.1.2.dist-info/direct_url.json,sha256=rC0cgRYuAV96ZhqohICZYGcWgh85fJNAUbExKOBJuaE,284
|
||||
triton-2.1.0+corex.4.1.2.dist-info/top_level.txt,sha256=g7iYLuhGmQFd2TwONpkqhhg7KJl6GMjbwGZvSp3_tnE,275
|
||||
triton/_C/libtriton.so,sha256=TabhXpykxBhWhTTiPjVzF4qoFgKy2B5-GV04hVDF1HY,256934256
|
||||
triton/__init__.py,sha256=J6RjARnKOLfAU59wgewEy82ZoICORUQe_mfHWHQ87lY,1267
|
||||
triton/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/__pycache__/testing.cpython-310.pyc,,
|
||||
triton/common/__init__.py,sha256=uG0IrQy1GiwfCkg9FHqVcYHQ9dspmDtma8SSdFqKXew,48
|
||||
triton/common/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/common/__pycache__/build.cpython-310.pyc,,
|
||||
triton/common/build.py,sha256=Ssz_enawYCeVbyHsmJRk9e81dvlFCIWNKikWcWJjZBo,4609
|
||||
triton/compiler/__init__.py,sha256=VGVviF6fZSXDMUrcO9o0H96j9gQXcYDjnJx8TG18xhU,144
|
||||
triton/compiler/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/compiler/__pycache__/code_generator.cpython-310.pyc,,
|
||||
triton/compiler/__pycache__/compiler.cpython-310.pyc,,
|
||||
triton/compiler/__pycache__/errors.cpython-310.pyc,,
|
||||
triton/compiler/__pycache__/make_launcher.cpython-310.pyc,,
|
||||
triton/compiler/code_generator.py,sha256=2T0mFAy4TUOh1v2Mh3csLtztT9Aky0zNH-fjfLqx5W8,49616
|
||||
triton/compiler/compiler.py,sha256=0_H9l2n4MlHKA9zfploeC9nF-oWC8sw5oUOvgGirRwU,23506
|
||||
triton/compiler/errors.py,sha256=PiquMxHuHayRvdH3hMXSQlOnDswK3TGNWENB29YX_FU,1666
|
||||
triton/compiler/make_launcher.py,sha256=dhi2y8cTVkaFUAFfH-s7gBq40ukhq18_Rl674lsS408,12458
|
||||
triton/debugger/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
triton/debugger/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/debugger/__pycache__/core.cpython-310.pyc,,
|
||||
triton/debugger/__pycache__/debugger.cpython-310.pyc,,
|
||||
triton/debugger/__pycache__/memory_map.cpython-310.pyc,,
|
||||
triton/debugger/__pycache__/tl_lang.cpython-310.pyc,,
|
||||
triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc,,
|
||||
triton/debugger/core.py,sha256=zjYVVAleV3EdBMOhMUR-c1X9H3613fu1ttOsG5n0VgQ,150
|
||||
triton/debugger/debugger.py,sha256=-G9B4apnYC1woN-dTwdBkk4EdS-bavFnJ763_kvUr6o,5399
|
||||
triton/debugger/memory_map.py,sha256=7xvtTJsJoGYqX4bwsJJyxJvl8H5YOnA3nnqwF5zifi8,3165
|
||||
triton/debugger/tl_lang.py,sha256=P2ZZa9QNkLxp63PY7kSdTrEWU4mF5zzB3hsBIm1Xum0,17980
|
||||
triton/debugger/torch_wrapper.py,sha256=iD0qbPnsR-rMG0anMJRoEhIDvQaVzFmBuFTKkPmnj_g,353
|
||||
triton/language/__init__.py,sha256=lLiZyE7r2PJ3vQcUfw2EAkFaBJHLzG4YRUlJKGlOxrk,2884
|
||||
triton/language/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/language/__pycache__/core.cpython-310.pyc,,
|
||||
triton/language/__pycache__/math.cpython-310.pyc,,
|
||||
triton/language/__pycache__/random.cpython-310.pyc,,
|
||||
triton/language/__pycache__/semantic.cpython-310.pyc,,
|
||||
triton/language/__pycache__/standard.cpython-310.pyc,,
|
||||
triton/language/core.py,sha256=cxhJ3qG4h2T85FUuaVdDTAdB7eXkTHzmaerFZd57cow,54553
|
||||
triton/language/extra/__init__.py,sha256=zyhlj6Mo9_KD7E5szGxI18Ko3b31E00-s1ybOM5KzkQ,39
|
||||
triton/language/extra/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/language/extra/__pycache__/cuda.cpython-310.pyc,,
|
||||
triton/language/extra/cuda.bc,sha256=03xfoq03DK6PfMS2fm05V90ei5jlnDdZaD6kZffQEas,1808
|
||||
triton/language/extra/cuda.py,sha256=7UCCXfkwgLPpl87FRZsXyrHN6tt0xRoQeeBksSMdgyE,641
|
||||
triton/language/math.py,sha256=g1z24oHGwyfZZNPkp9OUQ0_nXlnGEQCwAMcKKrRt1kE,75450
|
||||
triton/language/random.py,sha256=yxzlT8jCKAv0yo4aVcWL5kBj4OOMyRcpbuGXMUJw0E0,5630
|
||||
triton/language/semantic.py,sha256=hmyErm5bWEvlt2NgWPzubVUNBh4Bbfrb5BJ6y42VScM,62939
|
||||
triton/language/standard.py,sha256=jpc-ha22ylEHusr-f4OUSLo7iuso3yF4cTjOTpSMqo0,2320
|
||||
triton/ops/__init__.py,sha256=KM696wNNUi4gHanlZo54WcfmYnRWkykHKGnuhRoqAgQ,370
|
||||
triton/ops/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/ops/__pycache__/bmm_matmul.cpython-310.pyc,,
|
||||
triton/ops/__pycache__/cross_entropy.cpython-310.pyc,,
|
||||
triton/ops/__pycache__/flash_attention.cpython-310.pyc,,
|
||||
triton/ops/__pycache__/matmul.cpython-310.pyc,,
|
||||
triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc,,
|
||||
triton/ops/blocksparse/__init__.py,sha256=6YEVQNzipgQCpoO_7B8H7ckaSW2Idt1244s7IyLWAwc,100
|
||||
triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc,,
|
||||
triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc,,
|
||||
triton/ops/blocksparse/matmul.py,sha256=uBOKf2pS5fFtGokQvrWS0GrVHNipIrsTZ_-vPr5lfOY,15617
|
||||
triton/ops/blocksparse/softmax.py,sha256=eFvPebQZ__8L4pCmks-PsqpCG2pgheRL2f0Z4WTKrts,7893
|
||||
triton/ops/bmm_matmul.py,sha256=grVD0Z800EWMC35TtZZoGgQihZqRUfnjJ9PfOpJmhfA,6381
|
||||
triton/ops/cross_entropy.py,sha256=NRZdHy8nh70E0B_7BIHc-3-rkQF5Nb9nBnbfxfjHUAs,3457
|
||||
triton/ops/flash_attention.py,sha256=AFjILTHWCkYXHDNkEBdrcsYIQ4OmYvYsKh7G2StXfH8,10335
|
||||
triton/ops/matmul.py,sha256=bMoTeSnGBrR5oeSgtUr6syeLNPepydVpy1l9Ts-5E-g,8447
|
||||
triton/ops/matmul_perf_model.py,sha256=KpYV6bJ9dRqg2js-1vLm_uPQyOytbUT8jaSVeleIXmI,6672
|
||||
triton/runtime/__init__.py,sha256=4N2bT8rp10BAMHQzF9CiIt276t7XhA22c1uJkNB2ejk,516
|
||||
triton/runtime/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/runtime/__pycache__/autotuner.cpython-310.pyc,,
|
||||
triton/runtime/__pycache__/cache.cpython-310.pyc,,
|
||||
triton/runtime/__pycache__/driver.cpython-310.pyc,,
|
||||
triton/runtime/__pycache__/errors.cpython-310.pyc,,
|
||||
triton/runtime/__pycache__/jit.cpython-310.pyc,,
|
||||
triton/runtime/autotuner.py,sha256=CYq8EkWf2mzAgubAvHheLtpsCYbrEdu7BbwKxYIhips,12718
|
||||
triton/runtime/backends/cuda.c,sha256=5s9niGGe1tCvt1AW3O66jqiYQtpHQM9aUlKXppnTvO0,4680
|
||||
triton/runtime/backends/hip.c,sha256=L37Uz-DsX5ntV_nI3OknmtB_xuufIWrfZAvW4SafFTA,4005
|
||||
triton/runtime/cache.py,sha256=kpcFPdEbH3cO60HVdzhYtnOuIycXJnbDqWhbuF5hkFQ,4190
|
||||
triton/runtime/driver.py,sha256=UIsodifreXt2d9NLErHR7wu7Fbr4nCpkhfOtYniMhnE,5100
|
||||
triton/runtime/errors.py,sha256=XuA6URwCy4e3iYYbX9k35X89wnoasp7_K1LAPvsTbMQ,591
|
||||
triton/runtime/jit.py,sha256=9LQhkSlhcKoWzx_x6E9E51eByfAj1U7yjwAyj-Mh57g,21468
|
||||
triton/testing.py,sha256=akS7bFcE0LVMi56WCZSt7Wu4CfkRGsasJ9_tL1Zj3Vo,15520
|
||||
triton/third_party/cuda/bin/ptxas,sha256=65yHgyKTqojU5uWOzQ134l3rk9pcGou5HefXzVsCgaw,20495600
|
||||
triton/third_party/cuda/include/cuda.h,sha256=yCBNCpZt_qVabzyVattc2qrEq36p5MeL3kDzvNdeV4A,790319
|
||||
triton/third_party/cuda/lib/libdevice.10.bc,sha256=XC-uN8huaMOjhgWpX1EtfRLV89uYYxC-R_VzBKpype4,473728
|
||||
triton/third_party/rocm/lib/bitcode/asanrtl.bc,sha256=cHu7_b3BolfsDJMlWnGTY71qI6fNNdSVPDMFIojOh9k,24708
|
||||
triton/third_party/rocm/lib/bitcode/cuda2gcn.bc,sha256=UDl-zwwWB5Ad__lIlhReRQBx61m0nMORxdppdTxkjQo,41568
|
||||
triton/third_party/rocm/lib/bitcode/cuda2gcn.patch,sha256=EuHo5MOHjwbqdnuv53532LaFAuRh3C3J1QfRo0B3NJM,600
|
||||
triton/third_party/rocm/lib/bitcode/hip.bc,sha256=yttUFcAzSE3ydohSCvOlTp5QZ3SjSH4juoSOYXz_zGg,2372
|
||||
triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh,sha256=9jMtbtlZBIY6yyEp8V12-6Nw4WahdrCPFyFnRb1kpls,348
|
||||
triton/third_party/rocm/lib/bitcode/ockl.bc,sha256=S3XZYIoykPyJ8XOFeVdqyRcdqxoBu50-AaU8oTdAhS8,227392
|
||||
triton/third_party/rocm/lib/bitcode/oclc_abi_version_400.bc,sha256=K3TaLRC6uoHXTMK6WS7NgLJzjXeKKfvf_8a6lqfNwAs,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_abi_version_500.bc,sha256=9e6Z5f4lm_5asnMSHuPOrrnwlV6jlStgrwYTEiOU5tc,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_off.bc,sha256=DiRiGcUfL17ukJS40RV8OHCWvEU-Oi4oS5NCn-SRAFY,1936
|
||||
triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_on.bc,sha256=cVJuiMvffHLxRqvPmzpbMp8nX7yxtUojn9A-iBfCmaY,1936
|
||||
triton/third_party/rocm/lib/bitcode/oclc_daz_opt_off.bc,sha256=hcifwxlJIx5HlOcj_LjOUeqezaZ2v1xzX6LHNSQSkbE,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_daz_opt_on.bc,sha256=jYa1MyGdVRH6etbC6gicJ8egrwwV4wVHbe4aHRQ27R8,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_finite_only_off.bc,sha256=gWpVuB2OXig5vPczOQshtvZ_oJO72R5fFRiAYTILhgM,1928
|
||||
triton/third_party/rocm/lib/bitcode/oclc_finite_only_on.bc,sha256=z9bPfFcxKqMILV3SsRJUIk6ExpZHak2Asv8FP9Ni9bw,1928
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1010.bc,sha256=0vFumpEzejM0Xz7hSSbEYT-iZCrPrDeTRicpD-H-jXM,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1011.bc,sha256=VqawQ_trHRs7rn9NQEkKQ0gDO1BW73hkl9ZdP5Ixybc,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1012.bc,sha256=2C8uRNtmebw0dHh8Vg5_PuFLiY4x19wcydnS7VvadOo,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1013.bc,sha256=lZfO_8vUqSCn1o3VDALvFN-Cclq3eOYDe7-FthD0Gng,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1030.bc,sha256=WIadAgcRBfWphRmcrQJUYewZXWtSN7npAAMAkNQ0C1w,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1031.bc,sha256=ofxROlV79qeW1SyCo9VmJ-W9CVK7zxZtbWLDA-nKRh0,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1032.bc,sha256=IXOzwf4sdsw453PcColeLVjEQek4H1X3G-IfqvoZfhA,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1033.bc,sha256=4KNBcZx7R7e6CJE-4-Ge3v4xzOIuGOOOdWhpga_txbU,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1034.bc,sha256=nmMcRhEdWZcSzMbO0VCj98VTujkBsswVyQiozF4Pgo4,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1035.bc,sha256=EfSk25ZwCSNNBVOhKoGvkKpPXmW5WSYPTCqD0PCwAB0,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1036.bc,sha256=L7hI21xe_IX3IFyQf_A73k1YXEf95w16rLV1BB6FwQE,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1100.bc,sha256=_UxQflZQXU8FJcg7B6VKJtJJwJRKdDY2qiauHkP6nFs,2096
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1101.bc,sha256=Lb83xl4VE1cQjOxz0gMwWS1_p-sq1wHjs67DpEI2RL0,2096
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1102.bc,sha256=Xqalbf-GktJPs_ZOrg54FCCdxzMAZ4U7Ny7HNqh4Blc,2096
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_1103.bc,sha256=KMSqCG4J8tDGDRUoYK97xw-Gj0CeFJN0bYVn7Y8ot70,2096
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_600.bc,sha256=mkocG9cx-o5vwHCT98Fcq4v_fol3RUk24VPaJ2KNiJ0,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_601.bc,sha256=Ib-dn5dosOuXVbBdG59-RERNDk2eMX4vvWUm0Z8Qqpc,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_602.bc,sha256=QH5FGGkLH3fFhPJneGRrucoP44FBTKjm7VMKzGSHeDA,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_700.bc,sha256=ibu033_NCSnqAYqqD70CRVLeym0RFpgm-PX0q-Y2yDs,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_701.bc,sha256=P3d62xvb3_OSmACax3BeobuBAOP04X_0FjXQDKIL7bk,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_702.bc,sha256=WmuD3Kq09F2KQ3u6mFCfE0MoDZWQAfjwDkOHwZpBnsA,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_703.bc,sha256=3_iPbveGJ9fJa1Kvu-uEEGbuRAOrbTMauJGJNFPBvHI,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_704.bc,sha256=0PEze_BCWRcbGBeM_Jaote9xmzsW7hDCwbeuHs55PGo,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_705.bc,sha256=QBqWlg1VDWgQPXDvuPzWQ8--WRP3oSyk25nv70yQXMU,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_801.bc,sha256=iQKrQ9rQx2UiET0BQ-1RgibWE4_aouTdPoaTg0gUEZk,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_802.bc,sha256=9euokmFJ4pmx6IfkcBCpnHoRJrHqQ1I_zZMIYY7j7Ks,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_803.bc,sha256=-v7D1qhgp4YwsFgYL2R3gN4cN0vR-6RGxMGxMm7RpKM,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_805.bc,sha256=MalaxEN85FaTuJ9CUahDNKaCR2CZj81YNzVeXOHsDOA,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_810.bc,sha256=8Dbv2N9xQB93CzJ0DWe1YLpstI6SpEH71XABjEJu7BQ,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_900.bc,sha256=j3tpNoOQGtuHd4isYNNPdtYVGSmeXbIiHfLgPRioS8Q,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_902.bc,sha256=3SlZ_VGBX9y8cRm9hF3XdtadPFJFLsIaoOnGdnqtdFY,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_904.bc,sha256=OLEXkdaKEefZgzsommw6c5378MDasYKH_PSkGoU7hGU,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_906.bc,sha256=KeOq_NoSwkh3D-1eYbzEvYcLqx-a4XHlMib3GQhtoFY,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_908.bc,sha256=PFAdOwLxneisb1dnPMU8nkm8kslmNuCGWbVFOP9bgj0,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_909.bc,sha256=J4Cwry1PZErUfvaYu4ZNCct1NQMDle0mz84ycbKJp6g,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_90a.bc,sha256=ojZ7tLg2N7Hq1omZIrCkBRzia0bYtOdGSU4CmkkdfhE,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_90c.bc,sha256=6mVoOg82NG_NfT2Su6fB18DFgRBRtOSQnRL2QgjRdHY,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_isa_version_940.bc,sha256=Z5ahEe7z2mwBgxuBySEOG7-rRRdpVACwawch5PE6s2I,1920
|
||||
triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_off.bc,sha256=evf88YMj0hBicj04zDdTM0Xlu4KG9fpUlj3iGS0F87A,1928
|
||||
triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_on.bc,sha256=bYF53Lu769UGewikZaQMfoxOerQBGs3SfLpQXoj3zVk,1928
|
||||
triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_off.bc,sha256=XlcwAnD8EeEHQ6FksQG3-Rxc07zMt7aFXVd-xmGBeDo,1928
|
||||
triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_on.bc,sha256=s7t5si7PLY1kIAs0oJGY-9uSL-lMKtcXdGZqzXyFteI,1928
|
||||
triton/third_party/rocm/lib/bitcode/ocml.bc,sha256=bWAX1FuKc-cJrmLlC6kYhvs5fzRoBbc0Y7BP5qgg0BA,192228
|
||||
triton/third_party/rocm/lib/bitcode/opencl.bc,sha256=mL2MGdVis93KLQWVYYj0pT6d8iS99DKlbnEe-1VGdss,2841028
|
||||
triton/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
triton/tools/__pycache__/__init__.cpython-310.pyc,,
|
||||
triton/tools/__pycache__/aot.cpython-310.pyc,,
|
||||
triton/tools/__pycache__/build_extern.cpython-310.pyc,,
|
||||
triton/tools/__pycache__/disasm.cpython-310.pyc,,
|
||||
triton/tools/aot.py,sha256=DUb9edluQKqfvY6aeIyQk2dN1Ljv3i6GYystyObqc8o,4752
|
||||
triton/tools/build_extern.py,sha256=t5vdo6a3bcFfawMQBnTderw7Xi-L49x4udomoOWRUtw,14661
|
||||
triton/tools/disasm.py,sha256=FOtqCFjUedLoZtb1hektBD32L1YxC43qnWKZr6VXc5E,4593
|
||||
0
pkgs/triton-2.1.0+corex.4.1.2.dist-info/REQUESTED
Normal file
0
pkgs/triton-2.1.0+corex.4.1.2.dist-info/REQUESTED
Normal file
5
pkgs/triton-2.1.0+corex.4.1.2.dist-info/WHEEL
Normal file
5
pkgs/triton-2.1.0+corex.4.1.2.dist-info/WHEEL
Normal file
@@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.44.0)
|
||||
Root-Is-Purelib: false
|
||||
Tag: cp310-cp310-linux_x86_64
|
||||
|
||||
1
pkgs/triton-2.1.0+corex.4.1.2.dist-info/direct_url.json
Normal file
1
pkgs/triton-2.1.0+corex.4.1.2.dist-info/direct_url.json
Normal file
@@ -0,0 +1 @@
|
||||
{"archive_info": {"hash": "sha256=0b2734afe117d3c56cda0a7cd1f4752510dc91d8678638321dcbaee07110da40", "hashes": {"sha256": "0b2734afe117d3c56cda0a7cd1f4752510dc91d8678638321dcbaee07110da40"}}, "url": "file:///usr/local/apps_whl/triton-2.1.0%2Bcorex.4.1.2-cp310-cp310-linux_x86_64.whl"}
|
||||
15
pkgs/triton-2.1.0+corex.4.1.2.dist-info/top_level.txt
Normal file
15
pkgs/triton-2.1.0+corex.4.1.2.dist-info/top_level.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
triton
|
||||
triton/_C
|
||||
triton/common
|
||||
triton/compiler
|
||||
triton/debugger
|
||||
triton/language
|
||||
triton/language/extra
|
||||
triton/ops
|
||||
triton/ops/blocksparse
|
||||
triton/runtime
|
||||
triton/runtime/backends
|
||||
triton/third_party/cuda/bin
|
||||
triton/third_party/cuda/include
|
||||
triton/third_party/cuda/lib
|
||||
triton/tools
|
||||
BIN
pkgs/triton/_C/libtriton.so
Executable file
BIN
pkgs/triton/_C/libtriton.so
Executable file
Binary file not shown.
68
pkgs/triton/__init__.py
Normal file
68
pkgs/triton/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""isort:skip_file"""
|
||||
__version__ = '2.1.0'
|
||||
|
||||
# ---------------------------------------
|
||||
# Note: import order is significant here.
|
||||
|
||||
# submodules
|
||||
from .runtime import (
|
||||
autotune,
|
||||
Config,
|
||||
heuristics,
|
||||
JITFunction,
|
||||
KernelInterface,
|
||||
reinterpret,
|
||||
TensorWrapper,
|
||||
OutOfResources,
|
||||
MockTensor,
|
||||
)
|
||||
from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from .debugger.debugger import program_ids_from_grid
|
||||
|
||||
from . import language
|
||||
from . import testing
|
||||
|
||||
__all__ = [
|
||||
"autotune",
|
||||
"cdiv",
|
||||
"CompilationError",
|
||||
"compile",
|
||||
"Config",
|
||||
"heuristics",
|
||||
"impl",
|
||||
"jit",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"language",
|
||||
"MockTensor",
|
||||
"next_power_of_2",
|
||||
"ops",
|
||||
"OutOfResources",
|
||||
"reinterpret",
|
||||
"runtime",
|
||||
"TensorWrapper",
|
||||
"testing",
|
||||
"program_ids_from_grid",
|
||||
]
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# misc. utilities that don't fit well
|
||||
# into any specific module
|
||||
# -------------------------------------
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
BIN
pkgs/triton/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/__pycache__/testing.cpython-310.pyc
Normal file
BIN
pkgs/triton/__pycache__/testing.cpython-310.pyc
Normal file
Binary file not shown.
3
pkgs/triton/common/__init__.py
Normal file
3
pkgs/triton/common/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .build import _build
|
||||
|
||||
__all__ = ["_build"]
|
||||
BIN
pkgs/triton/common/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/common/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/common/__pycache__/build.cpython-310.pyc
Normal file
BIN
pkgs/triton/common/__pycache__/build.cpython-310.pyc
Normal file
Binary file not shown.
137
pkgs/triton/common/build.py
Normal file
137
pkgs/triton/common/build.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import sysconfig
|
||||
|
||||
import setuptools
|
||||
|
||||
|
||||
# TODO: is_hip shouldn't be here
|
||||
def is_hip():
|
||||
import torch
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def is_corex():
|
||||
import torch
|
||||
return hasattr(torch, "corex") and torch.corex == True
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def cuda_home_dirs():
|
||||
loc = subprocess.check_output(["whereis", "clang++"]).decode().strip().split()[1]
|
||||
default_dir = os.path.dirname(os.path.dirname(loc))
|
||||
return os.getenv("CUDA_HOME", default=default_dir)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def libcuda_dirs():
|
||||
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
||||
return [os.path.dirname(loc) for loc in locs]
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
|
||||
|
||||
def _build(name, src, srcdir):
|
||||
if is_hip():
|
||||
hip_lib_dir = os.path.join(rocm_path_dir(), "lib")
|
||||
hip_include_dir = os.path.join(rocm_path_dir(), "include")
|
||||
else:
|
||||
if is_corex():
|
||||
cuda_path = cuda_home_dirs()
|
||||
cu_include_dir = os.path.join(cuda_path, "include")
|
||||
cuda_lib_dirs = [os.path.join(cuda_path, "lib64")]
|
||||
else:
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
|
||||
cuda_path = os.path.join(base_dir, "third_party", "cuda")
|
||||
|
||||
cu_include_dir = os.path.join(cuda_path, "include")
|
||||
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
|
||||
cuda_header = os.path.join(cu_include_dir, "cuda.h")
|
||||
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
|
||||
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
|
||||
cu_include_dir = triton_include_dir
|
||||
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
cc = os.environ.get("CC")
|
||||
if cc is None:
|
||||
# TODO: support more things here.
|
||||
clang = shutil.which("clang")
|
||||
gcc = shutil.which("gcc")
|
||||
if is_corex():
|
||||
cc = clang if clang is not None else gcc
|
||||
else:
|
||||
cc = gcc if gcc is not None else clang
|
||||
if cc is None:
|
||||
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
|
||||
# This function was renamed and made public in Python 3.10
|
||||
if hasattr(sysconfig, 'get_default_scheme'):
|
||||
scheme = sysconfig.get_default_scheme()
|
||||
else:
|
||||
scheme = sysconfig._get_default_scheme()
|
||||
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
|
||||
# path changes to include 'local'. This change is required to use triton with system-wide python.
|
||||
if scheme == 'posix_local':
|
||||
scheme = 'posix_prefix'
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
if is_hip():
|
||||
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
|
||||
if ret == 0:
|
||||
return so
|
||||
# fallback on setuptools
|
||||
extra_compile_args = []
|
||||
library_dirs = cuda_lib_dirs
|
||||
include_dirs = [srcdir, cu_include_dir]
|
||||
libraries = ['cuda']
|
||||
# extra arguments
|
||||
extra_link_args = []
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name=name,
|
||||
language='c',
|
||||
sources=[src],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=extra_compile_args + ['-O3'],
|
||||
extra_link_args=extra_link_args,
|
||||
library_dirs=library_dirs,
|
||||
libraries=libraries,
|
||||
)
|
||||
# build extension module
|
||||
args = ['build_ext']
|
||||
args.append('--build-temp=' + srcdir)
|
||||
args.append('--build-lib=' + srcdir)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name=name,
|
||||
ext_modules=[ext],
|
||||
script_args=args,
|
||||
)
|
||||
with quiet():
|
||||
setuptools.setup(**args)
|
||||
return so
|
||||
4
pkgs/triton/compiler/__init__.py
Normal file
4
pkgs/triton/compiler/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .compiler import CompiledKernel, compile
|
||||
from .errors import CompilationError
|
||||
|
||||
__all__ = ["compile", "CompiledKernel", "CompilationError"]
|
||||
BIN
pkgs/triton/compiler/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/code_generator.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/code_generator.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/compiler.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/compiler.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/errors.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/errors.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/compiler/__pycache__/make_launcher.cpython-310.pyc
Normal file
BIN
pkgs/triton/compiler/__pycache__/make_launcher.cpython-310.pyc
Normal file
Binary file not shown.
1133
pkgs/triton/compiler/code_generator.py
Normal file
1133
pkgs/triton/compiler/code_generator.py
Normal file
File diff suppressed because it is too large
Load Diff
631
pkgs/triton/compiler/compiler.py
Normal file
631
pkgs/triton/compiler/compiler.py
Normal file
@@ -0,0 +1,631 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..runtime import driver
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..tools.disasm import extract
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
|
||||
|
||||
def is_corex():
|
||||
import torch
|
||||
return hasattr(torch, "corex") and torch.corex == True
|
||||
|
||||
CUDA_DEFAULT_WARP_SIZE = 64 if is_corex() else 32
|
||||
|
||||
def inline_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def ttir_compute_capability_rewrite(mod, arch):
|
||||
# For hardware without support, we must rewrite all load/store
|
||||
# with block (tensor) pointers into tensors of pointers
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
if _is_cuda(arch):
|
||||
pm.add_rewrite_tensor_pointer_pass(arch)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttir(mod, arch):
|
||||
mod = inline_triton_ir(mod)
|
||||
mod = ttir_compute_capability_rewrite(mod, arch)
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def ttir_to_ttgir(mod, num_warps, warpsize):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, arch, use_sme = 0):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if _is_cuda(arch):
|
||||
pm.add_tritongpu_accelerate_matmul_pass(arch, use_sme)
|
||||
# TODO change interface of accelerate_matmul_pass
|
||||
if is_hip() and gpu_has_mfma():
|
||||
pm.add_tritongpu_accelerate_matmul_pass(80)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
if is_corex():
|
||||
pm.add_tritongpu_matmul_smeload_pass(arch) #BI 70 MR > 71,only MR support sme
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
# TODO enable this pass for AMD GPU when it is ready
|
||||
if not is_hip():
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
if not is_corex():
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _add_external_libs(mod, libs):
|
||||
for name, path in libs.items():
|
||||
if len(name) == 0 or len(path) == 0:
|
||||
return
|
||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def ttgir_to_llir(mod, extern_libs, arch):
|
||||
if extern_libs:
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(arch):
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
|
||||
else:
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)
|
||||
|
||||
|
||||
# PTX translation
|
||||
|
||||
@functools.lru_cache()
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
Get the highest PTX version supported by the current CUDA driver.
|
||||
'''
|
||||
assert isinstance(cuda_version, str)
|
||||
major, minor = map(int, cuda_version.split('.'))
|
||||
if major == 12:
|
||||
return 80 + minor
|
||||
if major == 11:
|
||||
return 70 + minor
|
||||
if major == 10:
|
||||
return 63 + minor
|
||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_ptxas():
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
|
||||
]
|
||||
|
||||
for ptxas in paths:
|
||||
if os.path.exists(ptxas) and os.path.isfile(ptxas):
|
||||
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
|
||||
if result is not None:
|
||||
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
||||
if version is not None:
|
||||
return ptxas, version.group(1)
|
||||
raise RuntimeError("Cannot find ptxas")
|
||||
|
||||
|
||||
def llir_to_cubin(mod: Any, arch: int):
|
||||
'''
|
||||
Compile LLVM module to cubin.
|
||||
:param mod: a LLVM module
|
||||
:return: str
|
||||
'''
|
||||
return _triton.translate_llvmir_to_cubin(mod, arch)
|
||||
|
||||
|
||||
def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return: PTX code
|
||||
'''
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)
|
||||
|
||||
|
||||
def ptx_to_cubin(ptx: str, arch: int):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param compute_capability: compute capability
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
|
||||
|
||||
|
||||
# AMDGCN translation
|
||||
|
||||
def get_amdgcn_bitcode_paths(arch):
|
||||
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
|
||||
"ocml.bc",
|
||||
"ockl.bc",
|
||||
"oclc_finite_only_off.bc",
|
||||
"oclc_daz_opt_off.bc",
|
||||
"oclc_correctly_rounded_sqrt_on.bc",
|
||||
"oclc_unsafe_math_off.bc",
|
||||
"oclc_wavefrontsize64_on.bc",
|
||||
"oclc_abi_version_400.bc",]
|
||||
|
||||
gfx_arch = arch[1]
|
||||
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
|
||||
|
||||
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
|
||||
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
|
||||
|
||||
amdgcn_bitcode_paths = {}
|
||||
i = 0
|
||||
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
|
||||
bc_path = bitcode_path_dir + bc_lib
|
||||
if os.path.exists(bc_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
|
||||
i += 1
|
||||
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
|
||||
if os.path.exists(bc_gfx_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
|
||||
|
||||
return amdgcn_bitcode_paths
|
||||
|
||||
|
||||
def get_amdgpu_arch_fulldetails():
|
||||
"""
|
||||
get the amdgpu fulll ISA details for compiling:
|
||||
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
|
||||
"""
|
||||
try:
|
||||
# TODO: package rocm.cc with Triton
|
||||
arch_info = _triton.get_arch_info()
|
||||
warpsize = _triton.get_warp_size()
|
||||
gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--')
|
||||
arch_triple = gfx_arch_details[0]
|
||||
arch_name_features = gfx_arch_details[1].split(':')
|
||||
arch_name = arch_name_features[0]
|
||||
arch_features = ""
|
||||
|
||||
if (len(arch_name_features) == 3):
|
||||
arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\
|
||||
"-" + re.search('\\w+', arch_name_features[2]).group(0)
|
||||
return [arch_triple, arch_name, arch_features, warpsize]
|
||||
except BaseException:
|
||||
return None
|
||||
|
||||
|
||||
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
|
||||
'''
|
||||
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return:
|
||||
- AMDGCN code
|
||||
- Path to HSACO object
|
||||
'''
|
||||
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# compiler
|
||||
# ------------------------------------------------------------------------------
|
||||
def get_kernel_name(src: str, pattern: str, llir: bool = False) -> str:
|
||||
'''
|
||||
Get kernel name from PTX code.
|
||||
This Kernel name is required when launching the kernel.
|
||||
'''
|
||||
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
|
||||
assert src
|
||||
for line in src.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith(pattern):
|
||||
if not llir:
|
||||
return line.split()[-1]
|
||||
return line.split("(")[0].split("@")[-1]
|
||||
|
||||
|
||||
def convert_type_repr(x):
|
||||
match = re.search(r'!tt\.ptr<(.*)>', x)
|
||||
if match is not None:
|
||||
return '*' + convert_type_repr(match.group(1))
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, arch, **kwargs):
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
debug = kwargs.get("debug", False)
|
||||
use_sme = kwargs.get("use_sme", 0)
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}-{use_sme}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
# and any following whitespace
|
||||
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
||||
# - (@\w+) : match an @ symbol followed by one or more word characters
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
prototype_pattern = {
|
||||
"ttir": mlir_prototype_pattern,
|
||||
"ttgir": mlir_prototype_pattern,
|
||||
"ptx": ptx_prototype_pattern,
|
||||
}
|
||||
|
||||
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
|
||||
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
||||
arg_type_pattern = {
|
||||
"ttir": mlir_arg_type_pattern,
|
||||
"ttgir": mlir_arg_type_pattern,
|
||||
"ptx": ptx_arg_type_pattern,
|
||||
}
|
||||
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
def _is_jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
serialized_constants = {}
|
||||
for constant in constants:
|
||||
if _is_jsonable(constants[constant]):
|
||||
serialized_constants[constant] = constants[constant]
|
||||
return serialized_constants
|
||||
|
||||
|
||||
def parse_mlir_module(path, context):
|
||||
module = _triton.ir.parse_mlir_module(path, context)
|
||||
# module takes ownership of the context
|
||||
module.context = context
|
||||
return module
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
# TODO: architecture descriptor class
|
||||
def _is_cuda(arch):
|
||||
return isinstance(arch, int)
|
||||
|
||||
def is_hip():
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return torch.version.hip is not None
|
||||
|
||||
from ..language.semantic import gpu_has_mfma
|
||||
|
||||
def get_architecture_descriptor(capability):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
if capability is None:
|
||||
if torch.version.hip is None:
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
capability = triton.runtime.jit.get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
else:
|
||||
capability = get_amdgpu_arch_fulldetails()
|
||||
return capability
|
||||
|
||||
|
||||
def add_rocm_stages(arch, extern_libs, stages):
|
||||
extern_libs.update(get_amdgcn_bitcode_paths(arch))
|
||||
|
||||
for key in list(extern_libs):
|
||||
if extern_libs[key] == '' or extern_libs[key] is None:
|
||||
extern_libs.pop(key)
|
||||
|
||||
gfx_arch_full_details = arch
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
stages["amdgcn"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
|
||||
gfx_arch_full_details[0],
|
||||
gfx_arch_full_details[2]))
|
||||
|
||||
|
||||
def add_cuda_stages(arch, extern_libs, stages):
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, arch))
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, arch))
|
||||
|
||||
|
||||
def add_iluvatar_stages(arch, extern_libs, stages):
|
||||
stages["cubin"] = (lambda path: Path(path).read_bytes(),
|
||||
lambda src: llir_to_cubin(src, arch))
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
if is_hip():
|
||||
capability = None
|
||||
else:
|
||||
capability = kwargs.get("cc", None)
|
||||
arch = get_architecture_descriptor(capability)
|
||||
is_cuda = _is_cuda(arch)
|
||||
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
|
||||
context = _triton.ir.context()
|
||||
asm = dict()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
use_sme = kwargs.get("use_sme", 0)
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
debug = kwargs.get("debug", False)
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug), arch))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch, use_sme))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch))
|
||||
if is_cuda:
|
||||
if is_corex():
|
||||
add_iluvatar_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
add_cuda_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
add_rocm_stages(arch, extern_libs, stages)
|
||||
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
signature = kwargs["signature"]
|
||||
if configs is None:
|
||||
configs = [instance_descriptor()]
|
||||
assert len(configs) == 1
|
||||
kwargs["configs"] = configs
|
||||
name = fn.__name__
|
||||
first_stage = 0
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
kwargs["signature"] = signature
|
||||
else:
|
||||
assert isinstance(fn, str)
|
||||
_, ir = os.path.basename(fn).split(".")
|
||||
src = Path(fn).read_text()
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
types = re.findall(arg_type_pattern[ir], signature)
|
||||
if ir == 'ttgir':
|
||||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
||||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
||||
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
num_warps = int(num_warps_matches[0])
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir)
|
||||
|
||||
# cache manager
|
||||
so_path = make_stub(name, signature, constants)
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
else:
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
|
||||
# load metadata if any
|
||||
metadata = None
|
||||
metadata_filename = f"{name}.json"
|
||||
|
||||
# The group is addressed by the metadata
|
||||
metadata_group = fn_cache_manager.get_group(
|
||||
metadata_filename
|
||||
) or {}
|
||||
|
||||
metadata_path = metadata_group.get(metadata_filename)
|
||||
|
||||
if metadata_path is not None:
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {"num_warps": num_warps,
|
||||
"warp_size": warp_size,
|
||||
"num_stages": num_stages,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug,
|
||||
"use_sme": use_sme}
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
metadata["shared"] = kwargs["shared"]
|
||||
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
|
||||
ir_filename = f"{name}.{ir}"
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
else:
|
||||
path = metadata_group.get(ir_filename)
|
||||
if path is None:
|
||||
next_module = compile_kernel(module)
|
||||
if ir == "amdgcn":
|
||||
extra_file_name = f"{name}.hsaco_path"
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename)
|
||||
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
|
||||
else:
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_cache_manager.put(next_module, ir_filename)
|
||||
else:
|
||||
if ir == "amdgcn":
|
||||
extra_file_name = f"{name}.hsaco_path"
|
||||
hasco_path = metadata_group.get(extra_file_name)
|
||||
assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn"
|
||||
next_module = (parse(path), parse(hasco_path))
|
||||
else:
|
||||
next_module = parse(path)
|
||||
|
||||
if ir == "llir":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern="define iluvatar_kernel void", llir=True)
|
||||
if ir == "cubin":
|
||||
asm[ir] = next_module
|
||||
elif ir == "amdgcn":
|
||||
asm[ir] = str(next_module[0])
|
||||
else:
|
||||
asm[ir] = str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
||||
if ir == "amdgcn":
|
||||
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
|
||||
asm["hsaco_path"] = next_module[1]
|
||||
module = next_module
|
||||
# write-back metadata, if it didn't come from the cache
|
||||
if metadata_path is None:
|
||||
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False)
|
||||
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
||||
|
||||
# return handle to compiled kernel
|
||||
return CompiledKernel(fn, so_path, metadata, asm)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
# Hooks for external tools to monitor the execution of triton kernels
|
||||
launch_enter_hook = None
|
||||
launch_exit_hook = None
|
||||
|
||||
def __init__(self, fn, so_path, metadata, asm):
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.fn = fn
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.warp_size = metadata["warp_size"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.constants = metadata["constants"]
|
||||
# initialize asm dict
|
||||
self.asm = asm
|
||||
# binaries are lazily initialized
|
||||
# because it involves doing runtime things
|
||||
# (e.g., checking amount of shared memory on current device)
|
||||
self.metadata = metadata
|
||||
self.cu_module = None
|
||||
self.cu_function = None
|
||||
|
||||
def _init_handles(self):
|
||||
if self.cu_module is not None:
|
||||
return
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
|
||||
|
||||
self.n_spills = n_spills
|
||||
self.n_regs = n_regs
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name == 'c_wrapper':
|
||||
self._init_handles()
|
||||
return super().__getattribute__(name)
|
||||
|
||||
def __getitem__(self, grid):
|
||||
self._init_handles()
|
||||
|
||||
def runner(*args, stream=None):
|
||||
if stream is None:
|
||||
stream = triton.runtime.jit.get_cuda_stream()
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
|
||||
return runner
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if 'sass' in self.asm:
|
||||
return self.asm['sass']
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
52
pkgs/triton/compiler/errors.py
Normal file
52
pkgs/triton/compiler/errors.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import ast
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
source_line_count_max_in_message = 12
|
||||
|
||||
def _format_message(self) -> str:
|
||||
node = self.node
|
||||
if self.src is None:
|
||||
source_excerpt = " <source unavailable>"
|
||||
else:
|
||||
source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
|
||||
if source_excerpt:
|
||||
source_excerpt.append(' ' * node.col_offset + '^')
|
||||
source_excerpt = '\n'.join(source_excerpt)
|
||||
else:
|
||||
source_excerpt = " <source empty>"
|
||||
|
||||
message = "at {}:{}:{}".format(node.lineno, node.col_offset, source_excerpt)
|
||||
if self.error_message:
|
||||
message += '\n' + self.error_message
|
||||
return message
|
||||
|
||||
def __init__(self, src: Optional[str], node: ast.AST, error_message: Union[str, None]):
|
||||
self.src = src
|
||||
self.node = node
|
||||
self.error_message = error_message
|
||||
self.message = self._format_message()
|
||||
|
||||
def set_source_code(self, src: Optional[str]):
|
||||
self.src = src
|
||||
self.message = self._format_message()
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({!r})".format(type(self).__name__, self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return type(self), (self.src, self.node, self.error_message)
|
||||
|
||||
|
||||
class CompileTimeAssertionFailure(CompilationError):
|
||||
"""Specific exception for failed tests in `static_assert` invocations"""
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedLanguageConstruct(CompilationError):
|
||||
pass
|
||||
392
pkgs/triton/compiler/make_launcher.py
Normal file
392
pkgs/triton/compiler/make_launcher.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from ..common import _build
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.jit import version_key
|
||||
|
||||
|
||||
def is_hip():
|
||||
import torch
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def is_corex():
|
||||
import torch
|
||||
return hasattr(torch, "corex") and torch.corex == True
|
||||
|
||||
|
||||
# ----- stub --------
|
||||
|
||||
|
||||
def make_so_cache_key(version_hash, signature, constants):
|
||||
# Get unique key for the compiled code
|
||||
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
||||
key = f"{version_hash}-{''.join(signature.values())}{constants}"
|
||||
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
return key
|
||||
|
||||
|
||||
def make_stub(name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
cache_path = so_cache_manager.get_file(so_name)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src = generate_launcher(constants, signature)
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(name, src_path, tmpdir)
|
||||
so_cache_manager.put(src, f"{name}.c", binary=False)
|
||||
with open(so, "rb") as f:
|
||||
return so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
else:
|
||||
return cache_path
|
||||
|
||||
# ----- source code generation --------
|
||||
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "hipDeviceptr_t" if is_hip() else "CUdeviceptr"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
"i16": "int16_t",
|
||||
"i32": "int32_t",
|
||||
"i64": "int64_t",
|
||||
"u32": "uint32_t",
|
||||
"u64": "uint64_t",
|
||||
"fp16": "float",
|
||||
"bf16": "float",
|
||||
"fp32": "float",
|
||||
"f32": "float",
|
||||
"fp64": "double",
|
||||
}[ty]
|
||||
|
||||
|
||||
def generate_launcher(constants, signature):
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
if ty[0] == '*':
|
||||
return "PyObject*"
|
||||
return {
|
||||
'i1': 'int32_t',
|
||||
'i32': 'int32_t',
|
||||
'i64': 'int64_t',
|
||||
'u32': 'uint32_t',
|
||||
'u64': 'uint64_t',
|
||||
'fp16': 'float',
|
||||
'bf16': 'float',
|
||||
'fp32': 'float',
|
||||
'f32': 'float',
|
||||
'fp64': 'double',
|
||||
}[ty]
|
||||
|
||||
def format_of(ty):
|
||||
return {
|
||||
"PyObject*": "O",
|
||||
"float": "f",
|
||||
"double": "d",
|
||||
"long": "l",
|
||||
"uint32_t": "I",
|
||||
"int32_t": "i",
|
||||
"uint64_t": "K",
|
||||
"int64_t": "L",
|
||||
}[ty]
|
||||
|
||||
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
if is_hip():
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static int getWarpSize(hipStream_t stream)
|
||||
{{
|
||||
int device_id = hipGetStreamDeviceId(stream);
|
||||
gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__);
|
||||
hipDeviceProp_t prop;
|
||||
HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
|
||||
return prop.warpSize;
|
||||
}}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if (gridX*gridY*gridZ > 0) {{
|
||||
int warp_size = getWarpSize(stream);
|
||||
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
hipDeviceptr_t dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
|
||||
if (ptr) {{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
|
||||
if (!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
|
||||
uint64_t dev_ptr;
|
||||
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == hipErrorInvalidValue) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())});
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
if (PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
warp_size = 64 if is_corex() else 32
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {{0}};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * {warp_size}, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
CUdeviceptr dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
|
||||
if(!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
/*
|
||||
uint64_t dev_ptr;
|
||||
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
ptr_info.dev_ptr = dev_ptr;
|
||||
*/
|
||||
Py_DECREF(ret); // Thanks ChatGPT!
|
||||
return ptr_info;
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
return src
|
||||
0
pkgs/triton/debugger/__init__.py
Normal file
0
pkgs/triton/debugger/__init__.py
Normal file
BIN
pkgs/triton/debugger/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/debugger/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/debugger/__pycache__/core.cpython-310.pyc
Normal file
BIN
pkgs/triton/debugger/__pycache__/core.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/debugger/__pycache__/debugger.cpython-310.pyc
Normal file
BIN
pkgs/triton/debugger/__pycache__/debugger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/debugger/__pycache__/memory_map.cpython-310.pyc
Normal file
BIN
pkgs/triton/debugger/__pycache__/memory_map.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/debugger/__pycache__/tl_lang.cpython-310.pyc
Normal file
BIN
pkgs/triton/debugger/__pycache__/tl_lang.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc
Normal file
BIN
pkgs/triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc
Normal file
Binary file not shown.
9
pkgs/triton/debugger/core.py
Normal file
9
pkgs/triton/debugger/core.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import Tuple
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExecutionContext:
|
||||
program_id: Tuple[int]
|
||||
program_size: Tuple[int]
|
||||
170
pkgs/triton/debugger/debugger.py
Normal file
170
pkgs/triton/debugger/debugger.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
|
||||
debugger_constexpr)
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
tl_method_backup = {}
|
||||
|
||||
|
||||
def get_proxy_method(proxy, name):
|
||||
method = getattr(proxy, name)
|
||||
|
||||
def fun(*args, **kwarg):
|
||||
return method(*args, **kwarg)
|
||||
|
||||
return fun
|
||||
|
||||
|
||||
def attach_triton(module, proxy):
|
||||
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
|
||||
for name in method_list:
|
||||
if hasattr(module, name):
|
||||
attr = getattr(module, name)
|
||||
tl_method_backup[name] = attr
|
||||
if callable(attr):
|
||||
setattr(module, name, get_proxy_method(proxy, name))
|
||||
else:
|
||||
setattr(module, name, getattr(proxy, name))
|
||||
|
||||
|
||||
def detach_triton(module):
|
||||
for name, method in tl_method_backup.items():
|
||||
setattr(module, name, method)
|
||||
|
||||
|
||||
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
# reverse the grid dimensions and generate the range for each dimension
|
||||
reversed_grid = reversed(grid)
|
||||
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
|
||||
|
||||
# gen all combinations
|
||||
index_combinations = list(itertools.product(*ranges_for_each_dimension))
|
||||
random.shuffle(index_combinations)
|
||||
|
||||
for index_combination in index_combinations:
|
||||
yield index_combination
|
||||
|
||||
|
||||
class DebuggerFunction:
|
||||
def __init__(self, func, grid=(1,)):
|
||||
self.func = func
|
||||
self.grid = grid
|
||||
|
||||
def _is_constexpr(self, name):
|
||||
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr
|
||||
|
||||
def _get_constexpr(self):
|
||||
result = []
|
||||
for name, annotation in self.func.__annotations__.items():
|
||||
if annotation is triton.language.core.constexpr:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
def _assert_constexpr(self, **kwargs):
|
||||
constexp = self._get_constexpr()
|
||||
missing = [i for i in constexp if i not in kwargs.keys()]
|
||||
assert len(missing) == 0, f"You must specify constexpr {missing}"
|
||||
|
||||
def _get_grid(self, **kwargs):
|
||||
if callable(self.grid):
|
||||
return self.grid(kwargs)
|
||||
else:
|
||||
return self.grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._assert_constexpr(**kwargs)
|
||||
|
||||
memory = MemoryMap()
|
||||
|
||||
def convert_arg(v):
|
||||
name, arg = v
|
||||
if torch.is_tensor(arg):
|
||||
ptr = memory.add_tensor(arg)
|
||||
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
|
||||
if self._is_constexpr(name):
|
||||
return debugger_constexpr(arg)
|
||||
return WrappedTensor(_primitive_to_tensor(arg))
|
||||
|
||||
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
|
||||
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
|
||||
|
||||
grid = self._get_grid(**kwargs)
|
||||
for program_id in program_ids_from_grid(grid):
|
||||
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
|
||||
attach_triton(tl, proxy)
|
||||
self.func(*new_args, **new_kwargs)
|
||||
detach_triton(tl)
|
||||
|
||||
|
||||
class GridSelector:
|
||||
"""
|
||||
Entry point of the debugger
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
version = torch.__version__
|
||||
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
|
||||
self.func = func
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return DebuggerFunction(self.func, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return DebuggerFunction(self.func)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneGridSelector:
|
||||
def __init__(self, func, autotune_params):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return AutotuneRunner(self.func, self.autotune_params, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneRunner:
|
||||
def __init__(self, func, autotune_params, grid=None):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert len(self.autotune_params["configs"]) >= 1
|
||||
|
||||
for config in self.autotune_params["configs"][1:]:
|
||||
|
||||
def convert_arg(v):
|
||||
if torch.is_tensor(v):
|
||||
return torch.clone(v)
|
||||
return v
|
||||
|
||||
new_args = tuple(map(convert_arg, args))
|
||||
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
|
||||
if self.grid:
|
||||
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
|
||||
else:
|
||||
self.func(*new_args, **new_kwargs, **config.kwargs)
|
||||
|
||||
main_config = self.autotune_params["configs"][0]
|
||||
if self.grid:
|
||||
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
|
||||
else:
|
||||
self.func(*args, **kwargs, **main_config.kwargs)
|
||||
|
||||
|
||||
def triton_debug_autotune(**kwars):
|
||||
def wrapper(func):
|
||||
return AutotuneGridSelector(func, kwars)
|
||||
|
||||
return wrapper
|
||||
100
pkgs/triton/debugger/memory_map.py
Normal file
100
pkgs/triton/debugger/memory_map.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import dataclasses
|
||||
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegisteredStorage:
|
||||
storage: torch.Storage
|
||||
dtype: torch.dtype
|
||||
size: int
|
||||
ptr: int
|
||||
|
||||
@property
|
||||
def end_ptr(self) -> int:
|
||||
return self.ptr + self.size
|
||||
|
||||
@property
|
||||
def access_tensor(self) -> torch.Tensor:
|
||||
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
|
||||
|
||||
def ensure_immutable(self):
|
||||
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
|
||||
|
||||
|
||||
class MemoryMap:
|
||||
storages: [RegisteredStorage]
|
||||
|
||||
def __init__(self):
|
||||
self.storages = []
|
||||
|
||||
def _get_registered_storage(self, pointer: torch.Tensor):
|
||||
max_pointer = torch.max(pointer).item()
|
||||
min_pointer = torch.min(pointer).item()
|
||||
|
||||
registered_storage = next(
|
||||
filter(
|
||||
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
|
||||
),
|
||||
None,
|
||||
)
|
||||
if registered_storage is None:
|
||||
raise Exception("Storage not found or pointers spanning multiple tensors")
|
||||
registered_storage.ensure_immutable()
|
||||
return registered_storage
|
||||
|
||||
def add_tensor(self, t: torch.Tensor):
|
||||
storage = t.untyped_storage()
|
||||
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
|
||||
return t.data_ptr()
|
||||
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
):
|
||||
assert pointer.is_cuda
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert mask.is_cuda
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
# Todo: The type is wrong here, we can't determine the correct type
|
||||
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
|
||||
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
|
||||
block[mask] = access_tensor[index_tensor[mask]]
|
||||
return block
|
||||
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
return
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)
|
||||
621
pkgs/triton/debugger/tl_lang.py
Normal file
621
pkgs/triton/debugger/tl_lang.py
Normal file
@@ -0,0 +1,621 @@
|
||||
import triton
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
def _primitive_to_tensor(x):
|
||||
"""
|
||||
Converts various Python primitive data types to PyTorch tensor.
|
||||
"""
|
||||
tensor_args = {"device": "cuda"}
|
||||
if isinstance(x, bool):
|
||||
return torch.tensor([x], dtype=torch.bool, **tensor_args)
|
||||
elif isinstance(x, int):
|
||||
if -(2**31) <= x < 2**31:
|
||||
return torch.tensor([x], dtype=torch.int32, **tensor_args)
|
||||
elif -(2**63) <= x < 2**63:
|
||||
return torch.tensor([x], dtype=torch.int64, **tensor_args)
|
||||
else:
|
||||
raise RuntimeError(f"Nonrepresentable integer {x}.")
|
||||
elif isinstance(x, float):
|
||||
return torch.tensor([x], dtype=torch.float32, **tensor_args)
|
||||
elif torch.is_tensor(x):
|
||||
return x
|
||||
elif isinstance(x, WrappedTensor):
|
||||
return x
|
||||
elif isinstance(x, debugger_constexpr):
|
||||
if x.value is None:
|
||||
return None
|
||||
return _primitive_to_tensor(x.value)
|
||||
elif x is None:
|
||||
return None
|
||||
assert False, f"cannot convert {x} of type {type(x)} to tensor"
|
||||
|
||||
|
||||
def _infer_tensor(func):
|
||||
"""
|
||||
A decorator function to harmonize function args:
|
||||
- converts primitives to PyTorch tensors
|
||||
- wraps PyTorch tensors with WrappedTensors
|
||||
"""
|
||||
def wrapper(*args):
|
||||
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
|
||||
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
|
||||
|
||||
return func(*new_args)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _tensor_operation(func):
|
||||
"""
|
||||
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
|
||||
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert not torch.is_tensor(arg), "unexpected tensor argument"
|
||||
|
||||
def unwrap_tensor(v):
|
||||
if isinstance(v, WrappedTensor):
|
||||
return v.tensor
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
new_args = tuple(map(unwrap_tensor, args))
|
||||
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
|
||||
|
||||
result = func(args[0], *new_args[1:], **new_kwargs)
|
||||
return WrappedTensor(result) if torch.is_tensor(result) else result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class debugger_constexpr:
|
||||
def __init__(self, value):
|
||||
if isinstance(value, debugger_constexpr):
|
||||
self.value = value.value
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "debugger_constexpr(" + str(self.value) + ")"
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value >= other
|
||||
|
||||
def __gt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value > other
|
||||
|
||||
def __le__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value <= other
|
||||
|
||||
def __lt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value < other
|
||||
|
||||
def __eq__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value == other
|
||||
|
||||
def __or__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __ror__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __and__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def __rand__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
if dtype in [torch.int64]:
|
||||
ret_ty = int
|
||||
elif dtype == torch.bool:
|
||||
ret_ty = bool
|
||||
elif dtype in [torch.float64]:
|
||||
ret_ty = float
|
||||
else:
|
||||
raise ValueError("dtype not supported in debugger")
|
||||
return debugger_constexpr(ret_ty(self.value))
|
||||
|
||||
|
||||
class WrappedTensor:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.tensor.item()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "wrapped_" + str(self.tensor)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return torch.all(self.tensor == True).item() # noqa: E712
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensor.dtype
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __add__(self, other):
|
||||
return torch.add(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __sub__(self, other):
|
||||
return torch.sub(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rsub__(self, other):
|
||||
return torch.sub(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mul__(self, other):
|
||||
return torch.mul(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmul__(self, other):
|
||||
return self.__mul__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __truediv__(self, other):
|
||||
return torch.div(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rtruediv__(self, other):
|
||||
return torch.div(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __floordiv__(self, other):
|
||||
return torch.floor_divide(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rfloordiv__(self, other):
|
||||
return torch.floor_divide(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mod__(self, other):
|
||||
return torch.remainder(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmod__(self, other):
|
||||
return torch.remainder(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __neg__(self):
|
||||
return -self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __invert__(self):
|
||||
return ~self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __and__(self, other):
|
||||
return torch.bitwise_and(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __or__(self, other):
|
||||
return torch.bitwise_or(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __xor__(self, other):
|
||||
return torch.bitwise_xor(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lshift__(self, other):
|
||||
return torch.bitwise_left_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rshift__(self, other):
|
||||
return torch.bitwise_right_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __gt__(self, other):
|
||||
return self.tensor > other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rgt__(self, other):
|
||||
return other > self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ge__(self, other):
|
||||
return self.tensor >= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rge__(self, other):
|
||||
return other >= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lt__(self, other):
|
||||
return self.tensor < other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rlt__(self, other):
|
||||
return other < self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __le__(self, other):
|
||||
return self.tensor <= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rle__(self, other):
|
||||
return other <= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __eq__(self, other):
|
||||
return torch.equal(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ne__(self, other):
|
||||
return not torch.equal(self.tensor, other)
|
||||
|
||||
@_tensor_operation
|
||||
def __getitem__(self, slices):
|
||||
return self.tensor.__getitem__(slices)
|
||||
# if isinstance(slices, slice):
|
||||
# slices = [slices]
|
||||
# src_shape = self.shape
|
||||
# dst_shape = []
|
||||
# curr = 0
|
||||
# for sl in slices:
|
||||
# if isinstance(sl, constexpr) and sl.value is None:
|
||||
# dst_shape.append(1)
|
||||
# elif sl == slice(None, None, None):
|
||||
# dst_shape.append(src_shape[curr].value)
|
||||
# curr += 1
|
||||
# ret = torch.reshape(self.tensor, dst_shape, )
|
||||
# return ret
|
||||
|
||||
@_tensor_operation
|
||||
def to(self, dtype, bitcast=False):
|
||||
return self.tensor.to(dtype)
|
||||
# if isinstance(bitcast, constexpr):
|
||||
# bitcast = bitcast.value
|
||||
# if bitcast:
|
||||
# return semantic.bitcast(self, dtype, )
|
||||
# return semantic.cast(self, dtype, )
|
||||
|
||||
|
||||
def _constexpr_to_value(v):
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
|
||||
class TritonLangProxy:
|
||||
_memory_map: MemoryMap
|
||||
_context: ExecutionContext
|
||||
|
||||
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
|
||||
self._memory_map = memory_map
|
||||
self._context = context
|
||||
|
||||
# Types
|
||||
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
|
||||
|
||||
# constexpr = debugger_constexpr
|
||||
|
||||
# Program functions
|
||||
|
||||
@_tensor_operation
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
cache_modifier="",
|
||||
eviction_policy="",
|
||||
volatile=False,
|
||||
):
|
||||
return self._memory_map.load(pointer, mask, other)
|
||||
|
||||
@_tensor_operation
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
return self._memory_map.store(pointer, value, mask)
|
||||
|
||||
@_tensor_operation
|
||||
def program_id(self, axis):
|
||||
assert axis < len(self._context.program_id)
|
||||
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def num_programs(self, axis):
|
||||
assert axis < len(self._context.program_size)
|
||||
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def arange(self, start, end):
|
||||
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def zeros(self, shape, dtype):
|
||||
for i, d in enumerate(shape):
|
||||
if not isinstance(d, debugger_constexpr):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
shape = [x.value for x in shape]
|
||||
if isinstance(dtype, triton.language.core.dtype):
|
||||
if dtype.is_fp32():
|
||||
dtype = torch.float32
|
||||
elif dtype.is_fp16():
|
||||
dtype = torch.float16
|
||||
elif dtype.is_bf16():
|
||||
dtype = torch.bfloat16
|
||||
elif dtype.is_int32():
|
||||
dtype = torch.int32
|
||||
elif dtype.is_int16():
|
||||
dtype = torch.int16
|
||||
elif dtype.is_int8():
|
||||
dtype = torch.int8
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype {dtype}")
|
||||
return torch.zeros(size=shape, dtype=dtype, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast(self, input, other):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast_to(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def cat(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def reshape(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
|
||||
assert input.dtype == other.dtype
|
||||
if trans_a:
|
||||
input = input.T
|
||||
if trans_b:
|
||||
other = other.T
|
||||
return torch.matmul(input=input, other=other)
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_cas(self, pointer, cmp, val):
|
||||
stored = self._memory_map.load(pointer, None, 0.0)
|
||||
if not isinstance(cmp, torch.Tensor):
|
||||
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
|
||||
if not isinstance(val, torch.Tensor):
|
||||
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
|
||||
if stored == cmp:
|
||||
self._memory_map.store(pointer, val, None)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xchg(self, pointer, val, mask=None):
|
||||
if isinstance(val, int):
|
||||
val = torch.tensor([val], dtype=torch.int32, device="cuda")
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
self._memory_map.store(pointer, val, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_add(self, pointer, val, mask=None):
|
||||
# arbitrary other value as it will masked during storing
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = stored + val
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_max(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.maximum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_min(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.minimum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_and(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_and(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_or(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_or(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xor(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_xor(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def where(self, condition, x, y):
|
||||
condition = _primitive_to_tensor(condition)
|
||||
x = _primitive_to_tensor(x)
|
||||
y = _primitive_to_tensor(y)
|
||||
return torch.where(condition, x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def umulhi(self, x, y):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def fdiv(self, x, y, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def exp(self, x):
|
||||
return torch.exp(x)
|
||||
|
||||
@_tensor_operation
|
||||
def log(self, x):
|
||||
return torch.log(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cos(self, x):
|
||||
return torch.cos(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sin(self, x):
|
||||
return torch.sin(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sqrt(self, x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
@_tensor_operation
|
||||
def globaltimer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def clock(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def debug_barrier(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def multiple_of(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def max_contiguous(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def abs(self, x):
|
||||
return torch.abs(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cdiv(self, x, div):
|
||||
return (x + div - 1) // div
|
||||
|
||||
@_tensor_operation
|
||||
def minimum(self, x, y):
|
||||
if isinstance(x, int):
|
||||
x = torch.tensor(x, device="cuda")
|
||||
if isinstance(y, int):
|
||||
y = torch.tensor(y, device="cuda")
|
||||
return torch.minimum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def maximum(self, x, y):
|
||||
return torch.maximum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def sigmoid(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def softmax(self, x, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def ravel(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def swizzle2d(self, i, j, size_i, size_j, size_g):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def zeros_like(self, input):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def max(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.max(input)
|
||||
return torch.max(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmax(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def min(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.min(input)
|
||||
return torch.min(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmin(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def sum(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.sum(input)
|
||||
return torch.sum(input, dim=axis)
|
||||
|
||||
@_tensor_operation
|
||||
def xor_sum(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
18
pkgs/triton/debugger/torch_wrapper.py
Normal file
18
pkgs/triton/debugger/torch_wrapper.py
Normal file
@@ -0,0 +1,18 @@
|
||||
try:
|
||||
import torch as _torch
|
||||
except ImportError:
|
||||
_torch = None
|
||||
|
||||
|
||||
class TorchWrapper:
|
||||
"""
|
||||
Helps in making torch an optional dependency
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _torch is None:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return getattr(_torch, name)
|
||||
|
||||
|
||||
torch = TorchWrapper()
|
||||
201
pkgs/triton/language/__init__.py
Normal file
201
pkgs/triton/language/__init__.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""isort:skip_file"""
|
||||
# Import order is significant here.
|
||||
|
||||
from . import math
|
||||
from . import extra
|
||||
from .standard import (
|
||||
cdiv,
|
||||
sigmoid,
|
||||
softmax,
|
||||
ravel,
|
||||
swizzle2d,
|
||||
zeros,
|
||||
zeros_like,
|
||||
)
|
||||
from .core import (
|
||||
abs,
|
||||
advance,
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
atomic_add,
|
||||
atomic_and,
|
||||
atomic_cas,
|
||||
atomic_max,
|
||||
atomic_min,
|
||||
atomic_or,
|
||||
atomic_xchg,
|
||||
atomic_xor,
|
||||
bfloat16,
|
||||
block_type,
|
||||
broadcast,
|
||||
broadcast_to,
|
||||
cat,
|
||||
constexpr,
|
||||
cos,
|
||||
debug_barrier,
|
||||
device_assert,
|
||||
device_print,
|
||||
dot,
|
||||
dtype,
|
||||
exp,
|
||||
expand_dims,
|
||||
full,
|
||||
fdiv,
|
||||
float16,
|
||||
float32,
|
||||
float64,
|
||||
float8e4,
|
||||
float8e5,
|
||||
function_type,
|
||||
int1,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
int8,
|
||||
load,
|
||||
log,
|
||||
make_block_ptr,
|
||||
max,
|
||||
max_contiguous,
|
||||
maximum,
|
||||
min,
|
||||
minimum,
|
||||
multiple_of,
|
||||
num_programs,
|
||||
pi32_t,
|
||||
pointer_type,
|
||||
program_id,
|
||||
reduce,
|
||||
reshape,
|
||||
sin,
|
||||
sqrt,
|
||||
static_assert,
|
||||
static_print,
|
||||
store,
|
||||
sum,
|
||||
static_range,
|
||||
tensor,
|
||||
trans,
|
||||
triton,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
uint8,
|
||||
umulhi,
|
||||
view,
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
)
|
||||
from .random import (
|
||||
pair_uniform_to_normal,
|
||||
philox,
|
||||
philox_impl,
|
||||
rand,
|
||||
rand4x,
|
||||
randint,
|
||||
randint4x,
|
||||
randn,
|
||||
randn4x,
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"abs",
|
||||
"advance",
|
||||
"arange",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"atomic_add",
|
||||
"atomic_and",
|
||||
"atomic_cas",
|
||||
"atomic_max",
|
||||
"atomic_min",
|
||||
"atomic_or",
|
||||
"atomic_xchg",
|
||||
"atomic_xor",
|
||||
"bfloat16",
|
||||
"block_type",
|
||||
"broadcast",
|
||||
"broadcast_to",
|
||||
"builtin",
|
||||
"cat",
|
||||
"cdiv",
|
||||
"constexpr",
|
||||
"cos",
|
||||
"debug_barrier",
|
||||
"device_assert",
|
||||
"device_print",
|
||||
"dot",
|
||||
"dtype",
|
||||
"exp",
|
||||
"expand_dims",
|
||||
"extra",
|
||||
"fdiv",
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"float8e4",
|
||||
"float8e5",
|
||||
"full",
|
||||
"function_type",
|
||||
"int1",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"math",
|
||||
"load",
|
||||
"log",
|
||||
"make_block_ptr",
|
||||
"max",
|
||||
"max_contiguous",
|
||||
"maximum",
|
||||
"min",
|
||||
"minimum",
|
||||
"multiple_of",
|
||||
"num_programs",
|
||||
"pair_uniform_to_normal",
|
||||
"philox",
|
||||
"philox_impl",
|
||||
"pi32_t",
|
||||
"pointer_type",
|
||||
"program_id",
|
||||
"rand",
|
||||
"rand4x",
|
||||
"randint",
|
||||
"randint4x",
|
||||
"randn",
|
||||
"randn4x",
|
||||
"ravel",
|
||||
"reduce",
|
||||
"reshape",
|
||||
"sigmoid",
|
||||
"sin",
|
||||
"softmax",
|
||||
"sqrt",
|
||||
"static_range",
|
||||
"static_assert",
|
||||
"static_print",
|
||||
"store",
|
||||
"sum",
|
||||
"swizzle2d",
|
||||
"tensor",
|
||||
"trans",
|
||||
"triton",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint32_to_uniform_float",
|
||||
"uint64",
|
||||
"uint8",
|
||||
"umulhi",
|
||||
"view",
|
||||
"void",
|
||||
"where",
|
||||
"xor_sum",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
]
|
||||
BIN
pkgs/triton/language/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/core.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/core.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/math.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/math.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/random.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/random.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/semantic.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/semantic.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/__pycache__/standard.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/__pycache__/standard.cpython-310.pyc
Normal file
Binary file not shown.
1729
pkgs/triton/language/core.py
Normal file
1729
pkgs/triton/language/core.py
Normal file
File diff suppressed because it is too large
Load Diff
3
pkgs/triton/language/extra/__init__.py
Normal file
3
pkgs/triton/language/extra/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import cuda
|
||||
|
||||
__all__ = ['cuda']
|
||||
BIN
pkgs/triton/language/extra/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/extra/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/extra/__pycache__/cuda.cpython-310.pyc
Normal file
BIN
pkgs/triton/language/extra/__pycache__/cuda.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/language/extra/cuda.bc
Normal file
BIN
pkgs/triton/language/extra/cuda.bc
Normal file
Binary file not shown.
19
pkgs/triton/language/extra/cuda.py
Normal file
19
pkgs/triton/language/extra/cuda.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
from .. import core
|
||||
|
||||
__path__ = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
|
||||
{tuple(): ("globaltimer", core.dtype("int64")),
|
||||
}, is_pure=False, _builder=_builder)
|
||||
|
||||
|
||||
@core.extern
|
||||
def smid(_builder=None):
|
||||
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
|
||||
{tuple(): ("smid", core.dtype("int32")),
|
||||
}, is_pure=True, _builder=_builder)
|
||||
1537
pkgs/triton/language/math.py
Normal file
1537
pkgs/triton/language/math.py
Normal file
File diff suppressed because it is too large
Load Diff
178
pkgs/triton/language/random.py
Normal file
178
pkgs/triton/language/random.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import triton
|
||||
from . import core as tl
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
|
||||
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
|
||||
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
||||
|
||||
# -------------------
|
||||
# randint
|
||||
# -------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||
"""
|
||||
for _ in tl.static_range(n_rounds):
|
||||
# for _ in range(n_rounds):
|
||||
# update random state
|
||||
A = PHILOX_ROUND_A
|
||||
B = PHILOX_ROUND_B
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
# raise key
|
||||
k0 = k0 + PHILOX_KEY_A
|
||||
k1 = k1 + PHILOX_KEY_B
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
seed = seed.to(tl.uint64)
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.uint32)
|
||||
c0 = c0.to(tl.uint32, bitcast=True)
|
||||
c1 = c1.to(tl.uint32, bitcast=True)
|
||||
c2 = c2.to(tl.uint32, bitcast=True)
|
||||
c3 = c3.to(tl.uint32, bitcast=True)
|
||||
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
||||
block of random :code:`int32`.
|
||||
|
||||
If you need multiple streams of random numbers,
|
||||
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
||||
blocks of random :code:`int32`.
|
||||
|
||||
This is the maximally efficient entry point
|
||||
to Triton's Philox pseudo-random number generator.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
# _0 = tl.zeros(offset.shape, offset.dtype)
|
||||
_0 = offset * 0
|
||||
return philox(seed, offset, _0, _0, _0, n_rounds)
|
||||
|
||||
|
||||
# -------------------
|
||||
# rand
|
||||
# -------------------
|
||||
|
||||
# @triton.jit
|
||||
# def uint32_to_uniform_float(x):
|
||||
# """
|
||||
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
# """
|
||||
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
# return x * two_to_the_minus_32
|
||||
|
||||
@triton.jit
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
"""
|
||||
x = x.to(tl.int32, bitcast=True)
|
||||
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
||||
scale = 4.6566127342e-10
|
||||
x = tl.where(x < 0, -x - 1, x)
|
||||
return x * scale
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`U(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
offset = offset.to(tl.uint32, bitcast=True)
|
||||
source = randint(seed, offset, n_rounds)
|
||||
return uint32_to_uniform_float(source)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offsets` block,
|
||||
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
offsets = offsets.to(tl.uint32, bitcast=True)
|
||||
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
u3 = uint32_to_uniform_float(i3)
|
||||
u4 = uint32_to_uniform_float(i4)
|
||||
return u1, u2, u3, u4
|
||||
|
||||
# -------------------
|
||||
# randn
|
||||
# -------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pair_uniform_to_normal(u1, u2):
|
||||
"""Box-Muller transform"""
|
||||
u1 = tl.maximum(1.0e-7, u1)
|
||||
th = 6.283185307179586 * u2
|
||||
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||
return r * tl.cos(th), r * tl.sin(th)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
n1, _ = pair_uniform_to_normal(u1, u2)
|
||||
return n1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
|
||||
n1, n2 = pair_uniform_to_normal(u1, u2)
|
||||
n3, n4 = pair_uniform_to_normal(u3, u4)
|
||||
return n1, n2, n3, n4
|
||||
1491
pkgs/triton/language/semantic.py
Normal file
1491
pkgs/triton/language/semantic.py
Normal file
File diff suppressed because it is too large
Load Diff
98
pkgs/triton/language/standard.py
Normal file
98
pkgs/triton/language/standard.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..runtime.jit import jit
|
||||
from . import core
|
||||
|
||||
# -----------------------
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
|
||||
@jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
Computes the ceiling division of :code:`x` by :code:`div`
|
||||
|
||||
:param x: the input number
|
||||
:type input: Block
|
||||
:param div: the divisor
|
||||
:param div: Block
|
||||
"""
|
||||
return (x + div - 1) // div
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("sigmoid")
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + core.exp(-x))
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - core.max(x, 0)
|
||||
num = core.exp(z)
|
||||
den = core.sum(num, 0)
|
||||
return core.fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@jit
|
||||
def ravel(x):
|
||||
"""
|
||||
Returns a contiguous flattened view of :code:`x`.
|
||||
|
||||
:param x: the input tensor
|
||||
:type x: Block
|
||||
"""
|
||||
return core.view(x, [x.numel])
|
||||
|
||||
|
||||
@jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
Transforms indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
[4 , 5 , 6 , 7 ],
|
||||
[8 , 9 , 10, 11],
|
||||
[12, 13, 14, 15]]
|
||||
into
|
||||
[[0, 2, 4 , 6 ],
|
||||
[1, 3, 5 , 7 ],
|
||||
[8, 10, 12, 14],
|
||||
[9, 11, 13, 15]]
|
||||
"""
|
||||
# "unrolled index in array"
|
||||
ij = i * size_j + j
|
||||
# number of elements in `size_g` groups
|
||||
# of `size_j` columns
|
||||
size_gj = size_g * size_j
|
||||
# index of the group in which (i,j) is
|
||||
group_id = ij // size_gj
|
||||
# row-index of the first element of this group
|
||||
off_i = group_id * size_g
|
||||
# last group may have fewer rows
|
||||
size_g = core.minimum(size_i - off_i, size_g)
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
return new_i, new_j
|
||||
|
||||
|
||||
@jit
|
||||
def zeros(shape, dtype):
|
||||
"""
|
||||
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
||||
:type dtype: DType
|
||||
"""
|
||||
return core.full(shape, 0, dtype)
|
||||
|
||||
|
||||
@jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
17
pkgs/triton/ops/__init__.py
Normal file
17
pkgs/triton/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# from .conv import _conv, conv
|
||||
from . import blocksparse
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from .flash_attention import attention
|
||||
from .matmul import _matmul, matmul
|
||||
from .bmm_matmul import _bmm, bmm
|
||||
|
||||
__all__ = [
|
||||
"blocksparse",
|
||||
"_cross_entropy",
|
||||
"cross_entropy",
|
||||
"_matmul",
|
||||
"matmul",
|
||||
"_bmm",
|
||||
"bmm",
|
||||
"attention",
|
||||
]
|
||||
BIN
pkgs/triton/ops/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/bmm_matmul.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/bmm_matmul.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/cross_entropy.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/cross_entropy.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/flash_attention.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/flash_attention.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/matmul.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/matmul.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc
Normal file
Binary file not shown.
7
pkgs/triton/ops/blocksparse/__init__.py
Normal file
7
pkgs/triton/ops/blocksparse/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .matmul import matmul
|
||||
from .softmax import softmax
|
||||
|
||||
__all__ = [
|
||||
"matmul",
|
||||
"softmax",
|
||||
]
|
||||
BIN
pkgs/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc
Normal file
Binary file not shown.
437
pkgs/triton/ops/blocksparse/matmul.py
Normal file
437
pkgs/triton/ops/blocksparse/matmul.py
Normal file
@@ -0,0 +1,437 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# ********************************************************
|
||||
# --------------------------------------------------------
|
||||
# Sparse = Dense x Dense (SDD)
|
||||
# This operation uses super-blocking to make sure that
|
||||
# it's done efficiently when small blocks can be grouped
|
||||
# together
|
||||
# --------------------------------------------------------
|
||||
# ********************************************************
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
block_id = tl.program_id(0) + grid_offset
|
||||
lut += block_id * 3
|
||||
# offsets
|
||||
off_z = tl.program_id(2) # batch
|
||||
off_h = tl.load(lut + 0) # head
|
||||
|
||||
# initialize pointers to A
|
||||
start_am = tl.load(lut + 1)
|
||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
a_ptrs = A \
|
||||
+ off_z * stride_za \
|
||||
+ off_h * stride_ha \
|
||||
+ offs_am[:, None] * stride_ma \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
# initialize pointers to B
|
||||
start_bn = tl.load(lut + 2)
|
||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
b_ptrs = B \
|
||||
+ off_z * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_nb \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
else:
|
||||
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
|
||||
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b, out_dtype=tl.float32)
|
||||
a_ptrs += TILE_K * stride_ak
|
||||
b_ptrs += TILE_K * stride_bk
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# ---------------- #
|
||||
# Epilogue #
|
||||
# ---------------- #
|
||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||
pc = C \
|
||||
+ off_z * stride_zc \
|
||||
+ block_id * stride_hc \
|
||||
+ offs_cm[:, None] * stride_mc \
|
||||
+ offs_cn[None, :] * stride_nc
|
||||
tl.store(pc, c, mask=True)
|
||||
|
||||
|
||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# (A * B)^T = B^T * A^T
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
# shape constraints
|
||||
a_dim = -2 if trans_a else -1
|
||||
b_dim = -1 if trans_b else -2
|
||||
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
|
||||
if Ka != Kb:
|
||||
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
|
||||
# allocate output
|
||||
if out is None:
|
||||
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (a.shape[0], lut.shape[0], block, block)
|
||||
c = out
|
||||
grid = [c.shape[1], 1, c.shape[0]]
|
||||
_sdd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def sdd_lut(layout, block, device):
|
||||
lut = layout.nonzero(as_tuple=False).to(device).int()
|
||||
lut = lut.contiguous()
|
||||
return lut, None
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Sparse x Dense (DSD)
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
# in order to minimize computations in the inner loop of the matmul kernel.
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
num_pid_m = tl.num_programs(0)
|
||||
num_pid_n = tl.num_programs(1)
|
||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||
pidz = tl.program_id(2)
|
||||
header = lut + pid_n * 4
|
||||
offset = tl.load(header + 0)
|
||||
K = tl.load(header + 1)
|
||||
column = tl.load(header + 2)
|
||||
off_h = tl.load(header + 3)
|
||||
pinc = lut + offset
|
||||
# initialize pointers to A (sparse)
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8) # compiler hint
|
||||
offs_am = tl.arange(0, TILE_M)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
pa = A + pidz * stride_az \
|
||||
+ block_id * stride_ha \
|
||||
+ offs_am[:, None] * stride_am \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
# initialize pointers to B (dense)
|
||||
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
|
||||
start_bk = tl.load(pinc)
|
||||
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
|
||||
offs_bk = start_bk + tl.arange(0, TILE_K)
|
||||
pb = B + pidz * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_bn \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
a = tl.load(pa)
|
||||
b = tl.load(pb)
|
||||
acc += tl.dot(a, b, out_dtype=tl.float32)
|
||||
pa += inc_a
|
||||
pb += inc_b * stride_bk
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# initialize pointers to C
|
||||
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
pc = C \
|
||||
+ off_h * stride_hc \
|
||||
+ pidz * stride_zc \
|
||||
+ offs_cm[:, None] * stride_cm \
|
||||
+ offs_cn[None, :] * stride_cn
|
||||
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
|
||||
|
||||
|
||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# shapes / dtypes
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# allocate output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
CS2 = BS3 if trans_c else AS1
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
if out is None:
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (CS0, CS1, CS2, CS3)
|
||||
c = out
|
||||
# meta-parameter heuristics
|
||||
TILE_N = 128
|
||||
# compute output
|
||||
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
BS3, AS1, lut,
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
)
|
||||
# exit()
|
||||
return c
|
||||
|
||||
|
||||
def dsd_lut(layout, block, step, trans, device):
|
||||
"""
|
||||
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
|
||||
Example (BLOCK=32, STEP=16)
|
||||
[[1, 0, 0, 1, 0],
|
||||
[0, 1, 1, 0, 1],
|
||||
[1, 0, 1, 0, 0]]
|
||||
|
||||
Then the offsets for A are
|
||||
[0 , 16, 32, 48] <- row 0
|
||||
\\----/ \\----/
|
||||
col=0 col=3
|
||||
[64, 80, 96, 112, 128, 144] <- row 1
|
||||
\\----/ \\----/ \\------/
|
||||
col=1 col=2 col=3
|
||||
[160, 176, 192, 208]
|
||||
which leads to increments table
|
||||
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
|
||||
|
||||
Because B is dense, the offsets are
|
||||
[0, 16, 96, 112] <- row 0
|
||||
[32, 48, 64, 80] <- row 1
|
||||
[0, 16, 64, 80] <- row 2
|
||||
"""
|
||||
sizes = torch.sum(layout, 2 if trans else 1)
|
||||
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
|
||||
sizes = sizes.flatten()
|
||||
segments = sizes * step
|
||||
# pointer increments
|
||||
if trans:
|
||||
nnz = layout.nonzero(as_tuple=False)
|
||||
else:
|
||||
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
||||
num_blocks = nnz.size(0)
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
|
||||
# -------------------------------
|
||||
# dense input pointer increments
|
||||
# -------------------------------
|
||||
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
|
||||
# that is smaller than the block size, so we need to do a bit of extra work
|
||||
# to handle this case
|
||||
B_idx = nnz[:, 2] * block
|
||||
B_incs = B_idx.clone()
|
||||
B_incs[1:] -= B_idx[:-1]
|
||||
div = block // step
|
||||
B_incs = B_incs.view(-1, 1).repeat(1, div)
|
||||
B_incs[:, 1:] = step
|
||||
B_incs[:, 0] -= (div - 1) * step
|
||||
# first increment for each reduction is actually the offset
|
||||
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
|
||||
B_incs = B_incs.view(-1)
|
||||
# -------------------------------
|
||||
# sparse input pointer increments
|
||||
# -------------------------------
|
||||
# same as above, except that the increments are in the sparse memory layout
|
||||
if trans:
|
||||
A_idx = torch.arange(num_blocks, device=layout.device)
|
||||
else:
|
||||
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone().long()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
|
||||
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
A_incs = A_idx * block * block
|
||||
A_incs[1:] -= A_idx[:-1] * block * block
|
||||
A_incs = A_incs.view(-1, 1).repeat(1, div)
|
||||
if trans:
|
||||
A_incs[:, 1:] = step
|
||||
A_incs[:, 0] -= (div - 1) * step
|
||||
else:
|
||||
A_incs[:, 1:] = step * block
|
||||
A_incs[:, 0] -= (div - 1) * step * block
|
||||
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
|
||||
A_incs = A_incs.view(-1)
|
||||
# create header
|
||||
width = col_id.size(0)
|
||||
offsets = offsets * 2 * div + 4 * width
|
||||
segments = segments * div
|
||||
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
||||
# create increments
|
||||
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
||||
# pad by a factor 2*MAX_NUM_STAGES
|
||||
# to accommodate pre-fetching inside the kernel
|
||||
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
|
||||
incs = torch.cat((incs, pad))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
# create locks
|
||||
return lut, width
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
# AB = (B^T A^T)^T
|
||||
|
||||
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
|
||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.da_lut = da_lut
|
||||
ctx.da_width = da_width
|
||||
ctx.db_lut = db_lut
|
||||
ctx.db_width = db_width
|
||||
ctx.mode = mode
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.trans_a = trans_a
|
||||
ctx.trans_b = trans_b
|
||||
ctx.trans_c = trans_c
|
||||
ctx.has_out = out is not None
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
# saved for backward
|
||||
a, b = ctx.saved_tensors
|
||||
da, db = None, None
|
||||
mode = ctx.mode
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](
|
||||
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||
)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](
|
||||
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||
)
|
||||
dout = dc if ctx.has_out else None
|
||||
return da, db, None, None, None,\
|
||||
None, None, None, None,\
|
||||
None, None, None, None, None, dout
|
||||
|
||||
|
||||
class matmul:
|
||||
|
||||
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
|
||||
if mode not in ['sdd', 'dsd', 'dds']:
|
||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||
self.block = block
|
||||
self.mode = mode
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.trans_c = trans_c
|
||||
self.layout = layout
|
||||
self.spdims = layout.shape
|
||||
step = min(block, 32)
|
||||
if self.mode == 'sdd':
|
||||
self.c_lut, self.c_width = sdd_lut(layout, block, device)
|
||||
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
|
||||
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
|
||||
if self.mode == 'dsd':
|
||||
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
|
||||
self.da_lut, self.da_width = sdd_lut(layout, block, device)
|
||||
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
|
||||
if self.mode == 'dds':
|
||||
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
|
||||
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
|
||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||
|
||||
def __call__(self, a, b, out=None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
return c
|
||||
239
pkgs/triton/ops/blocksparse/softmax.py
Normal file
239
pkgs/triton/ops/blocksparse/softmax.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n <= 128:
|
||||
return 1
|
||||
if n <= 256:
|
||||
return 2
|
||||
if n <= 512:
|
||||
return 4
|
||||
if n <= 4096:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# pointer offset
|
||||
off_a = z * stride_xz
|
||||
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
|
||||
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load X
|
||||
mask = block_n < size
|
||||
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
|
||||
a = a.to(tl.float32)
|
||||
# compute
|
||||
out = a
|
||||
out *= scale
|
||||
# apply relative attention
|
||||
if R is not None:
|
||||
R += z * stride_zr
|
||||
R += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent)
|
||||
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
|
||||
out += rel_logits
|
||||
out = out.to(tl.float32)
|
||||
# apply causal mask
|
||||
out = tl.where((ns > m) & is_causal, -float("inf"), out)
|
||||
# computation
|
||||
out = tl.softmax(out)
|
||||
# write-back
|
||||
tl.store(Out + off_a + lane_n, out, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# row-col offset
|
||||
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
|
||||
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
|
||||
mask = block_n < size
|
||||
# pointers
|
||||
As = Out + z * stride_zout + off_mn
|
||||
DOuts = DOut + z * stride_zdout + off_mn
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load data
|
||||
a = tl.load(As + lane_n, mask=mask, other=0.0)
|
||||
a = a.to(tl.float32)
|
||||
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
|
||||
dout = dout.to(tl.float32)
|
||||
# compute
|
||||
a = tl.where((ns > m) & is_causal & (a == a), 0., a)
|
||||
da = a * (dout - tl.sum(a * dout, 0))
|
||||
# apply relative attention
|
||||
if DR is not None:
|
||||
DR += z * stride_zr
|
||||
DR += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
|
||||
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
|
||||
da = da * scale
|
||||
# convert da
|
||||
# write-back
|
||||
DAs = DA + z * stride_zdx + off_mn
|
||||
tl.store(DAs + lane_n, da, mask=mask)
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
sizes = _empty.clone()
|
||||
# sizes along rows
|
||||
for h in range(layout.shape[0]):
|
||||
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||
total_sizes = sizes * block
|
||||
# offsets in block format
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# block indices
|
||||
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||
lut = torch.cat((header, columns)).type(torch.int32).to(device)
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
M = a.shape[0]
|
||||
grid = [spdims[0], spdims[1] * block, M]
|
||||
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
|
||||
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=triton.next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
)
|
||||
# save to context
|
||||
# ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(out, lut)
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.maxlut = maxlut
|
||||
ctx.scale = scale
|
||||
ctx.rel_shape = rel_shape
|
||||
ctx.rel_strides = rel_strides
|
||||
ctx.rel_dtype = a.dtype
|
||||
ctx.is_dense = is_dense
|
||||
ctx.is_causal = is_causal
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
# retrieve from context
|
||||
out, lut = ctx.saved_tensors
|
||||
# relative logits gradients
|
||||
dr = None
|
||||
if ctx.needs_input_grad[3]:
|
||||
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
|
||||
# run kernel
|
||||
M = out.shape[0]
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
|
||||
|
||||
class softmax:
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
self.block = block
|
||||
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
|
||||
self.is_dense = is_dense
|
||||
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError(f"relative position embedding must be {a.dtype}")
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
return a
|
||||
163
pkgs/triton/ops/bmm_matmul.py
Normal file
163
pkgs/triton/ops/bmm_matmul.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [1]:
|
||||
# TODO support block size 16 for MFMA dot op
|
||||
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 4 if block_n <= 64 else 8
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
#for split_k in [2, 4, 8, 16]:
|
||||
# configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
for block_m in [64, 128, 256]:
|
||||
for block_n in [64, 128, 256]:
|
||||
for block_k in [32, 64, 128]:
|
||||
num_warps = 8 if block_n <= 64 else 16
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=1, num_warps=num_warps))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
] + get_configs_compute_bound() + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % args['BLOCK_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _bmm_kernel(A, B, C, M, N, K,
|
||||
stride_aq, stride_am, stride_ak,
|
||||
stride_bq, stride_bk, stride_bn,
|
||||
stride_cq, stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
|
||||
idx_q = tl.program_id(1) # batch dimension for BMM
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
idx_q = tl.program_id(1) # batch dimension for BMM
|
||||
idx_m = rm[:, None]
|
||||
idx_n = rn[None, :]
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn + idx_q * stride_cq)
|
||||
mask = (idx_m < M) & (idx_n < N)
|
||||
# handles write-back with reduction-splitting
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
class _bmm(torch.autograd.Function):
|
||||
kernel = _bmm_kernel
|
||||
|
||||
_locks = {}
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, dot_out_dtype):
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
|
||||
#only MR support Trans layout
|
||||
if hasattr(torch, "corex"):
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if (capability < 71):
|
||||
if a.stride(0) >= 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) >= 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[0] == b.shape[0], "incompatible dimensions"
|
||||
assert a.shape[2] == b.shape[1], "incompatible dimensions"
|
||||
B, M, K = a.shape
|
||||
_, _, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((B, M, N), device=device, dtype=a.dtype)
|
||||
if dot_out_dtype is None:
|
||||
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
else:
|
||||
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
|
||||
if dot_out_dtype == torch.float16:
|
||||
dot_out_dtype = tl.float16
|
||||
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), B, 1)
|
||||
_bmm_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1), a.stride(2),
|
||||
b.stride(0), b.stride(1), b.stride(2),
|
||||
c.stride(0), c.stride(1), c.stride(2),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
GROUP_M=8)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, dot_out_dtype=None):
|
||||
return _bmm._call(a, b, dot_out_dtype=dot_out_dtype)
|
||||
|
||||
bmm = _bmm.apply
|
||||
94
pkgs/triton/ops/cross_entropy.py
Normal file
94
pkgs/triton/ops/cross_entropy.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(N):
|
||||
if N < 2048:
|
||||
return 4
|
||||
elif N < 8192:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK)
|
||||
idx = tl.load(IDX + row)
|
||||
# pointers to logit and probs
|
||||
LOGITS = LOGITS + row * N + cols
|
||||
WRIT_PROBS = PROBS + row * N + cols
|
||||
READ_PROBS = PROBS + row * N + idx
|
||||
# write-back negative log-probs
|
||||
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits - tl.max(logits, 0)
|
||||
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
|
||||
tl.store(WRIT_PROBS, probs, mask=cols < N)
|
||||
# There is a bug in the compiler, which fails to insert a barrier here.
|
||||
# We add it explicitly for now. Will be fixed soon.
|
||||
tl.debug_barrier()
|
||||
# write-back loss
|
||||
probs = tl.load(READ_PROBS)
|
||||
tl.store(LOSS + row, probs)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK)
|
||||
idx = tl.load(IDX + row)
|
||||
# pointers to probs
|
||||
PROBS = PROBS + row * N + cols
|
||||
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
# and we have -log(p[k]) stored in PROBS, so this is easy
|
||||
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
|
||||
probs = tl.exp(probs.to(tl.float32))
|
||||
delta = cols == idx
|
||||
# write result in-place in PROBS
|
||||
dout = tl.load(DPROBS + row)
|
||||
din = (probs - delta) * dout
|
||||
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
@classmethod
|
||||
def forward(cls, ctx, logits, indices):
|
||||
# make sure we can use triton
|
||||
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
|
||||
# make kernel
|
||||
device, dtype = logits.device, logits.dtype
|
||||
n_cols = logits.shape[-1]
|
||||
# run the kernel
|
||||
result = torch.empty_like(indices, dtype=dtype, device=device)
|
||||
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
|
||||
grid = lambda opt: (logits.numel() // n_cols, )
|
||||
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
|
||||
# save for backward
|
||||
ctx.save_for_backward(neg_logprobs, indices)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def backward(cls, ctx, dneg_logprobs):
|
||||
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
so we initialize the gradient as neg_logprobs, so we can just exponentiate
|
||||
to get p[k], which is most of what we need... neg_logprobs will be
|
||||
modified in place to become the gradient we want
|
||||
"""
|
||||
# load saved tensors
|
||||
neg_logprobs, indices = ctx.saved_tensors
|
||||
# run the kernel
|
||||
# neg_logprobs will be modified in place to become our gradient:
|
||||
n_cols = neg_logprobs.shape[-1]
|
||||
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
|
||||
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
|
||||
return neg_logprobs, None
|
||||
|
||||
|
||||
cross_entropy = _cross_entropy.apply
|
||||
271
pkgs/triton/ops/flash_attention.py
Normal file
271
pkgs/triton/ops/flash_attention.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.build import is_corex
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
l_prev *= tl.exp(m_prev - m_curr)
|
||||
# attention weights
|
||||
p = tl.exp(qk - m_curr[:, None])
|
||||
l_curr = tl.sum(p, 1) + l_prev
|
||||
# rescale operands of matmuls
|
||||
l_rcp = 1. / l_curr
|
||||
p *= l_rcp[:, None]
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_prev = l_curr
|
||||
m_prev = m_curr
|
||||
# update pointers
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_prev)
|
||||
tl.store(m_ptrs, m_prev)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
# only support for Ampere now
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if not is_corex():
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2 if not is_corex() else 1,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
BLOCK = 128 if not is_corex() else 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
|
||||
num_warps = 16 if is_corex() and ctx.BLOCK_DMODEL > 64 else 8
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
184
pkgs/triton/ops/matmul.py
Normal file
184
pkgs/triton/ops/matmul.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
if hasattr(torch, "corex"):
|
||||
return configs
|
||||
for num_stages in [1, 2]:
|
||||
# TODO support block size 16 for MFMA dot op
|
||||
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
if hasattr(torch, "corex"):
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for num_stages in [1, 2]:
|
||||
num_warps = 16 if block_m >= 128 or block_n >=128 or block_k >= 128 else 8
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
return configs
|
||||
|
||||
def get_nv_config():
|
||||
configs = []
|
||||
if hasattr(torch, "corex"):
|
||||
return configs
|
||||
configs = [# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
]
|
||||
return configs
|
||||
|
||||
@triton.autotune(
|
||||
configs=get_nv_config() + get_configs_compute_bound() + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
kernel = _kernel
|
||||
|
||||
_locks = {}
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, dot_out_dtype):
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||
if dot_out_dtype is None:
|
||||
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
else:
|
||||
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
|
||||
if dot_out_dtype == torch.float16:
|
||||
dot_out_dtype = tl.float16
|
||||
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
GROUP_M=8)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, dot_out_dtype=None):
|
||||
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype)
|
||||
|
||||
|
||||
matmul = _matmul.apply
|
||||
164
pkgs/triton/ops/matmul_perf_model.py
Normal file
164
pkgs/triton/ops/matmul_perf_model.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import heapq
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.runtime import driver
|
||||
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8 and dtype == torch.float32:
|
||||
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
dtype = A.dtype
|
||||
dtsize = A.element_size()
|
||||
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
num_cta_k = SPLIT_K
|
||||
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||
|
||||
# If the input is smaller than the block size
|
||||
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
||||
|
||||
# time to compute
|
||||
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
||||
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = driver.utils.get_device_properties(device)["multiprocessor_count"]
|
||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
||||
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
|
||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||
# assume 80% of (following) loads are in L2 cache
|
||||
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
|
||||
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
|
||||
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
|
||||
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
|
||||
# total
|
||||
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
||||
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
||||
# loading time in ms
|
||||
load_ms = total_dram / dram_bw + total_l2 / l2_bw
|
||||
|
||||
# estimate storing time
|
||||
store_bw = dram_bw * 0.6 # :o
|
||||
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
|
||||
if SPLIT_K == 1:
|
||||
store_ms = store_c_dram / store_bw
|
||||
else:
|
||||
reduce_bw = store_bw
|
||||
store_ms = store_c_dram / reduce_bw
|
||||
# c.zero_()
|
||||
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
|
||||
store_ms += zero_ms
|
||||
|
||||
total_time_ms = compute_ms + load_ms + store_ms
|
||||
if debug:
|
||||
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
|
||||
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
|
||||
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||
return total_time_ms
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args):
|
||||
device = torch.cuda.current_device()
|
||||
capability = torch.cuda.get_device_capability()
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
dtsize = named_args['A'].element_size()
|
||||
dtype = named_args['A'].dtype
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||
|
||||
max_shared_memory = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory <= max_shared_memory:
|
||||
pruned_configs.append(config)
|
||||
configs = pruned_configs
|
||||
|
||||
# Some dtypes do not allow atomic_add
|
||||
if dtype not in [torch.float16, torch.float32]:
|
||||
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
|
||||
|
||||
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||
configs_map = {}
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
|
||||
|
||||
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
|
||||
if key in configs_map:
|
||||
configs_map[key].append((config, num_stages))
|
||||
else:
|
||||
configs_map[key] = [(config, num_stages)]
|
||||
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if capability[0] >= 8:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
||||
mma_cycles = mmas / min(4, num_warps) * 8
|
||||
|
||||
ldgsts_latency = 300 # Does this matter?
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
else: # Volta & Turing only supports num_stages <= 2
|
||||
if hasattr(torch, "corex"):
|
||||
for stage in range(len(v)):
|
||||
random_config = v[stage][0]
|
||||
random_config.num_stages = v[stage][1]
|
||||
pruned_configs.append(random_config)
|
||||
else:
|
||||
random_config = v[0][0]
|
||||
random_config.num_stages = 2
|
||||
pruned_configs.append(random_config)
|
||||
return pruned_configs
|
||||
21
pkgs/triton/runtime/__init__.py
Normal file
21
pkgs/triton/runtime/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .driver import driver
|
||||
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
|
||||
version_key)
|
||||
|
||||
__all__ = [
|
||||
"driver",
|
||||
"Config",
|
||||
"Heuristics",
|
||||
"autotune",
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
"reinterpret",
|
||||
"TensorWrapper",
|
||||
"OutOfResources",
|
||||
"MockTensor",
|
||||
"Autotuner",
|
||||
]
|
||||
BIN
pkgs/triton/runtime/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/autotuner.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/autotuner.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/cache.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/cache.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/driver.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/driver.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/errors.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/errors.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/runtime/__pycache__/jit.cpython-310.pyc
Normal file
BIN
pkgs/triton/runtime/__pycache__/jit.cpython-310.pyc
Normal file
Binary file not shown.
305
pkgs/triton/runtime/autotuner.py
Normal file
305
pkgs/triton/runtime/autotuner.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from ..testing import do_bench
|
||||
from .jit import KernelInterface
|
||||
from .cache import default_cache_dir
|
||||
|
||||
|
||||
def build_best_config_hash(args_names, key):
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(f"{'_'.join(args_names) + str(key)}\n".encode())
|
||||
cfg_hash = hasher.hexdigest()
|
||||
cfg_hash_dir = os.path.join(cache_dir, cfg_hash)
|
||||
cfg_hash_file = os.path.splitext(cfg_hash)[0] + ".best_config"
|
||||
cfg_hash_file = os.path.join(cfg_hash_dir, cfg_hash_file)
|
||||
return cfg_hash_dir, cfg_hash_file
|
||||
|
||||
|
||||
def load_best_config(args_names, key):
|
||||
_, cfg_hash_file = build_best_config_hash(args_names, key)
|
||||
if os.path.exists(cfg_hash_file):
|
||||
with open(cfg_hash_file) as fd:
|
||||
best_config = json.loads(fd.read())
|
||||
num_warps = best_config.pop('num_warps') if 'num_warps' in best_config else 4
|
||||
num_stages = best_config.pop('num_stages') if 'num_stages' in best_config else 1
|
||||
return best_config, num_warps, num_stages
|
||||
return None
|
||||
|
||||
|
||||
def save_best_config(cfg, args_names, key):
|
||||
cfg_hash_dir, cfg_hash_file = build_best_config_hash(args_names, key)
|
||||
if os.path.exists(cfg_hash_dir):
|
||||
return
|
||||
os.makedirs(cfg_hash_dir, exist_ok=True)
|
||||
with open(cfg_hash_file, "w") as fd:
|
||||
fd.write(
|
||||
json.dumps(
|
||||
{
|
||||
**cfg.kwargs,
|
||||
"num_warps": cfg.num_warps,
|
||||
"num_stages": cfg.num_stages,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
super().__init__(self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
try:
|
||||
return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
return [float('inf'), float('inf'), float('inf')]
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
all_args = {**self.nargs, **kwargs}
|
||||
_args = []
|
||||
for name in self.arg_names:
|
||||
if name in all_args:
|
||||
_args.append(all_args[name])
|
||||
key = [_args[i] for i in self.key_idx]
|
||||
divisibility = 16
|
||||
for arg in args:
|
||||
if hasattr(arg, "data_ptr"):
|
||||
key.append(arg.dtype)
|
||||
key.append(arg.data_ptr() % divisibility == 0)
|
||||
elif isinstance(arg, int):
|
||||
key.append(arg)
|
||||
key = tuple(key)
|
||||
if key not in self.cache:
|
||||
load_config = load_best_config(self.arg_names, key)
|
||||
if load_config:
|
||||
best_config, num_warps, num_stages = load_config
|
||||
config = Config(best_config, num_warps, num_stages)
|
||||
self.cache[key] = config
|
||||
self.hook(args)
|
||||
else:
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
save_best_config(self.cache[key], self.arg_names, key)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
self.nargs = None
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
for config in self.prune_configs(kwargs):
|
||||
self.fn.warmup(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
||||
cooperatively execute using `8 * 32 = 256` threads.
|
||||
:type num_warps: int
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple times.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
resets the value of the provided tensor to `zero` before running any configuration.
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Heuristics(KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, values) -> None:
|
||||
self.fn = fn
|
||||
self.values = values
|
||||
self.arg_names = arg_names
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
for v, heur in self.values.items():
|
||||
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
|
||||
return self.fn.run(*args, **kwargs)
|
||||
|
||||
|
||||
def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
:param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
:type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
return decorator
|
||||
131
pkgs/triton/runtime/backends/cuda.c
Normal file
131
pkgs/triton/runtime/backends/cuda.c
Normal file
@@ -0,0 +1,131 @@
|
||||
#include "cuda.h"
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line) {
|
||||
if (code != CUDA_SUCCESS) {
|
||||
const char *prefix = "Triton Error [CUDA]: ";
|
||||
const char *str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) \
|
||||
{ \
|
||||
gpuAssert((ans), __FILE__, __LINE__); \
|
||||
if (PyErr_Occurred()) \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int device_id;
|
||||
if (!PyArg_ParseTuple(args, "i", &device_id))
|
||||
return NULL;
|
||||
// Get device handle
|
||||
CUdevice device;
|
||||
cuDeviceGet(&device, device_id);
|
||||
|
||||
// create a struct to hold device properties
|
||||
int max_shared_mem;
|
||||
int multiprocessor_count;
|
||||
int sm_clock_rate;
|
||||
int mem_clock_rate;
|
||||
int mem_bus_width;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
||||
max_shared_mem, "multiprocessor_count",
|
||||
multiprocessor_count, "sm_clock_rate", sm_clock_rate,
|
||||
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
|
||||
mem_bus_width);
|
||||
}
|
||||
|
||||
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
const char *name;
|
||||
const char *data;
|
||||
Py_ssize_t data_size;
|
||||
int shared;
|
||||
int device;
|
||||
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
|
||||
&device)) {
|
||||
return NULL;
|
||||
}
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
int32_t n_regs = 0;
|
||||
int32_t n_spills = 0;
|
||||
// create driver handles
|
||||
CUcontext pctx = 0;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
// get allocated registers and spilled registers from the function
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(
|
||||
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared > 49152 && shared_optin > 49152) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total, shared_static;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(
|
||||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
||||
n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS,
|
||||
"Load provided cubin into CUDA driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
||||
"Get the properties for a given device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
|
||||
NULL, // documentation
|
||||
-1, // size
|
||||
ModuleMethods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cuda_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if (m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
120
pkgs/triton/runtime/backends/hip.c
Normal file
120
pkgs/triton/runtime/backends/hip.c
Normal file
@@ -0,0 +1,120 @@
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line) {
|
||||
{
|
||||
if (code != HIP_SUCCESS) {
|
||||
{
|
||||
const char *prefix = "Triton Error [HIP]: ";
|
||||
const char *str = hipGetErrorString(code);
|
||||
char err[1024] = {0};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define HIP_CHECK(ans) \
|
||||
{ \
|
||||
gpuAssert((ans), __FILE__, __LINE__); \
|
||||
if (PyErr_Occurred()) \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
int device_id;
|
||||
if (!PyArg_ParseTuple(args, "i", &device_id))
|
||||
return NULL;
|
||||
|
||||
hipDeviceProp_t props;
|
||||
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
|
||||
|
||||
// create a struct to hold device properties
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
||||
props.sharedMemPerBlock, "multiprocessor_count",
|
||||
props.multiProcessorCount, "sm_clock_rate",
|
||||
props.clockRate, "mem_clock_rate", props.memoryClockRate,
|
||||
"mem_bus_width", props.memoryBusWidth);
|
||||
}
|
||||
|
||||
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
const char *name;
|
||||
const char *data;
|
||||
Py_ssize_t data_size;
|
||||
int shared;
|
||||
int device;
|
||||
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
|
||||
&device)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Open HSACO file
|
||||
FILE *hsaco_file;
|
||||
if ((hsaco_file = fopen(data, "rb")) == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Read HSCAO file into Buffer
|
||||
fseek(hsaco_file, 0L, SEEK_END);
|
||||
size_t hsaco_file_size = ftell(hsaco_file);
|
||||
unsigned char *hsaco =
|
||||
(unsigned char *)malloc(hsaco_file_size * sizeof(unsigned char));
|
||||
rewind(hsaco_file);
|
||||
fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file);
|
||||
fclose(hsaco_file);
|
||||
|
||||
// set HIP options
|
||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
|
||||
hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes,
|
||||
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
|
||||
const unsigned int errbufsize = 8192;
|
||||
const unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
|
||||
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
|
||||
|
||||
// launch HIP Binary
|
||||
hipModule_t mod;
|
||||
hipFunction_t fun;
|
||||
hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval);
|
||||
hipModuleGetFunction(&fun, mod, name);
|
||||
free(hsaco);
|
||||
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
||||
n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS,
|
||||
"Load provided hsaco into HIP driver"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
||||
"Get the properties for a given device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
|
||||
NULL, // documentation
|
||||
-1, // size
|
||||
ModuleMethods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_hip_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if (m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
131
pkgs/triton/runtime/cache.py
Normal file
131
pkgs/triton/runtime/cache.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def default_cache_dir():
|
||||
return os.path.join(Path.home(), ".triton", "cache")
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
def __init__(self, key):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_file(self, filename) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def has_file(self, filename) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, data, filename, binary=True) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put_group(self, filename: str, group: Dict[str, str]):
|
||||
pass
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _make_path(self, filename) -> str:
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
|
||||
def has_file(self, filename):
|
||||
if not self.cache_dir:
|
||||
return False
|
||||
return os.path.exists(self._make_path(filename))
|
||||
|
||||
def get_file(self, filename) -> Optional[str]:
|
||||
if self.has_file(filename):
|
||||
return self._make_path(filename)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
||||
grp_filename = f"__grp__{filename}"
|
||||
if not self.has_file(grp_filename):
|
||||
return None
|
||||
grp_filepath = self._make_path(grp_filename)
|
||||
with open(grp_filepath) as f:
|
||||
grp_data = json.load(f)
|
||||
child_paths = grp_data.get("child_paths", None)
|
||||
# Invalid group data.
|
||||
if child_paths is None:
|
||||
return None
|
||||
result = {}
|
||||
for c in child_paths:
|
||||
p = self._make_path(c)
|
||||
if not os.path.exists(p):
|
||||
raise Exception(f"Group file {p} does not exist from group {grp_filename} ")
|
||||
result[c] = p
|
||||
return result
|
||||
|
||||
# Note a group of pushed files as being part of a group
|
||||
def put_group(self, filename: str, group: Dict[str, str]):
|
||||
if not self.cache_dir:
|
||||
return
|
||||
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
|
||||
grp_filename = f"__grp__{filename}"
|
||||
return self.put(grp_contents, grp_filename, binary=False)
|
||||
|
||||
def put(self, data, filename, binary=True) -> str:
|
||||
if not self.cache_dir:
|
||||
return
|
||||
binary = isinstance(data, bytes)
|
||||
if not binary:
|
||||
data = str(data)
|
||||
assert self.lock_path is not None
|
||||
filepath = self._make_path(filename)
|
||||
# Random ID to avoid any collisions
|
||||
rnd_id = random.randint(0, 1000000)
|
||||
# we use the PID incase a bunch of these around so we can see what PID made it
|
||||
pid = os.getpid()
|
||||
# use tempfile to be robust against program interruptions
|
||||
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
|
||||
mode = "wb" if binary else "w"
|
||||
with open(temp_path, mode) as f:
|
||||
f.write(data)
|
||||
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
|
||||
# so filepath cannot see a partial write
|
||||
os.replace(temp_path, filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
__cache_cls = FileCacheManager
|
||||
__cache_cls_nme = "DEFAULT"
|
||||
|
||||
|
||||
def get_cache_manager(key) -> CacheManager:
|
||||
import os
|
||||
|
||||
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
|
||||
global __cache_cls
|
||||
global __cache_cls_nme
|
||||
|
||||
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
||||
import importlib
|
||||
module_path, clz_nme = user_cache_manager.split(":")
|
||||
module = importlib.import_module(module_path)
|
||||
__cache_cls = getattr(module, clz_nme)
|
||||
__cache_cls_nme = user_cache_manager
|
||||
|
||||
return __cache_cls(key)
|
||||
175
pkgs/triton/runtime/driver.py
Normal file
175
pkgs/triton/runtime/driver.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import abc
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from ..common.build import _build
|
||||
from .cache import get_cache_manager
|
||||
|
||||
|
||||
class DriverBase(metaclass=abc.ABCMeta):
|
||||
|
||||
CUDA = 0
|
||||
HIP = 1
|
||||
|
||||
@staticmethod
|
||||
def third_party_dir():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
# -----------------------------
|
||||
# CUDA
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = get_cache_manager(key)
|
||||
fname = "cuda_utils.so"
|
||||
cache_path = cache.get_file(fname)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build("cuda_utils", src_path, tmpdir)
|
||||
cache.put(src, "main.c", binary=False)
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
|
||||
class CudaDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
self.utils = CudaUtils()
|
||||
self.backend = self.CUDA
|
||||
|
||||
# -----------------------------
|
||||
# HIP
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class HIPUtils(object):
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(HIPUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
src = Path(os.path.join(dirname, "backends", "hip.c")).read_text()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = get_cache_manager(key)
|
||||
fname = "hip_utils.so"
|
||||
cache_path = cache.get_file(fname)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build("hip_utils", src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
|
||||
class HIPDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(HIPDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
self.utils = HIPUtils()
|
||||
self.backend = self.HIP
|
||||
|
||||
|
||||
class UnsupportedDriver(DriverBase):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
self.utils = None
|
||||
self.backend = None
|
||||
|
||||
# -----------------------------
|
||||
# Driver
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class LazyProxy:
|
||||
def __init__(self, init_fn):
|
||||
self._init_fn = init_fn
|
||||
self._obj = None
|
||||
|
||||
def _initialize_obj(self):
|
||||
if self._obj is None:
|
||||
self._obj = self._init_fn()
|
||||
|
||||
def __getattr__(self, name):
|
||||
self._initialize_obj()
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ['_init_fn', '_obj']:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self._initialize_obj()
|
||||
setattr(self._obj, name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
self._initialize_obj()
|
||||
delattr(self._obj, name)
|
||||
|
||||
def __repr__(self):
|
||||
if self._obj is None:
|
||||
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
|
||||
return repr(self._obj)
|
||||
|
||||
def __str__(self):
|
||||
self._initialize_obj()
|
||||
return str(self._obj)
|
||||
|
||||
|
||||
def initialize_driver():
|
||||
import torch
|
||||
if torch.version.hip is not None:
|
||||
return HIPDriver()
|
||||
elif torch.cuda.is_available():
|
||||
return CudaDriver()
|
||||
else:
|
||||
return UnsupportedDriver()
|
||||
|
||||
|
||||
driver = LazyProxy(initialize_driver)
|
||||
15
pkgs/triton/runtime/errors.py
Normal file
15
pkgs/triton/runtime/errors.py
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
super().__init__(self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
573
pkgs/triton/runtime/jit.py
Normal file
573
pkgs/triton/runtime/jit.py
Normal file
@@ -0,0 +1,573 @@
|
||||
from __future__ import annotations, division
|
||||
|
||||
import ast
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def get_disable_sme():
|
||||
disable_sme = os.getenv("TRITON_DISABLE_SME", default="0")
|
||||
cc = torch.cuda.get_device_capability()
|
||||
cc = cc[0] * 10 + cc[1]
|
||||
if cc == 70: # for ivcore10
|
||||
disable_sme = "1"
|
||||
|
||||
return disable_sme
|
||||
|
||||
|
||||
def get_corex_sme(args, tl_args, enable_sme=True):
|
||||
can_use_sme = 0
|
||||
if not enable_sme:
|
||||
return can_use_sme
|
||||
import torch
|
||||
if not (hasattr(torch, "corex") and torch.corex == True):
|
||||
return can_use_sme
|
||||
close_sme = get_disable_sme()
|
||||
if close_sme == "1":
|
||||
return can_use_sme
|
||||
index = 0
|
||||
for i, arg_name in enumerate(args):
|
||||
arg = args.get(arg_name)
|
||||
if (i in tl_args):
|
||||
continue
|
||||
if (isinstance(arg, int) and arg == 1):
|
||||
continue
|
||||
if torch.is_tensor(arg) and arg.dtype in [torch.float16, torch.float32, torch.bfloat16, torch.int8] and arg.dim() >= 2:
|
||||
dim_M = arg.shape[-2]
|
||||
dim_K = arg.shape[-1]
|
||||
sme_dim = 64 / arg.element_size()
|
||||
if (arg.is_contiguous() and dim_K % sme_dim == 0) or \
|
||||
(not arg.is_contiguous() and dim_M % sme_dim == 0):
|
||||
can_use_sme = (1 << index) | can_use_sme
|
||||
index += 1
|
||||
return can_use_sme
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
if idx is None:
|
||||
idx = get_current_device()
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
return _cuda_getCurrentRawStream(idx)
|
||||
except ImportError:
|
||||
import torch
|
||||
return torch.cuda.current_stream(idx).cuda_stream
|
||||
|
||||
|
||||
def get_current_device():
|
||||
import torch
|
||||
return torch.cuda.current_device()
|
||||
|
||||
|
||||
def set_current_device(idx):
|
||||
import torch
|
||||
torch.cuda.set_device(idx)
|
||||
|
||||
|
||||
def get_device_capability(idx):
|
||||
import torch
|
||||
return torch.cuda.get_device_capability(idx)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
"""
|
||||
This AST visitor is used to find dependencies of a JITFunction. This can
|
||||
be used to invalidate a JITFunction's hash when its source code -- or
|
||||
that of its dependencies -- changes.
|
||||
"""
|
||||
|
||||
def __init__(self, globals, src) -> None:
|
||||
super().__init__()
|
||||
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
self.globals = globals
|
||||
|
||||
def visit_Name(self, node):
|
||||
return self.globals.get(node.id, None)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or lhs is triton:
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.visit(node.func)
|
||||
if func is None:
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and func.__module__.startswith('triton.'):
|
||||
return
|
||||
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(*triton.__path__, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
language_path = os.path.join(*triton.__path__, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
try:
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
def __getitem__(self, grid) -> T:
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
||||
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
|
||||
@staticmethod
|
||||
def _key_of(arg):
|
||||
if hasattr(arg, "dtype"):
|
||||
return arg.dtype
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
|
||||
def _get_config(self, *args):
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(x, int):
|
||||
return x % JITFunction.divisibility == 0
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
if key is None:
|
||||
return '*i8'
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
"float8e5": "fp8e5",
|
||||
"float8e4": "fp8e4",
|
||||
"float16": "fp16",
|
||||
"bfloat16": "bf16",
|
||||
"float32": "fp32",
|
||||
"float64": "fp64",
|
||||
"int8": "i8",
|
||||
"int16": "i16",
|
||||
"int32": "i32",
|
||||
"int64": "i64",
|
||||
"uint8": "u8",
|
||||
"uint16": "u16",
|
||||
"uint32": "u32",
|
||||
"uint64": "u64",
|
||||
}
|
||||
# reinterpret can create triton type
|
||||
for v in list(tys.values()):
|
||||
tys[v] = v
|
||||
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
|
||||
|
||||
def _make_signature(self, sig_key):
|
||||
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||
return signature
|
||||
|
||||
def _make_constants(self, constexpr_key):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _get_arg_specialization_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
if arg_annotation == '':
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \
|
||||
else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \
|
||||
else (False,)'
|
||||
elif 'Tensor' in arg_annotation:
|
||||
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
|
||||
elif arg_annotation == 'int':
|
||||
return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)'
|
||||
else:
|
||||
return '(False,)'
|
||||
|
||||
def _get_arg_sig_key(self, arg) -> str:
|
||||
arg_annotation = self.__annotations__.get(arg, '')
|
||||
if 'Tensor' in arg_annotation:
|
||||
return f'{arg}.dtype'
|
||||
elif arg_annotation == 'bool':
|
||||
return "i1"
|
||||
elif arg_annotation == 'float':
|
||||
return 'fp32'
|
||||
else:
|
||||
return f'_key_of({arg})'
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
specializations = []
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [self._get_arg_specialization_key(arg)]
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, enable_sme=True, extern_libs=None, stream=None, warmup=False, device=None):
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
use_sme = get_corex_sme({{{grid_args}}}, self.constexprs, enable_sme)
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug, use_sme)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
if callable(grid):
|
||||
grid = grid({{{grid_args}}})
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
bin = cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, use_sme=use_sme)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
"""
|
||||
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||
"cache": self.cache, "triton": triton,
|
||||
"get_current_device": get_current_device,
|
||||
"set_current_device": set_current_device,
|
||||
"get_corex_sme": get_corex_sme}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
# cache of just-in-time compiled kernels
|
||||
self.cache = defaultdict(dict)
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
|
||||
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
self.__globals__ = fn.__globals__
|
||||
self.__module__ = fn.__module__
|
||||
|
||||
@property
|
||||
def cache_key(self):
|
||||
# TODO : hash should be attribute of `self`
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
return self.hash
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Our unit tests do this, for example.
|
||||
def parse(self):
|
||||
tree = ast.parse(self.src)
|
||||
assert isinstance(tree, ast.Module)
|
||||
assert len(tree.body) == 1
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
self.hash = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# `jit` decorator
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@overload
|
||||
def jit(fn: T) -> JITFunction[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def jit(
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
) -> Callable[[T], JITFunction[T]]:
|
||||
...
|
||||
|
||||
|
||||
def jit(
|
||||
fn: Optional[T] = None,
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
interpret: Optional[bool] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, arguments are
|
||||
implicitly converted to pointers if they have a :code:`.data_ptr()` method
|
||||
and a `.dtype` attribute.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* builtins within the triton package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
if interpret:
|
||||
from ..debugger.debugger import GridSelector
|
||||
return GridSelector(fn)
|
||||
else:
|
||||
return JITFunction(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
debug=debug,
|
||||
noinline=noinline,
|
||||
)
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
return decorator
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utilities for mocking tensors
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTensor:
|
||||
"""
|
||||
Can be used in place of real tensors when calling:
|
||||
kernel.warmup(MockTensor(torch.float32), ...)
|
||||
"""
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if arg.__class__.__name__ == "dtype" and\
|
||||
arg.__module__ == "torch":
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
@staticmethod
|
||||
def data_ptr():
|
||||
return 0 # optimistically assumes multiple of 16
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif hasattr(tensor, "data_ptr"):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
424
pkgs/triton/testing.py
Normal file
424
pkgs/triton/testing.py
Normal file
@@ -0,0 +1,424 @@
|
||||
import functools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.common.build import is_corex
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
attrs = ','.join(attrs)
|
||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||
out = subprocess.check_output(cmd)
|
||||
ret = out.decode(sys.stdout.encoding).split(',')
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
quantiles=None,
|
||||
fast_flush=True,
|
||||
return_mode="mean"):
|
||||
assert return_mode in ["min", "max", "mean", "median"]
|
||||
import torch
|
||||
"""
|
||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||
the 20-th and 80-th performance percentile.
|
||||
|
||||
:param fn: Function to benchmark
|
||||
:type fn: Callable
|
||||
:param warmup: Warmup time (in ms)
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms)
|
||||
:type rep: int
|
||||
:param grad_to_none: Reset the gradient of the provided tensor to None
|
||||
:type grad_to_none: torch.tensor, optional
|
||||
:param quantiles: Performance percentile to return in addition to the median.
|
||||
:type quantiles: list[float]
|
||||
:param fast_flush: Use faster kernel to flush L2 between measurements
|
||||
:type fast_flush: bool
|
||||
"""
|
||||
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# We maintain a buffer of 256 MB that we clear
|
||||
# before each kernel call to make sure that the L2
|
||||
# doesn't contain any input data before the run
|
||||
if fast_flush:
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
else:
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
|
||||
# Estimate the runtime of the function
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(5):
|
||||
cache.zero_()
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||
|
||||
# compute number of warmup and repeat
|
||||
n_warmup = max(1, int(warmup / estimate_ms))
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
# Warm-up
|
||||
for _ in range(n_warmup):
|
||||
fn()
|
||||
# Benchmark
|
||||
for i in range(n_repeat):
|
||||
# we don't want `fn` to accumulate gradient values
|
||||
# if it contains a backward pass. So we clear the
|
||||
# provided gradients
|
||||
if grad_to_none is not None:
|
||||
for x in grad_to_none:
|
||||
x.grad = None
|
||||
# we clear the L2 cache before each run
|
||||
cache.zero_()
|
||||
# record time of `fn`
|
||||
start_event[i].record()
|
||||
fn()
|
||||
end_event[i].record()
|
||||
# Record clocks
|
||||
torch.cuda.synchronize()
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||
if quantiles is not None:
|
||||
ret = torch.quantile(times, torch.tensor(quantiles)).tolist()
|
||||
if len(ret) == 1:
|
||||
ret = ret[0]
|
||||
return ret
|
||||
return getattr(torch, return_mode)(times).item()
|
||||
|
||||
|
||||
def assert_close(x, y, atol=None, rtol=None, err_msg=''):
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# canonicalize arguments to be tensors
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x)
|
||||
if not isinstance(y, torch.Tensor):
|
||||
y = torch.tensor(y)
|
||||
# absolute tolerance
|
||||
if atol is None:
|
||||
atol = 1e-2
|
||||
atol = atol(x.dtype) if callable(atol) else atol
|
||||
# relative tolerance hook
|
||||
if rtol is None:
|
||||
rtol = 0.
|
||||
rtol = rtol(x.dtype) if callable(rtol) else rtol
|
||||
# we use numpy instead of pytorch
|
||||
# as it seems more memory efficient
|
||||
# pytorch tends to oom on large tensors
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.dtype == torch.bfloat16:
|
||||
x = x.float()
|
||||
x = x.cpu().detach().numpy()
|
||||
if isinstance(y, torch.Tensor):
|
||||
if y.dtype == torch.bfloat16:
|
||||
y = y.float()
|
||||
y = y.cpu().detach().numpy()
|
||||
# we handle size==1 case separately as we can
|
||||
# provide better error message there
|
||||
if x.size > 1 or y.size > 1:
|
||||
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
|
||||
return
|
||||
if not np.allclose(x, y, atol=atol, rtol=rtol):
|
||||
raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""
|
||||
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_names,
|
||||
x_vals,
|
||||
line_arg,
|
||||
line_vals,
|
||||
line_names,
|
||||
plot_name,
|
||||
args,
|
||||
xlabel='',
|
||||
ylabel='',
|
||||
x_log=False,
|
||||
y_log=False,
|
||||
color=None,
|
||||
styles=None,
|
||||
):
|
||||
"""
|
||||
Constructor
|
||||
|
||||
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
|
||||
:type x_names: List[str]
|
||||
:param x_vals: List of values to use for the arguments in :code:`x_names`.
|
||||
:type x_vals: List[Any]
|
||||
:param line_arg: Argument name for which different values correspond to different lines in the plot.
|
||||
:type line_arg: str
|
||||
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
|
||||
:type line_vals: List[str]
|
||||
:param line_names: Label names for the different lines.
|
||||
:type line_names: List[str]
|
||||
:param plot_name: Name of the plot.
|
||||
:type plot_name: str
|
||||
:param args: List of arguments to remain fixed throughout the benchmark.
|
||||
:type args: List[str]
|
||||
:param xlabel: Label for the x axis of the plot.
|
||||
:type xlabel: str, optional
|
||||
:param ylabel: Label for the y axis of the plot.
|
||||
:type ylabel: str, optional
|
||||
:param x_log: Whether the x axis should be log scale.
|
||||
:type x_log: bool, optional
|
||||
:param y_log: Whether the y axis should be log scale.
|
||||
:type y_log: bool, optional
|
||||
"""
|
||||
self.x_names = x_names
|
||||
self.x_vals = x_vals
|
||||
self.x_log = x_log
|
||||
self.line_arg = line_arg
|
||||
self.line_vals = line_vals
|
||||
self.line_names = line_names
|
||||
self.y_log = y_log
|
||||
self.styles = styles
|
||||
# plot info
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.plot_name = plot_name
|
||||
self.args = args
|
||||
|
||||
|
||||
class Mark:
|
||||
def __init__(self, fn, benchmarks):
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench, save_path, show_plots, print_data):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
y_mean = bench.line_names
|
||||
y_min = [f'{x}-min' for x in bench.line_names]
|
||||
y_max = [f'{x}-max' for x in bench.line_names]
|
||||
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
|
||||
for x in bench.x_vals:
|
||||
x_args = {x_name: x for x_name in bench.x_names}
|
||||
row_mean, row_min, row_max = [], [], []
|
||||
for y in bench.line_vals:
|
||||
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
|
||||
try:
|
||||
y_mean, y_min, y_max = ret
|
||||
except TypeError:
|
||||
y_mean, y_min, y_max = ret, None, None
|
||||
row_mean += [y_mean]
|
||||
row_min += [y_min]
|
||||
row_max += [y_max]
|
||||
df.loc[len(df)] = [x] + row_mean + row_min + row_max
|
||||
if bench.plot_name:
|
||||
plt.figure()
|
||||
ax = plt.subplot()
|
||||
x = bench.x_names[0]
|
||||
for i, y in enumerate(bench.line_names):
|
||||
y_min, y_max = df[y + '-min'], df[y + '-max']
|
||||
col = bench.styles[i][0] if bench.styles else None
|
||||
sty = bench.styles[i][1] if bench.styles else None
|
||||
ax.plot(df[x], df[y], label=y, color=col, ls=sty)
|
||||
if y_min is not None and y_max is not None:
|
||||
ax.fill_between(df[x], y_min, y_max, alpha=0.15, color=col)
|
||||
ax.legend()
|
||||
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(bench.ylabel)
|
||||
# ax.set_title(bench.plot_name)
|
||||
ax.set_xscale("log" if bench.x_log else "linear")
|
||||
ax.set_yscale("log" if bench.y_log else "linear")
|
||||
if show_plots:
|
||||
plt.show()
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
df = df[[bench.x_names[0]] + bench.line_names]
|
||||
if print_data:
|
||||
print(bench.plot_name + ':')
|
||||
print(df)
|
||||
if save_path:
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
|
||||
def run(self, show_plots=False, print_data=False, save_path=''):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots, print_data)
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
html.write("</body></html>\n")
|
||||
|
||||
|
||||
def perf_report(benchmarks):
|
||||
"""
|
||||
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
|
||||
|
||||
:param benchmarks: Benchmarking configurations.
|
||||
:type benchmarks: List of :class:`Benchmark`
|
||||
"""
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_dram_gbps(backend=None, device=None):
|
||||
''' return DRAM bandwidth in GB/s '''
|
||||
import torch
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = driver.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
||||
bus_width = driver.utils.get_device_properties(device)["mem_bus_width"]
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
|
||||
def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None):
|
||||
import torch
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
if not clock_rate:
|
||||
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8:
|
||||
assert dtype == torch.float16 or is_corex()
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 256
|
||||
elif dtype in [torch.float16, torch.bfloat16]:
|
||||
ops_per_sub_core = 512
|
||||
elif dtype == torch.int8:
|
||||
ops_per_sub_core = 1024
|
||||
else:
|
||||
raise RuntimeError("dtype not supported")
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
|
||||
# create decorator that wraps test function into
|
||||
# a cuda-memcheck system call
|
||||
|
||||
|
||||
def cuda_memcheck(**target_kwargs):
|
||||
def decorator(test_fn):
|
||||
@functools.wraps(test_fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
import psutil
|
||||
ppid_name = psutil.Process(os.getppid()).name()
|
||||
run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
|
||||
if run_cuda_memcheck and ppid_name != "cuda-memcheck":
|
||||
path = os.path.realpath(test_fn.__globals__["__file__"])
|
||||
# get path of current file
|
||||
env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
|
||||
assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
|
||||
test_id = kwargs['request'].node.callspec.id
|
||||
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
|
||||
out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
|
||||
assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
|
||||
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
|
||||
else:
|
||||
test_fn(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def nvsmi_attr(attrs):
|
||||
attrs = ",".join(attrs)
|
||||
cmd = [
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
"--query-gpu=" + attrs,
|
||||
"--format=csv,noheader,nounits",
|
||||
]
|
||||
out = subprocess.check_output(cmd)
|
||||
ret = out.decode(sys.stdout.encoding).split(",")
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
try:
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||
]
|
||||
)
|
||||
subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"-i",
|
||||
"0",
|
||||
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||
]
|
||||
)
|
||||
cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0]
|
||||
cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0]
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
|
||||
tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
|
||||
gbps = 640 * 2 * ref_mem_clock * 1e-3
|
||||
yield tflops, gbps
|
||||
finally:
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
|
||||
|
||||
|
||||
def get_max_simd_tflops(dtype, backend=None, device=None):
|
||||
import torch
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 32 # 2*16
|
||||
elif dtype == torch.float16:
|
||||
ops_per_sub_core = 64
|
||||
else:
|
||||
raise RuntimeError("dtype not supported")
|
||||
else:
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 32
|
||||
elif dtype in [torch.float16, torch.bfloat16]:
|
||||
ops_per_sub_core = 64
|
||||
else:
|
||||
raise RuntimeError("dtype not supported")
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
BIN
pkgs/triton/third_party/cuda/bin/ptxas
vendored
Executable file
BIN
pkgs/triton/third_party/cuda/bin/ptxas
vendored
Executable file
Binary file not shown.
19348
pkgs/triton/third_party/cuda/include/cuda.h
vendored
Executable file
19348
pkgs/triton/third_party/cuda/include/cuda.h
vendored
Executable file
File diff suppressed because it is too large
Load Diff
BIN
pkgs/triton/third_party/cuda/lib/libdevice.10.bc
vendored
Executable file
BIN
pkgs/triton/third_party/cuda/lib/libdevice.10.bc
vendored
Executable file
Binary file not shown.
BIN
pkgs/triton/third_party/rocm/lib/bitcode/asanrtl.bc
vendored
Normal file
BIN
pkgs/triton/third_party/rocm/lib/bitcode/asanrtl.bc
vendored
Normal file
Binary file not shown.
BIN
pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.bc
vendored
Executable file
BIN
pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.bc
vendored
Executable file
Binary file not shown.
24
pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.patch
vendored
Normal file
24
pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.patch
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index b65f1b5..19cc5a9 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -62,12 +62,13 @@ include(OCL)
|
||||
set(AMDGCN_LIB_LIST)
|
||||
set(AMDGCN_DEP_LIST)
|
||||
add_subdirectory(irif)
|
||||
-add_subdirectory(oclc)
|
||||
-add_subdirectory(ocml)
|
||||
-add_subdirectory(ockl)
|
||||
-add_subdirectory(opencl)
|
||||
-add_subdirectory(hip)
|
||||
-add_subdirectory(asanrtl)
|
||||
+#add_subdirectory(oclc)
|
||||
+#add_subdirectory(ocml)
|
||||
+#add_subdirectory(ockl)
|
||||
+#add_subdirectory(opencl)
|
||||
+#add_subdirectory(hip)
|
||||
+#add_subdirectory(asanrtl)
|
||||
+add_subdirectory(cuda2gcn)
|
||||
|
||||
enable_testing()
|
||||
add_subdirectory(test/compile)
|
||||
BIN
pkgs/triton/third_party/rocm/lib/bitcode/hip.bc
vendored
Normal file
BIN
pkgs/triton/third_party/rocm/lib/bitcode/hip.bc
vendored
Normal file
Binary file not shown.
17
pkgs/triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh
vendored
Executable file
17
pkgs/triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh
vendored
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
pushd .
|
||||
|
||||
git clone https://github.com/dfukalov/ROCm-Device-Libs.git
|
||||
cd ROCm-Device-Libs
|
||||
git apply ../cuda2gcn.patch
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DCMAKE_PREFIX_PATH=$HOME/.triton/llvm/clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04
|
||||
make -j4
|
||||
|
||||
popd
|
||||
cp ROCm-Device-Libs/build/amdgcn/bitcode/cuda2gcn.bc .
|
||||
rm -rf ROCm-Device-Libs
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user