This commit is contained in:
zhousha
2025-08-06 15:38:55 +08:00
parent 4916ad0fe0
commit 55a67e817e
193 changed files with 51647 additions and 1 deletions

BIN
.DS_Store vendored

Binary file not shown.

49
Dockerfile Normal file
View File

@@ -0,0 +1,49 @@
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
MAINTAINER shiguangchuan@4paradigm.com
WORKDIR /workspace
COPY ssh-keygen /bin
RUN wget -q ftp://ftp.4pd.io/pub/pico/temp/pynini-2.1.6-cp38-cp38-manylinux_2_31_x86_64.whl && pip install pynini-2.1.6-cp38-cp38-manylinux_2_31_x86_64.whl && rm -f pynini-2.1.6-c p38-cp38-manylinux_2_31_x86_64.whl
ADD ./requirements.txt /workspace
RUN pip install -r ./requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --trusted-host nexus.4pd.io --extra-index-url https://mirrors.aliyun.com/pypi/simple/ \
&& pip cache purge \
&& ssh-keygen -f /workspace/ssh-key-ecdsa -t ecdsa -b 521 -q -N ""
ADD . /workspace
EXPOSE 80
CMD ["python3", "run_callback.py"]
###########################
## Dockerfile更新后
#FROM harbor.4pd.io/lab-platform/inf/python:3.9
#WORKDIR /app
## 安装依赖
##RUN pip install torch librosa flask
##RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && \
## pip cache purge && \
## pip --default-timeout=1000 install torch librosa flask
## 删除原来的 COPY pytorch_model.bin /app/
#COPY inference.py /app/
# 只需要复制启动脚本
#EXPOSE 80
#CMD ["python", "inference.py"]
####################
##############################更新0731#################################

Submodule asr-tco_image deleted from 8f9a14f472

6
config.yaml Normal file
View File

@@ -0,0 +1,6 @@
leaderboard_options:
nfs:
- name: sid_model
srcRelativePath: zhoushasha/models/image_models/apple_mobilevit-small
mountPoint: /model
source: ceph_customer

BIN
helm-chart/.DS_Store vendored Normal file

Binary file not shown.

77
helm-chart/README.md Normal file
View File

@@ -0,0 +1,77 @@
## judgeflow chart 的要求
### values.yaml 文件必须包含如下字段,并且模板中必须引用 values.yaml 中的如下字段
```
podLabels
env
volumeMounts
volumes
affinity
```
### values.yaml 文件必须在 volumeMounts 中声明如下卷
```
workspace
submit
datafile
```
## 被测服务sut chart 的要求
### values.yaml 文件必须包含如下字段,并且资源模板中必须引用 values.yaml 中的如下字段
```
podLabels
affinity
```
针对 podLabels 字段values.yaml 中配置格式如下:
```
podLabels: {}
```
下面给出示例
podLabels
values.yaml
templates/deployment.yaml
```
metadata:
labels:
{{- with .Values.podLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
```
affinity
values.yaml
```
affinity: {}
```
templates/deployment.yaml
```
spec:
template:
spec:
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
```
### 如果需要在 sut 中使用共享存储,则 sut chart 的 values.yaml 也必须包含如下字段,且模板中必须引用 values.yaml 中的如下字段
```
volumeMounts
volumes
```

View File

@@ -0,0 +1,23 @@
# Patterns to ignore when building packages.
# This supports shell glob matching, relative path matching, and
# negation (prefixed with !). Only one pattern per line.
.DS_Store
# Common VCS dirs
.git/
.gitignore
.bzr/
.bzrignore
.hg/
.hgignore
.svn/
# Common backup files
*.swp
*.bak
*.tmp
*.orig
*~
# Various IDEs
.project
.idea/
*.tmproj
.vscode/

View File

@@ -0,0 +1,24 @@
apiVersion: v2
name: ${chartName}
description: Leaderboard judgeflow helm chart for demo
# A chart can be either an 'application' or a 'library' chart.
#
# Application charts are a collection of templates that can be packaged into versioned archives
# to be deployed.
#
# Library charts provide useful utilities or functions for the chart developer. They're included as
# a dependency of application charts to inject those utilities and functions into the rendering
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
version: ${version}
# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes.
appVersion: "${appVersion}"

View File

@@ -0,0 +1,62 @@
{{/*
Expand the name of the chart.
*/}}
{{- define "judgeflow.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Create a default fully qualified app name.
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
If release name contains chart name it will be used as a full name.
*/}}
{{- define "judgeflow.fullname" -}}
{{- if .Values.fullnameOverride }}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- $name := default .Chart.Name .Values.nameOverride }}
{{- if contains $name .Release.Name }}
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Create chart name and version as used by the chart label.
*/}}
{{- define "judgeflow.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Common labels
*/}}
{{- define "judgeflow.labels" -}}
helm.sh/chart: {{ include "judgeflow.chart" . }}
{{ include "judgeflow.selectorLabels" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{/*
Selector labels
*/}}
{{- define "judgeflow.selectorLabels" -}}
app.kubernetes.io/name: {{ include "judgeflow.name" . }}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end }}
{{/*
Create the name of the service account to use
*/}}
{{- define "judgeflow.serviceAccountName" -}}
{{- if .Values.serviceAccount.create }}
{{- default (include "judgeflow.fullname" .) .Values.serviceAccount.name }}
{{- else }}
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,32 @@
{{- if .Values.autoscaling.enabled }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: {{ include "judgeflow.fullname" . }}
labels:
{{- include "judgeflow.labels" . | nindent 4 }}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: {{ include "judgeflow.fullname" . }}
minReplicas: {{ .Values.autoscaling.minReplicas }}
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
metrics:
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
{{- end }}
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,61 @@
{{- if .Values.ingress.enabled -}}
{{- $fullName := include "judgeflow.fullname" . -}}
{{- $svcPort := .Values.service.port -}}
{{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }}
{{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }}
{{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}}
{{- end }}
{{- end }}
{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}}
apiVersion: networking.k8s.io/v1
{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}}
apiVersion: networking.k8s.io/v1beta1
{{- else -}}
apiVersion: extensions/v1beta1
{{- end }}
kind: Ingress
metadata:
name: {{ $fullName }}
labels:
{{- include "judgeflow.labels" . | nindent 4 }}
{{- with .Values.ingress.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
{{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }}
ingressClassName: {{ .Values.ingress.className }}
{{- end }}
{{- if .Values.ingress.tls }}
tls:
{{- range .Values.ingress.tls }}
- hosts:
{{- range .hosts }}
- {{ . | quote }}
{{- end }}
secretName: {{ .secretName }}
{{- end }}
{{- end }}
rules:
{{- range .Values.ingress.hosts }}
- host: {{ .host | quote }}
http:
paths:
{{- range .paths }}
- path: {{ .path }}
{{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }}
pathType: {{ .pathType }}
{{- end }}
backend:
{{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }}
service:
name: {{ $fullName }}
port:
number: {{ $svcPort }}
{{- else }}
serviceName: {{ $fullName }}
servicePort: {{ $svcPort }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,63 @@
apiVersion: batch/v1
kind: Job
metadata:
name: {{ include "judgeflow.fullname" . }}
labels:
{{- include "judgeflow.labels" . | nindent 4 }}
{{- with .Values.podLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
template:
metadata:
labels:
{{- include "judgeflow.labels" . | nindent 8 }}
{{- with .Values.podLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
{{- with .Values.priorityclassname }}
priorityClassName: "{{ . }}"
{{- end }}
containers:
- name: {{ .Chart.Name }}
securityContext:
{{- toYaml .Values.securityContext | nindent 12 }}
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
{{- with .Values.env }}
env:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- if and (hasKey .Values "service") (hasKey .Values.service "ports") }}
ports:
{{- range .Values.service.ports }}
- name: {{ .name }}
containerPort: {{ .port }}
{{- end }}
{{- end }}
{{- if hasKey .Values "command" }}
command: {{ .Values.command }}
{{- end }}
volumeMounts:
{{- toYaml .Values.volumeMounts | nindent 12 }}
resources:
{{- toYaml .Values.resources | nindent 12 }}
restartPolicy: Never
{{- with .Values.volumes }}
volumes:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
backoffLimit: 0

View File

@@ -0,0 +1,10 @@
{{- if .Values.priorityclassname }}
apiVersion: scheduling.k8s.io/v1
kind: PriorityClass
metadata:
name: "{{ .Values.priorityclassname }}"
value: {{ .Values.priorityclassvalue }}
globalDefault: false
preemptionPolicy: "Never"
description: "This is a priority class."
{{- end }}

View File

@@ -0,0 +1,22 @@
{{- if and (hasKey .Values "service") (hasKey .Values.service "type") }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "judgeflow.fullname" . }}
labels:
{{- include "judgeflow.labels" . | nindent 4 }}
{{- with .Values.podLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
type: {{ .Values.service.type }}
ports:
{{- range .Values.service.ports }}
- port: {{ .port }}
targetPort: {{ .port }}
protocol: TCP
name: {{ .name }}
{{- end }}
selector:
{{- include "judgeflow.selectorLabels" . | nindent 4 }}
{{- end }}

View File

@@ -0,0 +1,13 @@
{{- if .Values.serviceAccount.create -}}
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ include "judgeflow.serviceAccountName" . }}
labels:
{{- include "judgeflow.labels" . | nindent 4 }}
{{- with .Values.serviceAccount.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
{{- end }}

View File

@@ -0,0 +1,15 @@
apiVersion: v1
kind: Pod
metadata:
name: {{ include "judgeflow.fullname" . }}-test-connection
labels:
{{- include "judgeflow.labels" . | nindent 4 }}
annotations:
"helm.sh/hook": test
spec:
containers:
- name: wget
image: busybox
command: ['wget']
args: ['{{ include "judgeflow.fullname" . }}:{{ .Values.service.port }}']
restartPolicy: Never

View File

@@ -0,0 +1,124 @@
# Default values for job_demo.
# This is a YAML-formatted file.
# Declare variables to be passed into your templates.
replicaCount: 1
image:
repository: "${imageRepo}"
pullPolicy: IfNotPresent
# Overrides the image tag whose default is the chart appVersion.
tag: "${imageTag}"
imagePullSecrets: []
nameOverride: ""
fullnameOverride: ""
serviceAccount:
# Specifies whether a service account should be created
create: true
# Annotations to add to the service account
annotations: {}
# The name of the service account to use.
# If not set and create is true, a name is generated using the fullname template
name: ""
podAnnotations: {}
podLabels:
contest.4pd.io/leaderboard-resource-type: judge_flow
contest.4pd.io/leaderboard-job-id: "0"
contest.4pd.io/leaderboard-submit-id: "0"
podSecurityContext: {}
# fsGroup: 2000
securityContext: {}
# capabilities:
# drop:
# - ALL
# readOnlyRootFilesystem: true
# runAsNonRoot: true
# runAsUser: 1000
service:
type: ClusterIP
ports:
- name: http
port: 80
ingress:
enabled: false
className: ""
annotations: {}
# kubernetes.io/ingress.class: nginx
# kubernetes.io/tls-acme: "true"
hosts:
- host: chart-example.local
paths:
- path: /
pathType: ImplementationSpecific
tls: []
# - secretName: chart-example-tls
# hosts:
# - chart-example.local
resources:
# We usually recommend not to specify default resources and to leave this as a conscious
# choice for the user. This also increases chances charts run on environments with little
# resources, such as Minikube. If you do want to specify resources, uncomment the following
# lines, adjust them as necessary, and remove the curly braces after 'resources:'.
limits:
cpu: 3000m
memory: 16Gi
requests:
cpu: 3000m
memory: 16Gi
autoscaling:
enabled: false
minReplicas: 1
maxReplicas: 100
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
nodeSelector:
juicefs: "on"
contest.4pd.io/cpu: INTEL-8358
tolerations: []
affinity: {}
env:
- name: TZ
value: Asia/Shanghai
- name: MY_POD_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
#command: '["python","run.py"]'
volumeMounts:
- name: workspace
mountPath: /tmp/workspace
- name: datafile
mountPath: /tmp/datafile
- name: submit
mountPath: /tmp/submit_config
- name: juicefs-pv
mountPath: /tmp/juicefs
- name: customer
mountPath: /tmp/customer
- name: submit-private
mountPath: /tmp/submit_private
volumes:
- name: juicefs-pv
persistentVolumeClaim:
claimName: juicefs-pvc
priorityclassname: ''
priorityclassvalue: '0'

BIN
helm-chart/sut/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -0,0 +1,23 @@
# Patterns to ignore when building packages.
# This supports shell glob matching, relative path matching, and
# negation (prefixed with !). Only one pattern per line.
.DS_Store
# Common VCS dirs
.git/
.gitignore
.bzr/
.bzrignore
.hg/
.hgignore
.svn/
# Common backup files
*.swp
*.bak
*.tmp
*.orig
*~
# Various IDEs
.project
.idea/
*.tmproj
.vscode/

24
helm-chart/sut/Chart.yaml Normal file
View File

@@ -0,0 +1,24 @@
apiVersion: v2
name: sut
description: A Helm chart for Kubernetes
# A chart can be either an 'application' or a 'library' chart.
#
# Application charts are a collection of templates that can be packaged into versioned archives
# to be deployed.
#
# Library charts provide useful utilities or functions for the chart developer. They're included as
# a dependency of application charts to inject those utilities and functions into the rendering
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
version: 0.1.0
# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes.
appVersion: "0.1.0"

View File

@@ -0,0 +1,62 @@
{{/*
Expand the name of the chart.
*/}}
{{- define "sut.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Create a default fully qualified app name.
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
If release name contains chart name it will be used as a full name.
*/}}
{{- define "sut.fullname" -}}
{{- if .Values.fullnameOverride }}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- $name := default .Chart.Name .Values.nameOverride }}
{{- if contains $name .Release.Name }}
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Create chart name and version as used by the chart label.
*/}}
{{- define "sut.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Common labels
*/}}
{{- define "sut.labels" -}}
helm.sh/chart: {{ include "sut.chart" . }}
{{ include "sut.selectorLabels" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{/*
Selector labels
*/}}
{{- define "sut.selectorLabels" -}}
app.kubernetes.io/name: {{ include "sut.name" . }}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end }}
{{/*
Create the name of the service account to use
*/}}
{{- define "sut.serviceAccountName" -}}
{{- if .Values.serviceAccount.create }}
{{- default (include "sut.fullname" .) .Values.serviceAccount.name }}
{{- else }}
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,94 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "sut.fullname" . }}
labels:
{{- include "sut.labels" . | nindent 4 }}
{{- with .Values.podLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
{{- if not .Values.autoscaling.enabled }}
replicas: {{ .Values.replicaCount }}
{{- end }}
selector:
matchLabels:
{{- include "sut.selectorLabels" . | nindent 6 }}
template:
metadata:
{{- with .Values.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
labels:
{{- include "sut.labels" . | nindent 8 }}
{{- with .Values.podLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "sut.serviceAccountName" . }}
securityContext:
{{- toYaml .Values.podSecurityContext | nindent 8 }}
{{- with .Values.priorityclassname }}
priorityClassName: "{{ . }}"
{{- end }}
containers:
- name: {{ .Chart.Name }}
securityContext:
{{- toYaml .Values.securityContext | nindent 12 }}
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
{{- with .Values.env }}
env:
{{- toYaml . | nindent 12 }}
{{- end }}
ports:
- name: http
containerPort: {{ .Values.service.port }}
protocol: TCP
{{- with .Values.command }}
command:
{{- toYaml . | nindent 12 }}
{{- end }}
resources:
{{- toYaml .Values.resources | nindent 12 }}
{{- with .Values.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.livenessProbe }}
livenessProbe:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.readinessProbe }}
readinessProbe:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.startupProbe }}
startupProbe:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
{{- with .Values.volumes }}
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
tolerations:
- key: "hosttype"
operator: "Equal"
value: "iluvatar"
effect: "NoSchedule"

View File

@@ -0,0 +1,32 @@
{{- if .Values.autoscaling.enabled }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: {{ include "sut.fullname" . }}
labels:
{{- include "sut.labels" . | nindent 4 }}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: {{ include "sut.fullname" . }}
minReplicas: {{ .Values.autoscaling.minReplicas }}
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
metrics:
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
{{- end }}
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,61 @@
{{- if .Values.ingress.enabled -}}
{{- $fullName := include "sut.fullname" . -}}
{{- $svcPort := .Values.service.port -}}
{{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }}
{{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }}
{{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}}
{{- end }}
{{- end }}
{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}}
apiVersion: networking.k8s.io/v1
{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}}
apiVersion: networking.k8s.io/v1beta1
{{- else -}}
apiVersion: extensions/v1beta1
{{- end }}
kind: Ingress
metadata:
name: {{ $fullName }}
labels:
{{- include "sut.labels" . | nindent 4 }}
{{- with .Values.ingress.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
{{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }}
ingressClassName: {{ .Values.ingress.className }}
{{- end }}
{{- if .Values.ingress.tls }}
tls:
{{- range .Values.ingress.tls }}
- hosts:
{{- range .hosts }}
- {{ . | quote }}
{{- end }}
secretName: {{ .secretName }}
{{- end }}
{{- end }}
rules:
{{- range .Values.ingress.hosts }}
- host: {{ .host | quote }}
http:
paths:
{{- range .paths }}
- path: {{ .path }}
{{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }}
pathType: {{ .pathType }}
{{- end }}
backend:
{{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }}
service:
name: {{ $fullName }}
port:
number: {{ $svcPort }}
{{- else }}
serviceName: {{ $fullName }}
servicePort: {{ $svcPort }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,18 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "sut.fullname" . }}
labels:
{{- include "sut.labels" . | nindent 4 }}
{{- with .Values.podLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
type: {{ .Values.service.type }}
ports:
- port: {{ .Values.service.port }}
targetPort: http
protocol: TCP
name: socket
selector:
{{- include "sut.selectorLabels" . | nindent 4 }}

View File

@@ -0,0 +1,13 @@
{{- if .Values.serviceAccount.create -}}
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ include "sut.serviceAccountName" . }}
labels:
{{- include "sut.labels" . | nindent 4 }}
{{- with .Values.serviceAccount.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
{{- end }}

View File

@@ -0,0 +1,15 @@
apiVersion: v1
kind: Pod
metadata:
name: "{{ include "sut.fullname" . }}-test-connection"
labels:
{{- include "sut.labels" . | nindent 4 }}
annotations:
"helm.sh/hook": test
spec:
containers:
- name: wget
image: busybox
command: ['wget']
args: ['{{ include "sut.fullname" . }}:{{ .Values.service.port }}']
restartPolicy: Never

View File

@@ -0,0 +1,144 @@
# Default values for sut.
# This is a YAML-formatted file.
# Declare variables to be passed into your templates.
replicaCount: 1
image:
repository: harbor.4pd.io/lab-platform/inf/python
pullPolicy: IfNotPresent
# Overrides the image tag whose default is the chart appVersion.
tag: 3.9
imagePullSecrets: []
nameOverride: ""
fullnameOverride: ""
serviceAccount:
# Specifies whether a service account should be created
create: true
# Automatically mount a ServiceAccount's API credentials?
automount: true
# Annotations to add to the service account
annotations: {}
# The name of the service account to use.
# If not set and create is true, a name is generated using the fullname template
name: ""
podAnnotations: {}
podLabels: {}
podSecurityContext: {}
# fsGroup: 2000
securityContext: {}
# capabilities:
# drop:
# - ALL
# readOnlyRootFilesystem: true
# runAsNonRoot: true
# runAsUser: 1000
service:
type: ClusterIP
port: 80
ingress:
enabled: false
className: ""
annotations: {}
# kubernetes.io/ingress.class: nginx
# kubernetes.io/tls-acme: "true"
hosts:
- host: chart-example.local
paths:
- path: /
pathType: ImplementationSpecific
tls: []
# - secretName: chart-example-tls
# hosts:
# - chart-example.local
resources:
# We usually recommend not to specify default resources and to leave this as a conscious
# choice for the user. This also increases chances charts run on environments with little
# resources, such as Minikube. If you do want to specify resources, uncomment the following
# lines, adjust them as necessary, and remove the curly braces after 'resources:'.
limits:
cpu: 1000m
memory: 4096Mi
requests:
cpu: 1000m
memory: 4096Mi
autoscaling:
enabled: false
minReplicas: 1
maxReplicas: 100
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
# Additional volumes on the output Deployment definition.
volumes: []
# - name: foo
# secret:
# secretName: mysecret
# optional: false
# Additional volumeMounts on the output Deployment definition.
volumeMounts: []
# - name: foo
# mountPath: "/etc/foo"
# readOnly: true
nodeSelector:
contest.4pd.io/accelerator: iluvatar-BI-V100
tolerations:
- key: hosttype
operator: Equal
value: iluvatar
effect: NoSchedule
affinity: {}
readinessProbe:
failureThreshold: 1000
httpGet:
path: /health
port: 80
scheme: HTTP
#readinessProbe:
# httpGet:
# path: /health
# port: 80
# scheme: HTTP
# initialDelaySeconds: 5 # 应用启动后等待 5 秒再开始探测
# failureThreshold: 5 # 连续失败 3 次后标记为未就绪
# successThreshold: 1 # 连续成功 1 次后标记为就绪
env:
- name: TZ
value: Asia/Shanghai
- name: MY_POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: MY_POD_NAMESPACE
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: MY_POD_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
- name: MY_NODE_IP
valueFrom:
fieldRef:
fieldPath: status.hostIP
#command: ''
priorityclassname: ''

64
local_test.py Normal file
View File

@@ -0,0 +1,64 @@
import os
import tempfile
import shutil
if os.path.exists("/tmp/submit_private"):
shutil.rmtree("/tmp/submit_private")
with tempfile.TemporaryDirectory() as tempdir:
config_path = os.path.join(tempdir, "config.json")
assert not os.system(f"ssh-keygen -f {tempdir}/ssh-key-ecdsa -t ecdsa -b 521 -q -N \"\"")
config = """
model: whisper
model_key: whisper
config.json:
name: 'faster-whisper-server:latest'
support_devices:
- cpu
model_path: ''
port: 8080
other_ports: []
other_ports_count: 1
entrypoint: start.bat
MIN_CHUNK: 2.5
MIN_ADD_CHUNK: 2.5
COMPUTE_TYPE: int8
NUM_WORKERS: 1
CPU_THREADS: 2
BEAM_SIZE: 5
BATCH: 1
LANG: auto
DEVICE: cpu
CHUNK_LENGTH: 5
CLASS_MODEL: ./models/faster-whisper-base
EN_MODEL: ./models/faster-whisper-base
ZH_MODEL: ./models/faster-whisper-base
RU_MODEL: ./models/faster-whisper-base
PT_MODEL: ./models/faster-whisper-base
AR_MODEL: ./models/faster-whisper-base
NEW_VERSION: 1
NEED_RESET: 0
leaderboard_options:
nfs:
- name: whisper
srcRelativePath: leaderboard/pc_asr/en.tar.gz
mountPoint: /tmp
source: ceph_customer
"""
with open(config_path, "w") as f:
f.write(config)
os.environ["SSH_KEY_DIR"] = tempdir
os.environ["SUBMIT_CONFIG_FILEPATH"] = config_path
os.environ["MODEL_MAPPING"] = '{"whisper": "edge-ml.tar.gz"}'
from run_async_a10 import get_sut_url_windows
print(get_sut_url_windows())
import time
time.sleep(3600)

8
mock_env.sh Normal file
View File

@@ -0,0 +1,8 @@
#!/bin/bash
export DATASET_FILEPATH=dataset/formatted1/de.zip
export RESULT_FILEPATH=out/result.json
export DETAILED_CASES_FILEPATH=out/detail_cases.json
export SUBMIT_CONFIG_FILEPATH=
export BENCHMARK_NAME=
export MY_POD_IP=127.0.0.1

215
model_test_caltech_3.py Normal file
View File

@@ -0,0 +1,215 @@
import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import BeitImageProcessor, BeitForImageClassification
# 根据模型实际架构选择类
from transformers import ViTForImageClassification, BeitForImageClassification
from tqdm import tqdm
from transformers import AutoConfig
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
import random
import time # 新增导入时间模块
# 支持 Iluvatar GPU 加速,若不可用则使用 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}") # 添加调试信息
# 若有多块 GPU可使用 DataParallel 进行并行计算
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 块 GPU 进行计算")
class COCOImageClassifier:
def __init__(self, model_path: str, local_image_paths: list):
"""初始化COCO图像分类器"""
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
# 将模型移动到设备
self.model = self.model.to(device)
print(f"模型是否在 GPU 上: {next(self.model.parameters()).is_cuda}") # 添加调试信息
# 若有多块 GPU使用 DataParallel
if torch.cuda.device_count() > 1:
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
self.local_image_paths = local_image_paths
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
"""
预测本地图片文件对应的图片类别
Args:
image_path: 本地图片文件路径
top_k: 返回置信度最高的前k个类别
Returns:
包含预测结果的字典
"""
try:
# 打开图片
image = Image.open(image_path).convert("RGB")
# 预处理
inputs = self.processor(images=image, return_tensors="pt")
# 将输入数据移动到设备
inputs = inputs.to(device)
# 模型推理
with torch.no_grad():
outputs = self.model(**inputs)
# 获取预测结果
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(top_k, dim=1)
# 整理结果
predictions = []
for i in range(top_k):
class_idx = top_indices[0, i].item()
confidence = top_probs[0, i].item()
predictions.append({
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": confidence
})
return {
"image_path": image_path,
"predictions": predictions
}
except Exception as e:
print(f"处理图片文件 {image_path} 时出错: {e}")
return None
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
"""
批量预测本地图片
Args:
limit: 限制处理的图片数量
top_k: 返回置信度最高的前k个类别
Returns:
包含所有预测结果的列表
"""
results = []
local_image_paths = self.local_image_paths[:limit]
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
start_time = time.time() # 记录开始时间
for image_path in tqdm(local_image_paths):
result = self.predict_image_path(image_path, top_k)
if result:
results.append(result)
end_time = time.time() # 记录结束时间
total_time = end_time - start_time # 计算总时间
images_per_second = len(results) / total_time # 计算每秒处理的图片数量
print(f"模型每秒可以处理 {images_per_second:.2f} 张图片")
return results
def save_results(self, results: list, output_file: str = "caltech_predictions.json"):
"""
保存预测结果到JSON文件
Args:
results: 预测结果列表
output_file: 输出文件名
"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"结果已保存到 {output_file}")
# 主程序
if __name__ == "__main__":
# 替换为本地模型路径
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
# 替换为Caltech 256数据集文件夹路径
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
local_image_paths = []
true_labels = {}
# 遍历Caltech 256数据集中的每个文件夹
for folder in os.listdir(CALTECH_256_PATH):
folder_path = os.path.join(CALTECH_256_PATH, folder)
if os.path.isdir(folder_path):
# 获取文件夹名称中的类别名称
class_name = folder.split('.', 1)[1]
# 获取文件夹中的所有图片文件
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
# 随机选择3张图片
selected_images = random.sample(image_files, min(3, len(image_files)))
for image_path in selected_images:
local_image_paths.append(image_path)
true_labels[image_path] = class_name
# 创建分类器实例
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
# 批量预测
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
# 保存结果
classifier.save_results(results)
# 打印简要统计
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
if results:
print("\n示例预测结果:")
sample = results[0]
print(f"图片路径: {sample['image_path']}")
for i, pred in enumerate(sample['predictions'], 1):
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
correct_count = 0
total_count = len(results)
# 统计每个类别的实际样本数和正确预测数
class_actual_count = {}
class_correct_count = {}
for prediction in results:
image_path = prediction['image_path']
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
predicted_class = top1_prediction['class_name'].lower()
true_class = true_labels.get(image_path).lower()
# 统计每个类别的实际样本数
if true_class not in class_actual_count:
class_actual_count[true_class] = 0
class_actual_count[true_class] += 1
# 检查预测类别中的每个单词是否包含真实标签
words = predicted_class.split()
for word in words:
if true_class in word:
correct_count += 1
# 统计每个类别的正确预测数
if true_class not in class_correct_count:
class_correct_count[true_class] = 0
class_correct_count[true_class] += 1
break
accuracy = correct_count / total_count
print(f"\nAccuracy: {accuracy * 100:.2f}%")
# 计算每个类别的召回率
recall_per_class = {}
for class_name in class_actual_count:
if class_name in class_correct_count:
recall_per_class[class_name] = class_correct_count[class_name] / class_actual_count[class_name]
else:
recall_per_class[class_name] = 0
# 计算平均召回率
average_recall = sum(recall_per_class.values()) / len(recall_per_class)
print(f"\nAverage Recall: {average_recall * 100:.2f}%")

197
model_test_caltech_cpu1.py Normal file
View File

@@ -0,0 +1,197 @@
import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import AutoImageProcessor, AutoModelForImageClassification
from tqdm import tqdm
import os
import random
import time
# 强制使用CPU
device = torch.device("cpu")
print(f"当前使用的设备: {device}")
class COCOImageClassifier:
def __init__(self, model_path: str, local_image_paths: list):
"""初始化COCO图像分类器"""
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
# 将模型移动到CPU
self.model = self.model.to(device)
self.id2label = self.model.config.id2label
self.local_image_paths = local_image_paths
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
"""
预测本地图片文件对应的图片类别
Args:
image_path: 本地图片文件路径
top_k: 返回置信度最高的前k个类别
Returns:
包含预测结果的字典
"""
try:
# 打开图片
image = Image.open(image_path).convert("RGB")
# 预处理
inputs = self.processor(images=image, return_tensors="pt")
# 将输入数据移动到CPU
inputs = inputs.to(device)
# 模型推理
with torch.no_grad():
outputs = self.model(**inputs)
# 获取预测结果
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(top_k, dim=1)
# 整理结果
predictions = []
for i in range(top_k):
class_idx = top_indices[0, i].item()
confidence = top_probs[0, i].item()
predictions.append({
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": confidence
})
return {
"image_path": image_path,
"predictions": predictions
}
except Exception as e:
print(f"处理图片文件 {image_path} 时出错: {e}")
return None
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
"""
批量预测本地图片
Args:
limit: 限制处理的图片数量
top_k: 返回置信度最高的前k个类别
Returns:
包含所有预测结果的列表
"""
results = []
local_image_paths = self.local_image_paths[:limit]
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
start_time = time.time()
for image_path in tqdm(local_image_paths):
result = self.predict_image_path(image_path, top_k)
if result:
results.append(result)
end_time = time.time()
# 计算吞吐量
throughput = len(results) / (end_time - start_time)
print(f"模型每秒可以处理 {throughput:.2f} 张图片")
return results
def save_results(self, results: list, output_file: str = "celtech_cpu_predictions.json"):
"""
保存预测结果到JSON文件
Args:
results: 预测结果列表
output_file: 输出文件名
"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"结果已保存到 {output_file}")
# 主程序
if __name__ == "__main__":
# 替换为本地模型路径
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
# 替换为Caltech 256数据集文件夹路径 New
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
local_image_paths = []
true_labels = {}
# 遍历Caltech 256数据集中的每个文件夹
for folder in os.listdir(CALTECH_256_PATH):
folder_path = os.path.join(CALTECH_256_PATH, folder)
if os.path.isdir(folder_path):
# 获取文件夹名称中的类别名称
class_name = folder.split('.', 1)[1]
# 获取文件夹中的所有图片文件
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
# 随机选择3张图片
selected_images = random.sample(image_files, min(3, len(image_files)))
for image_path in selected_images:
local_image_paths.append(image_path)
true_labels[image_path] = class_name
# 创建分类器实例
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
# 批量预测
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
# 保存结果
classifier.save_results(results)
# 打印简要统计
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
if results:
print("\n示例预测结果:")
sample = results[0]
print(f"图片路径: {sample['image_path']}")
for i, pred in enumerate(sample['predictions'], 1):
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
correct_count = 0
total_count = len(results)
class_true_positives = {}
class_false_negatives = {}
for prediction in results:
image_path = prediction['image_path']
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
predicted_class = top1_prediction['class_name'].lower()
true_class = true_labels.get(image_path).lower()
if true_class not in class_true_positives:
class_true_positives[true_class] = 0
class_false_negatives[true_class] = 0
# 检查预测类别中的每个单词是否包含真实标签
words = predicted_class.split()
for word in words:
if true_class in word:
correct_count += 1
class_true_positives[true_class] += 1
break
else:
class_false_negatives[true_class] += 1
accuracy = correct_count / total_count
print(f"\nAccuracy: {accuracy * 100:.2f}%")
# 计算召回率
total_true_positives = 0
total_false_negatives = 0
for class_name in class_true_positives:
total_true_positives += class_true_positives[class_name]
total_false_negatives += class_false_negatives[class_name]
recall = total_true_positives / (total_true_positives + total_false_negatives)
print(f"Recall: {recall * 100:.2f}%")

166
model_test_caltech_http.py Normal file
View File

@@ -0,0 +1,166 @@
import torch
import time
import os
import multiprocessing
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from flask import Flask, request, jsonify
from io import BytesIO
# 设置CPU核心数为4
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
torch.set_num_threads(4) # 设置PyTorch的CPU线程数
# 设备配置
device_cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_cpu = torch.device("cpu")
print(f"当前CUDA设备: {device_cuda}, CPU设备: {device_cpu}")
print(f"CPU核心数设置: {torch.get_num_threads()}")
class ImageClassifier:
def __init__(self, model_path: str):
self.processor = AutoImageProcessor.from_pretrained(model_path)
# 分别加载GPU和CPU模型实例
if device_cuda.type == "cuda":
self.model_cuda = AutoModelForImageClassification.from_pretrained(model_path).to(device_cuda)
else:
self.model_cuda = None # 若没有CUDA则不加载
self.model_cpu = AutoModelForImageClassification.from_pretrained(model_path).to(device_cpu)
# 保存id2label映射
self.id2label = self.model_cpu.config.id2label
def _predict_with_model(self, image, model, device) -> dict:
"""使用指定模型和设备执行预测,包含单独计时"""
try:
# 记录开始时间
start_time = time.perf_counter() # 使用更精确的计时函数
# 处理图片并移动到目标设备
inputs = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(** inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
max_prob, max_idx = probs.max(dim=1)
class_idx = max_idx.item()
# 计算处理时间保留6位小数
processing_time = round(time.perf_counter() - start_time, 6)
return {
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": float(max_prob.item()),
"device_used": str(device),
"processing_time": processing_time # 处理时间
}
except Exception as e:
return {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device),
"processing_time": 0.0,
"error": str(e)
}
def predict_single_image(self, image) -> dict:
"""预测单张图片分别使用GPU和CPU模型"""
results = {"status": "success"}
# GPU预测如果可用
if self.model_cuda is not None:
cuda_result = self._predict_with_model(image, self.model_cuda, device_cuda)
else:
cuda_result = {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cuda),
"processing_time": 0.0,
"error": "CUDA设备不可用未加载CUDA模型"
}
results["cuda_prediction"] = cuda_result
# CPU预测已限制为4核心
cpu_result = self._predict_with_model(image, self.model_cpu, device_cpu)
results["cpu_prediction"] = cpu_result
return results
# 初始化服务
app = Flask(__name__)
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径)
classifier = ImageClassifier(MODEL_PATH)
@app.route('/v1/private/s782b4996', methods=['POST'])
def predict_single():
"""接收单张图片并返回预测结果及处理时间"""
if 'image' not in request.files:
return jsonify({
"status": "error",
"cuda_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cuda),
"processing_time": 0.0,
"error": "请求中未包含图片"
},
"cpu_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cpu),
"processing_time": 0.0,
"error": "请求中未包含图片"
}
}), 400
image_file = request.files['image']
try:
image = Image.open(BytesIO(image_file.read())).convert("RGB")
result = classifier.predict_single_image(image)
return jsonify(result)
except Exception as e:
return jsonify({
"status": "error",
"cuda_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cuda),
"processing_time": 0.0,
"error": str(e)
},
"cpu_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cpu),
"processing_time": 0.0,
"error": str(e)
}
}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
"status": "healthy",
"cuda_available": device_cuda.type == "cuda",
"cuda_device": str(device_cuda),
"cpu_device": str(device_cpu),
"cpu_threads": torch.get_num_threads() # 显示CPU线程数
}), 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=80, debug=False)

View File

@@ -0,0 +1,163 @@
import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import AutoImageProcessor, AutoModelForImageClassification
from tqdm import tqdm
import os
import random
import time
from flask import Flask, request, jsonify # 引入Flask
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}")
class COCOImageClassifier:
def __init__(self, model_path: str):
"""初始化分类器移除local_image_paths参数改为动态接收"""
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
self.model = self.model.to(device)
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 块GPU")
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
"""预测单张图片(复用原逻辑)"""
try:
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = self.model(** inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(top_k, dim=1)
predictions = []
for i in range(top_k):
class_idx = top_indices[0, i].item()
predictions.append({
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": top_probs[0, i].item()
})
return {
"image_path": image_path,
"predictions": predictions
}
except Exception as e:
print(f"处理图片 {image_path} 出错: {e}")
return None
def batch_predict_and_evaluate(self, image_paths: list, true_labels: dict, top_k: int = 3) -> dict:
"""批量预测并计算准确率、召回率"""
results = []
start_time = time.time()
for image_path in tqdm(image_paths):
result = self.predict_image_path(image_path, top_k)
if result:
results.append(result)
end_time = time.time()
total_time = end_time - start_time
images_per_second = len(results) / total_time if total_time > 0 else 0
# 计算准确率和召回率(复用原逻辑)
correct_count = 0
total_count = len(results)
class_actual_count = {}
class_correct_count = {}
for prediction in results:
image_path = prediction['image_path']
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
predicted_class = top1_prediction['class_name'].lower()
true_class = true_labels.get(image_path, "").lower()
# 统计每个类别的实际样本数
class_actual_count[true_class] = class_actual_count.get(true_class, 0) + 1
# 检查预测是否正确
words = predicted_class.split()
for word in words:
if true_class in word:
correct_count += 1
class_correct_count[true_class] = class_correct_count.get(true_class, 0) + 1
break
# 计算指标
accuracy = correct_count / total_count if total_count > 0 else 0
recall_per_class = {}
for class_name in class_actual_count:
recall_per_class[class_name] = class_correct_count.get(class_name, 0) / class_actual_count[class_name]
average_recall = sum(recall_per_class.values()) / len(recall_per_class) if recall_per_class else 0
# 返回包含指标的结果
return {
"status": "success",
"metrics": {
"accuracy": round(accuracy * 100, 2), # 百分比
"average_recall": round(average_recall * 100, 2), # 百分比
"total_images": total_count,
"correct_predictions": correct_count,
"speed_images_per_second": round(images_per_second, 2)
},
"sample_predictions": results[:3] # 示例预测结果(可选)
}
# 初始化Flask服务
app = Flask(__name__)
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 容器内模型路径
DATASET_PATH = os.environ.get("DATASET_PATH", "/app/dataset") # 容器内数据集路径
classifier = COCOImageClassifier(MODEL_PATH)
@app.route('/v1/private/s782b4996', methods=['POST'])
def evaluate():
"""接收请求并返回评估结果(准确率、召回率等)"""
try:
# 解析请求参数可选允许动态指定limit等参数
data = request.get_json()
limit = data.get("limit", 20) # 限制处理的图片数量
# 加载数据集(容器内路径)
local_image_paths = []
true_labels = {}
for folder in os.listdir(DATASET_PATH):
folder_path = os.path.join(DATASET_PATH, folder)
if os.path.isdir(folder_path):
class_name = folder.split('.', 1)[1]
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
selected_images = random.sample(image_files, min(3, len(image_files)))
for image_path in selected_images:
local_image_paths.append(image_path)
true_labels[image_path] = class_name
# 限制处理数量
local_image_paths = local_image_paths[:limit]
# 执行预测和评估
result = classifier.batch_predict_and_evaluate(local_image_paths, true_labels, top_k=3)
return jsonify(result)
except Exception as e:
return jsonify({
"status": "error",
"message": str(e)
}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "healthy", "device": str(device)}), 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8000, debug=False)

