Compare commits
10 Commits
9d18371bb7
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab292e63f6 | ||
|
|
70f63772f7 | ||
|
|
99ffbdb4d5 | ||
|
|
90b80e3bcb | ||
|
|
32ad8fb98f | ||
|
|
3c69575c72 | ||
|
|
1b78ebefdd | ||
| 56f0b5b81d | |||
|
|
a575a38552 | ||
|
|
8bc7005d63 |
BIN
026_0010.jpg
Executable file
BIN
026_0010.jpg
Executable file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
49
Dockerfile
49
Dockerfile
@@ -1,49 +0,0 @@
|
||||
|
||||
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
|
||||
MAINTAINER shiguangchuan@4paradigm.com
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY ssh-keygen /bin
|
||||
|
||||
RUN wget -q ftp://ftp.4pd.io/pub/pico/temp/pynini-2.1.6-cp38-cp38-manylinux_2_31_x86_64.whl && pip install pynini-2.1.6-cp38-cp38-manylinux_2_31_x86_64.whl && rm -f pynini-2.1.6-c p38-cp38-manylinux_2_31_x86_64.whl
|
||||
|
||||
ADD ./requirements.txt /workspace
|
||||
RUN pip install -r ./requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --trusted-host nexus.4pd.io --extra-index-url https://mirrors.aliyun.com/pypi/simple/ \
|
||||
&& pip cache purge \
|
||||
&& ssh-keygen -f /workspace/ssh-key-ecdsa -t ecdsa -b 521 -q -N ""
|
||||
|
||||
ADD . /workspace
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
CMD ["python3", "run_callback.py"]
|
||||
|
||||
|
||||
###########################
|
||||
## Dockerfile(更新后)
|
||||
#FROM harbor.4pd.io/lab-platform/inf/python:3.9
|
||||
|
||||
#WORKDIR /app
|
||||
|
||||
## 安装依赖
|
||||
##RUN pip install torch librosa flask
|
||||
|
||||
##RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && \
|
||||
## pip cache purge && \
|
||||
## pip --default-timeout=1000 install torch librosa flask
|
||||
|
||||
## 删除原来的 COPY pytorch_model.bin /app/
|
||||
|
||||
#COPY inference.py /app/
|
||||
# 只需要复制启动脚本
|
||||
|
||||
#EXPOSE 80
|
||||
|
||||
#CMD ["python", "inference.py"]
|
||||
####################
|
||||
|
||||
|
||||
##############################更新0731#################################
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
FROM zibo.harbor.iluvatar.com.cn:30000/saas/bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.2
|
||||
FROM git.modelhub.org.cn:9443/enginex-iluvatar/bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.2
|
||||
|
||||
WORKDIR /workspace/
|
||||
COPY ./model_test_caltech_http.py /workspace/
|
||||
|
||||
65
README.md
65
README.md
@@ -1,2 +1,67 @@
|
||||
# image-classification-transformers
|
||||
|
||||
## 天数智芯 天垓100 视觉分类
|
||||
transformers框架支持多种图像分类模型,现对天垓100加速卡进行transformers框架的适配并且带入到信创算力测试框架中。将视觉分类模型放在天数卡(天垓100)上运行且测试性能,注意该测试框架下的模型需适配transformers库。
|
||||
|
||||
## Quick Start
|
||||
1、首先从 modelscope上下载视觉分类的模型,例如 microsoft/beit-base-patch16-224
|
||||
```python
|
||||
modelscope download --model microsoft/beit-base-patch16-224 README.md --local_dir /mnt/contest_ceph/zhoushasha/models/microsoft/beit_base_patch16_224_pt22k_ft22k
|
||||
```
|
||||
2、使用Dockerfile生成镜像
|
||||
从仓库的【软件包】栏目下载基础镜像 bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.2
|
||||
使用 Dockerfile_bi100 生成 镜像,例如 bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:test
|
||||
注意 Dockerfile_bi100 中已预先将模型 microsoft_beit_base_patch16_224_pt22k_ft22k 放在了 /model 下面
|
||||
|
||||
3、启动docker
|
||||
```python
|
||||
docker run -it --rm \
|
||||
-p 10086:80 \
|
||||
--name test_zss \
|
||||
-v /mnt/contest_ceph/zhoushasha/models/image_models/microsoft_beit_base_patch16_224_pt22k_ft22k:/model:rw \
|
||||
--privileged bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:test
|
||||
```
|
||||
其中/mnt/contest_ceph/zhoushasha/models/image_models/microsoft_beit_base_patch16_224_pt22k_ft22k为你存放的模型文件的实际地址
|
||||
|
||||
4、测试服务
|
||||
```python
|
||||
curl -X POST http://localhost:10086/v1/private/s782b4996 \
|
||||
> -F "image=@/home/zhoushasha/models/026_0010.jpg"
|
||||
```
|
||||
|
||||
## 视觉分类模型测试服务原理
|
||||
|
||||
使用Hugging Face transformers库中的工具类AutoImageProcessor 和AutoModelForImageClassification
|
||||
|
||||
AutoImageProcessor用于自动加载与预训练模型配套的图像处理器。与预训练模型绑定,通过from_pretrained(model_path)加载时,会自动读取模型训练时使用的预处理配置(如尺寸、归一化参数等),负责图像预处理(如尺寸调整、归一化等)。
|
||||
|
||||
AutoModelForImageClassification是一个 “自动模型类”,会根据预训练模型的类型(如 ViT、ResNet 等)自动加载对应的网络结构。AutoModelForImageClassification.from_pretrained(model_path)从model_path加载预训练的图像分类模型,必须接收AutoImageProcessor处理后的张量作为输入。
|
||||
AutoModelForImageClassification,执行图像分类的核心计算,输入预处理后的张量,输出分类结果(如类别概率)。
|
||||
|
||||
## 如何使用 视觉分类 模型测试框架
|
||||
|
||||
代码实现了一个接收图像并返回概率最高的类别作为最终分类结果的视觉分类 HTTP 服务,并基于 zibo.harbor.iluvatar.com.cn:30000/saas/bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.2 基础镜像,将该 HTTP 服务重新打包成 docker 镜像,通过 k8s 集群sut容器去请求这个 HTTP 服务。
|
||||
|
||||
该框架 已测试适配的 视觉分类 模型类型有:
|
||||
|
||||
1、卷积神经网络(CNN)类:ResNet
|
||||
2、Transformer 类:ViT(Vision Transformer)、Swin Transformer、DeiT(Data-efficient Image Transformers)、BEiT(BERT Pre-training of Image Transformers)
|
||||
3、 轻量级模型:MobileNet 系列
|
||||
4、其他特殊设计:ConvNeXt
|
||||
|
||||
|
||||
|
||||
## 天垓100视觉分类模型适配情况
|
||||
| 模型地址 | 类型 | 适配状态 | 天垓100准确率 | 天垓100吞吐量(张/秒) | cpu准确率 | cpu吞吐量(4C)(张/秒) | Submit Id |
|
||||
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
|
||||
| https://www.modelscope.cn/models/apple/mobilevit-x-small | MobileViT | 成功 | 22.6667% | 31.6415 | 22.6667% | 2.6574 | 249973 |
|
||||
| https://www.modelscope.cn/models/facebook/convnextv2-tiny-22k-384 | ConvNeXt V2(ConvNeXt 模型的改进版本) | 成功 | 29.3333% | 25.1330 | 29.3333% | 0.7301 | 249985 |
|
||||
| https://www.modelscope.cn/models/google/vit-base-patch16-224 | ViT(Vision Transformer) | 成功 | 29.3333% | 40.0226 | 29.3333% | 1.1306 | 249992 |
|
||||
| https://www.modelscope.cn/models/microsoft/beit-base-patch16-224-pt22k-ft22k | BEiT(BERT Pre-training of Image Transformers) | 成功 | 34.0000% | 23.7485 | 34.0000% | 0.9773 | 249537 |
|
||||
| https://www.modelscope.cn/models/microsoft/swinv2-tiny-patch4-window16-256 | Swin Transformer V2(基于Swin Transformer) | 成功 | 29.3333% | 13.8379 | 29.3333% | 1.0331 | 249557 |
|
||||
| https://www.modelscope.cn/models/facebook/deit-small-patch16-224 | DeiT(Data-efficient Image Transformer)由 Facebook AI 提出 | 成功 | 29.3333% | 40.5675 | 29.3333% | 3.2749 | 250034 |
|
||||
| https://www.modelscope.cn/models/microsoft/dit-base-finetuned-rvlcdip | DiT(Document Image Transformer) | 成功 | 0.0000% | 35.5122 | 0.0000% | 1.0823 | 250035 |
|
||||
| https://www.modelscope.cn/models/microsoft/cvt-13 | CvT(Convolutional Vision Transformer) | 成功 | 29.3333% | 27.1214 | 29.3333% | 1.7240 | 250039 |
|
||||
| https://www.modelscope.cn/models/google/efficientnet-b7 | EfficientNet 架构(基于卷积神经网络CNN) | 成功 | 28.6667% | 10.0449 | 28.6667% | 0.1541 | 250042 |
|
||||
| https://www.modelscope.cn/models/microsoft/resnet-18 | ResNet(Residual Network) | 成功 | 22.6667% | 43.5976 | 22.6667% | 7.3915 | 250047 |
|
||||
|
||||
|
||||
BIN
helm-chart/.DS_Store
vendored
BIN
helm-chart/.DS_Store
vendored
Binary file not shown.
@@ -1,77 +0,0 @@
|
||||
## judgeflow chart 的要求
|
||||
|
||||
### values.yaml 文件必须包含如下字段,并且模板中必须引用 values.yaml 中的如下字段
|
||||
|
||||
```
|
||||
podLabels
|
||||
env
|
||||
volumeMounts
|
||||
volumes
|
||||
affinity
|
||||
```
|
||||
|
||||
### values.yaml 文件必须在 volumeMounts 中声明如下卷
|
||||
|
||||
```
|
||||
workspace
|
||||
submit
|
||||
datafile
|
||||
```
|
||||
|
||||
## 被测服务(sut) chart 的要求
|
||||
|
||||
### values.yaml 文件必须包含如下字段,并且资源模板中必须引用 values.yaml 中的如下字段
|
||||
|
||||
```
|
||||
podLabels
|
||||
affinity
|
||||
```
|
||||
|
||||
针对 podLabels 字段,values.yaml 中配置格式如下:
|
||||
|
||||
```
|
||||
podLabels: {}
|
||||
```
|
||||
|
||||
下面给出示例
|
||||
|
||||
podLabels
|
||||
|
||||
values.yaml
|
||||
|
||||
templates/deployment.yaml
|
||||
|
||||
```
|
||||
metadata:
|
||||
labels:
|
||||
{{- with .Values.podLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
```
|
||||
|
||||
affinity
|
||||
|
||||
values.yaml
|
||||
|
||||
```
|
||||
affinity: {}
|
||||
```
|
||||
|
||||
templates/deployment.yaml
|
||||
|
||||
```
|
||||
spec:
|
||||
template:
|
||||
spec:
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
```
|
||||
|
||||
### 如果需要在 sut 中使用共享存储,则 sut chart 的 values.yaml 也必须包含如下字段,且模板中必须引用 values.yaml 中的如下字段
|
||||
|
||||
```
|
||||
volumeMounts
|
||||
volumes
|
||||
```
|
||||
@@ -1,23 +0,0 @@
|
||||
# Patterns to ignore when building packages.
|
||||
# This supports shell glob matching, relative path matching, and
|
||||
# negation (prefixed with !). Only one pattern per line.
|
||||
.DS_Store
|
||||
# Common VCS dirs
|
||||
.git/
|
||||
.gitignore
|
||||
.bzr/
|
||||
.bzrignore
|
||||
.hg/
|
||||
.hgignore
|
||||
.svn/
|
||||
# Common backup files
|
||||
*.swp
|
||||
*.bak
|
||||
*.tmp
|
||||
*.orig
|
||||
*~
|
||||
# Various IDEs
|
||||
.project
|
||||
.idea/
|
||||
*.tmproj
|
||||
.vscode/
|
||||
@@ -1,24 +0,0 @@
|
||||
apiVersion: v2
|
||||
name: ${chartName}
|
||||
description: Leaderboard judgeflow helm chart for demo
|
||||
|
||||
# A chart can be either an 'application' or a 'library' chart.
|
||||
#
|
||||
# Application charts are a collection of templates that can be packaged into versioned archives
|
||||
# to be deployed.
|
||||
#
|
||||
# Library charts provide useful utilities or functions for the chart developer. They're included as
|
||||
# a dependency of application charts to inject those utilities and functions into the rendering
|
||||
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
|
||||
type: application
|
||||
|
||||
# This is the chart version. This version number should be incremented each time you make changes
|
||||
# to the chart and its templates, including the app version.
|
||||
# Versions are expected to follow Semantic Versioning (https://semver.org/)
|
||||
version: ${version}
|
||||
|
||||
# This is the version number of the application being deployed. This version number should be
|
||||
# incremented each time you make changes to the application. Versions are not expected to
|
||||
# follow Semantic Versioning. They should reflect the version the application is using.
|
||||
# It is recommended to use it with quotes.
|
||||
appVersion: "${appVersion}"
|
||||
@@ -1,62 +0,0 @@
|
||||
{{/*
|
||||
Expand the name of the chart.
|
||||
*/}}
|
||||
{{- define "judgeflow.name" -}}
|
||||
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create a default fully qualified app name.
|
||||
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
|
||||
If release name contains chart name it will be used as a full name.
|
||||
*/}}
|
||||
{{- define "judgeflow.fullname" -}}
|
||||
{{- if .Values.fullnameOverride }}
|
||||
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- $name := default .Chart.Name .Values.nameOverride }}
|
||||
{{- if contains $name .Release.Name }}
|
||||
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create chart name and version as used by the chart label.
|
||||
*/}}
|
||||
{{- define "judgeflow.chart" -}}
|
||||
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Common labels
|
||||
*/}}
|
||||
{{- define "judgeflow.labels" -}}
|
||||
helm.sh/chart: {{ include "judgeflow.chart" . }}
|
||||
{{ include "judgeflow.selectorLabels" . }}
|
||||
{{- if .Chart.AppVersion }}
|
||||
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||
{{- end }}
|
||||
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Selector labels
|
||||
*/}}
|
||||
{{- define "judgeflow.selectorLabels" -}}
|
||||
app.kubernetes.io/name: {{ include "judgeflow.name" . }}
|
||||
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create the name of the service account to use
|
||||
*/}}
|
||||
{{- define "judgeflow.serviceAccountName" -}}
|
||||
{{- if .Values.serviceAccount.create }}
|
||||
{{- default (include "judgeflow.fullname" .) .Values.serviceAccount.name }}
|
||||
{{- else }}
|
||||
{{- default "default" .Values.serviceAccount.name }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -1,32 +0,0 @@
|
||||
{{- if .Values.autoscaling.enabled }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: {{ include "judgeflow.fullname" . }}
|
||||
labels:
|
||||
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: {{ include "judgeflow.fullname" . }}
|
||||
minReplicas: {{ .Values.autoscaling.minReplicas }}
|
||||
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
|
||||
metrics:
|
||||
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||
{{- end }}
|
||||
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||
- type: Resource
|
||||
resource:
|
||||
name: memory
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -1,61 +0,0 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
{{- $fullName := include "judgeflow.fullname" . -}}
|
||||
{{- $svcPort := .Values.service.port -}}
|
||||
{{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }}
|
||||
{{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }}
|
||||
{{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||
apiVersion: networking.k8s.io/v1beta1
|
||||
{{- else -}}
|
||||
apiVersion: extensions/v1beta1
|
||||
{{- end }}
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ $fullName }}
|
||||
labels:
|
||||
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||
{{- with .Values.ingress.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }}
|
||||
ingressClassName: {{ .Values.ingress.className }}
|
||||
{{- end }}
|
||||
{{- if .Values.ingress.tls }}
|
||||
tls:
|
||||
{{- range .Values.ingress.tls }}
|
||||
- hosts:
|
||||
{{- range .hosts }}
|
||||
- {{ . | quote }}
|
||||
{{- end }}
|
||||
secretName: {{ .secretName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
rules:
|
||||
{{- range .Values.ingress.hosts }}
|
||||
- host: {{ .host | quote }}
|
||||
http:
|
||||
paths:
|
||||
{{- range .paths }}
|
||||
- path: {{ .path }}
|
||||
{{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }}
|
||||
pathType: {{ .pathType }}
|
||||
{{- end }}
|
||||
backend:
|
||||
{{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }}
|
||||
service:
|
||||
name: {{ $fullName }}
|
||||
port:
|
||||
number: {{ $svcPort }}
|
||||
{{- else }}
|
||||
serviceName: {{ $fullName }}
|
||||
servicePort: {{ $svcPort }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -1,63 +0,0 @@
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: {{ include "judgeflow.fullname" . }}
|
||||
labels:
|
||||
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||
{{- with .Values.podLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "judgeflow.labels" . | nindent 8 }}
|
||||
{{- with .Values.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.priorityclassname }}
|
||||
priorityClassName: "{{ . }}"
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: {{ .Chart.Name }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: {{ .Values.image.pullPolicy }}
|
||||
{{- with .Values.env }}
|
||||
env:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- if and (hasKey .Values "service") (hasKey .Values.service "ports") }}
|
||||
ports:
|
||||
{{- range .Values.service.ports }}
|
||||
- name: {{ .name }}
|
||||
containerPort: {{ .port }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- if hasKey .Values "command" }}
|
||||
command: {{ .Values.command }}
|
||||
{{- end }}
|
||||
volumeMounts:
|
||||
{{- toYaml .Values.volumeMounts | nindent 12 }}
|
||||
resources:
|
||||
{{- toYaml .Values.resources | nindent 12 }}
|
||||
restartPolicy: Never
|
||||
{{- with .Values.volumes }}
|
||||
volumes:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
backoffLimit: 0
|
||||
@@ -1,10 +0,0 @@
|
||||
{{- if .Values.priorityclassname }}
|
||||
apiVersion: scheduling.k8s.io/v1
|
||||
kind: PriorityClass
|
||||
metadata:
|
||||
name: "{{ .Values.priorityclassname }}"
|
||||
value: {{ .Values.priorityclassvalue }}
|
||||
globalDefault: false
|
||||
preemptionPolicy: "Never"
|
||||
description: "This is a priority class."
|
||||
{{- end }}
|
||||
@@ -1,22 +0,0 @@
|
||||
{{- if and (hasKey .Values "service") (hasKey .Values.service "type") }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "judgeflow.fullname" . }}
|
||||
labels:
|
||||
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||
{{- with .Values.podLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
type: {{ .Values.service.type }}
|
||||
ports:
|
||||
{{- range .Values.service.ports }}
|
||||
- port: {{ .port }}
|
||||
targetPort: {{ .port }}
|
||||
protocol: TCP
|
||||
name: {{ .name }}
|
||||
{{- end }}
|
||||
selector:
|
||||
{{- include "judgeflow.selectorLabels" . | nindent 4 }}
|
||||
{{- end }}
|
||||
@@ -1,13 +0,0 @@
|
||||
{{- if .Values.serviceAccount.create -}}
|
||||
apiVersion: v1
|
||||
kind: ServiceAccount
|
||||
metadata:
|
||||
name: {{ include "judgeflow.serviceAccountName" . }}
|
||||
labels:
|
||||
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||
{{- with .Values.serviceAccount.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
|
||||
{{- end }}
|
||||
@@ -1,15 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
name: {{ include "judgeflow.fullname" . }}-test-connection
|
||||
labels:
|
||||
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||
annotations:
|
||||
"helm.sh/hook": test
|
||||
spec:
|
||||
containers:
|
||||
- name: wget
|
||||
image: busybox
|
||||
command: ['wget']
|
||||
args: ['{{ include "judgeflow.fullname" . }}:{{ .Values.service.port }}']
|
||||
restartPolicy: Never
|
||||
@@ -1,124 +0,0 @@
|
||||
# Default values for job_demo.
|
||||
# This is a YAML-formatted file.
|
||||
# Declare variables to be passed into your templates.
|
||||
|
||||
replicaCount: 1
|
||||
|
||||
image:
|
||||
repository: "${imageRepo}"
|
||||
pullPolicy: IfNotPresent
|
||||
# Overrides the image tag whose default is the chart appVersion.
|
||||
tag: "${imageTag}"
|
||||
|
||||
imagePullSecrets: []
|
||||
nameOverride: ""
|
||||
fullnameOverride: ""
|
||||
|
||||
serviceAccount:
|
||||
# Specifies whether a service account should be created
|
||||
create: true
|
||||
# Annotations to add to the service account
|
||||
annotations: {}
|
||||
# The name of the service account to use.
|
||||
# If not set and create is true, a name is generated using the fullname template
|
||||
name: ""
|
||||
|
||||
podAnnotations: {}
|
||||
|
||||
podLabels:
|
||||
contest.4pd.io/leaderboard-resource-type: judge_flow
|
||||
contest.4pd.io/leaderboard-job-id: "0"
|
||||
contest.4pd.io/leaderboard-submit-id: "0"
|
||||
|
||||
podSecurityContext: {}
|
||||
# fsGroup: 2000
|
||||
|
||||
securityContext: {}
|
||||
# capabilities:
|
||||
# drop:
|
||||
# - ALL
|
||||
# readOnlyRootFilesystem: true
|
||||
# runAsNonRoot: true
|
||||
# runAsUser: 1000
|
||||
|
||||
service:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- name: http
|
||||
port: 80
|
||||
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
annotations: {}
|
||||
# kubernetes.io/ingress.class: nginx
|
||||
# kubernetes.io/tls-acme: "true"
|
||||
hosts:
|
||||
- host: chart-example.local
|
||||
paths:
|
||||
- path: /
|
||||
pathType: ImplementationSpecific
|
||||
tls: []
|
||||
# - secretName: chart-example-tls
|
||||
# hosts:
|
||||
# - chart-example.local
|
||||
|
||||
resources:
|
||||
# We usually recommend not to specify default resources and to leave this as a conscious
|
||||
# choice for the user. This also increases chances charts run on environments with little
|
||||
# resources, such as Minikube. If you do want to specify resources, uncomment the following
|
||||
# lines, adjust them as necessary, and remove the curly braces after 'resources:'.
|
||||
limits:
|
||||
cpu: 3000m
|
||||
memory: 16Gi
|
||||
requests:
|
||||
cpu: 3000m
|
||||
memory: 16Gi
|
||||
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
maxReplicas: 100
|
||||
targetCPUUtilizationPercentage: 80
|
||||
# targetMemoryUtilizationPercentage: 80
|
||||
|
||||
nodeSelector:
|
||||
juicefs: "on"
|
||||
contest.4pd.io/cpu: INTEL-8358
|
||||
|
||||
tolerations: []
|
||||
|
||||
affinity: {}
|
||||
|
||||
env:
|
||||
- name: TZ
|
||||
value: Asia/Shanghai
|
||||
- name: MY_POD_IP
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: status.podIP
|
||||
|
||||
#command: '["python","run.py"]'
|
||||
|
||||
volumeMounts:
|
||||
- name: workspace
|
||||
mountPath: /tmp/workspace
|
||||
- name: datafile
|
||||
mountPath: /tmp/datafile
|
||||
- name: submit
|
||||
mountPath: /tmp/submit_config
|
||||
- name: juicefs-pv
|
||||
mountPath: /tmp/juicefs
|
||||
- name: customer
|
||||
mountPath: /tmp/customer
|
||||
- name: submit-private
|
||||
mountPath: /tmp/submit_private
|
||||
|
||||
volumes:
|
||||
- name: juicefs-pv
|
||||
persistentVolumeClaim:
|
||||
claimName: juicefs-pvc
|
||||
|
||||
|
||||
priorityclassname: ''
|
||||
priorityclassvalue: '0'
|
||||
BIN
helm-chart/sut/.DS_Store
vendored
BIN
helm-chart/sut/.DS_Store
vendored
Binary file not shown.
@@ -1,23 +0,0 @@
|
||||
# Patterns to ignore when building packages.
|
||||
# This supports shell glob matching, relative path matching, and
|
||||
# negation (prefixed with !). Only one pattern per line.
|
||||
.DS_Store
|
||||
# Common VCS dirs
|
||||
.git/
|
||||
.gitignore
|
||||
.bzr/
|
||||
.bzrignore
|
||||
.hg/
|
||||
.hgignore
|
||||
.svn/
|
||||
# Common backup files
|
||||
*.swp
|
||||
*.bak
|
||||
*.tmp
|
||||
*.orig
|
||||
*~
|
||||
# Various IDEs
|
||||
.project
|
||||
.idea/
|
||||
*.tmproj
|
||||
.vscode/
|
||||
@@ -1,24 +0,0 @@
|
||||
apiVersion: v2
|
||||
name: sut
|
||||
description: A Helm chart for Kubernetes
|
||||
|
||||
# A chart can be either an 'application' or a 'library' chart.
|
||||
#
|
||||
# Application charts are a collection of templates that can be packaged into versioned archives
|
||||
# to be deployed.
|
||||
#
|
||||
# Library charts provide useful utilities or functions for the chart developer. They're included as
|
||||
# a dependency of application charts to inject those utilities and functions into the rendering
|
||||
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
|
||||
type: application
|
||||
|
||||
# This is the chart version. This version number should be incremented each time you make changes
|
||||
# to the chart and its templates, including the app version.
|
||||
# Versions are expected to follow Semantic Versioning (https://semver.org/)
|
||||
version: 0.1.0
|
||||
|
||||
# This is the version number of the application being deployed. This version number should be
|
||||
# incremented each time you make changes to the application. Versions are not expected to
|
||||
# follow Semantic Versioning. They should reflect the version the application is using.
|
||||
# It is recommended to use it with quotes.
|
||||
appVersion: "0.1.0"
|
||||
@@ -1,62 +0,0 @@
|
||||
{{/*
|
||||
Expand the name of the chart.
|
||||
*/}}
|
||||
{{- define "sut.name" -}}
|
||||
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create a default fully qualified app name.
|
||||
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
|
||||
If release name contains chart name it will be used as a full name.
|
||||
*/}}
|
||||
{{- define "sut.fullname" -}}
|
||||
{{- if .Values.fullnameOverride }}
|
||||
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- $name := default .Chart.Name .Values.nameOverride }}
|
||||
{{- if contains $name .Release.Name }}
|
||||
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create chart name and version as used by the chart label.
|
||||
*/}}
|
||||
{{- define "sut.chart" -}}
|
||||
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Common labels
|
||||
*/}}
|
||||
{{- define "sut.labels" -}}
|
||||
helm.sh/chart: {{ include "sut.chart" . }}
|
||||
{{ include "sut.selectorLabels" . }}
|
||||
{{- if .Chart.AppVersion }}
|
||||
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||
{{- end }}
|
||||
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Selector labels
|
||||
*/}}
|
||||
{{- define "sut.selectorLabels" -}}
|
||||
app.kubernetes.io/name: {{ include "sut.name" . }}
|
||||
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create the name of the service account to use
|
||||
*/}}
|
||||
{{- define "sut.serviceAccountName" -}}
|
||||
{{- if .Values.serviceAccount.create }}
|
||||
{{- default (include "sut.fullname" .) .Values.serviceAccount.name }}
|
||||
{{- else }}
|
||||
{{- default "default" .Values.serviceAccount.name }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -1,94 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "sut.fullname" . }}
|
||||
labels:
|
||||
{{- include "sut.labels" . | nindent 4 }}
|
||||
{{- with .Values.podLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- if not .Values.autoscaling.enabled }}
|
||||
replicas: {{ .Values.replicaCount }}
|
||||
{{- end }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "sut.selectorLabels" . | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
{{- with .Values.podAnnotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "sut.labels" . | nindent 8 }}
|
||||
{{- with .Values.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "sut.serviceAccountName" . }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.podSecurityContext | nindent 8 }}
|
||||
{{- with .Values.priorityclassname }}
|
||||
priorityClassName: "{{ . }}"
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: {{ .Chart.Name }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: {{ .Values.image.pullPolicy }}
|
||||
{{- with .Values.env }}
|
||||
env:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.service.port }}
|
||||
protocol: TCP
|
||||
{{- with .Values.command }}
|
||||
command:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
{{- toYaml .Values.resources | nindent 12 }}
|
||||
{{- with .Values.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
|
||||
{{- with .Values.livenessProbe }}
|
||||
livenessProbe:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.readinessProbe }}
|
||||
readinessProbe:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.startupProbe }}
|
||||
startupProbe:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
|
||||
volumes:
|
||||
{{- with .Values.volumes }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
tolerations:
|
||||
- key: "hosttype"
|
||||
operator: "Equal"
|
||||
value: "iluvatar"
|
||||
effect: "NoSchedule"
|
||||
@@ -1,32 +0,0 @@
|
||||
{{- if .Values.autoscaling.enabled }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: {{ include "sut.fullname" . }}
|
||||
labels:
|
||||
{{- include "sut.labels" . | nindent 4 }}
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: {{ include "sut.fullname" . }}
|
||||
minReplicas: {{ .Values.autoscaling.minReplicas }}
|
||||
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
|
||||
metrics:
|
||||
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||
{{- end }}
|
||||
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||
- type: Resource
|
||||
resource:
|
||||
name: memory
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -1,61 +0,0 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
{{- $fullName := include "sut.fullname" . -}}
|
||||
{{- $svcPort := .Values.service.port -}}
|
||||
{{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }}
|
||||
{{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }}
|
||||
{{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||
apiVersion: networking.k8s.io/v1beta1
|
||||
{{- else -}}
|
||||
apiVersion: extensions/v1beta1
|
||||
{{- end }}
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ $fullName }}
|
||||
labels:
|
||||
{{- include "sut.labels" . | nindent 4 }}
|
||||
{{- with .Values.ingress.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }}
|
||||
ingressClassName: {{ .Values.ingress.className }}
|
||||
{{- end }}
|
||||
{{- if .Values.ingress.tls }}
|
||||
tls:
|
||||
{{- range .Values.ingress.tls }}
|
||||
- hosts:
|
||||
{{- range .hosts }}
|
||||
- {{ . | quote }}
|
||||
{{- end }}
|
||||
secretName: {{ .secretName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
rules:
|
||||
{{- range .Values.ingress.hosts }}
|
||||
- host: {{ .host | quote }}
|
||||
http:
|
||||
paths:
|
||||
{{- range .paths }}
|
||||
- path: {{ .path }}
|
||||
{{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }}
|
||||
pathType: {{ .pathType }}
|
||||
{{- end }}
|
||||
backend:
|
||||
{{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }}
|
||||
service:
|
||||
name: {{ $fullName }}
|
||||
port:
|
||||
number: {{ $svcPort }}
|
||||
{{- else }}
|
||||
serviceName: {{ $fullName }}
|
||||
servicePort: {{ $svcPort }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -1,18 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "sut.fullname" . }}
|
||||
labels:
|
||||
{{- include "sut.labels" . | nindent 4 }}
|
||||
{{- with .Values.podLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
type: {{ .Values.service.type }}
|
||||
ports:
|
||||
- port: {{ .Values.service.port }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: socket
|
||||
selector:
|
||||
{{- include "sut.selectorLabels" . | nindent 4 }}
|
||||
@@ -1,13 +0,0 @@
|
||||
{{- if .Values.serviceAccount.create -}}
|
||||
apiVersion: v1
|
||||
kind: ServiceAccount
|
||||
metadata:
|
||||
name: {{ include "sut.serviceAccountName" . }}
|
||||
labels:
|
||||
{{- include "sut.labels" . | nindent 4 }}
|
||||
{{- with .Values.serviceAccount.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
|
||||
{{- end }}
|
||||
@@ -1,15 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
name: "{{ include "sut.fullname" . }}-test-connection"
|
||||
labels:
|
||||
{{- include "sut.labels" . | nindent 4 }}
|
||||
annotations:
|
||||
"helm.sh/hook": test
|
||||
spec:
|
||||
containers:
|
||||
- name: wget
|
||||
image: busybox
|
||||
command: ['wget']
|
||||
args: ['{{ include "sut.fullname" . }}:{{ .Values.service.port }}']
|
||||
restartPolicy: Never
|
||||
@@ -1,144 +0,0 @@
|
||||
# Default values for sut.
|
||||
# This is a YAML-formatted file.
|
||||
# Declare variables to be passed into your templates.
|
||||
|
||||
replicaCount: 1
|
||||
|
||||
image:
|
||||
repository: harbor.4pd.io/lab-platform/inf/python
|
||||
pullPolicy: IfNotPresent
|
||||
# Overrides the image tag whose default is the chart appVersion.
|
||||
tag: 3.9
|
||||
|
||||
imagePullSecrets: []
|
||||
nameOverride: ""
|
||||
fullnameOverride: ""
|
||||
|
||||
serviceAccount:
|
||||
# Specifies whether a service account should be created
|
||||
create: true
|
||||
# Automatically mount a ServiceAccount's API credentials?
|
||||
automount: true
|
||||
# Annotations to add to the service account
|
||||
annotations: {}
|
||||
# The name of the service account to use.
|
||||
# If not set and create is true, a name is generated using the fullname template
|
||||
name: ""
|
||||
|
||||
podAnnotations: {}
|
||||
podLabels: {}
|
||||
podSecurityContext: {}
|
||||
# fsGroup: 2000
|
||||
|
||||
securityContext: {}
|
||||
# capabilities:
|
||||
# drop:
|
||||
# - ALL
|
||||
# readOnlyRootFilesystem: true
|
||||
# runAsNonRoot: true
|
||||
# runAsUser: 1000
|
||||
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 80
|
||||
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
annotations: {}
|
||||
# kubernetes.io/ingress.class: nginx
|
||||
# kubernetes.io/tls-acme: "true"
|
||||
hosts:
|
||||
- host: chart-example.local
|
||||
paths:
|
||||
- path: /
|
||||
pathType: ImplementationSpecific
|
||||
tls: []
|
||||
# - secretName: chart-example-tls
|
||||
# hosts:
|
||||
# - chart-example.local
|
||||
|
||||
resources:
|
||||
# We usually recommend not to specify default resources and to leave this as a conscious
|
||||
# choice for the user. This also increases chances charts run on environments with little
|
||||
# resources, such as Minikube. If you do want to specify resources, uncomment the following
|
||||
# lines, adjust them as necessary, and remove the curly braces after 'resources:'.
|
||||
limits:
|
||||
cpu: 1000m
|
||||
memory: 4096Mi
|
||||
requests:
|
||||
cpu: 1000m
|
||||
memory: 4096Mi
|
||||
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
maxReplicas: 100
|
||||
targetCPUUtilizationPercentage: 80
|
||||
# targetMemoryUtilizationPercentage: 80
|
||||
|
||||
# Additional volumes on the output Deployment definition.
|
||||
volumes: []
|
||||
# - name: foo
|
||||
# secret:
|
||||
# secretName: mysecret
|
||||
# optional: false
|
||||
|
||||
# Additional volumeMounts on the output Deployment definition.
|
||||
volumeMounts: []
|
||||
# - name: foo
|
||||
# mountPath: "/etc/foo"
|
||||
# readOnly: true
|
||||
|
||||
nodeSelector:
|
||||
contest.4pd.io/accelerator: iluvatar-BI-V100
|
||||
|
||||
tolerations:
|
||||
- key: hosttype
|
||||
operator: Equal
|
||||
value: iluvatar
|
||||
effect: NoSchedule
|
||||
|
||||
|
||||
affinity: {}
|
||||
|
||||
readinessProbe:
|
||||
failureThreshold: 1000
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 80
|
||||
scheme: HTTP
|
||||
|
||||
#readinessProbe:
|
||||
# httpGet:
|
||||
# path: /health
|
||||
# port: 80
|
||||
# scheme: HTTP
|
||||
# initialDelaySeconds: 5 # 应用启动后等待 5 秒再开始探测
|
||||
# failureThreshold: 5 # 连续失败 3 次后标记为未就绪
|
||||
# successThreshold: 1 # 连续成功 1 次后标记为就绪
|
||||
|
||||
env:
|
||||
- name: TZ
|
||||
value: Asia/Shanghai
|
||||
- name: MY_POD_NAME
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
- name: MY_POD_NAMESPACE
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.namespace
|
||||
- name: MY_POD_IP
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: status.podIP
|
||||
- name: MY_NODE_IP
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: status.hostIP
|
||||
|
||||
#command: ''
|
||||
|
||||
|
||||
priorityclassname: ''
|
||||
@@ -1,64 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
if os.path.exists("/tmp/submit_private"):
|
||||
shutil.rmtree("/tmp/submit_private")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
config_path = os.path.join(tempdir, "config.json")
|
||||
|
||||
assert not os.system(f"ssh-keygen -f {tempdir}/ssh-key-ecdsa -t ecdsa -b 521 -q -N \"\"")
|
||||
|
||||
config = """
|
||||
model: whisper
|
||||
model_key: whisper
|
||||
config.json:
|
||||
name: 'faster-whisper-server:latest'
|
||||
support_devices:
|
||||
- cpu
|
||||
model_path: ''
|
||||
port: 8080
|
||||
other_ports: []
|
||||
other_ports_count: 1
|
||||
entrypoint: start.bat
|
||||
MIN_CHUNK: 2.5
|
||||
MIN_ADD_CHUNK: 2.5
|
||||
COMPUTE_TYPE: int8
|
||||
NUM_WORKERS: 1
|
||||
CPU_THREADS: 2
|
||||
BEAM_SIZE: 5
|
||||
BATCH: 1
|
||||
LANG: auto
|
||||
DEVICE: cpu
|
||||
CHUNK_LENGTH: 5
|
||||
CLASS_MODEL: ./models/faster-whisper-base
|
||||
EN_MODEL: ./models/faster-whisper-base
|
||||
ZH_MODEL: ./models/faster-whisper-base
|
||||
RU_MODEL: ./models/faster-whisper-base
|
||||
PT_MODEL: ./models/faster-whisper-base
|
||||
AR_MODEL: ./models/faster-whisper-base
|
||||
NEW_VERSION: 1
|
||||
NEED_RESET: 0
|
||||
leaderboard_options:
|
||||
nfs:
|
||||
- name: whisper
|
||||
srcRelativePath: leaderboard/pc_asr/en.tar.gz
|
||||
mountPoint: /tmp
|
||||
source: ceph_customer
|
||||
"""
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
f.write(config)
|
||||
|
||||
os.environ["SSH_KEY_DIR"] = tempdir
|
||||
os.environ["SUBMIT_CONFIG_FILEPATH"] = config_path
|
||||
os.environ["MODEL_MAPPING"] = '{"whisper": "edge-ml.tar.gz"}'
|
||||
|
||||
from run_async_a10 import get_sut_url_windows
|
||||
|
||||
|
||||
print(get_sut_url_windows())
|
||||
|
||||
import time
|
||||
time.sleep(3600)
|
||||
18
microsoft_beit_base_patch16_224_pt22k_ft22k/.gitattributes
vendored
Normal file
18
microsoft_beit_base_patch16_224_pt22k_ft22k/.gitattributes
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
104
microsoft_beit_base_patch16_224_pt22k_ft22k/README.md
Normal file
104
microsoft_beit_base_patch16_224_pt22k_ft22k/README.md
Normal file
@@ -0,0 +1,104 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
tags:
|
||||
- image-classification
|
||||
- vision
|
||||
datasets:
|
||||
- imagenet
|
||||
- imagenet-21k
|
||||
---
|
||||
|
||||
# BEiT (base-sized model, fine-tuned on ImageNet-22k)
|
||||
|
||||
BEiT model pre-trained in a self-supervised fashion on ImageNet-22k - also called ImageNet-21k (14 million images, 21,841 classes) at resolution 224x224, and fine-tuned on the same dataset at resolution 224x224. It was introduced in the paper [BEIT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong and Furu Wei and first released in [this repository](https://github.com/microsoft/unilm/tree/master/beit).
|
||||
|
||||
Disclaimer: The team releasing BEiT did not write a model card for this model so this model card has been written by the Hugging Face team.
|
||||
|
||||
## Model description
|
||||
|
||||
The BEiT model is a Vision Transformer (ViT), which is a transformer encoder model (BERT-like). In contrast to the original ViT model, BEiT is pretrained on a large collection of images in a self-supervised fashion, namely ImageNet-21k, at a resolution of 224x224 pixels. The pre-training objective for the model is to predict visual tokens from the encoder of OpenAI's DALL-E's VQ-VAE, based on masked patches.
|
||||
Next, the model was fine-tuned in a supervised fashion on ImageNet (also referred to as ILSVRC2012), a dataset comprising 1 million images and 1,000 classes, also at resolution 224x224.
|
||||
|
||||
Images are presented to the model as a sequence of fixed-size patches (resolution 16x16), which are linearly embedded. Contrary to the original ViT models, BEiT models do use relative position embeddings (similar to T5) instead of absolute position embeddings, and perform classification of images by mean-pooling the final hidden states of the patches, instead of placing a linear layer on top of the final hidden state of the [CLS] token.
|
||||
|
||||
By pre-training the model, it learns an inner representation of images that can then be used to extract features useful for downstream tasks: if you have a dataset of labeled images for instance, you can train a standard classifier by placing a linear layer on top of the pre-trained encoder. One typically places a linear layer on top of the [CLS] token, as the last hidden state of this token can be seen as a representation of an entire image. Alternatively, one can mean-pool the final hidden states of the patch embeddings, and place a linear layer on top of that.
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
You can use the raw model for image classification. See the [model hub](https://huggingface.co/models?search=microsoft/beit) to look for
|
||||
fine-tuned versions on a task that interests you.
|
||||
|
||||
### How to use
|
||||
|
||||
Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
|
||||
|
||||
```python
|
||||
from transformers import BeitImageProcessor, BeitForImageClassification
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
processor = BeitImageProcessor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
|
||||
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
# model predicts one of the 21,841 ImageNet-22k classes
|
||||
predicted_class_idx = logits.argmax(-1).item()
|
||||
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
```
|
||||
|
||||
Currently, both the feature extractor and model support PyTorch.
|
||||
|
||||
## Training data
|
||||
|
||||
The BEiT model was pretrained on [ImageNet-21k](http://www.image-net.org/), a dataset consisting of 14 million images and 21k classes, and fine-tuned on the same dataset.
|
||||
|
||||
## Training procedure
|
||||
|
||||
### Preprocessing
|
||||
|
||||
The exact details of preprocessing of images during training/validation can be found [here](https://github.com/microsoft/unilm/blob/master/beit/datasets.py).
|
||||
|
||||
Images are resized/rescaled to the same resolution (224x224) and normalized across the RGB channels with mean (0.5, 0.5, 0.5) and standard deviation (0.5, 0.5, 0.5).
|
||||
|
||||
### Pretraining
|
||||
|
||||
For all pre-training related hyperparameters, we refer to page 15 of the [original paper](https://arxiv.org/abs/2106.08254).
|
||||
|
||||
## Evaluation results
|
||||
|
||||
For evaluation results on several image classification benchmarks, we refer to tables 1 and 2 of the original paper. Note that for fine-tuning, the best results are obtained with a higher resolution. Of course, increasing the model size will result in better performance.
|
||||
|
||||
### BibTeX entry and citation info
|
||||
|
||||
```@article{DBLP:journals/corr/abs-2106-08254,
|
||||
author = {Hangbo Bao and
|
||||
Li Dong and
|
||||
Furu Wei},
|
||||
title = {BEiT: {BERT} Pre-Training of Image Transformers},
|
||||
journal = {CoRR},
|
||||
volume = {abs/2106.08254},
|
||||
year = {2021},
|
||||
url = {https://arxiv.org/abs/2106.08254},
|
||||
archivePrefix = {arXiv},
|
||||
eprint = {2106.08254},
|
||||
timestamp = {Tue, 29 Jun 2021 16:55:04 +0200},
|
||||
biburl = {https://dblp.org/rec/journals/corr/abs-2106-08254.bib},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{deng2009imagenet,
|
||||
title={Imagenet: A large-scale hierarchical image database},
|
||||
author={Deng, Jia and Dong, Wei and Socher, Richard and Li, Li-Jia and Li, Kai and Fei-Fei, Li},
|
||||
booktitle={2009 IEEE conference on computer vision and pattern recognition},
|
||||
pages={248--255},
|
||||
year={2009},
|
||||
organization={Ieee}
|
||||
}
|
||||
```
|
||||
43111
microsoft_beit_base_patch16_224_pt22k_ft22k/config.json
Normal file
43111
microsoft_beit_base_patch16_224_pt22k_ft22k/config.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
microsoft_beit_base_patch16_224_pt22k_ft22k/flax_model.msgpack
Normal file
BIN
microsoft_beit_base_patch16_224_pt22k_ft22k/flax_model.msgpack
Normal file
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"crop_size": 224,
|
||||
"do_center_crop": false,
|
||||
"do_normalize": true,
|
||||
"do_resize": true,
|
||||
"feature_extractor_type": "BeitFeatureExtractor",
|
||||
"image_mean": [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
"image_std": [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
"resample": 2,
|
||||
"size": 224
|
||||
}
|
||||
BIN
microsoft_beit_base_patch16_224_pt22k_ft22k/pytorch_model.bin
Normal file
BIN
microsoft_beit_base_patch16_224_pt22k_ft22k/pytorch_model.bin
Normal file
Binary file not shown.
@@ -1,8 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
export DATASET_FILEPATH=dataset/formatted1/de.zip
|
||||
export RESULT_FILEPATH=out/result.json
|
||||
export DETAILED_CASES_FILEPATH=out/detail_cases.json
|
||||
export SUBMIT_CONFIG_FILEPATH=
|
||||
export BENCHMARK_NAME=
|
||||
export MY_POD_IP=127.0.0.1
|
||||
@@ -1,24 +0,0 @@
|
||||
[tool.black]
|
||||
line-length = 80
|
||||
target-version = ['py39']
|
||||
|
||||
[tool.flake8]
|
||||
max-line-length = 88
|
||||
count=true
|
||||
per-file-ignores="./annotation/manager.py:F401"
|
||||
exclude=["./label", "__pycache__", "./migrations", "./logs", "./pids", "./resources"]
|
||||
ignore=["W503", "E203"]
|
||||
enable-extensions="G"
|
||||
application-import-names=["flake8-isort", "flake8-logging-format", "flake8-builtins"]
|
||||
import-order-style="edited"
|
||||
extend-ignore = ["E203", "E701"]
|
||||
|
||||
[tool.isort]
|
||||
py_version=39
|
||||
profile="black"
|
||||
multi_line_output=9
|
||||
line_length=80
|
||||
group_by_package=true
|
||||
case_sensitive=true
|
||||
skip_gitignore=true
|
||||
|
||||
114
run.py
114
run.py
@@ -1,114 +0,0 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
|
||||
import yaml
|
||||
from schemas.context import ASRContext
|
||||
from utils.client import Client
|
||||
from utils.evaluator import BaseEvaluator
|
||||
from utils.logger import logger
|
||||
from utils.service import register_sut
|
||||
|
||||
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||
UNIT_TEST = os.getenv("UNIT_TEST", 0)
|
||||
|
||||
|
||||
def main():
|
||||
logger.info("执行……")
|
||||
|
||||
dataset_filepath = os.getenv(
|
||||
"DATASET_FILEPATH",
|
||||
"./tests/resources/en.zip",
|
||||
)
|
||||
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
|
||||
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||||
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||||
detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
|
||||
|
||||
resource_name = os.getenv("BENCHMARK_NAME")
|
||||
|
||||
# 提交配置 & 启动被测服务
|
||||
if os.getenv("DATASET_FILEPATH", ""):
|
||||
from utils.helm import resource_check
|
||||
|
||||
with open(submit_config_filepath, "r") as fp:
|
||||
st_config = yaml.safe_load(fp)
|
||||
st_config["values"] = resource_check(st_config.get("values", {}))
|
||||
if 'docker_images' in st_config:
|
||||
sut_url = "ws://172.26.1.75:9827"
|
||||
os.environ['test'] = '1'
|
||||
elif 'docker_image' in st_config:
|
||||
sut_url = register_sut(st_config, resource_name)
|
||||
elif UNIT_TEST:
|
||||
sut_url = "ws://172.27.231.36:80"
|
||||
else:
|
||||
logger.error("config 配置错误,没有 docker_image")
|
||||
os._exit(1)
|
||||
else:
|
||||
os.environ['test'] = '1'
|
||||
sut_url = "ws://172.27.231.36:80"
|
||||
if UNIT_TEST:
|
||||
exit(0)
|
||||
|
||||
"""
|
||||
# 数据集处理
|
||||
local_dataset_path = "./dataset"
|
||||
os.makedirs(local_dataset_path, exist_ok=True)
|
||||
with zipfile.ZipFile(dataset_filepath) as zf:
|
||||
zf.extractall(local_dataset_path)
|
||||
config_path = os.path.join(local_dataset_path, "data.yaml")
|
||||
with open(config_path, "r") as fp:
|
||||
dataset_config = yaml.safe_load(fp)
|
||||
|
||||
# 数据集信息
|
||||
dataset_global_config = dataset_config.get("global", {})
|
||||
dataset_query = dataset_config.get("query_data", {})
|
||||
|
||||
evaluator = BaseEvaluator()
|
||||
|
||||
# 开始预测
|
||||
for idx, query_item in enumerate(dataset_query):
|
||||
gc.collect()
|
||||
logger.info(f"开始执行 {idx} 条数据")
|
||||
|
||||
context = ASRContext(**dataset_global_config)
|
||||
context.lang = query_item.get("lang", context.lang)
|
||||
context.file_path = os.path.join(local_dataset_path, query_item["file"])
|
||||
# context.audio_length = query_item["audio_length"]
|
||||
|
||||
interactions = Client(sut_url, context).action()
|
||||
context.append_labels(query_item["voice"])
|
||||
context.append_preds(
|
||||
interactions["predict_data"],
|
||||
interactions["send_time"],
|
||||
interactions["recv_time"],
|
||||
)
|
||||
context.fail = interactions["fail"]
|
||||
if IN_TEST:
|
||||
with open('output.txt', 'w') as fp:
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = fp
|
||||
print(context)
|
||||
sys.stdout = original_stdout
|
||||
evaluator.evaluate(context)
|
||||
detail_case = evaluator.gen_detail_case()
|
||||
with open(detail_cases_filepath, "a") as fp:
|
||||
fp.write(json.dumps(detail_case.to_dict(), ensure_ascii=False) + "\n")
|
||||
time.sleep(4)
|
||||
|
||||
evaluator.post_evaluate()
|
||||
output_result = evaluator.gen_result()
|
||||
# print(evaluator.__dict__)
|
||||
logger.info("执行完成. Result = {output_result}")
|
||||
|
||||
with open(result_filepath, "w") as fp:
|
||||
json.dump(output_result, fp, indent=2, ensure_ascii=False)
|
||||
with open(bad_cases_filepath, "w") as fp:
|
||||
fp.write("当前榜单不存在 Bad Case\n")
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
757
run_async_a10.py
757
run_async_a10.py
@@ -1,757 +0,0 @@
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import fcntl
|
||||
import gc
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import yaml
|
||||
from fabric import Connection
|
||||
from vmplatform import VMOS, Client, VMDataDisk
|
||||
|
||||
from schemas.context import ASRContext
|
||||
from utils.client_async import ClientAsync
|
||||
from utils.evaluator import BaseEvaluator
|
||||
from utils.logger import logger
|
||||
from utils.service import register_sut
|
||||
|
||||
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||
UNIT_TEST = os.getenv("UNIT_TEST", 0)
|
||||
|
||||
DATASET_NUM = os.getenv("DATASET_NUM")
|
||||
|
||||
# vm榜单参数
|
||||
SUT_TYPE = os.getenv("SUT_TYPE", "kubernetes")
|
||||
SHARE_SUT = os.getenv("SHARE_SUT", "true") == "true"
|
||||
VM_ID = 0
|
||||
VM_IP = ""
|
||||
do_deploy_chart = True
|
||||
VM_CPU = int(os.getenv("VM_CPU", "2"))
|
||||
VM_MEM = int(os.getenv("VM_MEM", "4096"))
|
||||
MODEL_BASEPATH = os.getenv("MODEL_BASEPATH", "/tmp/customer/leaderboard/pc_asr")
|
||||
MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
|
||||
SSH_KEY_DIR = os.getenv("SSH_KEY_DIR", "/workspace")
|
||||
SSH_PUBLIC_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa.pub")
|
||||
SSH_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa")
|
||||
|
||||
CONNECT_KWARGS = {"key_filename": SSH_KEY_FILE}
|
||||
|
||||
# 共享sut参数
|
||||
JOB_ID = os.getenv("JOB_ID")
|
||||
dirname = "/tmp/submit_private/sut_share"
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
SUT_SHARE_LOCK = os.path.join(dirname, "lock.lock")
|
||||
SUT_SHARE_USE_LOCK = os.path.join(dirname, "use.lock")
|
||||
SUT_SHARE_STATUS = os.path.join(dirname, "status.json")
|
||||
SUT_SHARE_JOB_STATUS = os.path.join(dirname, f"job_status.{JOB_ID}")
|
||||
SUT_SHARE_PUBLIC_FAIL = os.path.join(dirname, "one_job_failed")
|
||||
fd_lock = open(SUT_SHARE_USE_LOCK, "a")
|
||||
|
||||
|
||||
def clean_vm_atexit():
|
||||
global VM_ID, do_deploy_chart
|
||||
if not VM_ID:
|
||||
return
|
||||
if not do_deploy_chart:
|
||||
return
|
||||
logger.info("删除vm")
|
||||
vmclient = Client()
|
||||
err_msg = vmclient.delete_vm(VM_ID)
|
||||
if err_msg:
|
||||
logger.warning(f"删除vm失败: {err_msg}")
|
||||
|
||||
|
||||
def put_file_to_vm(c: Connection, local_path: str, remote_path: str):
|
||||
logger.info(f"uploading file {local_path} to {remote_path}")
|
||||
result = c.put(local_path, remote_path)
|
||||
logger.info("uploaded {0.local} to {0.remote}".format(result))
|
||||
|
||||
|
||||
def deploy_windows_sut():
|
||||
global VM_ID
|
||||
global VM_IP
|
||||
|
||||
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
|
||||
with open(submit_config_filepath, "r") as fp:
|
||||
st_config = yaml.safe_load(fp)
|
||||
assert "model" in st_config, "未配置model"
|
||||
assert "model_key" in st_config, "未配置model_key"
|
||||
assert "config.json" in st_config, "未配置config.json"
|
||||
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
|
||||
assert len(nfs) > 0, "未配置nfs"
|
||||
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
|
||||
|
||||
model = st_config["model"]
|
||||
model_key = st_config["model_key"]
|
||||
model_path = ""
|
||||
config = st_config["config.json"]
|
||||
exist = False
|
||||
for nfs_item in nfs:
|
||||
if nfs_item["name"] == model_key:
|
||||
exist = True
|
||||
if nfs_item["source"] == "ceph_customer":
|
||||
model_path = os.path.join(
|
||||
"/tmp/customer",
|
||||
nfs_item["srcRelativePath"],
|
||||
)
|
||||
else:
|
||||
model_path = os.path.join(
|
||||
"/tmp/juicefs",
|
||||
nfs_item["srcRelativePath"],
|
||||
)
|
||||
break
|
||||
if not exist:
|
||||
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
|
||||
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
|
||||
model_dir = os.path.basename(model_path).split(".")[0]
|
||||
config["model_path"] = f"E:\\model\\{model_dir}"
|
||||
with open(config_path, "w") as fp:
|
||||
json.dump(config, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
vmclient = Client()
|
||||
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||||
sshpublickey = fp.read().rstrip()
|
||||
VM_ID = vmclient.create_vm(
|
||||
"amd64",
|
||||
VMOS.windows10,
|
||||
VM_CPU,
|
||||
VM_MEM,
|
||||
"leaderboard-%s-submit-%s-job-%s"
|
||||
% (
|
||||
os.getenv("BENCHMARK_NAME"),
|
||||
os.getenv("SUBMIT_ID"),
|
||||
os.getenv("JOB_ID"),
|
||||
),
|
||||
sshpublickey,
|
||||
datadisks=[
|
||||
VMDataDisk(
|
||||
size=50,
|
||||
disk_type="ssd",
|
||||
mount_path="/",
|
||||
filesystem="NTFS",
|
||||
)
|
||||
],
|
||||
)
|
||||
atexit.register(clean_vm_atexit)
|
||||
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
|
||||
VM_IP = vmclient.wait_until_vm_running(VM_ID)
|
||||
logger.info("vm created successfully, vm_ip: %s", VM_IP)
|
||||
|
||||
def sut_startup():
|
||||
with Connection(
|
||||
VM_IP,
|
||||
"administrator",
|
||||
connect_kwargs=CONNECT_KWARGS,
|
||||
) as c:
|
||||
script_path = "E:\\base\\asr\\faster-whisper\\server"
|
||||
script_path = "E:\\install\\asr\\sensevoice\\server"
|
||||
bat_filepath = f"{script_path}\\start.bat"
|
||||
config_filepath = "E:\\submit\\config.json"
|
||||
result = c.run("")
|
||||
assert result.ok
|
||||
c.run(
|
||||
f'cd /d {script_path} & set "EDGE_ML_ENV_HOME=E:\\install" & {bat_filepath} {config_filepath}',
|
||||
warn=True,
|
||||
)
|
||||
|
||||
with Connection(
|
||||
VM_IP,
|
||||
"administrator",
|
||||
connect_kwargs=CONNECT_KWARGS,
|
||||
) as c:
|
||||
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
|
||||
filename = os.path.basename(model_filepath)
|
||||
put_file_to_vm(c, model_filepath, "/E:/")
|
||||
|
||||
result = c.run("mkdir E:\\base")
|
||||
assert result.ok
|
||||
result = c.run("mkdir E:\\model")
|
||||
assert result.ok
|
||||
result = c.run("mkdir E:\\submit")
|
||||
assert result.ok
|
||||
|
||||
result = c.run(
|
||||
f"tar zxvf E:\\{filename} -C E:\\base --strip-components 1"
|
||||
)
|
||||
assert result.ok
|
||||
|
||||
result = c.run("E:\\base\\setup-win.bat E:\\install")
|
||||
assert result.ok
|
||||
|
||||
put_file_to_vm(c, config_path, "/E:/submit")
|
||||
put_file_to_vm(c, model_path, "/E:/model")
|
||||
result = c.run(
|
||||
f"tar zxvf E:\\model\\{os.path.basename(model_path)} -C E:\\model"
|
||||
)
|
||||
assert result.ok
|
||||
threading.Thread(target=sut_startup, daemon=True).start()
|
||||
time.sleep(60)
|
||||
|
||||
return f"ws://{VM_IP}:{config['port']}"
|
||||
|
||||
|
||||
def deploy_macos_sut():
|
||||
global VM_ID
|
||||
global VM_IP
|
||||
|
||||
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
|
||||
with open(submit_config_filepath, "r") as fp:
|
||||
st_config = yaml.safe_load(fp)
|
||||
assert "model" in st_config, "未配置model"
|
||||
assert "model_key" in st_config, "未配置model_key"
|
||||
assert "config.json" in st_config, "未配置config.json"
|
||||
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
|
||||
assert len(nfs) > 0, "未配置nfs"
|
||||
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
|
||||
|
||||
model = st_config["model"]
|
||||
model_key = st_config["model_key"]
|
||||
model_path = ""
|
||||
config = st_config["config.json"]
|
||||
exist = False
|
||||
for nfs_item in nfs:
|
||||
if nfs_item["name"] == model_key:
|
||||
exist = True
|
||||
if nfs_item["source"] == "ceph_customer":
|
||||
model_path = os.path.join(
|
||||
"/tmp/customer",
|
||||
nfs_item["srcRelativePath"],
|
||||
)
|
||||
else:
|
||||
model_path = os.path.join(
|
||||
"/tmp/juicefs",
|
||||
nfs_item["srcRelativePath"],
|
||||
)
|
||||
break
|
||||
if not exist:
|
||||
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
|
||||
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
|
||||
model_dir = os.path.basename(model_path).split(".")[0]
|
||||
|
||||
vmclient = Client()
|
||||
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||||
sshpublickey = fp.read().rstrip()
|
||||
VM_ID = vmclient.create_vm(
|
||||
"amd64",
|
||||
VMOS.macos12,
|
||||
VM_CPU,
|
||||
VM_MEM,
|
||||
"leaderboard-%s-submit-%s-job-%s"
|
||||
% (
|
||||
os.getenv("BENCHMARK_NAME"),
|
||||
os.getenv("SUBMIT_ID"),
|
||||
os.getenv("JOB_ID"),
|
||||
),
|
||||
sshpublickey,
|
||||
datadisks=[
|
||||
VMDataDisk(
|
||||
size=50,
|
||||
disk_type="ssd",
|
||||
mount_path="/",
|
||||
filesystem="apfs",
|
||||
)
|
||||
],
|
||||
)
|
||||
atexit.register(clean_vm_atexit)
|
||||
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
|
||||
VM_IP = vmclient.wait_until_vm_running(VM_ID)
|
||||
logger.info("vm created successfully, vm_ip: %s", VM_IP)
|
||||
|
||||
with Connection(
|
||||
VM_IP,
|
||||
"admin",
|
||||
connect_kwargs=CONNECT_KWARGS,
|
||||
) as c:
|
||||
result = c.run("ls -d /Volumes/data*")
|
||||
assert result.ok
|
||||
volume_path = result.stdout.strip()
|
||||
|
||||
config["model_path"] = f"{volume_path}/model/{model_dir}"
|
||||
with open(config_path, "w") as fp:
|
||||
json.dump(config, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
def sut_startup():
|
||||
with Connection(
|
||||
VM_IP,
|
||||
"admin",
|
||||
connect_kwargs=CONNECT_KWARGS,
|
||||
) as c:
|
||||
script_path = f"{volume_path}/install/asr/sensevoice/server"
|
||||
startsh = f"{script_path}/start.sh"
|
||||
config_filepath = f"{volume_path}/submit/config.json"
|
||||
c.run(
|
||||
f"cd {script_path} && sh {startsh} {config_filepath}",
|
||||
warn=True,
|
||||
)
|
||||
|
||||
with Connection(
|
||||
VM_IP,
|
||||
"admin",
|
||||
connect_kwargs=CONNECT_KWARGS,
|
||||
) as c:
|
||||
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
|
||||
filename = os.path.basename(model_filepath)
|
||||
put_file_to_vm(c, model_filepath, f"{volume_path}")
|
||||
|
||||
result = c.run(f"mkdir {volume_path}/base")
|
||||
assert result.ok
|
||||
result = c.run(f"mkdir {volume_path}/model")
|
||||
assert result.ok
|
||||
result = c.run(f"mkdir {volume_path}/submit")
|
||||
assert result.ok
|
||||
|
||||
result = c.run(
|
||||
f"tar zxvf {volume_path}/{filename} -C {volume_path}/base --strip-components 1" # noqa: E501
|
||||
)
|
||||
assert result.ok
|
||||
|
||||
result = c.run(
|
||||
f"sh {volume_path}/base/setup-mac.sh {volume_path}/install x64"
|
||||
)
|
||||
assert result.ok
|
||||
|
||||
put_file_to_vm(c, config_path, f"{volume_path}/submit")
|
||||
put_file_to_vm(c, model_path, f"{volume_path}/model")
|
||||
result = c.run(
|
||||
f"tar zxvf {volume_path}/model/{os.path.basename(model_path)} -C {volume_path}/model" # noqa: E501
|
||||
)
|
||||
assert result.ok
|
||||
threading.Thread(target=sut_startup, daemon=True).start()
|
||||
time.sleep(60)
|
||||
|
||||
return f"ws://{VM_IP}:{config['port']}"
|
||||
|
||||
|
||||
def get_sut_url_vm(vm_type: str):
|
||||
global VM_ID
|
||||
global VM_IP
|
||||
global do_deploy_chart
|
||||
|
||||
do_deploy_chart = True
|
||||
# 拉起SUT
|
||||
|
||||
def check_job_failed():
|
||||
while True:
|
||||
time.sleep(30)
|
||||
if os.path.exists(SUT_SHARE_PUBLIC_FAIL):
|
||||
logger.error("there is a job failed in current submit")
|
||||
sys.exit(1)
|
||||
|
||||
sut_url = ""
|
||||
threading.Thread(target=check_job_failed, daemon=True).start()
|
||||
if SHARE_SUT:
|
||||
|
||||
time.sleep(10 * random.random())
|
||||
try:
|
||||
open(SUT_SHARE_LOCK, "x").close()
|
||||
except Exception:
|
||||
do_deploy_chart = False
|
||||
|
||||
start_at = time.time()
|
||||
|
||||
def file_last_updated_at(file: str):
|
||||
return os.stat(file).st_mtime if os.path.exists(file) else start_at
|
||||
|
||||
if not do_deploy_chart:
|
||||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||
f.write("waiting")
|
||||
while (
|
||||
time.time() - file_last_updated_at(SUT_SHARE_STATUS)
|
||||
<= 60 * 60 * 24
|
||||
):
|
||||
logger.info(
|
||||
"Waiting sut application to be deployed by another job"
|
||||
)
|
||||
time.sleep(10 + random.random())
|
||||
if os.path.exists(SUT_SHARE_STATUS):
|
||||
get_status = False
|
||||
for _ in range(10):
|
||||
try:
|
||||
with open(SUT_SHARE_STATUS, "r") as f:
|
||||
status = json.load(f)
|
||||
get_status = True
|
||||
break
|
||||
except Exception:
|
||||
time.sleep(1 + random.random())
|
||||
continue
|
||||
if not get_status:
|
||||
raise RuntimeError(
|
||||
"Failed to get status of sut application"
|
||||
)
|
||||
assert (
|
||||
status.get("status") != "failed"
|
||||
), "Failed to deploy sut application, \
|
||||
please check other job logs"
|
||||
if status.get("status") == "running":
|
||||
VM_ID = status.get("vmid")
|
||||
VM_IP = status.get("vmip")
|
||||
sut_url = status.get("sut_url")
|
||||
with open(SSH_PUBLIC_KEY_FILE, "w") as fp:
|
||||
fp.write(status.get("pubkey"))
|
||||
with open(SSH_KEY_FILE, "w") as fp:
|
||||
fp.write(status.get("prikey"))
|
||||
logger.info("Successfully get deployed sut application")
|
||||
break
|
||||
|
||||
if do_deploy_chart:
|
||||
try:
|
||||
fcntl.flock(fd_lock, fcntl.LOCK_EX)
|
||||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||
f.write("waiting")
|
||||
pending = True
|
||||
|
||||
def update_status():
|
||||
while pending:
|
||||
time.sleep(30)
|
||||
if not pending:
|
||||
break
|
||||
with open(SUT_SHARE_STATUS, "w") as f:
|
||||
json.dump({"status": "pending"}, f)
|
||||
|
||||
threading.Thread(target=update_status, daemon=True).start()
|
||||
if vm_type == "windows":
|
||||
sut_url = deploy_windows_sut()
|
||||
else:
|
||||
sut_url = deploy_macos_sut()
|
||||
except Exception:
|
||||
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
|
||||
with open(SUT_SHARE_STATUS, "w") as f:
|
||||
json.dump({"status": "failed"}, f)
|
||||
raise
|
||||
finally:
|
||||
pending = False
|
||||
with open(SUT_SHARE_STATUS, "w") as f:
|
||||
pubkey = ""
|
||||
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||||
pubkey = fp.read().rstrip()
|
||||
prikey = ""
|
||||
with open(SSH_KEY_FILE, "r") as fp:
|
||||
prikey = fp.read()
|
||||
json.dump(
|
||||
{
|
||||
"status": "running",
|
||||
"vmid": VM_ID,
|
||||
"vmip": VM_IP,
|
||||
"pubkey": pubkey,
|
||||
"sut_url": sut_url,
|
||||
"prikey": prikey,
|
||||
},
|
||||
f,
|
||||
)
|
||||
else:
|
||||
while True:
|
||||
time.sleep(5 + random.random())
|
||||
try:
|
||||
fcntl.flock(fd_lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
break
|
||||
except Exception:
|
||||
logger.info("尝试抢占调用sut失败,继续等待 5s ...")
|
||||
|
||||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||
f.write("running")
|
||||
|
||||
return sut_url
|
||||
|
||||
|
||||
def get_sut_url():
|
||||
if SUT_TYPE in ("windows", "macos"):
|
||||
return get_sut_url_vm(SUT_TYPE)
|
||||
|
||||
submit_config_filepath = os.getenv(
|
||||
"SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config"
|
||||
)
|
||||
CPU = os.getenv("SUT_CPU", "2")
|
||||
MEMORY = os.getenv("SUT_MEMORY", "4Gi")
|
||||
resource_name = os.getenv("BENCHMARK_NAME")
|
||||
|
||||
# 任务信息
|
||||
# 斯拉夫语族:俄语、波兰语
|
||||
# 日耳曼语族:英语、德语、荷兰语
|
||||
# 拉丁语族(罗曼语族):西班牙语、葡萄牙语、法国语、意大利语
|
||||
# 闪米特语族:阿拉伯语、希伯来语
|
||||
|
||||
# 提交配置 & 启动被测服务
|
||||
if os.getenv("DATASET_FILEPATH", ""):
|
||||
with open(submit_config_filepath, "r") as fp:
|
||||
st_config = yaml.safe_load(fp)
|
||||
if "values" not in st_config:
|
||||
st_config["values"] = {}
|
||||
st_config["values"]["resources"] = {}
|
||||
st_config["values"]["resources"]["limits"] = {}
|
||||
st_config["values"]["resources"]["limits"]["cpu"] = CPU
|
||||
st_config["values"]["resources"]["limits"]["memory"] = MEMORY
|
||||
# st_config["values"]['resources']['limits']['nvidia.com/gpu'] = '1'
|
||||
# st_config["values"]['resources']['limits']['nvidia.com/gpumem'] = "1843"
|
||||
# st_config["values"]['resources']['limits']['nvidia.com/gpucores'] = "8"
|
||||
st_config["values"]["resources"]["requests"] = {}
|
||||
st_config["values"]["resources"]["requests"]["cpu"] = CPU
|
||||
st_config["values"]["resources"]["requests"]["memory"] = MEMORY
|
||||
# st_config["values"]['resources']['requests']['nvidia.com/gpu'] = '1'
|
||||
# st_config["values"]['resources']['requests']['nvidia.com/gpumem'] = "1843"
|
||||
# st_config["values"]['resources']['requests']['nvidia.com/gpucores'] = "8"
|
||||
# st_config['values']['nodeSelector'] = {}
|
||||
# st_config["values"]["nodeSelector"][
|
||||
# "contest.4pd.io/accelerator"
|
||||
# ] = "A10vgpu"
|
||||
# st_config['values']['tolerations'] = []
|
||||
# toleration_item = {}
|
||||
# toleration_item['key'] = 'hosttype'
|
||||
# toleration_item['operator'] = 'Equal'
|
||||
# toleration_item['value'] = 'vgpu'
|
||||
# toleration_item['effect'] = 'NoSchedule'
|
||||
# st_config['values']['tolerations'].append(toleration_item)
|
||||
if os.getenv("RESOURCE_TYPE", "cpu") == "cpu":
|
||||
values = st_config["values"]
|
||||
limits = values.get("resources", {}).get("limits", {})
|
||||
requests = values.get("resources", {}).get("requests", {})
|
||||
if (
|
||||
"nvidia.com/gpu" in limits
|
||||
or "nvidia.com/gpumem" in limits
|
||||
or "nvidia.com/gpucores" in limits
|
||||
or "nvidia.com/gpu" in requests
|
||||
or "nvidia.com/gpumem" in requests
|
||||
or "nvidia.com/gpucores" in requests
|
||||
):
|
||||
raise Exception("禁止使用GPU!")
|
||||
else:
|
||||
vgpu_num = int(os.getenv("SUT_VGPU", "3"))
|
||||
st_config["values"]["resources"]["limits"]["nvidia.com/gpu"] = (
|
||||
str(vgpu_num)
|
||||
)
|
||||
st_config["values"]["resources"]["limits"][
|
||||
"nvidia.com/gpumem"
|
||||
] = str(1843 * vgpu_num)
|
||||
st_config["values"]["resources"]["limits"][
|
||||
"nvidia.com/gpucores"
|
||||
] = str(8 * vgpu_num)
|
||||
st_config["values"]["resources"]["requests"][
|
||||
"nvidia.com/gpu"
|
||||
] = str(vgpu_num)
|
||||
st_config["values"]["resources"]["requests"][
|
||||
"nvidia.com/gpumem"
|
||||
] = str(1843 * vgpu_num)
|
||||
st_config["values"]["resources"]["requests"][
|
||||
"nvidia.com/gpucores"
|
||||
] = str(8 * vgpu_num)
|
||||
st_config["values"]["nodeSelector"] = {}
|
||||
st_config["values"]["nodeSelector"][
|
||||
"contest.4pd.io/accelerator"
|
||||
] = "A10vgpu"
|
||||
st_config["values"]["tolerations"] = []
|
||||
toleration_item = {}
|
||||
toleration_item["key"] = "hosttype"
|
||||
toleration_item["operator"] = "Equal"
|
||||
toleration_item["value"] = "vgpu"
|
||||
toleration_item["effect"] = "NoSchedule"
|
||||
st_config["values"]["tolerations"].append(toleration_item)
|
||||
if "docker_images" in st_config:
|
||||
sut_url = "ws://172.26.1.75:9827"
|
||||
os.environ["test"] = "1"
|
||||
elif "docker_image" in st_config:
|
||||
sut_url = register_sut(st_config, resource_name)
|
||||
elif UNIT_TEST:
|
||||
sut_url = "ws://172.27.231.36:80"
|
||||
else:
|
||||
logger.error("config 配置错误,没有 docker_image")
|
||||
os._exit(1)
|
||||
return sut_url
|
||||
else:
|
||||
os.environ["test"] = "1"
|
||||
sut_url = "ws://172.27.231.36:80"
|
||||
sut_url = "ws://172.26.1.75:9827"
|
||||
return sut_url
|
||||
|
||||
|
||||
def load_merge_dataset(dataset_filepath: str) -> dict:
|
||||
local_dataset_path = "./dataset"
|
||||
os.makedirs(local_dataset_path, exist_ok=True)
|
||||
with zipfile.ZipFile(dataset_filepath) as zf:
|
||||
zf.extractall(local_dataset_path)
|
||||
|
||||
config = {}
|
||||
sub_datasets = os.listdir(local_dataset_path)
|
||||
for sub_dataset in sub_datasets:
|
||||
if sub_dataset.startswith("asr."):
|
||||
lang = sub_dataset[4:]
|
||||
lang_path = os.path.join(local_dataset_path, lang)
|
||||
os.makedirs(lang_path, exist_ok=True)
|
||||
with zipfile.ZipFile(
|
||||
os.path.join(local_dataset_path, sub_dataset)
|
||||
) as zf:
|
||||
zf.extractall(lang_path)
|
||||
lang_config_path = os.path.join(lang_path, "data.yaml")
|
||||
with open(lang_config_path, "r") as fp:
|
||||
lang_config = yaml.safe_load(fp)
|
||||
audio_lengths = {}
|
||||
for query_item in lang_config.get("query_data", []):
|
||||
audio_path = os.path.join(
|
||||
lang_path,
|
||||
query_item["file"],
|
||||
)
|
||||
query_item["file"] = audio_path
|
||||
audio_lengths[query_item["file"]] = os.path.getsize(
|
||||
audio_path,
|
||||
)
|
||||
lang_config["query_data"] = sorted(
|
||||
lang_config.get("query_data", []),
|
||||
key=lambda x: audio_lengths[x["file"]],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
idx = 0
|
||||
length = 0.0
|
||||
for query_item in lang_config["query_data"]:
|
||||
audio_length = audio_lengths[query_item["file"]]
|
||||
length += audio_length / 32000
|
||||
idx += 1
|
||||
# 每个语言限制半个小时长度
|
||||
if length >= 30 * 60:
|
||||
break
|
||||
|
||||
lang_config["query_data"] = lang_config["query_data"][:idx]
|
||||
config[lang] = lang_config
|
||||
|
||||
config["query_data"] = []
|
||||
for lang, lang_config in config.items():
|
||||
if lang == "query_data":
|
||||
continue
|
||||
for query_item in lang_config["query_data"]:
|
||||
config["query_data"].append(
|
||||
{
|
||||
**query_item,
|
||||
"lang": lang,
|
||||
}
|
||||
)
|
||||
random.Random(0).shuffle(config["query_data"])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def postprocess_failed():
|
||||
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
|
||||
|
||||
|
||||
def main():
|
||||
dataset_filepath = os.getenv(
|
||||
"DATASET_FILEPATH",
|
||||
"/Users/4paradigm/Projects/dataset/asr/de.zip",
|
||||
# "./tests/resources/en.zip",
|
||||
)
|
||||
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||||
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||||
detail_cases_filepath = os.getenv(
|
||||
"DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl"
|
||||
)
|
||||
thread_num = int(os.getenv("THREAD_NUM", "1"))
|
||||
|
||||
# 数据集处理
|
||||
config = {}
|
||||
if os.getenv("MERGE_DATASET", "1"):
|
||||
config = load_merge_dataset(dataset_filepath)
|
||||
dataset_query = config["query_data"]
|
||||
else:
|
||||
local_dataset_path = "./dataset"
|
||||
os.makedirs(local_dataset_path, exist_ok=True)
|
||||
with zipfile.ZipFile(dataset_filepath) as zf:
|
||||
zf.extractall(local_dataset_path)
|
||||
config_path = os.path.join(local_dataset_path, "data.yaml")
|
||||
with open(config_path, "r") as fp:
|
||||
dataset_config = yaml.safe_load(fp)
|
||||
# 读取所有的音频,进而获得音频的总长度,最后按照音频长度对 query_data 进行降序排序
|
||||
lang = os.getenv("lang")
|
||||
if lang is None:
|
||||
lang = dataset_config.get("global", {}).get("lang", "en")
|
||||
audio_lengths = []
|
||||
for query_item in dataset_config.get("query_data", []):
|
||||
query_item["lang"] = lang
|
||||
audio_path = os.path.join(local_dataset_path, query_item["file"])
|
||||
query_item["file"] = audio_path
|
||||
audio_lengths.append(os.path.getsize(audio_path) / 1024 / 1024)
|
||||
dataset_config["query_data"] = sorted(
|
||||
dataset_config.get("query_data", []),
|
||||
key=lambda x: audio_lengths[dataset_config["query_data"].index(x)],
|
||||
reverse=True,
|
||||
)
|
||||
# 数据集信息
|
||||
# dataset_global_config = dataset_config.get("global", {})
|
||||
dataset_query = dataset_config.get("query_data", {})
|
||||
config[lang] = dataset_config
|
||||
|
||||
# sut url
|
||||
sut_url = get_sut_url()
|
||||
|
||||
try:
|
||||
# 开始测试
|
||||
logger.info("开始执行")
|
||||
evaluator = BaseEvaluator()
|
||||
future_list = []
|
||||
with ThreadPoolExecutor(max_workers=thread_num) as executor:
|
||||
for idx, query_item in enumerate(dataset_query):
|
||||
context = ASRContext(
|
||||
**config[query_item["lang"]].get("global", {}),
|
||||
)
|
||||
context.lang = query_item["lang"]
|
||||
context.file_path = query_item["file"]
|
||||
context.append_labels(query_item["voice"])
|
||||
future = executor.submit(
|
||||
ClientAsync(sut_url, context, idx).action
|
||||
)
|
||||
future_list.append(future)
|
||||
for future in concurrent.futures.as_completed(future_list):
|
||||
context = future.result()
|
||||
evaluator.evaluate(context)
|
||||
detail_case = evaluator.gen_detail_case()
|
||||
with open(detail_cases_filepath, "a") as fp:
|
||||
fp.write(
|
||||
json.dumps(
|
||||
detail_case.to_dict(),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n",
|
||||
)
|
||||
del context
|
||||
gc.collect()
|
||||
|
||||
evaluator.post_evaluate()
|
||||
output_result = evaluator.gen_result()
|
||||
logger.info("执行完成")
|
||||
|
||||
with open(result_filepath, "w") as fp:
|
||||
json.dump(output_result, fp, indent=2, ensure_ascii=False)
|
||||
with open(bad_cases_filepath, "w") as fp:
|
||||
fp.write("当前榜单不存在 Bad Case\n")
|
||||
|
||||
if SHARE_SUT:
|
||||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||
f.write("success")
|
||||
|
||||
fcntl.flock(fd_lock, fcntl.LOCK_UN)
|
||||
fd_lock.close()
|
||||
while SHARE_SUT and do_deploy_chart:
|
||||
time.sleep(30)
|
||||
success_num = 0
|
||||
for job_status_file in glob.glob(dirname + "/job_status.*"):
|
||||
with open(job_status_file, "r") as f:
|
||||
job_status = f.read()
|
||||
success_num += job_status == "success"
|
||||
if success_num == int(DATASET_NUM):
|
||||
break
|
||||
logger.info("Waiting for all jobs to complete")
|
||||
except Exception:
|
||||
if SHARE_SUT:
|
||||
postprocess_failed()
|
||||
raise
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,90 +0,0 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from schemas.stream import StreamDataModel
|
||||
|
||||
|
||||
class LabelContext(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
answer: str
|
||||
|
||||
|
||||
class PredContext(BaseModel):
|
||||
recognition_results: StreamDataModel
|
||||
recv_time: Optional[float] = Field(None)
|
||||
send_time: Optional[float] = Field(None)
|
||||
|
||||
|
||||
class ASRContext:
|
||||
def __init__(self, **kwargs):
|
||||
self.bits = kwargs.get("bits", 16)
|
||||
self.channel = kwargs.get("channel", 1)
|
||||
self.sample_rate = kwargs.get("sample_rate", 16000)
|
||||
self.audio_format = kwargs.get("format", "wav")
|
||||
self.enable_words = kwargs.get("enable_words", True)
|
||||
self.char_contains_rate = kwargs.get("char_contains_rate", 0.8)
|
||||
self.lang = os.getenv("lang")
|
||||
if self.lang is None:
|
||||
self.lang = kwargs.get("lang", "en")
|
||||
self.stream = kwargs.get("stream", True)
|
||||
|
||||
self.wait_time = float(os.getenv("wait_time", 0.1))
|
||||
self.chunk_size = self.sample_rate * self.bits / 8 * self.wait_time
|
||||
if int(os.getenv('chunk_size_set', 0)):
|
||||
self.chunk_size = int(os.getenv('chunk_size_set', 0))
|
||||
|
||||
self.audio_length = 0
|
||||
self.file_path = ""
|
||||
|
||||
self.labels: List[LabelContext] = kwargs.get("labels", [])
|
||||
self.preds: List[PredContext] = kwargs.get("preds", [])
|
||||
|
||||
self.label_sentences: List[str] = []
|
||||
self.pred_sentences: List[str] = []
|
||||
|
||||
self.send_time_start_end = []
|
||||
self.recv_time_start_end = []
|
||||
|
||||
self.fail = False
|
||||
self.fail_char_contains_rate_num = 0
|
||||
|
||||
self.punctuation_num = 0
|
||||
self.pred_punctuation_num = 0
|
||||
|
||||
def append_labels(self, voices: List[Dict]):
|
||||
for voice_data in voices:
|
||||
label_context = LabelContext(**voice_data)
|
||||
self.labels.append(label_context)
|
||||
|
||||
def append_preds(
|
||||
self,
|
||||
predict_data: List[StreamDataModel],
|
||||
send_time: List[float],
|
||||
recv_time: List[float],
|
||||
):
|
||||
self.send_time_start_end = [send_time[0], send_time[-1]] if len(send_time) > 0 else []
|
||||
self.recv_time_start_end = [recv_time[0], recv_time[-1]] if len(recv_time) > 0 else []
|
||||
for pred_item, send_time_item, recv_time_item in zip(predict_data, send_time, recv_time):
|
||||
pred_item = deepcopy(pred_item)
|
||||
pred_context = PredContext(recognition_results=pred_item.model_dump())
|
||||
pred_context.send_time = send_time_item
|
||||
pred_context.recv_time = recv_time_item
|
||||
self.preds.append(pred_context)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"bits": self.bits,
|
||||
"channel": self.channel,
|
||||
"sample_rate": self.sample_rate,
|
||||
"audio_format": self.audio_format,
|
||||
"enable_words": self.enable_words,
|
||||
"stream": self.stream,
|
||||
"wait_time": self.wait_time,
|
||||
"chunk_size": self.chunk_size,
|
||||
"labels": [item.model_dump_json() for item in self.labels],
|
||||
"preds": [item.model_dump_json() for item in self.preds],
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class QueryDataSentence(BaseModel):
|
||||
answer: str = Field(description="文本label")
|
||||
start: float = Field(description="句子开始时间")
|
||||
end: float = Field(description="句子结束时间")
|
||||
|
||||
|
||||
class QueryData(BaseModel):
|
||||
lang: str = Field(description="语言")
|
||||
file: str = Field(description="音频文件位置")
|
||||
duration: float = Field(description="音频长度")
|
||||
voice: List[QueryDataSentence] = Field(
|
||||
description="音频文件的文本label内容"
|
||||
)
|
||||
@@ -1,66 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, ValidationError, field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
|
||||
class StreamWordsModel(BaseModel):
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_result(self):
|
||||
if self.end_time < self.start_time:
|
||||
raise ValidationError("end-time 小于 start-time, error")
|
||||
return self
|
||||
|
||||
|
||||
class StreamDataModel(BaseModel):
|
||||
text: str
|
||||
language: str
|
||||
final_result: bool
|
||||
para_seq: int
|
||||
start_time: float
|
||||
end_time: float
|
||||
words: List[StreamWordsModel]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_result(self):
|
||||
if self.end_time < self.start_time:
|
||||
raise ValidationError("end-time 小于 start-time, error")
|
||||
return self
|
||||
|
||||
|
||||
class StreamResultModel(BaseModel):
|
||||
asr_results: StreamDataModel
|
||||
|
||||
@field_validator('asr_results', mode="after")
|
||||
def convert_to_seconds(cls, v: StreamDataModel, values):
|
||||
# 在这里处理除以1000的逻辑
|
||||
v.end_time = v.end_time / 1000
|
||||
v.start_time = v.start_time / 1000
|
||||
for word in v.words:
|
||||
word.start_time /= 1000
|
||||
word.end_time /= 1000
|
||||
return v
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
|
||||
class NonStreamDataModel(BaseModel):
|
||||
text: str
|
||||
para_seq: int
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_result(self):
|
||||
if self.end_time < self.start_time:
|
||||
raise ValidationError("end-time 小于 start-time, error")
|
||||
return self
|
||||
|
||||
|
||||
class NonStreamResultModel(BaseModel):
|
||||
contents: List[NonStreamDataModel]
|
||||
@@ -1,53 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def main(dataset_dir):
|
||||
dirs = os.listdir(dataset_dir)
|
||||
dirs = list(
|
||||
filter(lambda x: os.path.isdir(os.path.join(dataset_dir, x)), dirs)
|
||||
)
|
||||
|
||||
problem_dirs = set()
|
||||
problem_count = defaultdict(int)
|
||||
for dir in dirs:
|
||||
with open(os.path.join(dataset_dir, dir, "data.yaml"), "r") as f:
|
||||
data = yaml.full_load(f)
|
||||
for query_i, query in enumerate(data["query_data"]):
|
||||
voices = sorted(query["voice"], key=lambda x: x["start"])
|
||||
if voices != query["voice"]:
|
||||
print("-----", dir)
|
||||
if voices[0]["start"] > voices[0]["end"]:
|
||||
print(
|
||||
"err1: %s 第%s个query的第%d个voice的start大于end: %s"
|
||||
% (dir, query_i, 0, voices[0]["answer"])
|
||||
)
|
||||
problem_dirs.add(dir)
|
||||
for voice_i in range(1, len(voices)):
|
||||
voice = voices[voice_i]
|
||||
if voice["start"] > voice["end"]:
|
||||
print(
|
||||
"err1: %s 第%s个query的第%d个voice的start大于end: %s"
|
||||
% (dir, query_i, voice_i, voice["answer"])
|
||||
)
|
||||
problem_dirs.add(dir)
|
||||
if voice["start"] < voices[voice_i - 1]["end"]:
|
||||
print(
|
||||
"err2: %s 第%s个query的第%d个voice的start小于前一个voice的end: %s"
|
||||
% (dir, query_i, voice_i, voice["answer"])
|
||||
)
|
||||
problem_dirs.add(dir)
|
||||
problem_count[dir] += 1
|
||||
print(len(dirs))
|
||||
print(problem_dirs)
|
||||
print(problem_count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("指定 测试数据集文件夹")
|
||||
sys.exit(1)
|
||||
main(sys.argv[1])
|
||||
@@ -1,108 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
import yaml
|
||||
|
||||
"""
|
||||
target
|
||||
{
|
||||
"global": {
|
||||
"lang": ""
|
||||
},
|
||||
"query_data": [
|
||||
"file": "",
|
||||
"duration": 2.0,
|
||||
"voice": [
|
||||
{
|
||||
"answer": "",
|
||||
"start": 0.0,
|
||||
"end": 1.0
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def situation_a(meta, dataset_folder, output_folder):
|
||||
"""
|
||||
{
|
||||
"combined": {
|
||||
"en": [
|
||||
{
|
||||
"wav": "*.wav",
|
||||
"transcriptions": [
|
||||
{
|
||||
"text": "",
|
||||
"start": 0.0,
|
||||
"end": 1.0
|
||||
}
|
||||
],
|
||||
"duration": 2.0
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
"""
|
||||
meta = meta["combined"]
|
||||
|
||||
for lang, arr in meta.items():
|
||||
print("processing", lang)
|
||||
assert len(lang) == 2
|
||||
lang_folder = os.path.join(output_folder, lang)
|
||||
os.makedirs(lang_folder, exist_ok=True)
|
||||
data = {"global": {"lang": lang}, "query_data": []}
|
||||
query_data = data["query_data"]
|
||||
for item in arr:
|
||||
os.makedirs(
|
||||
os.path.join(lang_folder, os.path.dirname(item["wav"])),
|
||||
exist_ok=True,
|
||||
)
|
||||
mp3_file = item["wav"][:-4] + ".mp3"
|
||||
shutil.copyfile(
|
||||
os.path.join(dataset_folder, mp3_file),
|
||||
os.path.join(lang_folder, mp3_file),
|
||||
)
|
||||
query_data_item = {
|
||||
"file": mp3_file,
|
||||
"duration": float(item["duration"]),
|
||||
"voice": [],
|
||||
}
|
||||
query_data.append(query_data_item)
|
||||
voice = query_data_item["voice"]
|
||||
for v in item["transcriptions"]:
|
||||
voice.append(
|
||||
{
|
||||
"answer": v["text"],
|
||||
"start": float(v["start"]),
|
||||
"end": float(v["end"]),
|
||||
}
|
||||
)
|
||||
with open(os.path.join(lang_folder, "data.yaml"), "w") as f:
|
||||
yaml.dump(data, f, indent=2, allow_unicode=True, encoding="utf-8")
|
||||
with zipfile.ZipFile(
|
||||
os.path.join(output_folder, lang + ".zip"), "w"
|
||||
) as ziper:
|
||||
dirname = lang_folder
|
||||
for path, _, files in os.walk(dirname):
|
||||
for file in files:
|
||||
ziper.write(
|
||||
os.path.join(path, file),
|
||||
os.path.join(path[len(dirname) :], file),
|
||||
zipfile.ZIP_DEFLATED,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 3:
|
||||
print("指定 数据集文件夹路径 输出路径")
|
||||
sys.exit(1)
|
||||
dataset_folder = sys.argv[1]
|
||||
output_folder = sys.argv[2]
|
||||
|
||||
with open(os.path.join(dataset_folder, "meta.json")) as f:
|
||||
meta = json.load(f)
|
||||
situation_a(meta, dataset_folder, output_folder)
|
||||
@@ -1,56 +0,0 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
from schemas.dataset import QueryData
|
||||
from schemas.stream import StreamDataModel
|
||||
from utils.evaluator_plus import evaluate_editops
|
||||
|
||||
|
||||
def main(detailcase_file: str):
|
||||
with open(detailcase_file) as f:
|
||||
d = json.load(f)[0]
|
||||
preds = d["preds"]
|
||||
preds = list(map(lambda x: StreamDataModel(**x), preds))
|
||||
preds = list(filter(lambda x: x.final_result, preds))
|
||||
label = d["label"]
|
||||
label = QueryData(**label)
|
||||
print(evaluate_editops(label, preds))
|
||||
|
||||
|
||||
def evaluate_from_record(detailcase_file: str, record_path: str):
|
||||
with open(detailcase_file) as f:
|
||||
d = json.load(f)[0]
|
||||
label = d["label"]
|
||||
label = QueryData(**label)
|
||||
with open(record_path) as f:
|
||||
record = json.load(f)
|
||||
tokens_pred = record["tokens_pred"]
|
||||
tokens_label = record["tokens_label"]
|
||||
recognition_results = record["recognition_results"]
|
||||
recognition_results = list(
|
||||
map(lambda x: StreamDataModel(**x), recognition_results)
|
||||
)
|
||||
a, b = [], []
|
||||
for i, rr in enumerate(recognition_results):
|
||||
if rr.final_result:
|
||||
a.append(tokens_pred[i])
|
||||
b.append(rr)
|
||||
tokens_pred = a
|
||||
recognition_results = b
|
||||
|
||||
print(
|
||||
evaluate_editops(
|
||||
label,
|
||||
recognition_results,
|
||||
tokens_pred,
|
||||
tokens_label,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("请指定 detailcase 文件路径")
|
||||
sys.exit(1)
|
||||
main(sys.argv[1])
|
||||
# evaluate_from_record(sys.argv[1], sys.argv[2])
|
||||
BIN
ssh-keygen
BIN
ssh-keygen
Binary file not shown.
@@ -1,11 +0,0 @@
|
||||
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ADD ./requirements.txt /workspace
|
||||
RUN pip install -r ./requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --trusted-host nexus.4pd.io --extra-index-url https://mirrors.aliyun.com/pypi/simple/ \
|
||||
&& pip cache purge
|
||||
|
||||
ADD . /workspace
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
@@ -1,313 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import flask
|
||||
import requests
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
app = flask.Flask(__name__)
|
||||
heartbeat_active = False
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
log.propagate = False
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
log.setLevel(level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
streamHandler = logging.StreamHandler()
|
||||
streamHandler.setLevel(level)
|
||||
streamHandler.setFormatter(formatter)
|
||||
log.addHandler(streamHandler)
|
||||
|
||||
|
||||
def heartbeat(url):
|
||||
global heartbeat_active
|
||||
if heartbeat_active:
|
||||
return
|
||||
heartbeat_active = True
|
||||
while True:
|
||||
try:
|
||||
requests.post(url, json={"status": "RUNNING"})
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def asr(
|
||||
audio_file: FileStorage,
|
||||
language: Optional[str],
|
||||
progressCallbackUrl: str,
|
||||
taskId: str,
|
||||
):
|
||||
"""TODO: 读取audio_file, 调用语音识别服务, 实时返回识别结果"""
|
||||
|
||||
# ignore BEGIN
|
||||
# 此处为榜单本地测试使用
|
||||
if os.getenv("LOCAL_TEST"):
|
||||
return local_test(progressCallbackUrl, taskId)
|
||||
# ignore END
|
||||
|
||||
language = "de"
|
||||
# 某一次识别返回
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={
|
||||
"taskId": taskId,
|
||||
"status": "RUNNING",
|
||||
"recognition_results": { # 传增量结果, status如果是FINISHED, 或者ERROR, 这个字段请不要传值
|
||||
"text": "最先启动的还是",
|
||||
"final_result": True,
|
||||
"para_seq": 0,
|
||||
"language": language,
|
||||
"start_time": 6300,
|
||||
"end_time": 6421,
|
||||
"words": [
|
||||
{
|
||||
"text": "最",
|
||||
"start_time": 6300,
|
||||
"end_time": 6321,
|
||||
},
|
||||
{
|
||||
"text": "先",
|
||||
"start_time": 6321,
|
||||
"end_time": 6345,
|
||||
},
|
||||
{
|
||||
"text": "启",
|
||||
"start_time": 6345,
|
||||
"end_time": 6350,
|
||||
},
|
||||
{
|
||||
"text": "动",
|
||||
"start_time": 6350,
|
||||
"end_time": 6370,
|
||||
},
|
||||
{
|
||||
"text": "的",
|
||||
"start_time": 6370,
|
||||
"end_time": 6386,
|
||||
},
|
||||
{
|
||||
"text": "还",
|
||||
"start_time": 6386,
|
||||
"end_time": 6421,
|
||||
},
|
||||
{
|
||||
"text": "是",
|
||||
"start_time": 6421,
|
||||
"end_time": 6435,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
# ... 识别结果返回完毕
|
||||
|
||||
# 识别结束
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={
|
||||
"taskId": taskId,
|
||||
"status": "FINISHED",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/predict")
|
||||
def predict():
|
||||
body = flask.request.form
|
||||
language = body.get("language")
|
||||
if language is None:
|
||||
"自行判断语种"
|
||||
taskId = body["taskId"]
|
||||
progressCallbackUrl = body["progressCallbackUrl"]
|
||||
heartbeatUrl = body["heartbeatUrl"]
|
||||
|
||||
threading.Thread(
|
||||
target=heartbeat, args=(heartbeatUrl,), daemon=True
|
||||
).start()
|
||||
|
||||
audio_file = flask.request.files["file"]
|
||||
# audio_file.stream # 读取文件流
|
||||
# audio_file.save("audio.mp3") # 保存文件
|
||||
threading.Thread(
|
||||
target=asr,
|
||||
args=(audio_file, language, progressCallbackUrl, taskId),
|
||||
daemon=True,
|
||||
).start()
|
||||
return flask.jsonify({"status": "OK"})
|
||||
|
||||
|
||||
# ignore BEGIN
|
||||
def local_test(progressCallbackUrl: str, taskId: str):
|
||||
"""忽略此方法, 此方法为榜单本地调试使用"""
|
||||
import random
|
||||
import re
|
||||
|
||||
import yaml
|
||||
|
||||
def callback(content):
|
||||
try:
|
||||
if content is None:
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={"taskId": taskId, "status": "FINISHED"},
|
||||
)
|
||||
else:
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={
|
||||
"taskId": taskId,
|
||||
"status": "RUNNING",
|
||||
"recognition_results": content,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with open(
|
||||
os.getenv("LOCAL_TEST_DATA_PATH", "../dataset/out/data.yaml")
|
||||
) as f:
|
||||
data = yaml.full_load(f)
|
||||
|
||||
voices = data["query_data"][0]["voice"]
|
||||
|
||||
# 首次发送
|
||||
first_send_time = random.randint(3, 5)
|
||||
send_interval = random.random() * 0
|
||||
log.info("首次发送%ss 发送间隔%ss" % (first_send_time, send_interval))
|
||||
time.sleep(first_send_time)
|
||||
|
||||
# 将句子拼接到一起
|
||||
if random.random() < 0.3:
|
||||
log.info("将部分句子合并成单句 每次合并的句子不超过3句")
|
||||
rand_idx = 0
|
||||
rand_sep = [0, len(voices) - 1]
|
||||
while rand_sep[rand_idx] + 1 <= rand_sep[rand_idx + 1] - 1:
|
||||
rand_cursep = random.randint(
|
||||
rand_sep[rand_idx] + 1,
|
||||
min(rand_sep[rand_idx + 1] - 1, rand_sep[rand_idx] + 1 + 3),
|
||||
)
|
||||
rand_sep.insert(rand_idx + 1, rand_cursep)
|
||||
rand_idx += 1
|
||||
merged_voices = []
|
||||
for i, cur_sep in enumerate(rand_sep[:-1]):
|
||||
voice = voices[cur_sep]
|
||||
for j in range(cur_sep + 1, rand_sep[i + 1]):
|
||||
voice["answer"] += voices[j]["answer"]
|
||||
voice["end"] = voices[j]["end"]
|
||||
merged_voices.append(voice)
|
||||
merged_voices.append(voices[rand_sep[-1]])
|
||||
voices = merged_voices
|
||||
|
||||
def split_and_keep(text, delimiters):
|
||||
# 构建正则表达式模式,匹配文本或分隔符
|
||||
pattern = "|".join(re.escape(delimiter) for delimiter in delimiters)
|
||||
pattern = f"(?:[^{pattern}]+|[{pattern}])"
|
||||
return re.findall(pattern, text)
|
||||
|
||||
puncs = [",", ".", "?", "!", ";", ":"]
|
||||
|
||||
para_seq = 0
|
||||
for voice in voices:
|
||||
answer: str = voice["answer"]
|
||||
start_time: float = voice["start"]
|
||||
end_time: float = voice["end"]
|
||||
words = split_and_keep(answer, puncs)
|
||||
temp_words = []
|
||||
for i, word in enumerate(words):
|
||||
if i > 0 and i < len(words) - 1 and random.random() < 0.15:
|
||||
log.info("随机删除word")
|
||||
continue
|
||||
temp_words.extend(word.split(" "))
|
||||
if len(temp_words) == 0:
|
||||
temp_words = words[0].split(" ")
|
||||
words = temp_words
|
||||
answer = " ".join(words)
|
||||
words = list(map(lambda x: x.strip(), words))
|
||||
words = list(filter(lambda x: len(x) > 0, words))
|
||||
|
||||
# 将时间均匀分配到每个字上
|
||||
words_withtime = []
|
||||
word_unittime = (end_time - start_time) / len(words)
|
||||
for i, word in enumerate(words):
|
||||
word_start = start_time + word_unittime * i
|
||||
word_end = word_start + word_unittime
|
||||
words_withtime.append(
|
||||
{
|
||||
"text": word,
|
||||
"start_time": word_start * 1000,
|
||||
"end_time": word_end * 1000,
|
||||
}
|
||||
)
|
||||
|
||||
# 将句子首尾的标点符号时间扩展到字上 标点符号时间为瞬间
|
||||
punc_at = 0
|
||||
while punc_at < len(words) and words[punc_at] in puncs:
|
||||
punc_at += 1
|
||||
if punc_at < len(words):
|
||||
words_withtime[punc_at]["start_time"] = words_withtime[0][
|
||||
"start_time"
|
||||
]
|
||||
for i in range(0, punc_at):
|
||||
words_withtime[i]["start_time"] = words_withtime[0]["start_time"]
|
||||
words_withtime[i]["end_time"] = words_withtime[0]["start_time"]
|
||||
punc_at = len(words) - 1
|
||||
while punc_at >= 0 and words[punc_at] in puncs:
|
||||
punc_at -= 1
|
||||
if punc_at >= 0:
|
||||
words_withtime[punc_at]["end_time"] = words_withtime[-1]["end_time"]
|
||||
for i in range(punc_at + 1, len(words)):
|
||||
words_withtime[i]["start_time"] = (
|
||||
words_withtime[-1]["end_time"] + 0.1
|
||||
)
|
||||
words_withtime[i]["end_time"] = words_withtime[-1]["end_time"] + 0.1
|
||||
|
||||
if random.random() < 0.4 and len(words_withtime) > 1:
|
||||
log.info("发送一次final_result=False")
|
||||
rand_idx = random.randint(1, len(words_withtime) - 1)
|
||||
recognition_result = {
|
||||
"text": " ".join(
|
||||
map(lambda x: x["text"], words_withtime[:rand_idx])
|
||||
),
|
||||
"final_result": False,
|
||||
"para_seq": para_seq,
|
||||
"language": "de",
|
||||
"start_time": start_time * 1000,
|
||||
"end_time": end_time * 1000,
|
||||
"words": words_withtime[:rand_idx],
|
||||
}
|
||||
callback(recognition_result)
|
||||
|
||||
recognition_result = {
|
||||
"text": answer,
|
||||
"final_result": True,
|
||||
"para_seq": para_seq,
|
||||
"language": "de",
|
||||
"start_time": start_time * 1000,
|
||||
"end_time": end_time * 1000,
|
||||
"words": words_withtime,
|
||||
}
|
||||
callback(recognition_result)
|
||||
para_seq += 1
|
||||
log.info("send %s" % para_seq)
|
||||
|
||||
time.sleep(send_interval)
|
||||
|
||||
callback(None)
|
||||
|
||||
|
||||
# ignore END
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=80)
|
||||
@@ -1,3 +0,0 @@
|
||||
flask
|
||||
requests
|
||||
pyyaml
|
||||
@@ -1,16 +0,0 @@
|
||||
import json
|
||||
|
||||
from schemas.dataset import QueryData
|
||||
from schemas.stream import StreamDataModel
|
||||
from utils.evaluator_plus import evaluate_editops
|
||||
|
||||
with open("out/detail_cases.json") as f:
|
||||
detail_cases = json.load(f)
|
||||
|
||||
detail_case = detail_cases[0]
|
||||
preds = []
|
||||
for pred in detail_case["preds"]:
|
||||
preds.append(StreamDataModel.model_validate(pred))
|
||||
label = QueryData.model_validate(detail_case["label"])
|
||||
|
||||
print(evaluate_editops(label, preds))
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
f(a, b) 计算 a -> b 的编辑距离,使用的方法是之前asr榜单的方法
|
||||
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
|
||||
test() 是对拍程序
|
||||
"""
|
||||
|
||||
import random
|
||||
import string
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
|
||||
import Levenshtein
|
||||
|
||||
|
||||
def mapping(gt: str, dt: str):
|
||||
return [i for i in gt], [i for i in dt]
|
||||
|
||||
|
||||
def token_mapping(
|
||||
tokens_gt: List[str], tokens_dt: List[str]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
arr1 = deepcopy(tokens_gt)
|
||||
arr2 = deepcopy(tokens_dt)
|
||||
operations = Levenshtein.editops(arr1, arr2)
|
||||
for op in operations[::-1]:
|
||||
if op[0] == "insert":
|
||||
arr1.insert(op[1], None)
|
||||
elif op[0] == "delete":
|
||||
arr2.insert(op[2], None)
|
||||
return arr1, arr2
|
||||
|
||||
|
||||
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||||
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||||
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||||
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||||
equal = sum(
|
||||
1
|
||||
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
|
||||
if token_gt == token_dt
|
||||
)
|
||||
replace = len(tokens_gt_mapping) - insert - equal # - delete
|
||||
return replace, delete, insert
|
||||
|
||||
|
||||
def f(a, b):
|
||||
return cer(*token_mapping(*mapping(a, b)))
|
||||
|
||||
|
||||
def raw(tokens_gt, tokens_dt):
|
||||
arr1 = deepcopy(tokens_gt)
|
||||
arr2 = deepcopy(tokens_dt)
|
||||
operations = Levenshtein.editops(arr1, arr2)
|
||||
insert = 0
|
||||
delete = 0
|
||||
replace = 0
|
||||
for op in operations:
|
||||
if op[0] == "insert":
|
||||
insert += 1
|
||||
if op[0] == "delete":
|
||||
delete += 1
|
||||
if op[0] == "replace":
|
||||
replace += 1
|
||||
return replace, delete, insert
|
||||
|
||||
|
||||
def g(a, b):
|
||||
return raw(*mapping(a, b))
|
||||
|
||||
|
||||
def check(a, b):
|
||||
ff = f(a, b)
|
||||
gg = g(a, b)
|
||||
if ff != gg:
|
||||
print(ff, gg)
|
||||
return ff == gg
|
||||
|
||||
|
||||
def random_string(length):
|
||||
letters = string.ascii_lowercase
|
||||
return "".join(random.choice(letters) for i in range(length))
|
||||
|
||||
|
||||
def test():
|
||||
for _ in range(10000):
|
||||
a = random_string(30)
|
||||
b = random_string(30)
|
||||
if not check(a, b):
|
||||
print(a, b)
|
||||
break
|
||||
|
||||
|
||||
test()
|
||||
Reference in New Issue
Block a user