update
This commit is contained in:
49
Dockerfile
Normal file
49
Dockerfile
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
|
||||||
|
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
|
||||||
|
MAINTAINER shiguangchuan@4paradigm.com
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
COPY ssh-keygen /bin
|
||||||
|
|
||||||
|
RUN wget -q ftp://ftp.4pd.io/pub/pico/temp/pynini-2.1.6-cp38-cp38-manylinux_2_31_x86_64.whl && pip install pynini-2.1.6-cp38-cp38-manylinux_2_31_x86_64.whl && rm -f pynini-2.1.6-c p38-cp38-manylinux_2_31_x86_64.whl
|
||||||
|
|
||||||
|
ADD ./requirements.txt /workspace
|
||||||
|
RUN pip install -r ./requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --trusted-host nexus.4pd.io --extra-index-url https://mirrors.aliyun.com/pypi/simple/ \
|
||||||
|
&& pip cache purge \
|
||||||
|
&& ssh-keygen -f /workspace/ssh-key-ecdsa -t ecdsa -b 521 -q -N ""
|
||||||
|
|
||||||
|
ADD . /workspace
|
||||||
|
|
||||||
|
EXPOSE 80
|
||||||
|
|
||||||
|
CMD ["python3", "run_callback.py"]
|
||||||
|
|
||||||
|
|
||||||
|
###########################
|
||||||
|
## Dockerfile(更新后)
|
||||||
|
#FROM harbor.4pd.io/lab-platform/inf/python:3.9
|
||||||
|
|
||||||
|
#WORKDIR /app
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
##RUN pip install torch librosa flask
|
||||||
|
|
||||||
|
##RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && \
|
||||||
|
## pip cache purge && \
|
||||||
|
## pip --default-timeout=1000 install torch librosa flask
|
||||||
|
|
||||||
|
## 删除原来的 COPY pytorch_model.bin /app/
|
||||||
|
|
||||||
|
#COPY inference.py /app/
|
||||||
|
# 只需要复制启动脚本
|
||||||
|
|
||||||
|
#EXPOSE 80
|
||||||
|
|
||||||
|
#CMD ["python", "inference.py"]
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
##############################更新0731#################################
|
||||||
|
|
||||||
|
|
||||||
Submodule asr-tco_image deleted from 8f9a14f472
6
config.yaml
Normal file
6
config.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
leaderboard_options:
|
||||||
|
nfs:
|
||||||
|
- name: sid_model
|
||||||
|
srcRelativePath: zhoushasha/models/image_models/apple_mobilevit-small
|
||||||
|
mountPoint: /model
|
||||||
|
source: ceph_customer
|
||||||
BIN
helm-chart/.DS_Store
vendored
Normal file
BIN
helm-chart/.DS_Store
vendored
Normal file
Binary file not shown.
77
helm-chart/README.md
Normal file
77
helm-chart/README.md
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
## judgeflow chart 的要求
|
||||||
|
|
||||||
|
### values.yaml 文件必须包含如下字段,并且模板中必须引用 values.yaml 中的如下字段
|
||||||
|
|
||||||
|
```
|
||||||
|
podLabels
|
||||||
|
env
|
||||||
|
volumeMounts
|
||||||
|
volumes
|
||||||
|
affinity
|
||||||
|
```
|
||||||
|
|
||||||
|
### values.yaml 文件必须在 volumeMounts 中声明如下卷
|
||||||
|
|
||||||
|
```
|
||||||
|
workspace
|
||||||
|
submit
|
||||||
|
datafile
|
||||||
|
```
|
||||||
|
|
||||||
|
## 被测服务(sut) chart 的要求
|
||||||
|
|
||||||
|
### values.yaml 文件必须包含如下字段,并且资源模板中必须引用 values.yaml 中的如下字段
|
||||||
|
|
||||||
|
```
|
||||||
|
podLabels
|
||||||
|
affinity
|
||||||
|
```
|
||||||
|
|
||||||
|
针对 podLabels 字段,values.yaml 中配置格式如下:
|
||||||
|
|
||||||
|
```
|
||||||
|
podLabels: {}
|
||||||
|
```
|
||||||
|
|
||||||
|
下面给出示例
|
||||||
|
|
||||||
|
podLabels
|
||||||
|
|
||||||
|
values.yaml
|
||||||
|
|
||||||
|
templates/deployment.yaml
|
||||||
|
|
||||||
|
```
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
{{- with .Values.podLabels }}
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
```
|
||||||
|
|
||||||
|
affinity
|
||||||
|
|
||||||
|
values.yaml
|
||||||
|
|
||||||
|
```
|
||||||
|
affinity: {}
|
||||||
|
```
|
||||||
|
|
||||||
|
templates/deployment.yaml
|
||||||
|
|
||||||
|
```
|
||||||
|
spec:
|
||||||
|
template:
|
||||||
|
spec:
|
||||||
|
{{- with .Values.affinity }}
|
||||||
|
affinity:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 如果需要在 sut 中使用共享存储,则 sut chart 的 values.yaml 也必须包含如下字段,且模板中必须引用 values.yaml 中的如下字段
|
||||||
|
|
||||||
|
```
|
||||||
|
volumeMounts
|
||||||
|
volumes
|
||||||
|
```
|
||||||
23
helm-chart/asr-tco/.helmignore
Normal file
23
helm-chart/asr-tco/.helmignore
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Patterns to ignore when building packages.
|
||||||
|
# This supports shell glob matching, relative path matching, and
|
||||||
|
# negation (prefixed with !). Only one pattern per line.
|
||||||
|
.DS_Store
|
||||||
|
# Common VCS dirs
|
||||||
|
.git/
|
||||||
|
.gitignore
|
||||||
|
.bzr/
|
||||||
|
.bzrignore
|
||||||
|
.hg/
|
||||||
|
.hgignore
|
||||||
|
.svn/
|
||||||
|
# Common backup files
|
||||||
|
*.swp
|
||||||
|
*.bak
|
||||||
|
*.tmp
|
||||||
|
*.orig
|
||||||
|
*~
|
||||||
|
# Various IDEs
|
||||||
|
.project
|
||||||
|
.idea/
|
||||||
|
*.tmproj
|
||||||
|
.vscode/
|
||||||
24
helm-chart/asr-tco/Chart.yaml.tmpl
Normal file
24
helm-chart/asr-tco/Chart.yaml.tmpl
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
apiVersion: v2
|
||||||
|
name: ${chartName}
|
||||||
|
description: Leaderboard judgeflow helm chart for demo
|
||||||
|
|
||||||
|
# A chart can be either an 'application' or a 'library' chart.
|
||||||
|
#
|
||||||
|
# Application charts are a collection of templates that can be packaged into versioned archives
|
||||||
|
# to be deployed.
|
||||||
|
#
|
||||||
|
# Library charts provide useful utilities or functions for the chart developer. They're included as
|
||||||
|
# a dependency of application charts to inject those utilities and functions into the rendering
|
||||||
|
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
|
||||||
|
type: application
|
||||||
|
|
||||||
|
# This is the chart version. This version number should be incremented each time you make changes
|
||||||
|
# to the chart and its templates, including the app version.
|
||||||
|
# Versions are expected to follow Semantic Versioning (https://semver.org/)
|
||||||
|
version: ${version}
|
||||||
|
|
||||||
|
# This is the version number of the application being deployed. This version number should be
|
||||||
|
# incremented each time you make changes to the application. Versions are not expected to
|
||||||
|
# follow Semantic Versioning. They should reflect the version the application is using.
|
||||||
|
# It is recommended to use it with quotes.
|
||||||
|
appVersion: "${appVersion}"
|
||||||
62
helm-chart/asr-tco/templates/_helpers.tpl
Normal file
62
helm-chart/asr-tco/templates/_helpers.tpl
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
{{/*
|
||||||
|
Expand the name of the chart.
|
||||||
|
*/}}
|
||||||
|
{{- define "judgeflow.name" -}}
|
||||||
|
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Create a default fully qualified app name.
|
||||||
|
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
|
||||||
|
If release name contains chart name it will be used as a full name.
|
||||||
|
*/}}
|
||||||
|
{{- define "judgeflow.fullname" -}}
|
||||||
|
{{- if .Values.fullnameOverride }}
|
||||||
|
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- else }}
|
||||||
|
{{- $name := default .Chart.Name .Values.nameOverride }}
|
||||||
|
{{- if contains $name .Release.Name }}
|
||||||
|
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- else }}
|
||||||
|
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Create chart name and version as used by the chart label.
|
||||||
|
*/}}
|
||||||
|
{{- define "judgeflow.chart" -}}
|
||||||
|
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Common labels
|
||||||
|
*/}}
|
||||||
|
{{- define "judgeflow.labels" -}}
|
||||||
|
helm.sh/chart: {{ include "judgeflow.chart" . }}
|
||||||
|
{{ include "judgeflow.selectorLabels" . }}
|
||||||
|
{{- if .Chart.AppVersion }}
|
||||||
|
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||||
|
{{- end }}
|
||||||
|
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Selector labels
|
||||||
|
*/}}
|
||||||
|
{{- define "judgeflow.selectorLabels" -}}
|
||||||
|
app.kubernetes.io/name: {{ include "judgeflow.name" . }}
|
||||||
|
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Create the name of the service account to use
|
||||||
|
*/}}
|
||||||
|
{{- define "judgeflow.serviceAccountName" -}}
|
||||||
|
{{- if .Values.serviceAccount.create }}
|
||||||
|
{{- default (include "judgeflow.fullname" .) .Values.serviceAccount.name }}
|
||||||
|
{{- else }}
|
||||||
|
{{- default "default" .Values.serviceAccount.name }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
32
helm-chart/asr-tco/templates/hpa.yaml
Normal file
32
helm-chart/asr-tco/templates/hpa.yaml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{{- if .Values.autoscaling.enabled }}
|
||||||
|
apiVersion: autoscaling/v2
|
||||||
|
kind: HorizontalPodAutoscaler
|
||||||
|
metadata:
|
||||||
|
name: {{ include "judgeflow.fullname" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||||
|
spec:
|
||||||
|
scaleTargetRef:
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
name: {{ include "judgeflow.fullname" . }}
|
||||||
|
minReplicas: {{ .Values.autoscaling.minReplicas }}
|
||||||
|
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
|
||||||
|
metrics:
|
||||||
|
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||||
|
- type: Resource
|
||||||
|
resource:
|
||||||
|
name: cpu
|
||||||
|
target:
|
||||||
|
type: Utilization
|
||||||
|
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||||
|
- type: Resource
|
||||||
|
resource:
|
||||||
|
name: memory
|
||||||
|
target:
|
||||||
|
type: Utilization
|
||||||
|
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
61
helm-chart/asr-tco/templates/ingress.yaml
Normal file
61
helm-chart/asr-tco/templates/ingress.yaml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
{{- if .Values.ingress.enabled -}}
|
||||||
|
{{- $fullName := include "judgeflow.fullname" . -}}
|
||||||
|
{{- $svcPort := .Values.service.port -}}
|
||||||
|
{{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }}
|
||||||
|
{{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }}
|
||||||
|
{{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||||
|
apiVersion: networking.k8s.io/v1
|
||||||
|
{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||||
|
apiVersion: networking.k8s.io/v1beta1
|
||||||
|
{{- else -}}
|
||||||
|
apiVersion: extensions/v1beta1
|
||||||
|
{{- end }}
|
||||||
|
kind: Ingress
|
||||||
|
metadata:
|
||||||
|
name: {{ $fullName }}
|
||||||
|
labels:
|
||||||
|
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.ingress.annotations }}
|
||||||
|
annotations:
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
{{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }}
|
||||||
|
ingressClassName: {{ .Values.ingress.className }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Values.ingress.tls }}
|
||||||
|
tls:
|
||||||
|
{{- range .Values.ingress.tls }}
|
||||||
|
- hosts:
|
||||||
|
{{- range .hosts }}
|
||||||
|
- {{ . | quote }}
|
||||||
|
{{- end }}
|
||||||
|
secretName: {{ .secretName }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
rules:
|
||||||
|
{{- range .Values.ingress.hosts }}
|
||||||
|
- host: {{ .host | quote }}
|
||||||
|
http:
|
||||||
|
paths:
|
||||||
|
{{- range .paths }}
|
||||||
|
- path: {{ .path }}
|
||||||
|
{{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }}
|
||||||
|
pathType: {{ .pathType }}
|
||||||
|
{{- end }}
|
||||||
|
backend:
|
||||||
|
{{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }}
|
||||||
|
service:
|
||||||
|
name: {{ $fullName }}
|
||||||
|
port:
|
||||||
|
number: {{ $svcPort }}
|
||||||
|
{{- else }}
|
||||||
|
serviceName: {{ $fullName }}
|
||||||
|
servicePort: {{ $svcPort }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
63
helm-chart/asr-tco/templates/job.yaml
Normal file
63
helm-chart/asr-tco/templates/job.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
apiVersion: batch/v1
|
||||||
|
kind: Job
|
||||||
|
metadata:
|
||||||
|
name: {{ include "judgeflow.fullname" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.podLabels }}
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
{{- include "judgeflow.labels" . | nindent 8 }}
|
||||||
|
{{- with .Values.podLabels }}
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
{{- with .Values.priorityclassname }}
|
||||||
|
priorityClassName: "{{ . }}"
|
||||||
|
{{- end }}
|
||||||
|
containers:
|
||||||
|
- name: {{ .Chart.Name }}
|
||||||
|
securityContext:
|
||||||
|
{{- toYaml .Values.securityContext | nindent 12 }}
|
||||||
|
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||||
|
imagePullPolicy: {{ .Values.image.pullPolicy }}
|
||||||
|
{{- with .Values.env }}
|
||||||
|
env:
|
||||||
|
{{- toYaml . | nindent 12 }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if and (hasKey .Values "service") (hasKey .Values.service "ports") }}
|
||||||
|
ports:
|
||||||
|
{{- range .Values.service.ports }}
|
||||||
|
- name: {{ .name }}
|
||||||
|
containerPort: {{ .port }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if hasKey .Values "command" }}
|
||||||
|
command: {{ .Values.command }}
|
||||||
|
{{- end }}
|
||||||
|
volumeMounts:
|
||||||
|
{{- toYaml .Values.volumeMounts | nindent 12 }}
|
||||||
|
resources:
|
||||||
|
{{- toYaml .Values.resources | nindent 12 }}
|
||||||
|
restartPolicy: Never
|
||||||
|
{{- with .Values.volumes }}
|
||||||
|
volumes:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
{{- with .Values.affinity }}
|
||||||
|
affinity:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
{{- with .Values.nodeSelector }}
|
||||||
|
nodeSelector:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
{{- with .Values.tolerations }}
|
||||||
|
tolerations:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
backoffLimit: 0
|
||||||
10
helm-chart/asr-tco/templates/priorityclass.yaml
Normal file
10
helm-chart/asr-tco/templates/priorityclass.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{{- if .Values.priorityclassname }}
|
||||||
|
apiVersion: scheduling.k8s.io/v1
|
||||||
|
kind: PriorityClass
|
||||||
|
metadata:
|
||||||
|
name: "{{ .Values.priorityclassname }}"
|
||||||
|
value: {{ .Values.priorityclassvalue }}
|
||||||
|
globalDefault: false
|
||||||
|
preemptionPolicy: "Never"
|
||||||
|
description: "This is a priority class."
|
||||||
|
{{- end }}
|
||||||
22
helm-chart/asr-tco/templates/service.yaml
Normal file
22
helm-chart/asr-tco/templates/service.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{{- if and (hasKey .Values "service") (hasKey .Values.service "type") }}
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: {{ include "judgeflow.fullname" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.podLabels }}
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
type: {{ .Values.service.type }}
|
||||||
|
ports:
|
||||||
|
{{- range .Values.service.ports }}
|
||||||
|
- port: {{ .port }}
|
||||||
|
targetPort: {{ .port }}
|
||||||
|
protocol: TCP
|
||||||
|
name: {{ .name }}
|
||||||
|
{{- end }}
|
||||||
|
selector:
|
||||||
|
{{- include "judgeflow.selectorLabels" . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
13
helm-chart/asr-tco/templates/serviceaccount.yaml
Normal file
13
helm-chart/asr-tco/templates/serviceaccount.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{{- if .Values.serviceAccount.create -}}
|
||||||
|
apiVersion: v1
|
||||||
|
kind: ServiceAccount
|
||||||
|
metadata:
|
||||||
|
name: {{ include "judgeflow.serviceAccountName" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.serviceAccount.annotations }}
|
||||||
|
annotations:
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
|
||||||
|
{{- end }}
|
||||||
15
helm-chart/asr-tco/templates/tests/test-connection.yaml
Normal file
15
helm-chart/asr-tco/templates/tests/test-connection.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: Pod
|
||||||
|
metadata:
|
||||||
|
name: {{ include "judgeflow.fullname" . }}-test-connection
|
||||||
|
labels:
|
||||||
|
{{- include "judgeflow.labels" . | nindent 4 }}
|
||||||
|
annotations:
|
||||||
|
"helm.sh/hook": test
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: wget
|
||||||
|
image: busybox
|
||||||
|
command: ['wget']
|
||||||
|
args: ['{{ include "judgeflow.fullname" . }}:{{ .Values.service.port }}']
|
||||||
|
restartPolicy: Never
|
||||||
124
helm-chart/asr-tco/values.yaml.tmpl
Normal file
124
helm-chart/asr-tco/values.yaml.tmpl
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# Default values for job_demo.
|
||||||
|
# This is a YAML-formatted file.
|
||||||
|
# Declare variables to be passed into your templates.
|
||||||
|
|
||||||
|
replicaCount: 1
|
||||||
|
|
||||||
|
image:
|
||||||
|
repository: "${imageRepo}"
|
||||||
|
pullPolicy: IfNotPresent
|
||||||
|
# Overrides the image tag whose default is the chart appVersion.
|
||||||
|
tag: "${imageTag}"
|
||||||
|
|
||||||
|
imagePullSecrets: []
|
||||||
|
nameOverride: ""
|
||||||
|
fullnameOverride: ""
|
||||||
|
|
||||||
|
serviceAccount:
|
||||||
|
# Specifies whether a service account should be created
|
||||||
|
create: true
|
||||||
|
# Annotations to add to the service account
|
||||||
|
annotations: {}
|
||||||
|
# The name of the service account to use.
|
||||||
|
# If not set and create is true, a name is generated using the fullname template
|
||||||
|
name: ""
|
||||||
|
|
||||||
|
podAnnotations: {}
|
||||||
|
|
||||||
|
podLabels:
|
||||||
|
contest.4pd.io/leaderboard-resource-type: judge_flow
|
||||||
|
contest.4pd.io/leaderboard-job-id: "0"
|
||||||
|
contest.4pd.io/leaderboard-submit-id: "0"
|
||||||
|
|
||||||
|
podSecurityContext: {}
|
||||||
|
# fsGroup: 2000
|
||||||
|
|
||||||
|
securityContext: {}
|
||||||
|
# capabilities:
|
||||||
|
# drop:
|
||||||
|
# - ALL
|
||||||
|
# readOnlyRootFilesystem: true
|
||||||
|
# runAsNonRoot: true
|
||||||
|
# runAsUser: 1000
|
||||||
|
|
||||||
|
service:
|
||||||
|
type: ClusterIP
|
||||||
|
ports:
|
||||||
|
- name: http
|
||||||
|
port: 80
|
||||||
|
|
||||||
|
ingress:
|
||||||
|
enabled: false
|
||||||
|
className: ""
|
||||||
|
annotations: {}
|
||||||
|
# kubernetes.io/ingress.class: nginx
|
||||||
|
# kubernetes.io/tls-acme: "true"
|
||||||
|
hosts:
|
||||||
|
- host: chart-example.local
|
||||||
|
paths:
|
||||||
|
- path: /
|
||||||
|
pathType: ImplementationSpecific
|
||||||
|
tls: []
|
||||||
|
# - secretName: chart-example-tls
|
||||||
|
# hosts:
|
||||||
|
# - chart-example.local
|
||||||
|
|
||||||
|
resources:
|
||||||
|
# We usually recommend not to specify default resources and to leave this as a conscious
|
||||||
|
# choice for the user. This also increases chances charts run on environments with little
|
||||||
|
# resources, such as Minikube. If you do want to specify resources, uncomment the following
|
||||||
|
# lines, adjust them as necessary, and remove the curly braces after 'resources:'.
|
||||||
|
limits:
|
||||||
|
cpu: 3000m
|
||||||
|
memory: 16Gi
|
||||||
|
requests:
|
||||||
|
cpu: 3000m
|
||||||
|
memory: 16Gi
|
||||||
|
|
||||||
|
autoscaling:
|
||||||
|
enabled: false
|
||||||
|
minReplicas: 1
|
||||||
|
maxReplicas: 100
|
||||||
|
targetCPUUtilizationPercentage: 80
|
||||||
|
# targetMemoryUtilizationPercentage: 80
|
||||||
|
|
||||||
|
nodeSelector:
|
||||||
|
juicefs: "on"
|
||||||
|
contest.4pd.io/cpu: INTEL-8358
|
||||||
|
|
||||||
|
tolerations: []
|
||||||
|
|
||||||
|
affinity: {}
|
||||||
|
|
||||||
|
env:
|
||||||
|
- name: TZ
|
||||||
|
value: Asia/Shanghai
|
||||||
|
- name: MY_POD_IP
|
||||||
|
valueFrom:
|
||||||
|
fieldRef:
|
||||||
|
fieldPath: status.podIP
|
||||||
|
|
||||||
|
#command: '["python","run.py"]'
|
||||||
|
|
||||||
|
volumeMounts:
|
||||||
|
- name: workspace
|
||||||
|
mountPath: /tmp/workspace
|
||||||
|
- name: datafile
|
||||||
|
mountPath: /tmp/datafile
|
||||||
|
- name: submit
|
||||||
|
mountPath: /tmp/submit_config
|
||||||
|
- name: juicefs-pv
|
||||||
|
mountPath: /tmp/juicefs
|
||||||
|
- name: customer
|
||||||
|
mountPath: /tmp/customer
|
||||||
|
- name: submit-private
|
||||||
|
mountPath: /tmp/submit_private
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
- name: juicefs-pv
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: juicefs-pvc
|
||||||
|
|
||||||
|
|
||||||
|
priorityclassname: ''
|
||||||
|
priorityclassvalue: '0'
|
||||||
BIN
helm-chart/sut/.DS_Store
vendored
Normal file
BIN
helm-chart/sut/.DS_Store
vendored
Normal file
Binary file not shown.
23
helm-chart/sut/.helmignore
Normal file
23
helm-chart/sut/.helmignore
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Patterns to ignore when building packages.
|
||||||
|
# This supports shell glob matching, relative path matching, and
|
||||||
|
# negation (prefixed with !). Only one pattern per line.
|
||||||
|
.DS_Store
|
||||||
|
# Common VCS dirs
|
||||||
|
.git/
|
||||||
|
.gitignore
|
||||||
|
.bzr/
|
||||||
|
.bzrignore
|
||||||
|
.hg/
|
||||||
|
.hgignore
|
||||||
|
.svn/
|
||||||
|
# Common backup files
|
||||||
|
*.swp
|
||||||
|
*.bak
|
||||||
|
*.tmp
|
||||||
|
*.orig
|
||||||
|
*~
|
||||||
|
# Various IDEs
|
||||||
|
.project
|
||||||
|
.idea/
|
||||||
|
*.tmproj
|
||||||
|
.vscode/
|
||||||
24
helm-chart/sut/Chart.yaml
Normal file
24
helm-chart/sut/Chart.yaml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
apiVersion: v2
|
||||||
|
name: sut
|
||||||
|
description: A Helm chart for Kubernetes
|
||||||
|
|
||||||
|
# A chart can be either an 'application' or a 'library' chart.
|
||||||
|
#
|
||||||
|
# Application charts are a collection of templates that can be packaged into versioned archives
|
||||||
|
# to be deployed.
|
||||||
|
#
|
||||||
|
# Library charts provide useful utilities or functions for the chart developer. They're included as
|
||||||
|
# a dependency of application charts to inject those utilities and functions into the rendering
|
||||||
|
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
|
||||||
|
type: application
|
||||||
|
|
||||||
|
# This is the chart version. This version number should be incremented each time you make changes
|
||||||
|
# to the chart and its templates, including the app version.
|
||||||
|
# Versions are expected to follow Semantic Versioning (https://semver.org/)
|
||||||
|
version: 0.1.0
|
||||||
|
|
||||||
|
# This is the version number of the application being deployed. This version number should be
|
||||||
|
# incremented each time you make changes to the application. Versions are not expected to
|
||||||
|
# follow Semantic Versioning. They should reflect the version the application is using.
|
||||||
|
# It is recommended to use it with quotes.
|
||||||
|
appVersion: "0.1.0"
|
||||||
62
helm-chart/sut/templates/_helpers.tpl
Normal file
62
helm-chart/sut/templates/_helpers.tpl
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
{{/*
|
||||||
|
Expand the name of the chart.
|
||||||
|
*/}}
|
||||||
|
{{- define "sut.name" -}}
|
||||||
|
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Create a default fully qualified app name.
|
||||||
|
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
|
||||||
|
If release name contains chart name it will be used as a full name.
|
||||||
|
*/}}
|
||||||
|
{{- define "sut.fullname" -}}
|
||||||
|
{{- if .Values.fullnameOverride }}
|
||||||
|
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- else }}
|
||||||
|
{{- $name := default .Chart.Name .Values.nameOverride }}
|
||||||
|
{{- if contains $name .Release.Name }}
|
||||||
|
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- else }}
|
||||||
|
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Create chart name and version as used by the chart label.
|
||||||
|
*/}}
|
||||||
|
{{- define "sut.chart" -}}
|
||||||
|
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Common labels
|
||||||
|
*/}}
|
||||||
|
{{- define "sut.labels" -}}
|
||||||
|
helm.sh/chart: {{ include "sut.chart" . }}
|
||||||
|
{{ include "sut.selectorLabels" . }}
|
||||||
|
{{- if .Chart.AppVersion }}
|
||||||
|
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||||
|
{{- end }}
|
||||||
|
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Selector labels
|
||||||
|
*/}}
|
||||||
|
{{- define "sut.selectorLabels" -}}
|
||||||
|
app.kubernetes.io/name: {{ include "sut.name" . }}
|
||||||
|
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{/*
|
||||||
|
Create the name of the service account to use
|
||||||
|
*/}}
|
||||||
|
{{- define "sut.serviceAccountName" -}}
|
||||||
|
{{- if .Values.serviceAccount.create }}
|
||||||
|
{{- default (include "sut.fullname" .) .Values.serviceAccount.name }}
|
||||||
|
{{- else }}
|
||||||
|
{{- default "default" .Values.serviceAccount.name }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
94
helm-chart/sut/templates/deployment.yaml
Normal file
94
helm-chart/sut/templates/deployment.yaml
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: {{ include "sut.fullname" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "sut.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.podLabels }}
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
{{- if not .Values.autoscaling.enabled }}
|
||||||
|
replicas: {{ .Values.replicaCount }}
|
||||||
|
{{- end }}
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
{{- include "sut.selectorLabels" . | nindent 6 }}
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
{{- with .Values.podAnnotations }}
|
||||||
|
annotations:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
labels:
|
||||||
|
{{- include "sut.labels" . | nindent 8 }}
|
||||||
|
{{- with .Values.podLabels }}
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
{{- with .Values.imagePullSecrets }}
|
||||||
|
imagePullSecrets:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
serviceAccountName: {{ include "sut.serviceAccountName" . }}
|
||||||
|
securityContext:
|
||||||
|
{{- toYaml .Values.podSecurityContext | nindent 8 }}
|
||||||
|
{{- with .Values.priorityclassname }}
|
||||||
|
priorityClassName: "{{ . }}"
|
||||||
|
{{- end }}
|
||||||
|
containers:
|
||||||
|
- name: {{ .Chart.Name }}
|
||||||
|
securityContext:
|
||||||
|
{{- toYaml .Values.securityContext | nindent 12 }}
|
||||||
|
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||||
|
imagePullPolicy: {{ .Values.image.pullPolicy }}
|
||||||
|
{{- with .Values.env }}
|
||||||
|
env:
|
||||||
|
{{- toYaml . | nindent 12 }}
|
||||||
|
{{- end }}
|
||||||
|
ports:
|
||||||
|
- name: http
|
||||||
|
containerPort: {{ .Values.service.port }}
|
||||||
|
protocol: TCP
|
||||||
|
{{- with .Values.command }}
|
||||||
|
command:
|
||||||
|
{{- toYaml . | nindent 12 }}
|
||||||
|
{{- end }}
|
||||||
|
resources:
|
||||||
|
{{- toYaml .Values.resources | nindent 12 }}
|
||||||
|
{{- with .Values.volumeMounts }}
|
||||||
|
volumeMounts:
|
||||||
|
{{- toYaml . | nindent 12 }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{- with .Values.livenessProbe }}
|
||||||
|
livenessProbe:
|
||||||
|
{{- toYaml . | nindent 12 }}
|
||||||
|
{{- end }}
|
||||||
|
{{- with .Values.readinessProbe }}
|
||||||
|
readinessProbe:
|
||||||
|
{{- toYaml . | nindent 12 }}
|
||||||
|
{{- end }}
|
||||||
|
{{- with .Values.startupProbe }}
|
||||||
|
startupProbe:
|
||||||
|
{{- toYaml . | nindent 12 }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
{{- with .Values.volumes }}
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{- with .Values.nodeSelector }}
|
||||||
|
nodeSelector:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
{{- with .Values.affinity }}
|
||||||
|
affinity:
|
||||||
|
{{- toYaml . | nindent 8 }}
|
||||||
|
{{- end }}
|
||||||
|
tolerations:
|
||||||
|
- key: "hosttype"
|
||||||
|
operator: "Equal"
|
||||||
|
value: "iluvatar"
|
||||||
|
effect: "NoSchedule"
|
||||||
32
helm-chart/sut/templates/hpa.yaml
Normal file
32
helm-chart/sut/templates/hpa.yaml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{{- if .Values.autoscaling.enabled }}
|
||||||
|
apiVersion: autoscaling/v2
|
||||||
|
kind: HorizontalPodAutoscaler
|
||||||
|
metadata:
|
||||||
|
name: {{ include "sut.fullname" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "sut.labels" . | nindent 4 }}
|
||||||
|
spec:
|
||||||
|
scaleTargetRef:
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
name: {{ include "sut.fullname" . }}
|
||||||
|
minReplicas: {{ .Values.autoscaling.minReplicas }}
|
||||||
|
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
|
||||||
|
metrics:
|
||||||
|
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||||
|
- type: Resource
|
||||||
|
resource:
|
||||||
|
name: cpu
|
||||||
|
target:
|
||||||
|
type: Utilization
|
||||||
|
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||||
|
- type: Resource
|
||||||
|
resource:
|
||||||
|
name: memory
|
||||||
|
target:
|
||||||
|
type: Utilization
|
||||||
|
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
61
helm-chart/sut/templates/ingress.yaml
Normal file
61
helm-chart/sut/templates/ingress.yaml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
{{- if .Values.ingress.enabled -}}
|
||||||
|
{{- $fullName := include "sut.fullname" . -}}
|
||||||
|
{{- $svcPort := .Values.service.port -}}
|
||||||
|
{{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }}
|
||||||
|
{{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }}
|
||||||
|
{{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||||
|
apiVersion: networking.k8s.io/v1
|
||||||
|
{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}}
|
||||||
|
apiVersion: networking.k8s.io/v1beta1
|
||||||
|
{{- else -}}
|
||||||
|
apiVersion: extensions/v1beta1
|
||||||
|
{{- end }}
|
||||||
|
kind: Ingress
|
||||||
|
metadata:
|
||||||
|
name: {{ $fullName }}
|
||||||
|
labels:
|
||||||
|
{{- include "sut.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.ingress.annotations }}
|
||||||
|
annotations:
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
{{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }}
|
||||||
|
ingressClassName: {{ .Values.ingress.className }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Values.ingress.tls }}
|
||||||
|
tls:
|
||||||
|
{{- range .Values.ingress.tls }}
|
||||||
|
- hosts:
|
||||||
|
{{- range .hosts }}
|
||||||
|
- {{ . | quote }}
|
||||||
|
{{- end }}
|
||||||
|
secretName: {{ .secretName }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
rules:
|
||||||
|
{{- range .Values.ingress.hosts }}
|
||||||
|
- host: {{ .host | quote }}
|
||||||
|
http:
|
||||||
|
paths:
|
||||||
|
{{- range .paths }}
|
||||||
|
- path: {{ .path }}
|
||||||
|
{{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }}
|
||||||
|
pathType: {{ .pathType }}
|
||||||
|
{{- end }}
|
||||||
|
backend:
|
||||||
|
{{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }}
|
||||||
|
service:
|
||||||
|
name: {{ $fullName }}
|
||||||
|
port:
|
||||||
|
number: {{ $svcPort }}
|
||||||
|
{{- else }}
|
||||||
|
serviceName: {{ $fullName }}
|
||||||
|
servicePort: {{ $svcPort }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
18
helm-chart/sut/templates/service.yaml
Normal file
18
helm-chart/sut/templates/service.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: {{ include "sut.fullname" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "sut.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.podLabels }}
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
spec:
|
||||||
|
type: {{ .Values.service.type }}
|
||||||
|
ports:
|
||||||
|
- port: {{ .Values.service.port }}
|
||||||
|
targetPort: http
|
||||||
|
protocol: TCP
|
||||||
|
name: socket
|
||||||
|
selector:
|
||||||
|
{{- include "sut.selectorLabels" . | nindent 4 }}
|
||||||
13
helm-chart/sut/templates/serviceaccount.yaml
Normal file
13
helm-chart/sut/templates/serviceaccount.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{{- if .Values.serviceAccount.create -}}
|
||||||
|
apiVersion: v1
|
||||||
|
kind: ServiceAccount
|
||||||
|
metadata:
|
||||||
|
name: {{ include "sut.serviceAccountName" . }}
|
||||||
|
labels:
|
||||||
|
{{- include "sut.labels" . | nindent 4 }}
|
||||||
|
{{- with .Values.serviceAccount.annotations }}
|
||||||
|
annotations:
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
|
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
|
||||||
|
{{- end }}
|
||||||
15
helm-chart/sut/templates/tests/test-connection.yaml
Normal file
15
helm-chart/sut/templates/tests/test-connection.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: Pod
|
||||||
|
metadata:
|
||||||
|
name: "{{ include "sut.fullname" . }}-test-connection"
|
||||||
|
labels:
|
||||||
|
{{- include "sut.labels" . | nindent 4 }}
|
||||||
|
annotations:
|
||||||
|
"helm.sh/hook": test
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: wget
|
||||||
|
image: busybox
|
||||||
|
command: ['wget']
|
||||||
|
args: ['{{ include "sut.fullname" . }}:{{ .Values.service.port }}']
|
||||||
|
restartPolicy: Never
|
||||||
144
helm-chart/sut/values.yaml.tmpl
Normal file
144
helm-chart/sut/values.yaml.tmpl
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# Default values for sut.
|
||||||
|
# This is a YAML-formatted file.
|
||||||
|
# Declare variables to be passed into your templates.
|
||||||
|
|
||||||
|
replicaCount: 1
|
||||||
|
|
||||||
|
image:
|
||||||
|
repository: harbor.4pd.io/lab-platform/inf/python
|
||||||
|
pullPolicy: IfNotPresent
|
||||||
|
# Overrides the image tag whose default is the chart appVersion.
|
||||||
|
tag: 3.9
|
||||||
|
|
||||||
|
imagePullSecrets: []
|
||||||
|
nameOverride: ""
|
||||||
|
fullnameOverride: ""
|
||||||
|
|
||||||
|
serviceAccount:
|
||||||
|
# Specifies whether a service account should be created
|
||||||
|
create: true
|
||||||
|
# Automatically mount a ServiceAccount's API credentials?
|
||||||
|
automount: true
|
||||||
|
# Annotations to add to the service account
|
||||||
|
annotations: {}
|
||||||
|
# The name of the service account to use.
|
||||||
|
# If not set and create is true, a name is generated using the fullname template
|
||||||
|
name: ""
|
||||||
|
|
||||||
|
podAnnotations: {}
|
||||||
|
podLabels: {}
|
||||||
|
podSecurityContext: {}
|
||||||
|
# fsGroup: 2000
|
||||||
|
|
||||||
|
securityContext: {}
|
||||||
|
# capabilities:
|
||||||
|
# drop:
|
||||||
|
# - ALL
|
||||||
|
# readOnlyRootFilesystem: true
|
||||||
|
# runAsNonRoot: true
|
||||||
|
# runAsUser: 1000
|
||||||
|
|
||||||
|
service:
|
||||||
|
type: ClusterIP
|
||||||
|
port: 80
|
||||||
|
|
||||||
|
ingress:
|
||||||
|
enabled: false
|
||||||
|
className: ""
|
||||||
|
annotations: {}
|
||||||
|
# kubernetes.io/ingress.class: nginx
|
||||||
|
# kubernetes.io/tls-acme: "true"
|
||||||
|
hosts:
|
||||||
|
- host: chart-example.local
|
||||||
|
paths:
|
||||||
|
- path: /
|
||||||
|
pathType: ImplementationSpecific
|
||||||
|
tls: []
|
||||||
|
# - secretName: chart-example-tls
|
||||||
|
# hosts:
|
||||||
|
# - chart-example.local
|
||||||
|
|
||||||
|
resources:
|
||||||
|
# We usually recommend not to specify default resources and to leave this as a conscious
|
||||||
|
# choice for the user. This also increases chances charts run on environments with little
|
||||||
|
# resources, such as Minikube. If you do want to specify resources, uncomment the following
|
||||||
|
# lines, adjust them as necessary, and remove the curly braces after 'resources:'.
|
||||||
|
limits:
|
||||||
|
cpu: 1000m
|
||||||
|
memory: 4096Mi
|
||||||
|
requests:
|
||||||
|
cpu: 1000m
|
||||||
|
memory: 4096Mi
|
||||||
|
|
||||||
|
autoscaling:
|
||||||
|
enabled: false
|
||||||
|
minReplicas: 1
|
||||||
|
maxReplicas: 100
|
||||||
|
targetCPUUtilizationPercentage: 80
|
||||||
|
# targetMemoryUtilizationPercentage: 80
|
||||||
|
|
||||||
|
# Additional volumes on the output Deployment definition.
|
||||||
|
volumes: []
|
||||||
|
# - name: foo
|
||||||
|
# secret:
|
||||||
|
# secretName: mysecret
|
||||||
|
# optional: false
|
||||||
|
|
||||||
|
# Additional volumeMounts on the output Deployment definition.
|
||||||
|
volumeMounts: []
|
||||||
|
# - name: foo
|
||||||
|
# mountPath: "/etc/foo"
|
||||||
|
# readOnly: true
|
||||||
|
|
||||||
|
nodeSelector:
|
||||||
|
contest.4pd.io/accelerator: iluvatar-BI-V100
|
||||||
|
|
||||||
|
tolerations:
|
||||||
|
- key: hosttype
|
||||||
|
operator: Equal
|
||||||
|
value: iluvatar
|
||||||
|
effect: NoSchedule
|
||||||
|
|
||||||
|
|
||||||
|
affinity: {}
|
||||||
|
|
||||||
|
readinessProbe:
|
||||||
|
failureThreshold: 1000
|
||||||
|
httpGet:
|
||||||
|
path: /health
|
||||||
|
port: 80
|
||||||
|
scheme: HTTP
|
||||||
|
|
||||||
|
#readinessProbe:
|
||||||
|
# httpGet:
|
||||||
|
# path: /health
|
||||||
|
# port: 80
|
||||||
|
# scheme: HTTP
|
||||||
|
# initialDelaySeconds: 5 # 应用启动后等待 5 秒再开始探测
|
||||||
|
# failureThreshold: 5 # 连续失败 3 次后标记为未就绪
|
||||||
|
# successThreshold: 1 # 连续成功 1 次后标记为就绪
|
||||||
|
|
||||||
|
env:
|
||||||
|
- name: TZ
|
||||||
|
value: Asia/Shanghai
|
||||||
|
- name: MY_POD_NAME
|
||||||
|
valueFrom:
|
||||||
|
fieldRef:
|
||||||
|
fieldPath: metadata.name
|
||||||
|
- name: MY_POD_NAMESPACE
|
||||||
|
valueFrom:
|
||||||
|
fieldRef:
|
||||||
|
fieldPath: metadata.namespace
|
||||||
|
- name: MY_POD_IP
|
||||||
|
valueFrom:
|
||||||
|
fieldRef:
|
||||||
|
fieldPath: status.podIP
|
||||||
|
- name: MY_NODE_IP
|
||||||
|
valueFrom:
|
||||||
|
fieldRef:
|
||||||
|
fieldPath: status.hostIP
|
||||||
|
|
||||||
|
#command: ''
|
||||||
|
|
||||||
|
|
||||||
|
priorityclassname: ''
|
||||||
64
local_test.py
Normal file
64
local_test.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if os.path.exists("/tmp/submit_private"):
|
||||||
|
shutil.rmtree("/tmp/submit_private")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
config_path = os.path.join(tempdir, "config.json")
|
||||||
|
|
||||||
|
assert not os.system(f"ssh-keygen -f {tempdir}/ssh-key-ecdsa -t ecdsa -b 521 -q -N \"\"")
|
||||||
|
|
||||||
|
config = """
|
||||||
|
model: whisper
|
||||||
|
model_key: whisper
|
||||||
|
config.json:
|
||||||
|
name: 'faster-whisper-server:latest'
|
||||||
|
support_devices:
|
||||||
|
- cpu
|
||||||
|
model_path: ''
|
||||||
|
port: 8080
|
||||||
|
other_ports: []
|
||||||
|
other_ports_count: 1
|
||||||
|
entrypoint: start.bat
|
||||||
|
MIN_CHUNK: 2.5
|
||||||
|
MIN_ADD_CHUNK: 2.5
|
||||||
|
COMPUTE_TYPE: int8
|
||||||
|
NUM_WORKERS: 1
|
||||||
|
CPU_THREADS: 2
|
||||||
|
BEAM_SIZE: 5
|
||||||
|
BATCH: 1
|
||||||
|
LANG: auto
|
||||||
|
DEVICE: cpu
|
||||||
|
CHUNK_LENGTH: 5
|
||||||
|
CLASS_MODEL: ./models/faster-whisper-base
|
||||||
|
EN_MODEL: ./models/faster-whisper-base
|
||||||
|
ZH_MODEL: ./models/faster-whisper-base
|
||||||
|
RU_MODEL: ./models/faster-whisper-base
|
||||||
|
PT_MODEL: ./models/faster-whisper-base
|
||||||
|
AR_MODEL: ./models/faster-whisper-base
|
||||||
|
NEW_VERSION: 1
|
||||||
|
NEED_RESET: 0
|
||||||
|
leaderboard_options:
|
||||||
|
nfs:
|
||||||
|
- name: whisper
|
||||||
|
srcRelativePath: leaderboard/pc_asr/en.tar.gz
|
||||||
|
mountPoint: /tmp
|
||||||
|
source: ceph_customer
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
f.write(config)
|
||||||
|
|
||||||
|
os.environ["SSH_KEY_DIR"] = tempdir
|
||||||
|
os.environ["SUBMIT_CONFIG_FILEPATH"] = config_path
|
||||||
|
os.environ["MODEL_MAPPING"] = '{"whisper": "edge-ml.tar.gz"}'
|
||||||
|
|
||||||
|
from run_async_a10 import get_sut_url_windows
|
||||||
|
|
||||||
|
|
||||||
|
print(get_sut_url_windows())
|
||||||
|
|
||||||
|
import time
|
||||||
|
time.sleep(3600)
|
||||||
8
mock_env.sh
Normal file
8
mock_env.sh
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
export DATASET_FILEPATH=dataset/formatted1/de.zip
|
||||||
|
export RESULT_FILEPATH=out/result.json
|
||||||
|
export DETAILED_CASES_FILEPATH=out/detail_cases.json
|
||||||
|
export SUBMIT_CONFIG_FILEPATH=
|
||||||
|
export BENCHMARK_NAME=
|
||||||
|
export MY_POD_IP=127.0.0.1
|
||||||
215
model_test_caltech_3.py
Normal file
215
model_test_caltech_3.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from transformers import BeitImageProcessor, BeitForImageClassification
|
||||||
|
# 根据模型实际架构选择类
|
||||||
|
from transformers import ViTForImageClassification, BeitForImageClassification
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time # 新增导入时间模块
|
||||||
|
|
||||||
|
# 支持 Iluvatar GPU 加速,若不可用则使用 CPU
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"当前使用的设备: {device}") # 添加调试信息
|
||||||
|
|
||||||
|
# 若有多块 GPU,可使用 DataParallel 进行并行计算
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print(f"使用 {torch.cuda.device_count()} 块 GPU 进行计算")
|
||||||
|
|
||||||
|
class COCOImageClassifier:
|
||||||
|
def __init__(self, model_path: str, local_image_paths: list):
|
||||||
|
"""初始化COCO图像分类器"""
|
||||||
|
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
self.model = AutoModelForImageClassification.from_pretrained(model_path)
|
||||||
|
|
||||||
|
# 将模型移动到设备
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
print(f"模型是否在 GPU 上: {next(self.model.parameters()).is_cuda}") # 添加调试信息
|
||||||
|
|
||||||
|
# 若有多块 GPU,使用 DataParallel
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
self.model = torch.nn.DataParallel(self.model)
|
||||||
|
|
||||||
|
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
|
||||||
|
self.local_image_paths = local_image_paths
|
||||||
|
|
||||||
|
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
|
||||||
|
"""
|
||||||
|
预测本地图片文件对应的图片类别
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: 本地图片文件路径
|
||||||
|
top_k: 返回置信度最高的前k个类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含预测结果的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 打开图片
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# 预处理
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
# 将输入数据移动到设备
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
|
||||||
|
# 模型推理
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
# 获取预测结果
|
||||||
|
logits = outputs.logits
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
top_probs, top_indices = probs.topk(top_k, dim=1)
|
||||||
|
|
||||||
|
# 整理结果
|
||||||
|
predictions = []
|
||||||
|
for i in range(top_k):
|
||||||
|
class_idx = top_indices[0, i].item()
|
||||||
|
confidence = top_probs[0, i].item()
|
||||||
|
predictions.append({
|
||||||
|
"class_id": class_idx,
|
||||||
|
"class_name": self.id2label[class_idx],
|
||||||
|
"confidence": confidence
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image_path": image_path,
|
||||||
|
"predictions": predictions
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图片文件 {image_path} 时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
|
||||||
|
"""
|
||||||
|
批量预测本地图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: 限制处理的图片数量
|
||||||
|
top_k: 返回置信度最高的前k个类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有预测结果的列表
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
local_image_paths = self.local_image_paths[:limit]
|
||||||
|
|
||||||
|
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
|
||||||
|
start_time = time.time() # 记录开始时间
|
||||||
|
for image_path in tqdm(local_image_paths):
|
||||||
|
result = self.predict_image_path(image_path, top_k)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
end_time = time.time() # 记录结束时间
|
||||||
|
total_time = end_time - start_time # 计算总时间
|
||||||
|
images_per_second = len(results) / total_time # 计算每秒处理的图片数量
|
||||||
|
print(f"模型每秒可以处理 {images_per_second:.2f} 张图片")
|
||||||
|
return results
|
||||||
|
|
||||||
|
def save_results(self, results: list, output_file: str = "caltech_predictions.json"):
|
||||||
|
"""
|
||||||
|
保存预测结果到JSON文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: 预测结果列表
|
||||||
|
output_file: 输出文件名
|
||||||
|
"""
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"结果已保存到 {output_file}")
|
||||||
|
|
||||||
|
# 主程序
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 替换为本地模型路径
|
||||||
|
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
|
||||||
|
|
||||||
|
# 替换为Caltech 256数据集文件夹路径
|
||||||
|
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
|
||||||
|
|
||||||
|
local_image_paths = []
|
||||||
|
true_labels = {}
|
||||||
|
|
||||||
|
# 遍历Caltech 256数据集中的每个文件夹
|
||||||
|
for folder in os.listdir(CALTECH_256_PATH):
|
||||||
|
folder_path = os.path.join(CALTECH_256_PATH, folder)
|
||||||
|
if os.path.isdir(folder_path):
|
||||||
|
# 获取文件夹名称中的类别名称
|
||||||
|
class_name = folder.split('.', 1)[1]
|
||||||
|
# 获取文件夹中的所有图片文件
|
||||||
|
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
||||||
|
# 随机选择3张图片
|
||||||
|
selected_images = random.sample(image_files, min(3, len(image_files)))
|
||||||
|
for image_path in selected_images:
|
||||||
|
local_image_paths.append(image_path)
|
||||||
|
true_labels[image_path] = class_name
|
||||||
|
|
||||||
|
# 创建分类器实例
|
||||||
|
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
|
||||||
|
|
||||||
|
# 批量预测
|
||||||
|
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
classifier.save_results(results)
|
||||||
|
|
||||||
|
# 打印简要统计
|
||||||
|
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
|
||||||
|
if results:
|
||||||
|
print("\n示例预测结果:")
|
||||||
|
sample = results[0]
|
||||||
|
print(f"图片路径: {sample['image_path']}")
|
||||||
|
for i, pred in enumerate(sample['predictions'], 1):
|
||||||
|
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
|
||||||
|
|
||||||
|
correct_count = 0
|
||||||
|
total_count = len(results)
|
||||||
|
|
||||||
|
# 统计每个类别的实际样本数和正确预测数
|
||||||
|
class_actual_count = {}
|
||||||
|
class_correct_count = {}
|
||||||
|
|
||||||
|
for prediction in results:
|
||||||
|
image_path = prediction['image_path']
|
||||||
|
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
|
||||||
|
predicted_class = top1_prediction['class_name'].lower()
|
||||||
|
true_class = true_labels.get(image_path).lower()
|
||||||
|
|
||||||
|
# 统计每个类别的实际样本数
|
||||||
|
if true_class not in class_actual_count:
|
||||||
|
class_actual_count[true_class] = 0
|
||||||
|
class_actual_count[true_class] += 1
|
||||||
|
|
||||||
|
# 检查预测类别中的每个单词是否包含真实标签
|
||||||
|
words = predicted_class.split()
|
||||||
|
for word in words:
|
||||||
|
if true_class in word:
|
||||||
|
correct_count += 1
|
||||||
|
# 统计每个类别的正确预测数
|
||||||
|
if true_class not in class_correct_count:
|
||||||
|
class_correct_count[true_class] = 0
|
||||||
|
class_correct_count[true_class] += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
accuracy = correct_count / total_count
|
||||||
|
print(f"\nAccuracy: {accuracy * 100:.2f}%")
|
||||||
|
|
||||||
|
# 计算每个类别的召回率
|
||||||
|
recall_per_class = {}
|
||||||
|
for class_name in class_actual_count:
|
||||||
|
if class_name in class_correct_count:
|
||||||
|
recall_per_class[class_name] = class_correct_count[class_name] / class_actual_count[class_name]
|
||||||
|
else:
|
||||||
|
recall_per_class[class_name] = 0
|
||||||
|
|
||||||
|
# 计算平均召回率
|
||||||
|
average_recall = sum(recall_per_class.values()) / len(recall_per_class)
|
||||||
|
print(f"\nAverage Recall: {average_recall * 100:.2f}%")
|
||||||
197
model_test_caltech_cpu1.py
Normal file
197
model_test_caltech_cpu1.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 强制使用CPU
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print(f"当前使用的设备: {device}")
|
||||||
|
|
||||||
|
class COCOImageClassifier:
|
||||||
|
def __init__(self, model_path: str, local_image_paths: list):
|
||||||
|
"""初始化COCO图像分类器"""
|
||||||
|
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
self.model = AutoModelForImageClassification.from_pretrained(model_path)
|
||||||
|
|
||||||
|
# 将模型移动到CPU
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
self.id2label = self.model.config.id2label
|
||||||
|
self.local_image_paths = local_image_paths
|
||||||
|
|
||||||
|
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
|
||||||
|
"""
|
||||||
|
预测本地图片文件对应的图片类别
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: 本地图片文件路径
|
||||||
|
top_k: 返回置信度最高的前k个类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含预测结果的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 打开图片
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# 预处理
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
# 将输入数据移动到CPU
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
|
||||||
|
# 模型推理
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
# 获取预测结果
|
||||||
|
logits = outputs.logits
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
top_probs, top_indices = probs.topk(top_k, dim=1)
|
||||||
|
|
||||||
|
# 整理结果
|
||||||
|
predictions = []
|
||||||
|
for i in range(top_k):
|
||||||
|
class_idx = top_indices[0, i].item()
|
||||||
|
confidence = top_probs[0, i].item()
|
||||||
|
predictions.append({
|
||||||
|
"class_id": class_idx,
|
||||||
|
"class_name": self.id2label[class_idx],
|
||||||
|
"confidence": confidence
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image_path": image_path,
|
||||||
|
"predictions": predictions
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图片文件 {image_path} 时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
|
||||||
|
"""
|
||||||
|
批量预测本地图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: 限制处理的图片数量
|
||||||
|
top_k: 返回置信度最高的前k个类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有预测结果的列表
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
local_image_paths = self.local_image_paths[:limit]
|
||||||
|
|
||||||
|
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
|
||||||
|
start_time = time.time()
|
||||||
|
for image_path in tqdm(local_image_paths):
|
||||||
|
result = self.predict_image_path(image_path, top_k)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
# 计算吞吐量
|
||||||
|
throughput = len(results) / (end_time - start_time)
|
||||||
|
print(f"模型每秒可以处理 {throughput:.2f} 张图片")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def save_results(self, results: list, output_file: str = "celtech_cpu_predictions.json"):
|
||||||
|
"""
|
||||||
|
保存预测结果到JSON文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: 预测结果列表
|
||||||
|
output_file: 输出文件名
|
||||||
|
"""
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"结果已保存到 {output_file}")
|
||||||
|
|
||||||
|
# 主程序
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 替换为本地模型路径
|
||||||
|
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
|
||||||
|
|
||||||
|
# 替换为Caltech 256数据集文件夹路径 New
|
||||||
|
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
|
||||||
|
|
||||||
|
local_image_paths = []
|
||||||
|
true_labels = {}
|
||||||
|
|
||||||
|
# 遍历Caltech 256数据集中的每个文件夹
|
||||||
|
for folder in os.listdir(CALTECH_256_PATH):
|
||||||
|
folder_path = os.path.join(CALTECH_256_PATH, folder)
|
||||||
|
if os.path.isdir(folder_path):
|
||||||
|
# 获取文件夹名称中的类别名称
|
||||||
|
class_name = folder.split('.', 1)[1]
|
||||||
|
# 获取文件夹中的所有图片文件
|
||||||
|
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
||||||
|
# 随机选择3张图片
|
||||||
|
selected_images = random.sample(image_files, min(3, len(image_files)))
|
||||||
|
for image_path in selected_images:
|
||||||
|
local_image_paths.append(image_path)
|
||||||
|
true_labels[image_path] = class_name
|
||||||
|
|
||||||
|
# 创建分类器实例
|
||||||
|
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
|
||||||
|
|
||||||
|
# 批量预测
|
||||||
|
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
classifier.save_results(results)
|
||||||
|
|
||||||
|
# 打印简要统计
|
||||||
|
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
|
||||||
|
if results:
|
||||||
|
print("\n示例预测结果:")
|
||||||
|
sample = results[0]
|
||||||
|
print(f"图片路径: {sample['image_path']}")
|
||||||
|
for i, pred in enumerate(sample['predictions'], 1):
|
||||||
|
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
|
||||||
|
|
||||||
|
correct_count = 0
|
||||||
|
total_count = len(results)
|
||||||
|
class_true_positives = {}
|
||||||
|
class_false_negatives = {}
|
||||||
|
|
||||||
|
for prediction in results:
|
||||||
|
image_path = prediction['image_path']
|
||||||
|
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
|
||||||
|
predicted_class = top1_prediction['class_name'].lower()
|
||||||
|
true_class = true_labels.get(image_path).lower()
|
||||||
|
|
||||||
|
if true_class not in class_true_positives:
|
||||||
|
class_true_positives[true_class] = 0
|
||||||
|
class_false_negatives[true_class] = 0
|
||||||
|
|
||||||
|
# 检查预测类别中的每个单词是否包含真实标签
|
||||||
|
words = predicted_class.split()
|
||||||
|
for word in words:
|
||||||
|
if true_class in word:
|
||||||
|
correct_count += 1
|
||||||
|
class_true_positives[true_class] += 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
class_false_negatives[true_class] += 1
|
||||||
|
|
||||||
|
accuracy = correct_count / total_count
|
||||||
|
print(f"\nAccuracy: {accuracy * 100:.2f}%")
|
||||||
|
|
||||||
|
# 计算召回率
|
||||||
|
total_true_positives = 0
|
||||||
|
total_false_negatives = 0
|
||||||
|
for class_name in class_true_positives:
|
||||||
|
total_true_positives += class_true_positives[class_name]
|
||||||
|
total_false_negatives += class_false_negatives[class_name]
|
||||||
|
|
||||||
|
recall = total_true_positives / (total_true_positives + total_false_negatives)
|
||||||
|
print(f"Recall: {recall * 100:.2f}%")
|
||||||
166
model_test_caltech_http.py
Normal file
166
model_test_caltech_http.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import multiprocessing
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
# 设置CPU核心数为4
|
||||||
|
os.environ["OMP_NUM_THREADS"] = "4"
|
||||||
|
os.environ["MKL_NUM_THREADS"] = "4"
|
||||||
|
os.environ["NUMEXPR_NUM_THREADS"] = "4"
|
||||||
|
os.environ["OPENBLAS_NUM_THREADS"] = "4"
|
||||||
|
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
|
||||||
|
torch.set_num_threads(4) # 设置PyTorch的CPU线程数
|
||||||
|
|
||||||
|
# 设备配置
|
||||||
|
device_cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
device_cpu = torch.device("cpu")
|
||||||
|
print(f"当前CUDA设备: {device_cuda}, CPU设备: {device_cpu}")
|
||||||
|
print(f"CPU核心数设置: {torch.get_num_threads()}")
|
||||||
|
|
||||||
|
class ImageClassifier:
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
|
||||||
|
# 分别加载GPU和CPU模型实例
|
||||||
|
if device_cuda.type == "cuda":
|
||||||
|
self.model_cuda = AutoModelForImageClassification.from_pretrained(model_path).to(device_cuda)
|
||||||
|
else:
|
||||||
|
self.model_cuda = None # 若没有CUDA,则不加载
|
||||||
|
|
||||||
|
self.model_cpu = AutoModelForImageClassification.from_pretrained(model_path).to(device_cpu)
|
||||||
|
|
||||||
|
# 保存id2label映射
|
||||||
|
self.id2label = self.model_cpu.config.id2label
|
||||||
|
|
||||||
|
def _predict_with_model(self, image, model, device) -> dict:
|
||||||
|
"""使用指定模型和设备执行预测,包含单独计时"""
|
||||||
|
try:
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.perf_counter() # 使用更精确的计时函数
|
||||||
|
|
||||||
|
# 处理图片并移动到目标设备
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(** inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
max_prob, max_idx = probs.max(dim=1)
|
||||||
|
class_idx = max_idx.item()
|
||||||
|
|
||||||
|
# 计算处理时间(秒),保留6位小数
|
||||||
|
processing_time = round(time.perf_counter() - start_time, 6)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"class_id": class_idx,
|
||||||
|
"class_name": self.id2label[class_idx],
|
||||||
|
"confidence": float(max_prob.item()),
|
||||||
|
"device_used": str(device),
|
||||||
|
"processing_time": processing_time # 处理时间
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"class_id": -1,
|
||||||
|
"class_name": "error",
|
||||||
|
"confidence": 0.0,
|
||||||
|
"device_used": str(device),
|
||||||
|
"processing_time": 0.0,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict_single_image(self, image) -> dict:
|
||||||
|
"""预测单张图片,分别使用GPU和CPU模型"""
|
||||||
|
results = {"status": "success"}
|
||||||
|
|
||||||
|
# GPU预测(如果可用)
|
||||||
|
if self.model_cuda is not None:
|
||||||
|
cuda_result = self._predict_with_model(image, self.model_cuda, device_cuda)
|
||||||
|
else:
|
||||||
|
cuda_result = {
|
||||||
|
"class_id": -1,
|
||||||
|
"class_name": "error",
|
||||||
|
"confidence": 0.0,
|
||||||
|
"device_used": str(device_cuda),
|
||||||
|
"processing_time": 0.0,
|
||||||
|
"error": "CUDA设备不可用,未加载CUDA模型"
|
||||||
|
}
|
||||||
|
results["cuda_prediction"] = cuda_result
|
||||||
|
|
||||||
|
# CPU预测(已限制为4核心)
|
||||||
|
cpu_result = self._predict_with_model(image, self.model_cpu, device_cpu)
|
||||||
|
results["cpu_prediction"] = cpu_result
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
app = Flask(__name__)
|
||||||
|
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径)
|
||||||
|
classifier = ImageClassifier(MODEL_PATH)
|
||||||
|
|
||||||
|
@app.route('/v1/private/s782b4996', methods=['POST'])
|
||||||
|
def predict_single():
|
||||||
|
"""接收单张图片并返回预测结果及处理时间"""
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"cuda_prediction": {
|
||||||
|
"class_id": -1,
|
||||||
|
"class_name": "error",
|
||||||
|
"confidence": 0.0,
|
||||||
|
"device_used": str(device_cuda),
|
||||||
|
"processing_time": 0.0,
|
||||||
|
"error": "请求中未包含图片"
|
||||||
|
},
|
||||||
|
"cpu_prediction": {
|
||||||
|
"class_id": -1,
|
||||||
|
"class_name": "error",
|
||||||
|
"confidence": 0.0,
|
||||||
|
"device_used": str(device_cpu),
|
||||||
|
"processing_time": 0.0,
|
||||||
|
"error": "请求中未包含图片"
|
||||||
|
}
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
image_file = request.files['image']
|
||||||
|
try:
|
||||||
|
image = Image.open(BytesIO(image_file.read())).convert("RGB")
|
||||||
|
result = classifier.predict_single_image(image)
|
||||||
|
return jsonify(result)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"cuda_prediction": {
|
||||||
|
"class_id": -1,
|
||||||
|
"class_name": "error",
|
||||||
|
"confidence": 0.0,
|
||||||
|
"device_used": str(device_cuda),
|
||||||
|
"processing_time": 0.0,
|
||||||
|
"error": str(e)
|
||||||
|
},
|
||||||
|
"cpu_prediction": {
|
||||||
|
"class_id": -1,
|
||||||
|
"class_name": "error",
|
||||||
|
"confidence": 0.0,
|
||||||
|
"device_used": str(device_cpu),
|
||||||
|
"processing_time": 0.0,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
return jsonify({
|
||||||
|
"status": "healthy",
|
||||||
|
"cuda_available": device_cuda.type == "cuda",
|
||||||
|
"cuda_device": str(device_cuda),
|
||||||
|
"cpu_device": str(device_cpu),
|
||||||
|
"cpu_threads": torch.get_num_threads() # 显示CPU线程数
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host='0.0.0.0', port=80, debug=False)
|
||||||
163
model_test_caltech_http_1.py
Normal file
163
model_test_caltech_http_1.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from flask import Flask, request, jsonify # 引入Flask
|
||||||
|
|
||||||
|
# 设备配置
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"当前使用的设备: {device}")
|
||||||
|
|
||||||
|
class COCOImageClassifier:
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
"""初始化分类器(移除local_image_paths参数,改为动态接收)"""
|
||||||
|
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
self.model = AutoModelForImageClassification.from_pretrained(model_path)
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print(f"使用 {torch.cuda.device_count()} 块GPU")
|
||||||
|
self.model = torch.nn.DataParallel(self.model)
|
||||||
|
|
||||||
|
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
|
||||||
|
|
||||||
|
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
|
||||||
|
"""预测单张图片(复用原逻辑)"""
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(** inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
top_probs, top_indices = probs.topk(top_k, dim=1)
|
||||||
|
|
||||||
|
predictions = []
|
||||||
|
for i in range(top_k):
|
||||||
|
class_idx = top_indices[0, i].item()
|
||||||
|
predictions.append({
|
||||||
|
"class_id": class_idx,
|
||||||
|
"class_name": self.id2label[class_idx],
|
||||||
|
"confidence": top_probs[0, i].item()
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image_path": image_path,
|
||||||
|
"predictions": predictions
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图片 {image_path} 出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def batch_predict_and_evaluate(self, image_paths: list, true_labels: dict, top_k: int = 3) -> dict:
|
||||||
|
"""批量预测并计算准确率、召回率"""
|
||||||
|
results = []
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for image_path in tqdm(image_paths):
|
||||||
|
result = self.predict_image_path(image_path, top_k)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
images_per_second = len(results) / total_time if total_time > 0 else 0
|
||||||
|
|
||||||
|
# 计算准确率和召回率(复用原逻辑)
|
||||||
|
correct_count = 0
|
||||||
|
total_count = len(results)
|
||||||
|
class_actual_count = {}
|
||||||
|
class_correct_count = {}
|
||||||
|
|
||||||
|
for prediction in results:
|
||||||
|
image_path = prediction['image_path']
|
||||||
|
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
|
||||||
|
predicted_class = top1_prediction['class_name'].lower()
|
||||||
|
true_class = true_labels.get(image_path, "").lower()
|
||||||
|
|
||||||
|
# 统计每个类别的实际样本数
|
||||||
|
class_actual_count[true_class] = class_actual_count.get(true_class, 0) + 1
|
||||||
|
|
||||||
|
# 检查预测是否正确
|
||||||
|
words = predicted_class.split()
|
||||||
|
for word in words:
|
||||||
|
if true_class in word:
|
||||||
|
correct_count += 1
|
||||||
|
class_correct_count[true_class] = class_correct_count.get(true_class, 0) + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# 计算指标
|
||||||
|
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||||
|
recall_per_class = {}
|
||||||
|
for class_name in class_actual_count:
|
||||||
|
recall_per_class[class_name] = class_correct_count.get(class_name, 0) / class_actual_count[class_name]
|
||||||
|
|
||||||
|
average_recall = sum(recall_per_class.values()) / len(recall_per_class) if recall_per_class else 0
|
||||||
|
|
||||||
|
# 返回包含指标的结果
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"metrics": {
|
||||||
|
"accuracy": round(accuracy * 100, 2), # 百分比
|
||||||
|
"average_recall": round(average_recall * 100, 2), # 百分比
|
||||||
|
"total_images": total_count,
|
||||||
|
"correct_predictions": correct_count,
|
||||||
|
"speed_images_per_second": round(images_per_second, 2)
|
||||||
|
},
|
||||||
|
"sample_predictions": results[:3] # 示例预测结果(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化Flask服务
|
||||||
|
app = Flask(__name__)
|
||||||
|
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 容器内模型路径
|
||||||
|
DATASET_PATH = os.environ.get("DATASET_PATH", "/app/dataset") # 容器内数据集路径
|
||||||
|
classifier = COCOImageClassifier(MODEL_PATH)
|
||||||
|
|
||||||
|
@app.route('/v1/private/s782b4996', methods=['POST'])
|
||||||
|
def evaluate():
|
||||||
|
"""接收请求并返回评估结果(准确率、召回率等)"""
|
||||||
|
try:
|
||||||
|
# 解析请求参数(可选:允许动态指定limit等参数)
|
||||||
|
data = request.get_json()
|
||||||
|
limit = data.get("limit", 20) # 限制处理的图片数量
|
||||||
|
|
||||||
|
# 加载数据集(容器内路径)
|
||||||
|
local_image_paths = []
|
||||||
|
true_labels = {}
|
||||||
|
for folder in os.listdir(DATASET_PATH):
|
||||||
|
folder_path = os.path.join(DATASET_PATH, folder)
|
||||||
|
if os.path.isdir(folder_path):
|
||||||
|
class_name = folder.split('.', 1)[1]
|
||||||
|
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
||||||
|
selected_images = random.sample(image_files, min(3, len(image_files)))
|
||||||
|
for image_path in selected_images:
|
||||||
|
local_image_paths.append(image_path)
|
||||||
|
true_labels[image_path] = class_name
|
||||||
|
|
||||||
|
# 限制处理数量
|
||||||
|
local_image_paths = local_image_paths[:limit]
|
||||||
|
|
||||||
|
# 执行预测和评估
|
||||||
|
result = classifier.batch_predict_and_evaluate(local_image_paths, true_labels, top_k=3)
|
||||||
|
return jsonify(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
return jsonify({"status": "healthy", "device": str(device)}), 200
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host='0.0.0.0', port=8000, debug=False)
|
||||||
89
model_test_caltech_http_3.py
Normal file
89
model_test_caltech_http_3.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
import os
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
# 设备配置
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"当前使用的设备: {device}")
|
||||||
|
|
||||||
|
class ImageClassifier:
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
# 获取模型路径下的第一个子目录(假设模型文件存放在这里)
|
||||||
|
subdirs = [d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d))]
|
||||||
|
if not subdirs:
|
||||||
|
raise ValueError(f"在 {model_path} 下未找到任何子目录,无法加载模型")
|
||||||
|
|
||||||
|
# 实际的模型文件路径
|
||||||
|
actual_model_path = os.path.join(model_path, subdirs[0])
|
||||||
|
print(f"加载模型从: {actual_model_path}")
|
||||||
|
|
||||||
|
self.processor = AutoImageProcessor.from_pretrained(actual_model_path)
|
||||||
|
self.model = AutoModelForImageClassification.from_pretrained(actual_model_path)
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print(f"使用 {torch.cuda.device_count()} 块GPU")
|
||||||
|
self.model = torch.nn.DataParallel(self.model)
|
||||||
|
|
||||||
|
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
|
||||||
|
|
||||||
|
def predict_single_image(self, image) -> dict:
|
||||||
|
"""预测单张图片,返回置信度最高的结果"""
|
||||||
|
try:
|
||||||
|
# 处理图片
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(** inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
# 获取置信度最高的预测结果
|
||||||
|
max_prob, max_idx = probs.max(dim=1)
|
||||||
|
class_idx = max_idx.item()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"top_prediction": {
|
||||||
|
"class_id": class_idx,
|
||||||
|
"class_name": self.id2label[class_idx],
|
||||||
|
"confidence": max_prob.item()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
app = Flask(__name__)
|
||||||
|
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型根路径(环境变量或默认路径)
|
||||||
|
classifier = ImageClassifier(MODEL_PATH)
|
||||||
|
|
||||||
|
@app.route('/v1/private/s782b4996', methods=['POST'])
|
||||||
|
def predict_single():
|
||||||
|
"""接收单张图片并返回最高置信度预测结果"""
|
||||||
|
# 检查是否有图片上传
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({"status": "error", "message": "请求中未包含图片"}), 400
|
||||||
|
|
||||||
|
image_file = request.files['image']
|
||||||
|
try:
|
||||||
|
# 读取图片
|
||||||
|
image = Image.open(BytesIO(image_file.read())).convert("RGB")
|
||||||
|
# 预测
|
||||||
|
result = classifier.predict_single_image(image)
|
||||||
|
return jsonify(result)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
return jsonify({"status": "healthy", "device": str(device)}), 200
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host='0.0.0.0', port=8000, debug=False)
|
||||||
80
model_test_caltech_http_cuda.py
Normal file
80
model_test_caltech_http_cuda.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
import os
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
# 设备配置
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"当前使用的设备: {device}")
|
||||||
|
|
||||||
|
class ImageClassifier:
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
self.model = AutoModelForImageClassification.from_pretrained(model_path)
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print(f"使用 {torch.cuda.device_count()} 块GPU")
|
||||||
|
self.model = torch.nn.DataParallel(self.model)
|
||||||
|
|
||||||
|
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
|
||||||
|
|
||||||
|
def predict_single_image(self, image) -> dict:
|
||||||
|
"""预测单张图片,返回置信度最高的结果"""
|
||||||
|
try:
|
||||||
|
# 处理图片
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
# 获取置信度最高的预测结果
|
||||||
|
max_prob, max_idx = probs.max(dim=1)
|
||||||
|
class_idx = max_idx.item()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"top_prediction": {
|
||||||
|
"class_id": class_idx,
|
||||||
|
"class_name": self.id2label[class_idx],
|
||||||
|
"confidence": max_prob.item()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
app = Flask(__name__)
|
||||||
|
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径)
|
||||||
|
classifier = ImageClassifier(MODEL_PATH)
|
||||||
|
|
||||||
|
@app.route('/v1/private/s782b4996', methods=['POST'])
|
||||||
|
def predict_single():
|
||||||
|
"""接收单张图片并返回最高置信度预测结果"""
|
||||||
|
# 检查是否有图片上传
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({"status": "error", "message": "请求中未包含图片"}), 400
|
||||||
|
|
||||||
|
image_file = request.files['image']
|
||||||
|
try:
|
||||||
|
# 读取图片
|
||||||
|
image = Image.open(BytesIO(image_file.read())).convert("RGB")
|
||||||
|
# 预测
|
||||||
|
result = classifier.predict_single_image(image)
|
||||||
|
return jsonify(result)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
return jsonify({"status": "healthy", "device": str(device)}), 200
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host='0.0.0.0', port=80, debug=False)
|
||||||
24
pyproject.toml
Normal file
24
pyproject.toml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[tool.black]
|
||||||
|
line-length = 80
|
||||||
|
target-version = ['py39']
|
||||||
|
|
||||||
|
[tool.flake8]
|
||||||
|
max-line-length = 88
|
||||||
|
count=true
|
||||||
|
per-file-ignores="./annotation/manager.py:F401"
|
||||||
|
exclude=["./label", "__pycache__", "./migrations", "./logs", "./pids", "./resources"]
|
||||||
|
ignore=["W503", "E203"]
|
||||||
|
enable-extensions="G"
|
||||||
|
application-import-names=["flake8-isort", "flake8-logging-format", "flake8-builtins"]
|
||||||
|
import-order-style="edited"
|
||||||
|
extend-ignore = ["E203", "E701"]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
py_version=39
|
||||||
|
profile="black"
|
||||||
|
multi_line_output=9
|
||||||
|
line_length=80
|
||||||
|
group_by_package=true
|
||||||
|
case_sensitive=true
|
||||||
|
skip_gitignore=true
|
||||||
|
|
||||||
13
requirements.txt
Normal file
13
requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
requests
|
||||||
|
ruamel.yaml
|
||||||
|
regex
|
||||||
|
pyyaml
|
||||||
|
websocket-client==0.44.0
|
||||||
|
pydantic==2.6.4
|
||||||
|
pydantic_core==2.16.3
|
||||||
|
Levenshtein
|
||||||
|
numpy
|
||||||
|
websockets
|
||||||
|
fabric
|
||||||
|
vmplatform==0.0.4
|
||||||
|
flask
|
||||||
114
run.py
Normal file
114
run.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from schemas.context import ASRContext
|
||||||
|
from utils.client import Client
|
||||||
|
from utils.evaluator import BaseEvaluator
|
||||||
|
from utils.logger import logger
|
||||||
|
from utils.service import register_sut
|
||||||
|
|
||||||
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||||
|
UNIT_TEST = os.getenv("UNIT_TEST", 0)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logger.info("执行……")
|
||||||
|
|
||||||
|
dataset_filepath = os.getenv(
|
||||||
|
"DATASET_FILEPATH",
|
||||||
|
"./tests/resources/en.zip",
|
||||||
|
)
|
||||||
|
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
|
||||||
|
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||||||
|
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||||||
|
detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
|
||||||
|
|
||||||
|
resource_name = os.getenv("BENCHMARK_NAME")
|
||||||
|
|
||||||
|
# 提交配置 & 启动被测服务
|
||||||
|
if os.getenv("DATASET_FILEPATH", ""):
|
||||||
|
from utils.helm import resource_check
|
||||||
|
|
||||||
|
with open(submit_config_filepath, "r") as fp:
|
||||||
|
st_config = yaml.safe_load(fp)
|
||||||
|
st_config["values"] = resource_check(st_config.get("values", {}))
|
||||||
|
if 'docker_images' in st_config:
|
||||||
|
sut_url = "ws://172.26.1.75:9827"
|
||||||
|
os.environ['test'] = '1'
|
||||||
|
elif 'docker_image' in st_config:
|
||||||
|
sut_url = register_sut(st_config, resource_name)
|
||||||
|
elif UNIT_TEST:
|
||||||
|
sut_url = "ws://172.27.231.36:80"
|
||||||
|
else:
|
||||||
|
logger.error("config 配置错误,没有 docker_image")
|
||||||
|
os._exit(1)
|
||||||
|
else:
|
||||||
|
os.environ['test'] = '1'
|
||||||
|
sut_url = "ws://172.27.231.36:80"
|
||||||
|
if UNIT_TEST:
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 数据集处理
|
||||||
|
local_dataset_path = "./dataset"
|
||||||
|
os.makedirs(local_dataset_path, exist_ok=True)
|
||||||
|
with zipfile.ZipFile(dataset_filepath) as zf:
|
||||||
|
zf.extractall(local_dataset_path)
|
||||||
|
config_path = os.path.join(local_dataset_path, "data.yaml")
|
||||||
|
with open(config_path, "r") as fp:
|
||||||
|
dataset_config = yaml.safe_load(fp)
|
||||||
|
|
||||||
|
# 数据集信息
|
||||||
|
dataset_global_config = dataset_config.get("global", {})
|
||||||
|
dataset_query = dataset_config.get("query_data", {})
|
||||||
|
|
||||||
|
evaluator = BaseEvaluator()
|
||||||
|
|
||||||
|
# 开始预测
|
||||||
|
for idx, query_item in enumerate(dataset_query):
|
||||||
|
gc.collect()
|
||||||
|
logger.info(f"开始执行 {idx} 条数据")
|
||||||
|
|
||||||
|
context = ASRContext(**dataset_global_config)
|
||||||
|
context.lang = query_item.get("lang", context.lang)
|
||||||
|
context.file_path = os.path.join(local_dataset_path, query_item["file"])
|
||||||
|
# context.audio_length = query_item["audio_length"]
|
||||||
|
|
||||||
|
interactions = Client(sut_url, context).action()
|
||||||
|
context.append_labels(query_item["voice"])
|
||||||
|
context.append_preds(
|
||||||
|
interactions["predict_data"],
|
||||||
|
interactions["send_time"],
|
||||||
|
interactions["recv_time"],
|
||||||
|
)
|
||||||
|
context.fail = interactions["fail"]
|
||||||
|
if IN_TEST:
|
||||||
|
with open('output.txt', 'w') as fp:
|
||||||
|
original_stdout = sys.stdout
|
||||||
|
sys.stdout = fp
|
||||||
|
print(context)
|
||||||
|
sys.stdout = original_stdout
|
||||||
|
evaluator.evaluate(context)
|
||||||
|
detail_case = evaluator.gen_detail_case()
|
||||||
|
with open(detail_cases_filepath, "a") as fp:
|
||||||
|
fp.write(json.dumps(detail_case.to_dict(), ensure_ascii=False) + "\n")
|
||||||
|
time.sleep(4)
|
||||||
|
|
||||||
|
evaluator.post_evaluate()
|
||||||
|
output_result = evaluator.gen_result()
|
||||||
|
# print(evaluator.__dict__)
|
||||||
|
logger.info("执行完成. Result = {output_result}")
|
||||||
|
|
||||||
|
with open(result_filepath, "w") as fp:
|
||||||
|
json.dump(output_result, fp, indent=2, ensure_ascii=False)
|
||||||
|
with open(bad_cases_filepath, "w") as fp:
|
||||||
|
fp.write("当前榜单不存在 Bad Case\n")
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
757
run_async_a10.py
Normal file
757
run_async_a10.py
Normal file
@@ -0,0 +1,757 @@
|
|||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
|
import fcntl
|
||||||
|
import gc
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import zipfile
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from fabric import Connection
|
||||||
|
from vmplatform import VMOS, Client, VMDataDisk
|
||||||
|
|
||||||
|
from schemas.context import ASRContext
|
||||||
|
from utils.client_async import ClientAsync
|
||||||
|
from utils.evaluator import BaseEvaluator
|
||||||
|
from utils.logger import logger
|
||||||
|
from utils.service import register_sut
|
||||||
|
|
||||||
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||||
|
UNIT_TEST = os.getenv("UNIT_TEST", 0)
|
||||||
|
|
||||||
|
DATASET_NUM = os.getenv("DATASET_NUM")
|
||||||
|
|
||||||
|
# vm榜单参数
|
||||||
|
SUT_TYPE = os.getenv("SUT_TYPE", "kubernetes")
|
||||||
|
SHARE_SUT = os.getenv("SHARE_SUT", "true") == "true"
|
||||||
|
VM_ID = 0
|
||||||
|
VM_IP = ""
|
||||||
|
do_deploy_chart = True
|
||||||
|
VM_CPU = int(os.getenv("VM_CPU", "2"))
|
||||||
|
VM_MEM = int(os.getenv("VM_MEM", "4096"))
|
||||||
|
MODEL_BASEPATH = os.getenv("MODEL_BASEPATH", "/tmp/customer/leaderboard/pc_asr")
|
||||||
|
MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
|
||||||
|
SSH_KEY_DIR = os.getenv("SSH_KEY_DIR", "/workspace")
|
||||||
|
SSH_PUBLIC_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa.pub")
|
||||||
|
SSH_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa")
|
||||||
|
|
||||||
|
CONNECT_KWARGS = {"key_filename": SSH_KEY_FILE}
|
||||||
|
|
||||||
|
# 共享sut参数
|
||||||
|
JOB_ID = os.getenv("JOB_ID")
|
||||||
|
dirname = "/tmp/submit_private/sut_share"
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
SUT_SHARE_LOCK = os.path.join(dirname, "lock.lock")
|
||||||
|
SUT_SHARE_USE_LOCK = os.path.join(dirname, "use.lock")
|
||||||
|
SUT_SHARE_STATUS = os.path.join(dirname, "status.json")
|
||||||
|
SUT_SHARE_JOB_STATUS = os.path.join(dirname, f"job_status.{JOB_ID}")
|
||||||
|
SUT_SHARE_PUBLIC_FAIL = os.path.join(dirname, "one_job_failed")
|
||||||
|
fd_lock = open(SUT_SHARE_USE_LOCK, "a")
|
||||||
|
|
||||||
|
|
||||||
|
def clean_vm_atexit():
|
||||||
|
global VM_ID, do_deploy_chart
|
||||||
|
if not VM_ID:
|
||||||
|
return
|
||||||
|
if not do_deploy_chart:
|
||||||
|
return
|
||||||
|
logger.info("删除vm")
|
||||||
|
vmclient = Client()
|
||||||
|
err_msg = vmclient.delete_vm(VM_ID)
|
||||||
|
if err_msg:
|
||||||
|
logger.warning(f"删除vm失败: {err_msg}")
|
||||||
|
|
||||||
|
|
||||||
|
def put_file_to_vm(c: Connection, local_path: str, remote_path: str):
|
||||||
|
logger.info(f"uploading file {local_path} to {remote_path}")
|
||||||
|
result = c.put(local_path, remote_path)
|
||||||
|
logger.info("uploaded {0.local} to {0.remote}".format(result))
|
||||||
|
|
||||||
|
|
||||||
|
def deploy_windows_sut():
|
||||||
|
global VM_ID
|
||||||
|
global VM_IP
|
||||||
|
|
||||||
|
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
|
||||||
|
with open(submit_config_filepath, "r") as fp:
|
||||||
|
st_config = yaml.safe_load(fp)
|
||||||
|
assert "model" in st_config, "未配置model"
|
||||||
|
assert "model_key" in st_config, "未配置model_key"
|
||||||
|
assert "config.json" in st_config, "未配置config.json"
|
||||||
|
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
|
||||||
|
assert len(nfs) > 0, "未配置nfs"
|
||||||
|
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
|
||||||
|
|
||||||
|
model = st_config["model"]
|
||||||
|
model_key = st_config["model_key"]
|
||||||
|
model_path = ""
|
||||||
|
config = st_config["config.json"]
|
||||||
|
exist = False
|
||||||
|
for nfs_item in nfs:
|
||||||
|
if nfs_item["name"] == model_key:
|
||||||
|
exist = True
|
||||||
|
if nfs_item["source"] == "ceph_customer":
|
||||||
|
model_path = os.path.join(
|
||||||
|
"/tmp/customer",
|
||||||
|
nfs_item["srcRelativePath"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = os.path.join(
|
||||||
|
"/tmp/juicefs",
|
||||||
|
nfs_item["srcRelativePath"],
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if not exist:
|
||||||
|
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
|
||||||
|
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
|
||||||
|
model_dir = os.path.basename(model_path).split(".")[0]
|
||||||
|
config["model_path"] = f"E:\\model\\{model_dir}"
|
||||||
|
with open(config_path, "w") as fp:
|
||||||
|
json.dump(config, fp, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
vmclient = Client()
|
||||||
|
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||||||
|
sshpublickey = fp.read().rstrip()
|
||||||
|
VM_ID = vmclient.create_vm(
|
||||||
|
"amd64",
|
||||||
|
VMOS.windows10,
|
||||||
|
VM_CPU,
|
||||||
|
VM_MEM,
|
||||||
|
"leaderboard-%s-submit-%s-job-%s"
|
||||||
|
% (
|
||||||
|
os.getenv("BENCHMARK_NAME"),
|
||||||
|
os.getenv("SUBMIT_ID"),
|
||||||
|
os.getenv("JOB_ID"),
|
||||||
|
),
|
||||||
|
sshpublickey,
|
||||||
|
datadisks=[
|
||||||
|
VMDataDisk(
|
||||||
|
size=50,
|
||||||
|
disk_type="ssd",
|
||||||
|
mount_path="/",
|
||||||
|
filesystem="NTFS",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
atexit.register(clean_vm_atexit)
|
||||||
|
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
|
||||||
|
VM_IP = vmclient.wait_until_vm_running(VM_ID)
|
||||||
|
logger.info("vm created successfully, vm_ip: %s", VM_IP)
|
||||||
|
|
||||||
|
def sut_startup():
|
||||||
|
with Connection(
|
||||||
|
VM_IP,
|
||||||
|
"administrator",
|
||||||
|
connect_kwargs=CONNECT_KWARGS,
|
||||||
|
) as c:
|
||||||
|
script_path = "E:\\base\\asr\\faster-whisper\\server"
|
||||||
|
script_path = "E:\\install\\asr\\sensevoice\\server"
|
||||||
|
bat_filepath = f"{script_path}\\start.bat"
|
||||||
|
config_filepath = "E:\\submit\\config.json"
|
||||||
|
result = c.run("")
|
||||||
|
assert result.ok
|
||||||
|
c.run(
|
||||||
|
f'cd /d {script_path} & set "EDGE_ML_ENV_HOME=E:\\install" & {bat_filepath} {config_filepath}',
|
||||||
|
warn=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with Connection(
|
||||||
|
VM_IP,
|
||||||
|
"administrator",
|
||||||
|
connect_kwargs=CONNECT_KWARGS,
|
||||||
|
) as c:
|
||||||
|
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
|
||||||
|
filename = os.path.basename(model_filepath)
|
||||||
|
put_file_to_vm(c, model_filepath, "/E:/")
|
||||||
|
|
||||||
|
result = c.run("mkdir E:\\base")
|
||||||
|
assert result.ok
|
||||||
|
result = c.run("mkdir E:\\model")
|
||||||
|
assert result.ok
|
||||||
|
result = c.run("mkdir E:\\submit")
|
||||||
|
assert result.ok
|
||||||
|
|
||||||
|
result = c.run(
|
||||||
|
f"tar zxvf E:\\{filename} -C E:\\base --strip-components 1"
|
||||||
|
)
|
||||||
|
assert result.ok
|
||||||
|
|
||||||
|
result = c.run("E:\\base\\setup-win.bat E:\\install")
|
||||||
|
assert result.ok
|
||||||
|
|
||||||
|
put_file_to_vm(c, config_path, "/E:/submit")
|
||||||
|
put_file_to_vm(c, model_path, "/E:/model")
|
||||||
|
result = c.run(
|
||||||
|
f"tar zxvf E:\\model\\{os.path.basename(model_path)} -C E:\\model"
|
||||||
|
)
|
||||||
|
assert result.ok
|
||||||
|
threading.Thread(target=sut_startup, daemon=True).start()
|
||||||
|
time.sleep(60)
|
||||||
|
|
||||||
|
return f"ws://{VM_IP}:{config['port']}"
|
||||||
|
|
||||||
|
|
||||||
|
def deploy_macos_sut():
|
||||||
|
global VM_ID
|
||||||
|
global VM_IP
|
||||||
|
|
||||||
|
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
|
||||||
|
with open(submit_config_filepath, "r") as fp:
|
||||||
|
st_config = yaml.safe_load(fp)
|
||||||
|
assert "model" in st_config, "未配置model"
|
||||||
|
assert "model_key" in st_config, "未配置model_key"
|
||||||
|
assert "config.json" in st_config, "未配置config.json"
|
||||||
|
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
|
||||||
|
assert len(nfs) > 0, "未配置nfs"
|
||||||
|
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
|
||||||
|
|
||||||
|
model = st_config["model"]
|
||||||
|
model_key = st_config["model_key"]
|
||||||
|
model_path = ""
|
||||||
|
config = st_config["config.json"]
|
||||||
|
exist = False
|
||||||
|
for nfs_item in nfs:
|
||||||
|
if nfs_item["name"] == model_key:
|
||||||
|
exist = True
|
||||||
|
if nfs_item["source"] == "ceph_customer":
|
||||||
|
model_path = os.path.join(
|
||||||
|
"/tmp/customer",
|
||||||
|
nfs_item["srcRelativePath"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = os.path.join(
|
||||||
|
"/tmp/juicefs",
|
||||||
|
nfs_item["srcRelativePath"],
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if not exist:
|
||||||
|
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
|
||||||
|
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
|
||||||
|
model_dir = os.path.basename(model_path).split(".")[0]
|
||||||
|
|
||||||
|
vmclient = Client()
|
||||||
|
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||||||
|
sshpublickey = fp.read().rstrip()
|
||||||
|
VM_ID = vmclient.create_vm(
|
||||||
|
"amd64",
|
||||||
|
VMOS.macos12,
|
||||||
|
VM_CPU,
|
||||||
|
VM_MEM,
|
||||||
|
"leaderboard-%s-submit-%s-job-%s"
|
||||||
|
% (
|
||||||
|
os.getenv("BENCHMARK_NAME"),
|
||||||
|
os.getenv("SUBMIT_ID"),
|
||||||
|
os.getenv("JOB_ID"),
|
||||||
|
),
|
||||||
|
sshpublickey,
|
||||||
|
datadisks=[
|
||||||
|
VMDataDisk(
|
||||||
|
size=50,
|
||||||
|
disk_type="ssd",
|
||||||
|
mount_path="/",
|
||||||
|
filesystem="apfs",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
atexit.register(clean_vm_atexit)
|
||||||
|
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
|
||||||
|
VM_IP = vmclient.wait_until_vm_running(VM_ID)
|
||||||
|
logger.info("vm created successfully, vm_ip: %s", VM_IP)
|
||||||
|
|
||||||
|
with Connection(
|
||||||
|
VM_IP,
|
||||||
|
"admin",
|
||||||
|
connect_kwargs=CONNECT_KWARGS,
|
||||||
|
) as c:
|
||||||
|
result = c.run("ls -d /Volumes/data*")
|
||||||
|
assert result.ok
|
||||||
|
volume_path = result.stdout.strip()
|
||||||
|
|
||||||
|
config["model_path"] = f"{volume_path}/model/{model_dir}"
|
||||||
|
with open(config_path, "w") as fp:
|
||||||
|
json.dump(config, fp, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
def sut_startup():
|
||||||
|
with Connection(
|
||||||
|
VM_IP,
|
||||||
|
"admin",
|
||||||
|
connect_kwargs=CONNECT_KWARGS,
|
||||||
|
) as c:
|
||||||
|
script_path = f"{volume_path}/install/asr/sensevoice/server"
|
||||||
|
startsh = f"{script_path}/start.sh"
|
||||||
|
config_filepath = f"{volume_path}/submit/config.json"
|
||||||
|
c.run(
|
||||||
|
f"cd {script_path} && sh {startsh} {config_filepath}",
|
||||||
|
warn=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with Connection(
|
||||||
|
VM_IP,
|
||||||
|
"admin",
|
||||||
|
connect_kwargs=CONNECT_KWARGS,
|
||||||
|
) as c:
|
||||||
|
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
|
||||||
|
filename = os.path.basename(model_filepath)
|
||||||
|
put_file_to_vm(c, model_filepath, f"{volume_path}")
|
||||||
|
|
||||||
|
result = c.run(f"mkdir {volume_path}/base")
|
||||||
|
assert result.ok
|
||||||
|
result = c.run(f"mkdir {volume_path}/model")
|
||||||
|
assert result.ok
|
||||||
|
result = c.run(f"mkdir {volume_path}/submit")
|
||||||
|
assert result.ok
|
||||||
|
|
||||||
|
result = c.run(
|
||||||
|
f"tar zxvf {volume_path}/{filename} -C {volume_path}/base --strip-components 1" # noqa: E501
|
||||||
|
)
|
||||||
|
assert result.ok
|
||||||
|
|
||||||
|
result = c.run(
|
||||||
|
f"sh {volume_path}/base/setup-mac.sh {volume_path}/install x64"
|
||||||
|
)
|
||||||
|
assert result.ok
|
||||||
|
|
||||||
|
put_file_to_vm(c, config_path, f"{volume_path}/submit")
|
||||||
|
put_file_to_vm(c, model_path, f"{volume_path}/model")
|
||||||
|
result = c.run(
|
||||||
|
f"tar zxvf {volume_path}/model/{os.path.basename(model_path)} -C {volume_path}/model" # noqa: E501
|
||||||
|
)
|
||||||
|
assert result.ok
|
||||||
|
threading.Thread(target=sut_startup, daemon=True).start()
|
||||||
|
time.sleep(60)
|
||||||
|
|
||||||
|
return f"ws://{VM_IP}:{config['port']}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_sut_url_vm(vm_type: str):
|
||||||
|
global VM_ID
|
||||||
|
global VM_IP
|
||||||
|
global do_deploy_chart
|
||||||
|
|
||||||
|
do_deploy_chart = True
|
||||||
|
# 拉起SUT
|
||||||
|
|
||||||
|
def check_job_failed():
|
||||||
|
while True:
|
||||||
|
time.sleep(30)
|
||||||
|
if os.path.exists(SUT_SHARE_PUBLIC_FAIL):
|
||||||
|
logger.error("there is a job failed in current submit")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
sut_url = ""
|
||||||
|
threading.Thread(target=check_job_failed, daemon=True).start()
|
||||||
|
if SHARE_SUT:
|
||||||
|
|
||||||
|
time.sleep(10 * random.random())
|
||||||
|
try:
|
||||||
|
open(SUT_SHARE_LOCK, "x").close()
|
||||||
|
except Exception:
|
||||||
|
do_deploy_chart = False
|
||||||
|
|
||||||
|
start_at = time.time()
|
||||||
|
|
||||||
|
def file_last_updated_at(file: str):
|
||||||
|
return os.stat(file).st_mtime if os.path.exists(file) else start_at
|
||||||
|
|
||||||
|
if not do_deploy_chart:
|
||||||
|
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||||
|
f.write("waiting")
|
||||||
|
while (
|
||||||
|
time.time() - file_last_updated_at(SUT_SHARE_STATUS)
|
||||||
|
<= 60 * 60 * 24
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Waiting sut application to be deployed by another job"
|
||||||
|
)
|
||||||
|
time.sleep(10 + random.random())
|
||||||
|
if os.path.exists(SUT_SHARE_STATUS):
|
||||||
|
get_status = False
|
||||||
|
for _ in range(10):
|
||||||
|
try:
|
||||||
|
with open(SUT_SHARE_STATUS, "r") as f:
|
||||||
|
status = json.load(f)
|
||||||
|
get_status = True
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
time.sleep(1 + random.random())
|
||||||
|
continue
|
||||||
|
if not get_status:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to get status of sut application"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
status.get("status") != "failed"
|
||||||
|
), "Failed to deploy sut application, \
|
||||||
|
please check other job logs"
|
||||||
|
if status.get("status") == "running":
|
||||||
|
VM_ID = status.get("vmid")
|
||||||
|
VM_IP = status.get("vmip")
|
||||||
|
sut_url = status.get("sut_url")
|
||||||
|
with open(SSH_PUBLIC_KEY_FILE, "w") as fp:
|
||||||
|
fp.write(status.get("pubkey"))
|
||||||
|
with open(SSH_KEY_FILE, "w") as fp:
|
||||||
|
fp.write(status.get("prikey"))
|
||||||
|
logger.info("Successfully get deployed sut application")
|
||||||
|
break
|
||||||
|
|
||||||
|
if do_deploy_chart:
|
||||||
|
try:
|
||||||
|
fcntl.flock(fd_lock, fcntl.LOCK_EX)
|
||||||
|
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||||
|
f.write("waiting")
|
||||||
|
pending = True
|
||||||
|
|
||||||
|
def update_status():
|
||||||
|
while pending:
|
||||||
|
time.sleep(30)
|
||||||
|
if not pending:
|
||||||
|
break
|
||||||
|
with open(SUT_SHARE_STATUS, "w") as f:
|
||||||
|
json.dump({"status": "pending"}, f)
|
||||||
|
|
||||||
|
threading.Thread(target=update_status, daemon=True).start()
|
||||||
|
if vm_type == "windows":
|
||||||
|
sut_url = deploy_windows_sut()
|
||||||
|
else:
|
||||||
|
sut_url = deploy_macos_sut()
|
||||||
|
except Exception:
|
||||||
|
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
|
||||||
|
with open(SUT_SHARE_STATUS, "w") as f:
|
||||||
|
json.dump({"status": "failed"}, f)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
pending = False
|
||||||
|
with open(SUT_SHARE_STATUS, "w") as f:
|
||||||
|
pubkey = ""
|
||||||
|
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||||||
|
pubkey = fp.read().rstrip()
|
||||||
|
prikey = ""
|
||||||
|
with open(SSH_KEY_FILE, "r") as fp:
|
||||||
|
prikey = fp.read()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"status": "running",
|
||||||
|
"vmid": VM_ID,
|
||||||
|
"vmip": VM_IP,
|
||||||
|
"pubkey": pubkey,
|
||||||
|
"sut_url": sut_url,
|
||||||
|
"prikey": prikey,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
while True:
|
||||||
|
time.sleep(5 + random.random())
|
||||||
|
try:
|
||||||
|
fcntl.flock(fd_lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.info("尝试抢占调用sut失败,继续等待 5s ...")
|
||||||
|
|
||||||
|
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||||
|
f.write("running")
|
||||||
|
|
||||||
|
return sut_url
|
||||||
|
|
||||||
|
|
||||||
|
def get_sut_url():
|
||||||
|
if SUT_TYPE in ("windows", "macos"):
|
||||||
|
return get_sut_url_vm(SUT_TYPE)
|
||||||
|
|
||||||
|
submit_config_filepath = os.getenv(
|
||||||
|
"SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config"
|
||||||
|
)
|
||||||
|
CPU = os.getenv("SUT_CPU", "2")
|
||||||
|
MEMORY = os.getenv("SUT_MEMORY", "4Gi")
|
||||||
|
resource_name = os.getenv("BENCHMARK_NAME")
|
||||||
|
|
||||||
|
# 任务信息
|
||||||
|
# 斯拉夫语族:俄语、波兰语
|
||||||
|
# 日耳曼语族:英语、德语、荷兰语
|
||||||
|
# 拉丁语族(罗曼语族):西班牙语、葡萄牙语、法国语、意大利语
|
||||||
|
# 闪米特语族:阿拉伯语、希伯来语
|
||||||
|
|
||||||
|
# 提交配置 & 启动被测服务
|
||||||
|
if os.getenv("DATASET_FILEPATH", ""):
|
||||||
|
with open(submit_config_filepath, "r") as fp:
|
||||||
|
st_config = yaml.safe_load(fp)
|
||||||
|
if "values" not in st_config:
|
||||||
|
st_config["values"] = {}
|
||||||
|
st_config["values"]["resources"] = {}
|
||||||
|
st_config["values"]["resources"]["limits"] = {}
|
||||||
|
st_config["values"]["resources"]["limits"]["cpu"] = CPU
|
||||||
|
st_config["values"]["resources"]["limits"]["memory"] = MEMORY
|
||||||
|
# st_config["values"]['resources']['limits']['nvidia.com/gpu'] = '1'
|
||||||
|
# st_config["values"]['resources']['limits']['nvidia.com/gpumem'] = "1843"
|
||||||
|
# st_config["values"]['resources']['limits']['nvidia.com/gpucores'] = "8"
|
||||||
|
st_config["values"]["resources"]["requests"] = {}
|
||||||
|
st_config["values"]["resources"]["requests"]["cpu"] = CPU
|
||||||
|
st_config["values"]["resources"]["requests"]["memory"] = MEMORY
|
||||||
|
# st_config["values"]['resources']['requests']['nvidia.com/gpu'] = '1'
|
||||||
|
# st_config["values"]['resources']['requests']['nvidia.com/gpumem'] = "1843"
|
||||||
|
# st_config["values"]['resources']['requests']['nvidia.com/gpucores'] = "8"
|
||||||
|
# st_config['values']['nodeSelector'] = {}
|
||||||
|
# st_config["values"]["nodeSelector"][
|
||||||
|
# "contest.4pd.io/accelerator"
|
||||||
|
# ] = "A10vgpu"
|
||||||
|
# st_config['values']['tolerations'] = []
|
||||||
|
# toleration_item = {}
|
||||||
|
# toleration_item['key'] = 'hosttype'
|
||||||
|
# toleration_item['operator'] = 'Equal'
|
||||||
|
# toleration_item['value'] = 'vgpu'
|
||||||
|
# toleration_item['effect'] = 'NoSchedule'
|
||||||
|
# st_config['values']['tolerations'].append(toleration_item)
|
||||||
|
if os.getenv("RESOURCE_TYPE", "cpu") == "cpu":
|
||||||
|
values = st_config["values"]
|
||||||
|
limits = values.get("resources", {}).get("limits", {})
|
||||||
|
requests = values.get("resources", {}).get("requests", {})
|
||||||
|
if (
|
||||||
|
"nvidia.com/gpu" in limits
|
||||||
|
or "nvidia.com/gpumem" in limits
|
||||||
|
or "nvidia.com/gpucores" in limits
|
||||||
|
or "nvidia.com/gpu" in requests
|
||||||
|
or "nvidia.com/gpumem" in requests
|
||||||
|
or "nvidia.com/gpucores" in requests
|
||||||
|
):
|
||||||
|
raise Exception("禁止使用GPU!")
|
||||||
|
else:
|
||||||
|
vgpu_num = int(os.getenv("SUT_VGPU", "3"))
|
||||||
|
st_config["values"]["resources"]["limits"]["nvidia.com/gpu"] = (
|
||||||
|
str(vgpu_num)
|
||||||
|
)
|
||||||
|
st_config["values"]["resources"]["limits"][
|
||||||
|
"nvidia.com/gpumem"
|
||||||
|
] = str(1843 * vgpu_num)
|
||||||
|
st_config["values"]["resources"]["limits"][
|
||||||
|
"nvidia.com/gpucores"
|
||||||
|
] = str(8 * vgpu_num)
|
||||||
|
st_config["values"]["resources"]["requests"][
|
||||||
|
"nvidia.com/gpu"
|
||||||
|
] = str(vgpu_num)
|
||||||
|
st_config["values"]["resources"]["requests"][
|
||||||
|
"nvidia.com/gpumem"
|
||||||
|
] = str(1843 * vgpu_num)
|
||||||
|
st_config["values"]["resources"]["requests"][
|
||||||
|
"nvidia.com/gpucores"
|
||||||
|
] = str(8 * vgpu_num)
|
||||||
|
st_config["values"]["nodeSelector"] = {}
|
||||||
|
st_config["values"]["nodeSelector"][
|
||||||
|
"contest.4pd.io/accelerator"
|
||||||
|
] = "A10vgpu"
|
||||||
|
st_config["values"]["tolerations"] = []
|
||||||
|
toleration_item = {}
|
||||||
|
toleration_item["key"] = "hosttype"
|
||||||
|
toleration_item["operator"] = "Equal"
|
||||||
|
toleration_item["value"] = "vgpu"
|
||||||
|
toleration_item["effect"] = "NoSchedule"
|
||||||
|
st_config["values"]["tolerations"].append(toleration_item)
|
||||||
|
if "docker_images" in st_config:
|
||||||
|
sut_url = "ws://172.26.1.75:9827"
|
||||||
|
os.environ["test"] = "1"
|
||||||
|
elif "docker_image" in st_config:
|
||||||
|
sut_url = register_sut(st_config, resource_name)
|
||||||
|
elif UNIT_TEST:
|
||||||
|
sut_url = "ws://172.27.231.36:80"
|
||||||
|
else:
|
||||||
|
logger.error("config 配置错误,没有 docker_image")
|
||||||
|
os._exit(1)
|
||||||
|
return sut_url
|
||||||
|
else:
|
||||||
|
os.environ["test"] = "1"
|
||||||
|
sut_url = "ws://172.27.231.36:80"
|
||||||
|
sut_url = "ws://172.26.1.75:9827"
|
||||||
|
return sut_url
|
||||||
|
|
||||||
|
|
||||||
|
def load_merge_dataset(dataset_filepath: str) -> dict:
|
||||||
|
local_dataset_path = "./dataset"
|
||||||
|
os.makedirs(local_dataset_path, exist_ok=True)
|
||||||
|
with zipfile.ZipFile(dataset_filepath) as zf:
|
||||||
|
zf.extractall(local_dataset_path)
|
||||||
|
|
||||||
|
config = {}
|
||||||
|
sub_datasets = os.listdir(local_dataset_path)
|
||||||
|
for sub_dataset in sub_datasets:
|
||||||
|
if sub_dataset.startswith("asr."):
|
||||||
|
lang = sub_dataset[4:]
|
||||||
|
lang_path = os.path.join(local_dataset_path, lang)
|
||||||
|
os.makedirs(lang_path, exist_ok=True)
|
||||||
|
with zipfile.ZipFile(
|
||||||
|
os.path.join(local_dataset_path, sub_dataset)
|
||||||
|
) as zf:
|
||||||
|
zf.extractall(lang_path)
|
||||||
|
lang_config_path = os.path.join(lang_path, "data.yaml")
|
||||||
|
with open(lang_config_path, "r") as fp:
|
||||||
|
lang_config = yaml.safe_load(fp)
|
||||||
|
audio_lengths = {}
|
||||||
|
for query_item in lang_config.get("query_data", []):
|
||||||
|
audio_path = os.path.join(
|
||||||
|
lang_path,
|
||||||
|
query_item["file"],
|
||||||
|
)
|
||||||
|
query_item["file"] = audio_path
|
||||||
|
audio_lengths[query_item["file"]] = os.path.getsize(
|
||||||
|
audio_path,
|
||||||
|
)
|
||||||
|
lang_config["query_data"] = sorted(
|
||||||
|
lang_config.get("query_data", []),
|
||||||
|
key=lambda x: audio_lengths[x["file"]],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
length = 0.0
|
||||||
|
for query_item in lang_config["query_data"]:
|
||||||
|
audio_length = audio_lengths[query_item["file"]]
|
||||||
|
length += audio_length / 32000
|
||||||
|
idx += 1
|
||||||
|
# 每个语言限制半个小时长度
|
||||||
|
if length >= 30 * 60:
|
||||||
|
break
|
||||||
|
|
||||||
|
lang_config["query_data"] = lang_config["query_data"][:idx]
|
||||||
|
config[lang] = lang_config
|
||||||
|
|
||||||
|
config["query_data"] = []
|
||||||
|
for lang, lang_config in config.items():
|
||||||
|
if lang == "query_data":
|
||||||
|
continue
|
||||||
|
for query_item in lang_config["query_data"]:
|
||||||
|
config["query_data"].append(
|
||||||
|
{
|
||||||
|
**query_item,
|
||||||
|
"lang": lang,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
random.Random(0).shuffle(config["query_data"])
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_failed():
|
||||||
|
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dataset_filepath = os.getenv(
|
||||||
|
"DATASET_FILEPATH",
|
||||||
|
"/Users/4paradigm/Projects/dataset/asr/de.zip",
|
||||||
|
# "./tests/resources/en.zip",
|
||||||
|
)
|
||||||
|
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||||||
|
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||||||
|
detail_cases_filepath = os.getenv(
|
||||||
|
"DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl"
|
||||||
|
)
|
||||||
|
thread_num = int(os.getenv("THREAD_NUM", "1"))
|
||||||
|
|
||||||
|
# 数据集处理
|
||||||
|
config = {}
|
||||||
|
if os.getenv("MERGE_DATASET", "1"):
|
||||||
|
config = load_merge_dataset(dataset_filepath)
|
||||||
|
dataset_query = config["query_data"]
|
||||||
|
else:
|
||||||
|
local_dataset_path = "./dataset"
|
||||||
|
os.makedirs(local_dataset_path, exist_ok=True)
|
||||||
|
with zipfile.ZipFile(dataset_filepath) as zf:
|
||||||
|
zf.extractall(local_dataset_path)
|
||||||
|
config_path = os.path.join(local_dataset_path, "data.yaml")
|
||||||
|
with open(config_path, "r") as fp:
|
||||||
|
dataset_config = yaml.safe_load(fp)
|
||||||
|
# 读取所有的音频,进而获得音频的总长度,最后按照音频长度对 query_data 进行降序排序
|
||||||
|
lang = os.getenv("lang")
|
||||||
|
if lang is None:
|
||||||
|
lang = dataset_config.get("global", {}).get("lang", "en")
|
||||||
|
audio_lengths = []
|
||||||
|
for query_item in dataset_config.get("query_data", []):
|
||||||
|
query_item["lang"] = lang
|
||||||
|
audio_path = os.path.join(local_dataset_path, query_item["file"])
|
||||||
|
query_item["file"] = audio_path
|
||||||
|
audio_lengths.append(os.path.getsize(audio_path) / 1024 / 1024)
|
||||||
|
dataset_config["query_data"] = sorted(
|
||||||
|
dataset_config.get("query_data", []),
|
||||||
|
key=lambda x: audio_lengths[dataset_config["query_data"].index(x)],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
# 数据集信息
|
||||||
|
# dataset_global_config = dataset_config.get("global", {})
|
||||||
|
dataset_query = dataset_config.get("query_data", {})
|
||||||
|
config[lang] = dataset_config
|
||||||
|
|
||||||
|
# sut url
|
||||||
|
sut_url = get_sut_url()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 开始测试
|
||||||
|
logger.info("开始执行")
|
||||||
|
evaluator = BaseEvaluator()
|
||||||
|
future_list = []
|
||||||
|
with ThreadPoolExecutor(max_workers=thread_num) as executor:
|
||||||
|
for idx, query_item in enumerate(dataset_query):
|
||||||
|
context = ASRContext(
|
||||||
|
**config[query_item["lang"]].get("global", {}),
|
||||||
|
)
|
||||||
|
context.lang = query_item["lang"]
|
||||||
|
context.file_path = query_item["file"]
|
||||||
|
context.append_labels(query_item["voice"])
|
||||||
|
future = executor.submit(
|
||||||
|
ClientAsync(sut_url, context, idx).action
|
||||||
|
)
|
||||||
|
future_list.append(future)
|
||||||
|
for future in concurrent.futures.as_completed(future_list):
|
||||||
|
context = future.result()
|
||||||
|
evaluator.evaluate(context)
|
||||||
|
detail_case = evaluator.gen_detail_case()
|
||||||
|
with open(detail_cases_filepath, "a") as fp:
|
||||||
|
fp.write(
|
||||||
|
json.dumps(
|
||||||
|
detail_case.to_dict(),
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
+ "\n",
|
||||||
|
)
|
||||||
|
del context
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
evaluator.post_evaluate()
|
||||||
|
output_result = evaluator.gen_result()
|
||||||
|
logger.info("执行完成")
|
||||||
|
|
||||||
|
with open(result_filepath, "w") as fp:
|
||||||
|
json.dump(output_result, fp, indent=2, ensure_ascii=False)
|
||||||
|
with open(bad_cases_filepath, "w") as fp:
|
||||||
|
fp.write("当前榜单不存在 Bad Case\n")
|
||||||
|
|
||||||
|
if SHARE_SUT:
|
||||||
|
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||||||
|
f.write("success")
|
||||||
|
|
||||||
|
fcntl.flock(fd_lock, fcntl.LOCK_UN)
|
||||||
|
fd_lock.close()
|
||||||
|
while SHARE_SUT and do_deploy_chart:
|
||||||
|
time.sleep(30)
|
||||||
|
success_num = 0
|
||||||
|
for job_status_file in glob.glob(dirname + "/job_status.*"):
|
||||||
|
with open(job_status_file, "r") as f:
|
||||||
|
job_status = f.read()
|
||||||
|
success_num += job_status == "success"
|
||||||
|
if success_num == int(DATASET_NUM):
|
||||||
|
break
|
||||||
|
logger.info("Waiting for all jobs to complete")
|
||||||
|
except Exception:
|
||||||
|
if SHARE_SUT:
|
||||||
|
postprocess_failed()
|
||||||
|
raise
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
923
run_callback.py
Normal file
923
run_callback.py
Normal file
@@ -0,0 +1,923 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from schemas.dataset import QueryData
|
||||||
|
from utils.client_callback import ClientCallback, EvaluateResult, StopException
|
||||||
|
from utils.logger import log
|
||||||
|
from utils.service import register_sut
|
||||||
|
from utils.update_submit import change_product_available
|
||||||
|
from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml
|
||||||
|
from utils.leaderboard import change_product_unavailable
|
||||||
|
|
||||||
|
|
||||||
|
lck = threading.Lock()
|
||||||
|
|
||||||
|
# Environment variables by leaderboard
|
||||||
|
DATASET_FILEPATH = os.environ["DATASET_FILEPATH"]
|
||||||
|
RESULT_FILEPATH = os.environ["RESULT_FILEPATH"]
|
||||||
|
|
||||||
|
DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"]
|
||||||
|
SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"]
|
||||||
|
BENCHMARK_NAME = os.environ["BENCHMARK_NAME"]
|
||||||
|
TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1))
|
||||||
|
THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8))
|
||||||
|
|
||||||
|
log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}")
|
||||||
|
workspace_path = "/tmp/workspace"
|
||||||
|
|
||||||
|
|
||||||
|
# Environment variables by kubernetes
|
||||||
|
MY_POD_IP = os.environ["MY_POD_IP"]
|
||||||
|
|
||||||
|
# constants
|
||||||
|
RESOURCE_NAME = BENCHMARK_NAME
|
||||||
|
|
||||||
|
# Environment variables by judge_flow_config
|
||||||
|
LANG = os.getenv("lang")
|
||||||
|
SUT_CPU = os.getenv("SUT_CPU", "2")
|
||||||
|
SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi")
|
||||||
|
SUT_VGPU = os.getenv("SUT_VGPU", "1")
|
||||||
|
#SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU)))
|
||||||
|
#SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU)))
|
||||||
|
SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100")
|
||||||
|
RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu")
|
||||||
|
assert RESOURCE_TYPE in [
|
||||||
|
"cpu",
|
||||||
|
"vgpu",
|
||||||
|
], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu"
|
||||||
|
|
||||||
|
|
||||||
|
unzip_dir(DATASET_FILEPATH, workspace_path)
|
||||||
|
|
||||||
|
def get_sut_url_kubernetes():
|
||||||
|
with open(SUBMIT_CONFIG_FILEPATH, "r") as f:
|
||||||
|
submit_config = yaml.safe_load(f)
|
||||||
|
assert isinstance(submit_config, dict)
|
||||||
|
|
||||||
|
submit_config.setdefault("values", {})
|
||||||
|
|
||||||
|
submit_config["values"]["containers"] = [
|
||||||
|
{
|
||||||
|
"name": "corex-container",
|
||||||
|
"image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像
|
||||||
|
"command": ["sleep"], # 替换为你的模型启动命令,使用python解释器
|
||||||
|
"args": ["3600"], # 替换为你的模型参数,运行我的推理脚本
|
||||||
|
|
||||||
|
# 添加存储卷挂载
|
||||||
|
#"volumeMounts": [
|
||||||
|
# {
|
||||||
|
# "name": "model-volume",
|
||||||
|
# "mountPath": "/model" # 挂载到/model目录
|
||||||
|
# }
|
||||||
|
#]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 添加存储卷配置
|
||||||
|
submit_config["values"]["volumes"] = [
|
||||||
|
{
|
||||||
|
"name": "model-volume",
|
||||||
|
"persistentVolumeClaim": {
|
||||||
|
"claimName": "sid-model-pvc" # 使用已有的PVC
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Inject specified cpu and memory
|
||||||
|
resource = {
|
||||||
|
"cpu": SUT_CPU,
|
||||||
|
"memory": SUT_MEMORY,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
submit_config["values"]["resources"] = {
|
||||||
|
"requests":{},
|
||||||
|
"limits": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
limits = submit_config["values"]["resources"]["limits"]
|
||||||
|
requests = submit_config["values"]["resources"]["requests"]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# ########## 关键修改:替换为iluvatar GPU配置 ##########
|
||||||
|
if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU
|
||||||
|
# 替换nvidia资源键为iluvatar.ai/gpu
|
||||||
|
vgpu_resource = {
|
||||||
|
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
|
||||||
|
# 若需要其他资源(如显存),按你的K8s配置补充,例如:
|
||||||
|
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
|
||||||
|
}
|
||||||
|
limits.update(vgpu_resource)
|
||||||
|
requests.update(vgpu_resource)
|
||||||
|
# 节点选择器:替换为你的accelerator标签
|
||||||
|
submit_config["values"]["nodeSelector"] = {
|
||||||
|
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
|
||||||
|
}
|
||||||
|
# 容忍度:替换为你的tolerations配置
|
||||||
|
submit_config["values"]["tolerations"] = [
|
||||||
|
{
|
||||||
|
"key": "hosttype",
|
||||||
|
"operator": "Equal",
|
||||||
|
"value": "iluvatar",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# #########################################
|
||||||
|
# 禁止CPU模式下使用GPU资源(保持原逻辑)
|
||||||
|
else:
|
||||||
|
if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests:
|
||||||
|
log.error("禁止在CPU模式下使用GPU资源")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键
|
||||||
|
#for key in gpukeys:
|
||||||
|
# if key in limits or key in requests:
|
||||||
|
# log.error("禁止使用vgpu资源")
|
||||||
|
# sys.exit(1)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 替换nvidia资源键为iluvatar.ai/gpu
|
||||||
|
vgpu_resource = {
|
||||||
|
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
|
||||||
|
# 若需要其他资源(如显存),按你的K8s配置补充,例如:
|
||||||
|
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
|
||||||
|
}
|
||||||
|
limits.update(vgpu_resource)
|
||||||
|
requests.update(vgpu_resource)
|
||||||
|
# 节点选择器:替换为你的accelerator标签
|
||||||
|
submit_config["values"]["nodeSelector"] = {
|
||||||
|
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
|
||||||
|
}
|
||||||
|
# 容忍度:替换为你的tolerations配置
|
||||||
|
"""
|
||||||
|
submit_config["values"]["tolerations"] = [
|
||||||
|
{
|
||||||
|
"key": "hosttype",
|
||||||
|
"operator": "Equal",
|
||||||
|
"value": "iluvatar",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "hosttype",
|
||||||
|
"operator": "Equal",
|
||||||
|
"value": "arm64",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "hosttype",
|
||||||
|
"operator": "Equal",
|
||||||
|
"value": "myinit",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "hosttype",
|
||||||
|
"operator": "Equal",
|
||||||
|
"value": "middleware",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
}
|
||||||
|
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"key": "node-role.kubernetes.io/master",
|
||||||
|
"operator": "Exists",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "node.kubernetes.io/not-ready",
|
||||||
|
"operator": "Exists",
|
||||||
|
"effect": "NoExecute",
|
||||||
|
"tolerationSeconds": 300
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "node.kubernetes.io/unreachable",
|
||||||
|
"operator": "Exists",
|
||||||
|
"effect": "NoExecute",
|
||||||
|
"tolerationSeconds": 300
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
log.info(f"submit_config: {submit_config}")
|
||||||
|
log.info(f"RESOURCE_NAME: {RESOURCE_NAME}")
|
||||||
|
|
||||||
|
return register_sut(submit_config, RESOURCE_NAME).replace(
|
||||||
|
"ws://", "http://"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sut_url():
|
||||||
|
return get_sut_url_kubernetes()
|
||||||
|
|
||||||
|
#SUT_URL = get_sut_url()
|
||||||
|
#os.environ["SUT_URL"] = SUT_URL
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#############################################################################
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
|
||||||
|
def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
|
||||||
|
"""
|
||||||
|
生成请求的body
|
||||||
|
:param apiname
|
||||||
|
:param APPId: Appid
|
||||||
|
:param file_name: 文件路径
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if apiname == 'createFeature':
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
audioBytes = f.read()
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "createFeature",
|
||||||
|
"groupId": "test_voiceprint_e",
|
||||||
|
"featureId": featureId,
|
||||||
|
"featureInfo": featureInfo,
|
||||||
|
"createFeatureRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"resource": {
|
||||||
|
"encoding": "lame",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"channels": 1,
|
||||||
|
"bit_depth": 16,
|
||||||
|
"status": 3,
|
||||||
|
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif apiname == 'createGroup':
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "createGroup",
|
||||||
|
"groupId": "test_voiceprint_e",
|
||||||
|
"groupName": "vip_user",
|
||||||
|
"groupInfo": "store_vip_user_voiceprint",
|
||||||
|
"createGroupRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif apiname == 'deleteFeature':
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "deleteFeature",
|
||||||
|
"groupId": "iFLYTEK_examples_groupId",
|
||||||
|
"featureId": "iFLYTEK_examples_featureId",
|
||||||
|
"deleteFeatureRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif apiname == 'queryFeatureList':
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "queryFeatureList",
|
||||||
|
"groupId": "user_voiceprint_2",
|
||||||
|
"queryFeatureListRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif apiname == 'searchFea':
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
audioBytes = f.read()
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "searchFea",
|
||||||
|
"groupId": "test_voiceprint_e",
|
||||||
|
"topK": 1,
|
||||||
|
"searchFeaRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"resource": {
|
||||||
|
"encoding": "lame",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"channels": 1,
|
||||||
|
"bit_depth": 16,
|
||||||
|
"status": 3,
|
||||||
|
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif apiname == 'searchScoreFea':
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
audioBytes = f.read()
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "searchScoreFea",
|
||||||
|
"groupId": "test_voiceprint_e",
|
||||||
|
"dstFeatureId": dstFeatureId,
|
||||||
|
"searchScoreFeaRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"resource": {
|
||||||
|
"encoding": "lame",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"channels": 1,
|
||||||
|
"bit_depth": 16,
|
||||||
|
"status": 3,
|
||||||
|
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif apiname == 'updateFeature':
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
audioBytes = f.read()
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "updateFeature",
|
||||||
|
"groupId": "iFLYTEK_examples_groupId",
|
||||||
|
"featureId": "iFLYTEK_examples_featureId",
|
||||||
|
"featureInfo": "iFLYTEK_examples_featureInfo_update",
|
||||||
|
"updateFeatureRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"resource": {
|
||||||
|
"encoding": "lame",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"channels": 1,
|
||||||
|
"bit_depth": 16,
|
||||||
|
"status": 3,
|
||||||
|
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif apiname == 'deleteGroup':
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": APPId,
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"s782b4996": {
|
||||||
|
"func": "deleteGroup",
|
||||||
|
"groupId": "iFLYTEK_examples_groupId",
|
||||||
|
"deleteGroupRes": {
|
||||||
|
"encoding": "utf8",
|
||||||
|
"compress": "raw",
|
||||||
|
"format": "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查")
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
log.info(f"开始请求获取到SUT服务URL")
|
||||||
|
# 获取SUT服务URL
|
||||||
|
sut_url = get_sut_url()
|
||||||
|
print(f"获取到的SUT_URL: {sut_url}") # 调试输出
|
||||||
|
log.info(f"获取到SUT服务URL: {sut_url}")
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
|
text_decoded = None
|
||||||
|
|
||||||
|
###################################新增新增################################
|
||||||
|
def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
|
||||||
|
"""
|
||||||
|
开始请求
|
||||||
|
:param APPId: APPID
|
||||||
|
:param file_path: body里的文件路径
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
global text_decoded
|
||||||
|
|
||||||
|
body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId)
|
||||||
|
#request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996'
|
||||||
|
|
||||||
|
#request_url = 'https://sut:80/sid/v1/private/s782b4996'
|
||||||
|
|
||||||
|
#headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId}
|
||||||
|
|
||||||
|
parsed_url = urlparse(sut_url)
|
||||||
|
headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId}
|
||||||
|
|
||||||
|
# 1. 首先测试服务健康检查
|
||||||
|
response = requests.get(f"{sut_url}/health")
|
||||||
|
print(response.status_code, response.text)
|
||||||
|
|
||||||
|
|
||||||
|
# 请求头
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
# 请求体(可指定限制处理的图片数量)
|
||||||
|
body = {"limit": 20 } # 可选参数,限制处理的图片总数
|
||||||
|
|
||||||
|
# 发送POST请求
|
||||||
|
response = requests.post(
|
||||||
|
f"{sut_url}/v1/private/s782b4996",
|
||||||
|
data=json.dumps(body),
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析响应结果
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
print("预测评估结果:")
|
||||||
|
print(f"准确率: {result['metrics']['accuracy']}%")
|
||||||
|
print(f"平均召回率: {result['metrics']['average_recall']}%")
|
||||||
|
print(f"处理图片总数: {result['metrics']['total_images']}")
|
||||||
|
else:
|
||||||
|
print(f"请求失败,状态码: {response.status_code}")
|
||||||
|
print(f"错误信息: {response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 添加基本认证信息
|
||||||
|
auth = ('llm', 'Rmf4#LcG(iFZrjU;2J')
|
||||||
|
#response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth)
|
||||||
|
|
||||||
|
#response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth)
|
||||||
|
#response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth)
|
||||||
|
"""
|
||||||
|
response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#print("HTTP状态码:", response.status_code)
|
||||||
|
#print("原始响应内容:", response.text) # 先打印原始内容
|
||||||
|
#print(f"请求URL: {sut_url + '/v1/private/s782b4996'}")
|
||||||
|
#print(f"请求headers: {headers}")
|
||||||
|
#print(f"请求body: {body}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#tempResult = json.loads(response.content.decode('utf-8'))
|
||||||
|
#print(tempResult)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 对text字段进行Base64解码
|
||||||
|
if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']:
|
||||||
|
text_encoded = tempResult['payload']['updateFeatureRes']['text']
|
||||||
|
text_decoded = base64.b64decode(text_encoded).decode('utf-8')
|
||||||
|
print(f"Base64解码后的text字段内容: {text_decoded}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
#text_encoded = tempResult['payload']['updateFeatureRes']['text']
|
||||||
|
#text_decoded = base64.b64decode(text_encoded).decode('utf-8')
|
||||||
|
#print(f"Base64解码后的text字段内容: {text_decoded}")
|
||||||
|
|
||||||
|
|
||||||
|
# 获取响应的 JSON 数据
|
||||||
|
result = response.json()
|
||||||
|
with open(RESULT_FILEPATH, "w") as f:
|
||||||
|
json.dump(result, f, indent=4, ensure_ascii=False)
|
||||||
|
print(f"结果已成功写入 {RESULT_FILEPATH}")
|
||||||
|
|
||||||
|
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
|
||||||
|
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||||||
|
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||||||
|
#detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
def result2file(
|
||||||
|
result: Dict[str, Any],
|
||||||
|
detail_cases: List[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
assert result_filepath is not None
|
||||||
|
assert bad_cases_filepath is not None
|
||||||
|
#assert detailed_cases_filepath is not None
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
with open(result_filepath, "w") as f:
|
||||||
|
json.dump(result, f, indent=4, ensure_ascii=False)
|
||||||
|
#if LOCAL_TEST:
|
||||||
|
# logger.info(f'result:\n {json.dumps(result, indent=4)}')
|
||||||
|
"""
|
||||||
|
if detail_cases is not None:
|
||||||
|
with open(detailed_cases_filepath, "w") as f:
|
||||||
|
json.dump(detail_cases, f, indent=4, ensure_ascii=False)
|
||||||
|
if LOCAL_TEST:
|
||||||
|
logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}')
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_prediction(sut_url, image_path):
|
||||||
|
"""发送单张图片到服务端预测"""
|
||||||
|
url = f"{sut_url}/v1/private/s782b4996"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
files = {'image': f}
|
||||||
|
response = requests.post(url, files=files, timeout=30)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if result.get('status') != 'success':
|
||||||
|
return None, f"服务端错误: {result.get('message')}"
|
||||||
|
|
||||||
|
return result, None
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"请求错误: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
#from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
print(f"\n===== main开始请求接口 ===============================================")
|
||||||
|
# 1. 首先测试服务健康检查
|
||||||
|
|
||||||
|
print(f"\n===== 服务健康检查 ===================================================")
|
||||||
|
response = requests.get(f"{sut_url}/health")
|
||||||
|
print(response.status_code, response.text)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 本地图片路径和真实标签(根据实际情况修改)
|
||||||
|
image_path = "/path/to/your/test_image.jpg"
|
||||||
|
true_label = "cat" # 图片的真实标签
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 请求头
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
# 请求体(可指定限制处理的图片数量)
|
||||||
|
body = {"limit": 20 } # 可选参数,限制处理的图片总数
|
||||||
|
|
||||||
|
# 发送POST请求
|
||||||
|
response = requests.post(
|
||||||
|
f"{sut_url}/v1/private/s782b4996",
|
||||||
|
data=json.dumps(body),
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 读取图片文件
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
files = {'image': f}
|
||||||
|
# 发送POST请求
|
||||||
|
response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files)
|
||||||
|
|
||||||
|
|
||||||
|
# 解析响应结果
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
print("预测评估结果:")
|
||||||
|
print(f"准确率: {result['metrics']['accuracy']}%")
|
||||||
|
print(f"平均召回率: {result['metrics']['average_recall']}%")
|
||||||
|
print(f"处理图片总数: {result['metrics']['total_images']}")
|
||||||
|
else:
|
||||||
|
print(f"请求失败,状态码: {response.status_code}")
|
||||||
|
print(f"错误信息: {response.text}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################################
|
||||||
|
dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录
|
||||||
|
samples_per_class = 3 # 每个类别抽取的样本数
|
||||||
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式
|
||||||
|
|
||||||
|
# 结果统计变量
|
||||||
|
total_samples = 0
|
||||||
|
#correct_predictions = 0
|
||||||
|
|
||||||
|
# GPU统计
|
||||||
|
gpu_true_positives = 0
|
||||||
|
gpu_false_positives = 0
|
||||||
|
gpu_false_negatives = 0
|
||||||
|
gpu_total_processing_time = 0.0
|
||||||
|
|
||||||
|
# CPU统计
|
||||||
|
cpu_true_positives = 0
|
||||||
|
cpu_false_positives = 0
|
||||||
|
cpu_false_negatives = 0
|
||||||
|
cpu_total_processing_time = 0.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 遍历所有类别文件夹
|
||||||
|
for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"):
|
||||||
|
folder_path = os.path.join(dataset_root, folder_name)
|
||||||
|
|
||||||
|
|
||||||
|
# 提取类别名(从"序号.name"格式中提取name部分)
|
||||||
|
class_name = folder_name.split('.', 1)[1].strip().lower()
|
||||||
|
|
||||||
|
# 获取文件夹中所有图片
|
||||||
|
image_files = []
|
||||||
|
for file in os.listdir(folder_path):
|
||||||
|
if file.lower().endswith(image_extensions):
|
||||||
|
image_files.append(os.path.join(folder_path, file))
|
||||||
|
|
||||||
|
# 随机抽取指定数量的图片(如果不足则取全部)
|
||||||
|
selected_images = random.sample(
|
||||||
|
image_files,
|
||||||
|
min(samples_per_class, len(image_files))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理选中的图片
|
||||||
|
for img_path in selected_images:
|
||||||
|
total_count += 1
|
||||||
|
|
||||||
|
# 发送预测请求
|
||||||
|
prediction, error = test_image_prediction(sut_url, img_path)
|
||||||
|
if error:
|
||||||
|
print(f"处理图片 {img_path} 失败: {error}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析预测结果
|
||||||
|
pred_class = prediction.get('class_name', '').lower()
|
||||||
|
confidence = prediction.get('confidence', 0)
|
||||||
|
|
||||||
|
# 判断是否预测正确(真实类别是否在预测类别中)
|
||||||
|
if class_name in pred_class:
|
||||||
|
correct_predictions += 1
|
||||||
|
|
||||||
|
|
||||||
|
# 可选:打印详细结果
|
||||||
|
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 遍历所有类别文件夹
|
||||||
|
for folder_name in os.listdir(dataset_root):
|
||||||
|
folder_path = os.path.join(dataset_root, folder_name)
|
||||||
|
|
||||||
|
# 跳过非文件夹的项目
|
||||||
|
if not os.path.isdir(folder_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 提取类别名(从"序号.name"格式中提取name部分)
|
||||||
|
try:
|
||||||
|
class_name = folder_name.split('.', 1)[1].strip().lower()
|
||||||
|
except IndexError:
|
||||||
|
print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取文件夹中所有图片
|
||||||
|
image_files = []
|
||||||
|
for file in os.listdir(folder_path):
|
||||||
|
file_path = os.path.join(folder_path, file)
|
||||||
|
if os.path.isfile(file_path) and file.lower().endswith(image_extensions):
|
||||||
|
image_files.append(file_path)
|
||||||
|
|
||||||
|
# 随机抽取指定数量的图片(如果不足则取全部)
|
||||||
|
selected_images = random.sample(
|
||||||
|
image_files,
|
||||||
|
min(samples_per_class, len(image_files))
|
||||||
|
)
|
||||||
|
|
||||||
|
for img_path in selected_images:
|
||||||
|
total_samples += 1
|
||||||
|
|
||||||
|
# 获取预测结果
|
||||||
|
prediction, error = test_image_prediction(sut_url, img_path)
|
||||||
|
|
||||||
|
# 打印test_image_prediction返回的结果
|
||||||
|
print(f"test_image_prediction返回的prediction: {prediction}")
|
||||||
|
print(f"test_image_prediction返回的error: {error}")
|
||||||
|
|
||||||
|
if error:
|
||||||
|
print(f"处理图片 {img_path} 失败: {error}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 解析GPU预测结果
|
||||||
|
gpu_pred = prediction.get('cuda_prediction', {})
|
||||||
|
gpu_pred_class = gpu_pred.get('class_name', '').lower()
|
||||||
|
gpu_processing_time = gpu_pred.get('processing_time', 0.0)
|
||||||
|
|
||||||
|
# 解析CPU预测结果
|
||||||
|
cpu_pred = prediction.get('cpu_prediction', {})
|
||||||
|
cpu_pred_class = cpu_pred.get('class_name', '').lower()
|
||||||
|
cpu_processing_time = cpu_pred.get('processing_time', 0.0)
|
||||||
|
|
||||||
|
# 判断GPU预测是否正确
|
||||||
|
gpu_is_correct = class_name in gpu_pred_class
|
||||||
|
if gpu_is_correct:
|
||||||
|
gpu_true_positives += 1
|
||||||
|
else:
|
||||||
|
gpu_false_positives += 1
|
||||||
|
gpu_false_negatives += 1
|
||||||
|
|
||||||
|
# 判断CPU预测是否正确
|
||||||
|
cpu_is_correct = class_name in cpu_pred_class
|
||||||
|
if cpu_is_correct:
|
||||||
|
cpu_true_positives += 1
|
||||||
|
else:
|
||||||
|
cpu_false_positives += 1
|
||||||
|
cpu_false_negatives += 1
|
||||||
|
|
||||||
|
# 累加处理时间
|
||||||
|
gpu_total_processing_time += gpu_processing_time
|
||||||
|
cpu_total_processing_time += cpu_processing_time
|
||||||
|
|
||||||
|
# 打印详细结果
|
||||||
|
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name}")
|
||||||
|
print(f"GPU预测: {gpu_pred_class} | {'正确' if gpu_is_correct else '错误'} | 耗时: {gpu_processing_time:.6f}s")
|
||||||
|
print(f"CPU预测: {cpu_pred_class} | {'正确' if cpu_is_correct else '错误'} | 耗时: {cpu_processing_time:.6f}s")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 计算整体指标(在单标签场景下,准确率=召回率)
|
||||||
|
if total_samples == 0:
|
||||||
|
overall_accuracy = 0.0
|
||||||
|
overall_recall = 0.0
|
||||||
|
else:
|
||||||
|
overall_accuracy = correct_predictions / total_samples
|
||||||
|
overall_recall = correct_predictions / total_samples # 整体召回率
|
||||||
|
|
||||||
|
# 输出统计结果
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print(f"测试总结:")
|
||||||
|
print(f"总测试样本数: {total_samples}")
|
||||||
|
print(f"正确预测样本数: {correct_predictions}")
|
||||||
|
print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})")
|
||||||
|
print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})")
|
||||||
|
print("="*50)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 初始化结果字典
|
||||||
|
result = {
|
||||||
|
# GPU指标
|
||||||
|
"gpu_accuracy": 0.0,
|
||||||
|
"gpu_recall": 0.0,
|
||||||
|
"gpu_running_time": round(gpu_total_processing_time, 6),
|
||||||
|
"gpu_throughput": 0.0,
|
||||||
|
|
||||||
|
# CPU指标
|
||||||
|
"cpu_accuracy": 0.0,
|
||||||
|
"cpu_recall": 0.0,
|
||||||
|
"cpu_running_time": round(cpu_total_processing_time, 6),
|
||||||
|
"cpu_throughput": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算GPU指标
|
||||||
|
gpu_accuracy = gpu_true_positives / total_samples * 100
|
||||||
|
gpu_recall_denominator = gpu_true_positives + gpu_false_negatives
|
||||||
|
gpu_recall = gpu_true_positives / gpu_recall_denominator * 100 if gpu_recall_denominator > 0 else 0
|
||||||
|
gpu_throughput = total_samples / gpu_total_processing_time if gpu_total_processing_time > 1e-6 else 0
|
||||||
|
|
||||||
|
# 计算CPU指标
|
||||||
|
cpu_accuracy = cpu_true_positives / total_samples * 100
|
||||||
|
cpu_recall_denominator = cpu_true_positives + cpu_false_negatives
|
||||||
|
cpu_recall = cpu_true_positives / cpu_recall_denominator * 100 if cpu_recall_denominator > 0 else 0
|
||||||
|
cpu_throughput = total_samples / cpu_total_processing_time if cpu_total_processing_time > 1e-6 else 0
|
||||||
|
|
||||||
|
# 更新结果字典
|
||||||
|
result.update({
|
||||||
|
"gpu_accuracy": round(gpu_accuracy, 6),
|
||||||
|
"gpu_recall": round(gpu_recall, 6),
|
||||||
|
"gpu_throughput": round(gpu_throughput, 6),
|
||||||
|
|
||||||
|
"cpu_accuracy": round(cpu_accuracy, 6),
|
||||||
|
"cpu_recall": round(cpu_recall, 6),
|
||||||
|
"cpu_throughput": round(cpu_throughput, 6)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# 打印最终统计结果
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print(f"总样本数: {total_samples}")
|
||||||
|
print("\nGPU指标:")
|
||||||
|
print(f"准确率: {result['gpu_accuracy']:.4f}%")
|
||||||
|
print(f"召回率: {result['gpu_recall']:.4f}%")
|
||||||
|
print(f"总运行时间: {result['gpu_running_time']:.6f}s")
|
||||||
|
print(f"吞吐量: {result['gpu_throughput']:.2f}张/秒")
|
||||||
|
|
||||||
|
print("\nCPU指标:")
|
||||||
|
print(f"准确率: {result['cpu_accuracy']:.4f}%")
|
||||||
|
print(f"召回率: {result['cpu_recall']:.4f}%")
|
||||||
|
print(f"总运行时间: {result['cpu_running_time']:.6f}s")
|
||||||
|
print(f"吞吐量: {result['cpu_throughput']:.2f}张/秒")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
|
||||||
|
#result = {}
|
||||||
|
#result['accuracy_1_1'] = 3
|
||||||
|
result2file(result)
|
||||||
|
|
||||||
|
if abs(gpu_accuracy - cpu_accuracy) > 3:
|
||||||
|
log.error(f"gpu与cpu准确率差别超过3%,模型结果不正确")
|
||||||
|
change_product_unavailable()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if result['accuracy_1_1'] < 0.9:
|
||||||
|
log.error(f"1:1正确率未达到90%, 视为产品不可用")
|
||||||
|
change_product_unavailable()
|
||||||
|
|
||||||
|
|
||||||
|
if result['accuracy_1_N'] < 1:
|
||||||
|
log.error(f"1:N正确率未达到100%, 视为产品不可用")
|
||||||
|
change_product_unavailable()
|
||||||
|
if result['1_1_latency'] > 0.5:
|
||||||
|
log.error(f"1:1平均latency超过0.5s, 视为产品不可用")
|
||||||
|
change_product_unavailable()
|
||||||
|
if result['1_N_latency'] > 0.5:
|
||||||
|
log.error(f"1:N平均latency超过0.5s, 视为产品不可用")
|
||||||
|
change_product_unavailable()
|
||||||
|
if result['enroll_latency'] > 1:
|
||||||
|
log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用")
|
||||||
|
change_product_unavailable()
|
||||||
|
"""
|
||||||
|
exit_code = 0
|
||||||
|
|
||||||
|
|
||||||
1193
run_callback_cuda.py
Normal file
1193
run_callback_cuda.py
Normal file
File diff suppressed because it is too large
Load Diff
1296
run_callback_new.py
Normal file
1296
run_callback_new.py
Normal file
File diff suppressed because it is too large
Load Diff
0
schemas/__init__.py
Normal file
0
schemas/__init__.py
Normal file
90
schemas/context.py
Normal file
90
schemas/context.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from schemas.stream import StreamDataModel
|
||||||
|
|
||||||
|
|
||||||
|
class LabelContext(BaseModel):
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
class PredContext(BaseModel):
|
||||||
|
recognition_results: StreamDataModel
|
||||||
|
recv_time: Optional[float] = Field(None)
|
||||||
|
send_time: Optional[float] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class ASRContext:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.bits = kwargs.get("bits", 16)
|
||||||
|
self.channel = kwargs.get("channel", 1)
|
||||||
|
self.sample_rate = kwargs.get("sample_rate", 16000)
|
||||||
|
self.audio_format = kwargs.get("format", "wav")
|
||||||
|
self.enable_words = kwargs.get("enable_words", True)
|
||||||
|
self.char_contains_rate = kwargs.get("char_contains_rate", 0.8)
|
||||||
|
self.lang = os.getenv("lang")
|
||||||
|
if self.lang is None:
|
||||||
|
self.lang = kwargs.get("lang", "en")
|
||||||
|
self.stream = kwargs.get("stream", True)
|
||||||
|
|
||||||
|
self.wait_time = float(os.getenv("wait_time", 0.1))
|
||||||
|
self.chunk_size = self.sample_rate * self.bits / 8 * self.wait_time
|
||||||
|
if int(os.getenv('chunk_size_set', 0)):
|
||||||
|
self.chunk_size = int(os.getenv('chunk_size_set', 0))
|
||||||
|
|
||||||
|
self.audio_length = 0
|
||||||
|
self.file_path = ""
|
||||||
|
|
||||||
|
self.labels: List[LabelContext] = kwargs.get("labels", [])
|
||||||
|
self.preds: List[PredContext] = kwargs.get("preds", [])
|
||||||
|
|
||||||
|
self.label_sentences: List[str] = []
|
||||||
|
self.pred_sentences: List[str] = []
|
||||||
|
|
||||||
|
self.send_time_start_end = []
|
||||||
|
self.recv_time_start_end = []
|
||||||
|
|
||||||
|
self.fail = False
|
||||||
|
self.fail_char_contains_rate_num = 0
|
||||||
|
|
||||||
|
self.punctuation_num = 0
|
||||||
|
self.pred_punctuation_num = 0
|
||||||
|
|
||||||
|
def append_labels(self, voices: List[Dict]):
|
||||||
|
for voice_data in voices:
|
||||||
|
label_context = LabelContext(**voice_data)
|
||||||
|
self.labels.append(label_context)
|
||||||
|
|
||||||
|
def append_preds(
|
||||||
|
self,
|
||||||
|
predict_data: List[StreamDataModel],
|
||||||
|
send_time: List[float],
|
||||||
|
recv_time: List[float],
|
||||||
|
):
|
||||||
|
self.send_time_start_end = [send_time[0], send_time[-1]] if len(send_time) > 0 else []
|
||||||
|
self.recv_time_start_end = [recv_time[0], recv_time[-1]] if len(recv_time) > 0 else []
|
||||||
|
for pred_item, send_time_item, recv_time_item in zip(predict_data, send_time, recv_time):
|
||||||
|
pred_item = deepcopy(pred_item)
|
||||||
|
pred_context = PredContext(recognition_results=pred_item.model_dump())
|
||||||
|
pred_context.send_time = send_time_item
|
||||||
|
pred_context.recv_time = recv_time_item
|
||||||
|
self.preds.append(pred_context)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"bits": self.bits,
|
||||||
|
"channel": self.channel,
|
||||||
|
"sample_rate": self.sample_rate,
|
||||||
|
"audio_format": self.audio_format,
|
||||||
|
"enable_words": self.enable_words,
|
||||||
|
"stream": self.stream,
|
||||||
|
"wait_time": self.wait_time,
|
||||||
|
"chunk_size": self.chunk_size,
|
||||||
|
"labels": [item.model_dump_json() for item in self.labels],
|
||||||
|
"preds": [item.model_dump_json() for item in self.preds],
|
||||||
|
}
|
||||||
18
schemas/dataset.py
Normal file
18
schemas/dataset.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class QueryDataSentence(BaseModel):
|
||||||
|
answer: str = Field(description="文本label")
|
||||||
|
start: float = Field(description="句子开始时间")
|
||||||
|
end: float = Field(description="句子结束时间")
|
||||||
|
|
||||||
|
|
||||||
|
class QueryData(BaseModel):
|
||||||
|
lang: str = Field(description="语言")
|
||||||
|
file: str = Field(description="音频文件位置")
|
||||||
|
duration: float = Field(description="音频长度")
|
||||||
|
voice: List[QueryDataSentence] = Field(
|
||||||
|
description="音频文件的文本label内容"
|
||||||
|
)
|
||||||
66
schemas/stream.py
Normal file
66
schemas/stream.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ValidationError, field_validator
|
||||||
|
from pydantic import model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class StreamWordsModel(BaseModel):
|
||||||
|
text: str
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_result(self):
|
||||||
|
if self.end_time < self.start_time:
|
||||||
|
raise ValidationError("end-time 小于 start-time, error")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class StreamDataModel(BaseModel):
|
||||||
|
text: str
|
||||||
|
language: str
|
||||||
|
final_result: bool
|
||||||
|
para_seq: int
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
words: List[StreamWordsModel]
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_result(self):
|
||||||
|
if self.end_time < self.start_time:
|
||||||
|
raise ValidationError("end-time 小于 start-time, error")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class StreamResultModel(BaseModel):
|
||||||
|
asr_results: StreamDataModel
|
||||||
|
|
||||||
|
@field_validator('asr_results', mode="after")
|
||||||
|
def convert_to_seconds(cls, v: StreamDataModel, values):
|
||||||
|
# 在这里处理除以1000的逻辑
|
||||||
|
v.end_time = v.end_time / 1000
|
||||||
|
v.start_time = v.start_time / 1000
|
||||||
|
for word in v.words:
|
||||||
|
word.start_time /= 1000
|
||||||
|
word.end_time /= 1000
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
validate_assignment = True
|
||||||
|
|
||||||
|
|
||||||
|
class NonStreamDataModel(BaseModel):
|
||||||
|
text: str
|
||||||
|
para_seq: int
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_result(self):
|
||||||
|
if self.end_time < self.start_time:
|
||||||
|
raise ValidationError("end-time 小于 start-time, error")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class NonStreamResultModel(BaseModel):
|
||||||
|
contents: List[NonStreamDataModel]
|
||||||
53
scripts/check_dataset_time.py
Normal file
53
scripts/check_dataset_time.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def main(dataset_dir):
|
||||||
|
dirs = os.listdir(dataset_dir)
|
||||||
|
dirs = list(
|
||||||
|
filter(lambda x: os.path.isdir(os.path.join(dataset_dir, x)), dirs)
|
||||||
|
)
|
||||||
|
|
||||||
|
problem_dirs = set()
|
||||||
|
problem_count = defaultdict(int)
|
||||||
|
for dir in dirs:
|
||||||
|
with open(os.path.join(dataset_dir, dir, "data.yaml"), "r") as f:
|
||||||
|
data = yaml.full_load(f)
|
||||||
|
for query_i, query in enumerate(data["query_data"]):
|
||||||
|
voices = sorted(query["voice"], key=lambda x: x["start"])
|
||||||
|
if voices != query["voice"]:
|
||||||
|
print("-----", dir)
|
||||||
|
if voices[0]["start"] > voices[0]["end"]:
|
||||||
|
print(
|
||||||
|
"err1: %s 第%s个query的第%d个voice的start大于end: %s"
|
||||||
|
% (dir, query_i, 0, voices[0]["answer"])
|
||||||
|
)
|
||||||
|
problem_dirs.add(dir)
|
||||||
|
for voice_i in range(1, len(voices)):
|
||||||
|
voice = voices[voice_i]
|
||||||
|
if voice["start"] > voice["end"]:
|
||||||
|
print(
|
||||||
|
"err1: %s 第%s个query的第%d个voice的start大于end: %s"
|
||||||
|
% (dir, query_i, voice_i, voice["answer"])
|
||||||
|
)
|
||||||
|
problem_dirs.add(dir)
|
||||||
|
if voice["start"] < voices[voice_i - 1]["end"]:
|
||||||
|
print(
|
||||||
|
"err2: %s 第%s个query的第%d个voice的start小于前一个voice的end: %s"
|
||||||
|
% (dir, query_i, voice_i, voice["answer"])
|
||||||
|
)
|
||||||
|
problem_dirs.add(dir)
|
||||||
|
problem_count[dir] += 1
|
||||||
|
print(len(dirs))
|
||||||
|
print(problem_dirs)
|
||||||
|
print(problem_count)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("指定 测试数据集文件夹")
|
||||||
|
sys.exit(1)
|
||||||
|
main(sys.argv[1])
|
||||||
108
scripts/convert_callback_dataset.py
Normal file
108
scripts/convert_callback_dataset.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
"""
|
||||||
|
target
|
||||||
|
{
|
||||||
|
"global": {
|
||||||
|
"lang": ""
|
||||||
|
},
|
||||||
|
"query_data": [
|
||||||
|
"file": "",
|
||||||
|
"duration": 2.0,
|
||||||
|
"voice": [
|
||||||
|
{
|
||||||
|
"answer": "",
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 1.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def situation_a(meta, dataset_folder, output_folder):
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"combined": {
|
||||||
|
"en": [
|
||||||
|
{
|
||||||
|
"wav": "*.wav",
|
||||||
|
"transcriptions": [
|
||||||
|
{
|
||||||
|
"text": "",
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 1.0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"duration": 2.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
meta = meta["combined"]
|
||||||
|
|
||||||
|
for lang, arr in meta.items():
|
||||||
|
print("processing", lang)
|
||||||
|
assert len(lang) == 2
|
||||||
|
lang_folder = os.path.join(output_folder, lang)
|
||||||
|
os.makedirs(lang_folder, exist_ok=True)
|
||||||
|
data = {"global": {"lang": lang}, "query_data": []}
|
||||||
|
query_data = data["query_data"]
|
||||||
|
for item in arr:
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(lang_folder, os.path.dirname(item["wav"])),
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
mp3_file = item["wav"][:-4] + ".mp3"
|
||||||
|
shutil.copyfile(
|
||||||
|
os.path.join(dataset_folder, mp3_file),
|
||||||
|
os.path.join(lang_folder, mp3_file),
|
||||||
|
)
|
||||||
|
query_data_item = {
|
||||||
|
"file": mp3_file,
|
||||||
|
"duration": float(item["duration"]),
|
||||||
|
"voice": [],
|
||||||
|
}
|
||||||
|
query_data.append(query_data_item)
|
||||||
|
voice = query_data_item["voice"]
|
||||||
|
for v in item["transcriptions"]:
|
||||||
|
voice.append(
|
||||||
|
{
|
||||||
|
"answer": v["text"],
|
||||||
|
"start": float(v["start"]),
|
||||||
|
"end": float(v["end"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with open(os.path.join(lang_folder, "data.yaml"), "w") as f:
|
||||||
|
yaml.dump(data, f, indent=2, allow_unicode=True, encoding="utf-8")
|
||||||
|
with zipfile.ZipFile(
|
||||||
|
os.path.join(output_folder, lang + ".zip"), "w"
|
||||||
|
) as ziper:
|
||||||
|
dirname = lang_folder
|
||||||
|
for path, _, files in os.walk(dirname):
|
||||||
|
for file in files:
|
||||||
|
ziper.write(
|
||||||
|
os.path.join(path, file),
|
||||||
|
os.path.join(path[len(dirname) :], file),
|
||||||
|
zipfile.ZIP_DEFLATED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 3:
|
||||||
|
print("指定 数据集文件夹路径 输出路径")
|
||||||
|
sys.exit(1)
|
||||||
|
dataset_folder = sys.argv[1]
|
||||||
|
output_folder = sys.argv[2]
|
||||||
|
|
||||||
|
with open(os.path.join(dataset_folder, "meta.json")) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
situation_a(meta, dataset_folder, output_folder)
|
||||||
56
scripts/debug_detailcase.py
Normal file
56
scripts/debug_detailcase.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from schemas.dataset import QueryData
|
||||||
|
from schemas.stream import StreamDataModel
|
||||||
|
from utils.evaluator_plus import evaluate_editops
|
||||||
|
|
||||||
|
|
||||||
|
def main(detailcase_file: str):
|
||||||
|
with open(detailcase_file) as f:
|
||||||
|
d = json.load(f)[0]
|
||||||
|
preds = d["preds"]
|
||||||
|
preds = list(map(lambda x: StreamDataModel(**x), preds))
|
||||||
|
preds = list(filter(lambda x: x.final_result, preds))
|
||||||
|
label = d["label"]
|
||||||
|
label = QueryData(**label)
|
||||||
|
print(evaluate_editops(label, preds))
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_from_record(detailcase_file: str, record_path: str):
|
||||||
|
with open(detailcase_file) as f:
|
||||||
|
d = json.load(f)[0]
|
||||||
|
label = d["label"]
|
||||||
|
label = QueryData(**label)
|
||||||
|
with open(record_path) as f:
|
||||||
|
record = json.load(f)
|
||||||
|
tokens_pred = record["tokens_pred"]
|
||||||
|
tokens_label = record["tokens_label"]
|
||||||
|
recognition_results = record["recognition_results"]
|
||||||
|
recognition_results = list(
|
||||||
|
map(lambda x: StreamDataModel(**x), recognition_results)
|
||||||
|
)
|
||||||
|
a, b = [], []
|
||||||
|
for i, rr in enumerate(recognition_results):
|
||||||
|
if rr.final_result:
|
||||||
|
a.append(tokens_pred[i])
|
||||||
|
b.append(rr)
|
||||||
|
tokens_pred = a
|
||||||
|
recognition_results = b
|
||||||
|
|
||||||
|
print(
|
||||||
|
evaluate_editops(
|
||||||
|
label,
|
||||||
|
recognition_results,
|
||||||
|
tokens_pred,
|
||||||
|
tokens_label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("请指定 detailcase 文件路径")
|
||||||
|
sys.exit(1)
|
||||||
|
main(sys.argv[1])
|
||||||
|
# evaluate_from_record(sys.argv[1], sys.argv[2])
|
||||||
BIN
ssh-keygen
Executable file
BIN
ssh-keygen
Executable file
Binary file not shown.
11
starting_kit/Dockerfile
Normal file
11
starting_kit/Dockerfile
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
ADD ./requirements.txt /workspace
|
||||||
|
RUN pip install -r ./requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --trusted-host nexus.4pd.io --extra-index-url https://mirrors.aliyun.com/pypi/simple/ \
|
||||||
|
&& pip cache purge
|
||||||
|
|
||||||
|
ADD . /workspace
|
||||||
|
|
||||||
|
CMD ["python", "main.py"]
|
||||||
313
starting_kit/main.py
Normal file
313
starting_kit/main.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import flask
|
||||||
|
import requests
|
||||||
|
from werkzeug.datastructures import FileStorage
|
||||||
|
|
||||||
|
app = flask.Flask(__name__)
|
||||||
|
heartbeat_active = False
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
log.propagate = False
|
||||||
|
|
||||||
|
level = logging.INFO
|
||||||
|
|
||||||
|
log.setLevel(level)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
|
||||||
|
"%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
streamHandler = logging.StreamHandler()
|
||||||
|
streamHandler.setLevel(level)
|
||||||
|
streamHandler.setFormatter(formatter)
|
||||||
|
log.addHandler(streamHandler)
|
||||||
|
|
||||||
|
|
||||||
|
def heartbeat(url):
|
||||||
|
global heartbeat_active
|
||||||
|
if heartbeat_active:
|
||||||
|
return
|
||||||
|
heartbeat_active = True
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
requests.post(url, json={"status": "RUNNING"})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
def asr(
|
||||||
|
audio_file: FileStorage,
|
||||||
|
language: Optional[str],
|
||||||
|
progressCallbackUrl: str,
|
||||||
|
taskId: str,
|
||||||
|
):
|
||||||
|
"""TODO: 读取audio_file, 调用语音识别服务, 实时返回识别结果"""
|
||||||
|
|
||||||
|
# ignore BEGIN
|
||||||
|
# 此处为榜单本地测试使用
|
||||||
|
if os.getenv("LOCAL_TEST"):
|
||||||
|
return local_test(progressCallbackUrl, taskId)
|
||||||
|
# ignore END
|
||||||
|
|
||||||
|
language = "de"
|
||||||
|
# 某一次识别返回
|
||||||
|
requests.post(
|
||||||
|
progressCallbackUrl,
|
||||||
|
json={
|
||||||
|
"taskId": taskId,
|
||||||
|
"status": "RUNNING",
|
||||||
|
"recognition_results": { # 传增量结果, status如果是FINISHED, 或者ERROR, 这个字段请不要传值
|
||||||
|
"text": "最先启动的还是",
|
||||||
|
"final_result": True,
|
||||||
|
"para_seq": 0,
|
||||||
|
"language": language,
|
||||||
|
"start_time": 6300,
|
||||||
|
"end_time": 6421,
|
||||||
|
"words": [
|
||||||
|
{
|
||||||
|
"text": "最",
|
||||||
|
"start_time": 6300,
|
||||||
|
"end_time": 6321,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "先",
|
||||||
|
"start_time": 6321,
|
||||||
|
"end_time": 6345,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "启",
|
||||||
|
"start_time": 6345,
|
||||||
|
"end_time": 6350,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "动",
|
||||||
|
"start_time": 6350,
|
||||||
|
"end_time": 6370,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "的",
|
||||||
|
"start_time": 6370,
|
||||||
|
"end_time": 6386,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "还",
|
||||||
|
"start_time": 6386,
|
||||||
|
"end_time": 6421,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "是",
|
||||||
|
"start_time": 6421,
|
||||||
|
"end_time": 6435,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# ... 识别结果返回完毕
|
||||||
|
|
||||||
|
# 识别结束
|
||||||
|
requests.post(
|
||||||
|
progressCallbackUrl,
|
||||||
|
json={
|
||||||
|
"taskId": taskId,
|
||||||
|
"status": "FINISHED",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/predict")
|
||||||
|
def predict():
|
||||||
|
body = flask.request.form
|
||||||
|
language = body.get("language")
|
||||||
|
if language is None:
|
||||||
|
"自行判断语种"
|
||||||
|
taskId = body["taskId"]
|
||||||
|
progressCallbackUrl = body["progressCallbackUrl"]
|
||||||
|
heartbeatUrl = body["heartbeatUrl"]
|
||||||
|
|
||||||
|
threading.Thread(
|
||||||
|
target=heartbeat, args=(heartbeatUrl,), daemon=True
|
||||||
|
).start()
|
||||||
|
|
||||||
|
audio_file = flask.request.files["file"]
|
||||||
|
# audio_file.stream # 读取文件流
|
||||||
|
# audio_file.save("audio.mp3") # 保存文件
|
||||||
|
threading.Thread(
|
||||||
|
target=asr,
|
||||||
|
args=(audio_file, language, progressCallbackUrl, taskId),
|
||||||
|
daemon=True,
|
||||||
|
).start()
|
||||||
|
return flask.jsonify({"status": "OK"})
|
||||||
|
|
||||||
|
|
||||||
|
# ignore BEGIN
|
||||||
|
def local_test(progressCallbackUrl: str, taskId: str):
|
||||||
|
"""忽略此方法, 此方法为榜单本地调试使用"""
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
def callback(content):
|
||||||
|
try:
|
||||||
|
if content is None:
|
||||||
|
requests.post(
|
||||||
|
progressCallbackUrl,
|
||||||
|
json={"taskId": taskId, "status": "FINISHED"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
requests.post(
|
||||||
|
progressCallbackUrl,
|
||||||
|
json={
|
||||||
|
"taskId": taskId,
|
||||||
|
"status": "RUNNING",
|
||||||
|
"recognition_results": content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.getenv("LOCAL_TEST_DATA_PATH", "../dataset/out/data.yaml")
|
||||||
|
) as f:
|
||||||
|
data = yaml.full_load(f)
|
||||||
|
|
||||||
|
voices = data["query_data"][0]["voice"]
|
||||||
|
|
||||||
|
# 首次发送
|
||||||
|
first_send_time = random.randint(3, 5)
|
||||||
|
send_interval = random.random() * 0
|
||||||
|
log.info("首次发送%ss 发送间隔%ss" % (first_send_time, send_interval))
|
||||||
|
time.sleep(first_send_time)
|
||||||
|
|
||||||
|
# 将句子拼接到一起
|
||||||
|
if random.random() < 0.3:
|
||||||
|
log.info("将部分句子合并成单句 每次合并的句子不超过3句")
|
||||||
|
rand_idx = 0
|
||||||
|
rand_sep = [0, len(voices) - 1]
|
||||||
|
while rand_sep[rand_idx] + 1 <= rand_sep[rand_idx + 1] - 1:
|
||||||
|
rand_cursep = random.randint(
|
||||||
|
rand_sep[rand_idx] + 1,
|
||||||
|
min(rand_sep[rand_idx + 1] - 1, rand_sep[rand_idx] + 1 + 3),
|
||||||
|
)
|
||||||
|
rand_sep.insert(rand_idx + 1, rand_cursep)
|
||||||
|
rand_idx += 1
|
||||||
|
merged_voices = []
|
||||||
|
for i, cur_sep in enumerate(rand_sep[:-1]):
|
||||||
|
voice = voices[cur_sep]
|
||||||
|
for j in range(cur_sep + 1, rand_sep[i + 1]):
|
||||||
|
voice["answer"] += voices[j]["answer"]
|
||||||
|
voice["end"] = voices[j]["end"]
|
||||||
|
merged_voices.append(voice)
|
||||||
|
merged_voices.append(voices[rand_sep[-1]])
|
||||||
|
voices = merged_voices
|
||||||
|
|
||||||
|
def split_and_keep(text, delimiters):
|
||||||
|
# 构建正则表达式模式,匹配文本或分隔符
|
||||||
|
pattern = "|".join(re.escape(delimiter) for delimiter in delimiters)
|
||||||
|
pattern = f"(?:[^{pattern}]+|[{pattern}])"
|
||||||
|
return re.findall(pattern, text)
|
||||||
|
|
||||||
|
puncs = [",", ".", "?", "!", ";", ":"]
|
||||||
|
|
||||||
|
para_seq = 0
|
||||||
|
for voice in voices:
|
||||||
|
answer: str = voice["answer"]
|
||||||
|
start_time: float = voice["start"]
|
||||||
|
end_time: float = voice["end"]
|
||||||
|
words = split_and_keep(answer, puncs)
|
||||||
|
temp_words = []
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
if i > 0 and i < len(words) - 1 and random.random() < 0.15:
|
||||||
|
log.info("随机删除word")
|
||||||
|
continue
|
||||||
|
temp_words.extend(word.split(" "))
|
||||||
|
if len(temp_words) == 0:
|
||||||
|
temp_words = words[0].split(" ")
|
||||||
|
words = temp_words
|
||||||
|
answer = " ".join(words)
|
||||||
|
words = list(map(lambda x: x.strip(), words))
|
||||||
|
words = list(filter(lambda x: len(x) > 0, words))
|
||||||
|
|
||||||
|
# 将时间均匀分配到每个字上
|
||||||
|
words_withtime = []
|
||||||
|
word_unittime = (end_time - start_time) / len(words)
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
word_start = start_time + word_unittime * i
|
||||||
|
word_end = word_start + word_unittime
|
||||||
|
words_withtime.append(
|
||||||
|
{
|
||||||
|
"text": word,
|
||||||
|
"start_time": word_start * 1000,
|
||||||
|
"end_time": word_end * 1000,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将句子首尾的标点符号时间扩展到字上 标点符号时间为瞬间
|
||||||
|
punc_at = 0
|
||||||
|
while punc_at < len(words) and words[punc_at] in puncs:
|
||||||
|
punc_at += 1
|
||||||
|
if punc_at < len(words):
|
||||||
|
words_withtime[punc_at]["start_time"] = words_withtime[0][
|
||||||
|
"start_time"
|
||||||
|
]
|
||||||
|
for i in range(0, punc_at):
|
||||||
|
words_withtime[i]["start_time"] = words_withtime[0]["start_time"]
|
||||||
|
words_withtime[i]["end_time"] = words_withtime[0]["start_time"]
|
||||||
|
punc_at = len(words) - 1
|
||||||
|
while punc_at >= 0 and words[punc_at] in puncs:
|
||||||
|
punc_at -= 1
|
||||||
|
if punc_at >= 0:
|
||||||
|
words_withtime[punc_at]["end_time"] = words_withtime[-1]["end_time"]
|
||||||
|
for i in range(punc_at + 1, len(words)):
|
||||||
|
words_withtime[i]["start_time"] = (
|
||||||
|
words_withtime[-1]["end_time"] + 0.1
|
||||||
|
)
|
||||||
|
words_withtime[i]["end_time"] = words_withtime[-1]["end_time"] + 0.1
|
||||||
|
|
||||||
|
if random.random() < 0.4 and len(words_withtime) > 1:
|
||||||
|
log.info("发送一次final_result=False")
|
||||||
|
rand_idx = random.randint(1, len(words_withtime) - 1)
|
||||||
|
recognition_result = {
|
||||||
|
"text": " ".join(
|
||||||
|
map(lambda x: x["text"], words_withtime[:rand_idx])
|
||||||
|
),
|
||||||
|
"final_result": False,
|
||||||
|
"para_seq": para_seq,
|
||||||
|
"language": "de",
|
||||||
|
"start_time": start_time * 1000,
|
||||||
|
"end_time": end_time * 1000,
|
||||||
|
"words": words_withtime[:rand_idx],
|
||||||
|
}
|
||||||
|
callback(recognition_result)
|
||||||
|
|
||||||
|
recognition_result = {
|
||||||
|
"text": answer,
|
||||||
|
"final_result": True,
|
||||||
|
"para_seq": para_seq,
|
||||||
|
"language": "de",
|
||||||
|
"start_time": start_time * 1000,
|
||||||
|
"end_time": end_time * 1000,
|
||||||
|
"words": words_withtime,
|
||||||
|
}
|
||||||
|
callback(recognition_result)
|
||||||
|
para_seq += 1
|
||||||
|
log.info("send %s" % para_seq)
|
||||||
|
|
||||||
|
time.sleep(send_interval)
|
||||||
|
|
||||||
|
callback(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ignore END
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host="0.0.0.0", port=80)
|
||||||
3
starting_kit/requirements.txt
Normal file
3
starting_kit/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
flask
|
||||||
|
requests
|
||||||
|
pyyaml
|
||||||
16
tests/test_callback_editops.py
Normal file
16
tests/test_callback_editops.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from schemas.dataset import QueryData
|
||||||
|
from schemas.stream import StreamDataModel
|
||||||
|
from utils.evaluator_plus import evaluate_editops
|
||||||
|
|
||||||
|
with open("out/detail_cases.json") as f:
|
||||||
|
detail_cases = json.load(f)
|
||||||
|
|
||||||
|
detail_case = detail_cases[0]
|
||||||
|
preds = []
|
||||||
|
for pred in detail_case["preds"]:
|
||||||
|
preds.append(StreamDataModel.model_validate(pred))
|
||||||
|
label = QueryData.model_validate(detail_case["label"])
|
||||||
|
|
||||||
|
print(evaluate_editops(label, preds))
|
||||||
93
tests/test_cer.py
Normal file
93
tests/test_cer.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
f(a, b) 计算 a -> b 的编辑距离,使用的方法是之前asr榜单的方法
|
||||||
|
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
|
||||||
|
test() 是对拍程序
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import Levenshtein
|
||||||
|
|
||||||
|
|
||||||
|
def mapping(gt: str, dt: str):
|
||||||
|
return [i for i in gt], [i for i in dt]
|
||||||
|
|
||||||
|
|
||||||
|
def token_mapping(
|
||||||
|
tokens_gt: List[str], tokens_dt: List[str]
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
arr1 = deepcopy(tokens_gt)
|
||||||
|
arr2 = deepcopy(tokens_dt)
|
||||||
|
operations = Levenshtein.editops(arr1, arr2)
|
||||||
|
for op in operations[::-1]:
|
||||||
|
if op[0] == "insert":
|
||||||
|
arr1.insert(op[1], None)
|
||||||
|
elif op[0] == "delete":
|
||||||
|
arr2.insert(op[2], None)
|
||||||
|
return arr1, arr2
|
||||||
|
|
||||||
|
|
||||||
|
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||||||
|
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||||||
|
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||||||
|
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||||||
|
equal = sum(
|
||||||
|
1
|
||||||
|
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
|
||||||
|
if token_gt == token_dt
|
||||||
|
)
|
||||||
|
replace = len(tokens_gt_mapping) - insert - equal # - delete
|
||||||
|
return replace, delete, insert
|
||||||
|
|
||||||
|
|
||||||
|
def f(a, b):
|
||||||
|
return cer(*token_mapping(*mapping(a, b)))
|
||||||
|
|
||||||
|
|
||||||
|
def raw(tokens_gt, tokens_dt):
|
||||||
|
arr1 = deepcopy(tokens_gt)
|
||||||
|
arr2 = deepcopy(tokens_dt)
|
||||||
|
operations = Levenshtein.editops(arr1, arr2)
|
||||||
|
insert = 0
|
||||||
|
delete = 0
|
||||||
|
replace = 0
|
||||||
|
for op in operations:
|
||||||
|
if op[0] == "insert":
|
||||||
|
insert += 1
|
||||||
|
if op[0] == "delete":
|
||||||
|
delete += 1
|
||||||
|
if op[0] == "replace":
|
||||||
|
replace += 1
|
||||||
|
return replace, delete, insert
|
||||||
|
|
||||||
|
|
||||||
|
def g(a, b):
|
||||||
|
return raw(*mapping(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
def check(a, b):
|
||||||
|
ff = f(a, b)
|
||||||
|
gg = g(a, b)
|
||||||
|
if ff != gg:
|
||||||
|
print(ff, gg)
|
||||||
|
return ff == gg
|
||||||
|
|
||||||
|
|
||||||
|
def random_string(length):
|
||||||
|
letters = string.ascii_lowercase
|
||||||
|
return "".join(random.choice(letters) for i in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
for _ in range(10000):
|
||||||
|
a = random_string(30)
|
||||||
|
b = random_string(30)
|
||||||
|
if not check(a, b):
|
||||||
|
print(a, b)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
test()
|
||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
57
utils/asr_ter.py
Normal file
57
utils/asr_ter.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# copy from
|
||||||
|
# https://gitlab.4pd.io/scene_lab/leaderboard/judge_flows/foundamental_capability/blob/master/utils/asr_ter.py
|
||||||
|
|
||||||
|
|
||||||
|
def calc_ter_speechio(pred, ref, language="zh"):
|
||||||
|
assert language == "zh", "Unsupported language %s" % language
|
||||||
|
assert ref is not None and ref != "", "Reference script cannot be empty"
|
||||||
|
if language == "zh":
|
||||||
|
from .speechio import error_rate_zh as error_rate
|
||||||
|
from .speechio import textnorm_zh as textnorm
|
||||||
|
|
||||||
|
normalizer = textnorm.TextNorm(
|
||||||
|
to_banjiao=True,
|
||||||
|
to_upper=True,
|
||||||
|
to_lower=False,
|
||||||
|
remove_fillers=True,
|
||||||
|
remove_erhua=True,
|
||||||
|
check_chars=False,
|
||||||
|
remove_space=False,
|
||||||
|
cc_mode="",
|
||||||
|
)
|
||||||
|
norm_pred = normalizer(pred if pred is not None else "")
|
||||||
|
norm_ref = normalizer(ref)
|
||||||
|
tokenizer = "char"
|
||||||
|
alignment, score = error_rate.EditDistance(
|
||||||
|
error_rate.tokenize_text(norm_ref, tokenizer),
|
||||||
|
error_rate.tokenize_text(norm_pred, tokenizer),
|
||||||
|
)
|
||||||
|
c, s, i, d = error_rate.CountEdits(alignment)
|
||||||
|
ter = error_rate.ComputeTokenErrorRate(c, s, i, d) / 100.0
|
||||||
|
return {"ter": ter, "err_token_cnt": s + d + i, "ref_all_token_cnt": s + d + c}
|
||||||
|
assert False, "Bug, not reachable"
|
||||||
|
|
||||||
|
|
||||||
|
def calc_ter_wjs(pred, ref, language="zh"):
|
||||||
|
assert language == "zh", "Unsupported language %s" % language
|
||||||
|
assert ref is not None and ref != "", "Reference script cannot be empty"
|
||||||
|
from . import wjs_asr_wer
|
||||||
|
|
||||||
|
ignore_words = set()
|
||||||
|
case_sensitive = False
|
||||||
|
split = None
|
||||||
|
calculator = wjs_asr_wer.Calculator()
|
||||||
|
norm_pred = wjs_asr_wer.normalize(
|
||||||
|
wjs_asr_wer.characterize(pred if pred is not None else ""),
|
||||||
|
ignore_words,
|
||||||
|
case_sensitive,
|
||||||
|
split,
|
||||||
|
)
|
||||||
|
norm_ref = wjs_asr_wer.normalize(wjs_asr_wer.characterize(ref), ignore_words, case_sensitive, split)
|
||||||
|
result = calculator.calculate(norm_pred, norm_ref)
|
||||||
|
ter = ((result["ins"] + result["sub"] + result["del"]) * 1.0 / result["all"]) if result["all"] != 0 else 1.0
|
||||||
|
return {
|
||||||
|
"ter": ter,
|
||||||
|
"err_token_cnt": result["ins"] + result["sub"] + result["del"],
|
||||||
|
"ref_all_token_cnt": result["all"],
|
||||||
|
}
|
||||||
224
utils/client.py
Normal file
224
utils/client.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import websocket
|
||||||
|
from pydantic_core import ValidationError
|
||||||
|
from websocket import create_connection
|
||||||
|
|
||||||
|
from schemas.context import ASRContext
|
||||||
|
from schemas.stream import StreamDataModel, StreamResultModel
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
|
def __init__(self, sut_url: str, context: ASRContext) -> None:
|
||||||
|
# base_url = "ws://127.0.0.1:5003"
|
||||||
|
self.base_url = sut_url + "/recognition"
|
||||||
|
logger.info(f"{self.base_url}")
|
||||||
|
self.context: ASRContext = deepcopy(context)
|
||||||
|
# if not os.getenv("DATASET_FILEPATH", ""):
|
||||||
|
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
|
||||||
|
# self.base_url = "ws://localhost:5003/recognition"
|
||||||
|
self.connect_num = 0
|
||||||
|
self.exception = False
|
||||||
|
self.close_time = 10**50
|
||||||
|
self.send_time: List[float] = []
|
||||||
|
self.recv_time: List[float] = []
|
||||||
|
self.predict_data: List[Any] = []
|
||||||
|
self.success = True
|
||||||
|
|
||||||
|
def action(self):
|
||||||
|
# 如果 5 次初始化都失败,则退出
|
||||||
|
connect_success = False
|
||||||
|
for i in range(5):
|
||||||
|
try:
|
||||||
|
self._connect_init()
|
||||||
|
connect_success = True
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"第 {i+1} 次连接失败,原因:{e}")
|
||||||
|
time.sleep(int(os.getenv("connect_sleep", 10)))
|
||||||
|
if not connect_success:
|
||||||
|
exit(-1)
|
||||||
|
self.trecv = threading.Thread(target=self._recv)
|
||||||
|
self.trecv.start()
|
||||||
|
self._send()
|
||||||
|
self._close()
|
||||||
|
return self._gen_result()
|
||||||
|
|
||||||
|
def _connect_init(self):
|
||||||
|
end_time = time.time() + float(os.getenv("end_time", 2))
|
||||||
|
success = False
|
||||||
|
try:
|
||||||
|
self.ws = create_connection(self.base_url)
|
||||||
|
self.ws.send(json.dumps(self._gen_init_data()))
|
||||||
|
while time.time() < end_time and not success:
|
||||||
|
data = self.ws.recv()
|
||||||
|
logger.info(f"data {data}")
|
||||||
|
if len(data) == 0:
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
if isinstance(data, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(data)
|
||||||
|
except Exception:
|
||||||
|
raise Exception("初始化阶段,数据不是 json 字符串格式,终止流程")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
success = data.get("success", False)
|
||||||
|
if not success:
|
||||||
|
logger.error(f"初始化失败,返回的结果为 {data},终止流程")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
logger.error("初始化阶段,数据不是 json 字符串格式,终止流程")
|
||||||
|
exit(-1)
|
||||||
|
except websocket.WebSocketConnectionClosedException or TimeoutError:
|
||||||
|
raise Exception("初始化阶段连接中断,终止流程")
|
||||||
|
# exit(-1)
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
raise Exception("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
|
||||||
|
# logger.error("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
|
||||||
|
# self.connect_num += 1
|
||||||
|
# if self.connect_num <= 4:
|
||||||
|
# time.sleep(int(os.getenv("connect_sleep", 10)))
|
||||||
|
# self._connect_init()
|
||||||
|
# success = True
|
||||||
|
# else:
|
||||||
|
# logger.error("初始化阶段连接失败多次")
|
||||||
|
# exit(-1)
|
||||||
|
if not success:
|
||||||
|
# logger.error("初始化阶段 60s 没有返回数据,时间太长,终止流程")
|
||||||
|
raise Exception("初始化阶段 60s 没有返回数据,时间太长,终止流程")
|
||||||
|
else:
|
||||||
|
logger.info("建立连接成功")
|
||||||
|
self.connect_num = 0
|
||||||
|
|
||||||
|
def _send(self):
|
||||||
|
send_ts = float(os.getenv("send_interval", 60))
|
||||||
|
if not self.success:
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(self.context.file_path, "rb") as fp:
|
||||||
|
wav_data = fp.read()
|
||||||
|
meta_length = wav_data.index(b"data") + 8
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.context.file_path, "rb") as fp:
|
||||||
|
# 去掉 wav 文件的头信息
|
||||||
|
fp.read(meta_length)
|
||||||
|
# 上一段音频的发送时间
|
||||||
|
last_send_time = -1
|
||||||
|
# 正文内容
|
||||||
|
while True:
|
||||||
|
now_time = time.perf_counter()
|
||||||
|
if last_send_time == -1:
|
||||||
|
chunk = fp.read(int(self.context.chunk_size))
|
||||||
|
else:
|
||||||
|
interval_cnt = max(
|
||||||
|
int((now_time - last_send_time) / self.context.wait_time),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
chunk = fp.read(int(self.context.chunk_size * interval_cnt))
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
send_time_start = time.perf_counter()
|
||||||
|
self.ws.send(chunk, websocket.ABNF.OPCODE_BINARY)
|
||||||
|
self.send_time.append(send_time_start)
|
||||||
|
last_send_time = send_time_start
|
||||||
|
send_time_end = time.perf_counter()
|
||||||
|
if send_time_end - send_time_start > send_ts:
|
||||||
|
logger.error(f"发送延迟已经超过 {send_ts}s, 终止当前音频发送")
|
||||||
|
break
|
||||||
|
if (sleep_time := self.context.wait_time + now_time - send_time_end) > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
logger.info("当条语音数据发送完成")
|
||||||
|
self.ws.send(json.dumps({"end": True}))
|
||||||
|
logger.info("2s 后关闭双向连接.")
|
||||||
|
except BrokenPipeError:
|
||||||
|
logger.error("发送数据出错,被测服务出现故障")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception: {e}")
|
||||||
|
logger.error(f"{traceback.print_exc()}")
|
||||||
|
logger.error("发送数据失败")
|
||||||
|
self.success = False
|
||||||
|
# self.close_time = time.perf_counter() + int(os.getenv("api_timeout", 2))
|
||||||
|
self.close_time = time.perf_counter() + 20 * 60
|
||||||
|
|
||||||
|
def _recv(self):
|
||||||
|
try:
|
||||||
|
while self.ws.connected and self.success:
|
||||||
|
recv_data = self.ws.recv()
|
||||||
|
if isinstance(recv_data, str):
|
||||||
|
if recv_data := str(recv_data):
|
||||||
|
self.recv_time.append(time.perf_counter())
|
||||||
|
# 识别到最后的合并结果后再关闭
|
||||||
|
recognition_results = StreamResultModel(**json.loads(recv_data)).recognition_results
|
||||||
|
if (
|
||||||
|
recognition_results.final_result
|
||||||
|
and recognition_results.start_time == 0
|
||||||
|
and recognition_results.end_time == 0
|
||||||
|
and recognition_results.para_seq == 0
|
||||||
|
):
|
||||||
|
self.success = False
|
||||||
|
else:
|
||||||
|
self.predict_data.append(recv_data)
|
||||||
|
# if recv_data.recognition_results.final_result and (IN_TEST or os.getenv('test')):
|
||||||
|
# logger.info(f"recv_data {recv_data}")
|
||||||
|
else:
|
||||||
|
self.success = False
|
||||||
|
raise Exception("返回的结果不是字符串形式")
|
||||||
|
except websocket.WebSocketConnectionClosedException:
|
||||||
|
logger.error("WebSocketConnectionClosedException")
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error("返回的结果不符合格式")
|
||||||
|
logger.error(f"Exception is {e}")
|
||||||
|
os._exit(1)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.error(f"{traceback.print_exc()}")
|
||||||
|
logger.error("处理被测服务返回数据时出错")
|
||||||
|
self.success = False
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
while time.perf_counter() < self.close_time and self.success:
|
||||||
|
# while not self.success:
|
||||||
|
time.sleep(1)
|
||||||
|
try:
|
||||||
|
self.ws.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _gen_result(self) -> dict:
|
||||||
|
if not self.predict_data:
|
||||||
|
logger.error("没有任何数据返回")
|
||||||
|
self.predict_data = [StreamResultModel(**json.loads(data)).recognition_results for data in self.predict_data]
|
||||||
|
# for item in self.predict_data:
|
||||||
|
# if item.final_result and (IN_TEST or os.getenv('test')):
|
||||||
|
# logger.info(f"recv_data {item}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"fail": not self.predict_data,
|
||||||
|
"send_time": self.send_time,
|
||||||
|
"recv_time": self.recv_time,
|
||||||
|
"predict_data": self.predict_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _gen_init_data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"parameter": {
|
||||||
|
"lang": self.context.lang,
|
||||||
|
"sample_rate": self.context.sample_rate,
|
||||||
|
"channel": self.context.channel,
|
||||||
|
"format": self.context.audio_format,
|
||||||
|
"bits": self.context.bits,
|
||||||
|
"enable_words": self.context.enable_words,
|
||||||
|
}
|
||||||
|
}
|
||||||
277
utils/client_async.py
Normal file
277
utils/client_async.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from copy import deepcopy
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
from pydantic_core import ValidationError
|
||||||
|
|
||||||
|
from schemas.context import ASRContext
|
||||||
|
from schemas.stream import StreamResultModel, StreamWordsModel
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||||
|
|
||||||
|
|
||||||
|
class STATUS_DATA(str, Enum):
|
||||||
|
WAITING_FIRST_INIT = "waiting_first_init"
|
||||||
|
FIRST_FAIL = "fail"
|
||||||
|
WAITING_SECOND_INIT = "waiting_second_init"
|
||||||
|
SECOND_INIT = "second_fail"
|
||||||
|
WAITING_THIRD_INIT = "waiting_third_init"
|
||||||
|
THIRD_INIT = "third_fail"
|
||||||
|
SUCCESS = "success"
|
||||||
|
CLOSED = "closed"
|
||||||
|
|
||||||
|
|
||||||
|
class ClientAsync:
|
||||||
|
def __init__(self, sut_url: str, context: ASRContext, idx: int) -> None:
|
||||||
|
# base_url = "ws://127.0.0.1:5003"
|
||||||
|
self.base_url = sut_url + "/recognition"
|
||||||
|
self.context: ASRContext = deepcopy(context)
|
||||||
|
self.idx = idx
|
||||||
|
# if not os.getenv("DATASET_FILEPATH", ""):
|
||||||
|
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
|
||||||
|
# self.base_url = "ws://localhost:5003/recognition"
|
||||||
|
self.fail_count = 0
|
||||||
|
self.close_time = 10**50
|
||||||
|
self.send_time: List[float] = []
|
||||||
|
self.recv_time: List[float] = []
|
||||||
|
self.predict_data: List[Any] = []
|
||||||
|
|
||||||
|
async def _sender(
|
||||||
|
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
|
||||||
|
):
|
||||||
|
# 设置 websocket 缓冲区大小
|
||||||
|
websocket.transport.set_write_buffer_limits(1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
# 发送初始化数据
|
||||||
|
await websocket.send(json.dumps(self._gen_init_data()))
|
||||||
|
await send_queue.put(STATUS_DATA.WAITING_FIRST_INIT)
|
||||||
|
connect_status = await recv_queue.get()
|
||||||
|
if connect_status == STATUS_DATA.FIRST_FAIL:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 开始发送音频
|
||||||
|
with open(self.context.file_path, "rb") as fp:
|
||||||
|
wav_data = fp.read()
|
||||||
|
meta_length = wav_data.index(b"data") + 8
|
||||||
|
try:
|
||||||
|
with open(self.context.file_path, "rb") as fp:
|
||||||
|
# 去掉 wav 文件的头信息
|
||||||
|
fp.read(meta_length)
|
||||||
|
wav_time = 0.0
|
||||||
|
label_id = 0
|
||||||
|
char_contains_rate_checktime = []
|
||||||
|
char_contains_rate_checktime_id = 0
|
||||||
|
while True:
|
||||||
|
now_time = time.perf_counter()
|
||||||
|
chunk = fp.read(int(self.context.chunk_size))
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
wav_time += self.context.wait_time
|
||||||
|
try:
|
||||||
|
self.send_time.append(time.perf_counter())
|
||||||
|
await asyncio.wait_for(websocket.send(chunk), timeout=0.08)
|
||||||
|
except asyncio.exceptions.TimeoutError:
|
||||||
|
pass
|
||||||
|
while label_id < len(self.context.labels) and wav_time >= self.context.labels[label_id].start:
|
||||||
|
char_contains_rate_checktime.append(now_time + 3.0)
|
||||||
|
label_id += 1
|
||||||
|
predict_text_len = sum(map(lambda x: len(x.text), self.predict_data))
|
||||||
|
while char_contains_rate_checktime_id < len(char_contains_rate_checktime) and \
|
||||||
|
char_contains_rate_checktime[char_contains_rate_checktime_id] <= now_time:
|
||||||
|
label_text_len = sum(
|
||||||
|
map(lambda x: len(x.answer),
|
||||||
|
self.context.labels[:char_contains_rate_checktime_id+1]))
|
||||||
|
if predict_text_len / self.context.char_contains_rate < label_text_len:
|
||||||
|
self.context.fail_char_contains_rate_num += 1
|
||||||
|
char_contains_rate_checktime_id += 1
|
||||||
|
await asyncio.sleep(max(0, self.context.wait_time - (time.perf_counter() - now_time)))
|
||||||
|
await websocket.send(json.dumps({"end": True}))
|
||||||
|
logger.info(f"第 {self.idx} 条数据,当条语音数据发送完成")
|
||||||
|
logger.info(f"第 {self.idx} 条数据,3s 后关闭双向连接.")
|
||||||
|
self.close_time = time.perf_counter() + 3
|
||||||
|
except websockets.exceptions.ConnectionClosedError:
|
||||||
|
logger.error(f"第 {self.idx} 条数据发送过程中,连接断开")
|
||||||
|
except Exception:
|
||||||
|
logger.error(f"{traceback.print_exc()}")
|
||||||
|
logger.error(f"第 {self.idx} 条数据,发送数据失败")
|
||||||
|
|
||||||
|
async def _recv(
|
||||||
|
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
|
||||||
|
):
|
||||||
|
await recv_queue.get()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(websocket.recv(), timeout=2)
|
||||||
|
except asyncio.exceptions.TimeoutError:
|
||||||
|
await send_queue.put(STATUS_DATA.FIRST_FAIL)
|
||||||
|
logger.info(f"第 {self.idx} 条数据,初始化阶段, 2s 没收到 success 返回,超时了")
|
||||||
|
self.fail_count += 1
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
await send_queue.put(STATUS_DATA.FIRST_FAIL)
|
||||||
|
logger.error(f"第 {self.idx} 条数据,初始化阶段, 收到异常:{e}")
|
||||||
|
self.fail_count += 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await send_queue.put(STATUS_DATA.SUCCESS)
|
||||||
|
|
||||||
|
# 开始接收语音识别结果
|
||||||
|
try:
|
||||||
|
while websocket.open:
|
||||||
|
# 接收数据
|
||||||
|
recv_data = await websocket.recv()
|
||||||
|
if isinstance(recv_data, str):
|
||||||
|
self.recv_time.append(time.perf_counter())
|
||||||
|
recv_data = str(recv_data)
|
||||||
|
recv_data = json.loads(recv_data)
|
||||||
|
result = StreamResultModel(**recv_data)
|
||||||
|
recognition_results = result.asr_results
|
||||||
|
if (
|
||||||
|
recognition_results.final_result
|
||||||
|
and not recognition_results.language
|
||||||
|
and recognition_results.start_time == 0
|
||||||
|
and recognition_results.end_time == 0
|
||||||
|
and recognition_results.para_seq == 0
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.predict_data.append(recognition_results)
|
||||||
|
else:
|
||||||
|
raise Exception("返回的结果不是字符串形式")
|
||||||
|
except websockets.exceptions.ConnectionClosedOK:
|
||||||
|
pass
|
||||||
|
except websockets.exceptions.ConnectionClosedError:
|
||||||
|
pass
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(f"第 {self.idx} 条数据,返回的结果不符合格式")
|
||||||
|
logger.error(f"Exception is {e}")
|
||||||
|
os._exit(1)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.error(f"{traceback.print_exc()}")
|
||||||
|
logger.error(f"第 {self.idx} 条数据,处理被测服务返回数据时出错")
|
||||||
|
|
||||||
|
async def _action(self):
|
||||||
|
logger.info(f"第 {self.idx} 条数据开始测试")
|
||||||
|
|
||||||
|
while self.fail_count < 3:
|
||||||
|
|
||||||
|
send_queue = asyncio.Queue()
|
||||||
|
recv_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
self.send_time: List[float] = []
|
||||||
|
self.recv_time: List[float] = []
|
||||||
|
self.predict_data: List[Any] = []
|
||||||
|
|
||||||
|
async with websockets.connect(self.base_url) as websocket:
|
||||||
|
send_task = asyncio.create_task(self._sender(websocket, send_queue, recv_queue))
|
||||||
|
recv_task = asyncio.create_task(self._recv(websocket, recv_queue, send_queue))
|
||||||
|
|
||||||
|
await asyncio.gather(send_task)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
await asyncio.gather(recv_task)
|
||||||
|
|
||||||
|
if self.send_time:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.fail_count += 1
|
||||||
|
logger.info(f"第 {self.idx} 条数据,初始化阶段, 第 {self.fail_count} 次失败, 1s 后重试")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def action(self):
|
||||||
|
asyncio.run(self._action())
|
||||||
|
return self._gen_result()
|
||||||
|
|
||||||
|
def _gen_result(self) -> ASRContext:
|
||||||
|
if not self.predict_data:
|
||||||
|
logger.error(f"第 {self.idx} 条数据,没有任何数据返回")
|
||||||
|
self.context.append_preds(self.predict_data, self.send_time, self.recv_time)
|
||||||
|
self.context.fail = not self.predict_data
|
||||||
|
|
||||||
|
punctuation_words: List[StreamWordsModel] = []
|
||||||
|
for pred in self.predict_data:
|
||||||
|
punctuations = [",", ".", "!", "?"]
|
||||||
|
if pred.language == "zh":
|
||||||
|
punctuations = [",", "。", "!", "?"]
|
||||||
|
elif pred.language == "ja":
|
||||||
|
punctuations = ["、", "。", "!", "?"]
|
||||||
|
elif pred.language in ("ar", "fa"):
|
||||||
|
punctuations = ["،", ".", "!", "؟"]
|
||||||
|
elif pred.language == "el":
|
||||||
|
punctuations = [",", ".", "!", ";"]
|
||||||
|
elif pred.language == "ti":
|
||||||
|
punctuations = ["།"]
|
||||||
|
|
||||||
|
for word in pred.words:
|
||||||
|
if word.text in punctuations:
|
||||||
|
punctuation_words.append(word)
|
||||||
|
start_times = list(map(lambda x: x.start_time, punctuation_words))
|
||||||
|
start_times = sorted(start_times)
|
||||||
|
end_times = list(map(lambda x: x.end_time, punctuation_words))
|
||||||
|
end_times = sorted(end_times)
|
||||||
|
|
||||||
|
self.context.punctuation_num = len(self.context.labels)
|
||||||
|
label_n = len(self.context.labels)
|
||||||
|
for i, label in enumerate(self.context.labels):
|
||||||
|
label_left = (label.end - 0.7)
|
||||||
|
label_right = (label.end + 0.7)
|
||||||
|
if i < label_n - 1:
|
||||||
|
label_left = label.end
|
||||||
|
label_right = self.context.labels[i+1].start
|
||||||
|
|
||||||
|
exist = False
|
||||||
|
|
||||||
|
def upper_bound(x: float, lst: List[float]) -> int:
|
||||||
|
ans = -1
|
||||||
|
left, right = 0, len(lst) - 1
|
||||||
|
while left <= right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if lst[mid] >= x:
|
||||||
|
ans = mid
|
||||||
|
right = mid - 1
|
||||||
|
else:
|
||||||
|
left = mid + 1
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def lower_bound(x: float, lst: List[float]) -> int:
|
||||||
|
ans = -1
|
||||||
|
left, right = 0, len(lst) - 1
|
||||||
|
while left <= right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if lst[mid] <= x:
|
||||||
|
ans = mid
|
||||||
|
left = mid + 1
|
||||||
|
else:
|
||||||
|
right = mid - 1
|
||||||
|
return ans
|
||||||
|
|
||||||
|
left_in_pred = upper_bound(label_left, start_times)
|
||||||
|
if left_in_pred != -1 and start_times[left_in_pred] <= label_right:
|
||||||
|
exist = True
|
||||||
|
right_in_pred = lower_bound(label_right, end_times)
|
||||||
|
if right_in_pred != -1 and end_times[right_in_pred] >= label_left:
|
||||||
|
exist = True
|
||||||
|
|
||||||
|
if exist:
|
||||||
|
self.context.pred_punctuation_num += 1
|
||||||
|
return self.context
|
||||||
|
|
||||||
|
def _gen_init_data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"parameter": {
|
||||||
|
"lang": None,
|
||||||
|
"sample_rate": self.context.sample_rate,
|
||||||
|
"channel": self.context.channel,
|
||||||
|
"format": self.context.audio_format,
|
||||||
|
"bits": self.context.bits,
|
||||||
|
"enable_words": self.context.enable_words,
|
||||||
|
}
|
||||||
|
}
|
||||||
409
utils/client_callback.py
Normal file
409
utils/client_callback.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from flask import Flask, abort, request
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||||
|
|
||||||
|
from schemas.dataset import QueryData
|
||||||
|
from schemas.stream import StreamDataModel
|
||||||
|
from utils.evaluator_plus import evaluate_editops, evaluate_punctuation
|
||||||
|
|
||||||
|
from .logger import log
|
||||||
|
|
||||||
|
MY_POD_IP = os.environ["MY_POD_IP"]
|
||||||
|
|
||||||
|
|
||||||
|
class StopException(Exception): ...
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluateResult(BaseModel):
|
||||||
|
lang: str
|
||||||
|
cer: float
|
||||||
|
align_start: Dict[int, int] = Field(
|
||||||
|
description="句首字对齐时间差值(ms) -> 对齐数"
|
||||||
|
)
|
||||||
|
align_end: Dict[int, int] = Field(
|
||||||
|
description="句尾字对齐时间差值(ms) -> 对齐数"
|
||||||
|
)
|
||||||
|
first_word_distance_sum: float = Field(description="句首字距离总和(s)")
|
||||||
|
last_word_distance_sum: float = Field(description="句尾字距离总和(s)")
|
||||||
|
rtf: float = Field(description="翻译速度")
|
||||||
|
first_receive_delay: float = Field(description="首包接收延迟(s)")
|
||||||
|
query_count: int = Field(description="音频数")
|
||||||
|
voice_count: int = Field(description="句子数")
|
||||||
|
pred_punctuation_num: int = Field(description="预测标点数")
|
||||||
|
label_punctuation_num: int = Field(description="标注标点数")
|
||||||
|
pred_sentence_punctuation_num: int = Field(description="预测句子标点数")
|
||||||
|
label_setence_punctuation_num: int = Field(description="标注句子标点数")
|
||||||
|
preds: List[StreamDataModel] = Field(description="预测结果")
|
||||||
|
label: QueryData = Field(description="标注结果")
|
||||||
|
|
||||||
|
|
||||||
|
class ResultModel(BaseModel):
|
||||||
|
taskId: str
|
||||||
|
status: str
|
||||||
|
message: str = Field("")
|
||||||
|
recognition_results: Optional[StreamDataModel] = Field(None)
|
||||||
|
|
||||||
|
@field_validator("recognition_results", mode="after")
|
||||||
|
def convert_to_seconds(cls, v: Optional[StreamDataModel], values):
|
||||||
|
# 在这里处理除以1000的逻辑
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
v.end_time = v.end_time / 1000
|
||||||
|
v.start_time = v.start_time / 1000
|
||||||
|
for word in v.words:
|
||||||
|
word.start_time /= 1000
|
||||||
|
word.end_time /= 1000
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ClientCallback:
|
||||||
|
def __init__(self, sut_url: str, port: int):
|
||||||
|
self.sut_url = sut_url #sut_url:ASR 服务的 URL(如 http://asr-service:8080)
|
||||||
|
self.port = port #port:当前客户端监听的端口(用于接收回调)
|
||||||
|
|
||||||
|
#创建 Flask 应用并注册路由
|
||||||
|
self.app = Flask(__name__)
|
||||||
|
self.app.add_url_rule(
|
||||||
|
"/api/asr/batch-callback/<taskId>",
|
||||||
|
view_func=self.asr_callback,
|
||||||
|
methods=["POST"],
|
||||||
|
)
|
||||||
|
self.app.add_url_rule(
|
||||||
|
"/api/asr-runner/report",
|
||||||
|
view_func=self.heartbeat,
|
||||||
|
methods=["POST"],
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
路由 1:/api/asr/batch-callback/<taskId>
|
||||||
|
接收 ASR 服务的识别结果回调(self.asr_callback 处理)。
|
||||||
|
taskId 是路径参数,用于标识具体任务。
|
||||||
|
路由 2:/api/asr-runner/report
|
||||||
|
接收 ASR 服务的心跳检测请求(self.heartbeat 处理)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.getLogger("werkzeug").disabled = True
|
||||||
|
threading.Thread(
|
||||||
|
target=self.app.run, args=("0.0.0.0", port), daemon=True
|
||||||
|
).start()
|
||||||
|
self.mutex = threading.Lock()
|
||||||
|
self.finished = threading.Event()
|
||||||
|
self.product_avaiable = True
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.begin_time = None
|
||||||
|
self.end_time = None
|
||||||
|
self.first_receive_time = None
|
||||||
|
self.last_heartbeat_time = None
|
||||||
|
self.app_on = False
|
||||||
|
self.para_seq = 0
|
||||||
|
self.finished.clear()
|
||||||
|
self.error: Optional[str] = None
|
||||||
|
self.last_recognition_result: Optional[StreamDataModel] = None
|
||||||
|
self.recognition_results: List[StreamDataModel] = []
|
||||||
|
|
||||||
|
def asr_callback(self, taskId: str):
|
||||||
|
if self.app_on is False:
|
||||||
|
abort(400)
|
||||||
|
body = request.get_json(silent=True) # 静默解析JSON,失败时返回None
|
||||||
|
if body is None:
|
||||||
|
abort(404)
|
||||||
|
try:
|
||||||
|
result = ResultModel.model_validate(body) #将回调的 JSON 数据解析为 ResultModel 对象,确保结构符合预期。
|
||||||
|
except ValidationError as e:
|
||||||
|
log.error("asr_callback: 结果格式错误: %s", e)
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
#处理任务完成状态(FINISHED)
|
||||||
|
if result.status == "FINISHED":
|
||||||
|
with self.mutex:
|
||||||
|
self.stop()
|
||||||
|
return "ok"
|
||||||
|
#处理非运行状态(非 RUNNING)
|
||||||
|
if result.status != "RUNNING":
|
||||||
|
log.error(
|
||||||
|
"asr_callback: 结果状态错误: %s, message: %s",
|
||||||
|
result.status,
|
||||||
|
result.message,
|
||||||
|
)
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
recognition_result = result.recognition_results
|
||||||
|
if recognition_result is None:
|
||||||
|
log.error("asr_callback: 结果中没有recognition_results字段")
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
with self.mutex:
|
||||||
|
if not self.app_on:
|
||||||
|
log.error("asr_callback: 应用已结束")
|
||||||
|
abort(400)
|
||||||
|
|
||||||
|
if recognition_result.para_seq < self.para_seq:
|
||||||
|
error = "asr_callback: 结果中para_seq小于上一次的: %d < %d" % (
|
||||||
|
recognition_result.para_seq,
|
||||||
|
self.para_seq,
|
||||||
|
)
|
||||||
|
log.error(error)
|
||||||
|
if self.error is None:
|
||||||
|
self.error = error
|
||||||
|
self.stop()
|
||||||
|
abort(404)
|
||||||
|
if recognition_result.para_seq > self.para_seq + 1:
|
||||||
|
error = (
|
||||||
|
"asr_callback: 结果中para_seq大于上一次的+1 \
|
||||||
|
说明存在para_seq = %d没有final_result为True确认"
|
||||||
|
% (self.para_seq + 1,)
|
||||||
|
)
|
||||||
|
log.error(error)
|
||||||
|
if self.error is None:
|
||||||
|
self.error = error
|
||||||
|
self.stop()
|
||||||
|
abort(404)
|
||||||
|
if (
|
||||||
|
self.last_recognition_result is not None
|
||||||
|
and recognition_result.start_time
|
||||||
|
< self.last_recognition_result.end_time
|
||||||
|
):
|
||||||
|
error = "asr_callback: 结果中start_time小于上一次的end_time: %s < %s" % (
|
||||||
|
recognition_result.start_time,
|
||||||
|
self.last_recognition_result.end_time,
|
||||||
|
)
|
||||||
|
log.error(error)
|
||||||
|
if self.error is None:
|
||||||
|
self.error = error
|
||||||
|
self.stop()
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
self.recognition_results.append(recognition_result)
|
||||||
|
if recognition_result.final_result is True:
|
||||||
|
self.para_seq = recognition_result.para_seq
|
||||||
|
if self.last_recognition_result is None:
|
||||||
|
self.first_receive_time = time.time()
|
||||||
|
self.last_recognition_result = recognition_result
|
||||||
|
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
"""
|
||||||
|
def heartbeat(self):
|
||||||
|
if self.app_on is False:
|
||||||
|
abort(400)
|
||||||
|
body = request.get_json(silent=True)
|
||||||
|
if body is None:
|
||||||
|
abort(404)
|
||||||
|
status = body.get("status")
|
||||||
|
if status != "RUNNING":
|
||||||
|
message = body.get("message", "")
|
||||||
|
if message:
|
||||||
|
message = ", message: " + message
|
||||||
|
log.error("heartbeat: 状态错误: %s%s", status, message)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
with self.mutex:
|
||||||
|
self.last_heartbeat_time = time.time()
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
language: Optional[str],
|
||||||
|
audio_file: str,
|
||||||
|
audio_duration: float,
|
||||||
|
task_id: str,
|
||||||
|
):
|
||||||
|
#使用互斥锁确保线程安全
|
||||||
|
with self.mutex:
|
||||||
|
if self.app_on:
|
||||||
|
log.error("上一音频尚未完成处理,流程出现异常")
|
||||||
|
raise StopException()
|
||||||
|
self.reset()
|
||||||
|
self.app_on = True
|
||||||
|
|
||||||
|
#请求URL:self.sut_url + "/predict"(如 http://localhost:8080/predict)
|
||||||
|
resp = requests.post(
|
||||||
|
self.sut_url + "/predict",
|
||||||
|
data={
|
||||||
|
"language": language,
|
||||||
|
"taskId": task_id,
|
||||||
|
"progressCallbackUrl": "http://%s:%d/api/asr/batch-callback/%s"
|
||||||
|
% (MY_POD_IP, self.port, task_id),
|
||||||
|
"heartbeatUrl": "http://%s:%d/api/asr-runner/report" % (MY_POD_IP, self.port),
|
||||||
|
},
|
||||||
|
files={"file": (audio_file, open(audio_file, "rb"))},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
#响应处理
|
||||||
|
if resp.status_code != 200:
|
||||||
|
log.error("/predict接口返回http code %s", resp.status_code)
|
||||||
|
raise StopException()
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
status = resp.json().get("status")
|
||||||
|
if status != "OK":
|
||||||
|
log.error("/predict接口返回非OK状态: %s", status)
|
||||||
|
raise StopException()
|
||||||
|
#辅助线程
|
||||||
|
threading.Thread(
|
||||||
|
target=self.dead_line_check, args=(audio_duration,), daemon=True
|
||||||
|
).start()
|
||||||
|
threading.Thread(target=self.heartbeat_check, daemon=True).start()
|
||||||
|
|
||||||
|
def dead_line_check(self, audio_duration: float):
|
||||||
|
begin_time = time.time()
|
||||||
|
self.begin_time = begin_time
|
||||||
|
|
||||||
|
# 初始化 10s 延迟检测
|
||||||
|
self.sleep_to(begin_time + 10)
|
||||||
|
with self.mutex:
|
||||||
|
if self.last_recognition_result is None:
|
||||||
|
error = "首包延迟内未收到返回"
|
||||||
|
log.error(error)
|
||||||
|
if self.error is None:
|
||||||
|
self.error = error
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 第一次30s检测
|
||||||
|
next_checktime = begin_time + 30
|
||||||
|
ddl = begin_time + max((audio_duration / 6) + 10, 30)
|
||||||
|
while time.time() < ddl:
|
||||||
|
self.sleep_to(next_checktime)
|
||||||
|
with self.mutex:
|
||||||
|
if self.finished.is_set():
|
||||||
|
return
|
||||||
|
if self.last_recognition_result is None:
|
||||||
|
error = "检测追赶线过程中获取最后一次识别结果异常"
|
||||||
|
log.error(error)
|
||||||
|
if self.error is None:
|
||||||
|
self.error = error
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
|
last_end_time = self.last_recognition_result.end_time
|
||||||
|
expect_end_time = (next_checktime - begin_time - 30) * 5.4
|
||||||
|
if last_end_time < expect_end_time:
|
||||||
|
log.warning(
|
||||||
|
"识别时间位置 %s 被死亡追赶线 %s 已追上,将置为产品不可用",
|
||||||
|
last_end_time,
|
||||||
|
expect_end_time,
|
||||||
|
)
|
||||||
|
self.product_avaiable = False
|
||||||
|
self.sleep_to(ddl)
|
||||||
|
break
|
||||||
|
next_checktime = last_end_time / 5.4 + begin_time + 30 + 1
|
||||||
|
next_checktime = min(next_checktime, ddl)
|
||||||
|
with self.mutex:
|
||||||
|
if self.finished.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
log.warning("识别速度rtf低于1/6, 将置为产品不可用")
|
||||||
|
self.product_avaiable = False
|
||||||
|
self.sleep_to(begin_time + max((audio_duration / 3) + 10, 30))
|
||||||
|
with self.mutex:
|
||||||
|
if self.finished.is_set():
|
||||||
|
return
|
||||||
|
error = "处理时间超过ddl %s " % (ddl - begin_time)
|
||||||
|
log.error(error)
|
||||||
|
if self.error is None:
|
||||||
|
self.error = error
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
def heartbeat_check(self):
|
||||||
|
self.last_heartbeat_time = time.time()
|
||||||
|
while True:
|
||||||
|
with self.mutex:
|
||||||
|
if self.finished.is_set():
|
||||||
|
return
|
||||||
|
if time.time() - self.last_heartbeat_time > 30:
|
||||||
|
error = "asr_runner 心跳超时 %s" % (
|
||||||
|
time.time() - self.last_heartbeat_time
|
||||||
|
)
|
||||||
|
log.error(error)
|
||||||
|
if self.error is None:
|
||||||
|
self.error = error
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def sleep_to(self, to: float):
|
||||||
|
seconds = to - time.time()
|
||||||
|
if seconds <= 0:
|
||||||
|
return
|
||||||
|
time.sleep(seconds)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.end_time = time.time()
|
||||||
|
self.finished.set()
|
||||||
|
self.app_on = False
|
||||||
|
|
||||||
|
def evaluate(self, query_data: QueryData):
|
||||||
|
log.info("开始评估")
|
||||||
|
if (
|
||||||
|
self.begin_time is None
|
||||||
|
or self.end_time is None
|
||||||
|
or self.first_receive_time is None
|
||||||
|
):
|
||||||
|
if self.begin_time is None:
|
||||||
|
log.error("评估流程异常 无开始时间")
|
||||||
|
if self.end_time is None:
|
||||||
|
log.error("评估流程异常 无结束时间")
|
||||||
|
if self.first_receive_time is None:
|
||||||
|
log.error("评估流程异常 无首次接收时间")
|
||||||
|
raise StopException()
|
||||||
|
rtf = max(self.end_time - self.begin_time - 10, 0) / query_data.duration
|
||||||
|
first_receive_delay = max(self.first_receive_time - self.begin_time, 0)
|
||||||
|
query_count = 1
|
||||||
|
voice_count = len(query_data.voice)
|
||||||
|
preds = self.recognition_results
|
||||||
|
self.recognition_results = list(
|
||||||
|
filter(lambda x: x.final_result, self.recognition_results)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
pred_punctuation_num,
|
||||||
|
label_punctuation_num,
|
||||||
|
pred_sentence_punctuation_num,
|
||||||
|
label_setence_punctuation_num,
|
||||||
|
) = evaluate_punctuation(query_data, self.recognition_results)
|
||||||
|
|
||||||
|
(
|
||||||
|
cer,
|
||||||
|
_,
|
||||||
|
align_start,
|
||||||
|
align_end,
|
||||||
|
first_word_distance_sum,
|
||||||
|
last_word_distance_sum,
|
||||||
|
) = evaluate_editops(query_data, self.recognition_results)
|
||||||
|
|
||||||
|
if align_start[300] / voice_count < 0.8:
|
||||||
|
log.warning(
|
||||||
|
"评估结果首字300ms对齐率 %s < 0.8, 将置为产品不可用",
|
||||||
|
align_start[300] / voice_count,
|
||||||
|
)
|
||||||
|
self.product_avaiable = False
|
||||||
|
|
||||||
|
return EvaluateResult(
|
||||||
|
lang=query_data.lang,
|
||||||
|
cer=cer,
|
||||||
|
align_start=align_start,
|
||||||
|
align_end=align_end,
|
||||||
|
first_word_distance_sum=first_word_distance_sum,
|
||||||
|
last_word_distance_sum=last_word_distance_sum,
|
||||||
|
rtf=rtf,
|
||||||
|
first_receive_delay=first_receive_delay,
|
||||||
|
query_count=query_count,
|
||||||
|
voice_count=voice_count,
|
||||||
|
pred_punctuation_num=pred_punctuation_num,
|
||||||
|
label_punctuation_num=label_punctuation_num,
|
||||||
|
pred_sentence_punctuation_num=pred_sentence_punctuation_num,
|
||||||
|
label_setence_punctuation_num=label_setence_punctuation_num,
|
||||||
|
preds=preds,
|
||||||
|
label=query_data,
|
||||||
|
)
|
||||||
445
utils/evaluate.py
Normal file
445
utils/evaluate.py
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from utils import asr_ter
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
log_mid_result = int(os.getenv("log", 0)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class AsrEvaluator:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.query_count = 0 # query 数目(语音数目)
|
||||||
|
self.voice_count = 0 # 有开始和结束时间的语音条数(用于 RTF 计算)
|
||||||
|
self.cut_punc = [] # 切分标点符号,需要注意切分的时候根据列表中的顺序进行切分,比如 ... 应该放到 . 之前。
|
||||||
|
# cer 属性
|
||||||
|
self.one_minus_cer = 0 # 每个 query 的 1 - cer 和
|
||||||
|
self.token_count = 0 # 每个 query 的字数/词数和
|
||||||
|
# 句子切分率属性
|
||||||
|
self.miss_count = 0 # 每个 query miss-count 和
|
||||||
|
self.more_count = 0 # 每个 query more-count 和
|
||||||
|
self.cut_count = 0 # 每个 query cut-count 和
|
||||||
|
self.rate = 0 # 每个 query 的 cut-rate 和
|
||||||
|
# detail case
|
||||||
|
self.result = []
|
||||||
|
|
||||||
|
def evaluate(self, eval_result):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_evaluate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def gen_result(self) -> Dict:
|
||||||
|
output_result = dict()
|
||||||
|
output_result["query_count"] = self.query_count
|
||||||
|
output_result["voice_count"] = self.voice_count
|
||||||
|
output_result["token_cnt"] = self.token_count
|
||||||
|
output_result["one_minus_cer"] = self.one_minus_cer
|
||||||
|
output_result["one_minus_cer_metrics"] = self.one_minus_cer / self.query_count
|
||||||
|
output_result["miss_count"] = self.miss_count
|
||||||
|
output_result["more_count"] = self.more_count
|
||||||
|
output_result["cut_count"] = self.cut_count
|
||||||
|
output_result["cut_rate"] = self.rate
|
||||||
|
output_result["cut_rate_metrics"] = self.rate / self.query_count
|
||||||
|
output_result["rtf"] = self.rtf
|
||||||
|
output_result["rtf_end"] = self.rtf_end
|
||||||
|
output_result["rtf_metrics"] = self.rtf / self.voice_count
|
||||||
|
output_result["rtf_end_metrics"] = self.rtf_end / self.voice_count
|
||||||
|
|
||||||
|
detail_case = self.result
|
||||||
|
return output_result, detail_case
|
||||||
|
|
||||||
|
def _get_predict_final_sentences(self, predict_data: List[Dict]) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取 predict data 数据,然后将其中 final 的句子拿出来,放到列表里。
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
item["recoginition_results"]["text"]
|
||||||
|
for item in predict_data
|
||||||
|
if item["recoginition_results"]["final_result"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def _sentence_final_index(self, sentences: List[str], tokens: List[str], tokenizer="word") -> List[int]:
|
||||||
|
"""
|
||||||
|
获取 sentence 结束的字对应的 token 索引值。
|
||||||
|
"""
|
||||||
|
token_index_list = []
|
||||||
|
token_idx = 0
|
||||||
|
for sentence in sentences:
|
||||||
|
for token in Tokenizer.tokenize(sentence, tokenizer):
|
||||||
|
if token not in tokens:
|
||||||
|
continue
|
||||||
|
while tokens[token_idx] != token:
|
||||||
|
token_idx += 1
|
||||||
|
token_index_list.append(token_idx)
|
||||||
|
return token_index_list
|
||||||
|
|
||||||
|
def _voice_to_cut_sentence(self, voice_sentences: List[str]) -> Dict:
|
||||||
|
"""
|
||||||
|
将数据集的语音片段转换为最小切分单元列表。
|
||||||
|
使用 cut_punc 中的所有 punc 进行依次切分,最后去除掉完全空的内容
|
||||||
|
示例:
|
||||||
|
["你好,你好呀", "你好,我在写抽象的代码逻辑"]
|
||||||
|
->
|
||||||
|
cut_sentences: ["你好", "你好呀", "你好", "我在写抽象的代码逻辑"]
|
||||||
|
cut_sentence_index_list: [1, 3] ("你好呀" 对应 1-idx, "我在写抽象的代码逻辑" 对应 3-idx)
|
||||||
|
"""
|
||||||
|
voice_sentences_result = defaultdict(list)
|
||||||
|
for voice_sentence in voice_sentences:
|
||||||
|
sentence_list = [voice_sentence]
|
||||||
|
sentence_tmp_list = []
|
||||||
|
for punc in self.cut_punc:
|
||||||
|
for sentence in sentence_list:
|
||||||
|
sentence_tmp_list.extend(sentence.split(punc))
|
||||||
|
sentence_list, sentence_tmp_list = sentence_tmp_list, []
|
||||||
|
sentence_list = [item for item in sentence_list if item]
|
||||||
|
# 切分后的句子单元
|
||||||
|
voice_sentences_result["cut_sentences"].extend(sentence_list)
|
||||||
|
# 每个语音单元最后一个字对应的句子单元的索引
|
||||||
|
voice_sentences_result["cut_sentence_index_list"].append(len(voice_sentences_result["cut_sentences"]) - 1)
|
||||||
|
return voice_sentences_result
|
||||||
|
|
||||||
|
def _voice_bytes_index(self, timestamp, sample_rate=16000, bit_depth=16, channels=1):
|
||||||
|
"""
|
||||||
|
timestamp: 时间, 单位秒
|
||||||
|
"""
|
||||||
|
bytes_per_sample = bit_depth // 8
|
||||||
|
return timestamp * sample_rate * bytes_per_sample * channels
|
||||||
|
|
||||||
|
|
||||||
|
class AsrZhEvaluator(AsrEvaluator):
|
||||||
|
"""
|
||||||
|
中文的评估方式
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.cut_zh_punc = ["······", "......", "。", ",", "?", "!", ";", ":"]
|
||||||
|
self.cut_en_punc = ["...", ".", ",", "?", "!", ";", ":"]
|
||||||
|
self.cut_punc = self.cut_zh_punc + self.cut_en_punc
|
||||||
|
|
||||||
|
def evaluate(self, eval_result) -> Dict:
|
||||||
|
self.query_count += 1
|
||||||
|
self.voice_count += len(eval_result["voice"])
|
||||||
|
|
||||||
|
# 获取,标注结果 & 语音单元(非句子单元)
|
||||||
|
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
|
||||||
|
# print("label_voice_sentences", label_voice_sentences)
|
||||||
|
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
|
||||||
|
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
|
||||||
|
# print("voice_to_cut_info", voice_to_cut_info)
|
||||||
|
# 获取,标注结果 & 句子单元
|
||||||
|
label_sentences = voice_to_cut_info["cut_sentences"]
|
||||||
|
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
|
||||||
|
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
|
||||||
|
# 标注结果 & 句子单元 & norm 操作
|
||||||
|
label_sentences = [self._sentence_norm(sentence) for sentence in label_sentences]
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"label_sentences {label_sentences}")
|
||||||
|
# print("label_sentences", label_sentences)
|
||||||
|
|
||||||
|
# 预测结果 & 句子单元
|
||||||
|
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
|
||||||
|
# print("predict_sentences_raw", predict_sentences_raw)
|
||||||
|
# 预测结果 & 句子单元 & norm 操作
|
||||||
|
predict_sentences = [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"predict_sentences {predict_sentences}")
|
||||||
|
# print("predict_sentences", predict_sentences)
|
||||||
|
|
||||||
|
# 基于最小编辑距离进行 token 匹配,获得匹配后的 token 列表
|
||||||
|
label_tokens, predict_tokens = self._sentence_transfer("".join(label_sentences), "".join(predict_sentences))
|
||||||
|
|
||||||
|
# cer 计算
|
||||||
|
cer_info = self.cer(label_sentences, predict_sentences)
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"cer_info {cer_info}")
|
||||||
|
# print("cer_info", cer_info)
|
||||||
|
self.one_minus_cer += cer_info["one_minus_cer"]
|
||||||
|
self.token_count += cer_info["token_count"]
|
||||||
|
|
||||||
|
# 句子切分准召率
|
||||||
|
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
|
||||||
|
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
|
||||||
|
# print("cut_info", cut_info)
|
||||||
|
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
|
||||||
|
self.miss_count += cut_info["miss_count"]
|
||||||
|
self.more_count += cut_info["more_count"]
|
||||||
|
self.cut_count += cut_info["cut_count"]
|
||||||
|
self.rate += cut_info["rate"]
|
||||||
|
|
||||||
|
self.result.append(
|
||||||
|
{
|
||||||
|
"label_tokens": label_tokens,
|
||||||
|
"predict_tokens": predict_tokens,
|
||||||
|
"one_minus_cer": cer_info["one_minus_cer"],
|
||||||
|
"token_count": cer_info["one_minus_cer"],
|
||||||
|
"miss_count": cut_info["miss_count"],
|
||||||
|
"more_count": cut_info["more_count"],
|
||||||
|
"cut_count": cut_info["cut_count"],
|
||||||
|
"rate": cut_info["rate"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def cer(self, label_sentences, predict_sentences):
|
||||||
|
pred_str = ''.join(predict_sentences) if predict_sentences is not None else ''
|
||||||
|
label_str = ''.join(label_sentences)
|
||||||
|
r = asr_ter.calc_ter_speechio(pred_str, label_str)
|
||||||
|
one_minus_cer = max(1.0 - r['ter'], 0)
|
||||||
|
token_count = r['ref_all_token_cnt']
|
||||||
|
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
|
||||||
|
|
||||||
|
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
|
||||||
|
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens))
|
||||||
|
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens))
|
||||||
|
label_sentence_count = len(label_final_index_list)
|
||||||
|
miss_count = len(label_final_index_list - pred_final_index_list)
|
||||||
|
more_count = len(pred_final_index_list - label_final_index_list)
|
||||||
|
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
|
||||||
|
return {
|
||||||
|
"miss_count": miss_count,
|
||||||
|
"more_count": more_count,
|
||||||
|
"cut_count": label_sentence_count,
|
||||||
|
"rate": rate,
|
||||||
|
"label_final_index_list": label_final_index_list,
|
||||||
|
"pred_final_index_list": pred_final_index_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _sentence_norm(self, sentence, tokenizer="word"):
|
||||||
|
"""
|
||||||
|
对句子进行 norm 操作
|
||||||
|
"""
|
||||||
|
from utils.speechio import textnorm_zh as textnorm
|
||||||
|
|
||||||
|
if tokenizer == "word":
|
||||||
|
normalizer = textnorm.TextNorm(
|
||||||
|
to_banjiao=True,
|
||||||
|
to_upper=True,
|
||||||
|
to_lower=False,
|
||||||
|
remove_fillers=True,
|
||||||
|
remove_erhua=False, # 这里同批量识别不同,改成了 False
|
||||||
|
check_chars=False,
|
||||||
|
remove_space=False,
|
||||||
|
cc_mode="",
|
||||||
|
)
|
||||||
|
return normalizer(sentence)
|
||||||
|
else:
|
||||||
|
logger.error("tokenizer error, not support.")
|
||||||
|
|
||||||
|
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="char"):
|
||||||
|
"""
|
||||||
|
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
|
||||||
|
args:
|
||||||
|
label: "今天的通话质量不错呀昨天的呢"
|
||||||
|
predict: "今天的通话质量不错昨天呢星期"
|
||||||
|
tokenizer: 分词方式
|
||||||
|
return:
|
||||||
|
label: ["今", "天", "的", "通", "话", "质", "量", "不", "错", "呀", "昨", "天", "的", "呢", None, None]
|
||||||
|
predict: ["今", "天", "的", "通", "话", "质", "量", "不", "错", None, "昨", "天", None, "呢", "星", "期"]
|
||||||
|
"""
|
||||||
|
from utils.speechio import error_rate_zh as error_rate
|
||||||
|
|
||||||
|
if tokenizer == "char":
|
||||||
|
alignment, score = error_rate.EditDistance(
|
||||||
|
error_rate.tokenize_text(label_sentence, tokenizer),
|
||||||
|
error_rate.tokenize_text(predict_sentence, tokenizer),
|
||||||
|
)
|
||||||
|
label_tokens, pred_tokens = [], []
|
||||||
|
for align in alignment:
|
||||||
|
# print(align.__dict__)
|
||||||
|
label_tokens.append(align.ref)
|
||||||
|
pred_tokens.append(align.hyp)
|
||||||
|
return (label_tokens, pred_tokens)
|
||||||
|
else:
|
||||||
|
logger.error("tokenizer 出错了,暂时不支持其它的")
|
||||||
|
|
||||||
|
def _pred_data_transfer(self, predict_data, recv_time):
|
||||||
|
"""
|
||||||
|
predict_data = [
|
||||||
|
{"recoginition_results": {"text": "1", "final_result": False, "para_seq": 0}},
|
||||||
|
{"recoginition_results": {"text": "12", "final_result": False, "para_seq": 0}},
|
||||||
|
{"recoginition_results": {"text": "123", "final_result": True, "para_seq": 0}},
|
||||||
|
{"recoginition_results": {"text": "4", "final_result": False, "para_seq": 0}},
|
||||||
|
{"recoginition_results": {"text": "45", "final_result": False, "para_seq": 0}},
|
||||||
|
{"recoginition_results": {"text": "456", "final_result": True, "para_seq": 0}},
|
||||||
|
]
|
||||||
|
recv_time = [1, 3, 5, 6, 7, 8]
|
||||||
|
|
||||||
|
->
|
||||||
|
|
||||||
|
[
|
||||||
|
[{'text': '1', 'time': 1}, {'text': '12', 'time': 3}, {'text': '123', 'time': 5}],
|
||||||
|
[{'text': '4', 'time': 6}, {'text': '45', 'time': 7}, {'text': '456', 'time': 8}],
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
pred_sentence_info = []
|
||||||
|
pred_sentence_index = 0
|
||||||
|
for predict_item, recv_time_item in zip(predict_data, recv_time):
|
||||||
|
if len(pred_sentence_info) == pred_sentence_index:
|
||||||
|
pred_sentence_info.append([])
|
||||||
|
pred_sentence_info[pred_sentence_index].append(
|
||||||
|
{
|
||||||
|
"text": predict_item["recoginition_results"]["text"],
|
||||||
|
"time": recv_time_item,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if predict_item["recoginition_results"]["final_result"]:
|
||||||
|
pred_sentence_index += 1
|
||||||
|
return pred_sentence_info
|
||||||
|
|
||||||
|
|
||||||
|
class AsrEnEvaluator(AsrEvaluator):
|
||||||
|
"""
|
||||||
|
英文的评估方式
|
||||||
|
"""
|
||||||
|
|
||||||
|
def evaluate(self, eval_result) -> Dict:
|
||||||
|
self.query_count += 1
|
||||||
|
self.voice_count += len(eval_result["voice"])
|
||||||
|
|
||||||
|
# 获取,标注结果 & 语音单元(非句子单元)
|
||||||
|
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
|
||||||
|
# print("label_voice_sentences", label_voice_sentences)
|
||||||
|
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
|
||||||
|
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
|
||||||
|
# print("voice_to_cut_info", voice_to_cut_info)
|
||||||
|
# 获取,标注结果 & 句子单元
|
||||||
|
label_sentences = voice_to_cut_info["cut_sentences"]
|
||||||
|
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
|
||||||
|
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
|
||||||
|
# 标注结果 & 句子单元 & norm 操作
|
||||||
|
label_sentences = self._sentence_list_norm(label_sentences)
|
||||||
|
# [self._sentence_norm(sentence) for sentence in label_sentences]
|
||||||
|
# print("label_sentences", label_sentences)
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"label_sentences {label_sentences}")
|
||||||
|
|
||||||
|
# 预测结果 & 句子单元
|
||||||
|
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
|
||||||
|
# print("predict_sentences_raw", predict_sentences_raw)
|
||||||
|
# 预测结果 & 句子单元 & norm 操作
|
||||||
|
predict_sentences = self._sentence_list_norm(predict_sentences_raw)
|
||||||
|
# [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
|
||||||
|
# print("predict_sentences", predict_sentences)
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"predict_sentences {predict_sentences}")
|
||||||
|
|
||||||
|
label_tokens, predict_tokens = self._sentence_transfer(" ".join(label_sentences), " ".join(predict_sentences))
|
||||||
|
# print(label_tokens)
|
||||||
|
# print(predict_tokens)
|
||||||
|
|
||||||
|
# cer 计算
|
||||||
|
cer_info = self.cer(label_tokens, predict_tokens)
|
||||||
|
# print("cer_info", cer_info)
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"cer_info {cer_info}")
|
||||||
|
self.one_minus_cer += cer_info["one_minus_cer"]
|
||||||
|
self.token_count += cer_info["token_count"]
|
||||||
|
|
||||||
|
# 句子切分准召率
|
||||||
|
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
|
||||||
|
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
|
||||||
|
# print("cut_info", cut_info)
|
||||||
|
if log_mid_result:
|
||||||
|
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
|
||||||
|
self.miss_count += cut_info["miss_count"]
|
||||||
|
self.more_count += cut_info["more_count"]
|
||||||
|
self.cut_count += cut_info["cut_count"]
|
||||||
|
self.rate += cut_info["rate"]
|
||||||
|
|
||||||
|
self.result.append(
|
||||||
|
{
|
||||||
|
"label_tokens": label_tokens,
|
||||||
|
"predict_tokens": predict_tokens,
|
||||||
|
"one_minus_cer": cer_info["one_minus_cer"],
|
||||||
|
"token_count": cer_info["one_minus_cer"],
|
||||||
|
"miss_count": cut_info["miss_count"],
|
||||||
|
"more_count": cut_info["more_count"],
|
||||||
|
"cut_count": cut_info["cut_count"],
|
||||||
|
"rate": cut_info["rate"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def cer(self, label_tokens, predict_tokens):
|
||||||
|
s, d, i, c = 0, 0, 0, 0
|
||||||
|
for label_token, predict_token in zip(label_tokens, predict_tokens):
|
||||||
|
if label_token == predict_token:
|
||||||
|
c += 1
|
||||||
|
elif predict_token is None:
|
||||||
|
d += 1
|
||||||
|
elif label_token is None:
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
s += 1
|
||||||
|
cer = (s + d + i) / (s + d + c)
|
||||||
|
one_minus_cer = max(1.0 - cer, 0)
|
||||||
|
token_count = s + d + c
|
||||||
|
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
|
||||||
|
|
||||||
|
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
|
||||||
|
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens, "whitespace"))
|
||||||
|
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens, "whitespace"))
|
||||||
|
label_sentence_count = len(label_final_index_list)
|
||||||
|
miss_count = len(label_final_index_list - pred_final_index_list)
|
||||||
|
more_count = len(pred_final_index_list - label_final_index_list)
|
||||||
|
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
|
||||||
|
return {
|
||||||
|
"miss_count": miss_count,
|
||||||
|
"more_count": more_count,
|
||||||
|
"cut_count": label_sentence_count,
|
||||||
|
"rate": rate,
|
||||||
|
"label_final_index_list": label_final_index_list,
|
||||||
|
"pred_final_index_list": pred_final_index_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _sentence_list_norm(self, sentence_list, tokenizer="whitespace"):
|
||||||
|
pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
with open('./predict.txt', 'w', encoding='utf-8') as fp:
|
||||||
|
for idx, sentence in enumerate(sentence_list):
|
||||||
|
fp.write('%s\t%s\n' % (idx, sentence))
|
||||||
|
subprocess.run(
|
||||||
|
f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
|
||||||
|
shell=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
sentence_norm = []
|
||||||
|
with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
|
||||||
|
for line in fp.readlines():
|
||||||
|
line_split_result = line.strip().split('\t', 1)
|
||||||
|
if len(line_split_result) >= 2:
|
||||||
|
sentence_norm.append(line_split_result[1])
|
||||||
|
# 有可能没有 norm 后就没了
|
||||||
|
return sentence_norm
|
||||||
|
|
||||||
|
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="whitespace"):
|
||||||
|
"""
|
||||||
|
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
|
||||||
|
args:
|
||||||
|
label: "HELLO WORLD ARE U OK YEP"
|
||||||
|
predict: "HELLO WORLD U ARE U OK YEP"
|
||||||
|
tokenizer: 分词方式
|
||||||
|
return:
|
||||||
|
label: ["HELLO", "WORLD", None, "ARE", "U", "OK", "YEP"]
|
||||||
|
predict: ["HELLO", "WORLD", "U", "ARE", "U", "OK", "YEP"]
|
||||||
|
"""
|
||||||
|
from utils.speechio import error_rate_zh as error_rate
|
||||||
|
|
||||||
|
if tokenizer == "whitespace":
|
||||||
|
alignment, score = error_rate.EditDistance(
|
||||||
|
error_rate.tokenize_text(label_sentence, tokenizer),
|
||||||
|
error_rate.tokenize_text(predict_sentence, tokenizer),
|
||||||
|
)
|
||||||
|
label_tokens, pred_tokens = [], []
|
||||||
|
for align in alignment:
|
||||||
|
label_tokens.append(align.ref)
|
||||||
|
pred_tokens.append(align.hyp)
|
||||||
|
return (label_tokens, pred_tokens)
|
||||||
|
else:
|
||||||
|
logger.error("tokenizer 出错了,暂时不支持其它的")
|
||||||
|
|
||||||
|
def post_evaluate(self) -> Dict:
|
||||||
|
pass
|
||||||
195
utils/evaluator.py
Normal file
195
utils/evaluator.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
from itertools import chain
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from schemas.context import ASRContext
|
||||||
|
from utils.logger import logger
|
||||||
|
from utils.metrics import cer, cut_rate, cut_sentence, first_delay
|
||||||
|
from utils.metrics import mean_on_counter, patch_unique_token_count
|
||||||
|
from utils.metrics import revision_delay, text_align, token_mapping
|
||||||
|
from utils.metrics import var_on_counter
|
||||||
|
from utils.tokenizer import TOKENIZER_MAPPING, Tokenizer
|
||||||
|
from utils.update_submit import change_product_available
|
||||||
|
|
||||||
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", 1) is None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEvaluator:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.query_count = 0 # query 数目(语音数目)
|
||||||
|
self.voice_count = 0
|
||||||
|
self.fail_count = 0 # 失败数目
|
||||||
|
# 首字延迟
|
||||||
|
self.first_delay_sum = 0
|
||||||
|
self.first_delay_cnt = 0
|
||||||
|
# 修正延迟
|
||||||
|
self.revision_delay_sum = 0
|
||||||
|
self.revision_delay_cnt = 0
|
||||||
|
# patch token 信息
|
||||||
|
self.patch_unique_cnt_counter = Counter()
|
||||||
|
# text align count
|
||||||
|
self.start_time_align_count = 0
|
||||||
|
self.end_time_align_count = 0
|
||||||
|
self.start_end_count = 0
|
||||||
|
# 1-cer
|
||||||
|
self.one_minus_cer = 0
|
||||||
|
self.token_count = 0
|
||||||
|
# 1-cer language
|
||||||
|
self.one_minus_cer_lang = defaultdict(int)
|
||||||
|
self.query_count_lang = defaultdict(int)
|
||||||
|
# sentence-cut
|
||||||
|
self.miss_count = 0
|
||||||
|
self.more_count = 0
|
||||||
|
self.sentence_count = 0
|
||||||
|
self.cut_rate = 0
|
||||||
|
# detail-case
|
||||||
|
self.context = ASRContext()
|
||||||
|
# 时延
|
||||||
|
self.send_interval = []
|
||||||
|
self.last_recv_interval = []
|
||||||
|
# 字含量不达标数
|
||||||
|
self.fail_char_contains_rate_num = 0
|
||||||
|
# 标点符号
|
||||||
|
self.punctuation_num = 0
|
||||||
|
self.pred_punctuation_num = 0
|
||||||
|
|
||||||
|
def evaluate(self, context: ASRContext):
|
||||||
|
self.query_count += 1
|
||||||
|
self.query_count_lang[context.lang] += 1
|
||||||
|
|
||||||
|
voice_count = len(context.labels)
|
||||||
|
self.voice_count += voice_count
|
||||||
|
|
||||||
|
self.punctuation_num += context.punctuation_num
|
||||||
|
self.pred_punctuation_num += context.pred_punctuation_num
|
||||||
|
|
||||||
|
if not context.fail:
|
||||||
|
# 首字延迟
|
||||||
|
first_delay_sum, first_delay_cnt = first_delay(context)
|
||||||
|
self.first_delay_sum += first_delay_sum
|
||||||
|
self.first_delay_cnt += first_delay_cnt
|
||||||
|
|
||||||
|
# 修正延迟
|
||||||
|
revision_delay_sum, revision_delay_cnt = revision_delay(context)
|
||||||
|
self.revision_delay_sum += revision_delay_sum
|
||||||
|
self.revision_delay_cnt += revision_delay_cnt
|
||||||
|
|
||||||
|
# patch token 信息
|
||||||
|
counter = patch_unique_token_count(context)
|
||||||
|
self.patch_unique_cnt_counter += counter
|
||||||
|
else:
|
||||||
|
self.fail_count += 1
|
||||||
|
|
||||||
|
self.fail_char_contains_rate_num += context.fail_char_contains_rate_num
|
||||||
|
|
||||||
|
# text align count
|
||||||
|
start_time_align_count, end_time_align_count, start_end_count = text_align(context)
|
||||||
|
self.start_time_align_count += start_time_align_count
|
||||||
|
self.end_time_align_count += end_time_align_count
|
||||||
|
self.start_end_count += start_end_count
|
||||||
|
|
||||||
|
# cer, wer
|
||||||
|
sentences_gt: List[str] = [item.answer for item in context.labels]
|
||||||
|
sentences_dt: List[str] = [
|
||||||
|
item.recognition_results.text for item in context.preds if item.recognition_results.final_result
|
||||||
|
]
|
||||||
|
if IN_TEST:
|
||||||
|
print(sentences_gt)
|
||||||
|
print(sentences_dt)
|
||||||
|
|
||||||
|
sentences_gt: List[str] = cut_sentence(sentences_gt, TOKENIZER_MAPPING.get(context.lang))
|
||||||
|
sentences_dt: List[str] = cut_sentence(sentences_dt, TOKENIZER_MAPPING.get(context.lang))
|
||||||
|
if IN_TEST:
|
||||||
|
print(sentences_gt)
|
||||||
|
print(sentences_dt)
|
||||||
|
|
||||||
|
# norm & tokenize
|
||||||
|
tokens_gt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_gt, context.lang)
|
||||||
|
tokens_dt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_dt, context.lang)
|
||||||
|
if IN_TEST:
|
||||||
|
print(tokens_gt)
|
||||||
|
print(tokens_dt)
|
||||||
|
|
||||||
|
# cer
|
||||||
|
tokens_gt_mapping, tokens_dt_mapping = token_mapping(list(chain(*tokens_gt)), list(chain(*tokens_dt)))
|
||||||
|
one_minue_cer, token_count = cer(tokens_gt_mapping, tokens_dt_mapping)
|
||||||
|
self.one_minus_cer += one_minue_cer
|
||||||
|
self.token_count += token_count
|
||||||
|
self.one_minus_cer_lang[context.lang] += one_minue_cer
|
||||||
|
|
||||||
|
# cut-rate
|
||||||
|
rate, sentence_cnt, miss_cnt, more_cnt = cut_rate(tokens_gt, tokens_dt, tokens_gt_mapping, tokens_dt_mapping)
|
||||||
|
self.cut_rate += rate
|
||||||
|
self.sentence_count += sentence_cnt
|
||||||
|
self.miss_count += miss_cnt
|
||||||
|
self.more_count += more_cnt
|
||||||
|
|
||||||
|
# detail-case
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
# 时延
|
||||||
|
if self.context.send_time_start_end and self.context.recv_time_start_end:
|
||||||
|
send_interval = self.context.send_time_start_end[1] - self.context.send_time_start_end[0]
|
||||||
|
recv_interval = self.context.recv_time_start_end[1] - self.context.send_time_start_end[0]
|
||||||
|
self.send_interval.append(send_interval)
|
||||||
|
self.last_recv_interval.append(recv_interval)
|
||||||
|
logger.info(
|
||||||
|
f"""第一次发送时间{self.context.send_time_start_end[0]}, \
|
||||||
|
最后一次发送时间{self.context.send_time_start_end[-1]}, \
|
||||||
|
发送间隔 {send_interval},
|
||||||
|
最后一次接收时间{self.context.recv_time_start_end[-1]}, \
|
||||||
|
接收间隔 {recv_interval}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_evaluate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def gen_result(self):
|
||||||
|
result = {
|
||||||
|
"query_count": self.query_count,
|
||||||
|
"voice_count": self.voice_count,
|
||||||
|
"pred_voice_count": self.first_delay_cnt,
|
||||||
|
"first_delay_mean": self.first_delay_sum / self.first_delay_cnt if self.first_delay_cnt > 0 else 10,
|
||||||
|
"revision_delay_mean": (
|
||||||
|
self.revision_delay_sum / self.revision_delay_cnt if self.revision_delay_cnt > 0 else 10
|
||||||
|
),
|
||||||
|
"patch_token_mean": mean_on_counter(self.patch_unique_cnt_counter),
|
||||||
|
"patch_token_var": var_on_counter(self.patch_unique_cnt_counter),
|
||||||
|
"start_time_align_count": self.start_time_align_count,
|
||||||
|
"end_time_align_count": self.end_time_align_count,
|
||||||
|
"start_time_align_rate": self.start_time_align_count / self.sentence_count,
|
||||||
|
"end_time_align_rate": self.end_time_align_count / self.sentence_count,
|
||||||
|
"start_end_count": self.start_end_count,
|
||||||
|
"one_minus_cer": self.one_minus_cer / self.query_count,
|
||||||
|
"token_count": self.token_count,
|
||||||
|
"miss_count": self.miss_count,
|
||||||
|
"more_count": self.more_count,
|
||||||
|
"sentence_count": self.sentence_count,
|
||||||
|
"cut_rate": self.cut_rate / self.query_count,
|
||||||
|
"fail_count": self.fail_count,
|
||||||
|
"send_interval": self.send_interval,
|
||||||
|
"last_recv_interval": self.last_recv_interval,
|
||||||
|
"fail_char_contains_rate_num": self.fail_char_contains_rate_num,
|
||||||
|
"punctuation_rate": self.pred_punctuation_num / self.punctuation_num,
|
||||||
|
}
|
||||||
|
for lang in self.one_minus_cer_lang:
|
||||||
|
result["one_minus_cer_" + lang] = \
|
||||||
|
self.one_minus_cer_lang[lang] / self.query_count_lang[lang]
|
||||||
|
|
||||||
|
if (
|
||||||
|
result["first_delay_mean"]
|
||||||
|
> float(os.getenv("FIRST_DELAY_THRESHOLD", "5"))
|
||||||
|
or
|
||||||
|
self.fail_char_contains_rate_num / self.voice_count > 0.1
|
||||||
|
# or
|
||||||
|
# result["punctuation_rate"] < 0.8
|
||||||
|
):
|
||||||
|
change_product_available()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def gen_detail_case(self):
|
||||||
|
return self.context
|
||||||
293
utils/evaluator_plus.py
Normal file
293
utils/evaluator_plus.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from copy import deepcopy
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import Levenshtein
|
||||||
|
|
||||||
|
from schemas.dataset import QueryData
|
||||||
|
from schemas.stream import StreamDataModel, StreamWordsModel
|
||||||
|
from utils.metrics import Tokenizer
|
||||||
|
from utils.metrics_plus import replace_general_punc
|
||||||
|
from utils.tokenizer import TOKENIZER_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_editops(
|
||||||
|
query_data: QueryData, recognition_results: List[StreamDataModel]
|
||||||
|
) -> Tuple[float, int, Dict[int, int], Dict[int, int], float, float]:
|
||||||
|
"""返回cer 句子总数 首字对齐情况 尾字对齐情况 首字时间差值和 尾字时间差值和
|
||||||
|
对齐情况为 时间差值->对齐数"""
|
||||||
|
recognition_results = deepcopy(recognition_results)
|
||||||
|
lang = query_data.lang
|
||||||
|
voices = query_data.voice
|
||||||
|
sentences_pred = [
|
||||||
|
recognition_result.text for recognition_result in recognition_results
|
||||||
|
]
|
||||||
|
sentences_label = [item.answer for item in voices]
|
||||||
|
|
||||||
|
tokenizer_type = TOKENIZER_MAPPING[lang]
|
||||||
|
sentences_pred = replace_general_punc(sentences_pred, tokenizer_type)
|
||||||
|
sentences_label = replace_general_punc(sentences_label, tokenizer_type)
|
||||||
|
|
||||||
|
# norm & tokenize
|
||||||
|
tokens_pred = Tokenizer.norm_and_tokenize(sentences_pred, lang)
|
||||||
|
tokens_label = Tokenizer.norm_and_tokenize(sentences_label, lang)
|
||||||
|
|
||||||
|
normed_words = []
|
||||||
|
for recognition_result in recognition_results:
|
||||||
|
words = list(map(lambda x: x.text, recognition_result.words))
|
||||||
|
normed_words.extend(words)
|
||||||
|
normed_words = replace_general_punc(normed_words, tokenizer_type)
|
||||||
|
normed_words = Tokenizer.norm(normed_words, lang)
|
||||||
|
|
||||||
|
# 预测中的结果进行相同的norm和tokenize操作
|
||||||
|
normed_word_index = 0
|
||||||
|
for recognition_result in recognition_results:
|
||||||
|
next_index = normed_word_index + len(recognition_result.words)
|
||||||
|
tokens_words = Tokenizer.tokenize(
|
||||||
|
normed_words[normed_word_index:next_index], lang
|
||||||
|
)
|
||||||
|
normed_word_index = next_index
|
||||||
|
stream_words: List[StreamWordsModel] = []
|
||||||
|
# 将原words进行norm和tokenize操作后赋值为对应原word的时间
|
||||||
|
for raw_stream_word, tokens_word in zip(
|
||||||
|
recognition_result.words, tokens_words
|
||||||
|
):
|
||||||
|
for word in tokens_word:
|
||||||
|
stream_words.append(
|
||||||
|
StreamWordsModel(
|
||||||
|
text=word,
|
||||||
|
start_time=raw_stream_word.start_time,
|
||||||
|
end_time=raw_stream_word.end_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
recognition_result.words = stream_words
|
||||||
|
|
||||||
|
# 将words对应上对分词后的词,从而使得分词后的词有时间
|
||||||
|
pred_word_time: List[StreamWordsModel] = []
|
||||||
|
for token_pred, recognition_result in zip(tokens_pred, recognition_results):
|
||||||
|
word_index = 0
|
||||||
|
for word in recognition_result.words:
|
||||||
|
try:
|
||||||
|
token_index = token_pred.index(word.text, word_index)
|
||||||
|
for i in range(word_index, token_index + 1):
|
||||||
|
pred_word_time.append(
|
||||||
|
StreamWordsModel(
|
||||||
|
text=token_pred[i],
|
||||||
|
start_time=word.start_time,
|
||||||
|
end_time=word.end_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
word_index = token_index + 1
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if len(recognition_result.words) > 0:
|
||||||
|
word = recognition_result.words[-1]
|
||||||
|
start_time = word.start_time
|
||||||
|
end_time = word.end_time
|
||||||
|
else:
|
||||||
|
start_time = recognition_result.start_time
|
||||||
|
end_time = recognition_result.end_time
|
||||||
|
for i in range(word_index, len(token_pred)):
|
||||||
|
pred_word_time.append(
|
||||||
|
StreamWordsModel(
|
||||||
|
text=token_pred[i],
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录label每句话的首字尾字对应分词后的位置
|
||||||
|
index = 0
|
||||||
|
label_firstword_index: List[int] = []
|
||||||
|
label_lastword_index: List[int] = []
|
||||||
|
for token_label in tokens_label:
|
||||||
|
label_firstword_index.append(index)
|
||||||
|
index += len(token_label)
|
||||||
|
label_lastword_index.append(index - 1)
|
||||||
|
|
||||||
|
# cer
|
||||||
|
flat_tokens_pred = list(chain(*tokens_pred))
|
||||||
|
flat_tokens_label = list(chain(*tokens_label))
|
||||||
|
ops = Levenshtein.editops(flat_tokens_pred, flat_tokens_label)
|
||||||
|
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
|
||||||
|
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
|
||||||
|
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
|
||||||
|
cer = (insert + delete + replace) / len(flat_tokens_label)
|
||||||
|
|
||||||
|
# 计算每个token在编辑后的下标位置
|
||||||
|
pred_offset = [0] * (len(flat_tokens_pred) + 1)
|
||||||
|
label_offset = [0] * (len(flat_tokens_label) + 1)
|
||||||
|
for op in ops:
|
||||||
|
if op[0] == "insert":
|
||||||
|
pred_offset[op[1]] += 1
|
||||||
|
elif op[0] == "delete":
|
||||||
|
label_offset[op[2]] += 1
|
||||||
|
pred_indexs = [pred_offset[0]]
|
||||||
|
for i in range(1, len(flat_tokens_pred)):
|
||||||
|
pred_indexs.append(pred_indexs[i - 1] + pred_offset[i] + 1)
|
||||||
|
label_indexs = [label_offset[0]]
|
||||||
|
for i in range(1, len(flat_tokens_label)):
|
||||||
|
label_indexs.append(label_indexs[i - 1] + label_offset[i] + 1)
|
||||||
|
|
||||||
|
# 计算每个label中首字和尾字对应的时间
|
||||||
|
align_start = {100: 0, 200: 0, 300: 0, 500: 0}
|
||||||
|
align_end = {100: 0, 200: 0, 300: 0, 500: 0}
|
||||||
|
first_word_distance_sum = 0.0
|
||||||
|
last_word_distance_sum = 0.0
|
||||||
|
for firstword_index, lastword_index, voice in zip(
|
||||||
|
label_firstword_index, label_lastword_index, voices
|
||||||
|
):
|
||||||
|
label_index = label_indexs[firstword_index]
|
||||||
|
label_in_pred_index = upper_bound(label_index, pred_indexs)
|
||||||
|
if label_in_pred_index != -1:
|
||||||
|
distance = abs(
|
||||||
|
voice.start - pred_word_time[label_in_pred_index].start_time
|
||||||
|
)
|
||||||
|
if label_in_pred_index > 0:
|
||||||
|
distance = min(
|
||||||
|
distance,
|
||||||
|
abs(
|
||||||
|
voice.start
|
||||||
|
- pred_word_time[label_in_pred_index - 1].start_time
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
distance = abs(voice.start - pred_word_time[-1].start_time)
|
||||||
|
for limit in align_start.keys():
|
||||||
|
if distance <= limit / 1000:
|
||||||
|
align_start[limit] += 1
|
||||||
|
first_word_distance_sum += distance
|
||||||
|
|
||||||
|
label_index = label_indexs[lastword_index]
|
||||||
|
label_in_pred_index = lower_bound(label_index, pred_indexs)
|
||||||
|
if label_in_pred_index != -1:
|
||||||
|
distance = abs(
|
||||||
|
voice.end - pred_word_time[label_in_pred_index].end_time
|
||||||
|
)
|
||||||
|
if label_in_pred_index < len(pred_word_time) - 1:
|
||||||
|
distance = min(
|
||||||
|
distance,
|
||||||
|
abs(
|
||||||
|
voice.end
|
||||||
|
- pred_word_time[label_in_pred_index + 1].end_time
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
distance = abs(voice.end - pred_word_time[0].end_time)
|
||||||
|
for limit in align_end.keys():
|
||||||
|
if distance <= limit / 1000:
|
||||||
|
align_end[limit] += 1
|
||||||
|
last_word_distance_sum += distance
|
||||||
|
return (
|
||||||
|
cer,
|
||||||
|
len(voices),
|
||||||
|
align_start,
|
||||||
|
align_end,
|
||||||
|
first_word_distance_sum,
|
||||||
|
last_word_distance_sum,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_punctuation(
|
||||||
|
query_data: QueryData, recognition_results: List[StreamDataModel]
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
"""评估标点符号指标 返回预测中标点数 label中标点数 预测中句子标点数 label中句子标点数"""
|
||||||
|
punctuation_mapping = defaultdict(lambda: [",", ".", "!", "?"])
|
||||||
|
punctuation_mapping.update(
|
||||||
|
{
|
||||||
|
"zh": [",", "。", "!", "?"],
|
||||||
|
"ja": ["、", "。", "!", "?"],
|
||||||
|
"ar": ["،", ".", "!", "؟"],
|
||||||
|
"fa": ["،", ".", "!", "؟"],
|
||||||
|
"el": [",", ".", "!", ";"],
|
||||||
|
"ti": ["།"],
|
||||||
|
"th": [" ", ",", ".", "!", "?"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
punctuation_words: List[StreamWordsModel] = []
|
||||||
|
for recognition_result in recognition_results:
|
||||||
|
punctuations = punctuation_mapping[query_data.lang]
|
||||||
|
for word in recognition_result.words:
|
||||||
|
for char in word.text:
|
||||||
|
if char in punctuations:
|
||||||
|
punctuation_words.append(word)
|
||||||
|
break
|
||||||
|
punctuation_start_times = list(
|
||||||
|
map(lambda x: x.start_time, punctuation_words)
|
||||||
|
)
|
||||||
|
punctuation_start_times = sorted(punctuation_start_times)
|
||||||
|
punctuation_end_times = list(map(lambda x: x.end_time, punctuation_words))
|
||||||
|
punctuation_end_times = sorted(punctuation_end_times)
|
||||||
|
|
||||||
|
voices = query_data.voice
|
||||||
|
label_len = len(voices)
|
||||||
|
pred_punctuation_num = len(punctuation_words)
|
||||||
|
label_punctuation_num = 0
|
||||||
|
for label_voice in voices:
|
||||||
|
punctuations = punctuation_mapping[query_data.lang]
|
||||||
|
for char in label_voice.answer:
|
||||||
|
if char in punctuations:
|
||||||
|
label_punctuation_num += 1
|
||||||
|
|
||||||
|
pred_sentence_punctuation_num = 0
|
||||||
|
label_setence_punctuation_num = label_len
|
||||||
|
for i, label_voice in enumerate(voices):
|
||||||
|
if i < label_len - 1:
|
||||||
|
label_left = label_voice.end
|
||||||
|
label_right = voices[i + 1].start
|
||||||
|
else:
|
||||||
|
label_left = label_voice.end - 0.7
|
||||||
|
label_right = label_voice.end + 0.7
|
||||||
|
|
||||||
|
left_in_pred = upper_bound(label_left, punctuation_start_times)
|
||||||
|
exist = False
|
||||||
|
if (
|
||||||
|
left_in_pred != -1
|
||||||
|
and punctuation_start_times[left_in_pred] <= label_right
|
||||||
|
):
|
||||||
|
exist = True
|
||||||
|
right_in_pred = lower_bound(label_right, punctuation_end_times)
|
||||||
|
if (
|
||||||
|
right_in_pred != -1
|
||||||
|
and punctuation_end_times[right_in_pred] >= label_left
|
||||||
|
):
|
||||||
|
exist = True
|
||||||
|
|
||||||
|
if exist:
|
||||||
|
pred_sentence_punctuation_num += 1
|
||||||
|
return (
|
||||||
|
pred_punctuation_num,
|
||||||
|
label_punctuation_num,
|
||||||
|
pred_sentence_punctuation_num,
|
||||||
|
label_setence_punctuation_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upper_bound(x: float, lst: List[float]) -> int:
|
||||||
|
"""第一个 >= x 的元素的下标 没有返回-1"""
|
||||||
|
ans = -1
|
||||||
|
left, right = 0, len(lst) - 1
|
||||||
|
while left <= right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if lst[mid] >= x:
|
||||||
|
ans = mid
|
||||||
|
right = mid - 1
|
||||||
|
else:
|
||||||
|
left = mid + 1
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def lower_bound(x: float, lst: List[float]) -> int:
|
||||||
|
"""最后一个 <= x 的元素的下标 没有返回-1"""
|
||||||
|
ans = -1
|
||||||
|
left, right = 0, len(lst) - 1
|
||||||
|
while left <= right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if lst[mid] <= x:
|
||||||
|
ans = mid
|
||||||
|
left = mid + 1
|
||||||
|
else:
|
||||||
|
right = mid - 1
|
||||||
|
return ans
|
||||||
151
utils/file.py
Normal file
151
utils/file.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path: str, raise_for_invalid: bool = False) -> Any:
|
||||||
|
"""读取path json文件转为对象"""
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
if raise_for_invalid:
|
||||||
|
|
||||||
|
def parse_constant(s: str):
|
||||||
|
raise ValueError("非法json字符: %s" % s)
|
||||||
|
|
||||||
|
return json.load(f, parse_constant=parse_constant)
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_json(path: str, obj: Any):
|
||||||
|
"""将obj对象以json形式写入path文件"""
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(obj, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml(path: str) -> Any:
|
||||||
|
"""读取path yaml文件转为对象"""
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return yaml.full_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_yaml(path: str, obj: Any):
|
||||||
|
"""将obj对象以yaml形式写入path文件"""
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
yaml.dump(obj, f, indent=2, allow_unicode=True, sort_keys=False, line_break="\n")
|
||||||
|
|
||||||
|
|
||||||
|
def dumps_yaml(obj: Any) -> str:
|
||||||
|
"""将obj对象以yaml形式导出为字符串"""
|
||||||
|
return yaml.dump(obj, indent=2, allow_unicode=True, sort_keys=False, line_break="\n")
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(path: str) -> str:
|
||||||
|
"""读取文件为str"""
|
||||||
|
with open(path, "r") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def write_bfile(path: str, data: bytes):
|
||||||
|
"""将bytes data写入path文件"""
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
|
def write_file(path: str, data: str):
|
||||||
|
"""将str data写入path文件"""
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
|
def tail_file(path: str, tail: int) -> str:
|
||||||
|
"""倍增获取文件path最后tail行"""
|
||||||
|
block = 1024
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
f.seek(0, 2)
|
||||||
|
filesize = f.tell()
|
||||||
|
while True:
|
||||||
|
if filesize < block:
|
||||||
|
block = filesize
|
||||||
|
f.seek(filesize - block, 0)
|
||||||
|
lines = f.readlines()
|
||||||
|
if len(lines) > tail or filesize <= block:
|
||||||
|
return "".join(line.decode() for line in lines[-tail:])
|
||||||
|
block *= 2
|
||||||
|
|
||||||
|
|
||||||
|
def zip_dir(zip_path: str, dirname: str):
|
||||||
|
"""将dirname制作为zip_path压缩包"""
|
||||||
|
with zipfile.ZipFile(zip_path, "w") as ziper:
|
||||||
|
for path, _, files in os.walk(dirname):
|
||||||
|
for file in files:
|
||||||
|
ziper.write(
|
||||||
|
os.path.join(path, file), os.path.join(path.removeprefix(dirname), file), zipfile.ZIP_DEFLATED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def zip_files(name: str, zipfile_paths: list):
|
||||||
|
"""将zipfiles_paths=list[文件名, 文件路径]制作为name压缩包"""
|
||||||
|
with zipfile.ZipFile(name, "w") as ziper:
|
||||||
|
for arcname, zipfile_path in zipfile_paths:
|
||||||
|
ziper.write(zipfile_path, arcname, zipfile.ZIP_DEFLATED)
|
||||||
|
|
||||||
|
|
||||||
|
def zip_strs(name: str, zipfile_strs: list):
|
||||||
|
"""将zipfile_strs=list[文件名, 内容]制作为name压缩包"""
|
||||||
|
with zipfile.ZipFile(name, "w") as ziper:
|
||||||
|
for filename, content in zipfile_strs:
|
||||||
|
ziper.writestr(filename, content)
|
||||||
|
|
||||||
|
|
||||||
|
def zip_zipers(name: str, ziper_paths: list):
|
||||||
|
"""将ziper_paths=list[压缩后名称, 压缩包/文件位置]制作为name压缩包"""
|
||||||
|
temp_dirname = tempfile.mkdtemp(prefix=name, dir=os.path.dirname(name))
|
||||||
|
os.makedirs(temp_dirname, exist_ok=True)
|
||||||
|
for subname, ziper_path in ziper_paths:
|
||||||
|
sub_dirname = os.path.join(temp_dirname, subname)
|
||||||
|
if not os.path.exists(ziper_path):
|
||||||
|
continue
|
||||||
|
if zipfile.is_zipfile(ziper_path):
|
||||||
|
# 压缩包解压
|
||||||
|
os.makedirs(sub_dirname, exist_ok=True)
|
||||||
|
unzip_dir(ziper_path, sub_dirname)
|
||||||
|
elif os.path.isfile(ziper_path):
|
||||||
|
# 文件
|
||||||
|
shutil.copyfile(ziper_path, sub_dirname)
|
||||||
|
else:
|
||||||
|
# 文件夹
|
||||||
|
shutil.copytree(ziper_path, sub_dirname)
|
||||||
|
zip_dir(name, temp_dirname)
|
||||||
|
shutil.rmtree(temp_dirname)
|
||||||
|
|
||||||
|
|
||||||
|
def unzip_dir(zip_path: str, dirname: str, catch_exc: bool = True):
|
||||||
|
"""将zip_path解压到dirname"""
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as ziper:
|
||||||
|
try:
|
||||||
|
ziper.extractall(dirname)
|
||||||
|
except Exception as e:
|
||||||
|
if catch_exc:
|
||||||
|
write_file(os.path.join(dirname, "unzip_error.log"), "%r" % e)
|
||||||
|
shutil.copyfile(zip_path, os.path.join(dirname, os.path.basename(zip_path)))
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def tar_dir(zip_path: str, dirname: str):
|
||||||
|
"""将dirname压缩到zip_path"""
|
||||||
|
with tarfile.open(zip_path, "w:gz") as ziper:
|
||||||
|
for path, _, files in os.walk(dirname):
|
||||||
|
for file in files:
|
||||||
|
ziper.add(os.path.join(path, file), os.path.join(path.removeprefix(dirname), file))
|
||||||
|
|
||||||
|
|
||||||
|
def untar_dir(zip_path: str, dirname: str):
|
||||||
|
"""将zip_path解压到dirname"""
|
||||||
|
with tarfile.open(zip_path) as ziper:
|
||||||
|
ziper.extractall(dirname)
|
||||||
331
utils/helm.py
Normal file
331
utils/helm.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import copy
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
sut_chart_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "helm-chart", "sut")
|
||||||
|
headers = (
|
||||||
|
{'Authorization': 'Bearer ' + os.getenv("LEADERBOARD_API_TOKEN")} if os.getenv("LEADERBOARD_API_TOKEN") else None
|
||||||
|
)
|
||||||
|
pull_num: defaultdict = defaultdict()
|
||||||
|
JOB_ID = int(os.getenv("JOB_ID", "-1"))
|
||||||
|
LOAD_SUT_URL = os.getenv("LOAD_SUT_URL")
|
||||||
|
GET_JOB_SUT_INFO_URL = os.getenv("GET_JOB_SUT_INFO_URL")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_env_to_values(values, envs):
|
||||||
|
if "env" not in values:
|
||||||
|
values["env"] = []
|
||||||
|
old_key_list = [x["name"] for x in values["env"]]
|
||||||
|
for k, v in envs.items():
|
||||||
|
try:
|
||||||
|
idx = old_key_list.index(k)
|
||||||
|
values["env"][idx]["value"] = v
|
||||||
|
except ValueError:
|
||||||
|
values["env"].append({"name": k, "value": v})
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def merge_values(base_value, incr_value):
|
||||||
|
if isinstance(base_value, dict) and isinstance(incr_value, dict):
|
||||||
|
for k in incr_value:
|
||||||
|
base_value[k] = merge_values(base_value[k], incr_value[k]) if k in base_value else incr_value[k]
|
||||||
|
elif isinstance(base_value, list) and isinstance(incr_value, list):
|
||||||
|
base_value.extend(incr_value)
|
||||||
|
else:
|
||||||
|
base_value = incr_value
|
||||||
|
return base_value
|
||||||
|
|
||||||
|
|
||||||
|
def gen_chart_tarball(docker_image):
|
||||||
|
"""docker image加上digest并根据image生成helm chart包, 失败直接异常退出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docker_image (_type_): docker image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[BytesIO, dict]: [helm chart包file对象, values内容]
|
||||||
|
"""
|
||||||
|
# load values template
|
||||||
|
with open(os.path.join(sut_chart_root, "values.yaml.tmpl")) as fp:
|
||||||
|
yaml = YAML(typ="rt")
|
||||||
|
values = yaml.load(fp)
|
||||||
|
# update docker_image
|
||||||
|
get_image_hash_url = os.getenv("GET_IMAGE_HASH_URL", None)
|
||||||
|
logger.info(f"get_image_hash_url: {get_image_hash_url}")
|
||||||
|
if get_image_hash_url is not None:
|
||||||
|
# convert tag to hash for docker_image
|
||||||
|
#docker_image = "harbor-contest.4pd.io/zhoushasha/speaker_identification:wo_model_v0"
|
||||||
|
docker_image = "harbor-contest.4pd.io/zhoushasha/image_classification:wo_model_v3"
|
||||||
|
resp = requests.get(get_image_hash_url, headers=headers, params={"image": docker_image}, timeout=600)
|
||||||
|
|
||||||
|
logger.info(f"resp.text: {resp.text}")
|
||||||
|
assert resp.status_code == 200, "Convert tag to hash for docker image failed, API retcode %d" % resp.status_code
|
||||||
|
resp = resp.json()
|
||||||
|
assert resp["success"], "Convert tag to hash for docker image failed, response: %s" % str(resp)
|
||||||
|
token = resp["data"]["image"].rsplit(":", 2)
|
||||||
|
assert len(token) == 3, "Invalid docker image %s" % resp["data"]["image"]
|
||||||
|
values["image"]["repository"] = token[0]
|
||||||
|
values["image"]["tag"] = ":".join(token[1:])
|
||||||
|
else:
|
||||||
|
token = docker_image.rsplit(":", 1)
|
||||||
|
if len(token) != 2:
|
||||||
|
raise RuntimeError("Invalid docker image %s" % docker_image)
|
||||||
|
values["image"]["repository"] = token[0]
|
||||||
|
values["image"]["tag"] = token[1]
|
||||||
|
# output values.yaml
|
||||||
|
with open(os.path.join(sut_chart_root, "values.yaml"), "w") as fp:
|
||||||
|
yaml = YAML(typ="rt")
|
||||||
|
yaml.dump(values, fp)
|
||||||
|
# tarball
|
||||||
|
tarfp = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=tarfp, mode="w:gz") as tar:
|
||||||
|
tar.add(sut_chart_root, arcname=os.path.basename(sut_chart_root), recursive=True)
|
||||||
|
tarfp.seek(0)
|
||||||
|
logger.debug(f"Generated chart using values: {values}")
|
||||||
|
return tarfp, values
|
||||||
|
|
||||||
|
|
||||||
|
def deploy_chart(
|
||||||
|
name_suffix,
|
||||||
|
readiness_timeout,
|
||||||
|
chart_str=None,
|
||||||
|
chart_fileobj=None,
|
||||||
|
extra_values=None,
|
||||||
|
restart_count_limit=3,
|
||||||
|
pullimage_count_limit=3,
|
||||||
|
):
|
||||||
|
"""部署sut, 失败直接异常退出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_suffix (str): 同一个job有多个sut时, 区分不同sut的名称
|
||||||
|
readiness_timeout (int): readiness超时时间, 单位s
|
||||||
|
chart_str (int, optional): chart url, 不为None则忽略chart_fileobj. Defaults to None.
|
||||||
|
chart_fileobj (BytesIO, optional): helm chart包file对象, chart_str不为None使用. Defaults to None.
|
||||||
|
extra_values (dict, optional): helm values的补充内容. Defaults to None.
|
||||||
|
restart_count_limit (int, optional): sut重启次数限制, 超出则异常退出. Defaults to 3.
|
||||||
|
pullimage_count_limit (int, optional): image拉取次数限制, 超出则异常退出. Defaults to 3.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: [用于访问服务的k8s域名, 用于unload_sut的名称]
|
||||||
|
"""
|
||||||
|
logger.info(f"Deploying SUT application for JOB {JOB_ID}, name_suffix {name_suffix}, extra_values {extra_values}")
|
||||||
|
# deploy
|
||||||
|
payload = {
|
||||||
|
"job_id": JOB_ID,
|
||||||
|
"resource_name": name_suffix,
|
||||||
|
"priorityclassname": os.environ.get("priorityclassname"),
|
||||||
|
}
|
||||||
|
extra_values = {} if not extra_values else extra_values
|
||||||
|
payload["values"] = json.dumps(extra_values, ensure_ascii=False)
|
||||||
|
if chart_str is not None:
|
||||||
|
payload["helm_chart"] = chart_str
|
||||||
|
resp = requests.post(LOAD_SUT_URL, data=payload, headers=headers, timeout=600)
|
||||||
|
else:
|
||||||
|
assert chart_fileobj is not None, "Either chart_str or chart_fileobj should be set"
|
||||||
|
|
||||||
|
logger.info(f"LOAD_SUT_URL: {LOAD_SUT_URL}")
|
||||||
|
logger.info(f"payload: {payload}")
|
||||||
|
logger.info(f"headers: {headers}")
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
LOAD_SUT_URL,
|
||||||
|
data=payload,
|
||||||
|
headers=headers,
|
||||||
|
files=[("helm_chart_file", (name_suffix + ".tgz", chart_fileobj))],
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise RuntimeError("Failed to deploy application status_code %d %s" % (resp.status_code, resp.text))
|
||||||
|
resp = resp.json()
|
||||||
|
if not resp["success"]:
|
||||||
|
logger.error("Failed to deploy application response %r", resp)
|
||||||
|
service_name = resp["data"]["service_name"]
|
||||||
|
sut_name = resp["data"]["sut_name"]
|
||||||
|
logger.info(f"SUT application deployed with service_name {service_name}")
|
||||||
|
# waiting for appliation ready
|
||||||
|
running_at = None
|
||||||
|
retry_count = 0
|
||||||
|
while True:
|
||||||
|
retry_interval = 10
|
||||||
|
if retry_count % 20 == 19:
|
||||||
|
retry_count += 1
|
||||||
|
logger.info(f"Waiting {retry_interval} seconds to check whether SUT application {service_name} is ready...")
|
||||||
|
logger.info("20 retrys log this message again.")
|
||||||
|
time.sleep(retry_interval)
|
||||||
|
check_result, running_at = check_sut_ready_from_resp(
|
||||||
|
service_name,
|
||||||
|
running_at,
|
||||||
|
readiness_timeout,
|
||||||
|
restart_count_limit,
|
||||||
|
pullimage_count_limit,
|
||||||
|
)
|
||||||
|
if check_result:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"SUT application for JOB {JOB_ID} name_suffix {name_suffix} is ready, service_name {service_name}")
|
||||||
|
return service_name, sut_name
|
||||||
|
|
||||||
|
|
||||||
|
def check_sut_ready_from_resp(
|
||||||
|
service_name,
|
||||||
|
running_at,
|
||||||
|
readiness_timeout,
|
||||||
|
restart_count_limit,
|
||||||
|
pullimage_count_limit,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
resp = requests.get(
|
||||||
|
f"{GET_JOB_SUT_INFO_URL}/{JOB_ID}",
|
||||||
|
headers=headers,
|
||||||
|
params={"with_detail": True},
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception occured while getting SUT application {service_name} status", e)
|
||||||
|
return False, running_at
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(f"Get SUT application {service_name} status failed with status_code {resp.status_code}")
|
||||||
|
return False, running_at
|
||||||
|
resp = resp.json()
|
||||||
|
if not resp["success"]:
|
||||||
|
logger.warning(f"Get SUT application {service_name} status failed with response {resp}")
|
||||||
|
return False, running_at
|
||||||
|
if len(resp["data"]["sut"]) == 0:
|
||||||
|
logger.warning("Empty SUT application status")
|
||||||
|
return False, running_at
|
||||||
|
resp_data_sut = copy.deepcopy(resp["data"]["sut"])
|
||||||
|
for status in resp_data_sut:
|
||||||
|
del status["detail"]
|
||||||
|
logger.info(f"Got SUT application status: {resp_data_sut}")
|
||||||
|
for status in resp["data"]["sut"]:
|
||||||
|
if status["phase"] in ["Succeeded", "Failed"]:
|
||||||
|
raise RuntimeError(f"Some pods of SUT application {service_name} terminated with status {status}")
|
||||||
|
elif status["phase"] in ["Pending", "Unknown"]:
|
||||||
|
return False, running_at
|
||||||
|
elif status["phase"] != "Running":
|
||||||
|
raise RuntimeError(f"Unexcepted pod status {status} of SUT application {service_name}")
|
||||||
|
if running_at is None:
|
||||||
|
running_at = time.time()
|
||||||
|
for ct in status["detail"]["status"]["container_statuses"]:
|
||||||
|
if ct["restart_count"] > 0:
|
||||||
|
logger.info(f"pod {status['pod_name']} restart count = {ct['restart_count']}")
|
||||||
|
if ct["restart_count"] > restart_count_limit:
|
||||||
|
raise RuntimeError(f"pod {status['pod_name']} restart too many times(over {restart_count_limit})")
|
||||||
|
if (
|
||||||
|
ct["state"]["waiting"] is not None
|
||||||
|
and "reason" in ct["state"]["waiting"]
|
||||||
|
and ct["state"]["waiting"]["reason"] in ["ImagePullBackOff", "ErrImagePull"]
|
||||||
|
):
|
||||||
|
pull_num[status["pod_name"]] += 1
|
||||||
|
logger.info(
|
||||||
|
"pod %s has {pull_num[status['pod_name']]} times inspect pulling image info: %s"
|
||||||
|
% (status["pod_name"], ct["state"]["waiting"])
|
||||||
|
)
|
||||||
|
if pull_num[status["pod_name"]] > pullimage_count_limit:
|
||||||
|
raise RuntimeError(f"pod {status['pod_name']} cannot pull image")
|
||||||
|
if not status["conditions"]["Ready"]:
|
||||||
|
if running_at is not None and time.time() - running_at > readiness_timeout:
|
||||||
|
raise RuntimeError(f"SUT Application readiness has exceeded readiness_timeout:{readiness_timeout}s")
|
||||||
|
return False, running_at
|
||||||
|
return True, running_at
|
||||||
|
|
||||||
|
|
||||||
|
def parse_resource(resource):
|
||||||
|
if resource == -1:
|
||||||
|
return -1
|
||||||
|
match = re.match(r"([\d\.]+)([mKMGTPENi]*)", resource)
|
||||||
|
value, unit = match.groups()
|
||||||
|
value = float(value)
|
||||||
|
unit_mapping = {
|
||||||
|
"": 1,
|
||||||
|
"m": 1e-3,
|
||||||
|
"K": 1e3,
|
||||||
|
"M": 1e6,
|
||||||
|
"G": 1e9,
|
||||||
|
"T": 1e12,
|
||||||
|
"P": 1e15,
|
||||||
|
"E": 1e18,
|
||||||
|
"Ki": 2**10,
|
||||||
|
"Mi": 2**20,
|
||||||
|
"Gi": 2**30,
|
||||||
|
"Ti": 2**40,
|
||||||
|
"Pi": 2**50,
|
||||||
|
"Ei": 2**60,
|
||||||
|
}
|
||||||
|
if unit not in unit_mapping:
|
||||||
|
raise ValueError(f"Unknown resources unit: {unit}")
|
||||||
|
return value * unit_mapping[unit]
|
||||||
|
|
||||||
|
|
||||||
|
def limit_resources(resource):
|
||||||
|
if "limits" not in resource:
|
||||||
|
return resource
|
||||||
|
if "cpu" in resource["limits"]:
|
||||||
|
cpu_limit = parse_resource(resource["limits"]["cpu"])
|
||||||
|
if cpu_limit > 30:
|
||||||
|
logger.error("CPU limit exceeded. Adjusting to 30 cores.")
|
||||||
|
resource["limits"]["cpu"] = "30"
|
||||||
|
if "memory" in resource["limits"]:
|
||||||
|
memory_limit = parse_resource(resource["limits"]["memory"])
|
||||||
|
if memory_limit > 100 * 2**30:
|
||||||
|
logger.error("Memory limit exceeded, adjusting to 100Gi")
|
||||||
|
resource["limits"]["memory"] = "100Gi"
|
||||||
|
|
||||||
|
|
||||||
|
def consistent_resources(resource):
|
||||||
|
if "limits" not in resource and "requests" not in resource:
|
||||||
|
return resource
|
||||||
|
elif "limits" in resource:
|
||||||
|
resource["requests"] = resource["limits"]
|
||||||
|
else:
|
||||||
|
resource["limits"] = resource["requests"]
|
||||||
|
return resource
|
||||||
|
|
||||||
|
|
||||||
|
def resource_check(values: Dict[str, Any]):
|
||||||
|
resources = values.get("resources", {}).get("limits", {})
|
||||||
|
if "nvidia.com/gpu" in resources and int(resources["nvidia.com/gpu"]) > 0:
|
||||||
|
values["resources"]["limits"]["nvidia.com/gpumem"] = 8192
|
||||||
|
values["resources"]["limits"]["nvidia.com/gpucores"] = 10
|
||||||
|
values["resources"]["requests"] = values["resources"].get("requests", {})
|
||||||
|
if "cpu" not in values["resources"]["requests"] and "cpu" in values["resources"]["limits"]:
|
||||||
|
values["resources"]["requests"]["cpu"] = values["resources"]["limits"]["cpu"]
|
||||||
|
if "memory" not in values["resources"]["requests"] and "memory" in values["resources"]["limits"]:
|
||||||
|
values["resources"]["requests"]["memory"] = values["resources"]["limits"]["memory"]
|
||||||
|
values["resources"]["requests"]["nvidia.com/gpu"] = values["resources"]["limits"]["nvidia.com/gpu"]
|
||||||
|
values["resources"]["requests"]["nvidia.com/gpumem"] = 8192
|
||||||
|
values["resources"]["requests"]["nvidia.com/gpucores"] = 10
|
||||||
|
|
||||||
|
values["nodeSelector"] = values.get("nodeSelector", {})
|
||||||
|
if "contest.4pd.io/accelerator" not in values["nodeSelector"]:
|
||||||
|
values["nodeSelector"]["contest.4pd.io/accelerator"] = "A100-SXM4-80GBvgpu"
|
||||||
|
gpu_type = values["nodeSelector"]["contest.4pd.io/accelerator"]
|
||||||
|
gpu_num = resources["nvidia.com/gpu"]
|
||||||
|
if gpu_type != "A100-SXM4-80GBvgpu":
|
||||||
|
raise RuntimeError("GPU类型只能为A100-SXM4-80GBvgpu")
|
||||||
|
if gpu_num != 1:
|
||||||
|
raise RuntimeError("GPU个数只能为1")
|
||||||
|
values["tolerations"] = values.get("tolerations", [])
|
||||||
|
values["tolerations"].append(
|
||||||
|
{
|
||||||
|
"key": "hosttype",
|
||||||
|
"operator": "Equal",
|
||||||
|
"value": "vgpu",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return values
|
||||||
38
utils/leaderboard.py
Normal file
38
utils/leaderboard.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from utils.request import requests_retry_session
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
lb_headers = {"Content-Type":"application/json"}
|
||||||
|
if os.getenv("LEADERBOARD_API_TOKEN"):
|
||||||
|
lb_headers['Authorization'] = 'Bearer ' + os.getenv("LEADERBOARD_API_TOKEN")
|
||||||
|
|
||||||
|
|
||||||
|
def change_product_unavailable() -> None:
|
||||||
|
logger.info("更改为产品不可用...")
|
||||||
|
submit_id = str(os.getenv("SUBMIT_ID", -1))
|
||||||
|
try:
|
||||||
|
requests_retry_session().post(
|
||||||
|
os.getenv("UPDATE_SUBMIT_URL", "http://contest.4pd.io:8080/submit/update"),
|
||||||
|
data=json.dumps({submit_id: {"product_avaliable": 0}}),
|
||||||
|
headers=lb_headers,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"change product avaliable error, {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def mark_evaluating(task_id) -> None:
|
||||||
|
logger.info("上报EVALUATING状态...")
|
||||||
|
job_id = os.getenv('JOB_ID') or "-1"
|
||||||
|
url = os.getenv("REGISTER_MARK_TASK_URL", "http://contest.4pd.io:8080/job/register_mark_task") + "/" + job_id
|
||||||
|
try:
|
||||||
|
requests_retry_session().post(
|
||||||
|
url,
|
||||||
|
data=json.dumps({"task_id": task_id}),
|
||||||
|
headers=lb_headers,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"mark evaluating error, {e}")
|
||||||
30
utils/logger.py
Normal file
30
utils/logger.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
# another logger
|
||||||
|
|
||||||
|
log = logging.getLogger("detailed_logger")
|
||||||
|
|
||||||
|
log.propagate = False
|
||||||
|
|
||||||
|
level = logging.INFO
|
||||||
|
|
||||||
|
log.setLevel(level)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
|
||||||
|
"%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
streamHandler = logging.StreamHandler()
|
||||||
|
streamHandler.setLevel(level)
|
||||||
|
streamHandler.setFormatter(formatter)
|
||||||
|
log.addHandler(streamHandler)
|
||||||
320
utils/metrics.py
Normal file
320
utils/metrics.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections import Counter
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import Levenshtein
|
||||||
|
import numpy as np
|
||||||
|
from schemas.context import ASRContext
|
||||||
|
from utils.logger import logger
|
||||||
|
from utils.tokenizer import Tokenizer, TokenizerType
|
||||||
|
from utils.update_submit import change_product_available
|
||||||
|
|
||||||
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def text_align(context: ASRContext) -> Tuple:
|
||||||
|
start_end_count = 0
|
||||||
|
|
||||||
|
label_start_time_list = []
|
||||||
|
label_end_time_list = []
|
||||||
|
for label_item in context.labels:
|
||||||
|
label_start_time_list.append(label_item.start)
|
||||||
|
label_end_time_list.append(label_item.end)
|
||||||
|
pred_start_time_list = []
|
||||||
|
pred_end_time_list = []
|
||||||
|
sentence_start = True
|
||||||
|
for pred_item in context.preds:
|
||||||
|
if sentence_start:
|
||||||
|
pred_start_time_list.append(pred_item.recognition_results.start_time)
|
||||||
|
if pred_item.recognition_results.final_result:
|
||||||
|
pred_end_time_list.append(pred_item.recognition_results.end_time)
|
||||||
|
sentence_start = pred_item.recognition_results.final_result
|
||||||
|
# check start0 < end0 < start1 < end1 < start2 < end2 - ...
|
||||||
|
if IN_TEST:
|
||||||
|
print(pred_start_time_list)
|
||||||
|
print(pred_end_time_list)
|
||||||
|
pred_time_list = []
|
||||||
|
i, j = 0, 0
|
||||||
|
while i < len(pred_start_time_list) and j < len(pred_end_time_list):
|
||||||
|
pred_time_list.append(pred_start_time_list[i])
|
||||||
|
pred_time_list.append(pred_end_time_list[j])
|
||||||
|
i += 1
|
||||||
|
j += 1
|
||||||
|
if i < len(pred_start_time_list):
|
||||||
|
pred_time_list.append(pred_start_time_list[-1])
|
||||||
|
for i in range(1, len(pred_time_list)):
|
||||||
|
# 这里给个 600ms 的宽限
|
||||||
|
if pred_time_list[i] < pred_time_list[i - 1] - 0.6:
|
||||||
|
logger.error("识别的 start、end 不符合 start0 < end0 < start1 < end1 < start2 < end2 ...")
|
||||||
|
logger.error(
|
||||||
|
f"当前识别的每个句子开始和结束时间分别为: \
|
||||||
|
开始时间:{pred_start_time_list}, \
|
||||||
|
结束时间:{pred_end_time_list}"
|
||||||
|
)
|
||||||
|
start_end_count += 1
|
||||||
|
# change_product_available()
|
||||||
|
# 时间前后差值 300ms 范围内
|
||||||
|
start_time_align_count = 0
|
||||||
|
end_time_align_count = 0
|
||||||
|
for label_start_time in label_start_time_list:
|
||||||
|
for pred_start_time in pred_start_time_list:
|
||||||
|
if pred_start_time <= label_start_time + 0.3 and pred_start_time >= label_start_time - 0.3:
|
||||||
|
start_time_align_count += 1
|
||||||
|
break
|
||||||
|
for label_end_time in label_end_time_list:
|
||||||
|
for pred_end_time in pred_end_time_list:
|
||||||
|
if pred_end_time <= label_end_time + 0.3 and pred_end_time >= label_end_time - 0.3:
|
||||||
|
end_time_align_count += 1
|
||||||
|
break
|
||||||
|
logger.info(
|
||||||
|
f"start-time 对齐个数 {start_time_align_count}, \
|
||||||
|
end-time 对齐个数 {end_time_align_count}\
|
||||||
|
数据集中句子总数 {len(label_start_time_list)}"
|
||||||
|
)
|
||||||
|
return start_time_align_count, end_time_align_count, start_end_count
|
||||||
|
|
||||||
|
|
||||||
|
def first_delay(context: ASRContext) -> Tuple:
|
||||||
|
first_send_time = context.preds[0].send_time
|
||||||
|
first_delay_list = []
|
||||||
|
sentence_start = True
|
||||||
|
for pred_context in context.preds:
|
||||||
|
if sentence_start:
|
||||||
|
sentence_begin_time = pred_context.recognition_results.start_time
|
||||||
|
first_delay_time = pred_context.recv_time - first_send_time - sentence_begin_time
|
||||||
|
first_delay_list.append(first_delay_time)
|
||||||
|
sentence_start = pred_context.recognition_results.final_result
|
||||||
|
if IN_TEST:
|
||||||
|
print(f"当前音频的首字延迟为{first_delay_list}")
|
||||||
|
logger.info(f"当前音频的首字延迟均值为 {np.mean(first_delay_list)}s")
|
||||||
|
return np.sum(first_delay_list), len(first_delay_list)
|
||||||
|
|
||||||
|
|
||||||
|
def revision_delay(context: ASRContext):
|
||||||
|
first_send_time = context.preds[0].send_time
|
||||||
|
revision_delay_list = []
|
||||||
|
for pred_context in context.preds:
|
||||||
|
if pred_context.recognition_results.final_result:
|
||||||
|
sentence_end_time = pred_context.recognition_results.end_time
|
||||||
|
revision_delay_time = pred_context.recv_time - first_send_time - sentence_end_time
|
||||||
|
revision_delay_list.append(revision_delay_time)
|
||||||
|
|
||||||
|
if IN_TEST:
|
||||||
|
print(revision_delay_list)
|
||||||
|
logger.info(f"当前音频的修正延迟均值为 {np.mean(revision_delay_list)}s")
|
||||||
|
return np.sum(revision_delay_list), len(revision_delay_list)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_unique_token_count(context: ASRContext):
|
||||||
|
# print(context.__dict__)
|
||||||
|
# 对于每一个返回的结果都进行 tokenize
|
||||||
|
pred_text_list = [pred_context.recognition_results.text for pred_context in context.preds]
|
||||||
|
pred_text_tokenized_list = Tokenizer.norm_and_tokenize(pred_text_list, lang=context.lang)
|
||||||
|
# print(pred_text_list)
|
||||||
|
# print(pred_text_tokenized_list)
|
||||||
|
|
||||||
|
# 判断当前是否修改了超过 3s 内的 token 数目
|
||||||
|
## 当前句子的最开始接受时间
|
||||||
|
first_recv_time = None
|
||||||
|
## 不可修改的 token 个数
|
||||||
|
unmodified_token_cnt = 0
|
||||||
|
## 3s 的 index 位置
|
||||||
|
time_token_idx = 0
|
||||||
|
## 当前是句子的开始
|
||||||
|
final_sentence = True
|
||||||
|
|
||||||
|
## 修改了不可修改的范围
|
||||||
|
is_unmodified_token = False
|
||||||
|
|
||||||
|
for idx, (now_tokens, pred_context) in enumerate(zip(pred_text_tokenized_list, context.preds)):
|
||||||
|
## 当前是句子的第一次返回
|
||||||
|
if final_sentence:
|
||||||
|
first_recv_time = pred_context.recv_time
|
||||||
|
unmodified_token_cnt = 0
|
||||||
|
time_token_idx = idx
|
||||||
|
final_sentence = pred_context.recognition_results.final_result
|
||||||
|
continue
|
||||||
|
final_sentence = pred_context.recognition_results.final_result
|
||||||
|
## 当前 pred 的 recv-time
|
||||||
|
pred_recv_time = pred_context.recv_time
|
||||||
|
## 最开始 3s 直接忽略
|
||||||
|
if pred_recv_time - first_recv_time < 3:
|
||||||
|
continue
|
||||||
|
## 根据历史返回信息,获得最长不可修改长度
|
||||||
|
while time_token_idx < idx:
|
||||||
|
context_pred_tmp = context.preds[time_token_idx]
|
||||||
|
context_pred_tmp_recv_time = context_pred_tmp.recv_time
|
||||||
|
tmp_tokens = pred_text_tokenized_list[time_token_idx]
|
||||||
|
if pred_recv_time - context_pred_tmp_recv_time >= 3:
|
||||||
|
unmodified_token_cnt = max(unmodified_token_cnt, len(tmp_tokens))
|
||||||
|
time_token_idx += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
## 和自己的上一条音频比,只能修改 unmodified_token_cnt 个 token
|
||||||
|
last_tokens = pred_text_tokenized_list[idx - 1]
|
||||||
|
if context.lang in ['ar', 'he']:
|
||||||
|
tokens_check_pre, tokens_check_now = last_tokens[::-1], now_tokens[::-1]
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
tokens_check_pre, tokens_check_now = last_tokens, now_tokens
|
||||||
|
for token_a, token_b in zip(tokens_check_pre[:unmodified_token_cnt], tokens_check_now[:unmodified_token_cnt]):
|
||||||
|
if token_a != token_b:
|
||||||
|
is_unmodified_token = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_unmodified_token and int(os.getenv('test', 0)):
|
||||||
|
logger.error(
|
||||||
|
f"{idx}-{unmodified_token_cnt}-{last_tokens[:unmodified_token_cnt]}-{now_tokens[:unmodified_token_cnt]}"
|
||||||
|
)
|
||||||
|
if is_unmodified_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_unmodified_token:
|
||||||
|
logger.error("修改了不可修改的文字范围")
|
||||||
|
# change_product_available()
|
||||||
|
if int(os.getenv('test', 0)):
|
||||||
|
final_result = True
|
||||||
|
result_list = []
|
||||||
|
for tokens, pred in zip(pred_text_tokenized_list, context.preds):
|
||||||
|
if final_result:
|
||||||
|
result_list.append([])
|
||||||
|
result_list[-1].append((tokens, pred.recv_time - context.preds[0].recv_time))
|
||||||
|
final_result = pred.recognition_results.final_result
|
||||||
|
for item in result_list:
|
||||||
|
logger.info(str(item))
|
||||||
|
|
||||||
|
# 记录每个 patch 的 token 个数
|
||||||
|
patch_unique_cnt_counter = Counter()
|
||||||
|
patch_unique_cnt_in_one_sentence = set()
|
||||||
|
for pred_text_tokenized, pred_context in zip(pred_text_tokenized_list, context.preds):
|
||||||
|
token_cnt = len(pred_text_tokenized)
|
||||||
|
patch_unique_cnt_in_one_sentence.add(token_cnt)
|
||||||
|
if pred_context.recognition_results.final_result:
|
||||||
|
for unique_cnt in patch_unique_cnt_in_one_sentence:
|
||||||
|
patch_unique_cnt_counter[unique_cnt] += 1
|
||||||
|
patch_unique_cnt_in_one_sentence.clear()
|
||||||
|
if context.preds and not context.preds[-1].recognition_results.final_result:
|
||||||
|
for unique_cnt in patch_unique_cnt_in_one_sentence:
|
||||||
|
patch_unique_cnt_counter[unique_cnt] += 1
|
||||||
|
# print(patch_unique_cnt_counter)
|
||||||
|
logger.info(
|
||||||
|
f"当前音频的 patch token 均值为 {mean_on_counter(patch_unique_cnt_counter)}, \
|
||||||
|
当前音频的 patch token 方差为 {var_on_counter(patch_unique_cnt_counter)}"
|
||||||
|
)
|
||||||
|
return patch_unique_cnt_counter
|
||||||
|
|
||||||
|
|
||||||
|
def mean_on_counter(counter: Counter):
|
||||||
|
total_sum = sum(key * count for key, count in counter.items())
|
||||||
|
total_count = sum(counter.values())
|
||||||
|
return total_sum * 1.0 / total_count
|
||||||
|
|
||||||
|
|
||||||
|
def var_on_counter(counter: Counter):
|
||||||
|
total_sum = sum(key * count for key, count in counter.items())
|
||||||
|
total_count = sum(counter.values())
|
||||||
|
mean = total_sum * 1.0 / total_count
|
||||||
|
return sum((key - mean) ** 2 * count for key, count in counter.items()) / total_count
|
||||||
|
|
||||||
|
|
||||||
|
def edit_distance(arr1: List, arr2: List):
|
||||||
|
operations = Levenshtein.editops(arr1, arr2)
|
||||||
|
i = sum([1 for operation in operations if operation[0] == "insert"])
|
||||||
|
s = sum([1 for operation in operations if operation[0] == "replace"])
|
||||||
|
d = sum([1 for operation in operations if operation[0] == "delete"])
|
||||||
|
c = len(arr1) - s - d
|
||||||
|
return s, d, i, c
|
||||||
|
|
||||||
|
|
||||||
|
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||||||
|
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||||||
|
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||||||
|
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||||||
|
equal = sum(1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt)
|
||||||
|
replace = len(tokens_gt_mapping) - insert - equal
|
||||||
|
|
||||||
|
token_count = replace + equal + delete
|
||||||
|
cer_value = (replace + delete + insert) * 1.0 / token_count
|
||||||
|
logger.info(f"当前音频的 cer/wer 值为 {cer_value}, token 个数为 {token_count}")
|
||||||
|
return 1 - cer_value, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def cut_rate(
|
||||||
|
tokens_gt: List[List[str]],
|
||||||
|
tokens_dt: List[List[str]],
|
||||||
|
tokens_gt_mapping: List[str],
|
||||||
|
tokens_dt_mapping: List[str],
|
||||||
|
):
|
||||||
|
sentence_final_token_index_gt = sentence_final_token_index(tokens_gt, tokens_gt_mapping)
|
||||||
|
sentence_final_token_index_dt = sentence_final_token_index(tokens_dt, tokens_dt_mapping)
|
||||||
|
sentence_final_token_index_gt = set(sentence_final_token_index_gt)
|
||||||
|
sentence_final_token_index_dt = set(sentence_final_token_index_dt)
|
||||||
|
sentence_count_gt = len(sentence_final_token_index_gt)
|
||||||
|
miss_count = len(sentence_final_token_index_gt - sentence_final_token_index_dt)
|
||||||
|
more_count = len(sentence_final_token_index_dt - sentence_final_token_index_gt)
|
||||||
|
rate = max(1 - (miss_count + more_count * 2) / sentence_count_gt, 0)
|
||||||
|
return rate, sentence_count_gt, miss_count, more_count
|
||||||
|
|
||||||
|
|
||||||
|
def token_mapping(tokens_gt: List[str], tokens_dt: List[str]) -> Tuple[List[str], List[str]]:
|
||||||
|
arr1 = deepcopy(tokens_gt)
|
||||||
|
arr2 = deepcopy(tokens_dt)
|
||||||
|
operations = Levenshtein.editops(arr1, arr2)
|
||||||
|
for op in operations[::-1]:
|
||||||
|
if op[0] == "insert":
|
||||||
|
arr1.insert(op[1], None)
|
||||||
|
elif op[0] == "delete":
|
||||||
|
arr2.insert(op[2], None)
|
||||||
|
return arr1, arr2
|
||||||
|
|
||||||
|
|
||||||
|
def sentence_final_token_index(tokens: List[List[str]], tokens_mapping: List[str]) -> List[int]:
|
||||||
|
"""获得原句子中每个句子尾部 token 的 index"""
|
||||||
|
token_index_list = []
|
||||||
|
token_index = 0
|
||||||
|
for token_in_one_sentence in tokens:
|
||||||
|
for _ in range(len(token_in_one_sentence)):
|
||||||
|
while token_index < len(tokens_mapping) and tokens_mapping[token_index] is None:
|
||||||
|
token_index += 1
|
||||||
|
token_index += 1
|
||||||
|
token_index_list.append(token_index - 1)
|
||||||
|
return token_index_list
|
||||||
|
|
||||||
|
|
||||||
|
def cut_sentence(sentences: List[str], tokenizerType: TokenizerType) -> List[str]:
|
||||||
|
"""use self.cut_punc to cut all sentences, merge them and put them into list"""
|
||||||
|
sentence_cut_list = []
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_list = [sentence]
|
||||||
|
sentence_tmp_list = []
|
||||||
|
for punc in [
|
||||||
|
"······",
|
||||||
|
"......",
|
||||||
|
"。",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
"...",
|
||||||
|
".",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
]:
|
||||||
|
for sentence in sentence_list:
|
||||||
|
sentence_tmp_list.extend(sentence.split(punc))
|
||||||
|
sentence_list, sentence_tmp_list = sentence_tmp_list, []
|
||||||
|
sentence_list = [item for item in sentence_list if item]
|
||||||
|
|
||||||
|
if tokenizerType == TokenizerType.whitespace:
|
||||||
|
sentence_cut_list.append(" ".join(sentence_list))
|
||||||
|
else:
|
||||||
|
sentence_cut_list.append("".join(sentence_list))
|
||||||
|
|
||||||
|
return sentence_cut_list
|
||||||
50
utils/metrics_plus.py
Normal file
50
utils/metrics_plus.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from utils.tokenizer import TokenizerType
|
||||||
|
|
||||||
|
|
||||||
|
def replace_general_punc(
|
||||||
|
sentences: List[str], tokenizer: TokenizerType
|
||||||
|
) -> List[str]:
|
||||||
|
"""代替原来的函数 utils.metrics.cut_sentence"""
|
||||||
|
general_puncs = [
|
||||||
|
"······",
|
||||||
|
"......",
|
||||||
|
"。",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
"...",
|
||||||
|
".",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
]
|
||||||
|
if tokenizer == TokenizerType.whitespace:
|
||||||
|
replacer = " "
|
||||||
|
else:
|
||||||
|
replacer = ""
|
||||||
|
trans = str.maketrans(dict.fromkeys("".join(general_puncs), replacer))
|
||||||
|
ret_sentences = [""] * len(sentences)
|
||||||
|
for i, sentence in enumerate(sentences):
|
||||||
|
sentence = sentence.translate(trans)
|
||||||
|
sentence = sentence.strip()
|
||||||
|
sentence = sentence.lower()
|
||||||
|
ret_sentences[i] = sentence
|
||||||
|
return ret_sentences
|
||||||
|
|
||||||
|
|
||||||
|
def distance_point_line(
|
||||||
|
point: float, line_start: float, line_end: float
|
||||||
|
) -> float:
|
||||||
|
"""计算点到直线的距离"""
|
||||||
|
if line_start <= point <= line_end:
|
||||||
|
return 0
|
||||||
|
if point < line_start:
|
||||||
|
return abs(point - line_start)
|
||||||
|
else:
|
||||||
|
return abs(point - line_end)
|
||||||
93
utils/pynini/Dockerfile
Normal file
93
utils/pynini/Dockerfile
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Dockerfile
|
||||||
|
# Pierre-André Noël, May 12th 2020
|
||||||
|
# Copyright © Element AI Inc. All rights reserved.
|
||||||
|
# Apache License, Version 2.0
|
||||||
|
#
|
||||||
|
# This builds `manylinux_2_28_x86_64` Python wheels for `pynini`, wrapping
|
||||||
|
# all its dependencies.
|
||||||
|
#
|
||||||
|
# This Dockerfile uses multi-stage builds; for more information, see:
|
||||||
|
# https://docs.docker.com/develop/develop-images/multistage-build/
|
||||||
|
#
|
||||||
|
# The recommended installation method for Pynini is through Conda-Forge. This gives Linux
|
||||||
|
# x86-64 users another option: installing a precompiled module from PyPI.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# To build wheels and run Pynini's tests, run:
|
||||||
|
#
|
||||||
|
# docker build --target=run-tests -t build-pynini-wheels .
|
||||||
|
#
|
||||||
|
# To extract the resulting wheels from the Docker image, run:
|
||||||
|
#
|
||||||
|
# docker run --rm -v `pwd`:/io build-pynini-wheels cp -r /wheelhouse /io
|
||||||
|
#
|
||||||
|
# Notice that this also generates Cython wheels.
|
||||||
|
#
|
||||||
|
# Then, `twine` (https://twine.readthedocs.io/en/latest/) can be used to
|
||||||
|
# publish the resulting Pynini wheels.
|
||||||
|
|
||||||
|
# ******************************************************
|
||||||
|
# *** All the following images are based on this one ***
|
||||||
|
# ******************************************************
|
||||||
|
#from quay.io/pypa/manylinux_2_28_x86_64 AS common
|
||||||
|
|
||||||
|
# ***********************************************************************
|
||||||
|
# *** Image providing all the requirements for building Pynini wheels ***
|
||||||
|
# ***********************************************************************
|
||||||
|
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
|
||||||
|
|
||||||
|
# The versions we want in the wheels.
|
||||||
|
ENV FST_VERSION "1.8.3"
|
||||||
|
ENV PYNINI_VERSION "2.1.6"
|
||||||
|
|
||||||
|
# Location of OpenFst and Pynini.
|
||||||
|
ENV FST_DOWNLOAD_PREFIX "https://www.openfst.org/twiki/pub/FST/FstDownload"
|
||||||
|
ENV PYNINI_DOWNLOAD_PREFIX "https://www.opengrm.org/twiki/pub/GRM/PyniniDownload"
|
||||||
|
|
||||||
|
# Note that our certificates are not known to the version of wget available in this image.
|
||||||
|
|
||||||
|
# Gets and unpack OpenFst source.
|
||||||
|
RUN apt update && apt-get install -y wget gcc-9 g++-9 make && ln -s $(which gcc-9) /usr/bin/gcc && ln -s $(which g++-9) /usr/bin/g++
|
||||||
|
RUN cd /tmp \
|
||||||
|
&& wget -q --no-check-certificate "${FST_DOWNLOAD_PREFIX}/openfst-${FST_VERSION}.tar.gz" \
|
||||||
|
&& tar -xzf "openfst-${FST_VERSION}.tar.gz" \
|
||||||
|
&& rm "openfst-${FST_VERSION}.tar.gz"
|
||||||
|
|
||||||
|
# Compiles OpenFst.
|
||||||
|
RUN cd "/tmp/openfst-${FST_VERSION}" \
|
||||||
|
&& ./configure --enable-grm \
|
||||||
|
&& make --jobs 4 install \
|
||||||
|
&& rm -rd "/tmp/openfst-${FST_VERSION}"
|
||||||
|
|
||||||
|
# Gets and unpacks Pynini source.
|
||||||
|
RUN mkdir -p /src && cd /src \
|
||||||
|
&& wget -q --no-check-certificate "${PYNINI_DOWNLOAD_PREFIX}/pynini-${PYNINI_VERSION}.tar.gz" \
|
||||||
|
&& tar -xzf "pynini-${PYNINI_VERSION}.tar.gz" \
|
||||||
|
&& rm "pynini-${PYNINI_VERSION}.tar.gz"
|
||||||
|
|
||||||
|
# Installs requirements in all our Pythons.
|
||||||
|
RUN pip install -i https://nexus.4pd.io/repository/pypi-all/simple -r "/src/pynini-${PYNINI_VERSION}/requirements.txt" || exit;
|
||||||
|
|
||||||
|
|
||||||
|
# **********************************************************
|
||||||
|
# *** Image making pynini wheels (placed in /wheelhouse) ***
|
||||||
|
# **********************************************************
|
||||||
|
#FROM wheel-building-env AS build-wheels
|
||||||
|
|
||||||
|
# Compiles the wheels to a temporary directory.
|
||||||
|
RUN pip wheel -i https://nexus.4pd.io/repository/pypi-all/simple -v "/src/pynini-${PYNINI_VERSION}" -w /tmp/wheelhouse/ || exit;
|
||||||
|
|
||||||
|
RUN wget ftp://ftp.4pd.io/pub/pico/temp/patchelf-0.18.0-x86_64.tar.gz && tar xzf patchelf-0.18.0-x86_64.tar.gz && rm -f patchelf-0.18.0-x86_64.tar.gz
|
||||||
|
RUN pip install -i https://nexus.4pd.io/repository/pypi-all/simple auditwheel
|
||||||
|
# Bundles external shared libraries into the wheels.
|
||||||
|
# See https://github.com/pypa/manylinux/tree/manylinux2014
|
||||||
|
RUN for WHL in /tmp/wheelhouse/pynini*.whl; do \
|
||||||
|
PATH=$(pwd)/bin:$PATH auditwheel repair --plat manylinux_2_31_x86_64 "${WHL}" -w /wheelhouse/ || exit; \
|
||||||
|
done
|
||||||
|
#RUN mkdir -p /wheelhouse && for WHL in /tmp/wheelhouse/pynini*.whl; do \
|
||||||
|
# cp "${WHL}" /wheelhouse/; \
|
||||||
|
#done
|
||||||
|
|
||||||
|
# Removes the non-repaired wheels.
|
||||||
|
RUN rm -rd /tmp/wheelhouse
|
||||||
|
|
||||||
17
utils/pynini/README.md
Normal file
17
utils/pynini/README.md
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# pynini
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
SpeechIO对英文ASR的评估工具依赖第三方库pynini(https://github.com/kylebgorman/pynini),该库强绑定OS和gcc版本,需要在运行环境中编译生成wheel包,本文说明编译pynini生成wheel包的方法
|
||||||
|
|
||||||
|
## 编译
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker build -t build-pynini-wheels .
|
||||||
|
```
|
||||||
|
|
||||||
|
## 获取wheel包
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker run --rm -v `pwd`:/io build-pynini-wheels cp -r /wheelhouse /io
|
||||||
|
```
|
||||||
40
utils/request.py
Normal file
40
utils/request.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from requests.packages.urllib3.util.retry import Retry
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 2 * 60 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutHTTPAdapter(HTTPAdapter):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.timeout = DEFAULT_TIMEOUT
|
||||||
|
if "timeout" in kwargs:
|
||||||
|
self.timeout = kwargs["timeout"]
|
||||||
|
del kwargs["timeout"]
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def send(self, request, **kwargs):
|
||||||
|
timeout = kwargs.get("timeout")
|
||||||
|
if timeout is None:
|
||||||
|
kwargs["timeout"] = self.timeout
|
||||||
|
return super().send(request, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def requests_retry_session(
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=1,
|
||||||
|
status_forcelist=[500, 502, 504, 404, 403],
|
||||||
|
session=None,
|
||||||
|
):
|
||||||
|
session = session or requests.Session()
|
||||||
|
retry = Retry(
|
||||||
|
total=retries,
|
||||||
|
read=retries,
|
||||||
|
connect=retries,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
status_forcelist=status_forcelist,
|
||||||
|
)
|
||||||
|
adapter = TimeoutHTTPAdapter(max_retries=retry)
|
||||||
|
session.mount('http://', adapter)
|
||||||
|
session.mount('https://', adapter)
|
||||||
|
return session
|
||||||
65
utils/service.py
Normal file
65
utils/service.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from utils.helm import deploy_chart, gen_chart_tarball
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
def register_sut(st_config, resource_name, **kwargs):
|
||||||
|
|
||||||
|
job_id = "".join([c for c in str(os.getenv("JOB_ID", -1)) if c.isnumeric()])
|
||||||
|
|
||||||
|
docker_image = "10.255.143.18:5000/speaker_identification:wo_model_v0"
|
||||||
|
#if "docker_image" in st_config and st_config["docker_image"]:
|
||||||
|
st_config_values = st_config.get("values", {})
|
||||||
|
#docker_image = st_config["docker_image"]
|
||||||
|
docker_image = "10.255.143.18:5000/speaker_identification:wo_model_v0"
|
||||||
|
chart_tar_fp, chart_values = gen_chart_tarball(docker_image)
|
||||||
|
sut_service_name, _ = deploy_chart(
|
||||||
|
resource_name,
|
||||||
|
int(os.getenv("readiness_timeout", 60 * 3)),
|
||||||
|
chart_fileobj=chart_tar_fp,
|
||||||
|
extra_values=st_config_values,
|
||||||
|
restart_count_limit=int(os.getenv('restart_count', 3)),
|
||||||
|
)
|
||||||
|
chart_tar_fp.close()
|
||||||
|
if st_config_values is not None and "service" in st_config_values and "port" in st_config_values["service"]:
|
||||||
|
sut_service_port = str(st_config_values["service"]["port"])
|
||||||
|
else:
|
||||||
|
sut_service_port = str(chart_values["service"]["port"])
|
||||||
|
return "ws://{}:{}".format(sut_service_name, sut_service_port)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
elif "chart_repo" in st_config:
|
||||||
|
logger.info(f"正在使用 helm-chart 配置,内容为 {st_config}")
|
||||||
|
chart_repo = st_config.get("chart_repo", None)
|
||||||
|
chart_name = st_config.get("chart_name", None)
|
||||||
|
chart_version = st_config.get("chart_version", None)
|
||||||
|
if chart_repo is None or chart_name is None or chart_version is None:
|
||||||
|
logger.error("chart_repo, chart_name, chart_version cant be none")
|
||||||
|
logger.info(f"{chart_repo} {chart_name} {chart_version}")
|
||||||
|
chart_str = os.path.join(chart_repo, chart_name) + ':' + chart_version
|
||||||
|
|
||||||
|
st_cfg_values = st_config.get('values', {})
|
||||||
|
st_config["values"] = st_cfg_values
|
||||||
|
|
||||||
|
sut_service_name, _ = deploy_chart(
|
||||||
|
resource_name,
|
||||||
|
600,
|
||||||
|
chart_str=chart_str,
|
||||||
|
extra_values=st_cfg_values,
|
||||||
|
)
|
||||||
|
sut_service_name = f"asr-{job_id}"
|
||||||
|
if st_cfg_values is not None and 'service' in st_cfg_values and 'port' in st_cfg_values['service']:
|
||||||
|
sut_service_port = str(st_cfg_values['service']['port'])
|
||||||
|
else:
|
||||||
|
sut_service_port = '80'
|
||||||
|
return 'ws://%s:%s' % (sut_service_name, sut_service_port)
|
||||||
|
else:
|
||||||
|
logger.error("配置信息错误,缺少 docker_image 属性")
|
||||||
|
#sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
3
utils/speechio/__init__.py
Normal file
3
utils/speechio/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
'''
|
||||||
|
reference: https://github.com/SpeechColab/Leaderboard/tree/f287a992dc359d1c021bfc6ce810e5e36608e057/utils
|
||||||
|
'''
|
||||||
551
utils/speechio/error_rate_en.py
Normal file
551
utils/speechio/error_rate_en.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf8
|
||||||
|
# Copyright 2022 Zhenxiang MA, Jiayu DU (SpeechColab)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stderr, level=logging.ERROR, format='[%(levelname)s] %(message)s')
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
# reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py
|
||||||
|
# to import original lib:
|
||||||
|
# from pynini.lib.edit_transducer import EditTransducer
|
||||||
|
class EditTransducer:
|
||||||
|
DELETE = "<delete>"
|
||||||
|
INSERT = "<insert>"
|
||||||
|
SUBSTITUTE = "<substitute>"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
symbol_table,
|
||||||
|
vocab: Iterable[str],
|
||||||
|
insert_cost: float = 1.0,
|
||||||
|
delete_cost: float = 1.0,
|
||||||
|
substitute_cost: float = 1.0,
|
||||||
|
bound: int = 0,
|
||||||
|
):
|
||||||
|
# Left factor; note that we divide the edit costs by two because they also
|
||||||
|
# will be incurred when traversing the right factor.
|
||||||
|
sigma = pynini.union(
|
||||||
|
*[pynini.accep(token, token_type=symbol_table) for token in vocab],
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2)
|
||||||
|
delete = pynini.cross(sigma, pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2))
|
||||||
|
substitute = pynini.cross(sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2))
|
||||||
|
|
||||||
|
edit = pynini.union(insert, delete, substitute).optimize()
|
||||||
|
|
||||||
|
if bound:
|
||||||
|
sigma_star = pynini.closure(sigma)
|
||||||
|
self._e_i = sigma_star.copy()
|
||||||
|
for _ in range(bound):
|
||||||
|
self._e_i.concat(edit.ques).concat(sigma_star)
|
||||||
|
else:
|
||||||
|
self._e_i = edit.union(sigma).closure()
|
||||||
|
|
||||||
|
self._e_i.optimize()
|
||||||
|
|
||||||
|
right_factor_std = EditTransducer._right_factor(self._e_i)
|
||||||
|
# right_factor_ext allows 0-cost matching between token's raw form & auxiliary form
|
||||||
|
# e.g.: 'I' -> 'I#', 'AM' -> 'AM#'
|
||||||
|
right_factor_ext = (
|
||||||
|
pynini.union(
|
||||||
|
*[
|
||||||
|
pynini.cross(
|
||||||
|
pynini.accep(x, token_type=symbol_table),
|
||||||
|
pynini.accep(x + '#', token_type=symbol_table),
|
||||||
|
)
|
||||||
|
for x in vocab
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.optimize()
|
||||||
|
.closure()
|
||||||
|
)
|
||||||
|
self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _right_factor(ifst: pynini.Fst) -> pynini.Fst:
|
||||||
|
ofst = pynini.invert(ifst)
|
||||||
|
syms = pynini.generated_symbols()
|
||||||
|
insert_label = syms.find(EditTransducer.INSERT)
|
||||||
|
delete_label = syms.find(EditTransducer.DELETE)
|
||||||
|
pairs = [(insert_label, delete_label), (delete_label, insert_label)]
|
||||||
|
right_factor = ofst.relabel_pairs(ipairs=pairs)
|
||||||
|
return right_factor
|
||||||
|
|
||||||
|
def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst:
|
||||||
|
lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr)
|
||||||
|
EditTransducer.check_wellformed_lattice(lattice)
|
||||||
|
return lattice
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_wellformed_lattice(lattice: pynini.Fst) -> None:
|
||||||
|
if lattice.start() == pynini.NO_STATE_ID:
|
||||||
|
raise RuntimeError("Edit distance composition lattice is empty.")
|
||||||
|
|
||||||
|
def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float:
|
||||||
|
lattice = self.create_lattice(iexpr, oexpr)
|
||||||
|
# The shortest cost from all final states to the start state is
|
||||||
|
# equivalent to the cost of the shortest path.
|
||||||
|
start = lattice.start()
|
||||||
|
return float(pynini.shortestdistance(lattice, reverse=True)[start])
|
||||||
|
|
||||||
|
def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike:
|
||||||
|
print(iexpr)
|
||||||
|
print(oexpr)
|
||||||
|
lattice = self.create_lattice(iexpr, oexpr)
|
||||||
|
alignment = pynini.shortestpath(lattice, nshortest=1, unique=True)
|
||||||
|
return alignment.optimize()
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorStats:
|
||||||
|
def __init__(self):
|
||||||
|
self.num_ref_utts = 0
|
||||||
|
self.num_hyp_utts = 0
|
||||||
|
self.num_eval_utts = 0 # in both ref & hyp
|
||||||
|
self.num_hyp_without_ref = 0
|
||||||
|
|
||||||
|
self.C = 0
|
||||||
|
self.S = 0
|
||||||
|
self.I = 0
|
||||||
|
self.D = 0
|
||||||
|
self.token_error_rate = 0.0
|
||||||
|
self.modified_token_error_rate = 0.0
|
||||||
|
|
||||||
|
self.num_utts_with_error = 0
|
||||||
|
self.sentence_error_rate = 0.0
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
# return json.dumps(self.__dict__, indent=4)
|
||||||
|
return json.dumps(self.__dict__)
|
||||||
|
|
||||||
|
def to_kaldi(self):
|
||||||
|
info = (
|
||||||
|
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
|
||||||
|
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def to_summary(self):
|
||||||
|
summary = (
|
||||||
|
'==================== Overall Statistics ====================\n'
|
||||||
|
F'num_ref_utts: {self.num_ref_utts}\n'
|
||||||
|
F'num_hyp_utts: {self.num_hyp_utts}\n'
|
||||||
|
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
|
||||||
|
F'num_eval_utts: {self.num_eval_utts}\n'
|
||||||
|
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
|
||||||
|
F'token_error_rate: {self.token_error_rate:.2f}%\n'
|
||||||
|
F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n'
|
||||||
|
F'token_stats:\n'
|
||||||
|
F' - tokens:{self.C + self.S + self.D:>7}\n'
|
||||||
|
F' - edits: {self.S + self.I + self.D:>7}\n'
|
||||||
|
F' - cor: {self.C:>7}\n'
|
||||||
|
F' - sub: {self.S:>7}\n'
|
||||||
|
F' - ins: {self.I:>7}\n'
|
||||||
|
F' - del: {self.D:>7}\n'
|
||||||
|
'============================================================\n'
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
class Utterance:
|
||||||
|
def __init__(self, uid, text):
|
||||||
|
self.uid = uid
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
|
def LoadKaldiArc(filepath):
|
||||||
|
utts = {}
|
||||||
|
with open(filepath, 'r', encoding='utf8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
cols = line.split(maxsplit=1)
|
||||||
|
assert len(cols) == 2 or len(cols) == 1
|
||||||
|
uid = cols[0]
|
||||||
|
text = cols[1] if len(cols) == 2 else ''
|
||||||
|
if utts.get(uid) != None:
|
||||||
|
raise RuntimeError(F'Found duplicated utterence id {uid}')
|
||||||
|
utts[uid] = Utterance(uid, text)
|
||||||
|
return utts
|
||||||
|
|
||||||
|
|
||||||
|
def BreakHyphen(token: str):
|
||||||
|
# 'T-SHIRT' should also introduce new words into vocabulary, e.g.:
|
||||||
|
# 1. 'T' & 'SHIRT'
|
||||||
|
# 2. 'TSHIRT'
|
||||||
|
assert '-' in token
|
||||||
|
v = token.split('-')
|
||||||
|
v.append(token.replace('-', ''))
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def LoadGLM(rel_path):
|
||||||
|
'''
|
||||||
|
glm.csv:
|
||||||
|
I'VE,I HAVE
|
||||||
|
GOING TO,GONNA
|
||||||
|
...
|
||||||
|
T-SHIRT,T SHIRT,TSHIRT
|
||||||
|
|
||||||
|
glm:
|
||||||
|
{
|
||||||
|
'<RULE_00000>': ["I'VE", 'I HAVE'],
|
||||||
|
'<RULE_00001>': ['GOING TO', 'GONNA'],
|
||||||
|
...
|
||||||
|
'<RULE_99999>': ['T-SHIRT', 'T SHIRT', 'TSHIRT'],
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
logging.info(f'Loading GLM from {rel_path} ...')
|
||||||
|
|
||||||
|
abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
|
||||||
|
reader = list(csv.reader(open(abs_path, encoding="utf-8"), delimiter=','))
|
||||||
|
|
||||||
|
glm = {}
|
||||||
|
for k, rule in enumerate(reader):
|
||||||
|
rule_name = f'<RULE_{k:06d}>'
|
||||||
|
glm[rule_name] = [phrase.strip() for phrase in rule]
|
||||||
|
logging.info(f' #rule: {len(glm)}')
|
||||||
|
|
||||||
|
return glm
|
||||||
|
|
||||||
|
|
||||||
|
def SymbolEQ(symbol_table, i1, i2):
|
||||||
|
return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#')
|
||||||
|
|
||||||
|
|
||||||
|
def PrintSymbolTable(symbol_table: pynini.SymbolTable):
|
||||||
|
print('SYMBOL_TABLE:')
|
||||||
|
for k in range(symbol_table.num_symbols()):
|
||||||
|
sym = symbol_table.find(k)
|
||||||
|
assert symbol_table.find(sym) == k # symbol table's find can be used for bi-directional lookup (id <-> sym)
|
||||||
|
print(k, sym)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def BuildSymbolTable(vocab) -> pynini.SymbolTable:
|
||||||
|
logging.info('Building symbol table ...')
|
||||||
|
symbol_table = pynini.SymbolTable()
|
||||||
|
symbol_table.add_symbol('<epsilon>')
|
||||||
|
|
||||||
|
for w in vocab:
|
||||||
|
symbol_table.add_symbol(w)
|
||||||
|
logging.info(f' #symbols: {symbol_table.num_symbols()}')
|
||||||
|
|
||||||
|
# PrintSymbolTable(symbol_table)
|
||||||
|
# symbol_table.write_text('symbol_table.txt')
|
||||||
|
return symbol_table
|
||||||
|
|
||||||
|
|
||||||
|
def BuildGLMTagger(glm, symbol_table) -> pynini.Fst:
|
||||||
|
logging.info('Building GLM tagger ...')
|
||||||
|
rule_taggers = []
|
||||||
|
for rule_tag, rule in glm.items():
|
||||||
|
for phrase in rule:
|
||||||
|
rule_taggers.append(
|
||||||
|
(
|
||||||
|
pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
|
||||||
|
+ pynini.accep(phrase, token_type=symbol_table)
|
||||||
|
+ pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
alphabet = pynini.union(
|
||||||
|
*[pynini.accep(sym, token_type=symbol_table) for k, sym in symbol_table if k != 0] # non-epsilon
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
tagger = pynini.cdrewrite(
|
||||||
|
pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure()
|
||||||
|
).optimize() # could be slow with large vocabulary
|
||||||
|
return tagger
|
||||||
|
|
||||||
|
|
||||||
|
def TokenWidth(token: str):
|
||||||
|
def CharWidth(c):
|
||||||
|
return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1
|
||||||
|
|
||||||
|
return sum([CharWidth(c) for c in token])
|
||||||
|
|
||||||
|
|
||||||
|
def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream=sys.stderr):
|
||||||
|
assert len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali)
|
||||||
|
|
||||||
|
H = ' HYP# : '
|
||||||
|
R = ' REF : '
|
||||||
|
E = ' EDIT : '
|
||||||
|
for i, e in enumerate(edit_ali):
|
||||||
|
h, r = hyp_ali[i], ref_ali[i]
|
||||||
|
e = '' if e == 'C' else e # don't bother printing correct edit-tag
|
||||||
|
|
||||||
|
nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e)
|
||||||
|
n = max(nr, nh, ne) + 1
|
||||||
|
|
||||||
|
H += h + ' ' * (n - nh)
|
||||||
|
R += r + ' ' * (n - nr)
|
||||||
|
E += e + ' ' * (n - ne)
|
||||||
|
|
||||||
|
print(F' HYP : {raw_hyp}', file=stream)
|
||||||
|
print(H, file=stream)
|
||||||
|
print(R, file=stream)
|
||||||
|
print(E, file=stream)
|
||||||
|
|
||||||
|
|
||||||
|
def ComputeTokenErrorRate(c, s, i, d):
|
||||||
|
assert (s + d + c) != 0
|
||||||
|
num_edits = s + d + i
|
||||||
|
ref_len = c + s + d
|
||||||
|
hyp_len = c + s + i
|
||||||
|
return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len)
|
||||||
|
|
||||||
|
|
||||||
|
def ComputeSentenceErrorRate(num_err_utts, num_utts):
|
||||||
|
assert num_utts != 0
|
||||||
|
return 100.0 * num_err_utts / num_utts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--logk', type=int, default=500, help='logging interval')
|
||||||
|
parser.add_argument(
|
||||||
|
'--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER'
|
||||||
|
)
|
||||||
|
parser.add_argument('--glm', type=str, default='glm_en.csv', help='glm')
|
||||||
|
parser.add_argument('--ref', type=str, required=True, help='reference kaldi arc file')
|
||||||
|
parser.add_argument('--hyp', type=str, required=True, help='hypothesis kaldi arc file')
|
||||||
|
parser.add_argument('result_file', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(args)
|
||||||
|
|
||||||
|
stats = ErrorStats()
|
||||||
|
|
||||||
|
logging.info('Generating tokenizer ...')
|
||||||
|
if args.tokenizer == 'whitespace':
|
||||||
|
|
||||||
|
def word_tokenizer(text):
|
||||||
|
return text.strip().split()
|
||||||
|
|
||||||
|
tokenizer = word_tokenizer
|
||||||
|
elif args.tokenizer == 'char':
|
||||||
|
|
||||||
|
def char_tokenizer(text):
|
||||||
|
return [c for c in text.strip().replace(' ', '')]
|
||||||
|
|
||||||
|
tokenizer = char_tokenizer
|
||||||
|
else:
|
||||||
|
tokenizer = None
|
||||||
|
assert tokenizer
|
||||||
|
|
||||||
|
logging.info('Loading REF & HYP ...')
|
||||||
|
ref_utts = LoadKaldiArc(args.ref)
|
||||||
|
hyp_utts = LoadKaldiArc(args.hyp)
|
||||||
|
|
||||||
|
# check valid utterances in hyp that have matched non-empty reference
|
||||||
|
uids = []
|
||||||
|
for uid in sorted(hyp_utts.keys()):
|
||||||
|
if uid in ref_utts.keys():
|
||||||
|
if ref_utts[uid].text.strip(): # non-empty reference
|
||||||
|
uids.append(uid)
|
||||||
|
else:
|
||||||
|
logging.warning(F'Found {uid} with empty reference, skipping...')
|
||||||
|
else:
|
||||||
|
logging.warning(F'Found {uid} without reference, skipping...')
|
||||||
|
stats.num_hyp_without_ref += 1
|
||||||
|
|
||||||
|
stats.num_hyp_utts = len(hyp_utts)
|
||||||
|
stats.num_ref_utts = len(ref_utts)
|
||||||
|
stats.num_eval_utts = len(uids)
|
||||||
|
logging.info(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
|
||||||
|
print(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
for uid in uids:
|
||||||
|
ref_tokens = tokenizer(ref_utts[uid].text)
|
||||||
|
hyp_tokens = tokenizer(hyp_utts[uid].text)
|
||||||
|
for t in ref_tokens + hyp_tokens:
|
||||||
|
tokens.append(t)
|
||||||
|
if '-' in t:
|
||||||
|
tokens.extend(BreakHyphen(t))
|
||||||
|
vocab_from_utts = list(set(tokens))
|
||||||
|
logging.info(f' HYP&REF vocab size: {len(vocab_from_utts)}')
|
||||||
|
print(f' HYP&REF vocab size: {len(vocab_from_utts)}')
|
||||||
|
|
||||||
|
assert args.glm
|
||||||
|
glm = LoadGLM(args.glm)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
for rule in glm.values():
|
||||||
|
for phrase in rule:
|
||||||
|
for t in tokenizer(phrase):
|
||||||
|
tokens.append(t)
|
||||||
|
if '-' in t:
|
||||||
|
tokens.extend(BreakHyphen(t))
|
||||||
|
vocab_from_glm = list(set(tokens))
|
||||||
|
logging.info(f' GLM vocab size: {len(vocab_from_glm)}')
|
||||||
|
print(f' GLM vocab size: {len(vocab_from_glm)}')
|
||||||
|
|
||||||
|
vocab = list(set(vocab_from_utts + vocab_from_glm))
|
||||||
|
logging.info(f'Global vocab size: {len(vocab)}')
|
||||||
|
print(f'Global vocab size: {len(vocab)}')
|
||||||
|
|
||||||
|
symtab = BuildSymbolTable(
|
||||||
|
# Normal evaluation vocab + auxiliary form for alternative paths + GLM tags
|
||||||
|
vocab
|
||||||
|
+ [x + '#' for x in vocab]
|
||||||
|
+ [x for x in glm.keys()]
|
||||||
|
)
|
||||||
|
glm_tagger = BuildGLMTagger(glm, symtab)
|
||||||
|
edit_transducer = EditTransducer(symbol_table=symtab, vocab=vocab)
|
||||||
|
print(edit_transducer)
|
||||||
|
|
||||||
|
logging.info('Evaluating error rate ...')
|
||||||
|
print('Evaluating error rate ...')
|
||||||
|
fo = open(args.result_file, 'w+', encoding='utf8')
|
||||||
|
ndone = 0
|
||||||
|
for uid in uids:
|
||||||
|
ref = ref_utts[uid].text
|
||||||
|
raw_hyp = hyp_utts[uid].text
|
||||||
|
|
||||||
|
ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type=symtab)
|
||||||
|
print(ref_fst)
|
||||||
|
|
||||||
|
# print(ref_fst.string(token_type = symtab))
|
||||||
|
|
||||||
|
raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type=symtab)
|
||||||
|
# print(raw_hyp_fst.string(token_type = symtab))
|
||||||
|
|
||||||
|
# Say, we have:
|
||||||
|
# RULE_001: "I'M" <-> "I AM"
|
||||||
|
# REF: HEY I AM HERE
|
||||||
|
# HYP: HEY I'M HERE
|
||||||
|
#
|
||||||
|
# We want to expand HYP with GLM rules(marked with auxiliary #)
|
||||||
|
# HYP#: HEY {I'M | I# AM#} HERE
|
||||||
|
# REF is honored to keep its original form.
|
||||||
|
#
|
||||||
|
# This could be considered as a flexible on-the-fly TN towards HYP.
|
||||||
|
|
||||||
|
# 1. GLM rule tagging:
|
||||||
|
# HEY I'M HERE
|
||||||
|
# ->
|
||||||
|
# HEY <RULE_001> I'M <RULE_001> HERE
|
||||||
|
lattice = (raw_hyp_fst @ glm_tagger).optimize()
|
||||||
|
tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type=symtab)
|
||||||
|
# print(hyp_tagged)
|
||||||
|
|
||||||
|
# 2. GLM rule expansion:
|
||||||
|
# HEY <RULE_001> I'M <RULE_001> HERE
|
||||||
|
# ->
|
||||||
|
# sausage-like fst: HEY {I'M | I# AM#} HERE
|
||||||
|
tokens = tagged_ir.split()
|
||||||
|
sausage = pynini.accep('', token_type=symtab)
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens): # invariant: tokens[0, i) has been built into fst
|
||||||
|
forms = []
|
||||||
|
if tokens[i].startswith('<RULE_') and tokens[i].endswith('>'): # rule segment
|
||||||
|
rule_name = tokens[i]
|
||||||
|
rule = glm[rule_name]
|
||||||
|
# pre-condition: i -> ltag
|
||||||
|
raw_form = ''
|
||||||
|
for j in range(i + 1, len(tokens)):
|
||||||
|
if tokens[j] == rule_name:
|
||||||
|
raw_form = ' '.join(tokens[i + 1 : j])
|
||||||
|
break
|
||||||
|
assert raw_form
|
||||||
|
# post-condition: i -> ltag, j -> rtag
|
||||||
|
|
||||||
|
forms.append(raw_form)
|
||||||
|
for phrase in rule:
|
||||||
|
if phrase != raw_form:
|
||||||
|
forms.append(' '.join([x + '#' for x in phrase.split()]))
|
||||||
|
i = j + 1
|
||||||
|
else: # normal token segment
|
||||||
|
token = tokens[i]
|
||||||
|
forms.append(token)
|
||||||
|
if "-" in token: # token with hyphen yields extra forms
|
||||||
|
forms.append(' '.join([x + '#' for x in token.split('-')])) # 'T-SHIRT' -> 'T# SHIRT#'
|
||||||
|
forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#'
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
sausage_segment = pynini.union(*[pynini.accep(x, token_type=symtab) for x in forms]).optimize()
|
||||||
|
sausage += sausage_segment
|
||||||
|
hyp_fst = sausage.optimize()
|
||||||
|
print(hyp_fst)
|
||||||
|
|
||||||
|
# Utterance-Level error rate evaluation
|
||||||
|
alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst)
|
||||||
|
print("alignment", alignment)
|
||||||
|
|
||||||
|
distance = 0.0
|
||||||
|
C, S, I, D = 0, 0, 0, 0 # Cor, Sub, Ins, Del
|
||||||
|
edit_ali, ref_ali, hyp_ali = [], [], []
|
||||||
|
for state in alignment.states():
|
||||||
|
for arc in alignment.arcs(state):
|
||||||
|
i, o = arc.ilabel, arc.olabel
|
||||||
|
if i != 0 and o != 0 and SymbolEQ(symtab, i, o):
|
||||||
|
e = 'C'
|
||||||
|
r, h = symtab.find(i), symtab.find(o)
|
||||||
|
|
||||||
|
C += 1
|
||||||
|
distance += 0.0
|
||||||
|
elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o):
|
||||||
|
e = 'S'
|
||||||
|
r, h = symtab.find(i), symtab.find(o)
|
||||||
|
|
||||||
|
S += 1
|
||||||
|
distance += 1.0
|
||||||
|
elif i == 0 and o != 0:
|
||||||
|
e = 'I'
|
||||||
|
r, h = '*', symtab.find(o)
|
||||||
|
|
||||||
|
I += 1
|
||||||
|
distance += 1.0
|
||||||
|
elif i != 0 and o == 0:
|
||||||
|
e = 'D'
|
||||||
|
r, h = symtab.find(i), '*'
|
||||||
|
|
||||||
|
D += 1
|
||||||
|
distance += 1.0
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
edit_ali.append(e)
|
||||||
|
ref_ali.append(r)
|
||||||
|
hyp_ali.append(h)
|
||||||
|
# assert(distance == edit_transducer.compute_distance(ref_fst, sausage))
|
||||||
|
|
||||||
|
utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D)
|
||||||
|
# print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo)
|
||||||
|
# PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo)
|
||||||
|
|
||||||
|
if utt_ter > 0:
|
||||||
|
stats.num_utts_with_error += 1
|
||||||
|
|
||||||
|
stats.C += C
|
||||||
|
stats.S += S
|
||||||
|
stats.I += I
|
||||||
|
stats.D += D
|
||||||
|
|
||||||
|
ndone += 1
|
||||||
|
if ndone % args.logk == 0:
|
||||||
|
logging.info(f'{ndone} utts evaluated.')
|
||||||
|
logging.info(f'{ndone} utts evaluated in total.')
|
||||||
|
|
||||||
|
# Corpus-Level evaluation
|
||||||
|
stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D)
|
||||||
|
stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts)
|
||||||
|
|
||||||
|
print(stats.to_json(), file=fo)
|
||||||
|
# print(stats.to_kaldi())
|
||||||
|
# print(stats.to_summary(), file=fo)
|
||||||
|
|
||||||
|
fo.close()
|
||||||
370
utils/speechio/error_rate_zh.py
Normal file
370
utils/speechio/error_rate_zh.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf8
|
||||||
|
|
||||||
|
# Copyright 2021 Jiayu DU
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
|
||||||
|
|
||||||
|
DEBUG = None
|
||||||
|
|
||||||
|
def GetEditType(ref_token, hyp_token):
|
||||||
|
if ref_token == None and hyp_token != None:
|
||||||
|
return 'I'
|
||||||
|
elif ref_token != None and hyp_token == None:
|
||||||
|
return 'D'
|
||||||
|
elif ref_token == hyp_token:
|
||||||
|
return 'C'
|
||||||
|
elif ref_token != hyp_token:
|
||||||
|
return 'S'
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
class AlignmentArc:
|
||||||
|
def __init__(self, src, dst, ref, hyp):
|
||||||
|
self.src = src
|
||||||
|
self.dst = dst
|
||||||
|
self.ref = ref
|
||||||
|
self.hyp = hyp
|
||||||
|
self.edit_type = GetEditType(ref, hyp)
|
||||||
|
|
||||||
|
def similarity_score_function(ref_token, hyp_token):
|
||||||
|
return 0 if (ref_token == hyp_token) else -1.0
|
||||||
|
|
||||||
|
def insertion_score_function(token):
|
||||||
|
return -1.0
|
||||||
|
|
||||||
|
def deletion_score_function(token):
|
||||||
|
return -1.0
|
||||||
|
|
||||||
|
def EditDistance(
|
||||||
|
ref,
|
||||||
|
hyp,
|
||||||
|
similarity_score_function = similarity_score_function,
|
||||||
|
insertion_score_function = insertion_score_function,
|
||||||
|
deletion_score_function = deletion_score_function):
|
||||||
|
assert(len(ref) != 0)
|
||||||
|
class DPState:
|
||||||
|
def __init__(self):
|
||||||
|
self.score = -float('inf')
|
||||||
|
# backpointer
|
||||||
|
self.prev_r = None
|
||||||
|
self.prev_h = None
|
||||||
|
|
||||||
|
def print_search_grid(S, R, H, fstream):
|
||||||
|
print(file=fstream)
|
||||||
|
for r in range(R):
|
||||||
|
for h in range(H):
|
||||||
|
print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
|
||||||
|
print(file=fstream)
|
||||||
|
|
||||||
|
R = len(ref) + 1
|
||||||
|
H = len(hyp) + 1
|
||||||
|
|
||||||
|
# Construct DP search space, a (R x H) grid
|
||||||
|
S = [ [] for r in range(R) ]
|
||||||
|
for r in range(R):
|
||||||
|
S[r] = [ DPState() for x in range(H) ]
|
||||||
|
|
||||||
|
# initialize DP search grid origin, S(r = 0, h = 0)
|
||||||
|
S[0][0].score = 0.0
|
||||||
|
S[0][0].prev_r = None
|
||||||
|
S[0][0].prev_h = None
|
||||||
|
|
||||||
|
# initialize REF axis
|
||||||
|
for r in range(1, R):
|
||||||
|
S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
|
||||||
|
S[r][0].prev_r = r-1
|
||||||
|
S[r][0].prev_h = 0
|
||||||
|
|
||||||
|
# initialize HYP axis
|
||||||
|
for h in range(1, H):
|
||||||
|
S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
|
||||||
|
S[0][h].prev_r = 0
|
||||||
|
S[0][h].prev_h = h-1
|
||||||
|
|
||||||
|
best_score = S[0][0].score
|
||||||
|
best_state = (0, 0)
|
||||||
|
|
||||||
|
for r in range(1, R):
|
||||||
|
for h in range(1, H):
|
||||||
|
sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
|
||||||
|
new_score = S[r-1][h-1].score + sub_or_cor_score
|
||||||
|
if new_score >= S[r][h].score:
|
||||||
|
S[r][h].score = new_score
|
||||||
|
S[r][h].prev_r = r-1
|
||||||
|
S[r][h].prev_h = h-1
|
||||||
|
|
||||||
|
del_score = deletion_score_function(ref[r-1])
|
||||||
|
new_score = S[r-1][h].score + del_score
|
||||||
|
if new_score >= S[r][h].score:
|
||||||
|
S[r][h].score = new_score
|
||||||
|
S[r][h].prev_r = r - 1
|
||||||
|
S[r][h].prev_h = h
|
||||||
|
|
||||||
|
ins_score = insertion_score_function(hyp[h-1])
|
||||||
|
new_score = S[r][h-1].score + ins_score
|
||||||
|
if new_score >= S[r][h].score:
|
||||||
|
S[r][h].score = new_score
|
||||||
|
S[r][h].prev_r = r
|
||||||
|
S[r][h].prev_h = h-1
|
||||||
|
|
||||||
|
best_score = S[R-1][H-1].score
|
||||||
|
best_state = (R-1, H-1)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print_search_grid(S, R, H, sys.stderr)
|
||||||
|
|
||||||
|
# Backtracing best alignment path, i.e. a list of arcs
|
||||||
|
# arc = (src, dst, ref, hyp, edit_type)
|
||||||
|
# src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
|
||||||
|
best_path = []
|
||||||
|
r, h = best_state[0], best_state[1]
|
||||||
|
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
|
||||||
|
score = S[r][h].score
|
||||||
|
# loop invariant:
|
||||||
|
# 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
|
||||||
|
# 2. score is the value of point(r, h) on DP search grid
|
||||||
|
while prev_r != None or prev_h != None:
|
||||||
|
src = (prev_r, prev_h)
|
||||||
|
dst = (r, h)
|
||||||
|
if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
|
||||||
|
arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
|
||||||
|
elif (r == prev_r + 1 and h == prev_h): # Deletion
|
||||||
|
arc = AlignmentArc(src, dst, ref[prev_r], None)
|
||||||
|
elif (r == prev_r and h == prev_h + 1): # Insertion
|
||||||
|
arc = AlignmentArc(src, dst, None, hyp[prev_h])
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
best_path.append(arc)
|
||||||
|
r, h = prev_r, prev_h
|
||||||
|
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
|
||||||
|
score = S[r][h].score
|
||||||
|
|
||||||
|
best_path.reverse()
|
||||||
|
return (best_path, best_score)
|
||||||
|
|
||||||
|
def PrettyPrintAlignment(alignment, stream = sys.stderr):
|
||||||
|
def get_token_str(token):
|
||||||
|
if token == None:
|
||||||
|
return "*"
|
||||||
|
return token
|
||||||
|
|
||||||
|
def is_double_width_char(ch):
|
||||||
|
if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
|
||||||
|
return True
|
||||||
|
# TODO: support other double-width-char language such as Japanese, Korean
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def display_width(token_str):
|
||||||
|
m = 0
|
||||||
|
for c in token_str:
|
||||||
|
if is_double_width_char(c):
|
||||||
|
m += 2
|
||||||
|
else:
|
||||||
|
m += 1
|
||||||
|
return m
|
||||||
|
|
||||||
|
R = ' REF : '
|
||||||
|
H = ' HYP : '
|
||||||
|
E = ' EDIT : '
|
||||||
|
for arc in alignment:
|
||||||
|
r = get_token_str(arc.ref)
|
||||||
|
h = get_token_str(arc.hyp)
|
||||||
|
e = arc.edit_type if arc.edit_type != 'C' else ''
|
||||||
|
|
||||||
|
nr, nh, ne = display_width(r), display_width(h), display_width(e)
|
||||||
|
n = max(nr, nh, ne) + 1
|
||||||
|
|
||||||
|
R += r + ' ' * (n-nr)
|
||||||
|
H += h + ' ' * (n-nh)
|
||||||
|
E += e + ' ' * (n-ne)
|
||||||
|
|
||||||
|
print(R, file=stream)
|
||||||
|
print(H, file=stream)
|
||||||
|
print(E, file=stream)
|
||||||
|
|
||||||
|
def CountEdits(alignment):
|
||||||
|
c, s, i, d = 0, 0, 0, 0
|
||||||
|
for arc in alignment:
|
||||||
|
if arc.edit_type == 'C':
|
||||||
|
c += 1
|
||||||
|
elif arc.edit_type == 'S':
|
||||||
|
s += 1
|
||||||
|
elif arc.edit_type == 'I':
|
||||||
|
i += 1
|
||||||
|
elif arc.edit_type == 'D':
|
||||||
|
d += 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
return (c, s, i, d)
|
||||||
|
|
||||||
|
def ComputeTokenErrorRate(c, s, i, d):
|
||||||
|
return 100.0 * (s + d + i) / (s + d + c)
|
||||||
|
|
||||||
|
def ComputeSentenceErrorRate(num_err_utts, num_utts):
|
||||||
|
assert(num_utts != 0)
|
||||||
|
return 100.0 * num_err_utts / num_utts
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationResult:
|
||||||
|
def __init__(self):
|
||||||
|
self.num_ref_utts = 0
|
||||||
|
self.num_hyp_utts = 0
|
||||||
|
self.num_eval_utts = 0 # seen in both ref & hyp
|
||||||
|
self.num_hyp_without_ref = 0
|
||||||
|
|
||||||
|
self.C = 0
|
||||||
|
self.S = 0
|
||||||
|
self.I = 0
|
||||||
|
self.D = 0
|
||||||
|
self.token_error_rate = 0.0
|
||||||
|
|
||||||
|
self.num_utts_with_error = 0
|
||||||
|
self.sentence_error_rate = 0.0
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(self.__dict__)
|
||||||
|
|
||||||
|
def to_kaldi(self):
|
||||||
|
info = (
|
||||||
|
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
|
||||||
|
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def to_sclite(self):
|
||||||
|
return "TODO"
|
||||||
|
|
||||||
|
def to_espnet(self):
|
||||||
|
return "TODO"
|
||||||
|
|
||||||
|
def to_summary(self):
|
||||||
|
#return json.dumps(self.__dict__, indent=4)
|
||||||
|
summary = (
|
||||||
|
'==================== Overall Statistics ====================\n'
|
||||||
|
F'num_ref_utts: {self.num_ref_utts}\n'
|
||||||
|
F'num_hyp_utts: {self.num_hyp_utts}\n'
|
||||||
|
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
|
||||||
|
F'num_eval_utts: {self.num_eval_utts}\n'
|
||||||
|
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
|
||||||
|
F'token_error_rate: {self.token_error_rate:.2f}%\n'
|
||||||
|
F'token_stats:\n'
|
||||||
|
F' - tokens:{self.C + self.S + self.D:>7}\n'
|
||||||
|
F' - edits: {self.S + self.I + self.D:>7}\n'
|
||||||
|
F' - cor: {self.C:>7}\n'
|
||||||
|
F' - sub: {self.S:>7}\n'
|
||||||
|
F' - ins: {self.I:>7}\n'
|
||||||
|
F' - del: {self.D:>7}\n'
|
||||||
|
'============================================================\n'
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
class Utterance:
|
||||||
|
def __init__(self, uid, text):
|
||||||
|
self.uid = uid
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
|
def LoadUtterances(filepath, format):
|
||||||
|
utts = {}
|
||||||
|
if format == 'text': # utt_id word1 word2 ...
|
||||||
|
with open(filepath, 'r', encoding='utf8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
cols = line.split(maxsplit=1)
|
||||||
|
assert(len(cols) == 2 or len(cols) == 1)
|
||||||
|
uid = cols[0]
|
||||||
|
text = cols[1] if len(cols) == 2 else ''
|
||||||
|
if utts.get(uid) != None:
|
||||||
|
raise RuntimeError(F'Found duplicated utterence id {uid}')
|
||||||
|
utts[uid] = Utterance(uid, text)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(F'Unsupported text format {format}')
|
||||||
|
return utts
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(text, tokenizer):
|
||||||
|
if tokenizer == 'whitespace':
|
||||||
|
return text.split()
|
||||||
|
elif tokenizer == 'char':
|
||||||
|
return [ ch for ch in ''.join(text.split()) ]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# optional
|
||||||
|
parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
|
||||||
|
parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
|
||||||
|
parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
|
||||||
|
# required
|
||||||
|
parser.add_argument('--ref', type=str, required=True, help='input reference file')
|
||||||
|
parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
|
||||||
|
|
||||||
|
parser.add_argument('result_file', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(args)
|
||||||
|
|
||||||
|
ref_utts = LoadUtterances(args.ref, args.ref_format)
|
||||||
|
hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
|
||||||
|
|
||||||
|
r = EvaluationResult()
|
||||||
|
|
||||||
|
# check valid utterances in hyp that have matched non-empty reference
|
||||||
|
eval_utts = []
|
||||||
|
r.num_hyp_without_ref = 0
|
||||||
|
for uid in sorted(hyp_utts.keys()):
|
||||||
|
if uid in ref_utts.keys(): # TODO: efficiency
|
||||||
|
if ref_utts[uid].text.strip(): # non-empty reference
|
||||||
|
eval_utts.append(uid)
|
||||||
|
else:
|
||||||
|
logging.warn(F'Found {uid} with empty reference, skipping...')
|
||||||
|
else:
|
||||||
|
logging.warn(F'Found {uid} without reference, skipping...')
|
||||||
|
r.num_hyp_without_ref += 1
|
||||||
|
|
||||||
|
r.num_hyp_utts = len(hyp_utts)
|
||||||
|
r.num_ref_utts = len(ref_utts)
|
||||||
|
r.num_eval_utts = len(eval_utts)
|
||||||
|
|
||||||
|
with open(args.result_file, 'w+', encoding='utf8') as fo:
|
||||||
|
for uid in eval_utts:
|
||||||
|
ref = ref_utts[uid]
|
||||||
|
hyp = hyp_utts[uid]
|
||||||
|
|
||||||
|
alignment, score = EditDistance(
|
||||||
|
tokenize_text(ref.text, args.tokenizer),
|
||||||
|
tokenize_text(hyp.text, args.tokenizer)
|
||||||
|
)
|
||||||
|
|
||||||
|
c, s, i, d = CountEdits(alignment)
|
||||||
|
utt_ter = ComputeTokenErrorRate(c, s, i, d)
|
||||||
|
|
||||||
|
# utt-level evaluation result
|
||||||
|
print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
|
||||||
|
PrettyPrintAlignment(alignment, fo)
|
||||||
|
|
||||||
|
r.C += c
|
||||||
|
r.S += s
|
||||||
|
r.I += i
|
||||||
|
r.D += d
|
||||||
|
|
||||||
|
if utt_ter > 0:
|
||||||
|
r.num_utts_with_error += 1
|
||||||
|
|
||||||
|
# corpus level evaluation result
|
||||||
|
r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
|
||||||
|
r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
|
||||||
|
|
||||||
|
print(r.to_summary(), file=fo)
|
||||||
|
|
||||||
|
print(r.to_json())
|
||||||
|
print(r.to_kaldi())
|
||||||
744
utils/speechio/glm_en.csv
Normal file
744
utils/speechio/glm_en.csv
Normal file
@@ -0,0 +1,744 @@
|
|||||||
|
I'M,I AM
|
||||||
|
I'LL,I WILL
|
||||||
|
I'D,I HAD
|
||||||
|
I'VE,I HAVE
|
||||||
|
I WOULD'VE,I'D HAVE
|
||||||
|
YOU'RE,YOU ARE
|
||||||
|
YOU'LL,YOU WILL
|
||||||
|
YOU'D,YOU WOULD
|
||||||
|
YOU'VE,YOU HAVE
|
||||||
|
HE'S,HE IS,HE WAS
|
||||||
|
HE'LL,HE WILL
|
||||||
|
HE'D,HE HAD
|
||||||
|
SHE'S,SHE IS,SHE WAS
|
||||||
|
SHE'LL,SHE WILL
|
||||||
|
SHE'D,SHE HAD
|
||||||
|
IT'S,IT IS,IT WAS
|
||||||
|
IT'LL,IT WILL
|
||||||
|
WE'RE,WE ARE,WE WERE
|
||||||
|
WE'LL,WE WILL
|
||||||
|
WE'D,WE WOULD
|
||||||
|
WE'VE,WE HAVE
|
||||||
|
WHO'LL,WHO WILL
|
||||||
|
THEY'RE,THEY ARE
|
||||||
|
THEY'LL,THEY WILL
|
||||||
|
THAT'S,THAT IS,THAT WAS
|
||||||
|
THAT'LL,THAT WILL
|
||||||
|
HERE'S,HERE IS,HERE WAS
|
||||||
|
THERE'S,THERE IS,THERE WAS
|
||||||
|
WHERE'S,WHERE IS,WHERE WAS
|
||||||
|
WHAT'S,WHAT IS,WHAT WAS
|
||||||
|
LET'S,LET US
|
||||||
|
WHO'S,WHO IS
|
||||||
|
ONE'S,ONE IS
|
||||||
|
THERE'LL,THERE WILL
|
||||||
|
SOMEBODY'S,SOMEBODY IS
|
||||||
|
EVERYBODY'S,EVERYBODY IS
|
||||||
|
WOULD'VE,WOULD HAVE
|
||||||
|
CAN'T,CANNOT,CAN NOT
|
||||||
|
HADN'T,HAD NOT
|
||||||
|
HASN'T,HAS NOT
|
||||||
|
HAVEN'T,HAVE NOT
|
||||||
|
ISN'T,IS NOT
|
||||||
|
AREN'T,ARE NOT
|
||||||
|
WON'T,WILL NOT
|
||||||
|
WOULDN'T,WOULD NOT
|
||||||
|
SHOULDN'T,SHOULD NOT
|
||||||
|
DON'T,DO NOT
|
||||||
|
DIDN'T,DID NOT
|
||||||
|
GOTTA,GOT TO
|
||||||
|
GONNA,GOING TO
|
||||||
|
WANNA,WANT TO
|
||||||
|
LEMME,LET ME
|
||||||
|
GIMME,GIVE ME
|
||||||
|
DUNNO,DON'T KNOW
|
||||||
|
GOTCHA,GOT YOU
|
||||||
|
KINDA,KIND OF
|
||||||
|
MYSELF,MY SELF
|
||||||
|
YOURSELF,YOUR SELF
|
||||||
|
HIMSELF,HIM SELF
|
||||||
|
HERSELF,HER SELF
|
||||||
|
ITSELF,IT SELF
|
||||||
|
OURSELVES,OUR SELVES
|
||||||
|
OKAY,OK,O K
|
||||||
|
Y'ALL,YALL,YOU ALL
|
||||||
|
'CAUSE,'COS,CUZ,BECAUSE
|
||||||
|
FUCKIN',FUCKING
|
||||||
|
KILLING,KILLIN'
|
||||||
|
EVERYDAY,EVERY DAY
|
||||||
|
DOCTOR,DR,DR.
|
||||||
|
MRS,MISSES,MISSUS
|
||||||
|
MR,MR.,MISTER
|
||||||
|
SR,SR.,SENIOR
|
||||||
|
JR,JR.,JUNIOR
|
||||||
|
ST,ST.,SAINT
|
||||||
|
VOL,VOL.,VOLUME
|
||||||
|
CM,CENTIMETER,CENTIMETRE
|
||||||
|
MM,MILLIMETER,MILLIMETRE
|
||||||
|
KM,KILOMETER,KILOMETRE
|
||||||
|
KB,KILOBYTES,KILO BYTES,K B
|
||||||
|
MB,MEGABYTES,MEGA BYTES
|
||||||
|
GB,GIGABYTES,GIGA BYTES,G B
|
||||||
|
THOUSAND,THOUSAND AND
|
||||||
|
HUNDRED,HUNDRED AND
|
||||||
|
A HUNDRED,ONE HUNDRED
|
||||||
|
TWO THOUSAND AND,TWENTY,TWO THOUSAND
|
||||||
|
STORYTELLER,STORY TELLER
|
||||||
|
TSHIRT,T SHIRT
|
||||||
|
TSHIRTS,T SHIRTS
|
||||||
|
LEUKAEMIA,LEUKEMIA
|
||||||
|
OESTROGEN,ESTROGEN
|
||||||
|
ACKNOWLEDGMENT,ACKNOWLEDGEMENT
|
||||||
|
JUDGMENT,JUDGEMENT
|
||||||
|
MAMMA,MAMA
|
||||||
|
DINING,DINNING
|
||||||
|
FLACK,FLAK
|
||||||
|
LEARNT,LEARNED
|
||||||
|
BLONDE,BLOND
|
||||||
|
JUMPSTART,JUMP START
|
||||||
|
RIGHTNOW,RIGHT NOW
|
||||||
|
EVERYONE,EVERY ONE
|
||||||
|
NAME'S,NAME IS
|
||||||
|
FAMILY'S,FAMILY IS
|
||||||
|
COMPANY'S,COMPANY HAS
|
||||||
|
GRANDKID,GRAND KID
|
||||||
|
GRANDKIDS,GRAND KIDS
|
||||||
|
MEALTIMES,MEAL TIMES
|
||||||
|
ALRIGHT,ALL RIGHT
|
||||||
|
GROWNUP,GROWN UP
|
||||||
|
GROWNUPS,GROWN UPS
|
||||||
|
SCHOOLDAYS,SCHOOL DAYS
|
||||||
|
SCHOOLCHILDREN,SCHOOL CHILDREN
|
||||||
|
CASEBOOK,CASE BOOK
|
||||||
|
HUNGOVER,HUNG OVER
|
||||||
|
HANDCLAPS,HAND CLAPS
|
||||||
|
HANDCLAP,HAND CLAP
|
||||||
|
HEATWAVE,HEAT WAVE
|
||||||
|
ADDON,ADD ON
|
||||||
|
ONTO,ON TO
|
||||||
|
INTO,IN TO
|
||||||
|
GOTO,GO TO
|
||||||
|
GUNSHOT,GUN SHOT
|
||||||
|
MOTHERFUCKER,MOTHER FUCKER
|
||||||
|
OFTENTIMES,OFTEN TIMES
|
||||||
|
SARTRE'S,SARTRE IS
|
||||||
|
NONSTARTER,NON STARTER
|
||||||
|
NONSTARTERS,NON STARTERS
|
||||||
|
LONGTIME,LONG TIME
|
||||||
|
POLICYMAKERS,POLICY MAKERS
|
||||||
|
ANYMORE,ANY MORE
|
||||||
|
CANADA'S,CANADA IS
|
||||||
|
CELLPHONE,CELL PHONE
|
||||||
|
WORKPLACE,WORK PLACE
|
||||||
|
UNDERESTIMATING,UNDER ESTIMATING
|
||||||
|
CYBERSECURITY,CYBER SECURITY
|
||||||
|
NORTHEAST,NORTH EAST
|
||||||
|
ANYTIME,ANY TIME
|
||||||
|
LIVESTREAM,LIVE STREAM
|
||||||
|
LIVESTREAMS,LIVE STREAMS
|
||||||
|
WEBCAM,WEB CAM
|
||||||
|
EMAIL,E MAIL
|
||||||
|
ECAM,E CAM
|
||||||
|
VMIX,V MIX
|
||||||
|
SETUP,SET UP
|
||||||
|
SMARTPHONE,SMART PHONE
|
||||||
|
MULTICASTING,MULTI CASTING
|
||||||
|
CHITCHAT,CHIT CHAT
|
||||||
|
SEMIFINAL,SEMI FINAL
|
||||||
|
SEMIFINALS,SEMI FINALS
|
||||||
|
BBQ,BARBECUE
|
||||||
|
STORYLINE,STORY LINE
|
||||||
|
STORYLINES,STORY LINES
|
||||||
|
BRO,BROTHER
|
||||||
|
BROS,BROTHERS
|
||||||
|
OVERPROTECTIIVE,OVER PROTECTIVE
|
||||||
|
TIMEOUT,TIME OUT
|
||||||
|
ADVISOR,ADVISER
|
||||||
|
TIMBERWOLVES,TIMBER WOLVES
|
||||||
|
WEBPAGE,WEB PAGE
|
||||||
|
NEWCOMER,NEW COMER
|
||||||
|
DELMAR,DEL MAR
|
||||||
|
NETPLAY,NET PLAY
|
||||||
|
STREETSIDE,STREET SIDE
|
||||||
|
COLOURED,COLORED
|
||||||
|
COLOURFUL,COLORFUL
|
||||||
|
O,ZERO
|
||||||
|
ETCETERA,ET CETERA
|
||||||
|
FUNDRAISING,FUND RAISING
|
||||||
|
RAINFOREST,RAIN FOREST
|
||||||
|
BREATHTAKING,BREATH TAKING
|
||||||
|
WIKIPAGE,WIKI PAGE
|
||||||
|
OVERTIME,OVER TIME
|
||||||
|
TRAIN'S TRAIN IS
|
||||||
|
ANYONE,ANY ONE
|
||||||
|
PHYSIOTHERAPY,PHYSIO THERAPY
|
||||||
|
ANYBODY,ANY BODY
|
||||||
|
BOTTLECAPS,BOTTLE CAPS
|
||||||
|
BOTTLECAP,BOTTLE CAP
|
||||||
|
STEPFATHER'S,STEP FATHER'S
|
||||||
|
STEPFATHER,STEP FATHER
|
||||||
|
WARTIME,WAR TIME
|
||||||
|
SCREENSHOT,SCREEN SHOT
|
||||||
|
TIMELINE,TIME LINE
|
||||||
|
CITY'S,CITY IS
|
||||||
|
NONPROFIT,NON PROFIT
|
||||||
|
KPOP,K POP
|
||||||
|
HOMEBASE,HOME BASE
|
||||||
|
LIFELONG,LIFE LONG
|
||||||
|
LAWSUITS,LAW SUITS
|
||||||
|
MULTIBILLION,MULTI BILLION
|
||||||
|
ROADMAP,ROAD MAP
|
||||||
|
GUY'S,GUY IS
|
||||||
|
CHECKOUT,CHECK OUT
|
||||||
|
SQUARESPACE,SQUARE SPACE
|
||||||
|
REDLINING,RED LINING
|
||||||
|
BASE'S,BASE IS
|
||||||
|
TAKEAWAY,TAKE AWAY
|
||||||
|
CANDYLAND,CANDY LAND
|
||||||
|
ANTISOCIAL,ANTI SOCIAL
|
||||||
|
CASEWORK,CASE WORK
|
||||||
|
RIGOR,RIGOUR
|
||||||
|
ORGANIZATIONS,ORGANISATIONS
|
||||||
|
ORGANIZATION,ORGANISATION
|
||||||
|
SIGNPOST,SIGN POST
|
||||||
|
WWII,WORLD WAR TWO
|
||||||
|
WINDOWPANE,WINDOW PANE
|
||||||
|
SUREFIRE,SURE FIRE
|
||||||
|
MOUNTAINTOP,MOUNTAIN TOP
|
||||||
|
SALESPERSON,SALES PERSON
|
||||||
|
NETWORK,NET WORK
|
||||||
|
MINISERIES,MINI SERIES
|
||||||
|
EDWARDS'S,EDWARDS IS
|
||||||
|
INTERSUBJECTIVITY,INTER SUBJECTIVITY
|
||||||
|
LIBERALISM'S,LIBERALISM IS
|
||||||
|
TAGLINE,TAG LINE
|
||||||
|
SHINETHEORY,SHINE THEORY
|
||||||
|
CALLYOURGIRLFRIEND,CALL YOUR GIRLFRIEND
|
||||||
|
STARTUP,START UP
|
||||||
|
BREAKUP,BREAK UP
|
||||||
|
RADIOTOPIA,RADIO TOPIA
|
||||||
|
HEARTBREAKING,HEART BREAKING
|
||||||
|
AUTOIMMUNE,AUTO IMMUNE
|
||||||
|
SINISE'S,SINISE IS
|
||||||
|
KICKBACK,KICK BACK
|
||||||
|
FOGHORN,FOG HORN
|
||||||
|
BADASS,BAD ASS
|
||||||
|
POWERAMERICAFORWARD,POWER AMERICA FORWARD
|
||||||
|
GOOGLE'S,GOOGLE IS
|
||||||
|
ROLEPLAY,ROLE PLAY
|
||||||
|
PRICE'S,PRICE IS
|
||||||
|
STANDOFF,STAND OFF
|
||||||
|
FOREVER,FOR EVER
|
||||||
|
GENERAL'S,GENERAL IS
|
||||||
|
DOG'S,DOG IS
|
||||||
|
AUDIOBOOK,AUDIO BOOK
|
||||||
|
ANYWAY,ANY WAY
|
||||||
|
PIGEONHOLE,PIEGON HOLE
|
||||||
|
EGGSHELLS,EGG SHELLS
|
||||||
|
VACCINE'S,VACCINE IS
|
||||||
|
WORKOUT,WORK OUT
|
||||||
|
ADMINISTRATOR'S,ADMINISTRATOR IS
|
||||||
|
FUCKUP,FUCK UP
|
||||||
|
RUNOFFS,RUN OFFS
|
||||||
|
COLORWAY,COLOR WAY
|
||||||
|
WAITLIST,WAIT LIST
|
||||||
|
HEALTHCARE,HEALTH CARE
|
||||||
|
TEXTBOOK,TEXT BOOK
|
||||||
|
CALLBACK,CALL BACK
|
||||||
|
PARTYGOERS,PARTY GOERS
|
||||||
|
SOMEDAY,SOME DAY
|
||||||
|
NIGHTGOWN,NIGHT GOWN
|
||||||
|
STANDALONG,STAND ALONG
|
||||||
|
BUSSINESSWOMAN,BUSSINESS WOMAN
|
||||||
|
STORYTELLING,STORY TELLING
|
||||||
|
MARKETPLACE,MARKET PLACE
|
||||||
|
CRATEJOY,CRATE JOY
|
||||||
|
OUTPERFORMED,OUT PERFORMED
|
||||||
|
TRUEBOTANICALS,TRUE BOTANICALS
|
||||||
|
NONFICTION,NON FICTION
|
||||||
|
SPINOFF,SPIN OFF
|
||||||
|
MOTHERFUCKING,MOTHER FUCKING
|
||||||
|
TRACKLIST,TRACK LIST
|
||||||
|
GODDAMN,GOD DAMN
|
||||||
|
PORNHUB,PORN HUB
|
||||||
|
UNDERAGE,UNDER AGE
|
||||||
|
GOODBYE,GOOD BYE
|
||||||
|
HARDCORE,HARD CORE
|
||||||
|
TRUCK'S,TRUCK IS
|
||||||
|
COUNTERSTEERING,COUNTER STEERING
|
||||||
|
BUZZWORD,BUZZ WORD
|
||||||
|
SUBCOMPONENTS,SUB COMPONENTS
|
||||||
|
MOREOVER,MORE OVER
|
||||||
|
PICKUP,PICK UP
|
||||||
|
NEWSLETTER,NEWS LETTER
|
||||||
|
KEYWORD,KEY WORD
|
||||||
|
LOGIN,LOG IN
|
||||||
|
TOOLBOX,TOOL BOX
|
||||||
|
LINK'S,LINK IS
|
||||||
|
PRIMIALVIDEO,PRIMAL VIDEO
|
||||||
|
DOTNET,DOT NET
|
||||||
|
AIRSTRIKE,AIR STRIKE
|
||||||
|
HAIRSTYLE,HAIR STYLE
|
||||||
|
TOWNSFOLK,TOWNS FOLK
|
||||||
|
GOLDFISH,GOLD FISH
|
||||||
|
TOM'S,TOM IS
|
||||||
|
HOMETOWN,HOME TOWN
|
||||||
|
CORONAVIRUS,CORONA VIRUS
|
||||||
|
PLAYSTATION,PLAY STATION
|
||||||
|
TOMORROW,TO MORROW
|
||||||
|
TIMECONSUMING,TIME CONSUMING
|
||||||
|
POSTWAR,POST WAR
|
||||||
|
HANDSON,HANDS ON
|
||||||
|
SHAKEUP,SHAKE UP
|
||||||
|
ECOMERS,E COMERS
|
||||||
|
COFOUNDER,CO FOUNDER
|
||||||
|
HIGHEND,HIGH END
|
||||||
|
INPERSON,IN PERSON
|
||||||
|
GROWNUP,GROWN UP
|
||||||
|
SELFREGULATION,SELF REGULATION
|
||||||
|
INDEPTH,IN DEPTH
|
||||||
|
ALLTIME,ALL TIME
|
||||||
|
LONGTERM,LONG TERM
|
||||||
|
SOCALLED,SO CALLED
|
||||||
|
SELFCONFIDENCE,SELF CONFIDENCE
|
||||||
|
STANDUP,STAND UP
|
||||||
|
MINDBOGGLING,MIND BOGGLING
|
||||||
|
BEINGFOROTHERS,BEING FOR OTHERS
|
||||||
|
COWROTE,CO WROTE
|
||||||
|
COSTARRED,CO STARRED
|
||||||
|
EDITORINCHIEF,EDITOR IN CHIEF
|
||||||
|
HIGHSPEED,HIGH SPEED
|
||||||
|
DECISIONMAKING,DECISION MAKING
|
||||||
|
WELLBEING,WELL BEING
|
||||||
|
NONTRIVIAL,NON TRIVIAL
|
||||||
|
PREEXISTING,PRE EXISTING
|
||||||
|
STATEOWNED,STATE OWNED
|
||||||
|
PLUGIN,PLUG IN
|
||||||
|
PROVERSION,PRO VERSION
|
||||||
|
OPTIN,OPT IN
|
||||||
|
FOLLOWUP,FOLLOW UP
|
||||||
|
FOLLOWUPS,FOLLOW UPS
|
||||||
|
WIFI,WI FI
|
||||||
|
THIRDPARTY,THIRD PARTY
|
||||||
|
PROFESSIONALLOOKING,PROFESSIONAL LOOKING
|
||||||
|
FULLSCREEN,FULL SCREEN
|
||||||
|
BUILTIN,BUILT IN
|
||||||
|
MULTISTREAM,MULTI STREAM
|
||||||
|
LOWCOST,LOW COST
|
||||||
|
RESTREAM,RE STREAM
|
||||||
|
GAMECHANGER,GAME CHANGER
|
||||||
|
WELLDEVELOPED,WELL DEVELOPED
|
||||||
|
QUARTERINCH,QUARTER INCH
|
||||||
|
FASTFASHION,FAST FASHION
|
||||||
|
ECOMMERCE,E COMMERCE
|
||||||
|
PRIZEWINNING,PRIZE WINNING
|
||||||
|
NEVERENDING,NEVER ENDING
|
||||||
|
MINDBLOWING,MIND BLOWING
|
||||||
|
REALLIFE,REAL LIFE
|
||||||
|
REOPEN,RE OPEN
|
||||||
|
ONDEMAND,ON DEMAND
|
||||||
|
PROBLEMSOLVING,PROBLEM SOLVING
|
||||||
|
HEAVYHANDED,HEAVY HANDED
|
||||||
|
OPENENDED,OPEN ENDED
|
||||||
|
SELFCONTROL,SELF CONTROL
|
||||||
|
WELLMEANING,WELL MEANING
|
||||||
|
COHOST,CO HOST
|
||||||
|
RIGHTSBASED,RIGHTS BASED
|
||||||
|
HALFBROTHER,HALF BROTHER
|
||||||
|
FATHERINLAW,FATHER IN LAW
|
||||||
|
COAUTHOR,CO AUTHOR
|
||||||
|
REELECTION,RE ELECTION
|
||||||
|
SELFHELP,SELF HELP
|
||||||
|
PROLIFE,PRO LIFE
|
||||||
|
ANTIDUKE,ANTI DUKE
|
||||||
|
POSTSTRUCTURALIST,POST STRUCTURALIST
|
||||||
|
COFOUNDED,CO FOUNDED
|
||||||
|
XRAY,X RAY
|
||||||
|
ALLAROUND,ALL AROUND
|
||||||
|
HIGHTECH,HIGH TECH
|
||||||
|
TMOBILE,T MOBILE
|
||||||
|
INHOUSE,IN HOUSE
|
||||||
|
POSTMORTEM,POST MORTEM
|
||||||
|
LITTLEKNOWN,LITTLE KNOWN
|
||||||
|
FALSEPOSITIVE,FALSE POSITIVE
|
||||||
|
ANTIVAXXER,ANTI VAXXER
|
||||||
|
EMAILS,E MAILS
|
||||||
|
DRIVETHROUGH,DRIVE THROUGH
|
||||||
|
DAYTODAY,DAY TO DAY
|
||||||
|
COSTAR,CO STAR
|
||||||
|
EBAY,E BAY
|
||||||
|
KOOLAID,KOOL AID
|
||||||
|
ANTIDEMOCRATIC,ANTI DEMOCRATIC
|
||||||
|
MIDDLEAGED,MIDDLE AGED
|
||||||
|
SHORTLIVED,SHORT LIVED
|
||||||
|
BESTSELLING,BEST SELLING
|
||||||
|
TICTACS,TIC TACS
|
||||||
|
UHHUH,UH HUH
|
||||||
|
MULTITANK,MULTI TANK
|
||||||
|
JAWDROPPING,JAW DROPPING
|
||||||
|
LIVESTREAMING,LIVE STREAMING
|
||||||
|
HARDWORKING,HARD WORKING
|
||||||
|
BOTTOMDWELLING,BOTTOM DWELLING
|
||||||
|
PRESHOW,PRE SHOW
|
||||||
|
HANDSFREE,HANDS FREE
|
||||||
|
TRICKORTREATING,TRICK OR TREATING
|
||||||
|
PRERECORDED,PRE RECORDED
|
||||||
|
DOGOODERS,DO GOODERS
|
||||||
|
WIDERANGING,WIDE RANGING
|
||||||
|
LIFESAVING,LIFE SAVING
|
||||||
|
SKIREPORT,SKI REPORT
|
||||||
|
SNOWBASE,SNOW BASE
|
||||||
|
JAYZ,JAY Z
|
||||||
|
SPIDERMAN,SPIDER MAN
|
||||||
|
FREEKICK,FREE KICK
|
||||||
|
EDWARDSHELAIRE,EDWARDS HELAIRE
|
||||||
|
SHORTTERM,SHORT TERM
|
||||||
|
HAVENOTS,HAVE NOTS
|
||||||
|
SELFINTEREST,SELF INTEREST
|
||||||
|
SELFINTERESTED,SELF INTERESTED
|
||||||
|
SELFCOMPASSION,SELF COMPASSION
|
||||||
|
MACHINELEARNING,MACHINE LEARNING
|
||||||
|
COAUTHORED,CO AUTHORED
|
||||||
|
NONGOVERNMENT,NON GOVERNMENT
|
||||||
|
SUBSAHARAN,SUB SAHARAN
|
||||||
|
COCHAIR,CO CHAIR
|
||||||
|
LARGESCALE,LARGE SCALE
|
||||||
|
VIDEOONDEMAND,VIDEO ON DEMAND
|
||||||
|
FIRSTCLASS,FIRST CLASS
|
||||||
|
COFOUNDERS,CO FOUNDERS
|
||||||
|
COOP,CO OP
|
||||||
|
PREORDERS,PRE ORDERS
|
||||||
|
DOUBLEENTRY,DOUBLE ENTRY
|
||||||
|
SELFCONFIDENT,SELF CONFIDENT
|
||||||
|
SELFPORTRAIT,SELF PORTRAIT
|
||||||
|
NONWHITE,NON WHITE
|
||||||
|
ONBOARD,ON BOARD
|
||||||
|
HALFLIFE,HALF LIFE
|
||||||
|
ONCOURT,ON COURT
|
||||||
|
SCIFI,SCI FI
|
||||||
|
XMEN,X MEN
|
||||||
|
DAYLEWIS,DAY LEWIS
|
||||||
|
LALALAND,LA LA LAND
|
||||||
|
AWARDWINNING,AWARD WINNING
|
||||||
|
BOXOFFICE,BOX OFFICE
|
||||||
|
TRIDACTYLS,TRI DACTYLS
|
||||||
|
TRIDACTYL,TRI DACTYL
|
||||||
|
MEDIUMSIZED,MEDIUM SIZED
|
||||||
|
POSTSECONDARY,POST SECONDARY
|
||||||
|
FULLTIME,FULL TIME
|
||||||
|
GOKART,GO KART
|
||||||
|
OPENAIR,OPEN AIR
|
||||||
|
WELLKNOWN,WELL KNOWN
|
||||||
|
ICECREAM,ICE CREAM
|
||||||
|
EARTHMOON,EARTH MOON
|
||||||
|
STATEOFTHEART,STATE OF THE ART
|
||||||
|
BSIDE,B SIDE
|
||||||
|
EASTWEST,EAST WEST
|
||||||
|
ALLSTAR,ALL STAR
|
||||||
|
RUNNERUP,RUNNER UP
|
||||||
|
HORSEDRAWN,HORSE DRAWN
|
||||||
|
OPENSOURCE,OPEN SOURCE
|
||||||
|
PURPOSEBUILT,PURPOSE BUILT
|
||||||
|
SQUAREFREE,SQUARE FREE
|
||||||
|
PRESENTDAY,PRESENT DAY
|
||||||
|
CANADAUNITED,CANADA UNITED
|
||||||
|
HOTCHPOTCH,HOTCH POTCH
|
||||||
|
LOWLYING,LOW LYING
|
||||||
|
RIGHTHANDED,RIGHT HANDED
|
||||||
|
PEARSHAPED,PEAR SHAPED
|
||||||
|
BESTKNOWN,BEST KNOWN
|
||||||
|
FULLLENGTH,FULL LENGTH
|
||||||
|
YEARROUND,YEAR ROUND
|
||||||
|
PREELECTION,PRE ELECTION
|
||||||
|
RERECORD,RE RECORD
|
||||||
|
MINIALBUM,MINI ALBUM
|
||||||
|
LONGESTRUNNING,LONGEST RUNNING
|
||||||
|
ALLIRELAND,ALL IRELAND
|
||||||
|
NORTHWESTERN,NORTH WESTERN
|
||||||
|
PARTTIME,PART TIME
|
||||||
|
NONGOVERNMENTAL,NON GOVERNMENTAL
|
||||||
|
ONLINE,ON LINE
|
||||||
|
ONAIR,ON AIR
|
||||||
|
NORTHSOUTH,NORTH SOUTH
|
||||||
|
RERELEASED,RE RELEASED
|
||||||
|
LEFTHANDED,LEFT HANDED
|
||||||
|
BSIDES,B SIDES
|
||||||
|
ANGLOSAXON,ANGLO SAXON
|
||||||
|
SOUTHSOUTHEAST,SOUTH SOUTHEAST
|
||||||
|
CROSSCOUNTRY,CROSS COUNTRY
|
||||||
|
REBUILT,RE BUILT
|
||||||
|
FREEFORM,FREE FORM
|
||||||
|
SCOOBYDOO,SCOOBY DOO
|
||||||
|
ATLARGE,AT LARGE
|
||||||
|
COUNCILMANAGER,COUNCIL MANAGER
|
||||||
|
LONGRUNNING,LONG RUNNING
|
||||||
|
PREWAR,PRE WAR
|
||||||
|
REELECTED,RE ELECTED
|
||||||
|
HIGHSCHOOL,HIGH SCHOOL
|
||||||
|
RUNNERSUP,RUNNERS UP
|
||||||
|
NORTHWEST,NORTH WEST
|
||||||
|
WEBBASED,WEB BASED
|
||||||
|
HIGHQUALITY,HIGH QUALITY
|
||||||
|
RIGHTWING,RIGHT WING
|
||||||
|
LANEFOX,LANE FOX
|
||||||
|
PAYPERVIEW,PAY PER VIEW
|
||||||
|
COPRODUCTION,CO PRODUCTION
|
||||||
|
NONPARTISAN,NON PARTISAN
|
||||||
|
FIRSTPERSON,FIRST PERSON
|
||||||
|
WORLDRENOWNED,WORLD RENOWNED
|
||||||
|
VICEPRESIDENT,VICE PRESIDENT
|
||||||
|
PROROMAN,PRO ROMAN
|
||||||
|
COPRODUCED,CO PRODUCED
|
||||||
|
LOWPOWER,LOW POWER
|
||||||
|
SELFESTEEM,SELF ESTEEM
|
||||||
|
SEMITRANSPARENT,SEMI TRANSPARENT
|
||||||
|
SECONDINCOMMAND,SECOND IN COMMAND
|
||||||
|
HIGHRISE,HIGH RISE
|
||||||
|
COHOSTED,CO HOSTED
|
||||||
|
AFRICANAMERICAN,AFRICAN AMERICAN
|
||||||
|
SOUTHWEST,SOUTH WEST
|
||||||
|
WELLPRESERVED,WELL PRESERVED
|
||||||
|
FEATURELENGTH,FEATURE LENGTH
|
||||||
|
HIPHOP,HIP HOP
|
||||||
|
ALLBIG,ALL BIG
|
||||||
|
SOUTHEAST,SOUTH EAST
|
||||||
|
COUNTERATTACK,COUNTER ATTACK
|
||||||
|
QUARTERFINALS,QUARTER FINALS
|
||||||
|
STABLEDOOR,STABLE DOOR
|
||||||
|
DARKEYED,DARK EYED
|
||||||
|
ALLAMERICAN,ALL AMERICAN
|
||||||
|
THIRDPERSON,THIRD PERSON
|
||||||
|
LOWLEVEL,LOW LEVEL
|
||||||
|
NTERMINAL,N TERMINAL
|
||||||
|
DRIEDUP,DRIED UP
|
||||||
|
AFRICANAMERICANS,AFRICAN AMERICANS
|
||||||
|
ANTIAPARTHEID,ANTI APARTHEID
|
||||||
|
STOKEONTRENT,STOKE ON TRENT
|
||||||
|
NORTHNORTHEAST,NORTH NORTHEAST
|
||||||
|
BRANDNEW,BRAND NEW
|
||||||
|
RIGHTANGLED,RIGHT ANGLED
|
||||||
|
GOVERNMENTOWNED,GOVERNMENT OWNED
|
||||||
|
SONINLAW,SON IN LAW
|
||||||
|
SUBJECTOBJECTVERB,SUBJECT OBJECT VERB
|
||||||
|
LEFTARM,LEFT ARM
|
||||||
|
LONGLIVED,LONG LIVED
|
||||||
|
REDEYE,RED EYE
|
||||||
|
TPOSE,T POSE
|
||||||
|
NIGHTVISION,NIGHT VISION
|
||||||
|
SOUTHEASTERN,SOUTH EASTERN
|
||||||
|
WELLRECEIVED,WELL RECEIVED
|
||||||
|
ALFAYOUM,AL FAYOUM
|
||||||
|
TIMEBASED,TIME BASED
|
||||||
|
KETTLEDRUMS,KETTLE DRUMS
|
||||||
|
BRIGHTEYED,BRIGHT EYED
|
||||||
|
REDBROWN,RED BROWN
|
||||||
|
SAMESEX,SAME SEX
|
||||||
|
PORTDEPAIX,PORT DE PAIX
|
||||||
|
CLEANUP,CLEAN UP
|
||||||
|
PERCENT,PERCENT SIGN
|
||||||
|
TAKEOUT,TAKE OUT
|
||||||
|
KNOWHOW,KNOW HOW
|
||||||
|
FISHBONE,FISH BONE
|
||||||
|
FISHSTICKS,FISH STICKS
|
||||||
|
PAPERWORK,PAPER WORK
|
||||||
|
NICKNACKS,NICK NACKS
|
||||||
|
STREETTALKING,STREET TALKING
|
||||||
|
NONACADEMIC,NON ACADEMIC
|
||||||
|
SHELLY,SHELLEY
|
||||||
|
SHELLY'S,SHELLEY'S
|
||||||
|
JIMMY,JIMMIE
|
||||||
|
JIMMY'S,JIMMIE'S
|
||||||
|
DRUGSTORE,DRUG STORE
|
||||||
|
THRU,THROUGH
|
||||||
|
PLAYDATE,PLAY DATE
|
||||||
|
MICROLIFE,MICRO LIFE
|
||||||
|
SKILLSET,SKILL SET
|
||||||
|
SKILLSETS,SKILL SETS
|
||||||
|
TRADEOFF,TRADE OFF
|
||||||
|
TRADEOFFS,TRADE OFFS
|
||||||
|
ONSCREEN,ON SCREEN
|
||||||
|
PLAYBACK,PLAY BACK
|
||||||
|
ARTWORK,ART WORK
|
||||||
|
COWORKER,CO WORDER
|
||||||
|
COWORKERS,CO WORDERS
|
||||||
|
SOMETIME,SOME TIME
|
||||||
|
SOMETIMES,SOME TIMES
|
||||||
|
CROWDFUNDING,CROWD FUNDING
|
||||||
|
AM,A.M.,A M
|
||||||
|
PM,P.M.,P M
|
||||||
|
TV,T V
|
||||||
|
MBA,M B A
|
||||||
|
USA,U S A
|
||||||
|
US,U S
|
||||||
|
UK,U K
|
||||||
|
CEO,C E O
|
||||||
|
CFO,C F O
|
||||||
|
COO,C O O
|
||||||
|
CIO,C I O
|
||||||
|
FM,F M
|
||||||
|
GMC,G M C
|
||||||
|
FSC,F S C
|
||||||
|
NPD,N P D
|
||||||
|
APM,A P M
|
||||||
|
NGO,N G O
|
||||||
|
TD,T D
|
||||||
|
LOL,L O L
|
||||||
|
IPO,I P O
|
||||||
|
CNBC,C N B C
|
||||||
|
IPOS,I P OS
|
||||||
|
CNBC's,C N B C'S
|
||||||
|
JT,J T
|
||||||
|
NPR,N P R
|
||||||
|
NPR'S,N P R'S
|
||||||
|
MP,M P
|
||||||
|
IOI,I O I
|
||||||
|
DW,D W
|
||||||
|
CNN,C N N
|
||||||
|
WSM,W S M
|
||||||
|
ET,E T
|
||||||
|
IT,I T
|
||||||
|
RJ,R J
|
||||||
|
DVD,D V D
|
||||||
|
DVD'S,D V D'S
|
||||||
|
HBO,H B O
|
||||||
|
LA,L A
|
||||||
|
XC,X C
|
||||||
|
SUV,S U V
|
||||||
|
NBA,N B A
|
||||||
|
NBA'S,N B A'S
|
||||||
|
ESPN,E S P N
|
||||||
|
ESPN'S,E S P N'S
|
||||||
|
ADT,A D T
|
||||||
|
HD,H D
|
||||||
|
VIP,V I P
|
||||||
|
TMZ,T M Z
|
||||||
|
CBC,C B C
|
||||||
|
NPO,N P O
|
||||||
|
BBC,B B C
|
||||||
|
LA'S,L A'S
|
||||||
|
TMZ'S,T M Z'S
|
||||||
|
HIV,H I V
|
||||||
|
FTC,F T C
|
||||||
|
EU,E U
|
||||||
|
PHD,P H D
|
||||||
|
AI,A I
|
||||||
|
FHI,F H I
|
||||||
|
ICML,I C M L
|
||||||
|
ICLR,I C L R
|
||||||
|
BMW,B M W
|
||||||
|
EV,E V
|
||||||
|
CR,C R
|
||||||
|
API,A P I
|
||||||
|
ICO,I C O
|
||||||
|
LTE,L T E
|
||||||
|
OBS,O B S
|
||||||
|
PC,P C
|
||||||
|
IO,I O
|
||||||
|
CRM,C R M
|
||||||
|
RTMP,R T M P
|
||||||
|
ASMR,A S M R
|
||||||
|
GG,G G
|
||||||
|
WWW,W W W
|
||||||
|
PEI,P E I
|
||||||
|
JJ,J J
|
||||||
|
PT,P T
|
||||||
|
DJ,D J
|
||||||
|
SD,S D
|
||||||
|
POW,P.O.W.,P O W
|
||||||
|
FYI,F Y I
|
||||||
|
DC,D C,D.C
|
||||||
|
ABC,A B C
|
||||||
|
TJ,T J
|
||||||
|
WMDT,W M D T
|
||||||
|
WDTN,W D T N
|
||||||
|
TY,T Y
|
||||||
|
EJ,E J
|
||||||
|
CJ,C J
|
||||||
|
ACL,A C L
|
||||||
|
UK'S,U K'S
|
||||||
|
GTV,G T V
|
||||||
|
MDMA,M D M A
|
||||||
|
DFW,D F W
|
||||||
|
WTF,W T F
|
||||||
|
AJ,A J
|
||||||
|
MD,M D
|
||||||
|
PH,P H
|
||||||
|
ID,I D
|
||||||
|
SEO,S E O
|
||||||
|
UTM'S,U T M'S
|
||||||
|
EC,E C
|
||||||
|
UFC,U F C
|
||||||
|
RV,R V
|
||||||
|
UTM,U T M
|
||||||
|
CSV,C S V
|
||||||
|
SMS,S M S
|
||||||
|
GRB,G R B
|
||||||
|
GT,G T
|
||||||
|
LEM,L E M
|
||||||
|
XR,X R
|
||||||
|
EDU,E D U
|
||||||
|
NBC,N B C
|
||||||
|
EMS,E M S
|
||||||
|
CDC,C D C
|
||||||
|
MLK,M L K
|
||||||
|
IE,I E
|
||||||
|
OC,O C
|
||||||
|
HR,H R
|
||||||
|
MA,M A
|
||||||
|
DEE,D E E
|
||||||
|
AP,A P
|
||||||
|
UFO,U F O
|
||||||
|
DE,D E
|
||||||
|
LGBTQ,L G B T Q
|
||||||
|
PTA,P T A
|
||||||
|
NHS,N H S
|
||||||
|
CMA,C M A
|
||||||
|
MGM,M G M
|
||||||
|
AKA,A K A
|
||||||
|
HW,H W
|
||||||
|
GOP,G O P
|
||||||
|
GOP'S,G O P'S
|
||||||
|
FBI,F B I
|
||||||
|
PRX,P R X
|
||||||
|
CTO,C T O
|
||||||
|
URL,U R L
|
||||||
|
EIN,E I N
|
||||||
|
MLS,M L S
|
||||||
|
CSI,C S I
|
||||||
|
AOC,A O C
|
||||||
|
CND,C N D
|
||||||
|
CP,C P
|
||||||
|
PP,P P
|
||||||
|
CLI,C L I
|
||||||
|
PB,P B
|
||||||
|
FDA,F D A
|
||||||
|
MRNA,M R N A
|
||||||
|
PR,P R
|
||||||
|
VP,V P
|
||||||
|
DNC,D N C
|
||||||
|
MSNBC,M S N B C
|
||||||
|
GQ,G Q
|
||||||
|
UT,U T
|
||||||
|
XXI,X X I
|
||||||
|
HRV,H R V
|
||||||
|
WHO,W H O
|
||||||
|
CRO,C R O
|
||||||
|
DPA,D P A
|
||||||
|
PPE,P P E
|
||||||
|
EVA,E V A
|
||||||
|
BP,B P
|
||||||
|
GPS,G P S
|
||||||
|
AR,A R
|
||||||
|
PJ,P J
|
||||||
|
MLM,M L M
|
||||||
|
OLED,O L E D
|
||||||
|
BO,B O
|
||||||
|
VE,V E
|
||||||
|
UN,U N
|
||||||
|
SLS,S L S
|
||||||
|
DM,D M
|
||||||
|
DM'S,D M'S
|
||||||
|
ASAP,A S A P
|
||||||
|
ETA,E T A
|
||||||
|
DOB,D O B
|
||||||
|
BMW,B M W
|
||||||
|
20
utils/speechio/interjections_en.csv
Normal file
20
utils/speechio/interjections_en.csv
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
ach
|
||||||
|
ah
|
||||||
|
eee
|
||||||
|
eh
|
||||||
|
er
|
||||||
|
ew
|
||||||
|
ha
|
||||||
|
hee
|
||||||
|
hm
|
||||||
|
hmm
|
||||||
|
hmmm
|
||||||
|
huh
|
||||||
|
mm
|
||||||
|
mmm
|
||||||
|
oof
|
||||||
|
uh
|
||||||
|
uhh
|
||||||
|
um
|
||||||
|
oh
|
||||||
|
hum
|
||||||
|
1
utils/speechio/nemo_text_processing/README.md
Normal file
1
utils/speechio/nemo_text_processing/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
nemo_version from commit:eae1684f7f33c2a18de9ecfa42ec7db93d39e631
|
||||||
13
utils/speechio/nemo_text_processing/__init__.py
Normal file
13
utils/speechio/nemo_text_processing/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
# Text Normalization
|
||||||
|
|
||||||
|
Text Normalization is part of NeMo's `nemo_text_processing` - a Python package that is installed with the `nemo_toolkit`.
|
||||||
|
It converts text from written form into its verbalized form, e.g. "123" -> "one hundred twenty three".
|
||||||
|
|
||||||
|
See [NeMo documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/text_normalization/wfst/wfst_text_normalization.html) for details.
|
||||||
|
|
||||||
|
Tutorial with overview of the package capabilities: [Text_(Inverse)_Normalization.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb)
|
||||||
|
|
||||||
|
Tutorial on how to customize the underlying gramamrs: [WFST_Tutorial.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/WFST_Tutorial.ipynb)
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,350 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
from collections import defaultdict, namedtuple
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
from unicodedata import category
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EOS_TYPE = "EOS"
|
||||||
|
PUNCT_TYPE = "PUNCT"
|
||||||
|
PLAIN_TYPE = "PLAIN"
|
||||||
|
Instance = namedtuple('Instance', 'token_type un_normalized normalized')
|
||||||
|
known_types = [
|
||||||
|
"PLAIN",
|
||||||
|
"DATE",
|
||||||
|
"CARDINAL",
|
||||||
|
"LETTERS",
|
||||||
|
"VERBATIM",
|
||||||
|
"MEASURE",
|
||||||
|
"DECIMAL",
|
||||||
|
"ORDINAL",
|
||||||
|
"DIGIT",
|
||||||
|
"MONEY",
|
||||||
|
"TELEPHONE",
|
||||||
|
"ELECTRONIC",
|
||||||
|
"FRACTION",
|
||||||
|
"TIME",
|
||||||
|
"ADDRESS",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
|
||||||
|
Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
|
||||||
|
E.g.
|
||||||
|
PLAIN Brillantaisia <self>
|
||||||
|
PLAIN is <self>
|
||||||
|
PLAIN a <self>
|
||||||
|
PLAIN genus <self>
|
||||||
|
PLAIN of <self>
|
||||||
|
PLAIN plant <self>
|
||||||
|
PLAIN in <self>
|
||||||
|
PLAIN family <self>
|
||||||
|
PLAIN Acanthaceae <self>
|
||||||
|
PUNCT . sil
|
||||||
|
<eos> <eos>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: file path to text file
|
||||||
|
|
||||||
|
Returns: flat list of instances
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
with open(file_path, 'r') as fp:
|
||||||
|
for line in fp:
|
||||||
|
parts = line.strip().split("\t")
|
||||||
|
if parts[0] == "<eos>":
|
||||||
|
res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized=""))
|
||||||
|
else:
|
||||||
|
l_type, l_token, l_normalized = parts
|
||||||
|
l_token = l_token.lower()
|
||||||
|
l_normalized = l_normalized.lower()
|
||||||
|
|
||||||
|
if l_type == PLAIN_TYPE:
|
||||||
|
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token))
|
||||||
|
elif l_type != PUNCT_TYPE:
|
||||||
|
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
Load given list of text files using the `load_func` function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: list of file paths
|
||||||
|
load_func: loading function
|
||||||
|
|
||||||
|
Returns: flat list of instances
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
res.extend(load_func(file_path=file_path))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def clean_generic(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Cleans text without affecting semiotic classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: string
|
||||||
|
|
||||||
|
Returns: cleaned string
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
text = text.lower()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True) -> float:
|
||||||
|
"""
|
||||||
|
Evaluates accuracy given predictions and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: predictions
|
||||||
|
labels: labels
|
||||||
|
input: optional, only needed for verbosity
|
||||||
|
verbose: if true prints [input], golden labels and predictions
|
||||||
|
|
||||||
|
Returns accuracy
|
||||||
|
"""
|
||||||
|
acc = 0
|
||||||
|
nums = len(preds)
|
||||||
|
for i in range(nums):
|
||||||
|
pred_norm = clean_generic(preds[i])
|
||||||
|
label_norm = clean_generic(labels[i])
|
||||||
|
if pred_norm == label_norm:
|
||||||
|
acc = acc + 1
|
||||||
|
else:
|
||||||
|
if input:
|
||||||
|
print(f"inpu: {json.dumps(input[i])}")
|
||||||
|
print(f"gold: {json.dumps(label_norm)}")
|
||||||
|
print(f"pred: {json.dumps(pred_norm)}")
|
||||||
|
return acc / nums
|
||||||
|
|
||||||
|
|
||||||
|
def training_data_to_tokens(
|
||||||
|
data: List[Instance], category: Optional[str] = None
|
||||||
|
) -> Dict[str, Tuple[List[str], List[str]]]:
|
||||||
|
"""
|
||||||
|
Filters the instance list by category if provided and converts it into a map from token type to list of un_normalized and normalized strings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: list of instances
|
||||||
|
category: optional semiotic class category name
|
||||||
|
|
||||||
|
Returns Dict: token type -> (list of un_normalized strings, list of normalized strings)
|
||||||
|
"""
|
||||||
|
result = defaultdict(lambda: ([], []))
|
||||||
|
for instance in data:
|
||||||
|
if instance.token_type != EOS_TYPE:
|
||||||
|
if category is None or instance.token_type == category:
|
||||||
|
result[instance.token_type][0].append(instance.un_normalized)
|
||||||
|
result[instance.token_type][1].append(instance.normalized)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def training_data_to_sentences(data: List[Instance]) -> Tuple[List[str], List[str], List[Set[str]]]:
|
||||||
|
"""
|
||||||
|
Takes instance list, creates list of sentences split by EOS_Token
|
||||||
|
Args:
|
||||||
|
data: list of instances
|
||||||
|
Returns (list of unnormalized sentences, list of normalized sentences, list of sets of categories in a sentence)
|
||||||
|
"""
|
||||||
|
# split data at EOS boundaries
|
||||||
|
sentences = []
|
||||||
|
sentence = []
|
||||||
|
categories = []
|
||||||
|
sentence_categories = set()
|
||||||
|
|
||||||
|
for instance in data:
|
||||||
|
if instance.token_type == EOS_TYPE:
|
||||||
|
sentences.append(sentence)
|
||||||
|
sentence = []
|
||||||
|
categories.append(sentence_categories)
|
||||||
|
sentence_categories = set()
|
||||||
|
else:
|
||||||
|
sentence.append(instance)
|
||||||
|
sentence_categories.update([instance.token_type])
|
||||||
|
un_normalized = [" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences]
|
||||||
|
normalized = [" ".join([instance.normalized for instance in sentence]) for sentence in sentences]
|
||||||
|
return un_normalized, normalized, categories
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_punctuation(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalized quotes and spaces
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: text
|
||||||
|
|
||||||
|
Returns: text with normalized spaces and quotes
|
||||||
|
"""
|
||||||
|
text = (
|
||||||
|
text.replace('( ', '(')
|
||||||
|
.replace(' )', ')')
|
||||||
|
.replace('{ ', '{')
|
||||||
|
.replace(' }', '}')
|
||||||
|
.replace('[ ', '[')
|
||||||
|
.replace(' ]', ']')
|
||||||
|
.replace(' ', ' ')
|
||||||
|
.replace('”', '"')
|
||||||
|
.replace("’", "'")
|
||||||
|
.replace("»", '"')
|
||||||
|
.replace("«", '"')
|
||||||
|
.replace("\\", "")
|
||||||
|
.replace("„", '"')
|
||||||
|
.replace("´", "'")
|
||||||
|
.replace("’", "'")
|
||||||
|
.replace('“', '"')
|
||||||
|
.replace("‘", "'")
|
||||||
|
.replace('`', "'")
|
||||||
|
.replace('- -', "--")
|
||||||
|
)
|
||||||
|
|
||||||
|
for punct in "!,.:;?":
|
||||||
|
text = text.replace(f' {punct}', punct)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def pre_process(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Optional text preprocessing before normalization (part of TTS TN pipeline)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: string that may include semiotic classes
|
||||||
|
|
||||||
|
Returns: text with spaces around punctuation marks
|
||||||
|
"""
|
||||||
|
space_both = '[]'
|
||||||
|
for punct in space_both:
|
||||||
|
text = text.replace(punct, ' ' + punct + ' ')
|
||||||
|
|
||||||
|
# remove extra space
|
||||||
|
text = re.sub(r' +', ' ', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(file_path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Loads given text file with separate lines into list of string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: file path
|
||||||
|
|
||||||
|
Returns: flat list of string
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
with open(file_path, 'r') as fp:
|
||||||
|
for line in fp:
|
||||||
|
res.append(line)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def write_file(file_path: str, data: List[str]):
|
||||||
|
"""
|
||||||
|
Writes out list of string to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: file path
|
||||||
|
data: list of string
|
||||||
|
|
||||||
|
"""
|
||||||
|
with open(file_path, 'w') as fp:
|
||||||
|
for line in data:
|
||||||
|
fp.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
|
||||||
|
"""
|
||||||
|
Post-processing of the normalized output to match input in terms of spaces around punctuation marks.
|
||||||
|
After NN normalization, Moses detokenization puts a space after
|
||||||
|
punctuation marks, and attaches an opening quote "'" to the word to the right.
|
||||||
|
E.g., input to the TN NN model is "12 test' example",
|
||||||
|
after normalization and detokenization -> "twelve test 'example" (the quote is considered to be an opening quote,
|
||||||
|
but it doesn't match the input and can cause issues during TTS voice generation.)
|
||||||
|
The current function will match the punctuation and spaces of the normalized text with the input sequence.
|
||||||
|
"12 test' example" -> "twelve test 'example" -> "twelve test' example" (the quote was shifted to match the input).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: input text (original input to the NN, before normalization or tokenization)
|
||||||
|
normalized_text: output text (output of the TN NN model)
|
||||||
|
add_unicode_punct: set to True to handle unicode punctuation marks as well as default string.punctuation (increases post processing time)
|
||||||
|
"""
|
||||||
|
# in the post-processing WFST graph "``" are repalced with '"" quotes (otherwise single quotes "`" won't be handled correctly)
|
||||||
|
# this function fixes spaces around them based on input sequence, so here we're making the same double quote replacement
|
||||||
|
# to make sure these new double quotes work with this function
|
||||||
|
if "``" in input and "``" not in normalized_text:
|
||||||
|
input = input.replace("``", '"')
|
||||||
|
input = [x for x in input]
|
||||||
|
normalized_text = [x for x in normalized_text]
|
||||||
|
punct_marks = [x for x in string.punctuation if x in input]
|
||||||
|
|
||||||
|
if add_unicode_punct:
|
||||||
|
punct_unicode = [
|
||||||
|
chr(i)
|
||||||
|
for i in range(sys.maxunicode)
|
||||||
|
if category(chr(i)).startswith("P") and chr(i) not in punct_default and chr(i) in input
|
||||||
|
]
|
||||||
|
punct_marks = punct_marks.extend(punct_unicode)
|
||||||
|
|
||||||
|
for punct in punct_marks:
|
||||||
|
try:
|
||||||
|
equal = True
|
||||||
|
if input.count(punct) != normalized_text.count(punct):
|
||||||
|
equal = False
|
||||||
|
idx_in, idx_out = 0, 0
|
||||||
|
while punct in input[idx_in:]:
|
||||||
|
idx_out = normalized_text.index(punct, idx_out)
|
||||||
|
idx_in = input.index(punct, idx_in)
|
||||||
|
|
||||||
|
def _is_valid(idx_out, idx_in, normalized_text, input):
|
||||||
|
"""Check if previous or next word match (for cases when punctuation marks are part of
|
||||||
|
semiotic token, i.e. some punctuation can be missing in the normalized text)"""
|
||||||
|
return (idx_out > 0 and idx_in > 0 and normalized_text[idx_out - 1] == input[idx_in - 1]) or (
|
||||||
|
idx_out < len(normalized_text) - 1
|
||||||
|
and idx_in < len(input) - 1
|
||||||
|
and normalized_text[idx_out + 1] == input[idx_in + 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not equal and not _is_valid(idx_out, idx_in, normalized_text, input):
|
||||||
|
idx_in += 1
|
||||||
|
continue
|
||||||
|
if idx_in > 0 and idx_out > 0:
|
||||||
|
if normalized_text[idx_out - 1] == " " and input[idx_in - 1] != " ":
|
||||||
|
normalized_text[idx_out - 1] = ""
|
||||||
|
|
||||||
|
elif normalized_text[idx_out - 1] != " " and input[idx_in - 1] == " ":
|
||||||
|
normalized_text[idx_out - 1] += " "
|
||||||
|
|
||||||
|
if idx_in < len(input) - 1 and idx_out < len(normalized_text) - 1:
|
||||||
|
if normalized_text[idx_out + 1] == " " and input[idx_in + 1] != " ":
|
||||||
|
normalized_text[idx_out + 1] = ""
|
||||||
|
elif normalized_text[idx_out + 1] != " " and input[idx_in + 1] == " ":
|
||||||
|
normalized_text[idx_out] = normalized_text[idx_out] + " "
|
||||||
|
idx_out += 1
|
||||||
|
idx_in += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
normalized_text = "".join(normalized_text)
|
||||||
|
return re.sub(r' +', ' ', normalized_text)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||||
@@ -0,0 +1,342 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
from nemo_text_processing.text_normalization.data_loader_utils import (
|
||||||
|
EOS_TYPE,
|
||||||
|
Instance,
|
||||||
|
load_files,
|
||||||
|
training_data_to_sentences,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file is for evaluation purposes.
|
||||||
|
filter_loaded_data() cleans data (list of instances) for text normalization. Filters and cleaners can be specified for each semiotic class individually.
|
||||||
|
For example, normalized text should only include characters and whitespace characters but no punctuation.
|
||||||
|
Cardinal unnormalized instances should contain at least one integer and all other characters are removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Filter:
|
||||||
|
"""
|
||||||
|
Filter class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_type: semiotic class used in dataset
|
||||||
|
process_func: function to transform text
|
||||||
|
filter_func: function to filter text
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, class_type: str, process_func: object, filter_func: object):
|
||||||
|
self.class_type = class_type
|
||||||
|
self.process_func = process_func
|
||||||
|
self.filter_func = filter_func
|
||||||
|
|
||||||
|
def filter(self, instance: Instance) -> bool:
|
||||||
|
"""
|
||||||
|
filter function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters given instance with filter function
|
||||||
|
|
||||||
|
Returns: True if given instance fulfills criteria or does not belong to class type
|
||||||
|
"""
|
||||||
|
if instance.token_type != self.class_type:
|
||||||
|
return True
|
||||||
|
return self.filter_func(instance)
|
||||||
|
|
||||||
|
def process(self, instance: Instance) -> Instance:
|
||||||
|
"""
|
||||||
|
process function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processes given instance with process function
|
||||||
|
|
||||||
|
Returns: processed instance if instance belongs to expected class type or original instance
|
||||||
|
"""
|
||||||
|
if instance.token_type != self.class_type:
|
||||||
|
return instance
|
||||||
|
return self.process_func(instance)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_cardinal_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_cardinal_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r"[^0-9]", "", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_ordinal_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"(st|nd|rd|th)\s*$", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_ordinal_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r"[,\s]", "", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_decimal_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_decimal_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_measure_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_measure_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
un_normalized = re.sub(r"m2", "m²", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)([^\d.\s])", r"\1 \2", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z\s]", "", normalized)
|
||||||
|
normalized = re.sub(r"per ([a-z\s]*)s$", r"per \1", normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_money_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_money_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
un_normalized = re.sub(r"a\$", r"$", un_normalized)
|
||||||
|
un_normalized = re.sub(r"us\$", r"$", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)m\s*$", r"\1 million", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)bn?\s*$", r"\1 billion", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_time_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_time_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
un_normalized = re.sub(r": ", ":", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)\s?a\s?m\s?", r"\1 a.m.", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)\s?p\s?m\s?", r"\1 p.m.", un_normalized)
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_plain_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_plain_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_punct_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_punct_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_date_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_date_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_letters_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_letters_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_verbatim_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_verbatim_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_digit_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_digit_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_telephone_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_telephone_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_electronic_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_electronic_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_fraction_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_fraction_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_address_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_address_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
filters = []
|
||||||
|
filters.append(Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1))
|
||||||
|
filters.append(Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1))
|
||||||
|
filters.append(Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1))
|
||||||
|
filters.append(Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1))
|
||||||
|
filters.append(Filter(class_type="MONEY", process_func=process_money_1, filter_func=filter_money_1))
|
||||||
|
filters.append(Filter(class_type="TIME", process_func=process_time_1, filter_func=filter_time_1))
|
||||||
|
|
||||||
|
filters.append(Filter(class_type="DATE", process_func=process_date_1, filter_func=filter_date_1))
|
||||||
|
filters.append(Filter(class_type="PLAIN", process_func=process_plain_1, filter_func=filter_plain_1))
|
||||||
|
filters.append(Filter(class_type="PUNCT", process_func=process_punct_1, filter_func=filter_punct_1))
|
||||||
|
filters.append(Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1))
|
||||||
|
filters.append(Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1))
|
||||||
|
filters.append(Filter(class_type="DIGIT", process_func=process_digit_1, filter_func=filter_digit_1))
|
||||||
|
filters.append(Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1))
|
||||||
|
filters.append(Filter(class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1))
|
||||||
|
filters.append(Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1))
|
||||||
|
filters.append(Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1))
|
||||||
|
filters.append(Filter(class_type=EOS_TYPE, process_func=lambda x: x, filter_func=lambda x: True))
|
||||||
|
|
||||||
|
|
||||||
|
def filter_loaded_data(data: List[Instance], verbose: bool = False) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
Filters list of instances
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: list of instances
|
||||||
|
|
||||||
|
Returns: filtered and transformed list of instances
|
||||||
|
"""
|
||||||
|
updates_instances = []
|
||||||
|
for instance in data:
|
||||||
|
updated_instance = False
|
||||||
|
for fil in filters:
|
||||||
|
if fil.class_type == instance.token_type and fil.filter(instance):
|
||||||
|
instance = fil.process(instance)
|
||||||
|
updated_instance = True
|
||||||
|
if updated_instance:
|
||||||
|
if verbose:
|
||||||
|
print(instance)
|
||||||
|
updates_instances.append(instance)
|
||||||
|
return updates_instances
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--input", help="input file path", type=str, default='./en_with_types/output-00001-of-00100')
|
||||||
|
parser.add_argument("--verbose", help="print filtered instances", action='store_true')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
file_path = args.input
|
||||||
|
|
||||||
|
print("Loading training data: " + file_path)
|
||||||
|
instance_list = load_files([file_path]) # List of instances
|
||||||
|
filtered_instance_list = filter_loaded_data(instance_list, args.verbose)
|
||||||
|
training_data_to_sentences(filtered_instance_list)
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
st Street
|
||||||
|
street Street
|
||||||
|
expy Expressway
|
||||||
|
fwy Freeway
|
||||||
|
hwy Highway
|
||||||
|
dr Drive
|
||||||
|
ct Court
|
||||||
|
ave Avenue
|
||||||
|
av Avenue
|
||||||
|
cir Circle
|
||||||
|
blvd Boulevard
|
||||||
|
alley Alley
|
||||||
|
way Way
|
||||||
|
jct Junction
|
||||||
|
@@ -0,0 +1,52 @@
|
|||||||
|
Alabama AL
|
||||||
|
Alaska AK
|
||||||
|
Arizona AZ
|
||||||
|
Arkansas AR
|
||||||
|
California CA
|
||||||
|
Colorado CO
|
||||||
|
Connecticut CT
|
||||||
|
Delaware DE
|
||||||
|
Florida FL
|
||||||
|
Georgia GA
|
||||||
|
Hawaii HI
|
||||||
|
Idaho ID
|
||||||
|
Illinois IL
|
||||||
|
Indiana IN
|
||||||
|
Indiana IND
|
||||||
|
Iowa IA
|
||||||
|
Kansas KS
|
||||||
|
Kentucky KY
|
||||||
|
Louisiana LA
|
||||||
|
Maine ME
|
||||||
|
Maryland MD
|
||||||
|
Massachusetts MA
|
||||||
|
Michigan MI
|
||||||
|
Minnesota MN
|
||||||
|
Mississippi MS
|
||||||
|
Missouri MO
|
||||||
|
Montana MT
|
||||||
|
Nebraska NE
|
||||||
|
Nevada NV
|
||||||
|
New Hampshire NH
|
||||||
|
New Jersey NJ
|
||||||
|
New Mexico NM
|
||||||
|
New York NY
|
||||||
|
North Carolina NC
|
||||||
|
North Dakota ND
|
||||||
|
Ohio OH
|
||||||
|
Oklahoma OK
|
||||||
|
Oregon OR
|
||||||
|
Pennsylvania PA
|
||||||
|
Rhode Island RI
|
||||||
|
South Carolina SC
|
||||||
|
South Dakota SD
|
||||||
|
Tennessee TN
|
||||||
|
Tennessee TENN
|
||||||
|
Texas TX
|
||||||
|
Utah UT
|
||||||
|
Vermont VT
|
||||||
|
Virginia VA
|
||||||
|
Washington WA
|
||||||
|
West Virginia WV
|
||||||
|
Wisconsin WI
|
||||||
|
Wyoming WY
|
||||||
|
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
one
|
||||||
|
two
|
||||||
|
three
|
||||||
|
four
|
||||||
|
five
|
||||||
|
six
|
||||||
|
seven
|
||||||
|
eight
|
||||||
|
nine
|
||||||
|
ten
|
||||||
|
eleven
|
||||||
|
twelve
|
||||||
|
thirteen
|
||||||
|
fourteen
|
||||||
|
fifteen
|
||||||
|
sixteen
|
||||||
|
seventeen
|
||||||
|
eighteen
|
||||||
|
nineteen
|
||||||
|
twenty
|
||||||
|
twenty one
|
||||||
|
twenty two
|
||||||
|
twenty three
|
||||||
|
twenty four
|
||||||
|
twenty five
|
||||||
|
twenty six
|
||||||
|
twenty seven
|
||||||
|
twenty eight
|
||||||
|
twenty nine
|
||||||
|
thirty
|
||||||
|
thirty one
|
||||||
|
@@ -0,0 +1,12 @@
|
|||||||
|
jan january
|
||||||
|
feb february
|
||||||
|
mar march
|
||||||
|
apr april
|
||||||
|
jun june
|
||||||
|
jul july
|
||||||
|
aug august
|
||||||
|
sep september
|
||||||
|
sept september
|
||||||
|
oct october
|
||||||
|
nov november
|
||||||
|
dec december
|
||||||
|
@@ -0,0 +1,12 @@
|
|||||||
|
january
|
||||||
|
february
|
||||||
|
march
|
||||||
|
april
|
||||||
|
may
|
||||||
|
june
|
||||||
|
july
|
||||||
|
august
|
||||||
|
september
|
||||||
|
october
|
||||||
|
november
|
||||||
|
december
|
||||||
|
@@ -0,0 +1,24 @@
|
|||||||
|
1 january
|
||||||
|
2 february
|
||||||
|
3 march
|
||||||
|
4 april
|
||||||
|
5 may
|
||||||
|
6 june
|
||||||
|
7 july
|
||||||
|
8 august
|
||||||
|
9 september
|
||||||
|
10 october
|
||||||
|
11 november
|
||||||
|
12 december
|
||||||
|
01 january
|
||||||
|
02 february
|
||||||
|
03 march
|
||||||
|
04 april
|
||||||
|
05 may
|
||||||
|
06 june
|
||||||
|
07 july
|
||||||
|
08 august
|
||||||
|
09 september
|
||||||
|
10 october
|
||||||
|
11 november
|
||||||
|
12 december
|
||||||
|
@@ -0,0 +1,16 @@
|
|||||||
|
A. D AD
|
||||||
|
A.D AD
|
||||||
|
a. d AD
|
||||||
|
a.d AD
|
||||||
|
a. d. AD
|
||||||
|
a.d. AD
|
||||||
|
B. C BC
|
||||||
|
B.C BC
|
||||||
|
b. c BC
|
||||||
|
b.c BC
|
||||||
|
A. D. AD
|
||||||
|
A.D. AD
|
||||||
|
B. C. BC
|
||||||
|
B.C. BC
|
||||||
|
b. c. BC
|
||||||
|
b.c. BC
|
||||||
|
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
.com dot com
|
||||||
|
.org dot org
|
||||||
|
.gov dot gov
|
||||||
|
.uk dot UK
|
||||||
|
.fr dot FR
|
||||||
|
.net dot net
|
||||||
|
.br dot BR
|
||||||
|
.in dot IN
|
||||||
|
.ru dot RU
|
||||||
|
.de dot DE
|
||||||
|
.it dot IT
|
||||||
|
.jpg dot jpeg
|
||||||
|
@@ -0,0 +1,21 @@
|
|||||||
|
. dot
|
||||||
|
- dash
|
||||||
|
_ underscore
|
||||||
|
! exclamation mark
|
||||||
|
# number sign
|
||||||
|
$ dollar sign
|
||||||
|
% percent sign
|
||||||
|
& ampersand
|
||||||
|
' quote
|
||||||
|
* asterisk
|
||||||
|
+ plus
|
||||||
|
/ slash
|
||||||
|
= equal sign
|
||||||
|
? question mark
|
||||||
|
^ circumflex
|
||||||
|
` right single quote
|
||||||
|
{ left brace
|
||||||
|
| vertical bar
|
||||||
|
} right brace
|
||||||
|
~ tilde
|
||||||
|
, comma
|
||||||
|
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user