View File

@@ -0,0 +1,89 @@
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
from flask import Flask, request, jsonify
from io import BytesIO
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}")
class ImageClassifier:
def __init__(self, model_path: str):
# 获取模型路径下的第一个子目录(假设模型文件存放在这里)
subdirs = [d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d))]
if not subdirs:
raise ValueError(f"{model_path} 下未找到任何子目录,无法加载模型")
# 实际的模型文件路径
actual_model_path = os.path.join(model_path, subdirs[0])
print(f"加载模型从: {actual_model_path}")
self.processor = AutoImageProcessor.from_pretrained(actual_model_path)
self.model = AutoModelForImageClassification.from_pretrained(actual_model_path)
self.model = self.model.to(device)
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 块GPU")
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
def predict_single_image(self, image) -> dict:
"""预测单张图片,返回置信度最高的结果"""
try:
# 处理图片
inputs = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = self.model(** inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
# 获取置信度最高的预测结果
max_prob, max_idx = probs.max(dim=1)
class_idx = max_idx.item()
return {
"status": "success",
"top_prediction": {
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": max_prob.item()
}
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# 初始化服务
app = Flask(__name__)
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型根路径(环境变量或默认路径)
classifier = ImageClassifier(MODEL_PATH)
@app.route('/v1/private/s782b4996', methods=['POST'])
def predict_single():
"""接收单张图片并返回最高置信度预测结果"""
# 检查是否有图片上传
if 'image' not in request.files:
return jsonify({"status": "error", "message": "请求中未包含图片"}), 400
image_file = request.files['image']
try:
# 读取图片
image = Image.open(BytesIO(image_file.read())).convert("RGB")
# 预测
result = classifier.predict_single_image(image)
return jsonify(result)
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "healthy", "device": str(device)}), 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8000, debug=False)

View File

@@ -0,0 +1,80 @@
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
from flask import Flask, request, jsonify
from io import BytesIO
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}")
class ImageClassifier:
def __init__(self, model_path: str):
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
self.model = self.model.to(device)
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 块GPU")
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
def predict_single_image(self, image) -> dict:
"""预测单张图片,返回置信度最高的结果"""
try:
# 处理图片
inputs = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
# 获取置信度最高的预测结果
max_prob, max_idx = probs.max(dim=1)
class_idx = max_idx.item()
return {
"status": "success",
"top_prediction": {
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": max_prob.item()
}
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# 初始化服务
app = Flask(__name__)
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径)
classifier = ImageClassifier(MODEL_PATH)
@app.route('/v1/private/s782b4996', methods=['POST'])
def predict_single():
"""接收单张图片并返回最高置信度预测结果"""
# 检查是否有图片上传
if 'image' not in request.files:
return jsonify({"status": "error", "message": "请求中未包含图片"}), 400
image_file = request.files['image']
try:
# 读取图片
image = Image.open(BytesIO(image_file.read())).convert("RGB")
# 预测
result = classifier.predict_single_image(image)
return jsonify(result)
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "healthy", "device": str(device)}), 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=80, debug=False)

24
pyproject.toml Normal file
View File

@@ -0,0 +1,24 @@
[tool.black]
line-length = 80
target-version = ['py39']
[tool.flake8]
max-line-length = 88
count=true
per-file-ignores="./annotation/manager.py:F401"
exclude=["./label", "__pycache__", "./migrations", "./logs", "./pids", "./resources"]
ignore=["W503", "E203"]
enable-extensions="G"
application-import-names=["flake8-isort", "flake8-logging-format", "flake8-builtins"]
import-order-style="edited"
extend-ignore = ["E203", "E701"]
[tool.isort]
py_version=39
profile="black"
multi_line_output=9
line_length=80
group_by_package=true
case_sensitive=true
skip_gitignore=true

13
requirements.txt Normal file
View File

@@ -0,0 +1,13 @@
requests
ruamel.yaml
regex
pyyaml
websocket-client==0.44.0
pydantic==2.6.4
pydantic_core==2.16.3
Levenshtein
numpy
websockets
fabric
vmplatform==0.0.4
flask

114
run.py Normal file
View File

@@ -0,0 +1,114 @@
import gc
import json
import os
import sys
import time
import zipfile
import yaml
from schemas.context import ASRContext
from utils.client import Client
from utils.evaluator import BaseEvaluator
from utils.logger import logger
from utils.service import register_sut
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
UNIT_TEST = os.getenv("UNIT_TEST", 0)
def main():
logger.info("执行……")
dataset_filepath = os.getenv(
"DATASET_FILEPATH",
"./tests/resources/en.zip",
)
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
resource_name = os.getenv("BENCHMARK_NAME")
# 提交配置 & 启动被测服务
if os.getenv("DATASET_FILEPATH", ""):
from utils.helm import resource_check
with open(submit_config_filepath, "r") as fp:
st_config = yaml.safe_load(fp)
st_config["values"] = resource_check(st_config.get("values", {}))
if 'docker_images' in st_config:
sut_url = "ws://172.26.1.75:9827"
os.environ['test'] = '1'
elif 'docker_image' in st_config:
sut_url = register_sut(st_config, resource_name)
elif UNIT_TEST:
sut_url = "ws://172.27.231.36:80"
else:
logger.error("config 配置错误,没有 docker_image")
os._exit(1)
else:
os.environ['test'] = '1'
sut_url = "ws://172.27.231.36:80"
if UNIT_TEST:
exit(0)
"""
# 数据集处理
local_dataset_path = "./dataset"
os.makedirs(local_dataset_path, exist_ok=True)
with zipfile.ZipFile(dataset_filepath) as zf:
zf.extractall(local_dataset_path)
config_path = os.path.join(local_dataset_path, "data.yaml")
with open(config_path, "r") as fp:
dataset_config = yaml.safe_load(fp)
# 数据集信息
dataset_global_config = dataset_config.get("global", {})
dataset_query = dataset_config.get("query_data", {})
evaluator = BaseEvaluator()
# 开始预测
for idx, query_item in enumerate(dataset_query):
gc.collect()
logger.info(f"开始执行 {idx} 条数据")
context = ASRContext(**dataset_global_config)
context.lang = query_item.get("lang", context.lang)
context.file_path = os.path.join(local_dataset_path, query_item["file"])
# context.audio_length = query_item["audio_length"]
interactions = Client(sut_url, context).action()
context.append_labels(query_item["voice"])
context.append_preds(
interactions["predict_data"],
interactions["send_time"],
interactions["recv_time"],
)
context.fail = interactions["fail"]
if IN_TEST:
with open('output.txt', 'w') as fp:
original_stdout = sys.stdout
sys.stdout = fp
print(context)
sys.stdout = original_stdout
evaluator.evaluate(context)
detail_case = evaluator.gen_detail_case()
with open(detail_cases_filepath, "a") as fp:
fp.write(json.dumps(detail_case.to_dict(), ensure_ascii=False) + "\n")
time.sleep(4)
evaluator.post_evaluate()
output_result = evaluator.gen_result()
# print(evaluator.__dict__)
logger.info("执行完成. Result = {output_result}")
with open(result_filepath, "w") as fp:
json.dump(output_result, fp, indent=2, ensure_ascii=False)
with open(bad_cases_filepath, "w") as fp:
fp.write("当前榜单不存在 Bad Case\n")
"""
if __name__ == "__main__":
main()

757
run_async_a10.py Normal file
View File

