Improve benchmark scripts & add more models (#484)
This commit is contained in:
@@ -11,13 +11,13 @@ from io import BytesIO
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
import requests
|
||||
import torch
|
||||
import triton
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import torch.distributed as dist
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -178,7 +178,8 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
|
||||
def wrap_kernel_launcher(kernel):
|
||||
"""A faster launcher for triton kernels."""
|
||||
import torch.distributed as dist
|
||||
if int(triton.__version__.split(".")[0]) >= 3:
|
||||
return None
|
||||
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
|
||||
Reference in New Issue
Block a user