@@ -0,0 +1,757 @@
import atexit
import concurrent.futures
import fcntl
import gc
import glob
import json
import os
import random
import signal
import sys
import tempfile
import threading
import time
import zipfile
from concurrent.futures import ThreadPoolExecutor
import yaml
from fabric import Connection
from vmplatform import VMOS, Client, VMDataDisk
from schemas.context import ASRContext
from utils.client_async import ClientAsync
from utils.evaluator import BaseEvaluator
from utils.logger import logger
from utils.service import register_sut
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
UNIT_TEST = os.getenv("UNIT_TEST", 0)
DATASET_NUM = os.getenv("DATASET_NUM")
# vm榜单参数
SUT_TYPE = os.getenv("SUT_TYPE", "kubernetes")
SHARE_SUT = os.getenv("SHARE_SUT", "true") == "true"
VM_ID = 0
VM_IP = ""
do_deploy_chart = True
VM_CPU = int(os.getenv("VM_CPU", "2"))
VM_MEM = int(os.getenv("VM_MEM", "4096"))
MODEL_BASEPATH = os.getenv("MODEL_BASEPATH", "/tmp/customer/leaderboard/pc_asr")
MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
SSH_KEY_DIR = os.getenv("SSH_KEY_DIR", "/workspace")
SSH_PUBLIC_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa.pub")
SSH_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa")
CONNECT_KWARGS = {"key_filename": SSH_KEY_FILE}
# 共享sut参数
JOB_ID = os.getenv("JOB_ID")
dirname = "/tmp/submit_private/sut_share"
os.makedirs(dirname, exist_ok=True)
SUT_SHARE_LOCK = os.path.join(dirname, "lock.lock")
SUT_SHARE_USE_LOCK = os.path.join(dirname, "use.lock")
SUT_SHARE_STATUS = os.path.join(dirname, "status.json")
SUT_SHARE_JOB_STATUS = os.path.join(dirname, f"job_status.{JOB_ID}")
SUT_SHARE_PUBLIC_FAIL = os.path.join(dirname, "one_job_failed")
fd_lock = open(SUT_SHARE_USE_LOCK, "a")
def clean_vm_atexit():
global VM_ID, do_deploy_chart
if not VM_ID:
return
if not do_deploy_chart:
return
logger.info("删除vm")
vmclient = Client()
err_msg = vmclient.delete_vm(VM_ID)
if err_msg:
logger.warning(f"删除vm失败: {err_msg}")
def put_file_to_vm(c: Connection, local_path: str, remote_path: str):
logger.info(f"uploading file {local_path} to {remote_path}")
result = c.put(local_path, remote_path)
logger.info("uploaded {0.local} to {0.remote}".format(result))
def deploy_windows_sut():
global VM_ID
global VM_IP
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
with open(submit_config_filepath, "r") as fp:
st_config = yaml.safe_load(fp)
assert "model" in st_config, "未配置model"
assert "model_key" in st_config, "未配置model_key"
assert "config.json" in st_config, "未配置config.json"
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
assert len(nfs) > 0, "未配置nfs"
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
model = st_config["model"]
model_key = st_config["model_key"]
model_path = ""
config = st_config["config.json"]
exist = False
for nfs_item in nfs:
if nfs_item["name"] == model_key:
exist = True
if nfs_item["source"] == "ceph_customer":
model_path = os.path.join(
"/tmp/customer",
nfs_item["srcRelativePath"],
)
else:
model_path = os.path.join(
"/tmp/juicefs",
nfs_item["srcRelativePath"],
)
break
if not exist:
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
model_dir = os.path.basename(model_path).split(".")[0]
config["model_path"] = f"E:\\model\\{model_dir}"
with open(config_path, "w") as fp:
json.dump(config, fp, ensure_ascii=False, indent=4)
vmclient = Client()
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
sshpublickey = fp.read().rstrip()
VM_ID = vmclient.create_vm(
"amd64",
VMOS.windows10,
VM_CPU,
VM_MEM,
"leaderboard-%s-submit-%s-job-%s"
% (
os.getenv("BENCHMARK_NAME"),
os.getenv("SUBMIT_ID"),
os.getenv("JOB_ID"),
),
sshpublickey,
datadisks=[
VMDataDisk(
size=50,
disk_type="ssd",
mount_path="/",
filesystem="NTFS",
)
],
)
atexit.register(clean_vm_atexit)
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
VM_IP = vmclient.wait_until_vm_running(VM_ID)
logger.info("vm created successfully, vm_ip: %s", VM_IP)
def sut_startup():
with Connection(
VM_IP,
"administrator",
connect_kwargs=CONNECT_KWARGS,
) as c:
script_path = "E:\\base\\asr\\faster-whisper\\server"
script_path = "E:\\install\\asr\\sensevoice\\server"
bat_filepath = f"{script_path}\\start.bat"
config_filepath = "E:\\submit\\config.json"
result = c.run("")
assert result.ok
c.run(
f'cd /d {script_path} & set "EDGE_ML_ENV_HOME=E:\\install" & {bat_filepath} {config_filepath}',
warn=True,
)
with Connection(
VM_IP,
"administrator",
connect_kwargs=CONNECT_KWARGS,
) as c:
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
filename = os.path.basename(model_filepath)
put_file_to_vm(c, model_filepath, "/E:/")
result = c.run("mkdir E:\\base")
assert result.ok
result = c.run("mkdir E:\\model")
assert result.ok
result = c.run("mkdir E:\\submit")
assert result.ok
result = c.run(
f"tar zxvf E:\\{filename} -C E:\\base --strip-components 1"
)
assert result.ok
result = c.run("E:\\base\\setup-win.bat E:\\install")
assert result.ok
put_file_to_vm(c, config_path, "/E:/submit")
put_file_to_vm(c, model_path, "/E:/model")
result = c.run(
f"tar zxvf E:\\model\\{os.path.basename(model_path)} -C E:\\model"
)
assert result.ok
threading.Thread(target=sut_startup, daemon=True).start()
time.sleep(60)
return f"ws://{VM_IP}:{config['port']}"
def deploy_macos_sut():
global VM_ID
global VM_IP
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
with open(submit_config_filepath, "r") as fp:
st_config = yaml.safe_load(fp)
assert "model" in st_config, "未配置model"
assert "model_key" in st_config, "未配置model_key"
assert "config.json" in st_config, "未配置config.json"
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
assert len(nfs) > 0, "未配置nfs"
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
model = st_config["model"]
model_key = st_config["model_key"]
model_path = ""
config = st_config["config.json"]
exist = False
for nfs_item in nfs:
if nfs_item["name"] == model_key:
exist = True
if nfs_item["source"] == "ceph_customer":
model_path = os.path.join(
"/tmp/customer",
nfs_item["srcRelativePath"],
)
else:
model_path = os.path.join(
"/tmp/juicefs",
nfs_item["srcRelativePath"],
)
break
if not exist:
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
model_dir = os.path.basename(model_path).split(".")[0]
vmclient = Client()
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
sshpublickey = fp.read().rstrip()
VM_ID = vmclient.create_vm(
"amd64",
VMOS.macos12,
VM_CPU,
VM_MEM,
"leaderboard-%s-submit-%s-job-%s"
% (
os.getenv("BENCHMARK_NAME"),
os.getenv("SUBMIT_ID"),
os.getenv("JOB_ID"),
),
sshpublickey,
datadisks=[
VMDataDisk(
size=50,
disk_type="ssd",
mount_path="/",
filesystem="apfs",
)
],
)
atexit.register(clean_vm_atexit)
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
VM_IP = vmclient.wait_until_vm_running(VM_ID)
logger.info("vm created successfully, vm_ip: %s", VM_IP)
with Connection(
VM_IP,
"admin",
connect_kwargs=CONNECT_KWARGS,
) as c:
result = c.run("ls -d /Volumes/data*")
assert result.ok
volume_path = result.stdout.strip()
config["model_path"] = f"{volume_path}/model/{model_dir}"
with open(config_path, "w") as fp:
json.dump(config, fp, ensure_ascii=False, indent=4)
def sut_startup():
with Connection(
VM_IP,
"admin",
connect_kwargs=CONNECT_KWARGS,
) as c:
script_path = f"{volume_path}/install/asr/sensevoice/server"
startsh = f"{script_path}/start.sh"
config_filepath = f"{volume_path}/submit/config.json"
c.run(
f"cd {script_path} && sh {startsh} {config_filepath}",
warn=True,
)
with Connection(
VM_IP,
"admin",
connect_kwargs=CONNECT_KWARGS,
) as c:
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
filename = os.path.basename(model_filepath)
put_file_to_vm(c, model_filepath, f"{volume_path}")
result = c.run(f"mkdir {volume_path}/base")
assert result.ok
result = c.run(f"mkdir {volume_path}/model")
assert result.ok
result = c.run(f"mkdir {volume_path}/submit")
assert result.ok
result = c.run(
f"tar zxvf {volume_path}/{filename} -C {volume_path}/base --strip-components 1" # noqa: E501
)
assert result.ok
result = c.run(
f"sh {volume_path}/base/setup-mac.sh {volume_path}/install x64"
)
assert result.ok
put_file_to_vm(c, config_path, f"{volume_path}/submit")
put_file_to_vm(c, model_path, f"{volume_path}/model")
result = c.run(
f"tar zxvf {volume_path}/model/{os.path.basename(model_path)} -C {volume_path}/model" # noqa: E501
)
assert result.ok
threading.Thread(target=sut_startup, daemon=True).start()
time.sleep(60)
return f"ws://{VM_IP}:{config['port']}"
def get_sut_url_vm(vm_type: str):
global VM_ID
global VM_IP
global do_deploy_chart
do_deploy_chart = True
# 拉起SUT
def check_job_failed():
while True:
time.sleep(30)
if os.path.exists(SUT_SHARE_PUBLIC_FAIL):
logger.error("there is a job failed in current submit")
sys.exit(1)
sut_url = ""
threading.Thread(target=check_job_failed, daemon=True).start()
if SHARE_SUT:
time.sleep(10 * random.random())
try:
open(SUT_SHARE_LOCK, "x").close()
except Exception:
do_deploy_chart = False
start_at = time.time()
def file_last_updated_at(file: str):
return os.stat(file).st_mtime if os.path.exists(file) else start_at
if not do_deploy_chart:
with open(SUT_SHARE_JOB_STATUS, "w") as f:
f.write("waiting")
while (
time.time() - file_last_updated_at(SUT_SHARE_STATUS)
<= 60 * 60 * 24
):
logger.info(
"Waiting sut application to be deployed by another job"
)
time.sleep(10 + random.random())
if os.path.exists(SUT_SHARE_STATUS):
get_status = False
for _ in range(10):
try:
with open(SUT_SHARE_STATUS, "r") as f:
status = json.load(f)
get_status = True
break
except Exception:
time.sleep(1 + random.random())
continue
if not get_status:
raise RuntimeError(
"Failed to get status of sut application"
)
assert (
status.get("status") != "failed"
), "Failed to deploy sut application, \
please check other job logs"
if status.get("status") == "running":
VM_ID = status.get("vmid")
VM_IP = status.get("vmip")
sut_url = status.get("sut_url")
with open(SSH_PUBLIC_KEY_FILE, "w") as fp:
fp.write(status.get("pubkey"))
with open(SSH_KEY_FILE, "w") as fp:
fp.write(status.get("prikey"))
logger.info("Successfully get deployed sut application")
break
if do_deploy_chart:
try:
fcntl.flock(fd_lock, fcntl.LOCK_EX)
with open(SUT_SHARE_JOB_STATUS, "w") as f:
f.write("waiting")
pending = True
def update_status():
while pending:
time.sleep(30)
if not pending:
break
with open(SUT_SHARE_STATUS, "w") as f:
json.dump({"status": "pending"}, f)
threading.Thread(target=update_status, daemon=True).start()
if vm_type == "windows":
sut_url = deploy_windows_sut()
else:
sut_url = deploy_macos_sut()
except Exception:
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
with open(SUT_SHARE_STATUS, "w") as f:
json.dump({"status": "failed"}, f)
raise
finally:
pending = False
with open(SUT_SHARE_STATUS, "w") as f:
pubkey = ""
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
pubkey = fp.read().rstrip()
prikey = ""
with open(SSH_KEY_FILE, "r") as fp:
prikey = fp.read()
json.dump(
{
"status": "running",
"vmid": VM_ID,
"vmip": VM_IP,
"pubkey": pubkey,
"sut_url": sut_url,
"prikey": prikey,
},
f,
)
else:
while True:
time.sleep(5 + random.random())
try:
fcntl.flock(fd_lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
break
except Exception:
logger.info("尝试抢占调用sut失败继续等待 5s ...")
with open(SUT_SHARE_JOB_STATUS, "w") as f:
f.write("running")
return sut_url
def get_sut_url():
if SUT_TYPE in ("windows", "macos"):
return get_sut_url_vm(SUT_TYPE)
submit_config_filepath = os.getenv(
"SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config"
)
CPU = os.getenv("SUT_CPU", "2")
MEMORY = os.getenv("SUT_MEMORY", "4Gi")
resource_name = os.getenv("BENCHMARK_NAME")
# 任务信息
# 斯拉夫语族:俄语、波兰语
# 日耳曼语族:英语、德语、荷兰语
# 拉丁语族(罗曼语族):西班牙语、葡萄牙语、法国语、意大利语
# 闪米特语族:阿拉伯语、希伯来语
# 提交配置 & 启动被测服务
if os.getenv("DATASET_FILEPATH", ""):
with open(submit_config_filepath, "r") as fp:
st_config = yaml.safe_load(fp)
if "values" not in st_config:
st_config["values"] = {}
st_config["values"]["resources"] = {}
st_config["values"]["resources"]["limits"] = {}
st_config["values"]["resources"]["limits"]["cpu"] = CPU
st_config["values"]["resources"]["limits"]["memory"] = MEMORY
# st_config["values"]['resources']['limits']['nvidia.com/gpu'] = '1'
# st_config["values"]['resources']['limits']['nvidia.com/gpumem'] = "1843"
# st_config["values"]['resources']['limits']['nvidia.com/gpucores'] = "8"
st_config["values"]["resources"]["requests"] = {}
st_config["values"]["resources"]["requests"]["cpu"] = CPU
st_config["values"]["resources"]["requests"]["memory"] = MEMORY
# st_config["values"]['resources']['requests']['nvidia.com/gpu'] = '1'
# st_config["values"]['resources']['requests']['nvidia.com/gpumem'] = "1843"
# st_config["values"]['resources']['requests']['nvidia.com/gpucores'] = "8"
# st_config['values']['nodeSelector'] = {}
# st_config["values"]["nodeSelector"][
# "contest.4pd.io/accelerator"
# ] = "A10vgpu"
# st_config['values']['tolerations'] = []
# toleration_item = {}
# toleration_item['key'] = 'hosttype'
# toleration_item['operator'] = 'Equal'
# toleration_item['value'] = 'vgpu'
# toleration_item['effect'] = 'NoSchedule'
# st_config['values']['tolerations'].append(toleration_item)
if os.getenv("RESOURCE_TYPE", "cpu") == "cpu":
values = st_config["values"]
limits = values.get("resources", {}).get("limits", {})
requests = values.get("resources", {}).get("requests", {})
if (
"nvidia.com/gpu" in limits
or "nvidia.com/gpumem" in limits
or "nvidia.com/gpucores" in limits
or "nvidia.com/gpu" in requests
or "nvidia.com/gpumem" in requests
or "nvidia.com/gpucores" in requests
):
raise Exception("禁止使用GPU!")
else:
vgpu_num = int(os.getenv("SUT_VGPU", "3"))
st_config["values"]["resources"]["limits"]["nvidia.com/gpu"] = (
str(vgpu_num)
)
st_config["values"]["resources"]["limits"][
"nvidia.com/gpumem"
] = str(1843 * vgpu_num)
st_config["values"]["resources"]["limits"][
"nvidia.com/gpucores"
] = str(8 * vgpu_num)
st_config["values"]["resources"]["requests"][
"nvidia.com/gpu"
] = str(vgpu_num)
st_config["values"]["resources"]["requests"][
"nvidia.com/gpumem"
] = str(1843 * vgpu_num)
st_config["values"]["resources"]["requests"][
"nvidia.com/gpucores"
] = str(8 * vgpu_num)
st_config["values"]["nodeSelector"] = {}
st_config["values"]["nodeSelector"][
"contest.4pd.io/accelerator"
] = "A10vgpu"
st_config["values"]["tolerations"] = []
toleration_item = {}
toleration_item["key"] = "hosttype"
toleration_item["operator"] = "Equal"
toleration_item["value"] = "vgpu"
toleration_item["effect"] = "NoSchedule"
st_config["values"]["tolerations"].append(toleration_item)
if "docker_images" in st_config:
sut_url = "ws://172.26.1.75:9827"
os.environ["test"] = "1"
elif "docker_image" in st_config:
sut_url = register_sut(st_config, resource_name)
elif UNIT_TEST:
sut_url = "ws://172.27.231.36:80"
else:
logger.error("config 配置错误,没有 docker_image")
os._exit(1)
return sut_url
else:
os.environ["test"] = "1"
sut_url = "ws://172.27.231.36:80"
sut_url = "ws://172.26.1.75:9827"
return sut_url
def load_merge_dataset(dataset_filepath: str) -> dict:
local_dataset_path = "./dataset"
os.makedirs(local_dataset_path, exist_ok=True)
with zipfile.ZipFile(dataset_filepath) as zf:
zf.extractall(local_dataset_path)
config = {}
sub_datasets = os.listdir(local_dataset_path)
for sub_dataset in sub_datasets:
if sub_dataset.startswith("asr."):
lang = sub_dataset[4:]
lang_path = os.path.join(local_dataset_path, lang)
os.makedirs(lang_path, exist_ok=True)
with zipfile.ZipFile(
os.path.join(local_dataset_path, sub_dataset)
) as zf:
zf.extractall(lang_path)
lang_config_path = os.path.join(lang_path, "data.yaml")
with open(lang_config_path, "r") as fp:
lang_config = yaml.safe_load(fp)
audio_lengths = {}
for query_item in lang_config.get("query_data", []):
audio_path = os.path.join(
lang_path,
query_item["file"],
)
query_item["file"] = audio_path
audio_lengths[query_item["file"]] = os.path.getsize(
audio_path,
)
lang_config["query_data"] = sorted(
lang_config.get("query_data", []),
key=lambda x: audio_lengths[x["file"]],
reverse=True,
)
idx = 0
length = 0.0
for query_item in lang_config["query_data"]:
audio_length = audio_lengths[query_item["file"]]
length += audio_length / 32000
idx += 1
# 每个语言限制半个小时长度
if length >= 30 * 60:
break
lang_config["query_data"] = lang_config["query_data"][:idx]
config[lang] = lang_config
config["query_data"] = []
for lang, lang_config in config.items():
if lang == "query_data":
continue
for query_item in lang_config["query_data"]:
config["query_data"].append(
{
**query_item,
"lang": lang,
}
)
random.Random(0).shuffle(config["query_data"])
return config
def postprocess_failed():
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
def main():
dataset_filepath = os.getenv(
"DATASET_FILEPATH",
"/Users/4paradigm/Projects/dataset/asr/de.zip",
# "./tests/resources/en.zip",
)
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
detail_cases_filepath = os.getenv(
"DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl"
)
thread_num = int(os.getenv("THREAD_NUM", "1"))
# 数据集处理
config = {}
if os.getenv("MERGE_DATASET", "1"):
config = load_merge_dataset(dataset_filepath)
dataset_query = config["query_data"]
else:
local_dataset_path = "./dataset"
os.makedirs(local_dataset_path, exist_ok=True)
with zipfile.ZipFile(dataset_filepath) as zf:
zf.extractall(local_dataset_path)
config_path = os.path.join(local_dataset_path, "data.yaml")
with open(config_path, "r") as fp:
dataset_config = yaml.safe_load(fp)
# 读取所有的音频,进而获得音频的总长度,最后按照音频长度对 query_data 进行降序排序
lang = os.getenv("lang")
if lang is None:
lang = dataset_config.get("global", {}).get("lang", "en")
audio_lengths = []
for query_item in dataset_config.get("query_data", []):
query_item["lang"] = lang
audio_path = os.path.join(local_dataset_path, query_item["file"])
query_item["file"] = audio_path
audio_lengths.append(os.path.getsize(audio_path) / 1024 / 1024)
dataset_config["query_data"] = sorted(
dataset_config.get("query_data", []),
key=lambda x: audio_lengths[dataset_config["query_data"].index(x)],
reverse=True,
)
# 数据集信息
# dataset_global_config = dataset_config.get("global", {})
dataset_query = dataset_config.get("query_data", {})
config[lang] = dataset_config
# sut url
sut_url = get_sut_url()
try:
# 开始测试
logger.info("开始执行")
evaluator = BaseEvaluator()
future_list = []
with ThreadPoolExecutor(max_workers=thread_num) as executor:
for idx, query_item in enumerate(dataset_query):
context = ASRContext(
**config[query_item["lang"]].get("global", {}),
)
context.lang = query_item["lang"]
context.file_path = query_item["file"]
context.append_labels(query_item["voice"])
future = executor.submit(
ClientAsync(sut_url, context, idx).action
)
future_list.append(future)
for future in concurrent.futures.as_completed(future_list):
context = future.result()
evaluator.evaluate(context)
detail_case = evaluator.gen_detail_case()
with open(detail_cases_filepath, "a") as fp:
fp.write(
json.dumps(
detail_case.to_dict(),
ensure_ascii=False,
)
+ "\n",
)
del context
gc.collect()
evaluator.post_evaluate()
output_result = evaluator.gen_result()
logger.info("执行完成")
with open(result_filepath, "w") as fp:
json.dump(output_result, fp, indent=2, ensure_ascii=False)
with open(bad_cases_filepath, "w") as fp:
fp.write("当前榜单不存在 Bad Case\n")
if SHARE_SUT:
with open(SUT_SHARE_JOB_STATUS, "w") as f:
f.write("success")
fcntl.flock(fd_lock, fcntl.LOCK_UN)
fd_lock.close()
while SHARE_SUT and do_deploy_chart:
time.sleep(30)
success_num = 0
for job_status_file in glob.glob(dirname + "/job_status.*"):
with open(job_status_file, "r") as f:
job_status = f.read()
success_num += job_status == "success"
if success_num == int(DATASET_NUM):
break
logger.info("Waiting for all jobs to complete")
except Exception:
if SHARE_SUT:
postprocess_failed()
raise
sys.exit(0)
if __name__ == "__main__":
main()

923
run_callback.py Normal file
View File

@@ -0,0 +1,923 @@
import json
import os
import sys
import time
import tempfile
import zipfile
import threading
from collections import defaultdict
from typing import Dict, List
import yaml
from pydantic import ValidationError
from schemas.dataset import QueryData
from utils.client_callback import ClientCallback, EvaluateResult, StopException
from utils.logger import log
from utils.service import register_sut
from utils.update_submit import change_product_available
from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml
from utils.leaderboard import change_product_unavailable
lck = threading.Lock()
# Environment variables by leaderboard
DATASET_FILEPATH = os.environ["DATASET_FILEPATH"]
RESULT_FILEPATH = os.environ["RESULT_FILEPATH"]
DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"]
SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"]
BENCHMARK_NAME = os.environ["BENCHMARK_NAME"]
TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1))
THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8))
log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}")
workspace_path = "/tmp/workspace"
# Environment variables by kubernetes
MY_POD_IP = os.environ["MY_POD_IP"]
# constants
RESOURCE_NAME = BENCHMARK_NAME
# Environment variables by judge_flow_config
LANG = os.getenv("lang")
SUT_CPU = os.getenv("SUT_CPU", "2")
SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi")
SUT_VGPU = os.getenv("SUT_VGPU", "1")
#SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU)))
#SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU)))
SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100")
RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu")
assert RESOURCE_TYPE in [
"cpu",
"vgpu",
], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu"
unzip_dir(DATASET_FILEPATH, workspace_path)
def get_sut_url_kubernetes():
with open(SUBMIT_CONFIG_FILEPATH, "r") as f:
submit_config = yaml.safe_load(f)
assert isinstance(submit_config, dict)
submit_config.setdefault("values", {})
submit_config["values"]["containers"] = [
{
"name": "corex-container",
"image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像
"command": ["sleep"], # 替换为你的模型启动命令使用python解释器
"args": ["3600"], # 替换为你的模型参数,运行我的推理脚本
# 添加存储卷挂载
#"volumeMounts": [
# {
# "name": "model-volume",
# "mountPath": "/model" # 挂载到/model目录
# }
#]
}
]
"""
# 添加存储卷配置
submit_config["values"]["volumes"] = [
{
"name": "model-volume",
"persistentVolumeClaim": {
"claimName": "sid-model-pvc" # 使用已有的PVC
}
}
]
"""
"""
# Inject specified cpu and memory
resource = {
"cpu": SUT_CPU,
"memory": SUT_MEMORY,
}
"""
submit_config["values"]["resources"] = {
"requests":{},
"limits": {},
}
limits = submit_config["values"]["resources"]["limits"]
requests = submit_config["values"]["resources"]["requests"]
"""
# ########## 关键修改替换为iluvatar GPU配置 ##########
if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU
# 替换nvidia资源键为iluvatar.ai/gpu
vgpu_resource = {
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
# 若需要其他资源如显存按你的K8s配置补充例如
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
}
limits.update(vgpu_resource)
requests.update(vgpu_resource)
# 节点选择器替换为你的accelerator标签
submit_config["values"]["nodeSelector"] = {
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
}
# 容忍度替换为你的tolerations配置
submit_config["values"]["tolerations"] = [
{
"key": "hosttype",
"operator": "Equal",
"value": "iluvatar",
"effect": "NoSchedule",
}
]
# #########################################
# 禁止CPU模式下使用GPU资源保持原逻辑
else:
if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests:
log.error("禁止在CPU模式下使用GPU资源")
sys.exit(1)
#gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键
#for key in gpukeys:
# if key in limits or key in requests:
# log.error("禁止使用vgpu资源")
# sys.exit(1)
"""
# 替换nvidia资源键为iluvatar.ai/gpu
vgpu_resource = {
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
# 若需要其他资源如显存按你的K8s配置补充例如
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
}
limits.update(vgpu_resource)
requests.update(vgpu_resource)
# 节点选择器替换为你的accelerator标签
submit_config["values"]["nodeSelector"] = {
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
}
# 容忍度替换为你的tolerations配置
"""
submit_config["values"]["tolerations"] = [
{
"key": "hosttype",
"operator": "Equal",
"value": "iluvatar",
"effect": "NoSchedule",
},
{
"key": "hosttype",
"operator": "Equal",
"value": "arm64",
"effect": "NoSchedule",
},
{
"key": "hosttype",
"operator": "Equal",
"value": "myinit",
"effect": "NoSchedule",
},
{
"key": "hosttype",
"operator": "Equal",
"value": "middleware",
"effect": "NoSchedule",
}
]
"""
"""
{
"key": "node-role.kubernetes.io/master",
"operator": "Exists",
"effect": "NoSchedule",
},
{
"key": "node.kubernetes.io/not-ready",
"operator": "Exists",
"effect": "NoExecute",
"tolerationSeconds": 300
},
{
"key": "node.kubernetes.io/unreachable",
"operator": "Exists",
"effect": "NoExecute",
"tolerationSeconds": 300
}
"""
log.info(f"submit_config: {submit_config}")
log.info(f"RESOURCE_NAME: {RESOURCE_NAME}")
return register_sut(submit_config, RESOURCE_NAME).replace(
"ws://", "http://"
)
def get_sut_url():
return get_sut_url_kubernetes()
#SUT_URL = get_sut_url()
#os.environ["SUT_URL"] = SUT_URL
#############################################################################
import requests
import base64
def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
"""
生成请求的body
:param apiname
:param APPId: Appid
:param file_name: 文件路径
:return:
"""
if apiname == 'createFeature':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "createFeature",
"groupId": "test_voiceprint_e",
"featureId": featureId,
"featureInfo": featureInfo,
"createFeatureRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'createGroup':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "createGroup",
"groupId": "test_voiceprint_e",
"groupName": "vip_user",
"groupInfo": "store_vip_user_voiceprint",
"createGroupRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
elif apiname == 'deleteFeature':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "deleteFeature",
"groupId": "iFLYTEK_examples_groupId",
"featureId": "iFLYTEK_examples_featureId",
"deleteFeatureRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
elif apiname == 'queryFeatureList':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "queryFeatureList",
"groupId": "user_voiceprint_2",
"queryFeatureListRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
elif apiname == 'searchFea':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "searchFea",
"groupId": "test_voiceprint_e",
"topK": 1,
"searchFeaRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'searchScoreFea':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "searchScoreFea",
"groupId": "test_voiceprint_e",
"dstFeatureId": dstFeatureId,
"searchScoreFeaRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'updateFeature':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "updateFeature",
"groupId": "iFLYTEK_examples_groupId",
"featureId": "iFLYTEK_examples_featureId",
"featureInfo": "iFLYTEK_examples_featureInfo_update",
"updateFeatureRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'deleteGroup':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "deleteGroup",
"groupId": "iFLYTEK_examples_groupId",
"deleteGroupRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
else:
raise Exception(
"输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查")
return body
log.info(f"开始请求获取到SUT服务URL")
# 获取SUT服务URL
sut_url = get_sut_url()
print(f"获取到的SUT_URL: {sut_url}") # 调试输出
log.info(f"获取到SUT服务URL: {sut_url}")
from urllib.parse import urlparse
# 全局变量
text_decoded = None
###################################新增新增################################
def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
"""
开始请求
:param APPId: APPID
:param file_path: body里的文件路径
:return:
"""
global text_decoded
body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId)
#request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996'
#request_url = 'https://sut:80/sid/v1/private/s782b4996'
#headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId}
parsed_url = urlparse(sut_url)
headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId}
# 1. 首先测试服务健康检查
response = requests.get(f"{sut_url}/health")
print(response.status_code, response.text)
# 请求头
headers = {"Content-Type": "application/json"}
# 请求体(可指定限制处理的图片数量)
body = {"limit": 20 } # 可选参数,限制处理的图片总数
# 发送POST请求
response = requests.post(
f"{sut_url}/v1/private/s782b4996",
data=json.dumps(body),
headers=headers
)
# 解析响应结果
if response.status_code == 200:
result = response.json()
print("预测评估结果:")
print(f"准确率: {result['metrics']['accuracy']}%")
print(f"平均召回率: {result['metrics']['average_recall']}%")
print(f"处理图片总数: {result['metrics']['total_images']}")
else:
print(f"请求失败,状态码: {response.status_code}")
print(f"错误信息: {response.text}")
# 添加基本认证信息
auth = ('llm', 'Rmf4#LcG(iFZrjU;2J')
#response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth)
#response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth)
#response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth)
"""
response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers)
"""
#print("HTTP状态码:", response.status_code)
#print("原始响应内容:", response.text) # 先打印原始内容
#print(f"请求URL: {sut_url + '/v1/private/s782b4996'}")
#print(f"请求headers: {headers}")
#print(f"请求body: {body}")
#tempResult = json.loads(response.content.decode('utf-8'))
#print(tempResult)
"""
# 对text字段进行Base64解码
if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']:
text_encoded = tempResult['payload']['updateFeatureRes']['text']
text_decoded = base64.b64decode(text_encoded).decode('utf-8')
print(f"Base64解码后的text字段内容: {text_decoded}")
"""
#text_encoded = tempResult['payload']['updateFeatureRes']['text']
#text_decoded = base64.b64decode(text_encoded).decode('utf-8')
#print(f"Base64解码后的text字段内容: {text_decoded}")
# 获取响应的 JSON 数据
result = response.json()
with open(RESULT_FILEPATH, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
print(f"结果已成功写入 {RESULT_FILEPATH}")
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
#detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
from typing import Any, Dict, List
def result2file(
result: Dict[str, Any],
detail_cases: List[Dict[str, Any]] = None
):
assert result_filepath is not None
assert bad_cases_filepath is not None
#assert detailed_cases_filepath is not None
if result is not None:
with open(result_filepath, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
#if LOCAL_TEST:
# logger.info(f'result:\n {json.dumps(result, indent=4)}')
"""
if detail_cases is not None:
with open(detailed_cases_filepath, "w") as f:
json.dump(detail_cases, f, indent=4, ensure_ascii=False)
if LOCAL_TEST:
logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}')
"""
def test_image_prediction(sut_url, image_path):
"""发送单张图片到服务端预测"""
url = f"{sut_url}/v1/private/s782b4996"
try:
with open(image_path, 'rb') as f:
files = {'image': f}
response = requests.post(url, files=files, timeout=30)
result = response.json()
if result.get('status') != 'success':
return None, f"服务端错误: {result.get('message')}"
return result, None
except Exception as e:
return None, f"请求错误: {str(e)}"
import random
import time
#from tqdm import tqdm
import os
import requests
if __name__ == '__main__':
print(f"\n===== main开始请求接口 ===============================================")
# 1. 首先测试服务健康检查
print(f"\n===== 服务健康检查 ===================================================")
response = requests.get(f"{sut_url}/health")
print(response.status_code, response.text)
"""
# 本地图片路径和真实标签(根据实际情况修改)
image_path = "/path/to/your/test_image.jpg"
true_label = "cat" # 图片的真实标签
"""
"""
# 请求头
headers = {"Content-Type": "application/json"}
# 请求体(可指定限制处理的图片数量)
body = {"limit": 20 } # 可选参数,限制处理的图片总数
# 发送POST请求
response = requests.post(
f"{sut_url}/v1/private/s782b4996",
data=json.dumps(body),
headers=headers
)
"""
"""
# 读取图片文件
with open(image_path, 'rb') as f:
files = {'image': f}
# 发送POST请求
response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files)
# 解析响应结果
if response.status_code == 200:
result = response.json()
print("预测评估结果:")
print(f"准确率: {result['metrics']['accuracy']}%")
print(f"平均召回率: {result['metrics']['average_recall']}%")
print(f"处理图片总数: {result['metrics']['total_images']}")
else:
print(f"请求失败,状态码: {response.status_code}")
print(f"错误信息: {response.text}")
"""
###############################################################################################
dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录
samples_per_class = 3 # 每个类别抽取的样本数
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式
# 结果统计变量
total_samples = 0
#correct_predictions = 0
# GPU统计
gpu_true_positives = 0
gpu_false_positives = 0
gpu_false_negatives = 0
gpu_total_processing_time = 0.0
# CPU统计
cpu_true_positives = 0
cpu_false_positives = 0
cpu_false_negatives = 0
cpu_total_processing_time = 0.0
"""
# 遍历所有类别文件夹
for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"):
folder_path = os.path.join(dataset_root, folder_name)
# 提取类别名(从"序号.name"格式中提取name部分
class_name = folder_name.split('.', 1)[1].strip().lower()
# 获取文件夹中所有图片
image_files = []
for file in os.listdir(folder_path):
if file.lower().endswith(image_extensions):
image_files.append(os.path.join(folder_path, file))
# 随机抽取指定数量的图片(如果不足则取全部)
selected_images = random.sample(
image_files,
min(samples_per_class, len(image_files))
)
# 处理选中的图片
for img_path in selected_images:
total_count += 1
# 发送预测请求
prediction, error = test_image_prediction(sut_url, img_path)
if error:
print(f"处理图片 {img_path} 失败: {error}")
continue
# 解析预测结果
pred_class = prediction.get('class_name', '').lower()
confidence = prediction.get('confidence', 0)
# 判断是否预测正确(真实类别是否在预测类别中)
if class_name in pred_class:
correct_predictions += 1
# 可选:打印详细结果
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}")
"""
# 遍历所有类别文件夹
for folder_name in os.listdir(dataset_root):
folder_path = os.path.join(dataset_root, folder_name)
# 跳过非文件夹的项目
if not os.path.isdir(folder_path):
continue
# 提取类别名(从"序号.name"格式中提取name部分
try:
class_name = folder_name.split('.', 1)[1].strip().lower()
except IndexError:
print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹")
continue
# 获取文件夹中所有图片
image_files = []
for file in os.listdir(folder_path):
file_path = os.path.join(folder_path, file)
if os.path.isfile(file_path) and file.lower().endswith(image_extensions):
image_files.append(file_path)
# 随机抽取指定数量的图片(如果不足则取全部)
selected_images = random.sample(
image_files,
min(samples_per_class, len(image_files))
)
for img_path in selected_images:
total_samples += 1
# 获取预测结果
prediction, error = test_image_prediction(sut_url, img_path)
# 打印test_image_prediction返回的结果
print(f"test_image_prediction返回的prediction: {prediction}")
print(f"test_image_prediction返回的error: {error}")
if error:
print(f"处理图片 {img_path} 失败: {error}")
continue
# 解析GPU预测结果
gpu_pred = prediction.get('cuda_prediction', {})
gpu_pred_class = gpu_pred.get('class_name', '').lower()
gpu_processing_time = gpu_pred.get('processing_time', 0.0)
# 解析CPU预测结果
cpu_pred = prediction.get('cpu_prediction', {})
cpu_pred_class = cpu_pred.get('class_name', '').lower()
cpu_processing_time = cpu_pred.get('processing_time', 0.0)
# 判断GPU预测是否正确
gpu_is_correct = class_name in gpu_pred_class
if gpu_is_correct:
gpu_true_positives += 1
else:
gpu_false_positives += 1
gpu_false_negatives += 1
# 判断CPU预测是否正确
cpu_is_correct = class_name in cpu_pred_class
if cpu_is_correct:
cpu_true_positives += 1
else:
cpu_false_positives += 1
cpu_false_negatives += 1
# 累加处理时间
gpu_total_processing_time += gpu_processing_time
cpu_total_processing_time += cpu_processing_time
# 打印详细结果
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name}")
print(f"GPU预测: {gpu_pred_class} | {'正确' if gpu_is_correct else '错误'} | 耗时: {gpu_processing_time:.6f}s")
print(f"CPU预测: {cpu_pred_class} | {'正确' if cpu_is_correct else '错误'} | 耗时: {cpu_processing_time:.6f}s")
print("-" * 50)
"""
# 计算整体指标(在单标签场景下,准确率=召回率)
if total_samples == 0:
overall_accuracy = 0.0
overall_recall = 0.0
else:
overall_accuracy = correct_predictions / total_samples
overall_recall = correct_predictions / total_samples # 整体召回率
# 输出统计结果
print("\n" + "="*50)
print(f"测试总结:")
print(f"总测试样本数: {total_samples}")
print(f"正确预测样本数: {correct_predictions}")
print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})")
print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})")
print("="*50)
"""
# 初始化结果字典
result = {
# GPU指标
"gpu_accuracy": 0.0,
"gpu_recall": 0.0,
"gpu_running_time": round(gpu_total_processing_time, 6),
"gpu_throughput": 0.0,
# CPU指标
"cpu_accuracy": 0.0,
"cpu_recall": 0.0,
"cpu_running_time": round(cpu_total_processing_time, 6),
"cpu_throughput": 0.0
}
# 计算GPU指标
gpu_accuracy = gpu_true_positives / total_samples * 100
gpu_recall_denominator = gpu_true_positives + gpu_false_negatives
gpu_recall = gpu_true_positives / gpu_recall_denominator * 100 if gpu_recall_denominator > 0 else 0
gpu_throughput = total_samples / gpu_total_processing_time if gpu_total_processing_time > 1e-6 else 0
# 计算CPU指标
cpu_accuracy = cpu_true_positives / total_samples * 100
cpu_recall_denominator = cpu_true_positives + cpu_false_negatives
cpu_recall = cpu_true_positives / cpu_recall_denominator * 100 if cpu_recall_denominator > 0 else 0
cpu_throughput = total_samples / cpu_total_processing_time if cpu_total_processing_time > 1e-6 else 0
# 更新结果字典
result.update({
"gpu_accuracy": round(gpu_accuracy, 6),
"gpu_recall": round(gpu_recall, 6),
"gpu_throughput": round(gpu_throughput, 6),
"cpu_accuracy": round(cpu_accuracy, 6),
"cpu_recall": round(cpu_recall, 6),
"cpu_throughput": round(cpu_throughput, 6)
})
# 打印最终统计结果
print("\n" + "="*50)
print(f"总样本数: {total_samples}")
print("\nGPU指标:")
print(f"准确率: {result['gpu_accuracy']:.4f}%")
print(f"召回率: {result['gpu_recall']:.4f}%")
print(f"总运行时间: {result['gpu_running_time']:.6f}s")
print(f"吞吐量: {result['gpu_throughput']:.2f}张/秒")
print("\nCPU指标:")
print(f"准确率: {result['cpu_accuracy']:.4f}%")
print(f"召回率: {result['cpu_recall']:.4f}%")
print(f"总运行时间: {result['cpu_running_time']:.6f}s")
print(f"吞吐量: {result['cpu_throughput']:.2f}张/秒")
print("="*50)
#result = {}
#result['accuracy_1_1'] = 3
result2file(result)
if abs(gpu_accuracy - cpu_accuracy) > 3:
log.error(f"gpu与cpu准确率差别超过3%,模型结果不正确")
change_product_unavailable()
"""
if result['accuracy_1_1'] < 0.9:
log.error(f"1:1正确率未达到90%, 视为产品不可用")
change_product_unavailable()
if result['accuracy_1_N'] < 1:
log.error(f"1:N正确率未达到100%, 视为产品不可用")
change_product_unavailable()
if result['1_1_latency'] > 0.5:
log.error(f"1:1平均latency超过0.5s, 视为产品不可用")
change_product_unavailable()
if result['1_N_latency'] > 0.5:
log.error(f"1:N平均latency超过0.5s, 视为产品不可用")
change_product_unavailable()
if result['enroll_latency'] > 1:
log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用")
change_product_unavailable()
"""
exit_code = 0

1193
run_callback_cuda.py Normal file

File diff suppressed because it is too large Load Diff

1296
run_callback_new.py Normal file

File diff suppressed because it is too large Load Diff

0
schemas/__init__.py Normal file
View File

90
schemas/context.py Normal file
View File

@@ -0,0 +1,90 @@
import os
from copy import deepcopy
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from schemas.stream import StreamDataModel
class LabelContext(BaseModel):
start: float
end: float
answer: str
class PredContext(BaseModel):
recognition_results: StreamDataModel
recv_time: Optional[float] = Field(None)
send_time: Optional[float] = Field(None)
class ASRContext:
def __init__(self, **kwargs):
self.bits = kwargs.get("bits", 16)
self.channel = kwargs.get("channel", 1)
self.sample_rate = kwargs.get("sample_rate", 16000)
self.audio_format = kwargs.get("format", "wav")
self.enable_words = kwargs.get("enable_words", True)
self.char_contains_rate = kwargs.get("char_contains_rate", 0.8)
self.lang = os.getenv("lang")
if self.lang is None:
self.lang = kwargs.get("lang", "en")
self.stream = kwargs.get("stream", True)
self.wait_time = float(os.getenv("wait_time", 0.1))
self.chunk_size = self.sample_rate * self.bits / 8 * self.wait_time
if int(os.getenv('chunk_size_set', 0)):
self.chunk_size = int(os.getenv('chunk_size_set', 0))
self.audio_length = 0
self.file_path = ""
self.labels: List[LabelContext] = kwargs.get("labels", [])
self.preds: List[PredContext] = kwargs.get("preds", [])
self.label_sentences: List[str] = []
self.pred_sentences: List[str] = []
self.send_time_start_end = []
self.recv_time_start_end = []
self.fail = False
self.fail_char_contains_rate_num = 0
self.punctuation_num = 0
self.pred_punctuation_num = 0
def append_labels(self, voices: List[Dict]):
for voice_data in voices:
label_context = LabelContext(**voice_data)
self.labels.append(label_context)
def append_preds(
self,
predict_data: List[StreamDataModel],
send_time: List[float],
recv_time: List[float],
):
self.send_time_start_end = [send_time[0], send_time[-1]] if len(send_time) > 0 else []
self.recv_time_start_end = [recv_time[0], recv_time[-1]] if len(recv_time) > 0 else []
for pred_item, send_time_item, recv_time_item in zip(predict_data, send_time, recv_time):
pred_item = deepcopy(pred_item)
pred_context = PredContext(recognition_results=pred_item.model_dump())
pred_context.send_time = send_time_item
pred_context.recv_time = recv_time_item
self.preds.append(pred_context)
def to_dict(self):
return {
"bits": self.bits,
"channel": self.channel,
"sample_rate": self.sample_rate,
"audio_format": self.audio_format,
"enable_words": self.enable_words,
"stream": self.stream,
"wait_time": self.wait_time,
"chunk_size": self.chunk_size,
"labels": [item.model_dump_json() for item in self.labels],
"preds": [item.model_dump_json() for item in self.preds],
}

18
schemas/dataset.py Normal file
View File

@@ -0,0 +1,18 @@
from typing import List
from pydantic import BaseModel, Field
class QueryDataSentence(BaseModel):
answer: str = Field(description="文本label")
start: float = Field(description="句子开始时间")
end: float = Field(description="句子结束时间")
class QueryData(BaseModel):
lang: str = Field(description="语言")
file: str = Field(description="音频文件位置")
duration: float = Field(description="音频长度")
voice: List[QueryDataSentence] = Field(
description="音频文件的文本label内容"
)

66
schemas/stream.py Normal file
View File

@@ -0,0 +1,66 @@
from typing import List
from pydantic import BaseModel, ValidationError, field_validator
from pydantic import model_validator
class StreamWordsModel(BaseModel):
text: str
start_time: float
end_time: float
@model_validator(mode="after")
def check_result(self):
if self.end_time < self.start_time:
raise ValidationError("end-time 小于 start-time, error")
return self
class StreamDataModel(BaseModel):
text: str
language: str
final_result: bool
para_seq: int
start_time: float
end_time: float
words: List[StreamWordsModel]
@model_validator(mode="after")
def check_result(self):
if self.end_time < self.start_time:
raise ValidationError("end-time 小于 start-time, error")
return self
class StreamResultModel(BaseModel):
asr_results: StreamDataModel
@field_validator('asr_results', mode="after")
def convert_to_seconds(cls, v: StreamDataModel, values):
# 在这里处理除以1000的逻辑
v.end_time = v.end_time / 1000
v.start_time = v.start_time / 1000
for word in v.words:
word.start_time /= 1000
word.end_time /= 1000
return v
class Config:
validate_assignment = True
class NonStreamDataModel(BaseModel):
text: str
para_seq: int
start_time: float
end_time: float
@model_validator(mode="after")
def check_result(self):
if self.end_time < self.start_time:
raise ValidationError("end-time 小于 start-time, error")
return self
class NonStreamResultModel(BaseModel):
contents: List[NonStreamDataModel]

View File

@@ -0,0 +1,53 @@
import os
import sys
from collections import defaultdict
import yaml
def main(dataset_dir):
dirs = os.listdir(dataset_dir)
dirs = list(
filter(lambda x: os.path.isdir(os.path.join(dataset_dir, x)), dirs)
)
problem_dirs = set()
problem_count = defaultdict(int)
for dir in dirs:
with open(os.path.join(dataset_dir, dir, "data.yaml"), "r") as f:
data = yaml.full_load(f)
for query_i, query in enumerate(data["query_data"]):
voices = sorted(query["voice"], key=lambda x: x["start"])
if voices != query["voice"]:
print("-----", dir)
if voices[0]["start"] > voices[0]["end"]:
print(
"err1: %s%s个query的第%d个voice的start大于end: %s"
% (dir, query_i, 0, voices[0]["answer"])
)
problem_dirs.add(dir)
for voice_i in range(1, len(voices)):
voice = voices[voice_i]
if voice["start"] > voice["end"]:
print(
"err1: %s%s个query的第%d个voice的start大于end: %s"
% (dir, query_i, voice_i, voice["answer"])
)
problem_dirs.add(dir)
if voice["start"] < voices[voice_i - 1]["end"]:
print(
"err2: %s%s个query的第%d个voice的start小于前一个voice的end: %s"
% (dir, query_i, voice_i, voice["answer"])
)
problem_dirs.add(dir)
problem_count[dir] += 1
print(len(dirs))
print(problem_dirs)
print(problem_count)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("指定 测试数据集文件夹")
sys.exit(1)
main(sys.argv[1])

View File

@@ -0,0 +1,108 @@
import json
import os
import shutil
import sys
import zipfile
import yaml
"""
target
{
"global": {
"lang": ""
},
"query_data": [
"file": "",
"duration": 2.0,
"voice": [
{
"answer": "",
"start": 0.0,
"end": 1.0
}
]
]
}
"""
def situation_a(meta, dataset_folder, output_folder):
"""
{
"combined": {
"en": [
{
"wav": "*.wav",
"transcriptions": [
{
"text": "",
"start": 0.0,
"end": 1.0
}
],
"duration": 2.0
}
]
}
}
"""
meta = meta["combined"]
for lang, arr in meta.items():
print("processing", lang)
assert len(lang) == 2
lang_folder = os.path.join(output_folder, lang)
os.makedirs(lang_folder, exist_ok=True)
data = {"global": {"lang": lang}, "query_data": []}
query_data = data["query_data"]
for item in arr:
os.makedirs(
os.path.join(lang_folder, os.path.dirname(item["wav"])),
exist_ok=True,
)
mp3_file = item["wav"][:-4] + ".mp3"
shutil.copyfile(
os.path.join(dataset_folder, mp3_file),
os.path.join(lang_folder, mp3_file),
)
query_data_item = {
"file": mp3_file,
"duration": float(item["duration"]),
"voice": [],
}
query_data.append(query_data_item)
voice = query_data_item["voice"]
for v in item["transcriptions"]:
voice.append(
{
"answer": v["text"],
"start": float(v["start"]),
"end": float(v["end"]),
}
)
with open(os.path.join(lang_folder, "data.yaml"), "w") as f:
yaml.dump(data, f, indent=2, allow_unicode=True, encoding="utf-8")
with zipfile.ZipFile(
os.path.join(output_folder, lang + ".zip"), "w"
) as ziper:
dirname = lang_folder
for path, _, files in os.walk(dirname):
for file in files:
ziper.write(
os.path.join(path, file),
os.path.join(path[len(dirname) :], file),
zipfile.ZIP_DEFLATED,
)
if __name__ == "__main__":
if len(sys.argv) < 3:
print("指定 数据集文件夹路径 输出路径")
sys.exit(1)
dataset_folder = sys.argv[1]
output_folder = sys.argv[2]
with open(os.path.join(dataset_folder, "meta.json")) as f:
meta = json.load(f)
situation_a(meta, dataset_folder, output_folder)

View File

@@ -0,0 +1,56 @@
import json
import sys
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel
from utils.evaluator_plus import evaluate_editops
def main(detailcase_file: str):
with open(detailcase_file) as f:
d = json.load(f)[0]
preds = d["preds"]
preds = list(map(lambda x: StreamDataModel(**x), preds))
preds = list(filter(lambda x: x.final_result, preds))
label = d["label"]
label = QueryData(**label)
print(evaluate_editops(label, preds))
def evaluate_from_record(detailcase_file: str, record_path: str):
with open(detailcase_file) as f:
d = json.load(f)[0]
label = d["label"]
label = QueryData(**label)
with open(record_path) as f:
record = json.load(f)
tokens_pred = record["tokens_pred"]
tokens_label = record["tokens_label"]
recognition_results = record["recognition_results"]
recognition_results = list(
map(lambda x: StreamDataModel(**x), recognition_results)
)
a, b = [], []
for i, rr in enumerate(recognition_results):
if rr.final_result:
a.append(tokens_pred[i])
b.append(rr)
tokens_pred = a
recognition_results = b
print(
evaluate_editops(
label,
recognition_results,
tokens_pred,
tokens_label,
)
)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("请指定 detailcase 文件路径")
sys.exit(1)
main(sys.argv[1])
# evaluate_from_record(sys.argv[1], sys.argv[2])

BIN
ssh-keygen Executable file

Binary file not shown.

11
starting_kit/Dockerfile Normal file
View File

@@ -0,0 +1,11 @@
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
WORKDIR /workspace
ADD ./requirements.txt /workspace
RUN pip install -r ./requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --trusted-host nexus.4pd.io --extra-index-url https://mirrors.aliyun.com/pypi/simple/ \
&& pip cache purge
ADD . /workspace
CMD ["python", "main.py"]

313
starting_kit/main.py Normal file
View File

@@ -0,0 +1,313 @@
import logging
import os
import threading
import time
from typing import Optional
import flask
import requests
from werkzeug.datastructures import FileStorage
app = flask.Flask(__name__)
heartbeat_active = False
log = logging.getLogger(__name__)
log.propagate = False
level = logging.INFO
log.setLevel(level)
formatter = logging.Formatter(
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
"%Y-%m-%d %H:%M:%S",
)
streamHandler = logging.StreamHandler()
streamHandler.setLevel(level)
streamHandler.setFormatter(formatter)
log.addHandler(streamHandler)
def heartbeat(url):
global heartbeat_active
if heartbeat_active:
return
heartbeat_active = True
while True:
try:
requests.post(url, json={"status": "RUNNING"})
except Exception:
pass
time.sleep(10)
def asr(
audio_file: FileStorage,
language: Optional[str],
progressCallbackUrl: str,
taskId: str,
):
"""TODO: 读取audio_file, 调用语音识别服务, 实时返回识别结果"""
# ignore BEGIN
# 此处为榜单本地测试使用
if os.getenv("LOCAL_TEST"):
return local_test(progressCallbackUrl, taskId)
# ignore END
language = "de"
# 某一次识别返回
requests.post(
progressCallbackUrl,
json={
"taskId": taskId,
"status": "RUNNING",
"recognition_results": { # 传增量结果, status如果是FINISHED, 或者ERROR, 这个字段请不要传值
"text": "最先启动的还是",
"final_result": True,
"para_seq": 0,
"language": language,
"start_time": 6300,
"end_time": 6421,
"words": [
{
"text": "",
"start_time": 6300,
"end_time": 6321,
},
{
"text": "",
"start_time": 6321,
"end_time": 6345,
},
{
"text": "",
"start_time": 6345,
"end_time": 6350,
},
{
"text": "",
"start_time": 6350,
"end_time": 6370,
},
{
"text": "",
"start_time": 6370,
"end_time": 6386,
},
{
"text": "",
"start_time": 6386,
"end_time": 6421,
},
{
"text": "",
"start_time": 6421,
"end_time": 6435,
},
],
},
},
)
# ... 识别结果返回完毕
# 识别结束
requests.post(
progressCallbackUrl,
json={
"taskId": taskId,
"status": "FINISHED",
},
)
@app.post("/predict")
def predict():
body = flask.request.form
language = body.get("language")
if language is None:
"自行判断语种"
taskId = body["taskId"]
progressCallbackUrl = body["progressCallbackUrl"]
heartbeatUrl = body["heartbeatUrl"]
threading.Thread(
target=heartbeat, args=(heartbeatUrl,), daemon=True
).start()
audio_file = flask.request.files["file"]
# audio_file.stream # 读取文件流
# audio_file.save("audio.mp3") # 保存文件
threading.Thread(
target=asr,
args=(audio_file, language, progressCallbackUrl, taskId),
daemon=True,
).start()
return flask.jsonify({"status": "OK"})
# ignore BEGIN
def local_test(progressCallbackUrl: str, taskId: str):
"""忽略此方法, 此方法为榜单本地调试使用"""
import random
import re
import yaml
def callback(content):
try:
if content is None:
requests.post(
progressCallbackUrl,
json={"taskId": taskId, "status": "FINISHED"},
)
else:
requests.post(
progressCallbackUrl,
json={
"taskId": taskId,
"status": "RUNNING",
"recognition_results": content,
},
)
except Exception:
pass
with open(
os.getenv("LOCAL_TEST_DATA_PATH", "../dataset/out/data.yaml")
) as f:
data = yaml.full_load(f)
voices = data["query_data"][0]["voice"]
# 首次发送
first_send_time = random.randint(3, 5)
send_interval = random.random() * 0
log.info("首次发送%ss 发送间隔%ss" % (first_send_time, send_interval))
time.sleep(first_send_time)
# 将句子拼接到一起
if random.random() < 0.3:
log.info("将部分句子合并成单句 每次合并的句子不超过3句")
rand_idx = 0
rand_sep = [0, len(voices) - 1]
while rand_sep[rand_idx] + 1 <= rand_sep[rand_idx + 1] - 1:
rand_cursep = random.randint(
rand_sep[rand_idx] + 1,
min(rand_sep[rand_idx + 1] - 1, rand_sep[rand_idx] + 1 + 3),
)
rand_sep.insert(rand_idx + 1, rand_cursep)
rand_idx += 1
merged_voices = []
for i, cur_sep in enumerate(rand_sep[:-1]):
voice = voices[cur_sep]
for j in range(cur_sep + 1, rand_sep[i + 1]):
voice["answer"] += voices[j]["answer"]
voice["end"] = voices[j]["end"]
merged_voices.append(voice)
merged_voices.append(voices[rand_sep[-1]])
voices = merged_voices
def split_and_keep(text, delimiters):
# 构建正则表达式模式,匹配文本或分隔符
pattern = "|".join(re.escape(delimiter) for delimiter in delimiters)
pattern = f"(?:[^{pattern}]+|[{pattern}])"
return re.findall(pattern, text)
puncs = [",", ".", "?", "!", ";", ":"]
para_seq = 0
for voice in voices:
answer: str = voice["answer"]
start_time: float = voice["start"]
end_time: float = voice["end"]
words = split_and_keep(answer, puncs)
temp_words = []
for i, word in enumerate(words):
if i > 0 and i < len(words) - 1 and random.random() < 0.15:
log.info("随机删除word")
continue
temp_words.extend(word.split(" "))
if len(temp_words) == 0:
temp_words = words[0].split(" ")
words = temp_words
answer = " ".join(words)
words = list(map(lambda x: x.strip(), words))
words = list(filter(lambda x: len(x) > 0, words))
# 将时间均匀分配到每个字上
words_withtime = []
word_unittime = (end_time - start_time) / len(words)
for i, word in enumerate(words):
word_start = start_time + word_unittime * i
word_end = word_start + word_unittime
words_withtime.append(
{
"text": word,
"start_time": word_start * 1000,
"end_time": word_end * 1000,
}
)
# 将句子首尾的标点符号时间扩展到字上 标点符号时间为瞬间
punc_at = 0
while punc_at < len(words) and words[punc_at] in puncs:
punc_at += 1
if punc_at < len(words):
words_withtime[punc_at]["start_time"] = words_withtime[0][
"start_time"
]
for i in range(0, punc_at):
words_withtime[i]["start_time"] = words_withtime[0]["start_time"]
words_withtime[i]["end_time"] = words_withtime[0]["start_time"]
punc_at = len(words) - 1
while punc_at >= 0 and words[punc_at] in puncs:
punc_at -= 1
if punc_at >= 0:
words_withtime[punc_at]["end_time"] = words_withtime[-1]["end_time"]
for i in range(punc_at + 1, len(words)):
words_withtime[i]["start_time"] = (
words_withtime[-1]["end_time"] + 0.1
)
words_withtime[i]["end_time"] = words_withtime[-1]["end_time"] + 0.1
if random.random() < 0.4 and len(words_withtime) > 1:
log.info("发送一次final_result=False")
rand_idx = random.randint(1, len(words_withtime) - 1)
recognition_result = {
"text": " ".join(
map(lambda x: x["text"], words_withtime[:rand_idx])
),
"final_result": False,
"para_seq": para_seq,
"language": "de",
"start_time": start_time * 1000,
"end_time": end_time * 1000,
"words": words_withtime[:rand_idx],
}
callback(recognition_result)
recognition_result = {
"text": answer,
"final_result": True,
"para_seq": para_seq,
"language": "de",
"start_time": start_time * 1000,
"end_time": end_time * 1000,
"words": words_withtime,
}
callback(recognition_result)
para_seq += 1
log.info("send %s" % para_seq)
time.sleep(send_interval)
callback(None)
# ignore END
if __name__ == "__main__":
app.run(host="0.0.0.0", port=80)

View File

@@ -0,0 +1,3 @@
flask
requests
pyyaml

View File

@@ -0,0 +1,16 @@
import json
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel
from utils.evaluator_plus import evaluate_editops
with open("out/detail_cases.json") as f:
detail_cases = json.load(f)
detail_case = detail_cases[0]
preds = []
for pred in detail_case["preds"]:
preds.append(StreamDataModel.model_validate(pred))
label = QueryData.model_validate(detail_case["label"])
print(evaluate_editops(label, preds))

93
tests/test_cer.py Normal file
View File

@@ -0,0 +1,93 @@
"""
f(a, b) 计算 a -> b 的编辑距离使用的方法是之前asr榜单的方法
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
test() 是对拍程序
"""
import random
import string
from copy import deepcopy
from typing import List, Tuple
import Levenshtein
def mapping(gt: str, dt: str):
return [i for i in gt], [i for i in dt]
def token_mapping(
tokens_gt: List[str], tokens_dt: List[str]
) -> Tuple[List[str], List[str]]:
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
for op in operations[::-1]:
if op[0] == "insert":
arr1.insert(op[1], None)
elif op[0] == "delete":
arr2.insert(op[2], None)
return arr1, arr2
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
insert = sum(1 for item in tokens_gt_mapping if item is None)
delete = sum(1 for item in tokens_dt_mapping if item is None)
equal = sum(
1
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
if token_gt == token_dt
)
replace = len(tokens_gt_mapping) - insert - equal # - delete
return replace, delete, insert
def f(a, b):
return cer(*token_mapping(*mapping(a, b)))
def raw(tokens_gt, tokens_dt):
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
insert = 0
delete = 0
replace = 0
for op in operations:
if op[0] == "insert":
insert += 1
if op[0] == "delete":
delete += 1
if op[0] == "replace":
replace += 1
return replace, delete, insert
def g(a, b):
return raw(*mapping(a, b))
def check(a, b):
ff = f(a, b)
gg = g(a, b)
if ff != gg:
print(ff, gg)
return ff == gg
def random_string(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
def test():
for _ in range(10000):
a = random_string(30)
b = random_string(30)
if not check(a, b):
print(a, b)
break
test()

0
utils/__init__.py Normal file
View File

57
utils/asr_ter.py Normal file
View File

@@ -0,0 +1,57 @@
# copy from
# https://gitlab.4pd.io/scene_lab/leaderboard/judge_flows/foundamental_capability/blob/master/utils/asr_ter.py
def calc_ter_speechio(pred, ref, language="zh"):
assert language == "zh", "Unsupported language %s" % language
assert ref is not None and ref != "", "Reference script cannot be empty"
if language == "zh":
from .speechio import error_rate_zh as error_rate
from .speechio import textnorm_zh as textnorm
normalizer = textnorm.TextNorm(
to_banjiao=True,
to_upper=True,
to_lower=False,
remove_fillers=True,
remove_erhua=True,
check_chars=False,
remove_space=False,
cc_mode="",
)
norm_pred = normalizer(pred if pred is not None else "")
norm_ref = normalizer(ref)
tokenizer = "char"
alignment, score = error_rate.EditDistance(
error_rate.tokenize_text(norm_ref, tokenizer),
error_rate.tokenize_text(norm_pred, tokenizer),
)
c, s, i, d = error_rate.CountEdits(alignment)
ter = error_rate.ComputeTokenErrorRate(c, s, i, d) / 100.0
return {"ter": ter, "err_token_cnt": s + d + i, "ref_all_token_cnt": s + d + c}
assert False, "Bug, not reachable"
def calc_ter_wjs(pred, ref, language="zh"):
assert language == "zh", "Unsupported language %s" % language
assert ref is not None and ref != "", "Reference script cannot be empty"
from . import wjs_asr_wer
ignore_words = set()
case_sensitive = False
split = None
calculator = wjs_asr_wer.Calculator()
norm_pred = wjs_asr_wer.normalize(
wjs_asr_wer.characterize(pred if pred is not None else ""),
ignore_words,
case_sensitive,
split,
)
norm_ref = wjs_asr_wer.normalize(wjs_asr_wer.characterize(ref), ignore_words, case_sensitive, split)
result = calculator.calculate(norm_pred, norm_ref)
ter = ((result["ins"] + result["sub"] + result["del"]) * 1.0 / result["all"]) if result["all"] != 0 else 1.0
return {
"ter": ter,
"err_token_cnt": result["ins"] + result["sub"] + result["del"],
"ref_all_token_cnt": result["all"],
}

224
utils/client.py Normal file
View File

@@ -0,0 +1,224 @@
import json
import os
import threading
import time
import traceback
from copy import deepcopy
from typing import Any, List
import websocket
from pydantic_core import ValidationError
from websocket import create_connection
from schemas.context import ASRContext
from schemas.stream import StreamDataModel, StreamResultModel
from utils.logger import logger
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
class Client:
def __init__(self, sut_url: str, context: ASRContext) -> None:
# base_url = "ws://127.0.0.1:5003"
self.base_url = sut_url + "/recognition"
logger.info(f"{self.base_url}")
self.context: ASRContext = deepcopy(context)
# if not os.getenv("DATASET_FILEPATH", ""):
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
# self.base_url = "ws://localhost:5003/recognition"
self.connect_num = 0
self.exception = False
self.close_time = 10**50
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
self.success = True
def action(self):
# 如果 5 次初始化都失败,则退出
connect_success = False
for i in range(5):
try:
self._connect_init()
connect_success = True
break
except Exception as e:
logger.error(f"{i+1} 次连接失败,原因:{e}")
time.sleep(int(os.getenv("connect_sleep", 10)))
if not connect_success:
exit(-1)
self.trecv = threading.Thread(target=self._recv)
self.trecv.start()
self._send()
self._close()
return self._gen_result()
def _connect_init(self):
end_time = time.time() + float(os.getenv("end_time", 2))
success = False
try:
self.ws = create_connection(self.base_url)
self.ws.send(json.dumps(self._gen_init_data()))
while time.time() < end_time and not success:
data = self.ws.recv()
logger.info(f"data {data}")
if len(data) == 0:
time.sleep(1)
continue
if isinstance(data, str):
try:
data = json.loads(data)
except Exception:
raise Exception("初始化阶段,数据不是 json 字符串格式,终止流程")
if isinstance(data, dict):
success = data.get("success", False)
if not success:
logger.error(f"初始化失败,返回的结果为 {data},终止流程")
else:
break
logger.error("初始化阶段,数据不是 json 字符串格式,终止流程")
exit(-1)
except websocket.WebSocketConnectionClosedException or TimeoutError:
raise Exception("初始化阶段连接中断,终止流程")
# exit(-1)
except ConnectionRefusedError:
raise Exception("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
# logger.error("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
# self.connect_num += 1
# if self.connect_num <= 4:
# time.sleep(int(os.getenv("connect_sleep", 10)))
# self._connect_init()
# success = True
# else:
# logger.error("初始化阶段连接失败多次")
# exit(-1)
if not success:
# logger.error("初始化阶段 60s 没有返回数据,时间太长,终止流程")
raise Exception("初始化阶段 60s 没有返回数据,时间太长,终止流程")
else:
logger.info("建立连接成功")
self.connect_num = 0
def _send(self):
send_ts = float(os.getenv("send_interval", 60))
if not self.success:
return
with open(self.context.file_path, "rb") as fp:
wav_data = fp.read()
meta_length = wav_data.index(b"data") + 8
try:
with open(self.context.file_path, "rb") as fp:
# 去掉 wav 文件的头信息
fp.read(meta_length)
# 上一段音频的发送时间
last_send_time = -1
# 正文内容
while True:
now_time = time.perf_counter()
if last_send_time == -1:
chunk = fp.read(int(self.context.chunk_size))
else:
interval_cnt = max(
int((now_time - last_send_time) / self.context.wait_time),
1,
)
chunk = fp.read(int(self.context.chunk_size * interval_cnt))
if not chunk:
break
send_time_start = time.perf_counter()
self.ws.send(chunk, websocket.ABNF.OPCODE_BINARY)
self.send_time.append(send_time_start)
last_send_time = send_time_start
send_time_end = time.perf_counter()
if send_time_end - send_time_start > send_ts:
logger.error(f"发送延迟已经超过 {send_ts}s, 终止当前音频发送")
break
if (sleep_time := self.context.wait_time + now_time - send_time_end) > 0:
time.sleep(sleep_time)
logger.info("当条语音数据发送完成")
self.ws.send(json.dumps({"end": True}))
logger.info("2s 后关闭双向连接.")
except BrokenPipeError:
logger.error("发送数据出错,被测服务出现故障")
except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"{traceback.print_exc()}")
logger.error("发送数据失败")
self.success = False
# self.close_time = time.perf_counter() + int(os.getenv("api_timeout", 2))
self.close_time = time.perf_counter() + 20 * 60
def _recv(self):
try:
while self.ws.connected and self.success:
recv_data = self.ws.recv()
if isinstance(recv_data, str):
if recv_data := str(recv_data):
self.recv_time.append(time.perf_counter())
# 识别到最后的合并结果后再关闭
recognition_results = StreamResultModel(**json.loads(recv_data)).recognition_results
if (
recognition_results.final_result
and recognition_results.start_time == 0
and recognition_results.end_time == 0
and recognition_results.para_seq == 0
):
self.success = False
else:
self.predict_data.append(recv_data)
# if recv_data.recognition_results.final_result and (IN_TEST or os.getenv('test')):
# logger.info(f"recv_data {recv_data}")
else:
self.success = False
raise Exception("返回的结果不是字符串形式")
except websocket.WebSocketConnectionClosedException:
logger.error("WebSocketConnectionClosedException")
except ValidationError as e:
logger.error("返回的结果不符合格式")
logger.error(f"Exception is {e}")
os._exit(1)
except OSError:
pass
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error("处理被测服务返回数据时出错")
self.success = False
def _close(self):
while time.perf_counter() < self.close_time and self.success:
# while not self.success:
time.sleep(1)
try:
self.ws.close()
except Exception as e:
print(e)
pass
def _gen_result(self) -> dict:
if not self.predict_data:
logger.error("没有任何数据返回")
self.predict_data = [StreamResultModel(**json.loads(data)).recognition_results for data in self.predict_data]
# for item in self.predict_data:
# if item.final_result and (IN_TEST or os.getenv('test')):
# logger.info(f"recv_data {item}")
return {
"fail": not self.predict_data,
"send_time": self.send_time,
"recv_time": self.recv_time,
"predict_data": self.predict_data,
}
def _gen_init_data(self) -> dict:
return {
"parameter": {
"lang": self.context.lang,
"sample_rate": self.context.sample_rate,
"channel": self.context.channel,
"format": self.context.audio_format,
"bits": self.context.bits,
"enable_words": self.context.enable_words,
}
}

277
utils/client_async.py Normal file
View File

@@ -0,0 +1,277 @@
import asyncio
import json
import os
import time
import traceback
from copy import deepcopy
from enum import Enum
from typing import Any, List
import websockets
from pydantic_core import ValidationError
from schemas.context import ASRContext
from schemas.stream import StreamResultModel, StreamWordsModel
from utils.logger import logger
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
class STATUS_DATA(str, Enum):
WAITING_FIRST_INIT = "waiting_first_init"
FIRST_FAIL = "fail"
WAITING_SECOND_INIT = "waiting_second_init"
SECOND_INIT = "second_fail"
WAITING_THIRD_INIT = "waiting_third_init"
THIRD_INIT = "third_fail"
SUCCESS = "success"
CLOSED = "closed"
class ClientAsync:
def __init__(self, sut_url: str, context: ASRContext, idx: int) -> None:
# base_url = "ws://127.0.0.1:5003"
self.base_url = sut_url + "/recognition"
self.context: ASRContext = deepcopy(context)
self.idx = idx
# if not os.getenv("DATASET_FILEPATH", ""):
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
# self.base_url = "ws://localhost:5003/recognition"
self.fail_count = 0
self.close_time = 10**50
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
async def _sender(
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
):
# 设置 websocket 缓冲区大小
websocket.transport.set_write_buffer_limits(1024 * 1024 * 1024)
# 发送初始化数据
await websocket.send(json.dumps(self._gen_init_data()))
await send_queue.put(STATUS_DATA.WAITING_FIRST_INIT)
connect_status = await recv_queue.get()
if connect_status == STATUS_DATA.FIRST_FAIL:
return
# 开始发送音频
with open(self.context.file_path, "rb") as fp:
wav_data = fp.read()
meta_length = wav_data.index(b"data") + 8
try:
with open(self.context.file_path, "rb") as fp:
# 去掉 wav 文件的头信息
fp.read(meta_length)
wav_time = 0.0
label_id = 0
char_contains_rate_checktime = []
char_contains_rate_checktime_id = 0
while True:
now_time = time.perf_counter()
chunk = fp.read(int(self.context.chunk_size))
if not chunk:
break
wav_time += self.context.wait_time
try:
self.send_time.append(time.perf_counter())
await asyncio.wait_for(websocket.send(chunk), timeout=0.08)
except asyncio.exceptions.TimeoutError:
pass
while label_id < len(self.context.labels) and wav_time >= self.context.labels[label_id].start:
char_contains_rate_checktime.append(now_time + 3.0)
label_id += 1
predict_text_len = sum(map(lambda x: len(x.text), self.predict_data))
while char_contains_rate_checktime_id < len(char_contains_rate_checktime) and \
char_contains_rate_checktime[char_contains_rate_checktime_id] <= now_time:
label_text_len = sum(
map(lambda x: len(x.answer),
self.context.labels[:char_contains_rate_checktime_id+1]))
if predict_text_len / self.context.char_contains_rate < label_text_len:
self.context.fail_char_contains_rate_num += 1
char_contains_rate_checktime_id += 1
await asyncio.sleep(max(0, self.context.wait_time - (time.perf_counter() - now_time)))
await websocket.send(json.dumps({"end": True}))
logger.info(f"{self.idx} 条数据,当条语音数据发送完成")
logger.info(f"{self.idx} 条数据3s 后关闭双向连接.")
self.close_time = time.perf_counter() + 3
except websockets.exceptions.ConnectionClosedError:
logger.error(f"{self.idx} 条数据发送过程中,连接断开")
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error(f"{self.idx} 条数据,发送数据失败")
async def _recv(
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
):
await recv_queue.get()
try:
await asyncio.wait_for(websocket.recv(), timeout=2)
except asyncio.exceptions.TimeoutError:
await send_queue.put(STATUS_DATA.FIRST_FAIL)
logger.info(f"{self.idx} 条数据,初始化阶段, 2s 没收到 success 返回,超时了")
self.fail_count += 1
return
except Exception as e:
await send_queue.put(STATUS_DATA.FIRST_FAIL)
logger.error(f"{self.idx} 条数据,初始化阶段, 收到异常:{e}")
self.fail_count += 1
return
else:
await send_queue.put(STATUS_DATA.SUCCESS)
# 开始接收语音识别结果
try:
while websocket.open:
# 接收数据
recv_data = await websocket.recv()
if isinstance(recv_data, str):
self.recv_time.append(time.perf_counter())
recv_data = str(recv_data)
recv_data = json.loads(recv_data)
result = StreamResultModel(**recv_data)
recognition_results = result.asr_results
if (
recognition_results.final_result
and not recognition_results.language
and recognition_results.start_time == 0
and recognition_results.end_time == 0
and recognition_results.para_seq == 0
):
pass
else:
self.predict_data.append(recognition_results)
else:
raise Exception("返回的结果不是字符串形式")
except websockets.exceptions.ConnectionClosedOK:
pass
except websockets.exceptions.ConnectionClosedError:
pass
except ValidationError as e:
logger.error(f"{self.idx} 条数据,返回的结果不符合格式")
logger.error(f"Exception is {e}")
os._exit(1)
except OSError:
pass
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error(f"{self.idx} 条数据,处理被测服务返回数据时出错")
async def _action(self):
logger.info(f"{self.idx} 条数据开始测试")
while self.fail_count < 3:
send_queue = asyncio.Queue()
recv_queue = asyncio.Queue()
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
async with websockets.connect(self.base_url) as websocket:
send_task = asyncio.create_task(self._sender(websocket, send_queue, recv_queue))
recv_task = asyncio.create_task(self._recv(websocket, recv_queue, send_queue))
await asyncio.gather(send_task)
await asyncio.sleep(3)
await asyncio.gather(recv_task)
if self.send_time:
break
else:
self.fail_count += 1
logger.info(f"{self.idx} 条数据,初始化阶段, 第 {self.fail_count} 次失败, 1s 后重试")
time.sleep(1)
def action(self):
asyncio.run(self._action())
return self._gen_result()
def _gen_result(self) -> ASRContext:
if not self.predict_data:
logger.error(f"{self.idx} 条数据,没有任何数据返回")
self.context.append_preds(self.predict_data, self.send_time, self.recv_time)
self.context.fail = not self.predict_data
punctuation_words: List[StreamWordsModel] = []
for pred in self.predict_data:
punctuations = [",", ".", "!", "?"]
if pred.language == "zh":
punctuations = ["", "", "", ""]
elif pred.language == "ja":
punctuations = ["", "", "", ""]
elif pred.language in ("ar", "fa"):
punctuations = ["،", ".", "!", "؟"]
elif pred.language == "el":
punctuations = [",", ".", "", ""]
elif pred.language == "ti":
punctuations = [""]
for word in pred.words:
if word.text in punctuations:
punctuation_words.append(word)
start_times = list(map(lambda x: x.start_time, punctuation_words))
start_times = sorted(start_times)
end_times = list(map(lambda x: x.end_time, punctuation_words))
end_times = sorted(end_times)
self.context.punctuation_num = len(self.context.labels)
label_n = len(self.context.labels)
for i, label in enumerate(self.context.labels):
label_left = (label.end - 0.7)
label_right = (label.end + 0.7)
if i < label_n - 1:
label_left = label.end
label_right = self.context.labels[i+1].start
exist = False
def upper_bound(x: float, lst: List[float]) -> int:
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] >= x:
ans = mid
right = mid - 1
else:
left = mid + 1
return ans
def lower_bound(x: float, lst: List[float]) -> int:
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] <= x:
ans = mid
left = mid + 1
else:
right = mid - 1
return ans
left_in_pred = upper_bound(label_left, start_times)
if left_in_pred != -1 and start_times[left_in_pred] <= label_right:
exist = True
right_in_pred = lower_bound(label_right, end_times)
if right_in_pred != -1 and end_times[right_in_pred] >= label_left:
exist = True
if exist:
self.context.pred_punctuation_num += 1
return self.context
def _gen_init_data(self) -> dict:
return {
"parameter": {
"lang": None,
"sample_rate": self.context.sample_rate,
"channel": self.context.channel,
"format": self.context.audio_format,
"bits": self.context.bits,
"enable_words": self.context.enable_words,
}
}

409
utils/client_callback.py Normal file
View File

@@ -0,0 +1,409 @@
import logging
import os
import threading
import time
from typing import Dict, List, Optional
import requests
from flask import Flask, abort, request
from pydantic import BaseModel, Field, ValidationError, field_validator
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel
from utils.evaluator_plus import evaluate_editops, evaluate_punctuation
from .logger import log
MY_POD_IP = os.environ["MY_POD_IP"]
class StopException(Exception): ...
class EvaluateResult(BaseModel):
lang: str
cer: float
align_start: Dict[int, int] = Field(
description="句首字对齐时间差值(ms) -> 对齐数"
)
align_end: Dict[int, int] = Field(
description="句尾字对齐时间差值(ms) -> 对齐数"
)
first_word_distance_sum: float = Field(description="句首字距离总和(s)")
last_word_distance_sum: float = Field(description="句尾字距离总和(s)")
rtf: float = Field(description="翻译速度")
first_receive_delay: float = Field(description="首包接收延迟(s)")
query_count: int = Field(description="音频数")
voice_count: int = Field(description="句子数")
pred_punctuation_num: int = Field(description="预测标点数")
label_punctuation_num: int = Field(description="标注标点数")
pred_sentence_punctuation_num: int = Field(description="预测句子标点数")
label_setence_punctuation_num: int = Field(description="标注句子标点数")
preds: List[StreamDataModel] = Field(description="预测结果")
label: QueryData = Field(description="标注结果")
class ResultModel(BaseModel):
taskId: str
status: str
message: str = Field("")
recognition_results: Optional[StreamDataModel] = Field(None)
@field_validator("recognition_results", mode="after")
def convert_to_seconds(cls, v: Optional[StreamDataModel], values):
# 在这里处理除以1000的逻辑
if v is None:
return v
v.end_time = v.end_time / 1000
v.start_time = v.start_time / 1000
for word in v.words:
word.start_time /= 1000
word.end_time /= 1000
return v
class ClientCallback:
def __init__(self, sut_url: str, port: int):
self.sut_url = sut_url #sut_urlASR 服务的 URL如 http://asr-service:8080
self.port = port #port当前客户端监听的端口用于接收回调
#创建 Flask 应用并注册路由
self.app = Flask(__name__)
self.app.add_url_rule(
"/api/asr/batch-callback/<taskId>",
view_func=self.asr_callback,
methods=["POST"],
)
self.app.add_url_rule(
"/api/asr-runner/report",
view_func=self.heartbeat,
methods=["POST"],
)
"""
路由 1/api/asr/batch-callback/<taskId>
接收 ASR 服务的识别结果回调self.asr_callback 处理)。
taskId 是路径参数,用于标识具体任务。
路由 2/api/asr-runner/report
接收 ASR 服务的心跳检测请求self.heartbeat 处理)。
"""
logging.getLogger("werkzeug").disabled = True
threading.Thread(
target=self.app.run, args=("0.0.0.0", port), daemon=True
).start()
self.mutex = threading.Lock()
self.finished = threading.Event()
self.product_avaiable = True
self.reset()
def reset(self):
self.begin_time = None
self.end_time = None
self.first_receive_time = None
self.last_heartbeat_time = None
self.app_on = False
self.para_seq = 0
self.finished.clear()
self.error: Optional[str] = None
self.last_recognition_result: Optional[StreamDataModel] = None
self.recognition_results: List[StreamDataModel] = []
def asr_callback(self, taskId: str):
if self.app_on is False:
abort(400)
body = request.get_json(silent=True) # 静默解析JSON失败时返回None
if body is None:
abort(404)
try:
result = ResultModel.model_validate(body) #将回调的 JSON 数据解析为 ResultModel 对象,确保结构符合预期。
except ValidationError as e:
log.error("asr_callback: 结果格式错误: %s", e)
abort(404)
#处理任务完成状态FINISHED
if result.status == "FINISHED":
with self.mutex:
self.stop()
return "ok"
#处理非运行状态(非 RUNNING
if result.status != "RUNNING":
log.error(
"asr_callback: 结果状态错误: %s, message: %s",
result.status,
result.message,
)
abort(404)
recognition_result = result.recognition_results
if recognition_result is None:
log.error("asr_callback: 结果中没有recognition_results字段")
abort(404)
with self.mutex:
if not self.app_on:
log.error("asr_callback: 应用已结束")
abort(400)
if recognition_result.para_seq < self.para_seq:
error = "asr_callback: 结果中para_seq小于上一次的: %d < %d" % (
recognition_result.para_seq,
self.para_seq,
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
if recognition_result.para_seq > self.para_seq + 1:
error = (
"asr_callback: 结果中para_seq大于上一次的+1 \
说明存在para_seq = %d没有final_result为True确认"
% (self.para_seq + 1,)
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
if (
self.last_recognition_result is not None
and recognition_result.start_time
< self.last_recognition_result.end_time
):
error = "asr_callback: 结果中start_time小于上一次的end_time: %s < %s" % (
recognition_result.start_time,
self.last_recognition_result.end_time,
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
self.recognition_results.append(recognition_result)
if recognition_result.final_result is True:
self.para_seq = recognition_result.para_seq
if self.last_recognition_result is None:
self.first_receive_time = time.time()
self.last_recognition_result = recognition_result
return "ok"
"""
def heartbeat(self):
if self.app_on is False:
abort(400)
body = request.get_json(silent=True)
if body is None:
abort(404)
status = body.get("status")
if status != "RUNNING":
message = body.get("message", "")
if message:
message = ", message: " + message
log.error("heartbeat: 状态错误: %s%s", status, message)
return "ok"
with self.mutex:
self.last_heartbeat_time = time.time()
return "ok"
"""
def predict(
self,
language: Optional[str],
audio_file: str,
audio_duration: float,
task_id: str,
):
#使用互斥锁确保线程安全
with self.mutex:
if self.app_on:
log.error("上一音频尚未完成处理,流程出现异常")
raise StopException()
self.reset()
self.app_on = True
#请求URLself.sut_url + "/predict"(如 http://localhost:8080/predict
resp = requests.post(
self.sut_url + "/predict",
data={
"language": language,
"taskId": task_id,
"progressCallbackUrl": "http://%s:%d/api/asr/batch-callback/%s"
% (MY_POD_IP, self.port, task_id),
"heartbeatUrl": "http://%s:%d/api/asr-runner/report" % (MY_POD_IP, self.port),
},
files={"file": (audio_file, open(audio_file, "rb"))},
timeout=60,
)
#响应处理
if resp.status_code != 200:
log.error("/predict接口返回http code %s", resp.status_code)
raise StopException()
resp.raise_for_status()
status = resp.json().get("status")
if status != "OK":
log.error("/predict接口返回非OK状态: %s", status)
raise StopException()
#辅助线程
threading.Thread(
target=self.dead_line_check, args=(audio_duration,), daemon=True
).start()
threading.Thread(target=self.heartbeat_check, daemon=True).start()
def dead_line_check(self, audio_duration: float):
begin_time = time.time()
self.begin_time = begin_time
# 初始化 10s 延迟检测
self.sleep_to(begin_time + 10)
with self.mutex:
if self.last_recognition_result is None:
error = "首包延迟内未收到返回"
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
# 第一次30s检测
next_checktime = begin_time + 30
ddl = begin_time + max((audio_duration / 6) + 10, 30)
while time.time() < ddl:
self.sleep_to(next_checktime)
with self.mutex:
if self.finished.is_set():
return
if self.last_recognition_result is None:
error = "检测追赶线过程中获取最后一次识别结果异常"
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
last_end_time = self.last_recognition_result.end_time
expect_end_time = (next_checktime - begin_time - 30) * 5.4
if last_end_time < expect_end_time:
log.warning(
"识别时间位置 %s 被死亡追赶线 %s 已追上,将置为产品不可用",
last_end_time,
expect_end_time,
)
self.product_avaiable = False
self.sleep_to(ddl)
break
next_checktime = last_end_time / 5.4 + begin_time + 30 + 1
next_checktime = min(next_checktime, ddl)
with self.mutex:
if self.finished.is_set():
return
log.warning("识别速度rtf低于1/6, 将置为产品不可用")
self.product_avaiable = False
self.sleep_to(begin_time + max((audio_duration / 3) + 10, 30))
with self.mutex:
if self.finished.is_set():
return
error = "处理时间超过ddl %s " % (ddl - begin_time)
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
def heartbeat_check(self):
self.last_heartbeat_time = time.time()
while True:
with self.mutex:
if self.finished.is_set():
return
if time.time() - self.last_heartbeat_time > 30:
error = "asr_runner 心跳超时 %s" % (
time.time() - self.last_heartbeat_time
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
time.sleep(5)
def sleep_to(self, to: float):
seconds = to - time.time()
if seconds <= 0:
return
time.sleep(seconds)
def stop(self):
self.end_time = time.time()
self.finished.set()
self.app_on = False
def evaluate(self, query_data: QueryData):
log.info("开始评估")
if (
self.begin_time is None
or self.end_time is None
or self.first_receive_time is None
):
if self.begin_time is None:
log.error("评估流程异常 无开始时间")
if self.end_time is None:
log.error("评估流程异常 无结束时间")
if self.first_receive_time is None:
log.error("评估流程异常 无首次接收时间")
raise StopException()
rtf = max(self.end_time - self.begin_time - 10, 0) / query_data.duration
first_receive_delay = max(self.first_receive_time - self.begin_time, 0)
query_count = 1
voice_count = len(query_data.voice)
preds = self.recognition_results
self.recognition_results = list(
filter(lambda x: x.final_result, self.recognition_results)
)
(
pred_punctuation_num,
label_punctuation_num,
pred_sentence_punctuation_num,
label_setence_punctuation_num,
) = evaluate_punctuation(query_data, self.recognition_results)
(
cer,
_,
align_start,
align_end,
first_word_distance_sum,
last_word_distance_sum,
) = evaluate_editops(query_data, self.recognition_results)
if align_start[300] / voice_count < 0.8:
log.warning(
"评估结果首字300ms对齐率 %s < 0.8, 将置为产品不可用",
align_start[300] / voice_count,
)
self.product_avaiable = False
return EvaluateResult(
lang=query_data.lang,
cer=cer,
align_start=align_start,
align_end=align_end,
first_word_distance_sum=first_word_distance_sum,
last_word_distance_sum=last_word_distance_sum,
rtf=rtf,
first_receive_delay=first_receive_delay,
query_count=query_count,
voice_count=voice_count,
pred_punctuation_num=pred_punctuation_num,
label_punctuation_num=label_punctuation_num,
pred_sentence_punctuation_num=pred_sentence_punctuation_num,
label_setence_punctuation_num=label_setence_punctuation_num,
preds=preds,
label=query_data,
)

445
utils/evaluate.py Normal file
View File

@@ -0,0 +1,445 @@
import os
import subprocess
from collections import defaultdict
from typing import Dict, List
from utils import asr_ter
from utils.logger import logger
log_mid_result = int(os.getenv("log", 0)) == 1
class AsrEvaluator:
def __init__(self) -> None:
self.query_count = 0 # query 数目(语音数目)
self.voice_count = 0 # 有开始和结束时间的语音条数(用于 RTF 计算)
self.cut_punc = [] # 切分标点符号,需要注意切分的时候根据列表中的顺序进行切分,比如 ... 应该放到 . 之前。
# cer 属性
self.one_minus_cer = 0 # 每个 query 的 1 - cer 和
self.token_count = 0 # 每个 query 的字数/词数和
# 句子切分率属性
self.miss_count = 0 # 每个 query miss-count 和
self.more_count = 0 # 每个 query more-count 和
self.cut_count = 0 # 每个 query cut-count 和
self.rate = 0 # 每个 query 的 cut-rate 和
# detail case
self.result = []
def evaluate(self, eval_result):
pass
def post_evaluate(self):
pass
def gen_result(self) -> Dict:
output_result = dict()
output_result["query_count"] = self.query_count
output_result["voice_count"] = self.voice_count
output_result["token_cnt"] = self.token_count
output_result["one_minus_cer"] = self.one_minus_cer
output_result["one_minus_cer_metrics"] = self.one_minus_cer / self.query_count
output_result["miss_count"] = self.miss_count
output_result["more_count"] = self.more_count
output_result["cut_count"] = self.cut_count
output_result["cut_rate"] = self.rate
output_result["cut_rate_metrics"] = self.rate / self.query_count
output_result["rtf"] = self.rtf
output_result["rtf_end"] = self.rtf_end
output_result["rtf_metrics"] = self.rtf / self.voice_count
output_result["rtf_end_metrics"] = self.rtf_end / self.voice_count
detail_case = self.result
return output_result, detail_case
def _get_predict_final_sentences(self, predict_data: List[Dict]) -> List[str]:
"""
获取 predict data 数据,然后将其中 final 的句子拿出来,放到列表里。
"""
return [
item["recoginition_results"]["text"]
for item in predict_data
if item["recoginition_results"]["final_result"]
]
def _sentence_final_index(self, sentences: List[str], tokens: List[str], tokenizer="word") -> List[int]:
"""
获取 sentence 结束的字对应的 token 索引值。
"""
token_index_list = []
token_idx = 0
for sentence in sentences:
for token in Tokenizer.tokenize(sentence, tokenizer):
if token not in tokens:
continue
while tokens[token_idx] != token:
token_idx += 1
token_index_list.append(token_idx)
return token_index_list
def _voice_to_cut_sentence(self, voice_sentences: List[str]) -> Dict:
"""
将数据集的语音片段转换为最小切分单元列表。
使用 cut_punc 中的所有 punc 进行依次切分,最后去除掉完全空的内容
示例:
["你好,你好呀", "你好,我在写抽象的代码逻辑"]
->
cut_sentences: ["你好", "你好呀", "你好", "我在写抽象的代码逻辑"]
cut_sentence_index_list: [1, 3] ("你好呀" 对应 1-idx, "我在写抽象的代码逻辑" 对应 3-idx)
"""
voice_sentences_result = defaultdict(list)
for voice_sentence in voice_sentences:
sentence_list = [voice_sentence]
sentence_tmp_list = []
for punc in self.cut_punc:
for sentence in sentence_list:
sentence_tmp_list.extend(sentence.split(punc))
sentence_list, sentence_tmp_list = sentence_tmp_list, []
sentence_list = [item for item in sentence_list if item]
# 切分后的句子单元
voice_sentences_result["cut_sentences"].extend(sentence_list)
# 每个语音单元最后一个字对应的句子单元的索引
voice_sentences_result["cut_sentence_index_list"].append(len(voice_sentences_result["cut_sentences"]) - 1)
return voice_sentences_result
def _voice_bytes_index(self, timestamp, sample_rate=16000, bit_depth=16, channels=1):
"""
timestamp: 时间, 单位秒
"""
bytes_per_sample = bit_depth // 8
return timestamp * sample_rate * bytes_per_sample * channels
class AsrZhEvaluator(AsrEvaluator):
"""
中文的评估方式
"""
def __init__(self):
super().__init__()
self.cut_zh_punc = ["······", "......", "", "", "", "", "", ""]
self.cut_en_punc = ["...", ".", ",", "?", "!", ";", ":"]
self.cut_punc = self.cut_zh_punc + self.cut_en_punc
def evaluate(self, eval_result) -> Dict:
self.query_count += 1
self.voice_count += len(eval_result["voice"])
# 获取,标注结果 & 语音单元(非句子单元)
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
# print("label_voice_sentences", label_voice_sentences)
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
# print("voice_to_cut_info", voice_to_cut_info)
# 获取,标注结果 & 句子单元
label_sentences = voice_to_cut_info["cut_sentences"]
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
# 标注结果 & 句子单元 & norm 操作
label_sentences = [self._sentence_norm(sentence) for sentence in label_sentences]
if log_mid_result:
logger.info(f"label_sentences {label_sentences}")
# print("label_sentences", label_sentences)
# 预测结果 & 句子单元
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
# print("predict_sentences_raw", predict_sentences_raw)
# 预测结果 & 句子单元 & norm 操作
predict_sentences = [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
if log_mid_result:
logger.info(f"predict_sentences {predict_sentences}")
# print("predict_sentences", predict_sentences)
# 基于最小编辑距离进行 token 匹配,获得匹配后的 token 列表
label_tokens, predict_tokens = self._sentence_transfer("".join(label_sentences), "".join(predict_sentences))
# cer 计算
cer_info = self.cer(label_sentences, predict_sentences)
if log_mid_result:
logger.info(f"cer_info {cer_info}")
# print("cer_info", cer_info)
self.one_minus_cer += cer_info["one_minus_cer"]
self.token_count += cer_info["token_count"]
# 句子切分准召率
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
if log_mid_result:
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
# print("cut_info", cut_info)
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
self.miss_count += cut_info["miss_count"]
self.more_count += cut_info["more_count"]
self.cut_count += cut_info["cut_count"]
self.rate += cut_info["rate"]
self.result.append(
{
"label_tokens": label_tokens,
"predict_tokens": predict_tokens,
"one_minus_cer": cer_info["one_minus_cer"],
"token_count": cer_info["one_minus_cer"],
"miss_count": cut_info["miss_count"],
"more_count": cut_info["more_count"],
"cut_count": cut_info["cut_count"],
"rate": cut_info["rate"],
}
)
def cer(self, label_sentences, predict_sentences):
pred_str = ''.join(predict_sentences) if predict_sentences is not None else ''
label_str = ''.join(label_sentences)
r = asr_ter.calc_ter_speechio(pred_str, label_str)
one_minus_cer = max(1.0 - r['ter'], 0)
token_count = r['ref_all_token_cnt']
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens))
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens))
label_sentence_count = len(label_final_index_list)
miss_count = len(label_final_index_list - pred_final_index_list)
more_count = len(pred_final_index_list - label_final_index_list)
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
return {
"miss_count": miss_count,
"more_count": more_count,
"cut_count": label_sentence_count,
"rate": rate,
"label_final_index_list": label_final_index_list,
"pred_final_index_list": pred_final_index_list,
}
def _sentence_norm(self, sentence, tokenizer="word"):
"""
对句子进行 norm 操作
"""
from utils.speechio import textnorm_zh as textnorm
if tokenizer == "word":
normalizer = textnorm.TextNorm(
to_banjiao=True,
to_upper=True,
to_lower=False,
remove_fillers=True,
remove_erhua=False, # 这里同批量识别不同,改成了 False
check_chars=False,
remove_space=False,
cc_mode="",
)
return normalizer(sentence)
else:
logger.error("tokenizer error, not support.")
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="char"):
"""
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
args:
label: "今天的通话质量不错呀昨天的呢"
predict: "今天的通话质量不错昨天呢星期"
tokenizer: 分词方式
return:
label: ["", "", "", "", "", "", "", "", "", "", "", "", "", "", None, None]
predict: ["", "", "", "", "", "", "", "", "", None, "", "", None, "", "", ""]
"""
from utils.speechio import error_rate_zh as error_rate
if tokenizer == "char":
alignment, score = error_rate.EditDistance(
error_rate.tokenize_text(label_sentence, tokenizer),
error_rate.tokenize_text(predict_sentence, tokenizer),
)
label_tokens, pred_tokens = [], []
for align in alignment:
# print(align.__dict__)
label_tokens.append(align.ref)
pred_tokens.append(align.hyp)
return (label_tokens, pred_tokens)
else:
logger.error("tokenizer 出错了,暂时不支持其它的")
def _pred_data_transfer(self, predict_data, recv_time):
"""
predict_data = [
{"recoginition_results": {"text": "1", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "12", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "123", "final_result": True, "para_seq": 0}},
{"recoginition_results": {"text": "4", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "45", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "456", "final_result": True, "para_seq": 0}},
]
recv_time = [1, 3, 5, 6, 7, 8]
->
[
[{'text': '1', 'time': 1}, {'text': '12', 'time': 3}, {'text': '123', 'time': 5}],
[{'text': '4', 'time': 6}, {'text': '45', 'time': 7}, {'text': '456', 'time': 8}],
]
"""
pred_sentence_info = []
pred_sentence_index = 0
for predict_item, recv_time_item in zip(predict_data, recv_time):
if len(pred_sentence_info) == pred_sentence_index:
pred_sentence_info.append([])
pred_sentence_info[pred_sentence_index].append(
{
"text": predict_item["recoginition_results"]["text"],
"time": recv_time_item,
}
)
if predict_item["recoginition_results"]["final_result"]:
pred_sentence_index += 1
return pred_sentence_info
class AsrEnEvaluator(AsrEvaluator):
"""
英文的评估方式
"""
def evaluate(self, eval_result) -> Dict:
self.query_count += 1
self.voice_count += len(eval_result["voice"])
# 获取,标注结果 & 语音单元(非句子单元)
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
# print("label_voice_sentences", label_voice_sentences)
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
# print("voice_to_cut_info", voice_to_cut_info)
# 获取,标注结果 & 句子单元
label_sentences = voice_to_cut_info["cut_sentences"]
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
# 标注结果 & 句子单元 & norm 操作
label_sentences = self._sentence_list_norm(label_sentences)
# [self._sentence_norm(sentence) for sentence in label_sentences]
# print("label_sentences", label_sentences)
if log_mid_result:
logger.info(f"label_sentences {label_sentences}")
# 预测结果 & 句子单元
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
# print("predict_sentences_raw", predict_sentences_raw)
# 预测结果 & 句子单元 & norm 操作
predict_sentences = self._sentence_list_norm(predict_sentences_raw)
# [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
# print("predict_sentences", predict_sentences)
if log_mid_result:
logger.info(f"predict_sentences {predict_sentences}")
label_tokens, predict_tokens = self._sentence_transfer(" ".join(label_sentences), " ".join(predict_sentences))
# print(label_tokens)
# print(predict_tokens)
# cer 计算
cer_info = self.cer(label_tokens, predict_tokens)
# print("cer_info", cer_info)
if log_mid_result:
logger.info(f"cer_info {cer_info}")
self.one_minus_cer += cer_info["one_minus_cer"]
self.token_count += cer_info["token_count"]
# 句子切分准召率
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
# print("cut_info", cut_info)
if log_mid_result:
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
self.miss_count += cut_info["miss_count"]
self.more_count += cut_info["more_count"]
self.cut_count += cut_info["cut_count"]
self.rate += cut_info["rate"]
self.result.append(
{
"label_tokens": label_tokens,
"predict_tokens": predict_tokens,
"one_minus_cer": cer_info["one_minus_cer"],
"token_count": cer_info["one_minus_cer"],
"miss_count": cut_info["miss_count"],
"more_count": cut_info["more_count"],
"cut_count": cut_info["cut_count"],
"rate": cut_info["rate"],
}
)
def cer(self, label_tokens, predict_tokens):
s, d, i, c = 0, 0, 0, 0
for label_token, predict_token in zip(label_tokens, predict_tokens):
if label_token == predict_token:
c += 1
elif predict_token is None:
d += 1
elif label_token is None:
i += 1
else:
s += 1
cer = (s + d + i) / (s + d + c)
one_minus_cer = max(1.0 - cer, 0)
token_count = s + d + c
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens, "whitespace"))
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens, "whitespace"))
label_sentence_count = len(label_final_index_list)
miss_count = len(label_final_index_list - pred_final_index_list)
more_count = len(pred_final_index_list - label_final_index_list)
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
return {
"miss_count": miss_count,
"more_count": more_count,
"cut_count": label_sentence_count,
"rate": rate,
"label_final_index_list": label_final_index_list,
"pred_final_index_list": pred_final_index_list,
}
def _sentence_list_norm(self, sentence_list, tokenizer="whitespace"):
pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
with open('./predict.txt', 'w', encoding='utf-8') as fp:
for idx, sentence in enumerate(sentence_list):
fp.write('%s\t%s\n' % (idx, sentence))
subprocess.run(
f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
shell=True,
check=True,
)
sentence_norm = []
with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
for line in fp.readlines():
line_split_result = line.strip().split('\t', 1)
if len(line_split_result) >= 2:
sentence_norm.append(line_split_result[1])
# 有可能没有 norm 后就没了
return sentence_norm
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="whitespace"):
"""
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
args:
label: "HELLO WORLD ARE U OK YEP"
predict: "HELLO WORLD U ARE U OK YEP"
tokenizer: 分词方式
return:
label: ["HELLO", "WORLD", None, "ARE", "U", "OK", "YEP"]
predict: ["HELLO", "WORLD", "U", "ARE", "U", "OK", "YEP"]
"""
from utils.speechio import error_rate_zh as error_rate
if tokenizer == "whitespace":
alignment, score = error_rate.EditDistance(
error_rate.tokenize_text(label_sentence, tokenizer),
error_rate.tokenize_text(predict_sentence, tokenizer),
)
label_tokens, pred_tokens = [], []
for align in alignment:
label_tokens.append(align.ref)
pred_tokens.append(align.hyp)
return (label_tokens, pred_tokens)
else:
logger.error("tokenizer 出错了,暂时不支持其它的")
def post_evaluate(self) -> Dict:
pass

195
utils/evaluator.py Normal file
View File

@@ -0,0 +1,195 @@
# coding: utf-8
import os
from collections import Counter, defaultdict
from itertools import chain
from typing import List
from schemas.context import ASRContext
from utils.logger import logger
from utils.metrics import cer, cut_rate, cut_sentence, first_delay
from utils.metrics import mean_on_counter, patch_unique_token_count
from utils.metrics import revision_delay, text_align, token_mapping
from utils.metrics import var_on_counter
from utils.tokenizer import TOKENIZER_MAPPING, Tokenizer
from utils.update_submit import change_product_available
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", 1) is None
class BaseEvaluator:
def __init__(self) -> None:
self.query_count = 0 # query 数目(语音数目)
self.voice_count = 0
self.fail_count = 0 # 失败数目
# 首字延迟
self.first_delay_sum = 0
self.first_delay_cnt = 0
# 修正延迟
self.revision_delay_sum = 0
self.revision_delay_cnt = 0
# patch token 信息
self.patch_unique_cnt_counter = Counter()
# text align count
self.start_time_align_count = 0
self.end_time_align_count = 0
self.start_end_count = 0
# 1-cer
self.one_minus_cer = 0
self.token_count = 0
# 1-cer language
self.one_minus_cer_lang = defaultdict(int)
self.query_count_lang = defaultdict(int)
# sentence-cut
self.miss_count = 0
self.more_count = 0
self.sentence_count = 0
self.cut_rate = 0
# detail-case
self.context = ASRContext()
# 时延
self.send_interval = []
self.last_recv_interval = []
# 字含量不达标数
self.fail_char_contains_rate_num = 0
# 标点符号
self.punctuation_num = 0
self.pred_punctuation_num = 0
def evaluate(self, context: ASRContext):
self.query_count += 1
self.query_count_lang[context.lang] += 1
voice_count = len(context.labels)
self.voice_count += voice_count
self.punctuation_num += context.punctuation_num
self.pred_punctuation_num += context.pred_punctuation_num
if not context.fail:
# 首字延迟
first_delay_sum, first_delay_cnt = first_delay(context)
self.first_delay_sum += first_delay_sum
self.first_delay_cnt += first_delay_cnt
# 修正延迟
revision_delay_sum, revision_delay_cnt = revision_delay(context)
self.revision_delay_sum += revision_delay_sum
self.revision_delay_cnt += revision_delay_cnt
# patch token 信息
counter = patch_unique_token_count(context)
self.patch_unique_cnt_counter += counter
else:
self.fail_count += 1
self.fail_char_contains_rate_num += context.fail_char_contains_rate_num
# text align count
start_time_align_count, end_time_align_count, start_end_count = text_align(context)
self.start_time_align_count += start_time_align_count
self.end_time_align_count += end_time_align_count
self.start_end_count += start_end_count
# cer, wer
sentences_gt: List[str] = [item.answer for item in context.labels]
sentences_dt: List[str] = [
item.recognition_results.text for item in context.preds if item.recognition_results.final_result
]
if IN_TEST:
print(sentences_gt)
print(sentences_dt)
sentences_gt: List[str] = cut_sentence(sentences_gt, TOKENIZER_MAPPING.get(context.lang))
sentences_dt: List[str] = cut_sentence(sentences_dt, TOKENIZER_MAPPING.get(context.lang))
if IN_TEST:
print(sentences_gt)
print(sentences_dt)
# norm & tokenize
tokens_gt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_gt, context.lang)
tokens_dt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_dt, context.lang)
if IN_TEST:
print(tokens_gt)
print(tokens_dt)
# cer
tokens_gt_mapping, tokens_dt_mapping = token_mapping(list(chain(*tokens_gt)), list(chain(*tokens_dt)))
one_minue_cer, token_count = cer(tokens_gt_mapping, tokens_dt_mapping)
self.one_minus_cer += one_minue_cer
self.token_count += token_count
self.one_minus_cer_lang[context.lang] += one_minue_cer
# cut-rate
rate, sentence_cnt, miss_cnt, more_cnt = cut_rate(tokens_gt, tokens_dt, tokens_gt_mapping, tokens_dt_mapping)
self.cut_rate += rate
self.sentence_count += sentence_cnt
self.miss_count += miss_cnt
self.more_count += more_cnt
# detail-case
self.context = context
# 时延
if self.context.send_time_start_end and self.context.recv_time_start_end:
send_interval = self.context.send_time_start_end[1] - self.context.send_time_start_end[0]
recv_interval = self.context.recv_time_start_end[1] - self.context.send_time_start_end[0]
self.send_interval.append(send_interval)
self.last_recv_interval.append(recv_interval)
logger.info(
f"""第一次发送时间{self.context.send_time_start_end[0]}, \
最后一次发送时间{self.context.send_time_start_end[-1]}, \
发送间隔 {send_interval},
最后一次接收时间{self.context.recv_time_start_end[-1]}, \
接收间隔 {recv_interval}
"""
)
def post_evaluate(self):
pass
def gen_result(self):
result = {
"query_count": self.query_count,
"voice_count": self.voice_count,
"pred_voice_count": self.first_delay_cnt,
"first_delay_mean": self.first_delay_sum / self.first_delay_cnt if self.first_delay_cnt > 0 else 10,
"revision_delay_mean": (
self.revision_delay_sum / self.revision_delay_cnt if self.revision_delay_cnt > 0 else 10
),
"patch_token_mean": mean_on_counter(self.patch_unique_cnt_counter),
"patch_token_var": var_on_counter(self.patch_unique_cnt_counter),
"start_time_align_count": self.start_time_align_count,
"end_time_align_count": self.end_time_align_count,
"start_time_align_rate": self.start_time_align_count / self.sentence_count,
"end_time_align_rate": self.end_time_align_count / self.sentence_count,
"start_end_count": self.start_end_count,
"one_minus_cer": self.one_minus_cer / self.query_count,
"token_count": self.token_count,
"miss_count": self.miss_count,
"more_count": self.more_count,
"sentence_count": self.sentence_count,
"cut_rate": self.cut_rate / self.query_count,
"fail_count": self.fail_count,
"send_interval": self.send_interval,
"last_recv_interval": self.last_recv_interval,
"fail_char_contains_rate_num": self.fail_char_contains_rate_num,
"punctuation_rate": self.pred_punctuation_num / self.punctuation_num,
}
for lang in self.one_minus_cer_lang:
result["one_minus_cer_" + lang] = \
self.one_minus_cer_lang[lang] / self.query_count_lang[lang]
if (
result["first_delay_mean"]
> float(os.getenv("FIRST_DELAY_THRESHOLD", "5"))
or
self.fail_char_contains_rate_num / self.voice_count > 0.1
# or
# result["punctuation_rate"] < 0.8
):
change_product_available()
return result
def gen_detail_case(self):
return self.context

293
utils/evaluator_plus.py Normal file
View File

@@ -0,0 +1,293 @@
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Dict, List, Tuple
import Levenshtein
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel, StreamWordsModel
from utils.metrics import Tokenizer
from utils.metrics_plus import replace_general_punc
from utils.tokenizer import TOKENIZER_MAPPING
def evaluate_editops(
query_data: QueryData, recognition_results: List[StreamDataModel]
) -> Tuple[float, int, Dict[int, int], Dict[int, int], float, float]:
"""返回cer 句子总数 首字对齐情况 尾字对齐情况 首字时间差值和 尾字时间差值和
对齐情况为 时间差值->对齐数"""
recognition_results = deepcopy(recognition_results)
lang = query_data.lang
voices = query_data.voice
sentences_pred = [
recognition_result.text for recognition_result in recognition_results
]
sentences_label = [item.answer for item in voices]
tokenizer_type = TOKENIZER_MAPPING[lang]
sentences_pred = replace_general_punc(sentences_pred, tokenizer_type)
sentences_label = replace_general_punc(sentences_label, tokenizer_type)
# norm & tokenize
tokens_pred = Tokenizer.norm_and_tokenize(sentences_pred, lang)
tokens_label = Tokenizer.norm_and_tokenize(sentences_label, lang)
normed_words = []
for recognition_result in recognition_results:
words = list(map(lambda x: x.text, recognition_result.words))
normed_words.extend(words)
normed_words = replace_general_punc(normed_words, tokenizer_type)
normed_words = Tokenizer.norm(normed_words, lang)
# 预测中的结果进行相同的norm和tokenize操作
normed_word_index = 0
for recognition_result in recognition_results:
next_index = normed_word_index + len(recognition_result.words)
tokens_words = Tokenizer.tokenize(
normed_words[normed_word_index:next_index], lang
)
normed_word_index = next_index
stream_words: List[StreamWordsModel] = []
# 将原words进行norm和tokenize操作后赋值为对应原word的时间
for raw_stream_word, tokens_word in zip(
recognition_result.words, tokens_words
):
for word in tokens_word:
stream_words.append(
StreamWordsModel(
text=word,
start_time=raw_stream_word.start_time,
end_time=raw_stream_word.end_time,
)
)
recognition_result.words = stream_words
# 将words对应上对分词后的词从而使得分词后的词有时间
pred_word_time: List[StreamWordsModel] = []
for token_pred, recognition_result in zip(tokens_pred, recognition_results):
word_index = 0
for word in recognition_result.words:
try:
token_index = token_pred.index(word.text, word_index)
for i in range(word_index, token_index + 1):
pred_word_time.append(
StreamWordsModel(
text=token_pred[i],
start_time=word.start_time,
end_time=word.end_time,
)
)
word_index = token_index + 1
except ValueError:
pass
if len(recognition_result.words) > 0:
word = recognition_result.words[-1]
start_time = word.start_time
end_time = word.end_time
else:
start_time = recognition_result.start_time
end_time = recognition_result.end_time
for i in range(word_index, len(token_pred)):
pred_word_time.append(
StreamWordsModel(
text=token_pred[i],
start_time=start_time,
end_time=end_time,
)
)
# 记录label每句话的首字尾字对应分词后的位置
index = 0
label_firstword_index: List[int] = []
label_lastword_index: List[int] = []
for token_label in tokens_label:
label_firstword_index.append(index)
index += len(token_label)
label_lastword_index.append(index - 1)
# cer
flat_tokens_pred = list(chain(*tokens_pred))
flat_tokens_label = list(chain(*tokens_label))
ops = Levenshtein.editops(flat_tokens_pred, flat_tokens_label)
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
cer = (insert + delete + replace) / len(flat_tokens_label)
# 计算每个token在编辑后的下标位置
pred_offset = [0] * (len(flat_tokens_pred) + 1)
label_offset = [0] * (len(flat_tokens_label) + 1)
for op in ops:
if op[0] == "insert":
pred_offset[op[1]] += 1
elif op[0] == "delete":
label_offset[op[2]] += 1
pred_indexs = [pred_offset[0]]
for i in range(1, len(flat_tokens_pred)):
pred_indexs.append(pred_indexs[i - 1] + pred_offset[i] + 1)
label_indexs = [label_offset[0]]
for i in range(1, len(flat_tokens_label)):
label_indexs.append(label_indexs[i - 1] + label_offset[i] + 1)
# 计算每个label中首字和尾字对应的时间
align_start = {100: 0, 200: 0, 300: 0, 500: 0}
align_end = {100: 0, 200: 0, 300: 0, 500: 0}
first_word_distance_sum = 0.0
last_word_distance_sum = 0.0
for firstword_index, lastword_index, voice in zip(
label_firstword_index, label_lastword_index, voices
):
label_index = label_indexs[firstword_index]
label_in_pred_index = upper_bound(label_index, pred_indexs)
if label_in_pred_index != -1:
distance = abs(
voice.start - pred_word_time[label_in_pred_index].start_time
)
if label_in_pred_index > 0:
distance = min(
distance,
abs(
voice.start
- pred_word_time[label_in_pred_index - 1].start_time
),
)
else:
distance = abs(voice.start - pred_word_time[-1].start_time)
for limit in align_start.keys():
if distance <= limit / 1000:
align_start[limit] += 1
first_word_distance_sum += distance
label_index = label_indexs[lastword_index]
label_in_pred_index = lower_bound(label_index, pred_indexs)
if label_in_pred_index != -1:
distance = abs(
voice.end - pred_word_time[label_in_pred_index].end_time
)
if label_in_pred_index < len(pred_word_time) - 1:
distance = min(
distance,
abs(
voice.end
- pred_word_time[label_in_pred_index + 1].end_time
),
)
else:
distance = abs(voice.end - pred_word_time[0].end_time)
for limit in align_end.keys():
if distance <= limit / 1000:
align_end[limit] += 1
last_word_distance_sum += distance
return (
cer,
len(voices),
align_start,
align_end,
first_word_distance_sum,
last_word_distance_sum,
)
def evaluate_punctuation(
query_data: QueryData, recognition_results: List[StreamDataModel]
) -> Tuple[int, int, int, int]:
"""评估标点符号指标 返回预测中标点数 label中标点数 预测中句子标点数 label中句子标点数"""
punctuation_mapping = defaultdict(lambda: [",", ".", "!", "?"])
punctuation_mapping.update(
{
"zh": ["", "", "", ""],
"ja": ["", "", "", ""],
"ar": ["،", ".", "!", "؟"],
"fa": ["،", ".", "!", "؟"],
"el": [",", ".", "", ""],
"ti": [""],
"th": [" ", ",", ".", "!", "?"],
}
)
punctuation_words: List[StreamWordsModel] = []
for recognition_result in recognition_results:
punctuations = punctuation_mapping[query_data.lang]
for word in recognition_result.words:
for char in word.text:
if char in punctuations:
punctuation_words.append(word)
break
punctuation_start_times = list(
map(lambda x: x.start_time, punctuation_words)
)
punctuation_start_times = sorted(punctuation_start_times)
punctuation_end_times = list(map(lambda x: x.end_time, punctuation_words))
punctuation_end_times = sorted(punctuation_end_times)
voices = query_data.voice
label_len = len(voices)
pred_punctuation_num = len(punctuation_words)
label_punctuation_num = 0
for label_voice in voices:
punctuations = punctuation_mapping[query_data.lang]
for char in label_voice.answer:
if char in punctuations:
label_punctuation_num += 1
pred_sentence_punctuation_num = 0
label_setence_punctuation_num = label_len
for i, label_voice in enumerate(voices):
if i < label_len - 1:
label_left = label_voice.end
label_right = voices[i + 1].start
else:
label_left = label_voice.end - 0.7
label_right = label_voice.end + 0.7
left_in_pred = upper_bound(label_left, punctuation_start_times)
exist = False
if (
left_in_pred != -1
and punctuation_start_times[left_in_pred] <= label_right
):
exist = True
right_in_pred = lower_bound(label_right, punctuation_end_times)
if (
right_in_pred != -1
and punctuation_end_times[right_in_pred] >= label_left
):
exist = True
if exist:
pred_sentence_punctuation_num += 1
return (
pred_punctuation_num,
label_punctuation_num,
pred_sentence_punctuation_num,
label_setence_punctuation_num,
)
def upper_bound(x: float, lst: List[float]) -> int:
"""第一个 >= x 的元素的下标 没有返回-1"""
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] >= x:
ans = mid
right = mid - 1
else:
left = mid + 1
return ans
def lower_bound(x: float, lst: List[float]) -> int:
"""最后一个 <= x 的元素的下标 没有返回-1"""
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] <= x:
ans = mid
left = mid + 1
else:
right = mid - 1
return ans

151
utils/file.py Normal file
View File

@@ -0,0 +1,151 @@
import json
import os
import shutil
import tarfile
import tempfile
import zipfile
from typing import Any
import yaml
def load_json(path: str, raise_for_invalid: bool = False) -> Any:
"""读取path json文件转为对象"""
with open(path, "r", encoding="utf-8") as f:
if raise_for_invalid:
def parse_constant(s: str):
raise ValueError("非法json字符: %s" % s)
return json.load(f, parse_constant=parse_constant)
return json.load(f)
def dump_json(path: str, obj: Any):
"""将obj对象以json形式写入path文件"""
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=4)
def load_yaml(path: str) -> Any:
"""读取path yaml文件转为对象"""
with open(path, "r", encoding="utf-8") as f:
return yaml.full_load(f)
def dump_yaml(path: str, obj: Any):
"""将obj对象以yaml形式写入path文件"""
with open(path, "w", encoding="utf-8") as f:
yaml.dump(obj, f, indent=2, allow_unicode=True, sort_keys=False, line_break="\n")
def dumps_yaml(obj: Any) -> str:
"""将obj对象以yaml形式导出为字符串"""
return yaml.dump(obj, indent=2, allow_unicode=True, sort_keys=False, line_break="\n")
def read_file(path: str) -> str:
"""读取文件为str"""
with open(path, "r") as f:
return f.read()
def write_bfile(path: str, data: bytes):
"""将bytes data写入path文件"""
with open(path, "wb") as f:
f.write(data)
def write_file(path: str, data: str):
"""将str data写入path文件"""
with open(path, "w") as f:
f.write(data)
def tail_file(path: str, tail: int) -> str:
"""倍增获取文件path最后tail行"""
block = 1024
with open(path, "rb") as f:
f.seek(0, 2)
filesize = f.tell()
while True:
if filesize < block:
block = filesize
f.seek(filesize - block, 0)
lines = f.readlines()
if len(lines) > tail or filesize <= block:
return "".join(line.decode() for line in lines[-tail:])
block *= 2
def zip_dir(zip_path: str, dirname: str):
"""将dirname制作为zip_path压缩包"""
with zipfile.ZipFile(zip_path, "w") as ziper:
for path, _, files in os.walk(dirname):
for file in files:
ziper.write(
os.path.join(path, file), os.path.join(path.removeprefix(dirname), file), zipfile.ZIP_DEFLATED
)
def zip_files(name: str, zipfile_paths: list):
"""将zipfiles_paths=list[文件名, 文件路径]制作为name压缩包"""
with zipfile.ZipFile(name, "w") as ziper:
for arcname, zipfile_path in zipfile_paths:
ziper.write(zipfile_path, arcname, zipfile.ZIP_DEFLATED)
def zip_strs(name: str, zipfile_strs: list):
"""将zipfile_strs=list[文件名, 内容]制作为name压缩包"""
with zipfile.ZipFile(name, "w") as ziper:
for filename, content in zipfile_strs:
ziper.writestr(filename, content)
def zip_zipers(name: str, ziper_paths: list):
"""将ziper_paths=list[压缩后名称, 压缩包/文件位置]制作为name压缩包"""
temp_dirname = tempfile.mkdtemp(prefix=name, dir=os.path.dirname(name))
os.makedirs(temp_dirname, exist_ok=True)
for subname, ziper_path in ziper_paths:
sub_dirname = os.path.join(temp_dirname, subname)
if not os.path.exists(ziper_path):
continue
if zipfile.is_zipfile(ziper_path):
# 压缩包解压
os.makedirs(sub_dirname, exist_ok=True)
unzip_dir(ziper_path, sub_dirname)
elif os.path.isfile(ziper_path):
# 文件
shutil.copyfile(ziper_path, sub_dirname)
else:
# 文件夹
shutil.copytree(ziper_path, sub_dirname)
zip_dir(name, temp_dirname)
shutil.rmtree(temp_dirname)
def unzip_dir(zip_path: str, dirname: str, catch_exc: bool = True):
"""将zip_path解压到dirname"""
with zipfile.ZipFile(zip_path, "r") as ziper:
try:
ziper.extractall(dirname)
except Exception as e:
if catch_exc:
write_file(os.path.join(dirname, "unzip_error.log"), "%r" % e)
shutil.copyfile(zip_path, os.path.join(dirname, os.path.basename(zip_path)))
else:
raise e
def tar_dir(zip_path: str, dirname: str):
"""将dirname压缩到zip_path"""
with tarfile.open(zip_path, "w:gz") as ziper:
for path, _, files in os.walk(dirname):
for file in files:
ziper.add(os.path.join(path, file), os.path.join(path.removeprefix(dirname), file))
def untar_dir(zip_path: str, dirname: str):
"""将zip_path解压到dirname"""
with tarfile.open(zip_path) as ziper:
ziper.extractall(dirname)

331
utils/helm.py Normal file
View File

@@ -0,0 +1,331 @@
# -*- coding: utf-8 -*-
import copy
import io
import json
import os
import re
import tarfile
import time
from collections import defaultdict
from typing import Any, Dict, Optional
import requests
from ruamel.yaml import YAML
from utils.logger import logger
sut_chart_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "helm-chart", "sut")
headers = (
{'Authorization': 'Bearer ' + os.getenv("LEADERBOARD_API_TOKEN")} if os.getenv("LEADERBOARD_API_TOKEN") else None
)
pull_num: defaultdict = defaultdict()
JOB_ID = int(os.getenv("JOB_ID", "-1"))
LOAD_SUT_URL = os.getenv("LOAD_SUT_URL")
GET_JOB_SUT_INFO_URL = os.getenv("GET_JOB_SUT_INFO_URL")
def apply_env_to_values(values, envs):
if "env" not in values:
values["env"] = []
old_key_list = [x["name"] for x in values["env"]]
for k, v in envs.items():
try:
idx = old_key_list.index(k)
values["env"][idx]["value"] = v
except ValueError:
values["env"].append({"name": k, "value": v})
return values
def merge_values(base_value, incr_value):
if isinstance(base_value, dict) and isinstance(incr_value, dict):
for k in incr_value:
base_value[k] = merge_values(base_value[k], incr_value[k]) if k in base_value else incr_value[k]
elif isinstance(base_value, list) and isinstance(incr_value, list):
base_value.extend(incr_value)
else:
base_value = incr_value
return base_value
def gen_chart_tarball(docker_image):
"""docker image加上digest并根据image生成helm chart包, 失败直接异常退出
Args:
docker_image (_type_): docker image
Returns:
tuple[BytesIO, dict]: [helm chart包file对象, values内容]
"""
# load values template
with open(os.path.join(sut_chart_root, "values.yaml.tmpl")) as fp:
yaml = YAML(typ="rt")
values = yaml.load(fp)
# update docker_image
get_image_hash_url = os.getenv("GET_IMAGE_HASH_URL", None)
logger.info(f"get_image_hash_url: {get_image_hash_url}")
if get_image_hash_url is not None:
# convert tag to hash for docker_image
#docker_image = "harbor-contest.4pd.io/zhoushasha/speaker_identification:wo_model_v0"
docker_image = "harbor-contest.4pd.io/zhoushasha/image_classification:wo_model_v3"
resp = requests.get(get_image_hash_url, headers=headers, params={"image": docker_image}, timeout=600)
logger.info(f"resp.text: {resp.text}")
assert resp.status_code == 200, "Convert tag to hash for docker image failed, API retcode %d" % resp.status_code
resp = resp.json()
assert resp["success"], "Convert tag to hash for docker image failed, response: %s" % str(resp)
token = resp["data"]["image"].rsplit(":", 2)
assert len(token) == 3, "Invalid docker image %s" % resp["data"]["image"]
values["image"]["repository"] = token[0]
values["image"]["tag"] = ":".join(token[1:])
else:
token = docker_image.rsplit(":", 1)
if len(token) != 2:
raise RuntimeError("Invalid docker image %s" % docker_image)
values["image"]["repository"] = token[0]
values["image"]["tag"] = token[1]
# output values.yaml
with open(os.path.join(sut_chart_root, "values.yaml"), "w") as fp:
yaml = YAML(typ="rt")
yaml.dump(values, fp)
# tarball
tarfp = io.BytesIO()
with tarfile.open(fileobj=tarfp, mode="w:gz") as tar:
tar.add(sut_chart_root, arcname=os.path.basename(sut_chart_root), recursive=True)
tarfp.seek(0)
logger.debug(f"Generated chart using values: {values}")
return tarfp, values
def deploy_chart(
name_suffix,
readiness_timeout,
chart_str=None,
chart_fileobj=None,
extra_values=None,
restart_count_limit=3,
pullimage_count_limit=3,
):
"""部署sut, 失败直接异常退出
Args:
name_suffix (str): 同一个job有多个sut时, 区分不同sut的名称
readiness_timeout (int): readiness超时时间, 单位s
chart_str (int, optional): chart url, 不为None则忽略chart_fileobj. Defaults to None.
chart_fileobj (BytesIO, optional): helm chart包file对象, chart_str不为None使用. Defaults to None.
extra_values (dict, optional): helm values的补充内容. Defaults to None.
restart_count_limit (int, optional): sut重启次数限制, 超出则异常退出. Defaults to 3.
pullimage_count_limit (int, optional): image拉取次数限制, 超出则异常退出. Defaults to 3.
Returns:
tuple[str, str]: [用于访问服务的k8s域名, 用于unload_sut的名称]
"""
logger.info(f"Deploying SUT application for JOB {JOB_ID}, name_suffix {name_suffix}, extra_values {extra_values}")
# deploy
payload = {
"job_id": JOB_ID,
"resource_name": name_suffix,
"priorityclassname": os.environ.get("priorityclassname"),
}
extra_values = {} if not extra_values else extra_values
payload["values"] = json.dumps(extra_values, ensure_ascii=False)
if chart_str is not None:
payload["helm_chart"] = chart_str
resp = requests.post(LOAD_SUT_URL, data=payload, headers=headers, timeout=600)
else:
assert chart_fileobj is not None, "Either chart_str or chart_fileobj should be set"
logger.info(f"LOAD_SUT_URL: {LOAD_SUT_URL}")
logger.info(f"payload: {payload}")
logger.info(f"headers: {headers}")
resp = requests.post(
LOAD_SUT_URL,
data=payload,
headers=headers,
files=[("helm_chart_file", (name_suffix + ".tgz", chart_fileobj))],
timeout=600,
)
if resp.status_code != 200:
raise RuntimeError("Failed to deploy application status_code %d %s" % (resp.status_code, resp.text))
resp = resp.json()
if not resp["success"]:
logger.error("Failed to deploy application response %r", resp)
service_name = resp["data"]["service_name"]
sut_name = resp["data"]["sut_name"]
logger.info(f"SUT application deployed with service_name {service_name}")
# waiting for appliation ready
running_at = None
retry_count = 0
while True:
retry_interval = 10
if retry_count % 20 == 19:
retry_count += 1
logger.info(f"Waiting {retry_interval} seconds to check whether SUT application {service_name} is ready...")
logger.info("20 retrys log this message again.")
time.sleep(retry_interval)
check_result, running_at = check_sut_ready_from_resp(
service_name,
running_at,
readiness_timeout,
restart_count_limit,
pullimage_count_limit,
)
if check_result:
break
logger.info(f"SUT application for JOB {JOB_ID} name_suffix {name_suffix} is ready, service_name {service_name}")
return service_name, sut_name
def check_sut_ready_from_resp(
service_name,
running_at,
readiness_timeout,
restart_count_limit,
pullimage_count_limit,
):
try:
resp = requests.get(
f"{GET_JOB_SUT_INFO_URL}/{JOB_ID}",
headers=headers,
params={"with_detail": True},
timeout=600,
)
except Exception as e:
logger.warning(f"Exception occured while getting SUT application {service_name} status", e)
return False, running_at
if resp.status_code != 200:
logger.warning(f"Get SUT application {service_name} status failed with status_code {resp.status_code}")
return False, running_at
resp = resp.json()
if not resp["success"]:
logger.warning(f"Get SUT application {service_name} status failed with response {resp}")
return False, running_at
if len(resp["data"]["sut"]) == 0:
logger.warning("Empty SUT application status")
return False, running_at
resp_data_sut = copy.deepcopy(resp["data"]["sut"])
for status in resp_data_sut:
del status["detail"]
logger.info(f"Got SUT application status: {resp_data_sut}")
for status in resp["data"]["sut"]:
if status["phase"] in ["Succeeded", "Failed"]:
raise RuntimeError(f"Some pods of SUT application {service_name} terminated with status {status}")
elif status["phase"] in ["Pending", "Unknown"]:
return False, running_at
elif status["phase"] != "Running":
raise RuntimeError(f"Unexcepted pod status {status} of SUT application {service_name}")
if running_at is None:
running_at = time.time()
for ct in status["detail"]["status"]["container_statuses"]:
if ct["restart_count"] > 0:
logger.info(f"pod {status['pod_name']} restart count = {ct['restart_count']}")
if ct["restart_count"] > restart_count_limit:
raise RuntimeError(f"pod {status['pod_name']} restart too many times(over {restart_count_limit})")
if (
ct["state"]["waiting"] is not None
and "reason" in ct["state"]["waiting"]
and ct["state"]["waiting"]["reason"] in ["ImagePullBackOff", "ErrImagePull"]
):
pull_num[status["pod_name"]] += 1
logger.info(
"pod %s has {pull_num[status['pod_name']]} times inspect pulling image info: %s"
% (status["pod_name"], ct["state"]["waiting"])
)
if pull_num[status["pod_name"]] > pullimage_count_limit:
raise RuntimeError(f"pod {status['pod_name']} cannot pull image")
if not status["conditions"]["Ready"]:
if running_at is not None and time.time() - running_at > readiness_timeout:
raise RuntimeError(f"SUT Application readiness has exceeded readiness_timeout:{readiness_timeout}s")
return False, running_at
return True, running_at
def parse_resource(resource):
if resource == -1:
return -1
match = re.match(r"([\d\.]+)([mKMGTPENi]*)", resource)
value, unit = match.groups()
value = float(value)
unit_mapping = {
"": 1,
"m": 1e-3,
"K": 1e3,
"M": 1e6,
"G": 1e9,
"T": 1e12,
"P": 1e15,
"E": 1e18,
"Ki": 2**10,
"Mi": 2**20,
"Gi": 2**30,
"Ti": 2**40,
"Pi": 2**50,
"Ei": 2**60,
}
if unit not in unit_mapping:
raise ValueError(f"Unknown resources unit: {unit}")
return value * unit_mapping[unit]
def limit_resources(resource):
if "limits" not in resource:
return resource
if "cpu" in resource["limits"]:
cpu_limit = parse_resource(resource["limits"]["cpu"])
if cpu_limit > 30:
logger.error("CPU limit exceeded. Adjusting to 30 cores.")
resource["limits"]["cpu"] = "30"
if "memory" in resource["limits"]:
memory_limit = parse_resource(resource["limits"]["memory"])
if memory_limit > 100 * 2**30:
logger.error("Memory limit exceeded, adjusting to 100Gi")
resource["limits"]["memory"] = "100Gi"
def consistent_resources(resource):
if "limits" not in resource and "requests" not in resource:
return resource
elif "limits" in resource:
resource["requests"] = resource["limits"]
else:
resource["limits"] = resource["requests"]
return resource
def resource_check(values: Dict[str, Any]):
resources = values.get("resources", {}).get("limits", {})
if "nvidia.com/gpu" in resources and int(resources["nvidia.com/gpu"]) > 0:
values["resources"]["limits"]["nvidia.com/gpumem"] = 8192
values["resources"]["limits"]["nvidia.com/gpucores"] = 10
values["resources"]["requests"] = values["resources"].get("requests", {})
if "cpu" not in values["resources"]["requests"] and "cpu" in values["resources"]["limits"]:
values["resources"]["requests"]["cpu"] = values["resources"]["limits"]["cpu"]
if "memory" not in values["resources"]["requests"] and "memory" in values["resources"]["limits"]:
values["resources"]["requests"]["memory"] = values["resources"]["limits"]["memory"]
values["resources"]["requests"]["nvidia.com/gpu"] = values["resources"]["limits"]["nvidia.com/gpu"]
values["resources"]["requests"]["nvidia.com/gpumem"] = 8192
values["resources"]["requests"]["nvidia.com/gpucores"] = 10
values["nodeSelector"] = values.get("nodeSelector", {})
if "contest.4pd.io/accelerator" not in values["nodeSelector"]:
values["nodeSelector"]["contest.4pd.io/accelerator"] = "A100-SXM4-80GBvgpu"
gpu_type = values["nodeSelector"]["contest.4pd.io/accelerator"]
gpu_num = resources["nvidia.com/gpu"]
if gpu_type != "A100-SXM4-80GBvgpu":
raise RuntimeError("GPU类型只能为A100-SXM4-80GBvgpu")
if gpu_num != 1:
raise RuntimeError("GPU个数只能为1")
values["tolerations"] = values.get("tolerations", [])
values["tolerations"].append(
{
"key": "hosttype",
"operator": "Equal",
"value": "vgpu",
"effect": "NoSchedule",
}
)
return values

38
utils/leaderboard.py Normal file
View File

@@ -0,0 +1,38 @@
from utils.request import requests_retry_session
import os
import json
import traceback
from utils.logger import logger
lb_headers = {"Content-Type":"application/json"}
if os.getenv("LEADERBOARD_API_TOKEN"):
lb_headers['Authorization'] = 'Bearer ' + os.getenv("LEADERBOARD_API_TOKEN")
def change_product_unavailable() -> None:
logger.info("更改为产品不可用...")
submit_id = str(os.getenv("SUBMIT_ID", -1))
try:
requests_retry_session().post(
os.getenv("UPDATE_SUBMIT_URL", "http://contest.4pd.io:8080/submit/update"),
data=json.dumps({submit_id: {"product_avaliable": 0}}),
headers=lb_headers,
)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"change product avaliable error, {e}")
def mark_evaluating(task_id) -> None:
logger.info("上报EVALUATING状态...")
job_id = os.getenv('JOB_ID') or "-1"
url = os.getenv("REGISTER_MARK_TASK_URL", "http://contest.4pd.io:8080/job/register_mark_task") + "/" + job_id
try:
requests_retry_session().post(
url,
data=json.dumps({"task_id": task_id}),
headers=lb_headers,
)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"mark evaluating error, {e}")

30
utils/logger.py Normal file
View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import logging
import os
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
# another logger
log = logging.getLogger("detailed_logger")
log.propagate = False
level = logging.INFO
log.setLevel(level)
formatter = logging.Formatter(
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
"%Y-%m-%d %H:%M:%S",
)
streamHandler = logging.StreamHandler()
streamHandler.setLevel(level)
streamHandler.setFormatter(formatter)
log.addHandler(streamHandler)

320
utils/metrics.py Normal file
View File

@@ -0,0 +1,320 @@
# coding: utf-8
import os
from collections import Counter
from copy import deepcopy
from typing import List, Tuple
import Levenshtein
import numpy as np
from schemas.context import ASRContext
from utils.logger import logger
from utils.tokenizer import Tokenizer, TokenizerType
from utils.update_submit import change_product_available
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
def text_align(context: ASRContext) -> Tuple:
start_end_count = 0
label_start_time_list = []
label_end_time_list = []
for label_item in context.labels:
label_start_time_list.append(label_item.start)
label_end_time_list.append(label_item.end)
pred_start_time_list = []
pred_end_time_list = []
sentence_start = True
for pred_item in context.preds:
if sentence_start:
pred_start_time_list.append(pred_item.recognition_results.start_time)
if pred_item.recognition_results.final_result:
pred_end_time_list.append(pred_item.recognition_results.end_time)
sentence_start = pred_item.recognition_results.final_result
# check start0 < end0 < start1 < end1 < start2 < end2 - ...
if IN_TEST:
print(pred_start_time_list)
print(pred_end_time_list)
pred_time_list = []
i, j = 0, 0
while i < len(pred_start_time_list) and j < len(pred_end_time_list):
pred_time_list.append(pred_start_time_list[i])
pred_time_list.append(pred_end_time_list[j])
i += 1
j += 1
if i < len(pred_start_time_list):
pred_time_list.append(pred_start_time_list[-1])
for i in range(1, len(pred_time_list)):
# 这里给个 600ms 的宽限
if pred_time_list[i] < pred_time_list[i - 1] - 0.6:
logger.error("识别的 start、end 不符合 start0 < end0 < start1 < end1 < start2 < end2 ...")
logger.error(
f"当前识别的每个句子开始和结束时间分别为: \
开始时间:{pred_start_time_list}, \
结束时间:{pred_end_time_list}"
)
start_end_count += 1
# change_product_available()
# 时间前后差值 300ms 范围内
start_time_align_count = 0
end_time_align_count = 0
for label_start_time in label_start_time_list:
for pred_start_time in pred_start_time_list:
if pred_start_time <= label_start_time + 0.3 and pred_start_time >= label_start_time - 0.3:
start_time_align_count += 1
break
for label_end_time in label_end_time_list:
for pred_end_time in pred_end_time_list:
if pred_end_time <= label_end_time + 0.3 and pred_end_time >= label_end_time - 0.3:
end_time_align_count += 1
break
logger.info(
f"start-time 对齐个数 {start_time_align_count}, \
end-time 对齐个数 {end_time_align_count}\
数据集中句子总数 {len(label_start_time_list)}"
)
return start_time_align_count, end_time_align_count, start_end_count
def first_delay(context: ASRContext) -> Tuple:
first_send_time = context.preds[0].send_time
first_delay_list = []
sentence_start = True
for pred_context in context.preds:
if sentence_start:
sentence_begin_time = pred_context.recognition_results.start_time
first_delay_time = pred_context.recv_time - first_send_time - sentence_begin_time
first_delay_list.append(first_delay_time)
sentence_start = pred_context.recognition_results.final_result
if IN_TEST:
print(f"当前音频的首字延迟为{first_delay_list}")
logger.info(f"当前音频的首字延迟均值为 {np.mean(first_delay_list)}s")
return np.sum(first_delay_list), len(first_delay_list)
def revision_delay(context: ASRContext):
first_send_time = context.preds[0].send_time
revision_delay_list = []
for pred_context in context.preds:
if pred_context.recognition_results.final_result:
sentence_end_time = pred_context.recognition_results.end_time
revision_delay_time = pred_context.recv_time - first_send_time - sentence_end_time
revision_delay_list.append(revision_delay_time)
if IN_TEST:
print(revision_delay_list)
logger.info(f"当前音频的修正延迟均值为 {np.mean(revision_delay_list)}s")
return np.sum(revision_delay_list), len(revision_delay_list)
def patch_unique_token_count(context: ASRContext):
# print(context.__dict__)
# 对于每一个返回的结果都进行 tokenize
pred_text_list = [pred_context.recognition_results.text for pred_context in context.preds]
pred_text_tokenized_list = Tokenizer.norm_and_tokenize(pred_text_list, lang=context.lang)
# print(pred_text_list)
# print(pred_text_tokenized_list)
# 判断当前是否修改了超过 3s 内的 token 数目
## 当前句子的最开始接受时间
first_recv_time = None
## 不可修改的 token 个数
unmodified_token_cnt = 0
## 3s 的 index 位置
time_token_idx = 0
## 当前是句子的开始
final_sentence = True
## 修改了不可修改的范围
is_unmodified_token = False
for idx, (now_tokens, pred_context) in enumerate(zip(pred_text_tokenized_list, context.preds)):
## 当前是句子的第一次返回
if final_sentence:
first_recv_time = pred_context.recv_time
unmodified_token_cnt = 0
time_token_idx = idx
final_sentence = pred_context.recognition_results.final_result
continue
final_sentence = pred_context.recognition_results.final_result
## 当前 pred 的 recv-time
pred_recv_time = pred_context.recv_time
## 最开始 3s 直接忽略
if pred_recv_time - first_recv_time < 3:
continue
## 根据历史返回信息,获得最长不可修改长度
while time_token_idx < idx:
context_pred_tmp = context.preds[time_token_idx]
context_pred_tmp_recv_time = context_pred_tmp.recv_time
tmp_tokens = pred_text_tokenized_list[time_token_idx]
if pred_recv_time - context_pred_tmp_recv_time >= 3:
unmodified_token_cnt = max(unmodified_token_cnt, len(tmp_tokens))
time_token_idx += 1
else:
break
## 和自己的上一条音频比,只能修改 unmodified_token_cnt 个 token
last_tokens = pred_text_tokenized_list[idx - 1]
if context.lang in ['ar', 'he']:
tokens_check_pre, tokens_check_now = last_tokens[::-1], now_tokens[::-1]
continue
else:
tokens_check_pre, tokens_check_now = last_tokens, now_tokens
for token_a, token_b in zip(tokens_check_pre[:unmodified_token_cnt], tokens_check_now[:unmodified_token_cnt]):
if token_a != token_b:
is_unmodified_token = True
break
if is_unmodified_token and int(os.getenv('test', 0)):
logger.error(
f"{idx}-{unmodified_token_cnt}-{last_tokens[:unmodified_token_cnt]}-{now_tokens[:unmodified_token_cnt]}"
)
if is_unmodified_token:
break
if is_unmodified_token:
logger.error("修改了不可修改的文字范围")
# change_product_available()
if int(os.getenv('test', 0)):
final_result = True
result_list = []
for tokens, pred in zip(pred_text_tokenized_list, context.preds):
if final_result:
result_list.append([])
result_list[-1].append((tokens, pred.recv_time - context.preds[0].recv_time))
final_result = pred.recognition_results.final_result
for item in result_list:
logger.info(str(item))
# 记录每个 patch 的 token 个数
patch_unique_cnt_counter = Counter()
patch_unique_cnt_in_one_sentence = set()
for pred_text_tokenized, pred_context in zip(pred_text_tokenized_list, context.preds):
token_cnt = len(pred_text_tokenized)
patch_unique_cnt_in_one_sentence.add(token_cnt)
if pred_context.recognition_results.final_result:
for unique_cnt in patch_unique_cnt_in_one_sentence:
patch_unique_cnt_counter[unique_cnt] += 1
patch_unique_cnt_in_one_sentence.clear()
if context.preds and not context.preds[-1].recognition_results.final_result:
for unique_cnt in patch_unique_cnt_in_one_sentence:
patch_unique_cnt_counter[unique_cnt] += 1
# print(patch_unique_cnt_counter)
logger.info(
f"当前音频的 patch token 均值为 {mean_on_counter(patch_unique_cnt_counter)}, \
当前音频的 patch token 方差为 {var_on_counter(patch_unique_cnt_counter)}"
)
return patch_unique_cnt_counter
def mean_on_counter(counter: Counter):
total_sum = sum(key * count for key, count in counter.items())
total_count = sum(counter.values())
return total_sum * 1.0 / total_count
def var_on_counter(counter: Counter):
total_sum = sum(key * count for key, count in counter.items())
total_count = sum(counter.values())
mean = total_sum * 1.0 / total_count
return sum((key - mean) ** 2 * count for key, count in counter.items()) / total_count
def edit_distance(arr1: List, arr2: List):
operations = Levenshtein.editops(arr1, arr2)
i = sum([1 for operation in operations if operation[0] == "insert"])
s = sum([1 for operation in operations if operation[0] == "replace"])
d = sum([1 for operation in operations if operation[0] == "delete"])
c = len(arr1) - s - d
return s, d, i, c
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
insert = sum(1 for item in tokens_gt_mapping if item is None)
delete = sum(1 for item in tokens_dt_mapping if item is None)
equal = sum(1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt)
replace = len(tokens_gt_mapping) - insert - equal
token_count = replace + equal + delete
cer_value = (replace + delete + insert) * 1.0 / token_count
logger.info(f"当前音频的 cer/wer 值为 {cer_value}, token 个数为 {token_count}")
return 1 - cer_value, token_count
def cut_rate(
tokens_gt: List[List[str]],
tokens_dt: List[List[str]],
tokens_gt_mapping: List[str],
tokens_dt_mapping: List[str],
):
sentence_final_token_index_gt = sentence_final_token_index(tokens_gt, tokens_gt_mapping)
sentence_final_token_index_dt = sentence_final_token_index(tokens_dt, tokens_dt_mapping)
sentence_final_token_index_gt = set(sentence_final_token_index_gt)
sentence_final_token_index_dt = set(sentence_final_token_index_dt)
sentence_count_gt = len(sentence_final_token_index_gt)
miss_count = len(sentence_final_token_index_gt - sentence_final_token_index_dt)
more_count = len(sentence_final_token_index_dt - sentence_final_token_index_gt)
rate = max(1 - (miss_count + more_count * 2) / sentence_count_gt, 0)
return rate, sentence_count_gt, miss_count, more_count
def token_mapping(tokens_gt: List[str], tokens_dt: List[str]) -> Tuple[List[str], List[str]]:
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
for op in operations[::-1]:
if op[0] == "insert":
arr1.insert(op[1], None)
elif op[0] == "delete":
arr2.insert(op[2], None)
return arr1, arr2
def sentence_final_token_index(tokens: List[List[str]], tokens_mapping: List[str]) -> List[int]:
"""获得原句子中每个句子尾部 token 的 index"""
token_index_list = []
token_index = 0
for token_in_one_sentence in tokens:
for _ in range(len(token_in_one_sentence)):
while token_index < len(tokens_mapping) and tokens_mapping[token_index] is None:
token_index += 1
token_index += 1
token_index_list.append(token_index - 1)
return token_index_list
def cut_sentence(sentences: List[str], tokenizerType: TokenizerType) -> List[str]:
"""use self.cut_punc to cut all sentences, merge them and put them into list"""
sentence_cut_list = []
for sentence in sentences:
sentence_list = [sentence]
sentence_tmp_list = []
for punc in [
"······",
"......",
"",
"",
"",
"",
"",
"",
"...",
".",
",",
"?",
"!",
";",
":",
]:
for sentence in sentence_list:
sentence_tmp_list.extend(sentence.split(punc))
sentence_list, sentence_tmp_list = sentence_tmp_list, []
sentence_list = [item for item in sentence_list if item]
if tokenizerType == TokenizerType.whitespace:
sentence_cut_list.append(" ".join(sentence_list))
else:
sentence_cut_list.append("".join(sentence_list))
return sentence_cut_list

50
utils/metrics_plus.py Normal file
View File

@@ -0,0 +1,50 @@
from typing import List
from utils.tokenizer import TokenizerType
def replace_general_punc(
sentences: List[str], tokenizer: TokenizerType
) -> List[str]:
"""代替原来的函数 utils.metrics.cut_sentence"""
general_puncs = [
"······",
"......",
"",
"",
"",
"",
"",
"",
"...",
".",
",",
"?",
"!",
";",
":",
]
if tokenizer == TokenizerType.whitespace:
replacer = " "
else:
replacer = ""
trans = str.maketrans(dict.fromkeys("".join(general_puncs), replacer))
ret_sentences = [""] * len(sentences)
for i, sentence in enumerate(sentences):
sentence = sentence.translate(trans)
sentence = sentence.strip()
sentence = sentence.lower()
ret_sentences[i] = sentence
return ret_sentences
def distance_point_line(
point: float, line_start: float, line_end: float
) -> float:
"""计算点到直线的距离"""
if line_start <= point <= line_end:
return 0
if point < line_start:
return abs(point - line_start)
else:
return abs(point - line_end)

93
utils/pynini/Dockerfile Normal file
View File

@@ -0,0 +1,93 @@
# Dockerfile
# Pierre-André Noël, May 12th 2020
# Copyright © Element AI Inc. All rights reserved.
# Apache License, Version 2.0
#
# This builds `manylinux_2_28_x86_64` Python wheels for `pynini`, wrapping
# all its dependencies.
#
# This Dockerfile uses multi-stage builds; for more information, see:
# https://docs.docker.com/develop/develop-images/multistage-build/
#
# The recommended installation method for Pynini is through Conda-Forge. This gives Linux
# x86-64 users another option: installing a precompiled module from PyPI.
#
#
# To build wheels and run Pynini's tests, run:
#
# docker build --target=run-tests -t build-pynini-wheels .
#
# To extract the resulting wheels from the Docker image, run:
#
# docker run --rm -v `pwd`:/io build-pynini-wheels cp -r /wheelhouse /io
#
# Notice that this also generates Cython wheels.
#
# Then, `twine` (https://twine.readthedocs.io/en/latest/) can be used to
# publish the resulting Pynini wheels.
# ******************************************************
# *** All the following images are based on this one ***
# ******************************************************
#from quay.io/pypa/manylinux_2_28_x86_64 AS common
# ***********************************************************************
# *** Image providing all the requirements for building Pynini wheels ***
# ***********************************************************************
FROM harbor.4pd.io/inf/base-python3.8-ubuntu:1.1.0
# The versions we want in the wheels.
ENV FST_VERSION "1.8.3"
ENV PYNINI_VERSION "2.1.6"
# Location of OpenFst and Pynini.
ENV FST_DOWNLOAD_PREFIX "https://www.openfst.org/twiki/pub/FST/FstDownload"
ENV PYNINI_DOWNLOAD_PREFIX "https://www.opengrm.org/twiki/pub/GRM/PyniniDownload"
# Note that our certificates are not known to the version of wget available in this image.
# Gets and unpack OpenFst source.
RUN apt update && apt-get install -y wget gcc-9 g++-9 make && ln -s $(which gcc-9) /usr/bin/gcc && ln -s $(which g++-9) /usr/bin/g++
RUN cd /tmp \
&& wget -q --no-check-certificate "${FST_DOWNLOAD_PREFIX}/openfst-${FST_VERSION}.tar.gz" \
&& tar -xzf "openfst-${FST_VERSION}.tar.gz" \
&& rm "openfst-${FST_VERSION}.tar.gz"
# Compiles OpenFst.
RUN cd "/tmp/openfst-${FST_VERSION}" \
&& ./configure --enable-grm \
&& make --jobs 4 install \
&& rm -rd "/tmp/openfst-${FST_VERSION}"
# Gets and unpacks Pynini source.
RUN mkdir -p /src && cd /src \
&& wget -q --no-check-certificate "${PYNINI_DOWNLOAD_PREFIX}/pynini-${PYNINI_VERSION}.tar.gz" \
&& tar -xzf "pynini-${PYNINI_VERSION}.tar.gz" \
&& rm "pynini-${PYNINI_VERSION}.tar.gz"
# Installs requirements in all our Pythons.
RUN pip install -i https://nexus.4pd.io/repository/pypi-all/simple -r "/src/pynini-${PYNINI_VERSION}/requirements.txt" || exit;
# **********************************************************
# *** Image making pynini wheels (placed in /wheelhouse) ***
# **********************************************************
#FROM wheel-building-env AS build-wheels
# Compiles the wheels to a temporary directory.
RUN pip wheel -i https://nexus.4pd.io/repository/pypi-all/simple -v "/src/pynini-${PYNINI_VERSION}" -w /tmp/wheelhouse/ || exit;
RUN wget ftp://ftp.4pd.io/pub/pico/temp/patchelf-0.18.0-x86_64.tar.gz && tar xzf patchelf-0.18.0-x86_64.tar.gz && rm -f patchelf-0.18.0-x86_64.tar.gz
RUN pip install -i https://nexus.4pd.io/repository/pypi-all/simple auditwheel
# Bundles external shared libraries into the wheels.
# See https://github.com/pypa/manylinux/tree/manylinux2014
RUN for WHL in /tmp/wheelhouse/pynini*.whl; do \
PATH=$(pwd)/bin:$PATH auditwheel repair --plat manylinux_2_31_x86_64 "${WHL}" -w /wheelhouse/ || exit; \
done
#RUN mkdir -p /wheelhouse && for WHL in /tmp/wheelhouse/pynini*.whl; do \
# cp "${WHL}" /wheelhouse/; \
#done
# Removes the non-repaired wheels.
RUN rm -rd /tmp/wheelhouse

17
utils/pynini/README.md Normal file
View File

@@ -0,0 +1,17 @@
# pynini
## 背景
SpeechIO对英文ASR的评估工具依赖第三方库pyninihttps://github.com/kylebgorman/pynini该库强绑定OS和gcc版本需要在运行环境中编译生成wheel包本文说明编译pynini生成wheel包的方法
## 编译
```shell
docker build -t build-pynini-wheels .
```
## 获取wheel包
```shell
docker run --rm -v `pwd`:/io build-pynini-wheels cp -r /wheelhouse /io
```

40
utils/request.py Normal file
View File

@@ -0,0 +1,40 @@
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
DEFAULT_TIMEOUT = 2 * 60 # seconds
class TimeoutHTTPAdapter(HTTPAdapter):
def __init__(self, *args, **kwargs):
self.timeout = DEFAULT_TIMEOUT
if "timeout" in kwargs:
self.timeout = kwargs["timeout"]
del kwargs["timeout"]
super().__init__(*args, **kwargs)
def send(self, request, **kwargs):
timeout = kwargs.get("timeout")
if timeout is None:
kwargs["timeout"] = self.timeout
return super().send(request, **kwargs)
def requests_retry_session(
retries=3,
backoff_factor=1,
status_forcelist=[500, 502, 504, 404, 403],
session=None,
):
session = session or requests.Session()
retry = Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
)
adapter = TimeoutHTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
return session

65
utils/service.py Normal file
View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
import os
import sys
from utils.helm import deploy_chart, gen_chart_tarball
from utils.logger import logger
def register_sut(st_config, resource_name, **kwargs):
job_id = "".join([c for c in str(os.getenv("JOB_ID", -1)) if c.isnumeric()])
docker_image = "10.255.143.18:5000/speaker_identification:wo_model_v0"
#if "docker_image" in st_config and st_config["docker_image"]:
st_config_values = st_config.get("values", {})
#docker_image = st_config["docker_image"]
docker_image = "10.255.143.18:5000/speaker_identification:wo_model_v0"
chart_tar_fp, chart_values = gen_chart_tarball(docker_image)
sut_service_name, _ = deploy_chart(
resource_name,
int(os.getenv("readiness_timeout", 60 * 3)),
chart_fileobj=chart_tar_fp,
extra_values=st_config_values,
restart_count_limit=int(os.getenv('restart_count', 3)),
)
chart_tar_fp.close()
if st_config_values is not None and "service" in st_config_values and "port" in st_config_values["service"]:
sut_service_port = str(st_config_values["service"]["port"])
else:
sut_service_port = str(chart_values["service"]["port"])
return "ws://{}:{}".format(sut_service_name, sut_service_port)
"""
elif "chart_repo" in st_config:
logger.info(f"正在使用 helm-chart 配置,内容为 {st_config}")
chart_repo = st_config.get("chart_repo", None)
chart_name = st_config.get("chart_name", None)
chart_version = st_config.get("chart_version", None)
if chart_repo is None or chart_name is None or chart_version is None:
logger.error("chart_repo, chart_name, chart_version cant be none")
logger.info(f"{chart_repo} {chart_name} {chart_version}")
chart_str = os.path.join(chart_repo, chart_name) + ':' + chart_version
st_cfg_values = st_config.get('values', {})
st_config["values"] = st_cfg_values
sut_service_name, _ = deploy_chart(
resource_name,
600,
chart_str=chart_str,
extra_values=st_cfg_values,
)
sut_service_name = f"asr-{job_id}"
if st_cfg_values is not None and 'service' in st_cfg_values and 'port' in st_cfg_values['service']:
sut_service_port = str(st_cfg_values['service']['port'])
else:
sut_service_port = '80'
return 'ws://%s:%s' % (sut_service_name, sut_service_port)
else:
logger.error("配置信息错误,缺少 docker_image 属性")
#sys.exit(-1)
"""

View File

@@ -0,0 +1,3 @@
'''
reference: https://github.com/SpeechColab/Leaderboard/tree/f287a992dc359d1c021bfc6ce810e5e36608e057/utils
'''

View File

@@ -0,0 +1,551 @@
#!/usr/bin/env python3
# coding=utf8
# Copyright 2022 Zhenxiang MA, Jiayu DU (SpeechColab)
import argparse
import csv
import json
import logging
import os
import sys
from typing import Iterable
logging.basicConfig(stream=sys.stderr, level=logging.ERROR, format='[%(levelname)s] %(message)s')
import pynini
from pynini.lib import pynutil
# reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py
# to import original lib:
# from pynini.lib.edit_transducer import EditTransducer
class EditTransducer:
DELETE = "<delete>"
INSERT = "<insert>"
SUBSTITUTE = "<substitute>"
def __init__(
self,
symbol_table,
vocab: Iterable[str],
insert_cost: float = 1.0,
delete_cost: float = 1.0,
substitute_cost: float = 1.0,
bound: int = 0,
):
# Left factor; note that we divide the edit costs by two because they also
# will be incurred when traversing the right factor.
sigma = pynini.union(
*[pynini.accep(token, token_type=symbol_table) for token in vocab],
).optimize()
insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2)
delete = pynini.cross(sigma, pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2))
substitute = pynini.cross(sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2))
edit = pynini.union(insert, delete, substitute).optimize()
if bound:
sigma_star = pynini.closure(sigma)
self._e_i = sigma_star.copy()
for _ in range(bound):
self._e_i.concat(edit.ques).concat(sigma_star)
else:
self._e_i = edit.union(sigma).closure()
self._e_i.optimize()
right_factor_std = EditTransducer._right_factor(self._e_i)
# right_factor_ext allows 0-cost matching between token's raw form & auxiliary form
# e.g.: 'I' -> 'I#', 'AM' -> 'AM#'
right_factor_ext = (
pynini.union(
*[
pynini.cross(
pynini.accep(x, token_type=symbol_table),
pynini.accep(x + '#', token_type=symbol_table),
)
for x in vocab
]
)
.optimize()
.closure()
)
self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize()
@staticmethod
def _right_factor(ifst: pynini.Fst) -> pynini.Fst:
ofst = pynini.invert(ifst)
syms = pynini.generated_symbols()
insert_label = syms.find(EditTransducer.INSERT)
delete_label = syms.find(EditTransducer.DELETE)
pairs = [(insert_label, delete_label), (delete_label, insert_label)]
right_factor = ofst.relabel_pairs(ipairs=pairs)
return right_factor
def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst:
lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr)
EditTransducer.check_wellformed_lattice(lattice)
return lattice
@staticmethod
def check_wellformed_lattice(lattice: pynini.Fst) -> None:
if lattice.start() == pynini.NO_STATE_ID:
raise RuntimeError("Edit distance composition lattice is empty.")
def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float:
lattice = self.create_lattice(iexpr, oexpr)
# The shortest cost from all final states to the start state is
# equivalent to the cost of the shortest path.
start = lattice.start()
return float(pynini.shortestdistance(lattice, reverse=True)[start])
def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike:
print(iexpr)
print(oexpr)
lattice = self.create_lattice(iexpr, oexpr)
alignment = pynini.shortestpath(lattice, nshortest=1, unique=True)
return alignment.optimize()
class ErrorStats:
def __init__(self):
self.num_ref_utts = 0
self.num_hyp_utts = 0
self.num_eval_utts = 0 # in both ref & hyp
self.num_hyp_without_ref = 0
self.C = 0
self.S = 0
self.I = 0
self.D = 0
self.token_error_rate = 0.0
self.modified_token_error_rate = 0.0
self.num_utts_with_error = 0
self.sentence_error_rate = 0.0
def to_json(self):
# return json.dumps(self.__dict__, indent=4)
return json.dumps(self.__dict__)
def to_kaldi(self):
info = (
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
)
return info
def to_summary(self):
summary = (
'==================== Overall Statistics ====================\n'
F'num_ref_utts: {self.num_ref_utts}\n'
F'num_hyp_utts: {self.num_hyp_utts}\n'
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
F'num_eval_utts: {self.num_eval_utts}\n'
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
F'token_error_rate: {self.token_error_rate:.2f}%\n'
F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n'
F'token_stats:\n'
F' - tokens:{self.C + self.S + self.D:>7}\n'
F' - edits: {self.S + self.I + self.D:>7}\n'
F' - cor: {self.C:>7}\n'
F' - sub: {self.S:>7}\n'
F' - ins: {self.I:>7}\n'
F' - del: {self.D:>7}\n'
'============================================================\n'
)
return summary
class Utterance:
def __init__(self, uid, text):
self.uid = uid
self.text = text
def LoadKaldiArc(filepath):
utts = {}
with open(filepath, 'r', encoding='utf8') as f:
for line in f:
line = line.strip()
if line:
cols = line.split(maxsplit=1)
assert len(cols) == 2 or len(cols) == 1
uid = cols[0]
text = cols[1] if len(cols) == 2 else ''
if utts.get(uid) != None:
raise RuntimeError(F'Found duplicated utterence id {uid}')
utts[uid] = Utterance(uid, text)
return utts
def BreakHyphen(token: str):
# 'T-SHIRT' should also introduce new words into vocabulary, e.g.:
# 1. 'T' & 'SHIRT'
# 2. 'TSHIRT'
assert '-' in token
v = token.split('-')
v.append(token.replace('-', ''))
return v
def LoadGLM(rel_path):
'''
glm.csv:
I'VE,I HAVE
GOING TO,GONNA
...
T-SHIRT,T SHIRT,TSHIRT
glm:
{
'<RULE_00000>': ["I'VE", 'I HAVE'],
'<RULE_00001>': ['GOING TO', 'GONNA'],
...
'<RULE_99999>': ['T-SHIRT', 'T SHIRT', 'TSHIRT'],
}
'''
logging.info(f'Loading GLM from {rel_path} ...')
abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
reader = list(csv.reader(open(abs_path, encoding="utf-8"), delimiter=','))
glm = {}
for k, rule in enumerate(reader):
rule_name = f'<RULE_{k:06d}>'
glm[rule_name] = [phrase.strip() for phrase in rule]
logging.info(f' #rule: {len(glm)}')
return glm
def SymbolEQ(symbol_table, i1, i2):
return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#')
def PrintSymbolTable(symbol_table: pynini.SymbolTable):
print('SYMBOL_TABLE:')
for k in range(symbol_table.num_symbols()):
sym = symbol_table.find(k)
assert symbol_table.find(sym) == k # symbol table's find can be used for bi-directional lookup (id <-> sym)
print(k, sym)
print()
def BuildSymbolTable(vocab) -> pynini.SymbolTable:
logging.info('Building symbol table ...')
symbol_table = pynini.SymbolTable()
symbol_table.add_symbol('<epsilon>')
for w in vocab:
symbol_table.add_symbol(w)
logging.info(f' #symbols: {symbol_table.num_symbols()}')
# PrintSymbolTable(symbol_table)
# symbol_table.write_text('symbol_table.txt')
return symbol_table
def BuildGLMTagger(glm, symbol_table) -> pynini.Fst:
logging.info('Building GLM tagger ...')
rule_taggers = []
for rule_tag, rule in glm.items():
for phrase in rule:
rule_taggers.append(
(
pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
+ pynini.accep(phrase, token_type=symbol_table)
+ pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
)
)
alphabet = pynini.union(
*[pynini.accep(sym, token_type=symbol_table) for k, sym in symbol_table if k != 0] # non-epsilon
).optimize()
tagger = pynini.cdrewrite(
pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure()
).optimize() # could be slow with large vocabulary
return tagger
def TokenWidth(token: str):
def CharWidth(c):
return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1
return sum([CharWidth(c) for c in token])
def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream=sys.stderr):
assert len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali)
H = ' HYP# : '
R = ' REF : '
E = ' EDIT : '
for i, e in enumerate(edit_ali):
h, r = hyp_ali[i], ref_ali[i]
e = '' if e == 'C' else e # don't bother printing correct edit-tag
nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e)
n = max(nr, nh, ne) + 1
H += h + ' ' * (n - nh)
R += r + ' ' * (n - nr)
E += e + ' ' * (n - ne)
print(F' HYP : {raw_hyp}', file=stream)
print(H, file=stream)
print(R, file=stream)
print(E, file=stream)
def ComputeTokenErrorRate(c, s, i, d):
assert (s + d + c) != 0
num_edits = s + d + i
ref_len = c + s + d
hyp_len = c + s + i
return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len)
def ComputeSentenceErrorRate(num_err_utts, num_utts):
assert num_utts != 0
return 100.0 * num_err_utts / num_utts
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--logk', type=int, default=500, help='logging interval')
parser.add_argument(
'--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER'
)
parser.add_argument('--glm', type=str, default='glm_en.csv', help='glm')
parser.add_argument('--ref', type=str, required=True, help='reference kaldi arc file')
parser.add_argument('--hyp', type=str, required=True, help='hypothesis kaldi arc file')
parser.add_argument('result_file', type=str)
args = parser.parse_args()
logging.info(args)
stats = ErrorStats()
logging.info('Generating tokenizer ...')
if args.tokenizer == 'whitespace':
def word_tokenizer(text):
return text.strip().split()
tokenizer = word_tokenizer
elif args.tokenizer == 'char':
def char_tokenizer(text):
return [c for c in text.strip().replace(' ', '')]
tokenizer = char_tokenizer
else:
tokenizer = None
assert tokenizer
logging.info('Loading REF & HYP ...')
ref_utts = LoadKaldiArc(args.ref)
hyp_utts = LoadKaldiArc(args.hyp)
# check valid utterances in hyp that have matched non-empty reference
uids = []
for uid in sorted(hyp_utts.keys()):
if uid in ref_utts.keys():
if ref_utts[uid].text.strip(): # non-empty reference
uids.append(uid)
else:
logging.warning(F'Found {uid} with empty reference, skipping...')
else:
logging.warning(F'Found {uid} without reference, skipping...')
stats.num_hyp_without_ref += 1
stats.num_hyp_utts = len(hyp_utts)
stats.num_ref_utts = len(ref_utts)
stats.num_eval_utts = len(uids)
logging.info(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
print(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
tokens = []
for uid in uids:
ref_tokens = tokenizer(ref_utts[uid].text)
hyp_tokens = tokenizer(hyp_utts[uid].text)
for t in ref_tokens + hyp_tokens:
tokens.append(t)
if '-' in t:
tokens.extend(BreakHyphen(t))
vocab_from_utts = list(set(tokens))
logging.info(f' HYP&REF vocab size: {len(vocab_from_utts)}')
print(f' HYP&REF vocab size: {len(vocab_from_utts)}')
assert args.glm
glm = LoadGLM(args.glm)
tokens = []
for rule in glm.values():
for phrase in rule:
for t in tokenizer(phrase):
tokens.append(t)
if '-' in t:
tokens.extend(BreakHyphen(t))
vocab_from_glm = list(set(tokens))
logging.info(f' GLM vocab size: {len(vocab_from_glm)}')
print(f' GLM vocab size: {len(vocab_from_glm)}')
vocab = list(set(vocab_from_utts + vocab_from_glm))
logging.info(f'Global vocab size: {len(vocab)}')
print(f'Global vocab size: {len(vocab)}')
symtab = BuildSymbolTable(
# Normal evaluation vocab + auxiliary form for alternative paths + GLM tags
vocab
+ [x + '#' for x in vocab]
+ [x for x in glm.keys()]
)
glm_tagger = BuildGLMTagger(glm, symtab)
edit_transducer = EditTransducer(symbol_table=symtab, vocab=vocab)
print(edit_transducer)
logging.info('Evaluating error rate ...')
print('Evaluating error rate ...')
fo = open(args.result_file, 'w+', encoding='utf8')
ndone = 0
for uid in uids:
ref = ref_utts[uid].text
raw_hyp = hyp_utts[uid].text
ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type=symtab)
print(ref_fst)
# print(ref_fst.string(token_type = symtab))
raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type=symtab)
# print(raw_hyp_fst.string(token_type = symtab))
# Say, we have:
# RULE_001: "I'M" <-> "I AM"
# REF: HEY I AM HERE
# HYP: HEY I'M HERE
#
# We want to expand HYP with GLM rules(marked with auxiliary #)
# HYP#: HEY {I'M | I# AM#} HERE
# REF is honored to keep its original form.
#
# This could be considered as a flexible on-the-fly TN towards HYP.
# 1. GLM rule tagging:
# HEY I'M HERE
# ->
# HEY <RULE_001> I'M <RULE_001> HERE
lattice = (raw_hyp_fst @ glm_tagger).optimize()
tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type=symtab)
# print(hyp_tagged)
# 2. GLM rule expansion:
# HEY <RULE_001> I'M <RULE_001> HERE
# ->
# sausage-like fst: HEY {I'M | I# AM#} HERE
tokens = tagged_ir.split()
sausage = pynini.accep('', token_type=symtab)
i = 0
while i < len(tokens): # invariant: tokens[0, i) has been built into fst
forms = []
if tokens[i].startswith('<RULE_') and tokens[i].endswith('>'): # rule segment
rule_name = tokens[i]
rule = glm[rule_name]
# pre-condition: i -> ltag
raw_form = ''
for j in range(i + 1, len(tokens)):
if tokens[j] == rule_name:
raw_form = ' '.join(tokens[i + 1 : j])
break
assert raw_form
# post-condition: i -> ltag, j -> rtag
forms.append(raw_form)
for phrase in rule:
if phrase != raw_form:
forms.append(' '.join([x + '#' for x in phrase.split()]))
i = j + 1
else: # normal token segment
token = tokens[i]
forms.append(token)
if "-" in token: # token with hyphen yields extra forms
forms.append(' '.join([x + '#' for x in token.split('-')])) # 'T-SHIRT' -> 'T# SHIRT#'
forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#'
i += 1
sausage_segment = pynini.union(*[pynini.accep(x, token_type=symtab) for x in forms]).optimize()
sausage += sausage_segment
hyp_fst = sausage.optimize()
print(hyp_fst)
# Utterance-Level error rate evaluation
alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst)
print("alignment", alignment)
distance = 0.0
C, S, I, D = 0, 0, 0, 0 # Cor, Sub, Ins, Del
edit_ali, ref_ali, hyp_ali = [], [], []
for state in alignment.states():
for arc in alignment.arcs(state):
i, o = arc.ilabel, arc.olabel
if i != 0 and o != 0 and SymbolEQ(symtab, i, o):
e = 'C'
r, h = symtab.find(i), symtab.find(o)
C += 1
distance += 0.0
elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o):
e = 'S'
r, h = symtab.find(i), symtab.find(o)
S += 1
distance += 1.0
elif i == 0 and o != 0:
e = 'I'
r, h = '*', symtab.find(o)
I += 1
distance += 1.0
elif i != 0 and o == 0:
e = 'D'
r, h = symtab.find(i), '*'
D += 1
distance += 1.0
else:
raise RuntimeError
edit_ali.append(e)
ref_ali.append(r)
hyp_ali.append(h)
# assert(distance == edit_transducer.compute_distance(ref_fst, sausage))
utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D)
# print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo)
# PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo)
if utt_ter > 0:
stats.num_utts_with_error += 1
stats.C += C
stats.S += S
stats.I += I
stats.D += D
ndone += 1
if ndone % args.logk == 0:
logging.info(f'{ndone} utts evaluated.')
logging.info(f'{ndone} utts evaluated in total.')
# Corpus-Level evaluation
stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D)
stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts)
print(stats.to_json(), file=fo)
# print(stats.to_kaldi())
# print(stats.to_summary(), file=fo)
fo.close()

View File

@@ -0,0 +1,370 @@
#!/usr/bin/env python3
# coding=utf8
# Copyright 2021 Jiayu DU
import sys
import argparse
import json
import logging
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
DEBUG = None
def GetEditType(ref_token, hyp_token):
if ref_token == None and hyp_token != None:
return 'I'
elif ref_token != None and hyp_token == None:
return 'D'
elif ref_token == hyp_token:
return 'C'
elif ref_token != hyp_token:
return 'S'
else:
raise RuntimeError
class AlignmentArc:
def __init__(self, src, dst, ref, hyp):
self.src = src
self.dst = dst
self.ref = ref
self.hyp = hyp
self.edit_type = GetEditType(ref, hyp)
def similarity_score_function(ref_token, hyp_token):
return 0 if (ref_token == hyp_token) else -1.0
def insertion_score_function(token):
return -1.0
def deletion_score_function(token):
return -1.0
def EditDistance(
ref,
hyp,
similarity_score_function = similarity_score_function,
insertion_score_function = insertion_score_function,
deletion_score_function = deletion_score_function):
assert(len(ref) != 0)
class DPState:
def __init__(self):
self.score = -float('inf')
# backpointer
self.prev_r = None
self.prev_h = None
def print_search_grid(S, R, H, fstream):
print(file=fstream)
for r in range(R):
for h in range(H):
print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
print(file=fstream)
R = len(ref) + 1
H = len(hyp) + 1
# Construct DP search space, a (R x H) grid
S = [ [] for r in range(R) ]
for r in range(R):
S[r] = [ DPState() for x in range(H) ]
# initialize DP search grid origin, S(r = 0, h = 0)
S[0][0].score = 0.0
S[0][0].prev_r = None
S[0][0].prev_h = None
# initialize REF axis
for r in range(1, R):
S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
S[r][0].prev_r = r-1
S[r][0].prev_h = 0
# initialize HYP axis
for h in range(1, H):
S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
S[0][h].prev_r = 0
S[0][h].prev_h = h-1
best_score = S[0][0].score
best_state = (0, 0)
for r in range(1, R):
for h in range(1, H):
sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
new_score = S[r-1][h-1].score + sub_or_cor_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r-1
S[r][h].prev_h = h-1
del_score = deletion_score_function(ref[r-1])
new_score = S[r-1][h].score + del_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r - 1
S[r][h].prev_h = h
ins_score = insertion_score_function(hyp[h-1])
new_score = S[r][h-1].score + ins_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r
S[r][h].prev_h = h-1
best_score = S[R-1][H-1].score
best_state = (R-1, H-1)
if DEBUG:
print_search_grid(S, R, H, sys.stderr)
# Backtracing best alignment path, i.e. a list of arcs
# arc = (src, dst, ref, hyp, edit_type)
# src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
best_path = []
r, h = best_state[0], best_state[1]
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
# loop invariant:
# 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
# 2. score is the value of point(r, h) on DP search grid
while prev_r != None or prev_h != None:
src = (prev_r, prev_h)
dst = (r, h)
if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
elif (r == prev_r + 1 and h == prev_h): # Deletion
arc = AlignmentArc(src, dst, ref[prev_r], None)
elif (r == prev_r and h == prev_h + 1): # Insertion
arc = AlignmentArc(src, dst, None, hyp[prev_h])
else:
raise RuntimeError
best_path.append(arc)
r, h = prev_r, prev_h
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
best_path.reverse()
return (best_path, best_score)
def PrettyPrintAlignment(alignment, stream = sys.stderr):
def get_token_str(token):
if token == None:
return "*"
return token
def is_double_width_char(ch):
if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
return True
# TODO: support other double-width-char language such as Japanese, Korean
else:
return False
def display_width(token_str):
m = 0
for c in token_str:
if is_double_width_char(c):
m += 2
else:
m += 1
return m
R = ' REF : '
H = ' HYP : '
E = ' EDIT : '
for arc in alignment:
r = get_token_str(arc.ref)
h = get_token_str(arc.hyp)
e = arc.edit_type if arc.edit_type != 'C' else ''
nr, nh, ne = display_width(r), display_width(h), display_width(e)
n = max(nr, nh, ne) + 1
R += r + ' ' * (n-nr)
H += h + ' ' * (n-nh)
E += e + ' ' * (n-ne)
print(R, file=stream)
print(H, file=stream)
print(E, file=stream)
def CountEdits(alignment):
c, s, i, d = 0, 0, 0, 0
for arc in alignment:
if arc.edit_type == 'C':
c += 1
elif arc.edit_type == 'S':
s += 1
elif arc.edit_type == 'I':
i += 1
elif arc.edit_type == 'D':
d += 1
else:
raise RuntimeError
return (c, s, i, d)
def ComputeTokenErrorRate(c, s, i, d):
return 100.0 * (s + d + i) / (s + d + c)
def ComputeSentenceErrorRate(num_err_utts, num_utts):
assert(num_utts != 0)
return 100.0 * num_err_utts / num_utts
class EvaluationResult:
def __init__(self):
self.num_ref_utts = 0
self.num_hyp_utts = 0
self.num_eval_utts = 0 # seen in both ref & hyp
self.num_hyp_without_ref = 0
self.C = 0
self.S = 0
self.I = 0
self.D = 0
self.token_error_rate = 0.0
self.num_utts_with_error = 0
self.sentence_error_rate = 0.0
def to_json(self):
return json.dumps(self.__dict__)
def to_kaldi(self):
info = (
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
)
return info
def to_sclite(self):
return "TODO"
def to_espnet(self):
return "TODO"
def to_summary(self):
#return json.dumps(self.__dict__, indent=4)
summary = (
'==================== Overall Statistics ====================\n'
F'num_ref_utts: {self.num_ref_utts}\n'
F'num_hyp_utts: {self.num_hyp_utts}\n'
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
F'num_eval_utts: {self.num_eval_utts}\n'
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
F'token_error_rate: {self.token_error_rate:.2f}%\n'
F'token_stats:\n'
F' - tokens:{self.C + self.S + self.D:>7}\n'
F' - edits: {self.S + self.I + self.D:>7}\n'
F' - cor: {self.C:>7}\n'
F' - sub: {self.S:>7}\n'
F' - ins: {self.I:>7}\n'
F' - del: {self.D:>7}\n'
'============================================================\n'
)
return summary
class Utterance:
def __init__(self, uid, text):
self.uid = uid
self.text = text
def LoadUtterances(filepath, format):
utts = {}
if format == 'text': # utt_id word1 word2 ...
with open(filepath, 'r', encoding='utf8') as f:
for line in f:
line = line.strip()
if line:
cols = line.split(maxsplit=1)
assert(len(cols) == 2 or len(cols) == 1)
uid = cols[0]
text = cols[1] if len(cols) == 2 else ''
if utts.get(uid) != None:
raise RuntimeError(F'Found duplicated utterence id {uid}')
utts[uid] = Utterance(uid, text)
else:
raise RuntimeError(F'Unsupported text format {format}')
return utts
def tokenize_text(text, tokenizer):
if tokenizer == 'whitespace':
return text.split()
elif tokenizer == 'char':
return [ ch for ch in ''.join(text.split()) ]
else:
raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# optional
parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
# required
parser.add_argument('--ref', type=str, required=True, help='input reference file')
parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
parser.add_argument('result_file', type=str)
args = parser.parse_args()
logging.info(args)
ref_utts = LoadUtterances(args.ref, args.ref_format)
hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
r = EvaluationResult()
# check valid utterances in hyp that have matched non-empty reference
eval_utts = []
r.num_hyp_without_ref = 0
for uid in sorted(hyp_utts.keys()):
if uid in ref_utts.keys(): # TODO: efficiency
if ref_utts[uid].text.strip(): # non-empty reference
eval_utts.append(uid)
else:
logging.warn(F'Found {uid} with empty reference, skipping...')
else:
logging.warn(F'Found {uid} without reference, skipping...')
r.num_hyp_without_ref += 1
r.num_hyp_utts = len(hyp_utts)
r.num_ref_utts = len(ref_utts)
r.num_eval_utts = len(eval_utts)
with open(args.result_file, 'w+', encoding='utf8') as fo:
for uid in eval_utts:
ref = ref_utts[uid]
hyp = hyp_utts[uid]
alignment, score = EditDistance(
tokenize_text(ref.text, args.tokenizer),
tokenize_text(hyp.text, args.tokenizer)
)
c, s, i, d = CountEdits(alignment)
utt_ter = ComputeTokenErrorRate(c, s, i, d)
# utt-level evaluation result
print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
PrettyPrintAlignment(alignment, fo)
r.C += c
r.S += s
r.I += i
r.D += d
if utt_ter > 0:
r.num_utts_with_error += 1
# corpus level evaluation result
r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
print(r.to_summary(), file=fo)
print(r.to_json())
print(r.to_kaldi())

744
utils/speechio/glm_en.csv Normal file
View File

@@ -0,0 +1,744 @@
I'M,I AM
I'LL,I WILL
I'D,I HAD
I'VE,I HAVE
I WOULD'VE,I'D HAVE
YOU'RE,YOU ARE
YOU'LL,YOU WILL
YOU'D,YOU WOULD
YOU'VE,YOU HAVE
HE'S,HE IS,HE WAS
HE'LL,HE WILL
HE'D,HE HAD
SHE'S,SHE IS,SHE WAS
SHE'LL,SHE WILL
SHE'D,SHE HAD
IT'S,IT IS,IT WAS
IT'LL,IT WILL
WE'RE,WE ARE,WE WERE
WE'LL,WE WILL
WE'D,WE WOULD
WE'VE,WE HAVE
WHO'LL,WHO WILL
THEY'RE,THEY ARE
THEY'LL,THEY WILL
THAT'S,THAT IS,THAT WAS
THAT'LL,THAT WILL
HERE'S,HERE IS,HERE WAS
THERE'S,THERE IS,THERE WAS
WHERE'S,WHERE IS,WHERE WAS
WHAT'S,WHAT IS,WHAT WAS
LET'S,LET US
WHO'S,WHO IS
ONE'S,ONE IS
THERE'LL,THERE WILL
SOMEBODY'S,SOMEBODY IS
EVERYBODY'S,EVERYBODY IS
WOULD'VE,WOULD HAVE
CAN'T,CANNOT,CAN NOT
HADN'T,HAD NOT
HASN'T,HAS NOT
HAVEN'T,HAVE NOT
ISN'T,IS NOT
AREN'T,ARE NOT
WON'T,WILL NOT
WOULDN'T,WOULD NOT
SHOULDN'T,SHOULD NOT
DON'T,DO NOT
DIDN'T,DID NOT
GOTTA,GOT TO
GONNA,GOING TO
WANNA,WANT TO
LEMME,LET ME
GIMME,GIVE ME
DUNNO,DON'T KNOW
GOTCHA,GOT YOU
KINDA,KIND OF
MYSELF,MY SELF
YOURSELF,YOUR SELF
HIMSELF,HIM SELF
HERSELF,HER SELF
ITSELF,IT SELF
OURSELVES,OUR SELVES
OKAY,OK,O K
Y'ALL,YALL,YOU ALL
'CAUSE,'COS,CUZ,BECAUSE
FUCKIN',FUCKING
KILLING,KILLIN'
EVERYDAY,EVERY DAY
DOCTOR,DR,DR.
MRS,MISSES,MISSUS
MR,MR.,MISTER
SR,SR.,SENIOR
JR,JR.,JUNIOR
ST,ST.,SAINT
VOL,VOL.,VOLUME
CM,CENTIMETER,CENTIMETRE
MM,MILLIMETER,MILLIMETRE
KM,KILOMETER,KILOMETRE
KB,KILOBYTES,KILO BYTES,K B
MB,MEGABYTES,MEGA BYTES
GB,GIGABYTES,GIGA BYTES,G B
THOUSAND,THOUSAND AND
HUNDRED,HUNDRED AND
A HUNDRED,ONE HUNDRED
TWO THOUSAND AND,TWENTY,TWO THOUSAND
STORYTELLER,STORY TELLER
TSHIRT,T SHIRT
TSHIRTS,T SHIRTS
LEUKAEMIA,LEUKEMIA
OESTROGEN,ESTROGEN
ACKNOWLEDGMENT,ACKNOWLEDGEMENT
JUDGMENT,JUDGEMENT
MAMMA,MAMA
DINING,DINNING
FLACK,FLAK
LEARNT,LEARNED
BLONDE,BLOND
JUMPSTART,JUMP START
RIGHTNOW,RIGHT NOW
EVERYONE,EVERY ONE
NAME'S,NAME IS
FAMILY'S,FAMILY IS
COMPANY'S,COMPANY HAS
GRANDKID,GRAND KID
GRANDKIDS,GRAND KIDS
MEALTIMES,MEAL TIMES
ALRIGHT,ALL RIGHT
GROWNUP,GROWN UP
GROWNUPS,GROWN UPS
SCHOOLDAYS,SCHOOL DAYS
SCHOOLCHILDREN,SCHOOL CHILDREN
CASEBOOK,CASE BOOK
HUNGOVER,HUNG OVER
HANDCLAPS,HAND CLAPS
HANDCLAP,HAND CLAP
HEATWAVE,HEAT WAVE
ADDON,ADD ON
ONTO,ON TO
INTO,IN TO
GOTO,GO TO
GUNSHOT,GUN SHOT
MOTHERFUCKER,MOTHER FUCKER
OFTENTIMES,OFTEN TIMES
SARTRE'S,SARTRE IS
NONSTARTER,NON STARTER
NONSTARTERS,NON STARTERS
LONGTIME,LONG TIME
POLICYMAKERS,POLICY MAKERS
ANYMORE,ANY MORE
CANADA'S,CANADA IS
CELLPHONE,CELL PHONE
WORKPLACE,WORK PLACE
UNDERESTIMATING,UNDER ESTIMATING
CYBERSECURITY,CYBER SECURITY
NORTHEAST,NORTH EAST
ANYTIME,ANY TIME
LIVESTREAM,LIVE STREAM
LIVESTREAMS,LIVE STREAMS
WEBCAM,WEB CAM
EMAIL,E MAIL
ECAM,E CAM
VMIX,V MIX
SETUP,SET UP
SMARTPHONE,SMART PHONE
MULTICASTING,MULTI CASTING
CHITCHAT,CHIT CHAT
SEMIFINAL,SEMI FINAL
SEMIFINALS,SEMI FINALS
BBQ,BARBECUE
STORYLINE,STORY LINE
STORYLINES,STORY LINES
BRO,BROTHER
BROS,BROTHERS
OVERPROTECTIIVE,OVER PROTECTIVE
TIMEOUT,TIME OUT
ADVISOR,ADVISER
TIMBERWOLVES,TIMBER WOLVES
WEBPAGE,WEB PAGE
NEWCOMER,NEW COMER
DELMAR,DEL MAR
NETPLAY,NET PLAY
STREETSIDE,STREET SIDE
COLOURED,COLORED
COLOURFUL,COLORFUL
O,ZERO
ETCETERA,ET CETERA
FUNDRAISING,FUND RAISING
RAINFOREST,RAIN FOREST
BREATHTAKING,BREATH TAKING
WIKIPAGE,WIKI PAGE
OVERTIME,OVER TIME
TRAIN'S TRAIN IS
ANYONE,ANY ONE
PHYSIOTHERAPY,PHYSIO THERAPY
ANYBODY,ANY BODY
BOTTLECAPS,BOTTLE CAPS
BOTTLECAP,BOTTLE CAP
STEPFATHER'S,STEP FATHER'S
STEPFATHER,STEP FATHER
WARTIME,WAR TIME
SCREENSHOT,SCREEN SHOT
TIMELINE,TIME LINE
CITY'S,CITY IS
NONPROFIT,NON PROFIT
KPOP,K POP
HOMEBASE,HOME BASE
LIFELONG,LIFE LONG
LAWSUITS,LAW SUITS
MULTIBILLION,MULTI BILLION
ROADMAP,ROAD MAP
GUY'S,GUY IS
CHECKOUT,CHECK OUT
SQUARESPACE,SQUARE SPACE
REDLINING,RED LINING
BASE'S,BASE IS
TAKEAWAY,TAKE AWAY
CANDYLAND,CANDY LAND
ANTISOCIAL,ANTI SOCIAL
CASEWORK,CASE WORK
RIGOR,RIGOUR
ORGANIZATIONS,ORGANISATIONS
ORGANIZATION,ORGANISATION
SIGNPOST,SIGN POST
WWII,WORLD WAR TWO
WINDOWPANE,WINDOW PANE
SUREFIRE,SURE FIRE
MOUNTAINTOP,MOUNTAIN TOP
SALESPERSON,SALES PERSON
NETWORK,NET WORK
MINISERIES,MINI SERIES
EDWARDS'S,EDWARDS IS
INTERSUBJECTIVITY,INTER SUBJECTIVITY
LIBERALISM'S,LIBERALISM IS
TAGLINE,TAG LINE
SHINETHEORY,SHINE THEORY
CALLYOURGIRLFRIEND,CALL YOUR GIRLFRIEND
STARTUP,START UP
BREAKUP,BREAK UP
RADIOTOPIA,RADIO TOPIA
HEARTBREAKING,HEART BREAKING
AUTOIMMUNE,AUTO IMMUNE
SINISE'S,SINISE IS
KICKBACK,KICK BACK
FOGHORN,FOG HORN
BADASS,BAD ASS
POWERAMERICAFORWARD,POWER AMERICA FORWARD
GOOGLE'S,GOOGLE IS
ROLEPLAY,ROLE PLAY
PRICE'S,PRICE IS
STANDOFF,STAND OFF
FOREVER,FOR EVER
GENERAL'S,GENERAL IS
DOG'S,DOG IS
AUDIOBOOK,AUDIO BOOK
ANYWAY,ANY WAY
PIGEONHOLE,PIEGON HOLE
EGGSHELLS,EGG SHELLS
VACCINE'S,VACCINE IS
WORKOUT,WORK OUT
ADMINISTRATOR'S,ADMINISTRATOR IS
FUCKUP,FUCK UP
RUNOFFS,RUN OFFS
COLORWAY,COLOR WAY
WAITLIST,WAIT LIST
HEALTHCARE,HEALTH CARE
TEXTBOOK,TEXT BOOK
CALLBACK,CALL BACK
PARTYGOERS,PARTY GOERS
SOMEDAY,SOME DAY
NIGHTGOWN,NIGHT GOWN
STANDALONG,STAND ALONG
BUSSINESSWOMAN,BUSSINESS WOMAN
STORYTELLING,STORY TELLING
MARKETPLACE,MARKET PLACE
CRATEJOY,CRATE JOY
OUTPERFORMED,OUT PERFORMED
TRUEBOTANICALS,TRUE BOTANICALS
NONFICTION,NON FICTION
SPINOFF,SPIN OFF
MOTHERFUCKING,MOTHER FUCKING
TRACKLIST,TRACK LIST
GODDAMN,GOD DAMN
PORNHUB,PORN HUB
UNDERAGE,UNDER AGE
GOODBYE,GOOD BYE
HARDCORE,HARD CORE
TRUCK'S,TRUCK IS
COUNTERSTEERING,COUNTER STEERING
BUZZWORD,BUZZ WORD
SUBCOMPONENTS,SUB COMPONENTS
MOREOVER,MORE OVER
PICKUP,PICK UP
NEWSLETTER,NEWS LETTER
KEYWORD,KEY WORD
LOGIN,LOG IN
TOOLBOX,TOOL BOX
LINK'S,LINK IS
PRIMIALVIDEO,PRIMAL VIDEO
DOTNET,DOT NET
AIRSTRIKE,AIR STRIKE
HAIRSTYLE,HAIR STYLE
TOWNSFOLK,TOWNS FOLK
GOLDFISH,GOLD FISH
TOM'S,TOM IS
HOMETOWN,HOME TOWN
CORONAVIRUS,CORONA VIRUS
PLAYSTATION,PLAY STATION
TOMORROW,TO MORROW
TIMECONSUMING,TIME CONSUMING
POSTWAR,POST WAR
HANDSON,HANDS ON
SHAKEUP,SHAKE UP
ECOMERS,E COMERS
COFOUNDER,CO FOUNDER
HIGHEND,HIGH END
INPERSON,IN PERSON
GROWNUP,GROWN UP
SELFREGULATION,SELF REGULATION
INDEPTH,IN DEPTH
ALLTIME,ALL TIME
LONGTERM,LONG TERM
SOCALLED,SO CALLED
SELFCONFIDENCE,SELF CONFIDENCE
STANDUP,STAND UP
MINDBOGGLING,MIND BOGGLING
BEINGFOROTHERS,BEING FOR OTHERS
COWROTE,CO WROTE
COSTARRED,CO STARRED
EDITORINCHIEF,EDITOR IN CHIEF
HIGHSPEED,HIGH SPEED
DECISIONMAKING,DECISION MAKING
WELLBEING,WELL BEING
NONTRIVIAL,NON TRIVIAL
PREEXISTING,PRE EXISTING
STATEOWNED,STATE OWNED
PLUGIN,PLUG IN
PROVERSION,PRO VERSION
OPTIN,OPT IN
FOLLOWUP,FOLLOW UP
FOLLOWUPS,FOLLOW UPS
WIFI,WI FI
THIRDPARTY,THIRD PARTY
PROFESSIONALLOOKING,PROFESSIONAL LOOKING
FULLSCREEN,FULL SCREEN
BUILTIN,BUILT IN
MULTISTREAM,MULTI STREAM
LOWCOST,LOW COST
RESTREAM,RE STREAM
GAMECHANGER,GAME CHANGER
WELLDEVELOPED,WELL DEVELOPED
QUARTERINCH,QUARTER INCH
FASTFASHION,FAST FASHION
ECOMMERCE,E COMMERCE
PRIZEWINNING,PRIZE WINNING
NEVERENDING,NEVER ENDING
MINDBLOWING,MIND BLOWING
REALLIFE,REAL LIFE
REOPEN,RE OPEN
ONDEMAND,ON DEMAND
PROBLEMSOLVING,PROBLEM SOLVING
HEAVYHANDED,HEAVY HANDED
OPENENDED,OPEN ENDED
SELFCONTROL,SELF CONTROL
WELLMEANING,WELL MEANING
COHOST,CO HOST
RIGHTSBASED,RIGHTS BASED
HALFBROTHER,HALF BROTHER
FATHERINLAW,FATHER IN LAW
COAUTHOR,CO AUTHOR
REELECTION,RE ELECTION
SELFHELP,SELF HELP
PROLIFE,PRO LIFE
ANTIDUKE,ANTI DUKE
POSTSTRUCTURALIST,POST STRUCTURALIST
COFOUNDED,CO FOUNDED
XRAY,X RAY
ALLAROUND,ALL AROUND
HIGHTECH,HIGH TECH
TMOBILE,T MOBILE
INHOUSE,IN HOUSE
POSTMORTEM,POST MORTEM
LITTLEKNOWN,LITTLE KNOWN
FALSEPOSITIVE,FALSE POSITIVE
ANTIVAXXER,ANTI VAXXER
EMAILS,E MAILS
DRIVETHROUGH,DRIVE THROUGH
DAYTODAY,DAY TO DAY
COSTAR,CO STAR
EBAY,E BAY
KOOLAID,KOOL AID
ANTIDEMOCRATIC,ANTI DEMOCRATIC
MIDDLEAGED,MIDDLE AGED
SHORTLIVED,SHORT LIVED
BESTSELLING,BEST SELLING
TICTACS,TIC TACS
UHHUH,UH HUH
MULTITANK,MULTI TANK
JAWDROPPING,JAW DROPPING
LIVESTREAMING,LIVE STREAMING
HARDWORKING,HARD WORKING
BOTTOMDWELLING,BOTTOM DWELLING
PRESHOW,PRE SHOW
HANDSFREE,HANDS FREE
TRICKORTREATING,TRICK OR TREATING
PRERECORDED,PRE RECORDED
DOGOODERS,DO GOODERS
WIDERANGING,WIDE RANGING
LIFESAVING,LIFE SAVING
SKIREPORT,SKI REPORT
SNOWBASE,SNOW BASE
JAYZ,JAY Z
SPIDERMAN,SPIDER MAN
FREEKICK,FREE KICK
EDWARDSHELAIRE,EDWARDS HELAIRE
SHORTTERM,SHORT TERM
HAVENOTS,HAVE NOTS
SELFINTEREST,SELF INTEREST
SELFINTERESTED,SELF INTERESTED
SELFCOMPASSION,SELF COMPASSION
MACHINELEARNING,MACHINE LEARNING
COAUTHORED,CO AUTHORED
NONGOVERNMENT,NON GOVERNMENT
SUBSAHARAN,SUB SAHARAN
COCHAIR,CO CHAIR
LARGESCALE,LARGE SCALE
VIDEOONDEMAND,VIDEO ON DEMAND
FIRSTCLASS,FIRST CLASS
COFOUNDERS,CO FOUNDERS
COOP,CO OP
PREORDERS,PRE ORDERS
DOUBLEENTRY,DOUBLE ENTRY
SELFCONFIDENT,SELF CONFIDENT
SELFPORTRAIT,SELF PORTRAIT
NONWHITE,NON WHITE
ONBOARD,ON BOARD
HALFLIFE,HALF LIFE
ONCOURT,ON COURT
SCIFI,SCI FI
XMEN,X MEN
DAYLEWIS,DAY LEWIS
LALALAND,LA LA LAND
AWARDWINNING,AWARD WINNING
BOXOFFICE,BOX OFFICE
TRIDACTYLS,TRI DACTYLS
TRIDACTYL,TRI DACTYL
MEDIUMSIZED,MEDIUM SIZED
POSTSECONDARY,POST SECONDARY
FULLTIME,FULL TIME
GOKART,GO KART
OPENAIR,OPEN AIR
WELLKNOWN,WELL KNOWN
ICECREAM,ICE CREAM
EARTHMOON,EARTH MOON
STATEOFTHEART,STATE OF THE ART
BSIDE,B SIDE
EASTWEST,EAST WEST
ALLSTAR,ALL STAR
RUNNERUP,RUNNER UP
HORSEDRAWN,HORSE DRAWN
OPENSOURCE,OPEN SOURCE
PURPOSEBUILT,PURPOSE BUILT
SQUAREFREE,SQUARE FREE
PRESENTDAY,PRESENT DAY
CANADAUNITED,CANADA UNITED
HOTCHPOTCH,HOTCH POTCH
LOWLYING,LOW LYING
RIGHTHANDED,RIGHT HANDED
PEARSHAPED,PEAR SHAPED
BESTKNOWN,BEST KNOWN
FULLLENGTH,FULL LENGTH
YEARROUND,YEAR ROUND
PREELECTION,PRE ELECTION
RERECORD,RE RECORD
MINIALBUM,MINI ALBUM
LONGESTRUNNING,LONGEST RUNNING
ALLIRELAND,ALL IRELAND
NORTHWESTERN,NORTH WESTERN
PARTTIME,PART TIME
NONGOVERNMENTAL,NON GOVERNMENTAL
ONLINE,ON LINE
ONAIR,ON AIR
NORTHSOUTH,NORTH SOUTH
RERELEASED,RE RELEASED
LEFTHANDED,LEFT HANDED
BSIDES,B SIDES
ANGLOSAXON,ANGLO SAXON
SOUTHSOUTHEAST,SOUTH SOUTHEAST
CROSSCOUNTRY,CROSS COUNTRY
REBUILT,RE BUILT
FREEFORM,FREE FORM
SCOOBYDOO,SCOOBY DOO
ATLARGE,AT LARGE
COUNCILMANAGER,COUNCIL MANAGER
LONGRUNNING,LONG RUNNING
PREWAR,PRE WAR
REELECTED,RE ELECTED
HIGHSCHOOL,HIGH SCHOOL
RUNNERSUP,RUNNERS UP
NORTHWEST,NORTH WEST
WEBBASED,WEB BASED
HIGHQUALITY,HIGH QUALITY
RIGHTWING,RIGHT WING
LANEFOX,LANE FOX
PAYPERVIEW,PAY PER VIEW
COPRODUCTION,CO PRODUCTION
NONPARTISAN,NON PARTISAN
FIRSTPERSON,FIRST PERSON
WORLDRENOWNED,WORLD RENOWNED
VICEPRESIDENT,VICE PRESIDENT
PROROMAN,PRO ROMAN
COPRODUCED,CO PRODUCED
LOWPOWER,LOW POWER
SELFESTEEM,SELF ESTEEM
SEMITRANSPARENT,SEMI TRANSPARENT
SECONDINCOMMAND,SECOND IN COMMAND
HIGHRISE,HIGH RISE
COHOSTED,CO HOSTED
AFRICANAMERICAN,AFRICAN AMERICAN
SOUTHWEST,SOUTH WEST
WELLPRESERVED,WELL PRESERVED
FEATURELENGTH,FEATURE LENGTH
HIPHOP,HIP HOP
ALLBIG,ALL BIG
SOUTHEAST,SOUTH EAST
COUNTERATTACK,COUNTER ATTACK
QUARTERFINALS,QUARTER FINALS
STABLEDOOR,STABLE DOOR
DARKEYED,DARK EYED
ALLAMERICAN,ALL AMERICAN
THIRDPERSON,THIRD PERSON
LOWLEVEL,LOW LEVEL
NTERMINAL,N TERMINAL
DRIEDUP,DRIED UP
AFRICANAMERICANS,AFRICAN AMERICANS
ANTIAPARTHEID,ANTI APARTHEID
STOKEONTRENT,STOKE ON TRENT
NORTHNORTHEAST,NORTH NORTHEAST
BRANDNEW,BRAND NEW
RIGHTANGLED,RIGHT ANGLED
GOVERNMENTOWNED,GOVERNMENT OWNED
SONINLAW,SON IN LAW
SUBJECTOBJECTVERB,SUBJECT OBJECT VERB
LEFTARM,LEFT ARM
LONGLIVED,LONG LIVED
REDEYE,RED EYE
TPOSE,T POSE
NIGHTVISION,NIGHT VISION
SOUTHEASTERN,SOUTH EASTERN
WELLRECEIVED,WELL RECEIVED
ALFAYOUM,AL FAYOUM
TIMEBASED,TIME BASED
KETTLEDRUMS,KETTLE DRUMS
BRIGHTEYED,BRIGHT EYED
REDBROWN,RED BROWN
SAMESEX,SAME SEX
PORTDEPAIX,PORT DE PAIX
CLEANUP,CLEAN UP
PERCENT,PERCENT SIGN
TAKEOUT,TAKE OUT
KNOWHOW,KNOW HOW
FISHBONE,FISH BONE
FISHSTICKS,FISH STICKS
PAPERWORK,PAPER WORK
NICKNACKS,NICK NACKS
STREETTALKING,STREET TALKING
NONACADEMIC,NON ACADEMIC
SHELLY,SHELLEY
SHELLY'S,SHELLEY'S
JIMMY,JIMMIE
JIMMY'S,JIMMIE'S
DRUGSTORE,DRUG STORE
THRU,THROUGH
PLAYDATE,PLAY DATE
MICROLIFE,MICRO LIFE
SKILLSET,SKILL SET
SKILLSETS,SKILL SETS
TRADEOFF,TRADE OFF
TRADEOFFS,TRADE OFFS
ONSCREEN,ON SCREEN
PLAYBACK,PLAY BACK
ARTWORK,ART WORK
COWORKER,CO WORDER
COWORKERS,CO WORDERS
SOMETIME,SOME TIME
SOMETIMES,SOME TIMES
CROWDFUNDING,CROWD FUNDING
AM,A.M.,A M
PM,P.M.,P M
TV,T V
MBA,M B A
USA,U S A
US,U S
UK,U K
CEO,C E O
CFO,C F O
COO,C O O
CIO,C I O
FM,F M
GMC,G M C
FSC,F S C
NPD,N P D
APM,A P M
NGO,N G O
TD,T D
LOL,L O L
IPO,I P O
CNBC,C N B C
IPOS,I P OS
CNBC's,C N B C'S
JT,J T
NPR,N P R
NPR'S,N P R'S
MP,M P
IOI,I O I
DW,D W
CNN,C N N
WSM,W S M
ET,E T
IT,I T
RJ,R J
DVD,D V D
DVD'S,D V D'S
HBO,H B O
LA,L A
XC,X C
SUV,S U V
NBA,N B A
NBA'S,N B A'S
ESPN,E S P N
ESPN'S,E S P N'S
ADT,A D T
HD,H D
VIP,V I P
TMZ,T M Z
CBC,C B C
NPO,N P O
BBC,B B C
LA'S,L A'S
TMZ'S,T M Z'S
HIV,H I V
FTC,F T C
EU,E U
PHD,P H D
AI,A I
FHI,F H I
ICML,I C M L
ICLR,I C L R
BMW,B M W
EV,E V
CR,C R
API,A P I
ICO,I C O
LTE,L T E
OBS,O B S
PC,P C
IO,I O
CRM,C R M
RTMP,R T M P
ASMR,A S M R
GG,G G
WWW,W W W
PEI,P E I
JJ,J J
PT,P T
DJ,D J
SD,S D
POW,P.O.W.,P O W
FYI,F Y I
DC,D C,D.C
ABC,A B C
TJ,T J
WMDT,W M D T
WDTN,W D T N
TY,T Y
EJ,E J
CJ,C J
ACL,A C L
UK'S,U K'S
GTV,G T V
MDMA,M D M A
DFW,D F W
WTF,W T F
AJ,A J
MD,M D
PH,P H
ID,I D
SEO,S E O
UTM'S,U T M'S
EC,E C
UFC,U F C
RV,R V
UTM,U T M
CSV,C S V
SMS,S M S
GRB,G R B
GT,G T
LEM,L E M
XR,X R
EDU,E D U
NBC,N B C
EMS,E M S
CDC,C D C
MLK,M L K
IE,I E
OC,O C
HR,H R
MA,M A
DEE,D E E
AP,A P
UFO,U F O
DE,D E
LGBTQ,L G B T Q
PTA,P T A
NHS,N H S
CMA,C M A
MGM,M G M
AKA,A K A
HW,H W
GOP,G O P
GOP'S,G O P'S
FBI,F B I
PRX,P R X
CTO,C T O
URL,U R L
EIN,E I N
MLS,M L S
CSI,C S I
AOC,A O C
CND,C N D
CP,C P
PP,P P
CLI,C L I
PB,P B
FDA,F D A
MRNA,M R N A
PR,P R
VP,V P
DNC,D N C
MSNBC,M S N B C
GQ,G Q
UT,U T
XXI,X X I
HRV,H R V
WHO,W H O
CRO,C R O
DPA,D P A
PPE,P P E
EVA,E V A
BP,B P
GPS,G P S
AR,A R
PJ,P J
MLM,M L M
OLED,O L E D
BO,B O
VE,V E
UN,U N
SLS,S L S
DM,D M
DM'S,D M'S
ASAP,A S A P
ETA,E T A
DOB,D O B
BMW,B M W
1 I'M,I AM
2 I'LL,I WILL
3 I'D,I HAD
4 I'VE,I HAVE
5 I WOULD'VE,I'D HAVE
6 YOU'RE,YOU ARE
7 YOU'LL,YOU WILL
8 YOU'D,YOU WOULD
9 YOU'VE,YOU HAVE
10 HE'S,HE IS,HE WAS
11 HE'LL,HE WILL
12 HE'D,HE HAD
13 SHE'S,SHE IS,SHE WAS
14 SHE'LL,SHE WILL
15 SHE'D,SHE HAD
16 IT'S,IT IS,IT WAS
17 IT'LL,IT WILL
18 WE'RE,WE ARE,WE WERE
19 WE'LL,WE WILL
20 WE'D,WE WOULD
21 WE'VE,WE HAVE
22 WHO'LL,WHO WILL
23 THEY'RE,THEY ARE
24 THEY'LL,THEY WILL
25 THAT'S,THAT IS,THAT WAS
26 THAT'LL,THAT WILL
27 HERE'S,HERE IS,HERE WAS
28 THERE'S,THERE IS,THERE WAS
29 WHERE'S,WHERE IS,WHERE WAS
30 WHAT'S,WHAT IS,WHAT WAS
31 LET'S,LET US
32 WHO'S,WHO IS
33 ONE'S,ONE IS
34 THERE'LL,THERE WILL
35 SOMEBODY'S,SOMEBODY IS
36 EVERYBODY'S,EVERYBODY IS
37 WOULD'VE,WOULD HAVE
38 CAN'T,CANNOT,CAN NOT
39 HADN'T,HAD NOT
40 HASN'T,HAS NOT
41 HAVEN'T,HAVE NOT
42 ISN'T,IS NOT
43 AREN'T,ARE NOT
44 WON'T,WILL NOT
45 WOULDN'T,WOULD NOT
46 SHOULDN'T,SHOULD NOT
47 DON'T,DO NOT
48 DIDN'T,DID NOT
49 GOTTA,GOT TO
50 GONNA,GOING TO
51 WANNA,WANT TO
52 LEMME,LET ME
53 GIMME,GIVE ME
54 DUNNO,DON'T KNOW
55 GOTCHA,GOT YOU
56 KINDA,KIND OF
57 MYSELF,MY SELF
58 YOURSELF,YOUR SELF
59 HIMSELF,HIM SELF
60 HERSELF,HER SELF
61 ITSELF,IT SELF
62 OURSELVES,OUR SELVES
63 OKAY,OK,O K
64 Y'ALL,YALL,YOU ALL
65 'CAUSE,'COS,CUZ,BECAUSE
66 FUCKIN',FUCKING
67 KILLING,KILLIN'
68 EVERYDAY,EVERY DAY
69 DOCTOR,DR,DR.
70 MRS,MISSES,MISSUS
71 MR,MR.,MISTER
72 SR,SR.,SENIOR
73 JR,JR.,JUNIOR
74 ST,ST.,SAINT
75 VOL,VOL.,VOLUME
76 CM,CENTIMETER,CENTIMETRE
77 MM,MILLIMETER,MILLIMETRE
78 KM,KILOMETER,KILOMETRE
79 KB,KILOBYTES,KILO BYTES,K B
80 MB,MEGABYTES,MEGA BYTES
81 GB,GIGABYTES,GIGA BYTES,G B
82 THOUSAND,THOUSAND AND
83 HUNDRED,HUNDRED AND
84 A HUNDRED,ONE HUNDRED
85 TWO THOUSAND AND,TWENTY,TWO THOUSAND
86 STORYTELLER,STORY TELLER
87 TSHIRT,T SHIRT
88 TSHIRTS,T SHIRTS
89 LEUKAEMIA,LEUKEMIA
90 OESTROGEN,ESTROGEN
91 ACKNOWLEDGMENT,ACKNOWLEDGEMENT
92 JUDGMENT,JUDGEMENT
93 MAMMA,MAMA
94 DINING,DINNING
95 FLACK,FLAK
96 LEARNT,LEARNED
97 BLONDE,BLOND
98 JUMPSTART,JUMP START
99 RIGHTNOW,RIGHT NOW
100 EVERYONE,EVERY ONE
101 NAME'S,NAME IS
102 FAMILY'S,FAMILY IS
103 COMPANY'S,COMPANY HAS
104 GRANDKID,GRAND KID
105 GRANDKIDS,GRAND KIDS
106 MEALTIMES,MEAL TIMES
107 ALRIGHT,ALL RIGHT
108 GROWNUP,GROWN UP
109 GROWNUPS,GROWN UPS
110 SCHOOLDAYS,SCHOOL DAYS
111 SCHOOLCHILDREN,SCHOOL CHILDREN
112 CASEBOOK,CASE BOOK
113 HUNGOVER,HUNG OVER
114 HANDCLAPS,HAND CLAPS
115 HANDCLAP,HAND CLAP
116 HEATWAVE,HEAT WAVE
117 ADDON,ADD ON
118 ONTO,ON TO
119 INTO,IN TO
120 GOTO,GO TO
121 GUNSHOT,GUN SHOT
122 MOTHERFUCKER,MOTHER FUCKER
123 OFTENTIMES,OFTEN TIMES
124 SARTRE'S,SARTRE IS
125 NONSTARTER,NON STARTER
126 NONSTARTERS,NON STARTERS
127 LONGTIME,LONG TIME
128 POLICYMAKERS,POLICY MAKERS
129 ANYMORE,ANY MORE
130 CANADA'S,CANADA IS
131 CELLPHONE,CELL PHONE
132 WORKPLACE,WORK PLACE
133 UNDERESTIMATING,UNDER ESTIMATING
134 CYBERSECURITY,CYBER SECURITY
135 NORTHEAST,NORTH EAST
136 ANYTIME,ANY TIME
137 LIVESTREAM,LIVE STREAM
138 LIVESTREAMS,LIVE STREAMS
139 WEBCAM,WEB CAM
140 EMAIL,E MAIL
141 ECAM,E CAM
142 VMIX,V MIX
143 SETUP,SET UP
144 SMARTPHONE,SMART PHONE
145 MULTICASTING,MULTI CASTING
146 CHITCHAT,CHIT CHAT
147 SEMIFINAL,SEMI FINAL
148 SEMIFINALS,SEMI FINALS
149 BBQ,BARBECUE
150 STORYLINE,STORY LINE
151 STORYLINES,STORY LINES
152 BRO,BROTHER
153 BROS,BROTHERS
154 OVERPROTECTIIVE,OVER PROTECTIVE
155 TIMEOUT,TIME OUT
156 ADVISOR,ADVISER
157 TIMBERWOLVES,TIMBER WOLVES
158 WEBPAGE,WEB PAGE
159 NEWCOMER,NEW COMER
160 DELMAR,DEL MAR
161 NETPLAY,NET PLAY
162 STREETSIDE,STREET SIDE
163 COLOURED,COLORED
164 COLOURFUL,COLORFUL
165 O,ZERO
166 ETCETERA,ET CETERA
167 FUNDRAISING,FUND RAISING
168 RAINFOREST,RAIN FOREST
169 BREATHTAKING,BREATH TAKING
170 WIKIPAGE,WIKI PAGE
171 OVERTIME,OVER TIME
172 TRAIN'S TRAIN IS
173 ANYONE,ANY ONE
174 PHYSIOTHERAPY,PHYSIO THERAPY
175 ANYBODY,ANY BODY
176 BOTTLECAPS,BOTTLE CAPS
177 BOTTLECAP,BOTTLE CAP
178 STEPFATHER'S,STEP FATHER'S
179 STEPFATHER,STEP FATHER
180 WARTIME,WAR TIME
181 SCREENSHOT,SCREEN SHOT
182 TIMELINE,TIME LINE
183 CITY'S,CITY IS
184 NONPROFIT,NON PROFIT
185 KPOP,K POP
186 HOMEBASE,HOME BASE
187 LIFELONG,LIFE LONG
188 LAWSUITS,LAW SUITS
189 MULTIBILLION,MULTI BILLION
190 ROADMAP,ROAD MAP
191 GUY'S,GUY IS
192 CHECKOUT,CHECK OUT
193 SQUARESPACE,SQUARE SPACE
194 REDLINING,RED LINING
195 BASE'S,BASE IS
196 TAKEAWAY,TAKE AWAY
197 CANDYLAND,CANDY LAND
198 ANTISOCIAL,ANTI SOCIAL
199 CASEWORK,CASE WORK
200 RIGOR,RIGOUR
201 ORGANIZATIONS,ORGANISATIONS
202 ORGANIZATION,ORGANISATION
203 SIGNPOST,SIGN POST
204 WWII,WORLD WAR TWO
205 WINDOWPANE,WINDOW PANE
206 SUREFIRE,SURE FIRE
207 MOUNTAINTOP,MOUNTAIN TOP
208 SALESPERSON,SALES PERSON
209 NETWORK,NET WORK
210 MINISERIES,MINI SERIES
211 EDWARDS'S,EDWARDS IS
212 INTERSUBJECTIVITY,INTER SUBJECTIVITY
213 LIBERALISM'S,LIBERALISM IS
214 TAGLINE,TAG LINE
215 SHINETHEORY,SHINE THEORY
216 CALLYOURGIRLFRIEND,CALL YOUR GIRLFRIEND
217 STARTUP,START UP
218 BREAKUP,BREAK UP
219 RADIOTOPIA,RADIO TOPIA
220 HEARTBREAKING,HEART BREAKING
221 AUTOIMMUNE,AUTO IMMUNE
222 SINISE'S,SINISE IS
223 KICKBACK,KICK BACK
224 FOGHORN,FOG HORN
225 BADASS,BAD ASS
226 POWERAMERICAFORWARD,POWER AMERICA FORWARD
227 GOOGLE'S,GOOGLE IS
228 ROLEPLAY,ROLE PLAY
229 PRICE'S,PRICE IS
230 STANDOFF,STAND OFF
231 FOREVER,FOR EVER
232 GENERAL'S,GENERAL IS
233 DOG'S,DOG IS
234 AUDIOBOOK,AUDIO BOOK
235 ANYWAY,ANY WAY
236 PIGEONHOLE,PIEGON HOLE
237 EGGSHELLS,EGG SHELLS
238 VACCINE'S,VACCINE IS
239 WORKOUT,WORK OUT
240 ADMINISTRATOR'S,ADMINISTRATOR IS
241 FUCKUP,FUCK UP
242 RUNOFFS,RUN OFFS
243 COLORWAY,COLOR WAY
244 WAITLIST,WAIT LIST
245 HEALTHCARE,HEALTH CARE
246 TEXTBOOK,TEXT BOOK
247 CALLBACK,CALL BACK
248 PARTYGOERS,PARTY GOERS
249 SOMEDAY,SOME DAY
250 NIGHTGOWN,NIGHT GOWN
251 STANDALONG,STAND ALONG
252 BUSSINESSWOMAN,BUSSINESS WOMAN
253 STORYTELLING,STORY TELLING
254 MARKETPLACE,MARKET PLACE
255 CRATEJOY,CRATE JOY
256 OUTPERFORMED,OUT PERFORMED
257 TRUEBOTANICALS,TRUE BOTANICALS
258 NONFICTION,NON FICTION
259 SPINOFF,SPIN OFF
260 MOTHERFUCKING,MOTHER FUCKING
261 TRACKLIST,TRACK LIST
262 GODDAMN,GOD DAMN
263 PORNHUB,PORN HUB
264 UNDERAGE,UNDER AGE
265 GOODBYE,GOOD BYE
266 HARDCORE,HARD CORE
267 TRUCK'S,TRUCK IS
268 COUNTERSTEERING,COUNTER STEERING
269 BUZZWORD,BUZZ WORD
270 SUBCOMPONENTS,SUB COMPONENTS
271 MOREOVER,MORE OVER
272 PICKUP,PICK UP
273 NEWSLETTER,NEWS LETTER
274 KEYWORD,KEY WORD
275 LOGIN,LOG IN
276 TOOLBOX,TOOL BOX
277 LINK'S,LINK IS
278 PRIMIALVIDEO,PRIMAL VIDEO
279 DOTNET,DOT NET
280 AIRSTRIKE,AIR STRIKE
281 HAIRSTYLE,HAIR STYLE
282 TOWNSFOLK,TOWNS FOLK
283 GOLDFISH,GOLD FISH
284 TOM'S,TOM IS
285 HOMETOWN,HOME TOWN
286 CORONAVIRUS,CORONA VIRUS
287 PLAYSTATION,PLAY STATION
288 TOMORROW,TO MORROW
289 TIMECONSUMING,TIME CONSUMING
290 POSTWAR,POST WAR
291 HANDSON,HANDS ON
292 SHAKEUP,SHAKE UP
293 ECOMERS,E COMERS
294 COFOUNDER,CO FOUNDER
295 HIGHEND,HIGH END
296 INPERSON,IN PERSON
297 GROWNUP,GROWN UP
298 SELFREGULATION,SELF REGULATION
299 INDEPTH,IN DEPTH
300 ALLTIME,ALL TIME
301 LONGTERM,LONG TERM
302 SOCALLED,SO CALLED
303 SELFCONFIDENCE,SELF CONFIDENCE
304 STANDUP,STAND UP
305 MINDBOGGLING,MIND BOGGLING
306 BEINGFOROTHERS,BEING FOR OTHERS
307 COWROTE,CO WROTE
308 COSTARRED,CO STARRED
309 EDITORINCHIEF,EDITOR IN CHIEF
310 HIGHSPEED,HIGH SPEED
311 DECISIONMAKING,DECISION MAKING
312 WELLBEING,WELL BEING
313 NONTRIVIAL,NON TRIVIAL
314 PREEXISTING,PRE EXISTING
315 STATEOWNED,STATE OWNED
316 PLUGIN,PLUG IN
317 PROVERSION,PRO VERSION
318 OPTIN,OPT IN
319 FOLLOWUP,FOLLOW UP
320 FOLLOWUPS,FOLLOW UPS
321 WIFI,WI FI
322 THIRDPARTY,THIRD PARTY
323 PROFESSIONALLOOKING,PROFESSIONAL LOOKING
324 FULLSCREEN,FULL SCREEN
325 BUILTIN,BUILT IN
326 MULTISTREAM,MULTI STREAM
327 LOWCOST,LOW COST
328 RESTREAM,RE STREAM
329 GAMECHANGER,GAME CHANGER
330 WELLDEVELOPED,WELL DEVELOPED
331 QUARTERINCH,QUARTER INCH
332 FASTFASHION,FAST FASHION
333 ECOMMERCE,E COMMERCE
334 PRIZEWINNING,PRIZE WINNING
335 NEVERENDING,NEVER ENDING
336 MINDBLOWING,MIND BLOWING
337 REALLIFE,REAL LIFE
338 REOPEN,RE OPEN
339 ONDEMAND,ON DEMAND
340 PROBLEMSOLVING,PROBLEM SOLVING
341 HEAVYHANDED,HEAVY HANDED
342 OPENENDED,OPEN ENDED
343 SELFCONTROL,SELF CONTROL
344 WELLMEANING,WELL MEANING
345 COHOST,CO HOST
346 RIGHTSBASED,RIGHTS BASED
347 HALFBROTHER,HALF BROTHER
348 FATHERINLAW,FATHER IN LAW
349 COAUTHOR,CO AUTHOR
350 REELECTION,RE ELECTION
351 SELFHELP,SELF HELP
352 PROLIFE,PRO LIFE
353 ANTIDUKE,ANTI DUKE
354 POSTSTRUCTURALIST,POST STRUCTURALIST
355 COFOUNDED,CO FOUNDED
356 XRAY,X RAY
357 ALLAROUND,ALL AROUND
358 HIGHTECH,HIGH TECH
359 TMOBILE,T MOBILE
360 INHOUSE,IN HOUSE
361 POSTMORTEM,POST MORTEM
362 LITTLEKNOWN,LITTLE KNOWN
363 FALSEPOSITIVE,FALSE POSITIVE
364 ANTIVAXXER,ANTI VAXXER
365 EMAILS,E MAILS
366 DRIVETHROUGH,DRIVE THROUGH
367 DAYTODAY,DAY TO DAY
368 COSTAR,CO STAR
369 EBAY,E BAY
370 KOOLAID,KOOL AID
371 ANTIDEMOCRATIC,ANTI DEMOCRATIC
372 MIDDLEAGED,MIDDLE AGED
373 SHORTLIVED,SHORT LIVED
374 BESTSELLING,BEST SELLING
375 TICTACS,TIC TACS
376 UHHUH,UH HUH
377 MULTITANK,MULTI TANK
378 JAWDROPPING,JAW DROPPING
379 LIVESTREAMING,LIVE STREAMING
380 HARDWORKING,HARD WORKING
381 BOTTOMDWELLING,BOTTOM DWELLING
382 PRESHOW,PRE SHOW
383 HANDSFREE,HANDS FREE
384 TRICKORTREATING,TRICK OR TREATING
385 PRERECORDED,PRE RECORDED
386 DOGOODERS,DO GOODERS
387 WIDERANGING,WIDE RANGING
388 LIFESAVING,LIFE SAVING
389 SKIREPORT,SKI REPORT
390 SNOWBASE,SNOW BASE
391 JAYZ,JAY Z
392 SPIDERMAN,SPIDER MAN
393 FREEKICK,FREE KICK
394 EDWARDSHELAIRE,EDWARDS HELAIRE
395 SHORTTERM,SHORT TERM
396 HAVENOTS,HAVE NOTS
397 SELFINTEREST,SELF INTEREST
398 SELFINTERESTED,SELF INTERESTED
399 SELFCOMPASSION,SELF COMPASSION
400 MACHINELEARNING,MACHINE LEARNING
401 COAUTHORED,CO AUTHORED
402 NONGOVERNMENT,NON GOVERNMENT
403 SUBSAHARAN,SUB SAHARAN
404 COCHAIR,CO CHAIR
405 LARGESCALE,LARGE SCALE
406 VIDEOONDEMAND,VIDEO ON DEMAND
407 FIRSTCLASS,FIRST CLASS
408 COFOUNDERS,CO FOUNDERS
409 COOP,CO OP
410 PREORDERS,PRE ORDERS
411 DOUBLEENTRY,DOUBLE ENTRY
412 SELFCONFIDENT,SELF CONFIDENT
413 SELFPORTRAIT,SELF PORTRAIT
414 NONWHITE,NON WHITE
415 ONBOARD,ON BOARD
416 HALFLIFE,HALF LIFE
417 ONCOURT,ON COURT
418 SCIFI,SCI FI
419 XMEN,X MEN
420 DAYLEWIS,DAY LEWIS
421 LALALAND,LA LA LAND
422 AWARDWINNING,AWARD WINNING
423 BOXOFFICE,BOX OFFICE
424 TRIDACTYLS,TRI DACTYLS
425 TRIDACTYL,TRI DACTYL
426 MEDIUMSIZED,MEDIUM SIZED
427 POSTSECONDARY,POST SECONDARY
428 FULLTIME,FULL TIME
429 GOKART,GO KART
430 OPENAIR,OPEN AIR
431 WELLKNOWN,WELL KNOWN
432 ICECREAM,ICE CREAM
433 EARTHMOON,EARTH MOON
434 STATEOFTHEART,STATE OF THE ART
435 BSIDE,B SIDE
436 EASTWEST,EAST WEST
437 ALLSTAR,ALL STAR
438 RUNNERUP,RUNNER UP
439 HORSEDRAWN,HORSE DRAWN
440 OPENSOURCE,OPEN SOURCE
441 PURPOSEBUILT,PURPOSE BUILT
442 SQUAREFREE,SQUARE FREE
443 PRESENTDAY,PRESENT DAY
444 CANADAUNITED,CANADA UNITED
445 HOTCHPOTCH,HOTCH POTCH
446 LOWLYING,LOW LYING
447 RIGHTHANDED,RIGHT HANDED
448 PEARSHAPED,PEAR SHAPED
449 BESTKNOWN,BEST KNOWN
450 FULLLENGTH,FULL LENGTH
451 YEARROUND,YEAR ROUND
452 PREELECTION,PRE ELECTION
453 RERECORD,RE RECORD
454 MINIALBUM,MINI ALBUM
455 LONGESTRUNNING,LONGEST RUNNING
456 ALLIRELAND,ALL IRELAND
457 NORTHWESTERN,NORTH WESTERN
458 PARTTIME,PART TIME
459 NONGOVERNMENTAL,NON GOVERNMENTAL
460 ONLINE,ON LINE
461 ONAIR,ON AIR
462 NORTHSOUTH,NORTH SOUTH
463 RERELEASED,RE RELEASED
464 LEFTHANDED,LEFT HANDED
465 BSIDES,B SIDES
466 ANGLOSAXON,ANGLO SAXON
467 SOUTHSOUTHEAST,SOUTH SOUTHEAST
468 CROSSCOUNTRY,CROSS COUNTRY
469 REBUILT,RE BUILT
470 FREEFORM,FREE FORM
471 SCOOBYDOO,SCOOBY DOO
472 ATLARGE,AT LARGE
473 COUNCILMANAGER,COUNCIL MANAGER
474 LONGRUNNING,LONG RUNNING
475 PREWAR,PRE WAR
476 REELECTED,RE ELECTED
477 HIGHSCHOOL,HIGH SCHOOL
478 RUNNERSUP,RUNNERS UP
479 NORTHWEST,NORTH WEST
480 WEBBASED,WEB BASED
481 HIGHQUALITY,HIGH QUALITY
482 RIGHTWING,RIGHT WING
483 LANEFOX,LANE FOX
484 PAYPERVIEW,PAY PER VIEW
485 COPRODUCTION,CO PRODUCTION
486 NONPARTISAN,NON PARTISAN
487 FIRSTPERSON,FIRST PERSON
488 WORLDRENOWNED,WORLD RENOWNED
489 VICEPRESIDENT,VICE PRESIDENT
490 PROROMAN,PRO ROMAN
491 COPRODUCED,CO PRODUCED
492 LOWPOWER,LOW POWER
493 SELFESTEEM,SELF ESTEEM
494 SEMITRANSPARENT,SEMI TRANSPARENT
495 SECONDINCOMMAND,SECOND IN COMMAND
496 HIGHRISE,HIGH RISE
497 COHOSTED,CO HOSTED
498 AFRICANAMERICAN,AFRICAN AMERICAN
499 SOUTHWEST,SOUTH WEST
500 WELLPRESERVED,WELL PRESERVED
501 FEATURELENGTH,FEATURE LENGTH
502 HIPHOP,HIP HOP
503 ALLBIG,ALL BIG
504 SOUTHEAST,SOUTH EAST
505 COUNTERATTACK,COUNTER ATTACK
506 QUARTERFINALS,QUARTER FINALS
507 STABLEDOOR,STABLE DOOR
508 DARKEYED,DARK EYED
509 ALLAMERICAN,ALL AMERICAN
510 THIRDPERSON,THIRD PERSON
511 LOWLEVEL,LOW LEVEL
512 NTERMINAL,N TERMINAL
513 DRIEDUP,DRIED UP
514 AFRICANAMERICANS,AFRICAN AMERICANS
515 ANTIAPARTHEID,ANTI APARTHEID
516 STOKEONTRENT,STOKE ON TRENT
517 NORTHNORTHEAST,NORTH NORTHEAST
518 BRANDNEW,BRAND NEW
519 RIGHTANGLED,RIGHT ANGLED
520 GOVERNMENTOWNED,GOVERNMENT OWNED
521 SONINLAW,SON IN LAW
522 SUBJECTOBJECTVERB,SUBJECT OBJECT VERB
523 LEFTARM,LEFT ARM
524 LONGLIVED,LONG LIVED
525 REDEYE,RED EYE
526 TPOSE,T POSE
527 NIGHTVISION,NIGHT VISION
528 SOUTHEASTERN,SOUTH EASTERN
529 WELLRECEIVED,WELL RECEIVED
530 ALFAYOUM,AL FAYOUM
531 TIMEBASED,TIME BASED
532 KETTLEDRUMS,KETTLE DRUMS
533 BRIGHTEYED,BRIGHT EYED
534 REDBROWN,RED BROWN
535 SAMESEX,SAME SEX
536 PORTDEPAIX,PORT DE PAIX
537 CLEANUP,CLEAN UP
538 PERCENT,PERCENT SIGN
539 TAKEOUT,TAKE OUT
540 KNOWHOW,KNOW HOW
541 FISHBONE,FISH BONE
542 FISHSTICKS,FISH STICKS
543 PAPERWORK,PAPER WORK
544 NICKNACKS,NICK NACKS
545 STREETTALKING,STREET TALKING
546 NONACADEMIC,NON ACADEMIC
547 SHELLY,SHELLEY
548 SHELLY'S,SHELLEY'S
549 JIMMY,JIMMIE
550 JIMMY'S,JIMMIE'S
551 DRUGSTORE,DRUG STORE
552 THRU,THROUGH
553 PLAYDATE,PLAY DATE
554 MICROLIFE,MICRO LIFE
555 SKILLSET,SKILL SET
556 SKILLSETS,SKILL SETS
557 TRADEOFF,TRADE OFF
558 TRADEOFFS,TRADE OFFS
559 ONSCREEN,ON SCREEN
560 PLAYBACK,PLAY BACK
561 ARTWORK,ART WORK
562 COWORKER,CO WORDER
563 COWORKERS,CO WORDERS
564 SOMETIME,SOME TIME
565 SOMETIMES,SOME TIMES
566 CROWDFUNDING,CROWD FUNDING
567 AM,A.M.,A M
568 PM,P.M.,P M
569 TV,T V
570 MBA,M B A
571 USA,U S A
572 US,U S
573 UK,U K
574 CEO,C E O
575 CFO,C F O
576 COO,C O O
577 CIO,C I O
578 FM,F M
579 GMC,G M C
580 FSC,F S C
581 NPD,N P D
582 APM,A P M
583 NGO,N G O
584 TD,T D
585 LOL,L O L
586 IPO,I P O
587 CNBC,C N B C
588 IPOS,I P OS
589 CNBC's,C N B C'S
590 JT,J T
591 NPR,N P R
592 NPR'S,N P R'S
593 MP,M P
594 IOI,I O I
595 DW,D W
596 CNN,C N N
597 WSM,W S M
598 ET,E T
599 IT,I T
600 RJ,R J
601 DVD,D V D
602 DVD'S,D V D'S
603 HBO,H B O
604 LA,L A
605 XC,X C
606 SUV,S U V
607 NBA,N B A
608 NBA'S,N B A'S
609 ESPN,E S P N
610 ESPN'S,E S P N'S
611 ADT,A D T
612 HD,H D
613 VIP,V I P
614 TMZ,T M Z
615 CBC,C B C
616 NPO,N P O
617 BBC,B B C
618 LA'S,L A'S
619 TMZ'S,T M Z'S
620 HIV,H I V
621 FTC,F T C
622 EU,E U
623 PHD,P H D
624 AI,A I
625 FHI,F H I
626 ICML,I C M L
627 ICLR,I C L R
628 BMW,B M W
629 EV,E V
630 CR,C R
631 API,A P I
632 ICO,I C O
633 LTE,L T E
634 OBS,O B S
635 PC,P C
636 IO,I O
637 CRM,C R M
638 RTMP,R T M P
639 ASMR,A S M R
640 GG,G G
641 WWW,W W W
642 PEI,P E I
643 JJ,J J
644 PT,P T
645 DJ,D J
646 SD,S D
647 POW,P.O.W.,P O W
648 FYI,F Y I
649 DC,D C,D.C
650 ABC,A B C
651 TJ,T J
652 WMDT,W M D T
653 WDTN,W D T N
654 TY,T Y
655 EJ,E J
656 CJ,C J
657 ACL,A C L
658 UK'S,U K'S
659 GTV,G T V
660 MDMA,M D M A
661 DFW,D F W
662 WTF,W T F
663 AJ,A J
664 MD,M D
665 PH,P H
666 ID,I D
667 SEO,S E O
668 UTM'S,U T M'S
669 EC,E C
670 UFC,U F C
671 RV,R V
672 UTM,U T M
673 CSV,C S V
674 SMS,S M S
675 GRB,G R B
676 GT,G T
677 LEM,L E M
678 XR,X R
679 EDU,E D U
680 NBC,N B C
681 EMS,E M S
682 CDC,C D C
683 MLK,M L K
684 IE,I E
685 OC,O C
686 HR,H R
687 MA,M A
688 DEE,D E E
689 AP,A P
690 UFO,U F O
691 DE,D E
692 LGBTQ,L G B T Q
693 PTA,P T A
694 NHS,N H S
695 CMA,C M A
696 MGM,M G M
697 AKA,A K A
698 HW,H W
699 GOP,G O P
700 GOP'S,G O P'S
701 FBI,F B I
702 PRX,P R X
703 CTO,C T O
704 URL,U R L
705 EIN,E I N
706 MLS,M L S
707 CSI,C S I
708 AOC,A O C
709 CND,C N D
710 CP,C P
711 PP,P P
712 CLI,C L I
713 PB,P B
714 FDA,F D A
715 MRNA,M R N A
716 PR,P R
717 VP,V P
718 DNC,D N C
719 MSNBC,M S N B C
720 GQ,G Q
721 UT,U T
722 XXI,X X I
723 HRV,H R V
724 WHO,W H O
725 CRO,C R O
726 DPA,D P A
727 PPE,P P E
728 EVA,E V A
729 BP,B P
730 GPS,G P S
731 AR,A R
732 PJ,P J
733 MLM,M L M
734 OLED,O L E D
735 BO,B O
736 VE,V E
737 UN,U N
738 SLS,S L S
739 DM,D M
740 DM'S,D M'S
741 ASAP,A S A P
742 ETA,E T A
743 DOB,D O B
744 BMW,B M W

View File

@@ -0,0 +1,20 @@
ach
ah
eee
eh
er
ew
ha
hee
hm
hmm
hmmm
huh
mm
mmm
oof
uh
uhh
um
oh
hum
1 ach
2 ah
3 eee
4 eh
5 er
6 ew
7 ha
8 hee
9 hm
10 hmm
11 hmmm
12 huh
13 mm
14 mmm
15 oof
16 uh
17 uhh
18 um
19 oh
20 hum

View File

@@ -0,0 +1 @@
nemo_version from commit:eae1684f7f33c2a18de9ecfa42ec7db93d39e631

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,10 @@
# Text Normalization
Text Normalization is part of NeMo's `nemo_text_processing` - a Python package that is installed with the `nemo_toolkit`.
It converts text from written form into its verbalized form, e.g. "123" -> "one hundred twenty three".
See [NeMo documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/text_normalization/wfst/wfst_text_normalization.html) for details.
Tutorial with overview of the package capabilities: [Text_(Inverse)_Normalization.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb)
Tutorial on how to customize the underlying gramamrs: [WFST_Tutorial.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/WFST_Tutorial.ipynb)

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,350 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import string
from collections import defaultdict, namedtuple
from typing import Dict, List, Optional, Set, Tuple
from unicodedata import category
EOS_TYPE = "EOS"
PUNCT_TYPE = "PUNCT"
PLAIN_TYPE = "PLAIN"
Instance = namedtuple('Instance', 'token_type un_normalized normalized')
known_types = [
"PLAIN",
"DATE",
"CARDINAL",
"LETTERS",
"VERBATIM",
"MEASURE",
"DECIMAL",
"ORDINAL",
"DIGIT",
"MONEY",
"TELEPHONE",
"ELECTRONIC",
"FRACTION",
"TIME",
"ADDRESS",
]
def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
"""
https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
E.g.
PLAIN Brillantaisia <self>
PLAIN is <self>
PLAIN a <self>
PLAIN genus <self>
PLAIN of <self>
PLAIN plant <self>
PLAIN in <self>
PLAIN family <self>
PLAIN Acanthaceae <self>
PUNCT . sil
<eos> <eos>
Args:
file_path: file path to text file
Returns: flat list of instances
"""
res = []
with open(file_path, 'r') as fp:
for line in fp:
parts = line.strip().split("\t")
if parts[0] == "<eos>":
res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized=""))
else:
l_type, l_token, l_normalized = parts
l_token = l_token.lower()
l_normalized = l_normalized.lower()
if l_type == PLAIN_TYPE:
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token))
elif l_type != PUNCT_TYPE:
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized))
return res
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]:
"""
Load given list of text files using the `load_func` function.
Args:
file_paths: list of file paths
load_func: loading function
Returns: flat list of instances
"""
res = []
for file_path in file_paths:
res.extend(load_func(file_path=file_path))
return res
def clean_generic(text: str) -> str:
"""
Cleans text without affecting semiotic classes.
Args:
text: string
Returns: cleaned string
"""
text = text.strip()
text = text.lower()
return text
def evaluate(preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True) -> float:
"""
Evaluates accuracy given predictions and labels.
Args:
preds: predictions
labels: labels
input: optional, only needed for verbosity
verbose: if true prints [input], golden labels and predictions
Returns accuracy
"""
acc = 0
nums = len(preds)
for i in range(nums):
pred_norm = clean_generic(preds[i])
label_norm = clean_generic(labels[i])
if pred_norm == label_norm:
acc = acc + 1
else:
if input:
print(f"inpu: {json.dumps(input[i])}")
print(f"gold: {json.dumps(label_norm)}")
print(f"pred: {json.dumps(pred_norm)}")
return acc / nums
def training_data_to_tokens(
data: List[Instance], category: Optional[str] = None
) -> Dict[str, Tuple[List[str], List[str]]]:
"""
Filters the instance list by category if provided and converts it into a map from token type to list of un_normalized and normalized strings
Args:
data: list of instances
category: optional semiotic class category name
Returns Dict: token type -> (list of un_normalized strings, list of normalized strings)
"""
result = defaultdict(lambda: ([], []))
for instance in data:
if instance.token_type != EOS_TYPE:
if category is None or instance.token_type == category:
result[instance.token_type][0].append(instance.un_normalized)
result[instance.token_type][1].append(instance.normalized)
return result
def training_data_to_sentences(data: List[Instance]) -> Tuple[List[str], List[str], List[Set[str]]]:
"""
Takes instance list, creates list of sentences split by EOS_Token
Args:
data: list of instances
Returns (list of unnormalized sentences, list of normalized sentences, list of sets of categories in a sentence)
"""
# split data at EOS boundaries
sentences = []
sentence = []
categories = []
sentence_categories = set()
for instance in data:
if instance.token_type == EOS_TYPE:
sentences.append(sentence)
sentence = []
categories.append(sentence_categories)
sentence_categories = set()
else:
sentence.append(instance)
sentence_categories.update([instance.token_type])
un_normalized = [" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences]
normalized = [" ".join([instance.normalized for instance in sentence]) for sentence in sentences]
return un_normalized, normalized, categories
def post_process_punctuation(text: str) -> str:
"""
Normalized quotes and spaces
Args:
text: text
Returns: text with normalized spaces and quotes
"""
text = (
text.replace('( ', '(')
.replace(' )', ')')
.replace('{ ', '{')
.replace(' }', '}')
.replace('[ ', '[')
.replace(' ]', ']')
.replace(' ', ' ')
.replace('', '"')
.replace("", "'")
.replace("»", '"')
.replace("«", '"')
.replace("\\", "")
.replace("", '"')
.replace("´", "'")
.replace("", "'")
.replace('', '"')
.replace("", "'")
.replace('`', "'")
.replace('- -', "--")
)
for punct in "!,.:;?":
text = text.replace(f' {punct}', punct)
return text.strip()
def pre_process(text: str) -> str:
"""
Optional text preprocessing before normalization (part of TTS TN pipeline)
Args:
text: string that may include semiotic classes
Returns: text with spaces around punctuation marks
"""
space_both = '[]'
for punct in space_both:
text = text.replace(punct, ' ' + punct + ' ')
# remove extra space
text = re.sub(r' +', ' ', text)
return text
def load_file(file_path: str) -> List[str]:
"""
Loads given text file with separate lines into list of string.
Args:
file_path: file path
Returns: flat list of string
"""
res = []
with open(file_path, 'r') as fp:
for line in fp:
res.append(line)
return res
def write_file(file_path: str, data: List[str]):
"""
Writes out list of string to file.
Args:
file_path: file path
data: list of string
"""
with open(file_path, 'w') as fp:
for line in data:
fp.write(line + '\n')
def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
"""
Post-processing of the normalized output to match input in terms of spaces around punctuation marks.
After NN normalization, Moses detokenization puts a space after
punctuation marks, and attaches an opening quote "'" to the word to the right.
E.g., input to the TN NN model is "12 test' example",
after normalization and detokenization -> "twelve test 'example" (the quote is considered to be an opening quote,
but it doesn't match the input and can cause issues during TTS voice generation.)
The current function will match the punctuation and spaces of the normalized text with the input sequence.
"12 test' example" -> "twelve test 'example" -> "twelve test' example" (the quote was shifted to match the input).
Args:
input: input text (original input to the NN, before normalization or tokenization)
normalized_text: output text (output of the TN NN model)
add_unicode_punct: set to True to handle unicode punctuation marks as well as default string.punctuation (increases post processing time)
"""
# in the post-processing WFST graph "``" are repalced with '"" quotes (otherwise single quotes "`" won't be handled correctly)
# this function fixes spaces around them based on input sequence, so here we're making the same double quote replacement
# to make sure these new double quotes work with this function
if "``" in input and "``" not in normalized_text:
input = input.replace("``", '"')
input = [x for x in input]
normalized_text = [x for x in normalized_text]
punct_marks = [x for x in string.punctuation if x in input]
if add_unicode_punct:
punct_unicode = [
chr(i)
for i in range(sys.maxunicode)
if category(chr(i)).startswith("P") and chr(i) not in punct_default and chr(i) in input
]
punct_marks = punct_marks.extend(punct_unicode)
for punct in punct_marks:
try:
equal = True
if input.count(punct) != normalized_text.count(punct):
equal = False
idx_in, idx_out = 0, 0
while punct in input[idx_in:]:
idx_out = normalized_text.index(punct, idx_out)
idx_in = input.index(punct, idx_in)
def _is_valid(idx_out, idx_in, normalized_text, input):
"""Check if previous or next word match (for cases when punctuation marks are part of
semiotic token, i.e. some punctuation can be missing in the normalized text)"""
return (idx_out > 0 and idx_in > 0 and normalized_text[idx_out - 1] == input[idx_in - 1]) or (
idx_out < len(normalized_text) - 1
and idx_in < len(input) - 1
and normalized_text[idx_out + 1] == input[idx_in + 1]
)
if not equal and not _is_valid(idx_out, idx_in, normalized_text, input):
idx_in += 1
continue
if idx_in > 0 and idx_out > 0:
if normalized_text[idx_out - 1] == " " and input[idx_in - 1] != " ":
normalized_text[idx_out - 1] = ""
elif normalized_text[idx_out - 1] != " " and input[idx_in - 1] == " ":
normalized_text[idx_out - 1] += " "
if idx_in < len(input) - 1 and idx_out < len(normalized_text) - 1:
if normalized_text[idx_out + 1] == " " and input[idx_in + 1] != " ":
normalized_text[idx_out + 1] = ""
elif normalized_text[idx_out + 1] != " " and input[idx_in + 1] == " ":
normalized_text[idx_out] = normalized_text[idx_out] + " "
idx_out += 1
idx_in += 1
except:
pass
normalized_text = "".join(normalized_text)
return re.sub(r' +', ' ', normalized_text)

View File

@@ -0,0 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst
from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst

View File

@@ -0,0 +1,342 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from typing import List
import regex as re
from nemo_text_processing.text_normalization.data_loader_utils import (
EOS_TYPE,
Instance,
load_files,
training_data_to_sentences,
)
"""
This file is for evaluation purposes.
filter_loaded_data() cleans data (list of instances) for text normalization. Filters and cleaners can be specified for each semiotic class individually.
For example, normalized text should only include characters and whitespace characters but no punctuation.
Cardinal unnormalized instances should contain at least one integer and all other characters are removed.
"""
class Filter:
"""
Filter class
Args:
class_type: semiotic class used in dataset
process_func: function to transform text
filter_func: function to filter text
"""
def __init__(self, class_type: str, process_func: object, filter_func: object):
self.class_type = class_type
self.process_func = process_func
self.filter_func = filter_func
def filter(self, instance: Instance) -> bool:
"""
filter function
Args:
filters given instance with filter function
Returns: True if given instance fulfills criteria or does not belong to class type
"""
if instance.token_type != self.class_type:
return True
return self.filter_func(instance)
def process(self, instance: Instance) -> Instance:
"""
process function
Args:
processes given instance with process function
Returns: processed instance if instance belongs to expected class type or original instance
"""
if instance.token_type != self.class_type:
return instance
return self.process_func(instance)
def filter_cardinal_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_cardinal_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r"[^0-9]", "", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_ordinal_1(instance: Instance) -> bool:
ok = re.search(r"(st|nd|rd|th)\s*$", instance.un_normalized)
return ok
def process_ordinal_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r"[,\s]", "", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_decimal_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_decimal_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
un_normalized = re.sub(r",", "", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_measure_1(instance: Instance) -> bool:
ok = True
return ok
def process_measure_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r",", "", un_normalized)
un_normalized = re.sub(r"m2", "", un_normalized)
un_normalized = re.sub(r"(\d)([^\d.\s])", r"\1 \2", un_normalized)
normalized = re.sub(r"[^a-z\s]", "", normalized)
normalized = re.sub(r"per ([a-z\s]*)s$", r"per \1", normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_money_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_money_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r",", "", un_normalized)
un_normalized = re.sub(r"a\$", r"$", un_normalized)
un_normalized = re.sub(r"us\$", r"$", un_normalized)
un_normalized = re.sub(r"(\d)m\s*$", r"\1 million", un_normalized)
un_normalized = re.sub(r"(\d)bn?\s*$", r"\1 billion", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_time_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_time_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
un_normalized = re.sub(r": ", ":", un_normalized)
un_normalized = re.sub(r"(\d)\s?a\s?m\s?", r"\1 a.m.", un_normalized)
un_normalized = re.sub(r"(\d)\s?p\s?m\s?", r"\1 p.m.", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_plain_1(instance: Instance) -> bool:
ok = True
return ok
def process_plain_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_punct_1(instance: Instance) -> bool:
ok = True
return ok
def process_punct_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_date_1(instance: Instance) -> bool:
ok = True
return ok
def process_date_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
un_normalized = re.sub(r",", "", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_letters_1(instance: Instance) -> bool:
ok = True
return ok
def process_letters_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_verbatim_1(instance: Instance) -> bool:
ok = True
return ok
def process_verbatim_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_digit_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_digit_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_telephone_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_telephone_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_electronic_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_electronic_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_fraction_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_fraction_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_address_1(instance: Instance) -> bool:
ok = True
return ok
def process_address_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
filters = []
filters.append(Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1))
filters.append(Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1))
filters.append(Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1))
filters.append(Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1))
filters.append(Filter(class_type="MONEY", process_func=process_money_1, filter_func=filter_money_1))
filters.append(Filter(class_type="TIME", process_func=process_time_1, filter_func=filter_time_1))
filters.append(Filter(class_type="DATE", process_func=process_date_1, filter_func=filter_date_1))
filters.append(Filter(class_type="PLAIN", process_func=process_plain_1, filter_func=filter_plain_1))
filters.append(Filter(class_type="PUNCT", process_func=process_punct_1, filter_func=filter_punct_1))
filters.append(Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1))
filters.append(Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1))
filters.append(Filter(class_type="DIGIT", process_func=process_digit_1, filter_func=filter_digit_1))
filters.append(Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1))
filters.append(Filter(class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1))
filters.append(Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1))
filters.append(Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1))
filters.append(Filter(class_type=EOS_TYPE, process_func=lambda x: x, filter_func=lambda x: True))
def filter_loaded_data(data: List[Instance], verbose: bool = False) -> List[Instance]:
"""
Filters list of instances
Args:
data: list of instances
Returns: filtered and transformed list of instances
"""
updates_instances = []
for instance in data:
updated_instance = False
for fil in filters:
if fil.class_type == instance.token_type and fil.filter(instance):
instance = fil.process(instance)
updated_instance = True
if updated_instance:
if verbose:
print(instance)
updates_instances.append(instance)
return updates_instances
def parse_args():
parser = ArgumentParser()
parser.add_argument("--input", help="input file path", type=str, default='./en_with_types/output-00001-of-00100')
parser.add_argument("--verbose", help="print filtered instances", action='store_true')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
file_path = args.input
print("Loading training data: " + file_path)
instance_list = load_files([file_path]) # List of instances
filtered_instance_list = filter_loaded_data(instance_list, args.verbose)
training_data_to_sentences(filtered_instance_list)

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,14 @@
st Street
street Street
expy Expressway
fwy Freeway
hwy Highway
dr Drive
ct Court
ave Avenue
av Avenue
cir Circle
blvd Boulevard
alley Alley
way Way
jct Junction
1 st Street
2 street Street
3 expy Expressway
4 fwy Freeway
5 hwy Highway
6 dr Drive
7 ct Court
8 ave Avenue
9 av Avenue
10 cir Circle
11 blvd Boulevard
12 alley Alley
13 way Way
14 jct Junction

View File

@@ -0,0 +1,52 @@
Alabama AL
Alaska AK
Arizona AZ
Arkansas AR
California CA
Colorado CO
Connecticut CT
Delaware DE
Florida FL
Georgia GA
Hawaii HI
Idaho ID
Illinois IL
Indiana IN
Indiana IND
Iowa IA
Kansas KS
Kentucky KY
Louisiana LA
Maine ME
Maryland MD
Massachusetts MA
Michigan MI
Minnesota MN
Mississippi MS
Missouri MO
Montana MT
Nebraska NE
Nevada NV
New Hampshire NH
New Jersey NJ
New Mexico NM
New York NY
North Carolina NC
North Dakota ND
Ohio OH
Oklahoma OK
Oregon OR
Pennsylvania PA
Rhode Island RI
South Carolina SC
South Dakota SD
Tennessee TN
Tennessee TENN
Texas TX
Utah UT
Vermont VT
Virginia VA
Washington WA
West Virginia WV
Wisconsin WI
Wyoming WY
1 Alabama AL
2 Alaska AK
3 Arizona AZ
4 Arkansas AR
5 California CA
6 Colorado CO
7 Connecticut CT
8 Delaware DE
9 Florida FL
10 Georgia GA
11 Hawaii HI
12 Idaho ID
13 Illinois IL
14 Indiana IN
15 Indiana IND
16 Iowa IA
17 Kansas KS
18 Kentucky KY
19 Louisiana LA
20 Maine ME
21 Maryland MD
22 Massachusetts MA
23 Michigan MI
24 Minnesota MN
25 Mississippi MS
26 Missouri MO
27 Montana MT
28 Nebraska NE
29 Nevada NV
30 New Hampshire NH
31 New Jersey NJ
32 New Mexico NM
33 New York NY
34 North Carolina NC
35 North Dakota ND
36 Ohio OH
37 Oklahoma OK
38 Oregon OR
39 Pennsylvania PA
40 Rhode Island RI
41 South Carolina SC
42 South Dakota SD
43 Tennessee TN
44 Tennessee TENN
45 Texas TX
46 Utah UT
47 Vermont VT
48 Virginia VA
49 Washington WA
50 West Virginia WV
51 Wisconsin WI
52 Wyoming WY

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,31 @@
one
two
three
four
five
six
seven
eight
nine
ten
eleven
twelve
thirteen
fourteen
fifteen
sixteen
seventeen
eighteen
nineteen
twenty
twenty one
twenty two
twenty three
twenty four
twenty five
twenty six
twenty seven
twenty eight
twenty nine
thirty
thirty one
1 one
2 two
3 three
4 four
5 five
6 six
7 seven
8 eight
9 nine
10 ten
11 eleven
12 twelve
13 thirteen
14 fourteen
15 fifteen
16 sixteen
17 seventeen
18 eighteen
19 nineteen
20 twenty
21 twenty one
22 twenty two
23 twenty three
24 twenty four
25 twenty five
26 twenty six
27 twenty seven
28 twenty eight
29 twenty nine
30 thirty
31 thirty one

View File

@@ -0,0 +1,12 @@
jan january
feb february
mar march
apr april
jun june
jul july
aug august
sep september
sept september
oct october
nov november
dec december
1 jan january
2 feb february
3 mar march
4 apr april
5 jun june
6 jul july
7 aug august
8 sep september
9 sept september
10 oct october
11 nov november
12 dec december

View File

@@ -0,0 +1,12 @@
january
february
march
april
may
june
july
august
september
october
november
december
1 january
2 february
3 march
4 april
5 may
6 june
7 july
8 august
9 september
10 october
11 november
12 december

View File

@@ -0,0 +1,24 @@
1 january
2 february
3 march
4 april
5 may
6 june
7 july
8 august
9 september
10 october
11 november
12 december
01 january
02 february
03 march
04 april
05 may
06 june
07 july
08 august
09 september
10 october
11 november
12 december
1 1 january
2 2 february
3 3 march
4 4 april
5 5 may
6 6 june
7 7 july
8 8 august
9 9 september
10 10 october
11 11 november
12 12 december
13 01 january
14 02 february
15 03 march
16 04 april
17 05 may
18 06 june
19 07 july
20 08 august
21 09 september
22 10 october
23 11 november
24 12 december

View File

@@ -0,0 +1,16 @@
A. D AD
A.D AD
a. d AD
a.d AD
a. d. AD
a.d. AD
B. C BC
B.C BC
b. c BC
b.c BC
A. D. AD
A.D. AD
B. C. BC
B.C. BC
b. c. BC
b.c. BC
1 A. D AD
2 A.D AD
3 a. d AD
4 a.d AD
5 a. d. AD
6 a.d. AD
7 B. C BC
8 B.C BC
9 b. c BC
10 b.c BC
11 A. D. AD
12 A.D. AD
13 B. C. BC
14 B.C. BC
15 b. c. BC
16 b.c. BC

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,12 @@
.com dot com
.org dot org
.gov dot gov
.uk dot UK
.fr dot FR
.net dot net
.br dot BR
.in dot IN
.ru dot RU
.de dot DE
.it dot IT
.jpg dot jpeg
1 .com dot com
2 .org dot org
3 .gov dot gov
4 .uk dot UK
5 .fr dot FR
6 .net dot net
7 .br dot BR
8 .in dot IN
9 .ru dot RU
10 .de dot DE
11 .it dot IT
12 .jpg dot jpeg

View File

@@ -0,0 +1,21 @@
. dot
- dash
_ underscore
! exclamation mark
# number sign
$ dollar sign
% percent sign
& ampersand
' quote
* asterisk
+ plus
/ slash
= equal sign
? question mark
^ circumflex
` right single quote
{ left brace
| vertical bar
} right brace
~ tilde
, comma
1 . dot
2 - dash
3 _ underscore
4 ! exclamation mark
5 # number sign
6 $ dollar sign
7 % percent sign
8 & ampersand
9 ' quote
10 * asterisk
11 + plus
12 / slash
13 = equal sign
14 ? question mark
15 ^ circumflex
16 ` right single quote
17 { left brace
18 | vertical bar
19 } right brace
20 ~ tilde
21 , comma

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Some files were not shown because too many files have changed in this diff Show More