diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
new file mode 100644
index 0000000000..55dd8360e1
--- /dev/null
+++ b/.github/workflows/pre-commit.yml
@@ -0,0 +1,21 @@
+name: pre-commit
+
+on:
+ pull_request:
+ push:
+ branches: ['main', 'release/*']
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v3
+ with:
+ python-version: '3.10'
+ # Install Dependencies for Python
+ - name: Install Dependencies for Python
+ run: |
+ python -m pip install --upgrade pip
+ pip install "clang-format==13.0.0"
+ - uses: pre-commit/action@v3.0.1
diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py
index a5c74a24d5..62991a68ff 100644
--- a/PPOCRLabel/PPOCRLabel.py
+++ b/PPOCRLabel/PPOCRLabel.py
@@ -26,17 +26,47 @@
import xlrd
from functools import partial
-from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPointF, QProcess
+from PyQt5.QtCore import (
+ QSize,
+ Qt,
+ QPoint,
+ QByteArray,
+ QTimer,
+ QFileInfo,
+ QPointF,
+ QProcess,
+)
from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader
-from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \
- QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, QGridLayout, \
- QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView, QSizePolicy
+from PyQt5.QtWidgets import (
+ QMainWindow,
+ QListWidget,
+ QVBoxLayout,
+ QToolButton,
+ QHBoxLayout,
+ QDockWidget,
+ QWidget,
+ QSlider,
+ QGraphicsOpacityEffect,
+ QMessageBox,
+ QListView,
+ QScrollArea,
+ QWidgetAction,
+ QApplication,
+ QLabel,
+ QGridLayout,
+ QFileDialog,
+ QListWidgetItem,
+ QComboBox,
+ QDialog,
+ QAbstractItemView,
+ QSizePolicy,
+)
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../PaddleOCR')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "../..")))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "../PaddleOCR")))
sys.path.append("..")
from paddleocr import PaddleOCR, PPStructure
@@ -57,7 +87,7 @@
from libs.unique_label_qlist_widget import UniqueLabelQListWidget
from libs.keyDialog import KeyDialog
-__appname__ = 'PPOCRLabel'
+__appname__ = "PPOCRLabel"
LABEL_COLORMAP = label_colormap()
@@ -65,13 +95,15 @@
class MainWindow(QMainWindow):
FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = list(range(3))
- def __init__(self,
- lang="ch",
- gpu=False,
- kie_mode=False,
- default_filename=None,
- default_predefined_class_file=None,
- default_save_dir=None):
+ def __init__(
+ self,
+ lang="ch",
+ gpu=False,
+ kie_mode=False,
+ default_filename=None,
+ default_predefined_class_file=None,
+ default_save_dir=None,
+ ):
super(MainWindow, self).__init__()
self.setWindowTitle(__appname__)
self.setWindowState(Qt.WindowMaximized) # set window max
@@ -84,34 +116,38 @@ def __init__(self,
self.lang = lang
# Load string bundle for i18n
- if lang not in ['ch', 'en']:
- lang = 'en'
- self.stringBundle = StringBundle.getBundle(localeStr='zh-CN' if lang == 'ch' else 'en') # 'en'
+ if lang not in ["ch", "en"]:
+ lang = "en"
+ self.stringBundle = StringBundle.getBundle(
+ localeStr="zh-CN" if lang == "ch" else "en"
+ ) # 'en'
getStr = lambda strId: self.stringBundle.getString(strId)
# KIE setting
self.kie_mode = kie_mode
self.key_previous_text = ""
self.existed_key_cls_set = set()
- self.key_dialog_tip = getStr('keyDialogTip')
+ self.key_dialog_tip = getStr("keyDialogTip")
self.defaultSaveDir = default_save_dir
- self.ocr = PaddleOCR(use_pdserving=False,
- use_angle_cls=True,
- det=True,
- cls=True,
- use_gpu=gpu,
- lang=lang,
- show_log=False)
- self.table_ocr = PPStructure(use_pdserving=False,
- use_gpu=gpu,
- lang=lang,
- layout=False,
- show_log=False)
-
- if os.path.exists('./data/paddle.png'):
- result = self.ocr.ocr('./data/paddle.png', cls=True, det=True)
- result = self.table_ocr('./data/paddle.png', return_ocr_result_in_table=True)
+ self.ocr = PaddleOCR(
+ use_pdserving=False,
+ use_angle_cls=True,
+ det=True,
+ cls=True,
+ use_gpu=gpu,
+ lang=lang,
+ show_log=False,
+ )
+ self.table_ocr = PPStructure(
+ use_pdserving=False, use_gpu=gpu, lang=lang, layout=False, show_log=False
+ )
+
+ if os.path.exists("./data/paddle.png"):
+ result = self.ocr.ocr("./data/paddle.png", cls=True, det=True)
+ result = self.table_ocr(
+ "./data/paddle.png", return_ocr_result_in_table=True
+ )
# For loading all image under a directory
self.mImgList = []
@@ -145,9 +181,9 @@ def __init__(self,
self.shapesToItems = {}
self.itemsToShapesbox = {}
self.shapesToItemsbox = {}
- self.prevLabelText = getStr('tempLabel')
- self.noLabelText = getStr('nullLabel')
- self.model = 'paddle'
+ self.prevLabelText = getStr("tempLabel")
+ self.noLabelText = getStr("nullLabel")
+ self.model = "paddle"
self.PPreader = None
self.autoSaveNum = 5
@@ -163,9 +199,9 @@ def __init__(self,
fileListContainer = QWidget()
fileListContainer.setLayout(filelistLayout)
- self.fileListName = getStr('fileList')
+ self.fileListName = getStr("fileList")
self.fileDock = QDockWidget(self.fileListName, self)
- self.fileDock.setObjectName(getStr('files'))
+ self.fileDock.setObjectName(getStr("files"))
self.fileDock.setWidget(fileListContainer)
self.addDockWidget(Qt.LeftDockWidgetArea, self.fileDock)
@@ -179,7 +215,7 @@ def __init__(self,
key_list_height = 50
self.keyList.setMaximumHeight(key_list_height)
- self.keyListDockName = getStr('keyListTitle')
+ self.keyListDockName = getStr("keyListTitle")
self.keyListDock = QDockWidget(self.keyListDockName, self)
self.keyListDock.setWidget(self.keyList)
self.keyListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
@@ -187,7 +223,7 @@ def __init__(self,
self.AutoRecognition = QToolButton()
self.AutoRecognition.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
- self.AutoRecognition.setIcon(newIcon('Auto'))
+ self.AutoRecognition.setIcon(newIcon("Auto"))
autoRecLayout = QHBoxLayout()
autoRecLayout.setContentsMargins(0, 0, 0, 0)
autoRecLayout.addWidget(self.AutoRecognition)
@@ -202,7 +238,7 @@ def __init__(self,
# Buttons
self.editButton = QToolButton()
self.reRecogButton = QToolButton()
- self.reRecogButton.setIcon(newIcon('reRec', 30))
+ self.reRecogButton.setIcon(newIcon("reRec", 30))
self.reRecogButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
self.tableRecButton = QToolButton()
@@ -233,17 +269,19 @@ def __init__(self,
# Create and add a widget for showing current label item index
self.indexList = QListWidget()
- self.indexList.setMaximumSize(30, 16777215) # limit max width
- self.indexList.setEditTriggers(QAbstractItemView.NoEditTriggers) # no editable
+ self.indexList.setMaximumSize(30, 16777215) # limit max width
+ self.indexList.setEditTriggers(QAbstractItemView.NoEditTriggers) # no editable
self.indexList.itemSelectionChanged.connect(self.indexSelectionChanged)
- self.indexList.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # no scroll Bar
- self.indexListDock = QDockWidget('No.', self)
+ self.indexList.setVerticalScrollBarPolicy(
+ Qt.ScrollBarAlwaysOff
+ ) # no scroll Bar
+ self.indexListDock = QDockWidget("No.", self)
self.indexListDock.setWidget(self.indexList)
self.indexListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
labelIndexListlBox.addWidget(self.indexListDock, 1)
# no margin between two boxes
labelIndexListlBox.setSpacing(0)
-
+
# Create and add a widget for showing current label items
self.labelList = EditInList()
labelListContainer = QWidget()
@@ -253,14 +291,16 @@ def __init__(self,
# Connect to itemChanged to detect checkbox changes.
self.labelList.itemChanged.connect(self.labelItemChanged)
- self.labelListDockName = getStr('recognitionResult')
+ self.labelListDockName = getStr("recognitionResult")
self.labelListDock = QDockWidget(self.labelListDockName, self)
self.labelListDock.setWidget(self.labelList)
self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
- labelIndexListlBox.addWidget(self.labelListDock, 10) # label list is wider than index list
-
+ labelIndexListlBox.addWidget(
+ self.labelListDock, 10
+ ) # label list is wider than index list
+
# enable labelList drag_drop to adjust bbox order
- # 设置选择模式为单选
+ # 设置选择模式为单选
self.labelList.setSelectionMode(QAbstractItemView.SingleSelection)
# 启用拖拽
self.labelList.setDragEnabled(True)
@@ -269,7 +309,7 @@ def __init__(self,
# 设置显示将要被放置的位置
self.labelList.setDropIndicatorShown(True)
# 设置拖放模式为移动项目,如果不设置,默认为复制项目
- self.labelList.setDragDropMode(QAbstractItemView.InternalMove)
+ self.labelList.setDragDropMode(QAbstractItemView.InternalMove)
# 触发放置
self.labelList.model().rowsMoved.connect(self.drag_drop_happened)
@@ -292,7 +332,7 @@ def __init__(self,
self.BoxList.itemDoubleClicked.connect(self.editBox)
# Connect to itemChanged to detect checkbox changes.
self.BoxList.itemChanged.connect(self.boxItemChanged)
- self.BoxListDockName = getStr('detectionBoxposition')
+ self.BoxListDockName = getStr("detectionBoxposition")
self.BoxListDock = QDockWidget(self.BoxListDockName, self)
self.BoxListDock.setWidget(self.BoxList)
self.BoxListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
@@ -306,8 +346,8 @@ def __init__(self,
leftbtmtoolboxcontainer.setLayout(leftbtmtoolbox)
listLayout.addWidget(leftbtmtoolboxcontainer)
- self.dock = QDockWidget(getStr('boxLabelText'), self)
- self.dock.setObjectName(getStr('labels'))
+ self.dock = QDockWidget(getStr("boxLabelText"), self)
+ self.dock.setObjectName(getStr("labels"))
self.dock.setWidget(labelListContainer)
# ================== Zoom Bar ==================
@@ -324,8 +364,8 @@ def __init__(self,
self.imageSlider.setGraphicsEffect(op)
self.imageSlider.setStyleSheet("background-color:transparent")
- self.imageSliderDock = QDockWidget(getStr('ImageResize'), self)
- self.imageSliderDock.setObjectName(getStr('IR'))
+ self.imageSliderDock = QDockWidget(getStr("ImageResize"), self)
+ self.imageSliderDock.setObjectName(getStr("IR"))
self.imageSliderDock.setWidget(self.imageSlider)
self.imageSliderDock.setFeatures(QDockWidget.DockWidgetFloatable)
self.imageSliderDock.setAttribute(Qt.WA_TranslucentBackground)
@@ -346,8 +386,8 @@ def __init__(self,
self.preButton.setIcon(newIcon("prev", 40))
self.preButton.setIconSize(QSize(40, 100))
self.preButton.clicked.connect(self.openPrevImg)
- self.preButton.setStyleSheet('border: none;')
- self.preButton.setShortcut('a')
+ self.preButton.setStyleSheet("border: none;")
+ self.preButton.setShortcut("a")
self.iconlist = QListWidget()
self.iconlist.setViewMode(QListView.IconMode)
self.iconlist.setFlow(QListView.TopToBottom)
@@ -356,14 +396,16 @@ def __init__(self,
self.iconlist.setMovement(QListView.Static)
self.iconlist.setResizeMode(QListView.Adjust)
self.iconlist.itemClicked.connect(self.iconitemDoubleClicked)
- self.iconlist.setStyleSheet("QListWidget{ background-color:transparent; border: none;}")
+ self.iconlist.setStyleSheet(
+ "QListWidget{ background-color:transparent; border: none;}"
+ )
self.iconlist.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.nextButton = QToolButton()
self.nextButton.setIcon(newIcon("next", 40))
self.nextButton.setIconSize(QSize(40, 100))
- self.nextButton.setStyleSheet('border: none;')
+ self.nextButton.setStyleSheet("border: none;")
self.nextButton.clicked.connect(self.openNextImg)
- self.nextButton.setShortcut('d')
+ self.nextButton.setShortcut("d")
hlayout.addWidget(self.preButton)
hlayout.addWidget(self.iconlist)
@@ -383,7 +425,7 @@ def __init__(self,
scroll.setWidgetResizable(True)
self.scrollBars = {
Qt.Vertical: scroll.verticalScrollBar(),
- Qt.Horizontal: scroll.horizontalScrollBar()
+ Qt.Horizontal: scroll.horizontalScrollBar(),
}
self.scrollArea = scroll
self.canvas.scrollRequest.connect(self.scrollRequest)
@@ -403,84 +445,201 @@ def __init__(self,
self.setCentralWidget(centerContainer)
self.addDockWidget(Qt.RightDockWidgetArea, self.dock)
- self.dock.setFeatures(QDockWidget.DockWidgetClosable | QDockWidget.DockWidgetFloatable)
+ self.dock.setFeatures(
+ QDockWidget.DockWidgetClosable | QDockWidget.DockWidgetFloatable
+ )
self.fileDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
# ================== Actions ==================
action = partial(newAction, self)
- quit = action(getStr('quit'), self.close,
- 'Ctrl+Q', 'quit', getStr('quitApp'))
-
- opendir = action(getStr('openDir'), self.openDirDialog,
- 'Ctrl+u', 'open', getStr('openDir'))
-
- open_dataset_dir = action(getStr('openDatasetDir'), self.openDatasetDirDialog,
- 'Ctrl+p', 'open', getStr('openDatasetDir'), enabled=False)
-
- save = action(getStr('save'), self.saveFile,
- 'Ctrl+V', 'verify', getStr('saveDetail'), enabled=False)
-
- alcm = action(getStr('choosemodel'), self.autolcm,
- 'Ctrl+M', 'next', getStr('tipchoosemodel'))
-
- deleteImg = action(getStr('deleteImg'), self.deleteImg, 'Ctrl+Shift+D', 'close', getStr('deleteImgDetail'),
- enabled=True)
-
- resetAll = action(getStr('resetAll'), self.resetAll, None, 'resetall', getStr('resetAllDetail'))
-
- color1 = action(getStr('boxLineColor'), self.chooseColor,
- 'Ctrl+L', 'color_line', getStr('boxLineColorDetail'))
-
- createMode = action(getStr('crtBox'), self.setCreateMode,
- 'w', 'new', getStr('crtBoxDetail'), enabled=False)
- editMode = action('&Edit\nRectBox', self.setEditMode,
- 'Ctrl+J', 'edit', u'Move and edit Boxs', enabled=False)
-
- create = action(getStr('crtBox'), self.createShape,
- 'w', 'objects', getStr('crtBoxDetail'), enabled=False)
-
- delete = action(getStr('delBox'), self.deleteSelectedShape,
- 'backspace', 'delete', getStr('delBoxDetail'), enabled=False)
-
- copy = action(getStr('dupBox'), self.copySelectedShape,
- 'Ctrl+C', 'copy', getStr('dupBoxDetail'),
- enabled=False)
-
- hideAll = action(getStr('hideBox'), partial(self.togglePolygons, False),
- 'Ctrl+H', 'hide', getStr('hideAllBoxDetail'),
- enabled=False)
- showAll = action(getStr('showBox'), partial(self.togglePolygons, True),
- 'Ctrl+A', 'hide', getStr('showAllBoxDetail'),
- enabled=False)
-
- help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
- showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
- showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
- showKeys = action(getStr('keys'), self.showKeysDialog, None, 'help', getStr('keys'))
+ quit = action(getStr("quit"), self.close, "Ctrl+Q", "quit", getStr("quitApp"))
+
+ opendir = action(
+ getStr("openDir"), self.openDirDialog, "Ctrl+u", "open", getStr("openDir")
+ )
+
+ open_dataset_dir = action(
+ getStr("openDatasetDir"),
+ self.openDatasetDirDialog,
+ "Ctrl+p",
+ "open",
+ getStr("openDatasetDir"),
+ enabled=False,
+ )
+
+ save = action(
+ getStr("save"),
+ self.saveFile,
+ "Ctrl+V",
+ "verify",
+ getStr("saveDetail"),
+ enabled=False,
+ )
+
+ alcm = action(
+ getStr("choosemodel"),
+ self.autolcm,
+ "Ctrl+M",
+ "next",
+ getStr("tipchoosemodel"),
+ )
+
+ deleteImg = action(
+ getStr("deleteImg"),
+ self.deleteImg,
+ "Ctrl+Shift+D",
+ "close",
+ getStr("deleteImgDetail"),
+ enabled=True,
+ )
+
+ resetAll = action(
+ getStr("resetAll"),
+ self.resetAll,
+ None,
+ "resetall",
+ getStr("resetAllDetail"),
+ )
+
+ color1 = action(
+ getStr("boxLineColor"),
+ self.chooseColor,
+ "Ctrl+L",
+ "color_line",
+ getStr("boxLineColorDetail"),
+ )
+
+ createMode = action(
+ getStr("crtBox"),
+ self.setCreateMode,
+ "w",
+ "new",
+ getStr("crtBoxDetail"),
+ enabled=False,
+ )
+ editMode = action(
+ "&Edit\nRectBox",
+ self.setEditMode,
+ "Ctrl+J",
+ "edit",
+ "Move and edit Boxs",
+ enabled=False,
+ )
+
+ create = action(
+ getStr("crtBox"),
+ self.createShape,
+ "w",
+ "objects",
+ getStr("crtBoxDetail"),
+ enabled=False,
+ )
+
+ delete = action(
+ getStr("delBox"),
+ self.deleteSelectedShape,
+ "backspace",
+ "delete",
+ getStr("delBoxDetail"),
+ enabled=False,
+ )
+
+ copy = action(
+ getStr("dupBox"),
+ self.copySelectedShape,
+ "Ctrl+C",
+ "copy",
+ getStr("dupBoxDetail"),
+ enabled=False,
+ )
+
+ hideAll = action(
+ getStr("hideBox"),
+ partial(self.togglePolygons, False),
+ "Ctrl+H",
+ "hide",
+ getStr("hideAllBoxDetail"),
+ enabled=False,
+ )
+ showAll = action(
+ getStr("showBox"),
+ partial(self.togglePolygons, True),
+ "Ctrl+A",
+ "hide",
+ getStr("showAllBoxDetail"),
+ enabled=False,
+ )
+
+ help = action(
+ getStr("tutorial"),
+ self.showTutorialDialog,
+ None,
+ "help",
+ getStr("tutorialDetail"),
+ )
+ showInfo = action(
+ getStr("info"), self.showInfoDialog, None, "help", getStr("info")
+ )
+ showSteps = action(
+ getStr("steps"), self.showStepsDialog, None, "help", getStr("steps")
+ )
+ showKeys = action(
+ getStr("keys"), self.showKeysDialog, None, "help", getStr("keys")
+ )
zoom = QWidgetAction(self)
zoom.setDefaultWidget(self.zoomWidget)
self.zoomWidget.setWhatsThis(
- u"Zoom in or out of the image. Also accessible with"
- " %s and %s from the canvas." % (fmtShortcut("Ctrl+[-+]"),
- fmtShortcut("Ctrl+Wheel")))
+ "Zoom in or out of the image. Also accessible with"
+ " %s and %s from the canvas."
+ % (fmtShortcut("Ctrl+[-+]"), fmtShortcut("Ctrl+Wheel"))
+ )
self.zoomWidget.setEnabled(False)
- zoomIn = action(getStr('zoomin'), partial(self.addZoom, 10),
- 'Ctrl++', 'zoom-in', getStr('zoominDetail'), enabled=False)
- zoomOut = action(getStr('zoomout'), partial(self.addZoom, -10),
- 'Ctrl+-', 'zoom-out', getStr('zoomoutDetail'), enabled=False)
- zoomOrg = action(getStr('originalsize'), partial(self.setZoom, 100),
- 'Ctrl+=', 'zoom', getStr('originalsizeDetail'), enabled=False)
- fitWindow = action(getStr('fitWin'), self.setFitWindow,
- 'Ctrl+F', 'fit-window', getStr('fitWinDetail'),
- checkable=True, enabled=False)
- fitWidth = action(getStr('fitWidth'), self.setFitWidth,
- 'Ctrl+Shift+F', 'fit-width', getStr('fitWidthDetail'),
- checkable=True, enabled=False)
+ zoomIn = action(
+ getStr("zoomin"),
+ partial(self.addZoom, 10),
+ "Ctrl++",
+ "zoom-in",
+ getStr("zoominDetail"),
+ enabled=False,
+ )
+ zoomOut = action(
+ getStr("zoomout"),
+ partial(self.addZoom, -10),
+ "Ctrl+-",
+ "zoom-out",
+ getStr("zoomoutDetail"),
+ enabled=False,
+ )
+ zoomOrg = action(
+ getStr("originalsize"),
+ partial(self.setZoom, 100),
+ "Ctrl+=",
+ "zoom",
+ getStr("originalsizeDetail"),
+ enabled=False,
+ )
+ fitWindow = action(
+ getStr("fitWin"),
+ self.setFitWindow,
+ "Ctrl+F",
+ "fit-window",
+ getStr("fitWinDetail"),
+ checkable=True,
+ enabled=False,
+ )
+ fitWidth = action(
+ getStr("fitWidth"),
+ self.setFitWidth,
+ "Ctrl+Shift+F",
+ "fit-width",
+ getStr("fitWidthDetail"),
+ checkable=True,
+ enabled=False,
+ )
# Group zoom controls into a list for easier toggling.
- zoomActions = (self.zoomWidget, zoomIn, zoomOut,
- zoomOrg, fitWindow, fitWidth)
+ zoomActions = (self.zoomWidget, zoomIn, zoomOut, zoomOrg, fitWindow, fitWidth)
self.zoomMode = self.MANUAL_ZOOM
self.scalers = {
self.FIT_WINDOW: self.scaleFitWindow,
@@ -491,55 +650,157 @@ def __init__(self,
# ================== New Actions ==================
- edit = action(getStr('editLabel'), self.editLabel,
- 'Ctrl+E', 'edit', getStr('editLabelDetail'), enabled=False)
-
- AutoRec = action(getStr('autoRecognition'), self.autoRecognition,
- '', 'Auto', getStr('autoRecognition'), enabled=False)
-
- reRec = action(getStr('reRecognition'), self.reRecognition,
- 'Ctrl+Shift+R', 'reRec', getStr('reRecognition'), enabled=False)
-
- singleRere = action(getStr('singleRe'), self.singleRerecognition,
- 'Ctrl+R', 'reRec', getStr('singleRe'), enabled=False)
-
- createpoly = action(getStr('creatPolygon'), self.createPolygon,
- 'q', 'new', getStr('creatPolygon'), enabled=False)
-
- tableRec = action(getStr('TableRecognition'), self.TableRecognition,
- '', 'Auto', getStr('TableRecognition'), enabled=False)
-
- cellreRec = action(getStr('cellreRecognition'), self.cellreRecognition,
- '', 'reRec', getStr('cellreRecognition'), enabled=False)
-
- saveRec = action(getStr('saveRec'), self.saveRecResult,
- '', 'save', getStr('saveRec'), enabled=False)
-
- saveLabel = action(getStr('saveLabel'), self.saveLabelFile, #
- 'Ctrl+S', 'save', getStr('saveLabel'), enabled=False)
-
- exportJSON = action(getStr('exportJSON'), self.exportJSON,
- '', 'save', getStr('exportJSON'), enabled=False)
-
- undoLastPoint = action(getStr("undoLastPoint"), self.canvas.undoLastPoint,
- 'Ctrl+Z', "undo", getStr("undoLastPoint"), enabled=False)
-
- rotateLeft = action(getStr("rotateLeft"), partial(self.rotateImgAction, 1),
- 'Ctrl+Alt+L', "rotateLeft", getStr("rotateLeft"), enabled=False)
-
- rotateRight = action(getStr("rotateRight"), partial(self.rotateImgAction, -1),
- 'Ctrl+Alt+R', "rotateRight", getStr("rotateRight"), enabled=False)
-
- undo = action(getStr("undo"), self.undoShapeEdit,
- 'Ctrl+Z', "undo", getStr("undo"), enabled=False)
-
- change_cls = action(getStr("keyChange"), self.change_box_key,
- 'Ctrl+X', "edit", getStr("keyChange"), enabled=False)
-
- lock = action(getStr("lockBox"), self.lockSelectedShape,
- None, "lock", getStr("lockBoxDetail"), enabled=False)
- expand = action(getStr("expandBox"), self.expandSelectedShape,
- "Ctrl+K", "expand", getStr("expandBoxDetail"), enabled=False)
+ edit = action(
+ getStr("editLabel"),
+ self.editLabel,
+ "Ctrl+E",
+ "edit",
+ getStr("editLabelDetail"),
+ enabled=False,
+ )
+
+ AutoRec = action(
+ getStr("autoRecognition"),
+ self.autoRecognition,
+ "",
+ "Auto",
+ getStr("autoRecognition"),
+ enabled=False,
+ )
+
+ reRec = action(
+ getStr("reRecognition"),
+ self.reRecognition,
+ "Ctrl+Shift+R",
+ "reRec",
+ getStr("reRecognition"),
+ enabled=False,
+ )
+
+ singleRere = action(
+ getStr("singleRe"),
+ self.singleRerecognition,
+ "Ctrl+R",
+ "reRec",
+ getStr("singleRe"),
+ enabled=False,
+ )
+
+ createpoly = action(
+ getStr("creatPolygon"),
+ self.createPolygon,
+ "q",
+ "new",
+ getStr("creatPolygon"),
+ enabled=False,
+ )
+
+ tableRec = action(
+ getStr("TableRecognition"),
+ self.TableRecognition,
+ "",
+ "Auto",
+ getStr("TableRecognition"),
+ enabled=False,
+ )
+
+ cellreRec = action(
+ getStr("cellreRecognition"),
+ self.cellreRecognition,
+ "",
+ "reRec",
+ getStr("cellreRecognition"),
+ enabled=False,
+ )
+
+ saveRec = action(
+ getStr("saveRec"),
+ self.saveRecResult,
+ "",
+ "save",
+ getStr("saveRec"),
+ enabled=False,
+ )
+
+ saveLabel = action(
+ getStr("saveLabel"),
+ self.saveLabelFile, #
+ "Ctrl+S",
+ "save",
+ getStr("saveLabel"),
+ enabled=False,
+ )
+
+ exportJSON = action(
+ getStr("exportJSON"),
+ self.exportJSON,
+ "",
+ "save",
+ getStr("exportJSON"),
+ enabled=False,
+ )
+
+ undoLastPoint = action(
+ getStr("undoLastPoint"),
+ self.canvas.undoLastPoint,
+ "Ctrl+Z",
+ "undo",
+ getStr("undoLastPoint"),
+ enabled=False,
+ )
+
+ rotateLeft = action(
+ getStr("rotateLeft"),
+ partial(self.rotateImgAction, 1),
+ "Ctrl+Alt+L",
+ "rotateLeft",
+ getStr("rotateLeft"),
+ enabled=False,
+ )
+
+ rotateRight = action(
+ getStr("rotateRight"),
+ partial(self.rotateImgAction, -1),
+ "Ctrl+Alt+R",
+ "rotateRight",
+ getStr("rotateRight"),
+ enabled=False,
+ )
+
+ undo = action(
+ getStr("undo"),
+ self.undoShapeEdit,
+ "Ctrl+Z",
+ "undo",
+ getStr("undo"),
+ enabled=False,
+ )
+
+ change_cls = action(
+ getStr("keyChange"),
+ self.change_box_key,
+ "Ctrl+X",
+ "edit",
+ getStr("keyChange"),
+ enabled=False,
+ )
+
+ lock = action(
+ getStr("lockBox"),
+ self.lockSelectedShape,
+ None,
+ "lock",
+ getStr("lockBoxDetail"),
+ enabled=False,
+ )
+ expand = action(
+ getStr("expandBox"),
+ self.expandSelectedShape,
+ "Ctrl+K",
+ "expand",
+ getStr("expandBoxDetail"),
+ enabled=False,
+ )
self.editButton.setDefaultAction(edit)
self.newButton.setDefaultAction(create)
@@ -572,12 +833,20 @@ def __init__(self,
zoomContainer.setLayout(zoomLayout)
zoomContainer.setGeometry(0, 0, 30, 150)
- shapeLineColor = action(getStr('shapeLineColor'), self.chshapeLineColor,
- icon='color_line', tip=getStr('shapeLineColorDetail'),
- enabled=False)
- shapeFillColor = action(getStr('shapeFillColor'), self.chshapeFillColor,
- icon='color', tip=getStr('shapeFillColorDetail'),
- enabled=False)
+ shapeLineColor = action(
+ getStr("shapeLineColor"),
+ self.chshapeLineColor,
+ icon="color_line",
+ tip=getStr("shapeLineColorDetail"),
+ enabled=False,
+ )
+ shapeFillColor = action(
+ getStr("shapeFillColor"),
+ self.chshapeFillColor,
+ icon="color",
+ tip=getStr("shapeFillColorDetail"),
+ enabled=False,
+ )
# Label list context menu.
labelMenu = QMenu()
@@ -587,82 +856,181 @@ def __init__(self,
self.labelList.customContextMenuRequested.connect(self.popLabelListMenu)
# Draw squares/rectangles
- self.drawSquaresOption = QAction(getStr('drawSquares'), self)
+ self.drawSquaresOption = QAction(getStr("drawSquares"), self)
self.drawSquaresOption.setCheckable(True)
self.drawSquaresOption.setChecked(settings.get(SETTING_DRAW_SQUARE, False))
self.drawSquaresOption.triggered.connect(self.toogleDrawSquare)
# Store actions for further handling.
- self.actions = struct(save=save, resetAll=resetAll, deleteImg=deleteImg,
- lineColor=color1, create=create, createpoly=createpoly, tableRec=tableRec, delete=delete, edit=edit, copy=copy,
- saveRec=saveRec, singleRere=singleRere, AutoRec=AutoRec, reRec=reRec, cellreRec=cellreRec,
- createMode=createMode, editMode=editMode,
- shapeLineColor=shapeLineColor, shapeFillColor=shapeFillColor,
- zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
- fitWindow=fitWindow, fitWidth=fitWidth,
- zoomActions=zoomActions, saveLabel=saveLabel, change_cls=change_cls,
- undo=undo, undoLastPoint=undoLastPoint, open_dataset_dir=open_dataset_dir,
- rotateLeft=rotateLeft, rotateRight=rotateRight, lock=lock, exportJSON=exportJSON,expand=expand,
- fileMenuActions=(opendir, open_dataset_dir, saveLabel, exportJSON, resetAll, quit),
- beginner=(), advanced=(),
- editMenu=(createpoly, edit, copy, delete, singleRere, cellreRec, None, undo, undoLastPoint,
- None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption, lock,expand,
- None, change_cls),
- beginnerContext=(
- create, createpoly, edit, copy, delete, singleRere, cellreRec, rotateLeft, rotateRight, lock,expand,change_cls),
- advancedContext=(createMode, editMode, edit, copy,
- delete, shapeLineColor, shapeFillColor),
- onLoadActive=(create, createpoly, createMode, editMode),
- onShapesPresent=(hideAll, showAll))
+ self.actions = struct(
+ save=save,
+ resetAll=resetAll,
+ deleteImg=deleteImg,
+ lineColor=color1,
+ create=create,
+ createpoly=createpoly,
+ tableRec=tableRec,
+ delete=delete,
+ edit=edit,
+ copy=copy,
+ saveRec=saveRec,
+ singleRere=singleRere,
+ AutoRec=AutoRec,
+ reRec=reRec,
+ cellreRec=cellreRec,
+ createMode=createMode,
+ editMode=editMode,
+ shapeLineColor=shapeLineColor,
+ shapeFillColor=shapeFillColor,
+ zoom=zoom,
+ zoomIn=zoomIn,
+ zoomOut=zoomOut,
+ zoomOrg=zoomOrg,
+ fitWindow=fitWindow,
+ fitWidth=fitWidth,
+ zoomActions=zoomActions,
+ saveLabel=saveLabel,
+ change_cls=change_cls,
+ undo=undo,
+ undoLastPoint=undoLastPoint,
+ open_dataset_dir=open_dataset_dir,
+ rotateLeft=rotateLeft,
+ rotateRight=rotateRight,
+ lock=lock,
+ exportJSON=exportJSON,
+ expand=expand,
+ fileMenuActions=(
+ opendir,
+ open_dataset_dir,
+ saveLabel,
+ exportJSON,
+ resetAll,
+ quit,
+ ),
+ beginner=(),
+ advanced=(),
+ editMenu=(
+ createpoly,
+ edit,
+ copy,
+ delete,
+ singleRere,
+ cellreRec,
+ None,
+ undo,
+ undoLastPoint,
+ None,
+ rotateLeft,
+ rotateRight,
+ None,
+ color1,
+ self.drawSquaresOption,
+ lock,
+ expand,
+ None,
+ change_cls,
+ ),
+ beginnerContext=(
+ create,
+ createpoly,
+ edit,
+ copy,
+ delete,
+ singleRere,
+ cellreRec,
+ rotateLeft,
+ rotateRight,
+ lock,
+ expand,
+ change_cls,
+ ),
+ advancedContext=(
+ createMode,
+ editMode,
+ edit,
+ copy,
+ delete,
+ shapeLineColor,
+ shapeFillColor,
+ ),
+ onLoadActive=(create, createpoly, createMode, editMode),
+ onShapesPresent=(hideAll, showAll),
+ )
# menus
self.menus = struct(
- file=self.menu('&' + getStr('mfile')),
- edit=self.menu('&' + getStr('medit')),
- view=self.menu('&' + getStr('mview')),
- autolabel=self.menu('&PaddleOCR'),
- help=self.menu('&' + getStr('mhelp')),
- recentFiles=QMenu('Open &Recent'),
- labelList=labelMenu)
+ file=self.menu("&" + getStr("mfile")),
+ edit=self.menu("&" + getStr("medit")),
+ view=self.menu("&" + getStr("mview")),
+ autolabel=self.menu("&PaddleOCR"),
+ help=self.menu("&" + getStr("mhelp")),
+ recentFiles=QMenu("Open &Recent"),
+ labelList=labelMenu,
+ )
self.lastLabel = None
# Add option to enable/disable labels being displayed at the top of bounding boxes
- self.displayLabelOption = QAction(getStr('displayLabel'), self)
+ self.displayLabelOption = QAction(getStr("displayLabel"), self)
self.displayLabelOption.setShortcut("Ctrl+Shift+P")
self.displayLabelOption.setCheckable(True)
self.displayLabelOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayLabelOption.triggered.connect(self.togglePaintLabelsOption)
# Add option to enable/disable box index being displayed at the top of bounding boxes
- self.displayIndexOption = QAction(getStr('displayIndex'), self)
+ self.displayIndexOption = QAction(getStr("displayIndex"), self)
self.displayIndexOption.setCheckable(True)
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.displayIndexOption.triggered.connect(self.togglePaintIndexOption)
- self.labelDialogOption = QAction(getStr('labelDialogOption'), self)
+ self.labelDialogOption = QAction(getStr("labelDialogOption"), self)
self.labelDialogOption.setShortcut("Ctrl+Shift+L")
self.labelDialogOption.setCheckable(True)
self.labelDialogOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.labelDialogOption.triggered.connect(self.speedChoose)
- self.autoSaveOption = QAction(getStr('autoSaveMode'), self)
+ self.autoSaveOption = QAction(getStr("autoSaveMode"), self)
self.autoSaveOption.setCheckable(True)
self.autoSaveOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.autoSaveOption.triggered.connect(self.autoSaveFunc)
- addActions(self.menus.file,
- (opendir, open_dataset_dir, None, saveLabel, saveRec, exportJSON, self.autoSaveOption, None, resetAll, deleteImg,
- quit))
+ addActions(
+ self.menus.file,
+ (
+ opendir,
+ open_dataset_dir,
+ None,
+ saveLabel,
+ saveRec,
+ exportJSON,
+ self.autoSaveOption,
+ None,
+ resetAll,
+ deleteImg,
+ quit,
+ ),
+ )
addActions(self.menus.help, (showKeys, showSteps, showInfo))
- addActions(self.menus.view, (
- self.displayLabelOption, self.displayIndexOption, self.labelDialogOption,
- None,
- hideAll, showAll, None,
- zoomIn, zoomOut, zoomOrg, None,
- fitWindow, fitWidth))
+ addActions(
+ self.menus.view,
+ (
+ self.displayLabelOption,
+ self.displayIndexOption,
+ self.labelDialogOption,
+ None,
+ hideAll,
+ showAll,
+ None,
+ zoomIn,
+ zoomOut,
+ zoomOrg,
+ None,
+ fitWindow,
+ fitWidth,
+ ),
+ )
addActions(self.menus.autolabel, (AutoRec, reRec, cellreRec, alcm, None, help))
@@ -671,7 +1039,7 @@ def __init__(self,
# Custom context menu for the canvas widget:
addActions(self.canvas.menus[0], self.actions.beginnerContext)
- self.statusBar().showMessage('%s started.' % __appname__)
+ self.statusBar().showMessage("%s started." % __appname__)
self.statusBar().show()
# Application state.
@@ -693,7 +1061,9 @@ def __init__(self,
recentFileQStringList = settings.get(SETTING_RECENT_FILES)
self.recentFiles = [ustr(i) for i in recentFileQStringList]
else:
- self.recentFiles = recentFileQStringList = settings.get(SETTING_RECENT_FILES)
+ self.recentFiles = recentFileQStringList = settings.get(
+ SETTING_RECENT_FILES
+ )
size = settings.get(SETTING_WIN_SIZE, QSize(1200, 800))
@@ -710,8 +1080,12 @@ def __init__(self,
self.lastOpenDir = ustr(settings.get(SETTING_LAST_OPEN_DIR, None))
self.restoreState(settings.get(SETTING_WIN_STATE, QByteArray()))
- Shape.line_color = self.lineColor = QColor(settings.get(SETTING_LINE_COLOR, DEFAULT_LINE_COLOR))
- Shape.fill_color = self.fillColor = QColor(settings.get(SETTING_FILL_COLOR, DEFAULT_FILL_COLOR))
+ Shape.line_color = self.lineColor = QColor(
+ settings.get(SETTING_LINE_COLOR, DEFAULT_LINE_COLOR)
+ )
+ Shape.fill_color = self.fillColor = QColor(
+ settings.get(SETTING_FILL_COLOR, DEFAULT_FILL_COLOR)
+ )
self.canvas.setDrawingColor(self.lineColor)
# Add chris
Shape.difficult = self.difficult
@@ -734,7 +1108,7 @@ def __init__(self,
self.populateModeActions()
# Display cursor coordinates at the right of status bar
- self.labelCoordinates = QLabel('')
+ self.labelCoordinates = QLabel("")
self.statusBar().addPermanentWidget(self.labelCoordinates)
# Open Dir if deafult file
@@ -763,7 +1137,9 @@ def populateModeActions(self):
self.canvas.menus[0].clear()
addActions(self.canvas.menus[0], self.actions.beginnerContext)
self.menus.edit.clear()
- actions = (self.actions.create,) # if self.beginner() else (self.actions.createMode, self.actions.editMode)
+ actions = (
+ self.actions.create,
+ ) # if self.beginner() else (self.actions.createMode, self.actions.editMode)
addActions(self.menus.edit, actions + self.actions.editMenu)
def setDirty(self):
@@ -833,12 +1209,12 @@ def advanced(self):
def getAvailableScreencastViewer(self):
osName = platform.system()
- if osName == 'Windows':
- return ['C:\\Program Files\\Internet Explorer\\iexplore.exe']
- elif osName == 'Linux':
- return ['xdg-open']
- elif osName == 'Darwin':
- return ['open']
+ if osName == "Windows":
+ return ["C:\\Program Files\\Internet Explorer\\iexplore.exe"]
+ elif osName == "Linux":
+ return ["xdg-open"]
+ elif osName == "Darwin":
+ return ["open"]
## Callbacks ##
def showTutorialDialog(self):
@@ -846,16 +1222,19 @@ def showTutorialDialog(self):
def showInfoDialog(self):
from libs.__init__ import __version__
- msg = u'Name:{0} \nApp Version:{1} \n{2} '.format(__appname__, __version__, sys.version_info)
- QMessageBox.information(self, u'Information', msg)
+
+ msg = "Name:{0} \nApp Version:{1} \n{2} ".format(
+ __appname__, __version__, sys.version_info
+ )
+ QMessageBox.information(self, "Information", msg)
def showStepsDialog(self):
msg = stepsInfo(self.lang)
- QMessageBox.information(self, u'Information', msg)
+ QMessageBox.information(self, "Information", msg)
def showKeysDialog(self):
msg = keysInfo(self.lang)
- QMessageBox.information(self, u'Information', msg)
+ QMessageBox.information(self, "Information", msg)
def createShape(self):
assert self.beginner()
@@ -881,15 +1260,18 @@ def rotateImg(self, filename, k, _value):
self.loadFile(filename)
def rotateImgWarn(self):
- if self.lang == 'ch':
+ if self.lang == "ch":
self.msgBox.warning(self, "提示", "\n 该图片已经有标注框,旋转操作会打乱标注,建议清除标注框后旋转。")
else:
- self.msgBox.warning(self, "Warn", "\n The picture already has a label box, "
- "and rotation will disrupt the label. "
- "It is recommended to clear the label box and rotate it.")
+ self.msgBox.warning(
+ self,
+ "Warn",
+ "\n The picture already has a label box, "
+ "and rotation will disrupt the label. "
+ "It is recommended to clear the label box and rotate it.",
+ )
def rotateImgAction(self, k=1, _value=False):
-
filename = self.mImgList[self.currIndex]
if os.path.exists(filename):
@@ -909,7 +1291,7 @@ def toggleDrawingSensitive(self, drawing=True):
self.actions.editMode.setEnabled(not drawing)
if not drawing and self.beginner():
# Cancel creation.
- print('Cancel creation.')
+ print("Cancel creation.")
self.canvas.setEditing(True)
self.canvas.restoreCursor()
self.actions.create.setEnabled(True)
@@ -937,12 +1319,10 @@ def exists(filename):
menu = self.menus.recentFiles
menu.clear()
- files = [f for f in self.recentFiles if f !=
- currFilePath and exists(f)]
+ files = [f for f in self.recentFiles if f != currFilePath and exists(f)]
for i, f in enumerate(files):
- icon = newIcon('labels')
- action = QAction(
- icon, '&%d %s' % (i + 1, QFileInfo(f).fileName()), self)
+ icon = newIcon("labels")
+ action = QAction(icon, "&%d %s" % (i + 1, QFileInfo(f).fileName()), self)
action.triggered.connect(partial(self.loadRecent, f))
menu.addAction(action)
@@ -992,16 +1372,24 @@ def editBox(self): # ADD
try:
text_list = eval(text)
except:
- msg_box = QMessageBox(QMessageBox.Warning, 'Warning', 'Please enter the correct format')
+ msg_box = QMessageBox(
+ QMessageBox.Warning, "Warning", "Please enter the correct format"
+ )
msg_box.exec_()
return
if len(text_list) < 4:
- msg_box = QMessageBox(QMessageBox.Warning, 'Warning', 'Please enter the coordinates of 4 points')
+ msg_box = QMessageBox(
+ QMessageBox.Warning,
+ "Warning",
+ "Please enter the coordinates of 4 points",
+ )
msg_box.exec_()
return
for box in text_list:
if box[0] > width or box[0] < 0 or box[1] > height or box[1] < 0:
- msg_box = QMessageBox(QMessageBox.Warning, 'Warning', 'Out of picture size')
+ msg_box = QMessageBox(
+ QMessageBox.Warning, "Warning", "Out of picture size"
+ )
msg_box.exec_()
return
@@ -1013,7 +1401,9 @@ def editBox(self): # ADD
def updateBoxlist(self):
self.canvas.selectedShapes_hShape = []
if self.canvas.hShape != None:
- self.canvas.selectedShapes_hShape = self.canvas.selectedShapes + [self.canvas.hShape]
+ self.canvas.selectedShapes_hShape = self.canvas.selectedShapes + [
+ self.canvas.hShape
+ ]
else:
self.canvas.selectedShapes_hShape = self.canvas.selectedShapes
for shape in self.canvas.selectedShapes_hShape:
@@ -1030,11 +1420,13 @@ def indexTo5Files(self, currIndex):
elif currIndex > len(self.mImgList) - 3:
return self.mImgList[-5:]
else:
- return self.mImgList[currIndex - 2: currIndex + 3]
+ return self.mImgList[currIndex - 2 : currIndex + 3]
# Tzutalin 20160906 : Add file list and dock to move faster
def fileitemDoubleClicked(self, item=None):
- self.currIndex = self.mImgList.index(ustr(os.path.join(os.path.abspath(self.dirname), item.text())))
+ self.currIndex = self.mImgList.index(
+ ustr(os.path.join(os.path.abspath(self.dirname), item.text()))
+ )
filename = self.mImgList[self.currIndex]
if filename:
self.mImgList5 = self.indexTo5Files(self.currIndex)
@@ -1067,24 +1459,31 @@ def shapeSelectionChanged(self, selected_shapes):
index = self.labelList.indexFromItem(self.shapesToItems[shape]).row()
self.indexList.item(index).setSelected(True)
- self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible
+ self.labelList.scrollToItem(
+ self.currentItem()
+ ) # QAbstractItemView.EnsureVisible
# map current label item to index item and select it
index = self.labelList.indexFromItem(self.currentItem()).row()
- self.indexList.scrollToItem(self.indexList.item(index))
+ self.indexList.scrollToItem(self.indexList.item(index))
self.BoxList.scrollToItem(self.currentBox())
if self.kie_mode:
if len(self.canvas.selectedShapes) == 1 and self.keyList.count() > 0:
- selected_key_item_row = self.keyList.findItemsByLabel(self.canvas.selectedShapes[0].key_cls,
- get_row=True)
- if isinstance(selected_key_item_row, list) and len(selected_key_item_row) == 0:
+ selected_key_item_row = self.keyList.findItemsByLabel(
+ self.canvas.selectedShapes[0].key_cls, get_row=True
+ )
+ if (
+ isinstance(selected_key_item_row, list)
+ and len(selected_key_item_row) == 0
+ ):
key_text = self.canvas.selectedShapes[0].key_cls
item = self.keyList.createItemFromLabel(key_text)
self.keyList.addItem(item)
rgb = self._get_rgb_by_label(key_text, self.kie_mode)
self.keyList.setItemLabel(item, key_text, rgb)
- selected_key_item_row = self.keyList.findItemsByLabel(self.canvas.selectedShapes[0].key_cls,
- get_row=True)
+ selected_key_item_row = self.keyList.findItemsByLabel(
+ self.canvas.selectedShapes[0].key_cls, get_row=True
+ )
self.keyList.setCurrentRow(selected_key_item_row)
@@ -1120,7 +1519,9 @@ def addLabel(self, shape):
# print('item in add label is ',[(p.x(), p.y()) for p in shape.points], shape.label)
# ADD for box
- item = HashableQListWidgetItem(str([(int(p.x()), int(p.y())) for p in shape.points]))
+ item = HashableQListWidgetItem(
+ str([(int(p.x()), int(p.y())) for p in shape.points])
+ )
self.itemsToShapesbox[item] = shape
self.shapesToItemsbox[shape] = item
self.BoxList.addItem(item)
@@ -1129,8 +1530,12 @@ def addLabel(self, shape):
self.updateComboBox()
# update show counting
- self.BoxListDock.setWindowTitle(self.BoxListDockName + f" ({self.BoxList.count()})")
- self.labelListDock.setWindowTitle(self.labelListDockName + f" ({self.labelList.count()})")
+ self.BoxListDock.setWindowTitle(
+ self.BoxListDockName + f" ({self.BoxList.count()})"
+ )
+ self.labelListDock.setWindowTitle(
+ self.labelListDockName + f" ({self.labelList.count()})"
+ )
def remLabels(self, shapes):
if shapes is None:
@@ -1157,7 +1562,6 @@ def loadLabels(self, shapes):
for label, points, line_color, key_cls, difficult in shapes:
shape = Shape(label=label, line_color=line_color, key_cls=key_cls)
for x, y in points:
-
# Ensure the labels are within the bounds of the image. If not, fix them.
x, y, snapped = self.canvas.snapPointToCanvas(x, y)
if snapped:
@@ -1192,7 +1596,9 @@ def singleLabel(self, shape):
def updateComboBox(self):
# Get the unique labels and add them to the Combobox.
- itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
+ itemsTextList = [
+ str(self.labelList.item(i).text()) for i in range(self.labelList.count())
+ ]
uniqueTextList = list(set(itemsTextList))
# Add a null row for showing all the labels
@@ -1208,23 +1614,29 @@ def updateIndexList(self):
string.setTextAlignment(Qt.AlignHCenter)
self.indexList.addItem(string)
- def saveLabels(self, annotationFilePath, mode='Auto'):
+ def saveLabels(self, annotationFilePath, mode="Auto"):
# Mode is Auto means that labels will be loaded from self.result_dic totally, which is the output of ocr model
annotationFilePath = ustr(annotationFilePath)
def format_shape(s):
# print('s in saveLabels is ',s)
- return dict(label=s.label, # str
- line_color=s.line_color.getRgb(),
- fill_color=s.fill_color.getRgb(),
- points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
- difficult=s.difficult,
- key_cls=s.key_cls) # bool
-
- if mode == 'Auto':
+ return dict(
+ label=s.label, # str
+ line_color=s.line_color.getRgb(),
+ fill_color=s.fill_color.getRgb(),
+ points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
+ difficult=s.difficult,
+ key_cls=s.key_cls,
+ ) # bool
+
+ if mode == "Auto":
shapes = []
else:
- shapes = [format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR]
+ shapes = [
+ format_shape(shape)
+ for shape in self.canvas.shapes
+ if shape.line_color != DEFAULT_LOCK_COLOR
+ ]
# Can add differrent annotation formats here
for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], "difficult": False}
@@ -1233,19 +1645,23 @@ def format_shape(s):
trans_dic.update({"key_cls": box[2]})
else:
trans_dic.update({"key_cls": "None"})
- if trans_dic["label"] == "" and mode == 'Auto':
+ if trans_dic["label"] == "" and mode == "Auto":
continue
shapes.append(trans_dic)
try:
trans_dic = []
for box in shapes:
- trans_dict = {"transcription": box['label'], "points": box['points'], "difficult": box['difficult']}
+ trans_dict = {
+ "transcription": box["label"],
+ "points": box["points"],
+ "difficult": box["difficult"],
+ }
if self.kie_mode:
- trans_dict.update({"key_cls": box['key_cls']})
+ trans_dict.update({"key_cls": box["key_cls"]})
trans_dic.append(trans_dict)
self.PPlabel[annotationFilePath] = trans_dic
- if mode == 'Auto':
+ if mode == "Auto":
self.Cachelabel[annotationFilePath] = trans_dic
# else:
@@ -1254,7 +1670,7 @@ def format_shape(s):
# print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath))
return True
except:
- self.errorMessage(u'Error saving label data', u'Error saving label data')
+ self.errorMessage("Error saving label data", "Error saving label data")
return False
def copySelectedShape(self):
@@ -1321,15 +1737,19 @@ def labelItemChanged(self, item):
shape.difficult = True if item.checkState() == Qt.Unchecked else False
self.setDirty()
else: # User probably changed item visibility
- self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked
+ self.canvas.setShapeVisible(
+ shape, True
+ ) # item.checkState() == Qt.Checked
# self.actions.save.setEnabled(True)
else:
- print('enter labelItemChanged slot with unhashable item: ', item, item.text())
-
+ print(
+ "enter labelItemChanged slot with unhashable item: ", item, item.text()
+ )
+
def drag_drop_happened(self):
- '''
+ """
label list drag drop signal slot
- '''
+ """
# print('___________________drag_drop_happened_______________')
# should only select single item
for item in self.labelList.selectedItems():
@@ -1339,14 +1759,14 @@ def drag_drop_happened(self):
assert len(self.canvas.selectedShapes) > 0
for shape in self.canvas.selectedShapes:
selectedShapeIndex = shape.idx
-
+
if newIndex == selectedShapeIndex:
return
# move corresponding item in shape list
shape = self.canvas.shapes.pop(selectedShapeIndex)
self.canvas.shapes.insert(newIndex, shape)
-
+
# update bbox index
self.canvas.updateShapeIndex()
@@ -1373,13 +1793,17 @@ def newShape(self, value=True):
text = self.prevLabelText
if text is not None:
- self.prevLabelText = self.stringBundle.getString('tempLabel')
+ self.prevLabelText = self.stringBundle.getString("tempLabel")
- shape = self.canvas.setLastLabel(text, None, None, None) # generate_color, generate_color
+ shape = self.canvas.setLastLabel(
+ text, None, None, None
+ ) # generate_color, generate_color
if self.kie_mode:
key_text, _ = self.keyDialog.popUp(self.key_previous_text)
if key_text is not None:
- shape = self.canvas.setLastLabel(text, None, None, key_text) # generate_color, generate_color
+ shape = self.canvas.setLastLabel(
+ text, None, None, key_text
+ ) # generate_color, generate_color
self.key_previous_text = key_text
if not self.keyList.findItemsByLabel(key_text):
item = self.keyList.createItemFromLabel(key_text)
@@ -1425,7 +1849,7 @@ def _get_rgb_by_label(self, label, kie_mode):
return (0, 255, 0)
def scrollRequest(self, delta, orientation):
- units = - delta / (8 * 15)
+ units = -delta / (8 * 15)
bar = self.scrollBars[orientation]
bar.setValue(int(bar.value() + bar.singleStep() * units))
@@ -1437,7 +1861,9 @@ def setZoom(self, value):
def addZoom(self, increment=10):
self.setZoom(int(self.zoomWidget.value() + increment))
- self.imageSlider.setValue(int(self.zoomWidget.value() + increment)) # set zoom slider value
+ self.imageSlider.setValue(
+ int(self.zoomWidget.value() + increment)
+ ) # set zoom slider value
def zoomRequest(self, delta):
# get the current scrollbar positions
@@ -1528,7 +1954,7 @@ def loadFile(self, filePath=None):
if unicodeFilePath in self.mImgList:
index = self.mImgList.index(unicodeFilePath)
fileWidgetItem = self.fileListWidget.item(index)
- print('unicodeFilePath is', unicodeFilePath)
+ print("unicodeFilePath is", unicodeFilePath)
fileWidgetItem.setSelected(True)
self.iconlist.clear()
self.additems5(None)
@@ -1554,11 +1980,15 @@ def loadFile(self, filePath=None):
cvimg = cv2.imdecode(np.fromfile(unicodeFilePath, dtype=np.uint8), 1)
height, width, depth = cvimg.shape
cvimg = cv2.cvtColor(cvimg, cv2.COLOR_BGR2RGB)
- image = QImage(cvimg.data, width, height, width * depth, QImage.Format_RGB888)
+ image = QImage(
+ cvimg.data, width, height, width * depth, QImage.Format_RGB888
+ )
if image.isNull():
- self.errorMessage(u'Error opening file',
- u"
Make sure %s is a valid image file." % unicodeFilePath)
+ self.errorMessage(
+ "Error opening file",
+ "
Make sure %s is a valid image file." % unicodeFilePath,
+ )
self.status("Error reading %s" % unicodeFilePath)
return False
self.status("Loaded %s" % os.path.basename(unicodeFilePath))
@@ -1582,22 +2012,30 @@ def loadFile(self, filePath=None):
self.showBoundingBoxFromPPlabel(filePath)
- self.setWindowTitle(__appname__ + ' ' + filePath)
+ self.setWindowTitle(__appname__ + " " + filePath)
# Default : select last item if there is at least one item
if self.labelList.count():
- self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1))
+ self.labelList.setCurrentItem(
+ self.labelList.item(self.labelList.count() - 1)
+ )
self.labelList.item(self.labelList.count() - 1).setSelected(True)
self.indexList.item(self.labelList.count() - 1).setSelected(True)
# show file list image count
select_indexes = self.fileListWidget.selectedIndexes()
if len(select_indexes) > 0:
- self.fileDock.setWindowTitle(self.fileListName + f" ({select_indexes[0].row() + 1}"
- f"/{self.fileListWidget.count()})")
+ self.fileDock.setWindowTitle(
+ self.fileListName + f" ({select_indexes[0].row() + 1}"
+ f"/{self.fileListWidget.count()})"
+ )
# update show counting
- self.BoxListDock.setWindowTitle(self.BoxListDockName + f" ({self.BoxList.count()})")
- self.labelListDock.setWindowTitle(self.labelListDockName + f" ({self.labelList.count()})")
+ self.BoxListDock.setWindowTitle(
+ self.BoxListDockName + f" ({self.BoxList.count()})"
+ )
+ self.labelListDock.setWindowTitle(
+ self.labelListDockName + f" ({self.labelList.count()})"
+ )
self.canvas.setFocus(True)
return True
@@ -1610,17 +2048,39 @@ def showBoundingBoxFromPPlabel(self, filePath):
# box['ratio'] of the shapes saved in lockedShapes contains the ratio of the
# four corner coordinates of the shapes to the height and width of the image
for box in self.canvas.lockedShapes:
- key_cls = 'None' if not self.kie_mode else box['key_cls']
+ key_cls = "None" if not self.kie_mode else box["key_cls"]
if self.canvas.isInTheSameImage:
- shapes.append((box['transcription'], [[s[0] * width, s[1] * height] for s in box['ratio']],
- DEFAULT_LOCK_COLOR, key_cls, box['difficult']))
+ shapes.append(
+ (
+ box["transcription"],
+ [[s[0] * width, s[1] * height] for s in box["ratio"]],
+ DEFAULT_LOCK_COLOR,
+ key_cls,
+ box["difficult"],
+ )
+ )
else:
- shapes.append(('锁定框:待检测', [[s[0] * width, s[1] * height] for s in box['ratio']],
- DEFAULT_LOCK_COLOR, key_cls, box['difficult']))
+ shapes.append(
+ (
+ "锁定框:待检测",
+ [[s[0] * width, s[1] * height] for s in box["ratio"]],
+ DEFAULT_LOCK_COLOR,
+ key_cls,
+ box["difficult"],
+ )
+ )
if imgidx in self.PPlabel.keys():
for box in self.PPlabel[imgidx]:
- key_cls = 'None' if not self.kie_mode else box.get('key_cls', 'None')
- shapes.append((box['transcription'], box['points'], None, key_cls, box.get('difficult', False)))
+ key_cls = "None" if not self.kie_mode else box.get("key_cls", "None")
+ shapes.append(
+ (
+ box["transcription"],
+ box["points"],
+ None,
+ key_cls,
+ box.get("difficult", False),
+ )
+ )
if shapes != []:
self.loadLabels(shapes)
@@ -1635,8 +2095,11 @@ def validFilestate(self, filePath):
return False
def resizeEvent(self, event):
- if self.canvas and not self.image.isNull() \
- and self.zoomMode != self.MANUAL_ZOOM:
+ if (
+ self.canvas
+ and not self.image.isNull()
+ and self.zoomMode != self.MANUAL_ZOOM
+ ):
self.adjustScale()
super(MainWindow, self).resizeEvent(event)
@@ -1675,9 +2138,9 @@ def closeEvent(self, event):
settings = self.settings
# If it loads images from dir, don't load it at the beginning
if self.dirname is None:
- settings[SETTING_FILENAME] = self.filePath if self.filePath else ''
+ settings[SETTING_FILENAME] = self.filePath if self.filePath else ""
else:
- settings[SETTING_FILENAME] = ''
+ settings[SETTING_FILENAME] = ""
settings[SETTING_WIN_SIZE] = self.size()
settings[SETTING_WIN_POSE] = self.pos()
@@ -1689,12 +2152,12 @@ def closeEvent(self, event):
if self.defaultSaveDir and os.path.exists(self.defaultSaveDir):
settings[SETTING_SAVE_DIR] = ustr(self.defaultSaveDir)
else:
- settings[SETTING_SAVE_DIR] = ''
+ settings[SETTING_SAVE_DIR] = ""
if self.lastOpenDir and os.path.exists(self.lastOpenDir):
settings[SETTING_LAST_OPEN_DIR] = self.lastOpenDir
else:
- settings[SETTING_LAST_OPEN_DIR] = ''
+ settings[SETTING_LAST_OPEN_DIR] = ""
settings[SETTING_PAINT_LABEL] = self.displayLabelOption.isChecked()
settings[SETTING_PAINT_INDEX] = self.displayIndexOption.isChecked()
@@ -1711,7 +2174,10 @@ def loadRecent(self, filename):
self.loadFile(filename)
def scanAllImages(self, folderPath):
- extensions = ['.%s' % fmt.data().decode("ascii").lower() for fmt in QImageReader.supportedImageFormats()]
+ extensions = [
+ ".%s" % fmt.data().decode("ascii").lower()
+ for fmt in QImageReader.supportedImageFormats()
+ ]
images = []
for file in os.listdir(folderPath):
@@ -1726,16 +2192,22 @@ def openDirDialog(self, _value=False, dirpath=None, silent=False):
if not self.mayContinue():
return
- defaultOpenDirPath = dirpath if dirpath else '.'
+ defaultOpenDirPath = dirpath if dirpath else "."
if self.lastOpenDir and os.path.exists(self.lastOpenDir):
defaultOpenDirPath = self.lastOpenDir
else:
- defaultOpenDirPath = os.path.dirname(self.filePath) if self.filePath else '.'
+ defaultOpenDirPath = (
+ os.path.dirname(self.filePath) if self.filePath else "."
+ )
if silent != True:
- targetDirPath = ustr(QFileDialog.getExistingDirectory(self,
- '%s - Open Directory' % __appname__,
- defaultOpenDirPath,
- QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks))
+ targetDirPath = ustr(
+ QFileDialog.getExistingDirectory(
+ self,
+ "%s - Open Directory" % __appname__,
+ defaultOpenDirPath,
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
+ )
+ )
else:
targetDirPath = ustr(defaultOpenDirPath)
self.lastOpenDir = targetDirPath
@@ -1743,21 +2215,26 @@ def openDirDialog(self, _value=False, dirpath=None, silent=False):
def openDatasetDirDialog(self):
if self.lastOpenDir and os.path.exists(self.lastOpenDir):
- if platform.system() == 'Windows':
+ if platform.system() == "Windows":
os.startfile(self.lastOpenDir)
else:
- os.system('open ' + os.path.normpath(self.lastOpenDir))
+ os.system("open " + os.path.normpath(self.lastOpenDir))
defaultOpenDirPath = self.lastOpenDir
else:
- if self.lang == 'ch':
+ if self.lang == "ch":
self.msgBox.warning(self, "提示", "\n 原文件夹已不存在,请从新选择数据集路径!")
else:
- self.msgBox.warning(self, "Warn",
- "\n The original folder no longer exists, please choose the data set path again!")
+ self.msgBox.warning(
+ self,
+ "Warn",
+ "\n The original folder no longer exists, please choose the data set path again!",
+ )
self.actions.open_dataset_dir.setEnabled(False)
- defaultOpenDirPath = os.path.dirname(self.filePath) if self.filePath else '.'
+ defaultOpenDirPath = (
+ os.path.dirname(self.filePath) if self.filePath else "."
+ )
def init_key_list(self, label_dict):
if not self.kie_mode:
@@ -1785,8 +2262,8 @@ def init_key_list(self, label_dict):
sort_labels=True,
show_text_field=True,
completion="startswith",
- fit_to_content={'column': True, 'row': False},
- flags=None
+ fit_to_content={"column": True, "row": False},
+ flags=None,
)
def importDirImages(self, dirpath, isDelete=False):
@@ -1797,9 +2274,9 @@ def importDirImages(self, dirpath, isDelete=False):
if not isDelete:
self.loadFilestate(dirpath)
- self.PPlabelpath = dirpath + '/Label.txt'
+ self.PPlabelpath = dirpath + "/Label.txt"
self.PPlabel = self.loadLabelFile(self.PPlabelpath)
- self.Cachelabelpath = dirpath + '/Cache.cach'
+ self.Cachelabelpath = dirpath + "/Cache.cach"
self.Cachelabel = self.loadLabelFile(self.Cachelabelpath)
if self.Cachelabel:
self.PPlabel = dict(self.Cachelabel, **self.PPlabel)
@@ -1810,8 +2287,10 @@ def importDirImages(self, dirpath, isDelete=False):
self.dirname = dirpath
self.defaultSaveDir = dirpath
- self.statusBar().showMessage('%s started. Annotation will be saved to %s' %
- (__appname__, self.defaultSaveDir))
+ self.statusBar().showMessage(
+ "%s started. Annotation will be saved to %s"
+ % (__appname__, self.defaultSaveDir)
+ )
self.statusBar().show()
self.filePath = None
@@ -1819,8 +2298,8 @@ def importDirImages(self, dirpath, isDelete=False):
self.mImgList = self.scanAllImages(dirpath)
self.mImgList5 = self.mImgList[:5]
self.openNextImg()
- doneicon = newIcon('done')
- closeicon = newIcon('close')
+ doneicon = newIcon("done")
+ closeicon = newIcon("close")
for imgPath in self.mImgList:
filename = os.path.basename(imgPath)
if self.validFilestate(imgPath) is True:
@@ -1829,7 +2308,7 @@ def importDirImages(self, dirpath, isDelete=False):
item = QListWidgetItem(closeicon, filename)
self.fileListWidget.addItem(item)
- print('DirPath in importDirImages is', dirpath)
+ print("DirPath in importDirImages is", dirpath)
self.iconlist.clear()
self.additems5(dirpath)
self.changeFileFolder = True
@@ -1845,7 +2324,9 @@ def importDirImages(self, dirpath, isDelete=False):
self.actions.rotateRight.setEnabled(True)
self.fileListWidget.setCurrentRow(0) # set list index to first
- self.fileDock.setWindowTitle(self.fileListName + f" (1/{self.fileListWidget.count()})") # show image count
+ self.fileDock.setWindowTitle(
+ self.fileListName + f" (1/{self.fileListWidget.count()})"
+ ) # show image count
def openPrevImg(self, _value=False):
if len(self.mImgList) <= 0:
@@ -1881,13 +2362,13 @@ def openNextImg(self, _value=False):
else:
self.mImgList5 = self.indexTo5Files(currIndex)
if filename:
- print('file name in openNext is ', filename)
+ print("file name in openNext is ", filename)
self.loadFile(filename)
def updateFileListIcon(self, filename):
pass
- def saveFile(self, _value=False, mode='Manual'):
+ def saveFile(self, _value=False, mode="Manual"):
# Manual mode is used for users click "Save" manually,which will change the state of the image
if self.filePath:
imgidx = self.getImglabelidx(self.filePath)
@@ -1905,46 +2386,46 @@ def saveLockedShapes(self):
self.canvas.selectedShapes.remove(s)
self.canvas.shapes.remove(s)
- def _saveFile(self, annotationFilePath, mode='Manual'):
+ def _saveFile(self, annotationFilePath, mode="Manual"):
if len(self.canvas.lockedShapes) != 0:
self.saveLockedShapes()
- if mode == 'Manual':
+ if mode == "Manual":
self.result_dic_locked = []
img = cv2.imread(self.filePath)
width, height = self.image.width(), self.image.height()
for shape in self.canvas.lockedShapes:
- box = [[int(p[0] * width), int(p[1] * height)] for p in shape['ratio']]
+ box = [[int(p[0] * width), int(p[1] * height)] for p in shape["ratio"]]
# assert len(box) == 4
- result = [(shape['transcription'], 1)]
+ result = [(shape["transcription"], 1)]
result.insert(0, box)
self.result_dic_locked.append(result)
self.result_dic += self.result_dic_locked
self.result_dic_locked = []
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
self.setClean()
- self.statusBar().showMessage('Saved to %s' % annotationFilePath)
+ self.statusBar().showMessage("Saved to %s" % annotationFilePath)
self.statusBar().show()
currIndex = self.mImgList.index(self.filePath)
item = self.fileListWidget.item(currIndex)
- item.setIcon(newIcon('done'))
+ item.setIcon(newIcon("done"))
self.fileStatedict[self.filePath] = 1
if len(self.fileStatedict) % self.autoSaveNum == 0:
self.saveFilestate()
- self.savePPlabel(mode='Auto')
+ self.savePPlabel(mode="Auto")
self.fileListWidget.insertItem(int(currIndex), item)
if not self.canvas.isInTheSameImage:
self.openNextImg()
self.actions.saveRec.setEnabled(True)
self.actions.saveLabel.setEnabled(True)
- self.actions.exportJSON.setEnabled(True)
+ self.actions.exportJSON.setEnabled(True)
- elif mode == 'Auto':
+ elif mode == "Auto":
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
self.setClean()
- self.statusBar().showMessage('Saved to %s' % annotationFilePath)
+ self.statusBar().showMessage("Saved to %s" % annotationFilePath)
self.statusBar().show()
def closeFile(self, _value=False):
@@ -1961,24 +2442,34 @@ def deleteImg(self):
if deletePath is not None:
deleteInfo = self.deleteImgDialog()
if deleteInfo == QMessageBox.Yes:
- if platform.system() == 'Windows':
+ if platform.system() == "Windows":
# from win32com import shell, shellcon
# shell.SHFileOperation((0, shellcon.FO_DELETE, deletePath, None,
# shellcon.FOF_SILENT | shellcon.FOF_ALLOWUNDO | shellcon.FOF_NOCONFIRMATION,
# None, None))
os.remove(deletePath)
# linux
- elif platform.system() == 'Linux':
- cmd = 'trash ' + deletePath
+ elif platform.system() == "Linux":
+ cmd = "trash " + deletePath
os.system(cmd)
# macOS
- elif platform.system() == 'Darwin':
+ elif platform.system() == "Darwin":
import subprocess
- absPath = os.path.abspath(deletePath).replace('\\', '\\\\').replace('"', '\\"')
- cmd = ['osascript', '-e',
- 'tell app "Finder" to move {the POSIX file "' + absPath + '"} to trash']
+
+ absPath = (
+ os.path.abspath(deletePath)
+ .replace("\\", "\\\\")
+ .replace('"', '\\"')
+ )
+ cmd = [
+ "osascript",
+ "-e",
+ 'tell app "Finder" to move {the POSIX file "'
+ + absPath
+ + '"} to trash',
+ ]
print(cmd)
- subprocess.call(cmd, stdout=open(os.devnull, 'w'))
+ subprocess.call(cmd, stdout=open(os.devnull, "w"))
if self.filePath in self.fileStatedict.keys():
self.fileStatedict.pop(self.filePath)
@@ -1990,8 +2481,8 @@ def deleteImg(self):
def deleteImgDialog(self):
yes, cancel = QMessageBox.Yes, QMessageBox.Cancel
- msg = u'The image will be deleted to the recycle bin'
- return QMessageBox.warning(self, u'Attention', msg, yes | cancel)
+ msg = "The image will be deleted to the recycle bin"
+ return QMessageBox.warning(self, "Attention", msg, yes | cancel)
def resetAll(self):
self.settings.reset()
@@ -2016,22 +2507,24 @@ def mayContinue(self): #
def discardChangesDialog(self):
yes, no, cancel = QMessageBox.Yes, QMessageBox.No, QMessageBox.Cancel
- if self.lang == 'ch':
- msg = u'您有未保存的变更, 您想保存再继续吗?\n点击 "No" 丢弃所有未保存的变更.'
+ if self.lang == "ch":
+ msg = '您有未保存的变更, 您想保存再继续吗?\n点击 "No" 丢弃所有未保存的变更.'
else:
- msg = u'You have unsaved changes, would you like to save them and proceed?\nClick "No" to undo all changes.'
- return QMessageBox.warning(self, u'Attention', msg, yes | no | cancel)
+ msg = 'You have unsaved changes, would you like to save them and proceed?\nClick "No" to undo all changes.'
+ return QMessageBox.warning(self, "Attention", msg, yes | no | cancel)
def errorMessage(self, title, message):
- return QMessageBox.critical(self, title,
- '
%s
%s' % (title, message))
+ return QMessageBox.critical(
+ self, title, "%s
%s" % (title, message)
+ )
def currentPath(self):
- return os.path.dirname(self.filePath) if self.filePath else '.'
+ return os.path.dirname(self.filePath) if self.filePath else "."
def chooseColor(self):
- color = self.colorDialog.getColor(self.lineColor, u'Choose line color',
- default=DEFAULT_LINE_COLOR)
+ color = self.colorDialog.getColor(
+ self.lineColor, "Choose line color", default=DEFAULT_LINE_COLOR
+ )
if color:
self.lineColor = color
Shape.line_color = color
@@ -2046,22 +2539,30 @@ def deleteSelectedShape(self):
if self.noShapes():
for action in self.actions.onShapesPresent:
action.setEnabled(False)
- self.BoxListDock.setWindowTitle(self.BoxListDockName + f" ({self.BoxList.count()})")
- self.labelListDock.setWindowTitle(self.labelListDockName + f" ({self.labelList.count()})")
+ self.BoxListDock.setWindowTitle(
+ self.BoxListDockName + f" ({self.BoxList.count()})"
+ )
+ self.labelListDock.setWindowTitle(
+ self.labelListDockName + f" ({self.labelList.count()})"
+ )
def chshapeLineColor(self):
- color = self.colorDialog.getColor(self.lineColor, u'Choose line color',
- default=DEFAULT_LINE_COLOR)
+ color = self.colorDialog.getColor(
+ self.lineColor, "Choose line color", default=DEFAULT_LINE_COLOR
+ )
if color:
- for shape in self.canvas.selectedShapes: shape.line_color = color
+ for shape in self.canvas.selectedShapes:
+ shape.line_color = color
self.canvas.update()
self.setDirty()
def chshapeFillColor(self):
- color = self.colorDialog.getColor(self.fillColor, u'Choose fill color',
- default=DEFAULT_FILL_COLOR)
+ color = self.colorDialog.getColor(
+ self.fillColor, "Choose fill color", default=DEFAULT_FILL_COLOR
+ )
if color:
- for shape in self.canvas.selectedShapes: shape.fill_color = color
+ for shape in self.canvas.selectedShapes:
+ shape.fill_color = color
self.canvas.update()
self.setDirty()
@@ -2076,7 +2577,7 @@ def moveShape(self):
def loadPredefinedClasses(self, predefClassesFile):
if os.path.exists(predefClassesFile) is True:
- with codecs.open(predefClassesFile, 'r', 'utf8') as f:
+ with codecs.open(predefClassesFile, "r", "utf8") as f:
for line in f:
line = line.strip()
if self.labelHist is None:
@@ -2106,8 +2607,12 @@ def additems(self, dirpath):
pix = QPixmap(file)
_, filename = os.path.split(file)
filename, _ = os.path.splitext(filename)
- item = QListWidgetItem(QIcon(pix.scaled(100, 100, Qt.IgnoreAspectRatio, Qt.FastTransformation)),
- filename[:10])
+ item = QListWidgetItem(
+ QIcon(
+ pix.scaled(100, 100, Qt.IgnoreAspectRatio, Qt.FastTransformation)
+ ),
+ filename[:10],
+ )
item.setToolTip(file)
self.iconlist.addItem(item)
@@ -2122,7 +2627,12 @@ def additems5(self, dirpath):
prelen = lentoken // 2
bfilename = prelen * " " + pfilename + (lentoken - prelen) * " "
# item = QListWidgetItem(QIcon(pix.scaled(100, 100, Qt.KeepAspectRatio, Qt.SmoothTransformation)),filename[:10])
- item = QListWidgetItem(QIcon(pix.scaled(100, 100, Qt.IgnoreAspectRatio, Qt.FastTransformation)), pfilename)
+ item = QListWidgetItem(
+ QIcon(
+ pix.scaled(100, 100, Qt.IgnoreAspectRatio, Qt.FastTransformation)
+ ),
+ pfilename,
+ )
# item.setForeground(QBrush(Qt.white))
item.setToolTip(file)
self.iconlist.addItem(item)
@@ -2139,17 +2649,20 @@ def gen_quad_from_poly(self, poly):
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
- rect = cv2.minAreaRect(poly.astype(
- np.int32)) # (center (x,y), (width, height), angle of rotation)
+ rect = cv2.minAreaRect(
+ poly.astype(np.int32)
+ ) # (center (x,y), (width, height), angle of rotation)
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
- dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
- np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
- np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
- np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ dist = (
+ np.linalg.norm(box[(i + 0) % 4] - poly[0])
+ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
+ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
+ + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ )
if dist < min_dist:
min_dist = dist
first_point_idx = i
@@ -2166,19 +2679,21 @@ def gen_quad_from_poly(self, poly):
return bbox
def getImglabelidx(self, filePath):
- if platform.system() == 'Windows':
- spliter = '\\'
+ if platform.system() == "Windows":
+ spliter = "\\"
else:
- spliter = '/'
+ spliter = "/"
filepathsplit = filePath.split(spliter)[-2:]
- return filepathsplit[0] + '/' + filepathsplit[1]
+ return filepathsplit[0] + "/" + filepathsplit[1]
def autoRecognition(self):
assert self.mImgList is not None
- print('Using model from ', self.model)
+ print("Using model from ", self.model)
uncheckedList = [i for i in self.mImgList if i not in self.fileStatedict.keys()]
- self.autoDialog = AutoDialog(parent=self, ocr=self.ocr, mImgList=uncheckedList, lenbar=len(uncheckedList))
+ self.autoDialog = AutoDialog(
+ parent=self, ocr=self.ocr, mImgList=uncheckedList, lenbar=len(uncheckedList)
+ )
self.autoDialog.popUp()
self.currIndex = len(self.mImgList) - 1
self.loadFile(self.filePath) # ADD
@@ -2191,11 +2706,13 @@ def autoRecognition(self):
self.init_key_list(self.Cachelabel)
def reRecognition(self):
- img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
+ img = cv2.imdecode(np.fromfile(self.filePath, dtype=np.uint8), 1)
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
if self.canvas.shapes:
self.result_dic = []
- self.result_dic_locked = [] # result_dic_locked stores the ocr result of self.canvas.lockedShapes
+ self.result_dic_locked = (
+ []
+ ) # result_dic_locked stores the ocr result of self.canvas.lockedShapes
rec_flag = 0
for shape in self.canvas.shapes:
box = [[int(p.x()), int(p.y())] for p in shape.points]
@@ -2207,11 +2724,15 @@ def reRecognition(self):
img_crop = get_rotate_crop_image(img, np.array(box, np.float32))
if img_crop is None:
- msg = 'Can not recognise the detection box in ' + self.filePath + '. Please change manually'
+ msg = (
+ "Can not recognise the detection box in "
+ + self.filePath
+ + ". Please change manually"
+ )
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)[0]
- if result[0][0] != '':
+ if result[0][0] != "":
if shape.line_color == DEFAULT_LOCK_COLOR:
shape.label = result[0][0]
result.insert(0, box)
@@ -2224,43 +2745,49 @@ def reRecognition(self):
result.append(kie_cls)
self.result_dic.append(result)
else:
- print('Can not recognise the box')
+ print("Can not recognise the box")
if shape.line_color == DEFAULT_LOCK_COLOR:
shape.label = result[0][0]
if self.kie_mode:
- self.result_dic_locked.append([box, (self.noLabelText, 0), kie_cls])
+ self.result_dic_locked.append(
+ [box, (self.noLabelText, 0), kie_cls]
+ )
else:
self.result_dic_locked.append([box, (self.noLabelText, 0)])
else:
if self.kie_mode:
- self.result_dic.append([box, (self.noLabelText, 0), kie_cls])
+ self.result_dic.append(
+ [box, (self.noLabelText, 0), kie_cls]
+ )
else:
self.result_dic.append([box, (self.noLabelText, 0)])
try:
if self.noLabelText == shape.label or result[1][0] == shape.label:
- print('label no change')
+ print("label no change")
else:
rec_flag += 1
except IndexError as e:
- print('Can not recognise the box')
+ print("Can not recognise the box")
if (len(self.result_dic) > 0 and rec_flag > 0) or self.canvas.lockedShapes:
self.canvas.isInTheSameImage = True
- self.saveFile(mode='Auto')
+ self.saveFile(mode="Auto")
self.loadFile(self.filePath)
self.canvas.isInTheSameImage = False
self.setDirty()
elif len(self.result_dic) == len(self.canvas.shapes) and rec_flag == 0:
- if self.lang == 'ch':
+ if self.lang == "ch":
QMessageBox.information(self, "Information", "识别结果保持一致!")
else:
- QMessageBox.information(self, "Information", "The recognition result remains unchanged!")
+ QMessageBox.information(
+ self, "Information", "The recognition result remains unchanged!"
+ )
else:
- print('Can not recgonise in ', self.filePath)
+ print("Can not recgonise in ", self.filePath)
else:
QMessageBox.information(self, "Information", "Draw a box!")
def singleRerecognition(self):
- img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
+ img = cv2.imdecode(np.fromfile(self.filePath, dtype=np.uint8), 1)
for shape in self.canvas.selectedShapes:
box = [[int(p.x()), int(p.y())] for p in shape.points]
if len(box) > 4:
@@ -2268,30 +2795,34 @@ def singleRerecognition(self):
assert len(box) == 4
img_crop = get_rotate_crop_image(img, np.array(box, np.float32))
if img_crop is None:
- msg = 'Can not recognise the detection box in ' + self.filePath + '. Please change manually'
+ msg = (
+ "Can not recognise the detection box in "
+ + self.filePath
+ + ". Please change manually"
+ )
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)[0]
- if result[0][0] != '':
+ if result[0][0] != "":
result.insert(0, box)
- print('result in reRec is ', result)
+ print("result in reRec is ", result)
if result[1][0] == shape.label:
- print('label no change')
+ print("label no change")
else:
shape.label = result[1][0]
else:
- print('Can not recognise the box')
+ print("Can not recognise the box")
if self.noLabelText == shape.label:
- print('label no change')
+ print("label no change")
else:
shape.label = self.noLabelText
self.singleLabel(shape)
self.setDirty()
def TableRecognition(self):
- '''
- Table Recegnition
- '''
+ """
+ Table Recegnition
+ """
from paddleocr import to_excel
import time
@@ -2300,27 +2831,35 @@ def TableRecognition(self):
img = cv2.imread(self.filePath)
res = self.table_ocr(img, return_ocr_result_in_table=True)
- TableRec_excel_dir = self.lastOpenDir + '/tableRec_excel_output/'
+ TableRec_excel_dir = self.lastOpenDir + "/tableRec_excel_output/"
os.makedirs(TableRec_excel_dir, exist_ok=True)
filename, _ = os.path.splitext(os.path.basename(self.filePath))
- excel_path = TableRec_excel_dir + '{}.xlsx'.format(filename)
-
+ excel_path = TableRec_excel_dir + "{}.xlsx".format(filename)
+
if res is None:
- msg = 'Can not recognise the table in ' + self.filePath + '. Please change manually'
+ msg = (
+ "Can not recognise the table in "
+ + self.filePath
+ + ". Please change manually"
+ )
QMessageBox.information(self, "Information", msg)
- to_excel('', excel_path) # create an empty excel
+ to_excel("", excel_path) # create an empty excel
return
-
+
# save res
# ONLY SUPPORT ONE TABLE in one image
hasTable = False
for region in res:
- if region['type'] == 'table':
- if region['res']['boxes'] is None:
- msg = 'Can not recognise the detection box in ' + self.filePath + '. Please change manually'
+ if region["type"] == "table":
+ if region["res"]["boxes"] is None:
+ msg = (
+ "Can not recognise the detection box in "
+ + self.filePath
+ + ". Please change manually"
+ )
QMessageBox.information(self, "Information", msg)
- to_excel('', excel_path) # create an empty excel
+ to_excel("", excel_path) # create an empty excel
return
hasTable = True
# save table ocr result on PPOCRLabel
@@ -2336,19 +2875,26 @@ def TableRecognition(self):
self.result_dic_locked = []
shapes = []
- result_len = len(region['res']['boxes'])
+ result_len = len(region["res"]["boxes"])
order_index = 0
for i in range(result_len):
- bbox = np.array(region['res']['boxes'][i])
- rec_text = region['res']['rec_res'][i][0]
+ bbox = np.array(region["res"]["boxes"][i])
+ rec_text = region["res"]["rec_res"][i][0]
- rext_bbox = [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]
+ rext_bbox = [
+ [bbox[0], bbox[1]],
+ [bbox[2], bbox[1]],
+ [bbox[2], bbox[3]],
+ [bbox[0], bbox[3]],
+ ]
# save bbox to shape
- shape = Shape(label=rec_text, line_color=DEFAULT_LINE_COLOR, key_cls=None)
+ shape = Shape(
+ label=rec_text, line_color=DEFAULT_LINE_COLOR, key_cls=None
+ )
for point in rext_bbox:
x, y = point
- # Ensure the labels are within the bounds of the image.
+ # Ensure the labels are within the bounds of the image.
# If not, fix them.
x, y, snapped = self.canvas.snapPointToCanvas(x, y)
shape.addPoint(QPointF(x, y))
@@ -2361,27 +2907,35 @@ def TableRecognition(self):
shapes.append(shape)
self.setDirty()
self.canvas.loadShapes(shapes)
-
+
# save HTML result to excel
try:
- to_excel(region['res']['html'], excel_path)
+ to_excel(region["res"]["html"], excel_path)
except:
- print('Can not save excel file, maybe Permission denied (.xlsx is being occupied)')
+ print(
+ "Can not save excel file, maybe Permission denied (.xlsx is being occupied)"
+ )
break
-
+
if not hasTable:
- msg = 'Can not recognise the table in ' + self.filePath + '. Please change manually'
+ msg = (
+ "Can not recognise the table in "
+ + self.filePath
+ + ". Please change manually"
+ )
QMessageBox.information(self, "Information", msg)
- to_excel('', excel_path) # create an empty excel
+ to_excel("", excel_path) # create an empty excel
return
# automatically open excel annotation file
- if platform.system() == 'Windows':
+ if platform.system() == "Windows":
try:
import win32com.client
except:
- print("CANNOT OPEN .xlsx. It could be one of the following reasons: " \
- "Only support Windows | No python win32com")
+ print(
+ "CANNOT OPEN .xlsx. It could be one of the following reasons: "
+ "Only support Windows | No python win32com"
+ )
try:
xl = win32com.client.Dispatch("Excel.Application")
@@ -2392,17 +2946,19 @@ def TableRecognition(self):
# os.startfile(excel_path)
except:
- print("CANNOT OPEN .xlsx. It could be the following reasons: " \
- ".xlsx is not existed")
+ print(
+ "CANNOT OPEN .xlsx. It could be the following reasons: "
+ ".xlsx is not existed"
+ )
else:
- os.system('open ' + os.path.normpath(excel_path))
-
- print('time cost: ', time.time() - start)
+ os.system("open " + os.path.normpath(excel_path))
+
+ print("time cost: ", time.time() - start)
def cellreRecognition(self):
- '''
- re-recognise text in a cell
- '''
+ """
+ re-recognise text in a cell
+ """
img = cv2.imread(self.filePath)
for shape in self.canvas.selectedShapes:
box = [[int(p.x()), int(p.y())] for p in shape.points]
@@ -2415,90 +2971,95 @@ def cellreRecognition(self):
_box = boxPad(box, img.shape, 6)
img_crop = get_rotate_crop_image(img, np.array(_box, np.float32))
if img_crop is None:
- msg = 'Can not recognise the detection box in ' + self.filePath + '. Please change manually'
+ msg = (
+ "Can not recognise the detection box in "
+ + self.filePath
+ + ". Please change manually"
+ )
QMessageBox.information(self, "Information", msg)
return
# merge the text result in the cell
- texts = ''
- probs = 0. # the probability of the cell is avgerage prob of every text box in the cell
+ texts = ""
+ probs = 0.0 # the probability of the cell is avgerage prob of every text box in the cell
bboxes = self.ocr.ocr(img_crop, det=True, rec=False, cls=False)[0]
if len(bboxes) > 0:
- bboxes.reverse() # top row text at first
+ bboxes.reverse() # top row text at first
for _bbox in bboxes:
patch = get_rotate_crop_image(img_crop, np.array(_bbox, np.float32))
rec_res = self.ocr.ocr(patch, det=False, rec=True, cls=False)[0]
text = rec_res[0][0]
- if text != '':
- texts += text + ('' if text[0].isalpha() else ' ') # add space between english word
+ if text != "":
+ texts += text + (
+ "" if text[0].isalpha() else " "
+ ) # add space between english word
probs += rec_res[0][1]
probs = probs / len(bboxes)
result = [(texts.strip(), probs)]
- if result[0][0] != '':
+ if result[0][0] != "":
result.insert(0, box)
- print('result in reRec is ', result)
+ print("result in reRec is ", result)
if result[1][0] == shape.label:
- print('label no change')
+ print("label no change")
else:
shape.label = result[1][0]
else:
- print('Can not recognise the box')
+ print("Can not recognise the box")
if self.noLabelText == shape.label:
- print('label no change')
+ print("label no change")
else:
shape.label = self.noLabelText
self.singleLabel(shape)
self.setDirty()
def exportJSON(self):
- '''
- export PPLabel and CSV to JSON (PubTabNet)
- '''
+ """
+ export PPLabel and CSV to JSON (PubTabNet)
+ """
import pandas as pd
# automatically save annotations
self.saveFilestate()
- self.savePPlabel(mode='auto')
+ self.savePPlabel(mode="auto")
# load box annotations
labeldict = {}
if not os.path.exists(self.PPlabelpath):
- msg = 'ERROR, Can not find Label.txt'
+ msg = "ERROR, Can not find Label.txt"
QMessageBox.information(self, "Information", msg)
return
else:
- with open(self.PPlabelpath, 'r', encoding='utf-8') as f:
+ with open(self.PPlabelpath, "r", encoding="utf-8") as f:
data = f.readlines()
for each in data:
- file, label = each.split('\t')
+ file, label = each.split("\t")
if label:
- label = label.replace('false', 'False')
- label = label.replace('true', 'True')
+ label = label.replace("false", "False")
+ label = label.replace("true", "True")
labeldict[file] = eval(label)
else:
labeldict[file] = []
-
+
# read table recognition output
- TableRec_excel_dir = os.path.join(
- self.lastOpenDir, 'tableRec_excel_output')
+ TableRec_excel_dir = os.path.join(self.lastOpenDir, "tableRec_excel_output")
# save txt
- fid = open(
- "{}/gt.txt".format(self.lastOpenDir), "w", encoding='utf-8')
+ fid = open("{}/gt.txt".format(self.lastOpenDir), "w", encoding="utf-8")
for image_path in labeldict.keys():
# load csv annotations
filename, _ = os.path.splitext(os.path.basename(image_path))
- csv_path = os.path.join(
- TableRec_excel_dir, filename + '.xlsx')
+ csv_path = os.path.join(TableRec_excel_dir, filename + ".xlsx")
if not os.path.exists(csv_path):
continue
excel = xlrd.open_workbook(csv_path)
sheet0 = excel.sheet_by_index(0) # only sheet 0
- merged_cells = sheet0.merged_cells # (0,1,1,3) start row, end row, start col, end col
+ merged_cells = (
+ sheet0.merged_cells
+ ) # (0,1,1,3) start row, end row, start col, end col
- html_list = [['td'] * sheet0.ncols for i in range(sheet0.nrows)]
+ html_list = [["td"] * sheet0.ncols for i in range(sheet0.nrows)]
for merged in merged_cells:
html_list = expand_list(merged, html_list)
@@ -2508,53 +3069,42 @@ def exportJSON(self):
# load box annotations
cells = []
for anno in labeldict[image_path]:
- tokens = list(anno['transcription'])
- cells.append({
- 'tokens': tokens,
- 'bbox': anno['points']
- })
+ tokens = list(anno["transcription"])
+ cells.append({"tokens": tokens, "bbox": anno["points"]})
# 构造标注信息
- html = {
- 'structure': {
- 'tokens': token_list
- },
- 'cells': cells
- }
- d = {
- 'filename': os.path.basename(image_path),
- 'html': html
- }
+ html = {"structure": {"tokens": token_list}, "cells": cells}
+ d = {"filename": os.path.basename(image_path), "html": html}
# 重构HTML
- d['gt'] = rebuild_html_from_ppstructure_label(d)
- fid.write('{}\n'.format(
- json.dumps(
- d, ensure_ascii=False)))
-
+ d["gt"] = rebuild_html_from_ppstructure_label(d)
+ fid.write("{}\n".format(json.dumps(d, ensure_ascii=False)))
+
# convert to PP-Structure label format
fid.close()
- msg = 'JSON sucessfully saved in {}/gt.txt'.format(self.lastOpenDir)
+ msg = "JSON sucessfully saved in {}/gt.txt".format(self.lastOpenDir)
QMessageBox.information(self, "Information", msg)
def autolcm(self):
vbox = QVBoxLayout()
hbox = QHBoxLayout()
self.panel = QLabel()
- self.panel.setText(self.stringBundle.getString('choseModelLg'))
+ self.panel.setText(self.stringBundle.getString("choseModelLg"))
self.panel.setAlignment(Qt.AlignLeft)
self.comboBox = QComboBox()
self.comboBox.setObjectName("comboBox")
- self.comboBox.addItems(['Chinese & English', 'English', 'French', 'German', 'Korean', 'Japanese'])
+ self.comboBox.addItems(
+ ["Chinese & English", "English", "French", "German", "Korean", "Japanese"]
+ )
vbox.addWidget(self.panel)
vbox.addWidget(self.comboBox)
self.dialog = QDialog()
self.dialog.resize(300, 100)
- self.okBtn = QPushButton(self.stringBundle.getString('ok'))
- self.cancelBtn = QPushButton(self.stringBundle.getString('cancel'))
+ self.okBtn = QPushButton(self.stringBundle.getString("ok"))
+ self.cancelBtn = QPushButton(self.stringBundle.getString("cancel"))
self.okBtn.clicked.connect(self.modelChoose)
self.cancelBtn.clicked.connect(self.cancel)
- self.dialog.setWindowTitle(self.stringBundle.getString('choseModelLg'))
+ self.dialog.setWindowTitle(self.stringBundle.getString("choseModelLg"))
hbox.addWidget(self.okBtn)
hbox.addWidget(self.cancelBtn)
@@ -2570,81 +3120,95 @@ def autolcm(self):
def modelChoose(self):
print(self.comboBox.currentText())
- lg_idx = {'Chinese & English': 'ch', 'English': 'en', 'French': 'french', 'German': 'german',
- 'Korean': 'korean', 'Japanese': 'japan'}
+ lg_idx = {
+ "Chinese & English": "ch",
+ "English": "en",
+ "French": "french",
+ "German": "german",
+ "Korean": "korean",
+ "Japanese": "japan",
+ }
del self.ocr
- self.ocr = PaddleOCR(use_pdserving=False, use_angle_cls=True, det=True, cls=True, use_gpu=False,
- lang=lg_idx[self.comboBox.currentText()])
+ self.ocr = PaddleOCR(
+ use_pdserving=False,
+ use_angle_cls=True,
+ det=True,
+ cls=True,
+ use_gpu=False,
+ lang=lg_idx[self.comboBox.currentText()],
+ )
del self.table_ocr
- self.table_ocr = PPStructure(use_pdserving=False,
- use_gpu=False,
- lang=lg_idx[self.comboBox.currentText()],
- layout=False,
- show_log=False)
+ self.table_ocr = PPStructure(
+ use_pdserving=False,
+ use_gpu=False,
+ lang=lg_idx[self.comboBox.currentText()],
+ layout=False,
+ show_log=False,
+ )
self.dialog.close()
def cancel(self):
self.dialog.close()
def loadFilestate(self, saveDir):
- self.fileStatepath = saveDir + '/fileState.txt'
+ self.fileStatepath = saveDir + "/fileState.txt"
self.fileStatedict = {}
if not os.path.exists(self.fileStatepath):
- f = open(self.fileStatepath, 'w', encoding='utf-8')
+ f = open(self.fileStatepath, "w", encoding="utf-8")
else:
- with open(self.fileStatepath, 'r', encoding='utf-8') as f:
+ with open(self.fileStatepath, "r", encoding="utf-8") as f:
states = f.readlines()
for each in states:
- file, state = each.split('\t')
+ file, state = each.split("\t")
self.fileStatedict[file] = 1
self.actions.saveLabel.setEnabled(True)
self.actions.saveRec.setEnabled(True)
self.actions.exportJSON.setEnabled(True)
def saveFilestate(self):
- with open(self.fileStatepath, 'w', encoding='utf-8') as f:
+ with open(self.fileStatepath, "w", encoding="utf-8") as f:
for key in self.fileStatedict:
- f.write(key + '\t')
- f.write(str(self.fileStatedict[key]) + '\n')
+ f.write(key + "\t")
+ f.write(str(self.fileStatedict[key]) + "\n")
def loadLabelFile(self, labelpath):
labeldict = {}
if not os.path.exists(labelpath):
- f = open(labelpath, 'w', encoding='utf-8')
+ f = open(labelpath, "w", encoding="utf-8")
else:
- with open(labelpath, 'r', encoding='utf-8') as f:
+ with open(labelpath, "r", encoding="utf-8") as f:
data = f.readlines()
for each in data:
- file, label = each.split('\t')
+ file, label = each.split("\t")
if label:
- label = label.replace('false', 'False')
- label = label.replace('true', 'True')
+ label = label.replace("false", "False")
+ label = label.replace("true", "True")
labeldict[file] = eval(label)
else:
labeldict[file] = []
return labeldict
- def savePPlabel(self, mode='Manual'):
+ def savePPlabel(self, mode="Manual"):
savedfile = [self.getImglabelidx(i) for i in self.fileStatedict.keys()]
- with open(self.PPlabelpath, 'w', encoding='utf-8') as f:
+ with open(self.PPlabelpath, "w", encoding="utf-8") as f:
for key in self.PPlabel:
if key in savedfile and self.PPlabel[key] != []:
- f.write(key + '\t')
- f.write(json.dumps(self.PPlabel[key], ensure_ascii=False) + '\n')
+ f.write(key + "\t")
+ f.write(json.dumps(self.PPlabel[key], ensure_ascii=False) + "\n")
- if mode == 'Manual':
- if self.lang == 'ch':
- msg = '已将检查过的图片标签保存在 ' + self.PPlabelpath + " 文件中"
+ if mode == "Manual":
+ if self.lang == "ch":
+ msg = "已将检查过的图片标签保存在 " + self.PPlabelpath + " 文件中"
else:
- msg = 'Images that have been checked are saved in ' + self.PPlabelpath
+ msg = "Images that have been checked are saved in " + self.PPlabelpath
QMessageBox.information(self, "Information", msg)
def saveCacheLabel(self):
- with open(self.Cachelabelpath, 'w', encoding='utf-8') as f:
+ with open(self.Cachelabelpath, "w", encoding="utf-8") as f:
for key in self.Cachelabel:
- f.write(key + '\t')
- f.write(json.dumps(self.Cachelabel[key], ensure_ascii=False) + '\n')
+ f.write(key + "\t")
+ f.write(json.dumps(self.Cachelabel[key], ensure_ascii=False) + "\n")
def saveLabelFile(self):
self.saveFilestate()
@@ -2655,36 +3219,51 @@ def saveRecResult(self):
QMessageBox.information(self, "Information", "Check the image first")
return
- rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt'
- crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/'
+ rec_gt_dir = os.path.dirname(self.PPlabelpath) + "/rec_gt.txt"
+ crop_img_dir = os.path.dirname(self.PPlabelpath) + "/crop_img/"
ques_img = []
if not os.path.exists(crop_img_dir):
os.mkdir(crop_img_dir)
- with open(rec_gt_dir, 'w', encoding='utf-8') as f:
+ with open(rec_gt_dir, "w", encoding="utf-8") as f:
for key in self.fileStatedict:
idx = self.getImglabelidx(key)
try:
img = cv2.imdecode(np.fromfile(key, dtype=np.uint8), -1)
for i, label in enumerate(self.PPlabel[idx]):
- if label['difficult']:
+ if label["difficult"]:
continue
- img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
- img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_' + str(i) + '.jpg'
- cv2.imencode(".jpg",img_crop)[1].tofile(crop_img_dir + img_name)
- f.write('crop_img/' + img_name + '\t')
- f.write(label['transcription'] + '\n')
+ img_crop = get_rotate_crop_image(
+ img, np.array(label["points"], np.float32)
+ )
+ img_name = (
+ os.path.splitext(os.path.basename(idx))[0]
+ + "_crop_"
+ + str(i)
+ + ".jpg"
+ )
+ cv2.imencode(".jpg", img_crop)[1].tofile(
+ crop_img_dir + img_name
+ )
+ f.write("crop_img/" + img_name + "\t")
+ f.write(label["transcription"] + "\n")
except KeyError as e:
pass
except Exception as e:
ques_img.append(key)
traceback.print_exc()
if ques_img:
- QMessageBox.information(self,
- "Information",
- "The following images can not be saved, please check the image path and labels.\n"
- + "".join(str(i) + '\n' for i in ques_img))
- QMessageBox.information(self, "Information", "Cropped images have been saved in " + str(crop_img_dir))
+ QMessageBox.information(
+ self,
+ "Information",
+ "The following images can not be saved, please check the image path and labels.\n"
+ + "".join(str(i) + "\n" for i in ques_img),
+ )
+ QMessageBox.information(
+ self,
+ "Information",
+ "Cropped images have been saved in " + str(crop_img_dir),
+ )
def speedChoose(self):
if self.labelDialogOption.isChecked():
@@ -2702,10 +3281,12 @@ def autoSaveFunc(self):
self.saveLabelFile()
except:
pass
- print('The program will automatically save once after confirming an image')
+ print("The program will automatically save once after confirming an image")
else:
self.autoSaveNum = 5 # Used for backup
- print('The program will automatically save once after confirming 5 images (default)')
+ print(
+ "The program will automatically save once after confirming 5 images (default)"
+ )
def change_box_key(self):
if not self.kie_mode:
@@ -2724,7 +3305,7 @@ def change_box_key(self):
self._update_shape_color(shape)
self.keyDialog.addLabelHistory(key_text)
-
+
# save changed shape
self.setDirty()
@@ -2749,20 +3330,23 @@ def loadShapes(self, shapes, replace=True):
def lockSelectedShape(self):
"""lock the selected shapes.
- Add self.selectedShapes to lock self.canvas.lockedShapes,
+ Add self.selectedShapes to lock self.canvas.lockedShapes,
which holds the ratio of the four coordinates of the locked shapes
to the width and height of the image
"""
width, height = self.image.width(), self.image.height()
def format_shape(s):
- return dict(label=s.label, # str
- line_color=s.line_color.getRgb(),
- fill_color=s.fill_color.getRgb(),
- ratio=[[int(p.x()) / width, int(p.y()) / height] for p in s.points], # QPonitF
- difficult=s.difficult, # bool
- key_cls=s.key_cls, # bool
- )
+ return dict(
+ label=s.label, # str
+ line_color=s.line_color.getRgb(),
+ fill_color=s.fill_color.getRgb(),
+ ratio=[
+ [int(p.x()) / width, int(p.y()) / height] for p in s.points
+ ], # QPonitF
+ difficult=s.difficult, # bool
+ key_cls=s.key_cls, # bool
+ )
# lock
if len(self.canvas.lockedShapes) == 0:
@@ -2772,7 +3356,11 @@ def format_shape(s):
shapes = [format_shape(shape) for shape in self.canvas.selectedShapes]
trans_dic = []
for box in shapes:
- trans_dict = {"transcription": box['label'], "ratio": box['ratio'], "difficult": box['difficult']}
+ trans_dict = {
+ "transcription": box["label"],
+ "ratio": box["ratio"],
+ "difficult": box["difficult"],
+ }
if self.kie_mode:
trans_dict.update({"key_cls": box["key_cls"]})
trans_dic.append(trans_dict)
@@ -2796,23 +3384,24 @@ def expandSelectedShape(self):
box = self.gen_quad_from_poly(np.array(box))
assert len(box) == 4
box = boxPad(box, img.shape, 3)
- shape.points = [QPointF(box[0][0], box[0][1]),
- QPointF(box[1][0], box[1][1]),
- QPointF(box[2][0], box[2][1]),
- QPointF(box[3][0], box[3][1])]
+ shape.points = [
+ QPointF(box[0][0], box[0][1]),
+ QPointF(box[1][0], box[1][1]),
+ QPointF(box[2][0], box[2][1]),
+ QPointF(box[3][0], box[3][1]),
+ ]
print(shape.points)
self.updateBoxlist()
self.setDirty()
-
def inverted(color):
return QColor(*[255 - v for v in color.getRgb()])
def read(filename, default=None):
try:
- with open(filename, 'rb') as f:
+ with open(filename, "rb") as f:
return f.read()
except:
return default
@@ -2832,18 +3421,24 @@ def get_main_app(argv=[]):
app.setWindowIcon(newIcon("app"))
# Tzutalin 201705+: Accept extra arguments to change predefined class file
arg_parser = argparse.ArgumentParser()
- arg_parser.add_argument("--lang", type=str, default='en', nargs="?")
+ arg_parser.add_argument("--lang", type=str, default="en", nargs="?")
arg_parser.add_argument("--gpu", type=str2bool, default=True, nargs="?")
arg_parser.add_argument("--kie", type=str2bool, default=False, nargs="?")
- arg_parser.add_argument("--predefined_classes_file",
- default=os.path.join(os.path.dirname(__file__), "data", "predefined_classes.txt"),
- nargs="?")
+ arg_parser.add_argument(
+ "--predefined_classes_file",
+ default=os.path.join(
+ os.path.dirname(__file__), "data", "predefined_classes.txt"
+ ),
+ nargs="?",
+ )
args = arg_parser.parse_args(argv[1:])
- win = MainWindow(lang=args.lang,
- gpu=args.gpu,
- kie_mode=args.kie,
- default_predefined_class_file=args.predefined_classes_file)
+ win = MainWindow(
+ lang=args.lang,
+ gpu=args.gpu,
+ kie_mode=args.kie,
+ default_predefined_class_file=args.predefined_classes_file,
+ )
win.show()
return app, win
@@ -2854,12 +3449,13 @@ def main():
return app.exec_()
-if __name__ == '__main__':
-
- resource_file = './libs/resources.py'
+if __name__ == "__main__":
+ resource_file = "./libs/resources.py"
if not os.path.exists(resource_file):
- output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
- assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
- "directory resources.py "
+ output = os.system("pyrcc5 -o libs/resources.py resources.qrc")
+ assert output == 0, (
+ "operate the cmd have some problems ,please check whether there is a in the lib "
+ "directory resources.py "
+ )
sys.exit(main())
diff --git a/PPOCRLabel/gen_ocr_train_val_test.py b/PPOCRLabel/gen_ocr_train_val_test.py
index a0e9294e27..5e7d599a6a 100644
--- a/PPOCRLabel/gen_ocr_train_val_test.py
+++ b/PPOCRLabel/gen_ocr_train_val_test.py
@@ -17,8 +17,16 @@ def isCreateOrDeleteFolder(path, flag):
return flagAbsPath
-def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_path, train_txt, val_txt, test_txt, flag):
-
+def splitTrainVal(
+ root,
+ abs_train_root_path,
+ abs_val_root_path,
+ abs_test_root_path,
+ train_txt,
+ val_txt,
+ test_txt,
+ flag,
+):
data_abs_path = os.path.abspath(root)
label_file_name = args.detLabelFileName if flag == "det" else args.recLabelFileName
label_file_path = os.path.join(data_abs_path, label_file_name)
@@ -29,13 +37,15 @@ def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_pa
label_record_len = len(label_file_content)
for index, label_record_info in enumerate(label_file_content):
- image_relative_path, image_label = label_record_info.split('\t')
+ image_relative_path, image_label = label_record_info.split("\t")
image_name = os.path.basename(image_relative_path)
if flag == "det":
image_path = os.path.join(data_abs_path, image_name)
elif flag == "rec":
- image_path = os.path.join(data_abs_path, args.recImageDirName, image_name)
+ image_path = os.path.join(
+ data_abs_path, args.recImageDirName, image_name
+ )
train_val_test_ratio = args.trainValTestRatio.split(":")
train_ratio = eval(train_val_test_ratio[0]) / 10
@@ -77,27 +87,46 @@ def genDetRecTrainVal(args):
removeFile(os.path.join(args.recRootPath, "val.txt"))
removeFile(os.path.join(args.recRootPath, "test.txt"))
- detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
+ detTrainTxt = open(
+ os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8"
+ )
detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
- recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
+ recTrainTxt = open(
+ os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8"
+ )
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
- splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
- detTestTxt, "det")
+ splitTrainVal(
+ args.datasetRootPath,
+ detAbsTrainRootPath,
+ detAbsValRootPath,
+ detAbsTestRootPath,
+ detTrainTxt,
+ detValTxt,
+ detTestTxt,
+ "det",
+ )
for root, dirs, files in os.walk(args.datasetRootPath):
for dir in dirs:
- if dir == 'crop_img':
- splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
- recTestTxt, "rec")
+ if dir == "crop_img":
+ splitTrainVal(
+ root,
+ recAbsTrainRootPath,
+ recAbsValRootPath,
+ recAbsTestRootPath,
+ recTrainTxt,
+ recValTxt,
+ recTestTxt,
+ "rec",
+ )
else:
continue
break
-
if __name__ == "__main__":
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
@@ -107,40 +136,43 @@ def genDetRecTrainVal(args):
"--trainValTestRatio",
type=str,
default="6:2:2",
- help="ratio of trainset:valset:testset")
+ help="ratio of trainset:valset:testset",
+ )
parser.add_argument(
"--datasetRootPath",
type=str,
default="../train_data/",
- help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
+ help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3...",
)
parser.add_argument(
"--detRootPath",
type=str,
default="../train_data/det",
- help="the path where the divided detection dataset is placed")
+ help="the path where the divided detection dataset is placed",
+ )
parser.add_argument(
"--recRootPath",
type=str,
default="../train_data/rec",
- help="the path where the divided recognition dataset is placed"
+ help="the path where the divided recognition dataset is placed",
)
parser.add_argument(
"--detLabelFileName",
type=str,
default="Label.txt",
- help="the name of the detection annotation file")
+ help="the name of the detection annotation file",
+ )
parser.add_argument(
"--recLabelFileName",
type=str,
default="rec_gt.txt",
- help="the name of the recognition annotation file"
+ help="the name of the recognition annotation file",
)
parser.add_argument(
"--recImageDirName",
type=str,
default="crop_img",
- help="the name of the folder where the cropped recognition dataset is located"
+ help="the name of the folder where the cropped recognition dataset is located",
)
args = parser.parse_args()
- genDetRecTrainVal(args)
\ No newline at end of file
+ genDetRecTrainVal(args)
diff --git a/PPOCRLabel/libs/__init__.py b/PPOCRLabel/libs/__init__.py
index 0df0cc8460..57d092cfed 100644
--- a/PPOCRLabel/libs/__init__.py
+++ b/PPOCRLabel/libs/__init__.py
@@ -1,2 +1,2 @@
-__version_info__ = ('1', '0', '0')
-__version__ = '.'.join(__version_info__)
+__version_info__ = ("1", "0", "0")
+__version__ = ".".join(__version_info__)
diff --git a/PPOCRLabel/libs/autoDialog.py b/PPOCRLabel/libs/autoDialog.py
index 55636eec0f..f1a27ba535 100644
--- a/PPOCRLabel/libs/autoDialog.py
+++ b/PPOCRLabel/libs/autoDialog.py
@@ -29,7 +29,7 @@ def __init__(self, ocr, mImgList, mainThread, model):
self.mImgList = mImgList
self.mainThread = mainThread
self.model = model
- self.setStackSize(1024*1024)
+ self.setStackSize(1024 * 1024)
def run(self):
try:
@@ -37,32 +37,45 @@ def run(self):
for Imgpath in self.mImgList:
if self.handle == 0:
self.listValue.emit(Imgpath)
- if self.model == 'paddle':
- h, w, _ = cv2.imdecode(np.fromfile(Imgpath, dtype=np.uint8), 1).shape
+ if self.model == "paddle":
+ h, w, _ = cv2.imdecode(
+ np.fromfile(Imgpath, dtype=np.uint8), 1
+ ).shape
if h > 32 and w > 32:
- self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)[0]
+ self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)[
+ 0
+ ]
else:
- print('The size of', Imgpath, 'is too small to be recognised')
+ print(
+ "The size of", Imgpath, "is too small to be recognised"
+ )
self.result_dic = None
# 结果保存
if self.result_dic is None or len(self.result_dic) == 0:
- print('Can not recognise file', Imgpath)
+ print("Can not recognise file", Imgpath)
pass
else:
- strs = ''
+ strs = ""
for res in self.result_dic:
chars = res[1][0]
cond = res[1][1]
posi = res[0]
- strs += "Transcription: " + chars + " Probability: " + str(cond) + \
- " Location: " + json.dumps(posi) +'\n'
+ strs += (
+ "Transcription: "
+ + chars
+ + " Probability: "
+ + str(cond)
+ + " Location: "
+ + json.dumps(posi)
+ + "\n"
+ )
# Sending large amounts of data repeatedly through pyqtSignal may affect the program efficiency
self.listValue.emit(strs)
self.mainThread.result_dic = self.result_dic
self.mainThread.filePath = Imgpath
# 保存
- self.mainThread.saveFile(mode='Auto')
+ self.mainThread.saveFile(mode="Auto")
findex += 1
self.progressBarValue.emit(findex)
else:
@@ -75,8 +88,9 @@ def run(self):
class AutoDialog(QDialog):
-
- def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=None, lenbar=0):
+ def __init__(
+ self, text="Enter object label", parent=None, ocr=None, mImgList=None, lenbar=0
+ ):
super(AutoDialog, self).__init__(parent)
self.setFixedWidth(1000)
self.parent = parent
@@ -89,13 +103,13 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No
layout = QVBoxLayout()
layout.addWidget(self.pb)
- self.model = 'paddle'
+ self.model = "paddle"
self.listWidget = QListWidget(self)
layout.addWidget(self.listWidget)
self.buttonBox = bb = BB(BB.Ok | BB.Cancel, Qt.Horizontal, self)
- bb.button(BB.Ok).setIcon(newIcon('done'))
- bb.button(BB.Cancel).setIcon(newIcon('undo'))
+ bb.button(BB.Ok).setIcon(newIcon("done"))
+ bb.button(BB.Cancel).setIcon(newIcon("undo"))
bb.accepted.connect(self.validate)
bb.rejected.connect(self.reject)
layout.addWidget(bb)
@@ -107,7 +121,7 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No
# self.setWindowFlags(Qt.WindowCloseButtonHint)
- self.thread_1 = Worker(self.ocr, self.mImgList, self.parent, 'paddle')
+ self.thread_1 = Worker(self.ocr, self.mImgList, self.parent, "paddle")
self.thread_1.progressBarValue.connect(self.handleProgressBarSingal)
self.thread_1.listValue.connect(self.handleListWidgetSingal)
self.thread_1.endsignal.connect(self.handleEndsignalSignal)
@@ -117,8 +131,14 @@ def handleProgressBarSingal(self, i):
self.pb.setValue(i)
# calculate time left of auto labeling
- avg_time = (time.time() - self.time_start) / i # Use average time to prevent time fluctuations
- time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(".")[0] # Remove microseconds
+ avg_time = (
+ time.time() - self.time_start
+ ) / i # Use average time to prevent time fluctuations
+ time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(
+ "."
+ )[
+ 0
+ ] # Remove microseconds
self.setWindowTitle("PPOCRLabel -- " + f"Time Left: {time_left}") # show
def handleListWidgetSingal(self, i):
diff --git a/PPOCRLabel/libs/canvas.py b/PPOCRLabel/libs/canvas.py
index b1f41f0c31..a0517d2f0c 100644
--- a/PPOCRLabel/libs/canvas.py
+++ b/PPOCRLabel/libs/canvas.py
@@ -36,7 +36,7 @@ class Canvas(QWidget):
drawingPolygon = pyqtSignal(bool)
CREATE, EDIT = list(range(2))
- _fill_drawing = False # draw shadows
+ _fill_drawing = False # draw shadows
epsilon = 5.0
@@ -71,15 +71,15 @@ def __init__(self, *args, **kwargs):
self.setFocusPolicy(Qt.WheelFocus)
self.verified = False
self.drawSquare = False
- self.fourpoint = True # ADD
+ self.fourpoint = True # ADD
self.pointnum = 0
self.movingShape = False
self.selectCountShape = False
- #initialisation for panning
+ # initialisation for panning
self.pan_initial_pos = QPoint()
- #lockedshapes related
+ # lockedshapes related
self.lockedShapes = []
self.isInTheSameImage = False
@@ -129,17 +129,20 @@ def mouseMoveEvent(self, ev):
window = self.parent().window()
if window.filePath is not None:
self.parent().window().labelCoordinates.setText(
- 'X: %d; Y: %d' % (pos.x(), pos.y()))
+ "X: %d; Y: %d" % (pos.x(), pos.y())
+ )
# Polygon drawing.
if self.drawing():
- self.overrideCursor(CURSOR_DRAW) # ?
+ self.overrideCursor(CURSOR_DRAW) # ?
if self.current:
# Display annotation width and height while drawing
currentWidth = abs(self.current[0].x() - pos.x())
currentHeight = abs(self.current[0].y() - pos.y())
self.parent().window().labelCoordinates.setText(
- 'Width: %d, Height: %d / X: %d; Y: %d' % (currentWidth, currentHeight, pos.x(), pos.y()))
+ "Width: %d, Height: %d / X: %d; Y: %d"
+ % (currentWidth, currentHeight, pos.x(), pos.y())
+ )
color = self.drawingLineColor
if self.outOfPixmap(pos):
@@ -168,10 +171,10 @@ def mouseMoveEvent(self, ev):
self.line[1] = pos
else:
- self.line[1] = pos # pos is the mouse's current position
+ self.line[1] = pos # pos is the mouse's current position
self.line.line_color = color
- self.prevPoint = QPointF() # ?
+ self.prevPoint = QPointF() # ?
self.current.highlightClear()
else:
self.prevPoint = pos
@@ -185,9 +188,7 @@ def mouseMoveEvent(self, ev):
self.boundedMoveShape(self.selectedShapesCopy, pos)
self.repaint()
elif self.selectedShapes:
- self.selectedShapesCopy = [
- s.copy() for s in self.selectedShapes
- ]
+ self.selectedShapesCopy = [s.copy() for s in self.selectedShapes]
self.repaint()
return
@@ -205,7 +206,7 @@ def mouseMoveEvent(self, ev):
self.repaint()
self.movingShape = True
else:
- #pan
+ # pan
delta_x = pos.x() - self.pan_initial_pos.x()
delta_y = pos.y() - self.pan_initial_pos.y()
self.scrollRequest.emit(delta_x, Qt.Horizontal)
@@ -238,8 +239,7 @@ def mouseMoveEvent(self, ev):
if self.selectedVertex():
self.hShape.highlightClear()
self.hVertex, self.hShape = None, shape
- self.setToolTip(
- "Click & drag to move shape '%s'" % shape.label)
+ self.setToolTip("Click & drag to move shape '%s'" % shape.label)
self.setStatusTip(self.toolTip())
self.overrideCursor(CURSOR_GRAB)
self.update()
@@ -257,7 +257,7 @@ def mousePressEvent(self, ev):
if self.drawing():
# self.handleDrawing(pos) # OLD
if self.current:
- if self.fourpoint: # ADD IF
+ if self.fourpoint: # ADD IF
# Add point to existing shape.
# print('Adding points in mousePressEvent is ', self.line[1])
self.current.addPoint(self.line[1])
@@ -294,8 +294,7 @@ def mouseReleaseEvent(self, ev):
if ev.button() == Qt.RightButton:
menu = self.menus[bool(self.selectedShapesCopy)]
self.restoreCursor()
- if not menu.exec_(self.mapToGlobal(ev.pos()))\
- and self.selectedShapesCopy:
+ if not menu.exec_(self.mapToGlobal(ev.pos())) and self.selectedShapesCopy:
# Cancel the move by deleting the shadow copy.
# self.selectedShapeCopy = None
self.selectedShapesCopy = []
@@ -312,18 +311,15 @@ def mouseReleaseEvent(self, ev):
if self.drawing():
self.handleDrawing(pos)
else:
- #pan
- QApplication.restoreOverrideCursor() # ?
+ # pan
+ QApplication.restoreOverrideCursor() # ?
if self.movingShape and self.hShape:
if self.hShape in self.shapes:
index = self.shapes.index(self.hShape)
- if (
- self.shapesBackups[-1][index].points
- != self.shapes[index].points
- ):
+ if self.shapesBackups[-1][index].points != self.shapes[index].points:
self.storeShapes()
- self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
+ self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
self.movingShape = False
@@ -332,7 +328,7 @@ def endMove(self, copy=False):
assert len(self.selectedShapesCopy) == len(self.selectedShapes)
if copy:
for i, shape in enumerate(self.selectedShapesCopy):
- shape.idx = len(self.shapes) # add current box index
+ shape.idx = len(self.shapes) # add current box index
self.shapes.append(shape)
self.selectedShapes[i].selected = False
self.selectedShapes[i] = shape
@@ -357,14 +353,14 @@ def handleDrawing(self, pos):
if self.fourpoint:
targetPos = self.line[self.pointnum]
self.current.addPoint(targetPos)
- print('current points in handleDrawing is ', self.line[self.pointnum])
+ print("current points in handleDrawing is ", self.line[self.pointnum])
self.update()
if self.pointnum == 3:
self.finalise()
else:
initPos = self.current[0]
- print('initPos', self.current[0])
+ print("initPos", self.current[0])
minX = initPos.x()
minY = initPos.y()
targetPos = self.line[1]
@@ -376,7 +372,7 @@ def handleDrawing(self, pos):
self.finalise()
elif not self.outOfPixmap(pos):
- print('release')
+ print("release")
self.current = Shape()
self.current.addPoint(pos)
self.line.points = [pos, pos]
@@ -399,7 +395,8 @@ def mouseDoubleClickEvent(self, ev):
self.finalise()
def selectShapes(self, shapes):
- for s in shapes: s.seleted = True
+ for s in shapes:
+ s.seleted = True
self.setHiding()
self.selectionChanged.emit(shapes)
self.update()
@@ -416,10 +413,8 @@ def selectShapePoint(self, point, multiple_selection_mode):
self.calculateOffsets(shape, point)
self.setHiding()
if multiple_selection_mode:
- if shape not in self.selectedShapes: # list
- self.selectionChanged.emit(
- self.selectedShapes + [shape]
- )
+ if shape not in self.selectedShapes: # list
+ self.selectionChanged.emit(self.selectedShapes + [shape])
else:
self.selectionChanged.emit([shape])
return
@@ -460,16 +455,24 @@ def boundedMoveVertex(self, pos):
opposite_point_index = (index + 2) % 4
opposite_point = shape[opposite_point_index]
- min_size = min(abs(pos.x() - opposite_point.x()), abs(pos.y() - opposite_point.y()))
+ min_size = min(
+ abs(pos.x() - opposite_point.x()), abs(pos.y() - opposite_point.y())
+ )
directionX = -1 if pos.x() - opposite_point.x() < 0 else 1
directionY = -1 if pos.y() - opposite_point.y() < 0 else 1
- shiftPos = QPointF(opposite_point.x() + directionX * min_size - point.x(),
- opposite_point.y() + directionY * min_size - point.y())
+ shiftPos = QPointF(
+ opposite_point.x() + directionX * min_size - point.x(),
+ opposite_point.y() + directionY * min_size - point.y(),
+ )
else:
shiftPos = pos - point
- if [shape[0].x(), shape[0].y(), shape[2].x(), shape[2].y()] \
- == [shape[3].x(),shape[1].y(),shape[1].x(),shape[3].y()]:
+ if [shape[0].x(), shape[0].y(), shape[2].x(), shape[2].y()] == [
+ shape[3].x(),
+ shape[1].y(),
+ shape[1].x(),
+ shape[3].y(),
+ ]:
shape.moveVertexBy(index, shiftPos)
lindex = (index + 1) % 4
rindex = (index + 3) % 4
@@ -488,7 +491,8 @@ def boundedMoveVertex(self, pos):
shape.moveVertexBy(index, shiftPos)
def boundedMoveShape(self, shapes, pos):
- if type(shapes).__name__ != 'list': shapes = [shapes]
+ if type(shapes).__name__ != "list":
+ shapes = [shapes]
if self.outOfPixmap(pos):
return False # No need to move
o1 = pos + self.offsets[0]
@@ -496,13 +500,15 @@ def boundedMoveShape(self, shapes, pos):
pos -= QPointF(min(0, o1.x()), min(0, o1.y()))
o2 = pos + self.offsets[1]
if self.outOfPixmap(o2):
- pos += QPointF(min(0, self.pixmap.width() - o2.x()),
- min(0, self.pixmap.height() - o2.y()))
+ pos += QPointF(
+ min(0, self.pixmap.width() - o2.x()),
+ min(0, self.pixmap.height() - o2.y()),
+ )
# The next line tracks the new position of the cursor
# relative to the shape, but also results in making it
# a bit "shaky" when nearing the border and allows it to
# go outside of the shape's area for some reason. XXX
- #self.calculateOffsets(self.selectedShape, pos)
+ # self.calculateOffsets(self.selectedShape, pos)
dp = pos - self.prevPoint
if dp:
for shape in shapes:
@@ -514,7 +520,8 @@ def boundedMoveShape(self, shapes, pos):
def deSelectShape(self):
if self.selectedShapes:
- for shape in self.selectedShapes: shape.selected=False
+ for shape in self.selectedShapes:
+ shape.selected = False
self.setHiding(False)
self.selectionChanged.emit([])
self.update()
@@ -595,26 +602,41 @@ def paintEvent(self, event):
p.setPen(self.drawingRectColor)
brush = QBrush(Qt.BDiagPattern)
p.setBrush(brush)
- p.drawRect(int(leftTop.x()), int(leftTop.y()), int(rectWidth), int(rectHeight))
-
+ p.drawRect(
+ int(leftTop.x()), int(leftTop.y()), int(rectWidth), int(rectHeight)
+ )
# ADD:
if (
- self.fillDrawing()
- and self.fourpoint
- and self.current is not None
- and len(self.current.points) >= 2
+ self.fillDrawing()
+ and self.fourpoint
+ and self.current is not None
+ and len(self.current.points) >= 2
):
- print('paint event')
+ print("paint event")
drawing_shape = self.current.copy()
drawing_shape.addPoint(self.line[1])
drawing_shape.fill = True
drawing_shape.paint(p)
- if self.drawing() and not self.prevPoint.isNull() and not self.outOfPixmap(self.prevPoint):
+ if (
+ self.drawing()
+ and not self.prevPoint.isNull()
+ and not self.outOfPixmap(self.prevPoint)
+ ):
p.setPen(QColor(0, 0, 0))
- p.drawLine(int(self.prevPoint.x()), 0, int(self.prevPoint.x()), int(self.pixmap.height()))
- p.drawLine(0, int(self.prevPoint.y()), int(self.pixmap.width()), int(self.prevPoint.y()))
+ p.drawLine(
+ int(self.prevPoint.x()),
+ 0,
+ int(self.prevPoint.x()),
+ int(self.pixmap.height()),
+ )
+ p.drawLine(
+ 0,
+ int(self.prevPoint.y()),
+ int(self.pixmap.width()),
+ int(self.prevPoint.y()),
+ )
self.setAutoFillBackground(True)
if self.verified:
@@ -632,7 +654,7 @@ def paintEvent(self, event):
fontszie = int(max(h, w) / 48)
for s in self.shapes:
s.fontsize = fontszie
-
+
p.end()
def fillDrawing(self):
@@ -665,16 +687,16 @@ def finalise(self):
return
self.current.close()
- self.current.idx = len(self.shapes) # add current box index
- self.shapes.append(self.current)
+ self.current.idx = len(self.shapes) # add current box index
+ self.shapes.append(self.current)
self.current = None
self.setHiding(False)
self.newShape.emit()
self.update()
def closeEnough(self, p1, p2):
- #d = distance(p1 - p2)
- #m = (p1-p2).manhattanLength()
+ # d = distance(p1 - p2)
+ # m = (p1-p2).manhattanLength()
# print "d %.2f, m %d, %.2f" % (d, m, d - m)
return distance(p1 - p2) < self.epsilon
@@ -718,20 +740,20 @@ def keyPressEvent(self, ev):
self.shapesBackups.pop()
self.shapesBackups.append(shapesBackup)
if key == Qt.Key_Escape and self.current:
- print('ESC press')
+ print("ESC press")
self.current = None
self.drawingPolygon.emit(False)
self.update()
elif key == Qt.Key_Return and self.canCloseShape():
self.finalise()
elif key == Qt.Key_Left and self.selectedShapes:
- self.moveOnePixel('Left')
+ self.moveOnePixel("Left")
elif key == Qt.Key_Right and self.selectedShapes:
- self.moveOnePixel('Right')
+ self.moveOnePixel("Right")
elif key == Qt.Key_Up and self.selectedShapes:
- self.moveOnePixel('Up')
+ self.moveOnePixel("Up")
elif key == Qt.Key_Down and self.selectedShapes:
- self.moveOnePixel('Down')
+ self.moveOnePixel("Down")
elif key == Qt.Key_X and self.selectedShapes:
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
@@ -764,25 +786,25 @@ def moveOnePixel(self, direction):
self.selectCountShape = True
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
- if direction == 'Left' and not self.moveOutOfBound(QPointF(-1.0, 0)):
+ if direction == "Left" and not self.moveOutOfBound(QPointF(-1.0, 0)):
# print("move Left one pixel")
self.selectedShape.points[0] += QPointF(-1.0, 0)
self.selectedShape.points[1] += QPointF(-1.0, 0)
self.selectedShape.points[2] += QPointF(-1.0, 0)
self.selectedShape.points[3] += QPointF(-1.0, 0)
- elif direction == 'Right' and not self.moveOutOfBound(QPointF(1.0, 0)):
+ elif direction == "Right" and not self.moveOutOfBound(QPointF(1.0, 0)):
# print("move Right one pixel")
self.selectedShape.points[0] += QPointF(1.0, 0)
self.selectedShape.points[1] += QPointF(1.0, 0)
self.selectedShape.points[2] += QPointF(1.0, 0)
self.selectedShape.points[3] += QPointF(1.0, 0)
- elif direction == 'Up' and not self.moveOutOfBound(QPointF(0, -1.0)):
+ elif direction == "Up" and not self.moveOutOfBound(QPointF(0, -1.0)):
# print("move Up one pixel")
self.selectedShape.points[0] += QPointF(0, -1.0)
self.selectedShape.points[1] += QPointF(0, -1.0)
self.selectedShape.points[2] += QPointF(0, -1.0)
self.selectedShape.points[3] += QPointF(0, -1.0)
- elif direction == 'Down' and not self.moveOutOfBound(QPointF(0, 1.0)):
+ elif direction == "Down" and not self.moveOutOfBound(QPointF(0, 1.0)):
# print("move Down one pixel")
self.selectedShape.points[0] += QPointF(0, 1.0)
self.selectedShape.points[1] += QPointF(0, 1.0)
@@ -795,7 +817,7 @@ def moveOnePixel(self, direction):
self.repaint()
def moveOutOfBound(self, step):
- points = [p1+p2 for p1, p2 in zip(self.selectedShape.points, [step]*4)]
+ points = [p1 + p2 for p1, p2 in zip(self.selectedShape.points, [step] * 4)]
return True in map(self.outOfPixmap, points)
def setLastLabel(self, text, line_color=None, fill_color=None, key_cls=None):
@@ -901,7 +923,7 @@ def restoreShape(self):
shape.selected = False
self.updateShapeIndex()
self.repaint()
-
+
@property
def isShapeRestorable(self):
if len(self.shapesBackups) < 2:
diff --git a/PPOCRLabel/libs/colorDialog.py b/PPOCRLabel/libs/colorDialog.py
index c70ff76549..431600b308 100644
--- a/PPOCRLabel/libs/colorDialog.py
+++ b/PPOCRLabel/libs/colorDialog.py
@@ -22,7 +22,6 @@
class ColorDialog(QColorDialog):
-
def __init__(self, parent=None):
super(ColorDialog, self).__init__(parent)
self.setOption(QColorDialog.ShowAlphaChannel)
diff --git a/PPOCRLabel/libs/constants.py b/PPOCRLabel/libs/constants.py
index f075f4a539..c3483cdb21 100644
--- a/PPOCRLabel/libs/constants.py
+++ b/PPOCRLabel/libs/constants.py
@@ -10,23 +10,23 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
-SETTING_FILENAME = 'filename'
-SETTING_RECENT_FILES = 'recentFiles'
-SETTING_WIN_SIZE = 'window/size'
-SETTING_WIN_POSE = 'window/position'
-SETTING_WIN_GEOMETRY = 'window/geometry'
-SETTING_LINE_COLOR = 'line/color'
-SETTING_FILL_COLOR = 'fill/color'
-SETTING_ADVANCE_MODE = 'advanced'
-SETTING_WIN_STATE = 'window/state'
-SETTING_SAVE_DIR = 'savedir'
-SETTING_PAINT_LABEL = 'paintlabel'
-SETTING_PAINT_INDEX = 'paintindex'
-SETTING_LAST_OPEN_DIR = 'lastOpenDir'
-SETTING_AUTO_SAVE = 'autosave'
-SETTING_SINGLE_CLASS = 'singleclass'
-FORMAT_PASCALVOC='PascalVOC'
-FORMAT_YOLO='YOLO'
-SETTING_DRAW_SQUARE = 'draw/square'
-SETTING_LABEL_FILE_FORMAT= 'labelFileFormat'
-DEFAULT_ENCODING = 'utf-8'
+SETTING_FILENAME = "filename"
+SETTING_RECENT_FILES = "recentFiles"
+SETTING_WIN_SIZE = "window/size"
+SETTING_WIN_POSE = "window/position"
+SETTING_WIN_GEOMETRY = "window/geometry"
+SETTING_LINE_COLOR = "line/color"
+SETTING_FILL_COLOR = "fill/color"
+SETTING_ADVANCE_MODE = "advanced"
+SETTING_WIN_STATE = "window/state"
+SETTING_SAVE_DIR = "savedir"
+SETTING_PAINT_LABEL = "paintlabel"
+SETTING_PAINT_INDEX = "paintindex"
+SETTING_LAST_OPEN_DIR = "lastOpenDir"
+SETTING_AUTO_SAVE = "autosave"
+SETTING_SINGLE_CLASS = "singleclass"
+FORMAT_PASCALVOC = "PascalVOC"
+FORMAT_YOLO = "YOLO"
+SETTING_DRAW_SQUARE = "draw/square"
+SETTING_LABEL_FILE_FORMAT = "labelFileFormat"
+DEFAULT_ENCODING = "utf-8"
diff --git a/PPOCRLabel/libs/create_ml_io.py b/PPOCRLabel/libs/create_ml_io.py
index a2123b265b..e5ce30929a 100644
--- a/PPOCRLabel/libs/create_ml_io.py
+++ b/PPOCRLabel/libs/create_ml_io.py
@@ -18,12 +18,21 @@
from libs.constants import DEFAULT_ENCODING
import os
-JSON_EXT = '.json'
+JSON_EXT = ".json"
ENCODE_METHOD = DEFAULT_ENCODING
class CreateMLWriter:
- def __init__(self, foldername, filename, imgsize, shapes, outputfile, databasesrc='Unknown', localimgpath=None):
+ def __init__(
+ self,
+ foldername,
+ filename,
+ imgsize,
+ shapes,
+ outputfile,
+ databasesrc="Unknown",
+ localimgpath=None,
+ ):
self.foldername = foldername
self.filename = filename
self.databasesrc = databasesrc
@@ -42,10 +51,7 @@ def write(self):
else:
outputdict = []
- outputimagedict = {
- "image": self.filename,
- "annotations": []
- }
+ outputimagedict = {"image": self.filename, "annotations": []}
for shape in self.shapes:
points = shape["points"]
@@ -59,12 +65,7 @@ def write(self):
shapedict = {
"label": shape["label"],
- "coordinates": {
- "x": x,
- "y": y,
- "width": width,
- "height": height
- }
+ "coordinates": {"x": x, "y": y, "width": width, "height": height},
}
outputimagedict["annotations"].append(shapedict)
diff --git a/PPOCRLabel/libs/editinlist.py b/PPOCRLabel/libs/editinlist.py
index 605482ffc9..38920bb8a2 100644
--- a/PPOCRLabel/libs/editinlist.py
+++ b/PPOCRLabel/libs/editinlist.py
@@ -30,4 +30,4 @@ def keyPressEvent(self, event) -> None:
# close edit
if event.key() in [16777220, 16777221]:
for i in range(self.count()):
- self.closePersistentEditor(self.item(i))
\ No newline at end of file
+ self.closePersistentEditor(self.item(i))
diff --git a/PPOCRLabel/libs/hashableQListWidgetItem.py b/PPOCRLabel/libs/hashableQListWidgetItem.py
index 95c496ff56..138ef0e023 100644
--- a/PPOCRLabel/libs/hashableQListWidgetItem.py
+++ b/PPOCRLabel/libs/hashableQListWidgetItem.py
@@ -13,6 +13,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
+
try:
from PyQt5.QtGui import *
from PyQt5.QtCore import *
@@ -24,7 +25,8 @@
# http://stackoverflow.com/questions/21217399/pyqt4-qtcore-qvariant-object-instead-of-a-string
if sys.version_info.major >= 3:
import sip
- sip.setapi('QVariant', 2)
+
+ sip.setapi("QVariant", 2)
from PyQt4.QtGui import *
from PyQt4.QtCore import *
@@ -32,7 +34,6 @@
class HashableQListWidgetItem(QListWidgetItem):
-
def __init__(self, *args):
super(HashableQListWidgetItem, self).__init__(*args)
diff --git a/PPOCRLabel/libs/keyDialog.py b/PPOCRLabel/libs/keyDialog.py
index 1ec8d97147..53ea90ed81 100644
--- a/PPOCRLabel/libs/keyDialog.py
+++ b/PPOCRLabel/libs/keyDialog.py
@@ -6,7 +6,7 @@
from PyQt5.Qt import QT_VERSION_STR
from libs.utils import newIcon, labelValidator
-QT5 = QT_VERSION_STR[0] == '5'
+QT5 = QT_VERSION_STR[0] == "5"
# TODO(unknown):
@@ -26,15 +26,15 @@ def keyPressEvent(self, e):
class KeyDialog(QtWidgets.QDialog):
def __init__(
- self,
- text="Enter object label",
- parent=None,
- labels=None,
- sort_labels=True,
- show_text_field=True,
- completion="startswith",
- fit_to_content=None,
- flags=None,
+ self,
+ text="Enter object label",
+ parent=None,
+ labels=None,
+ sort_labels=True,
+ show_text_field=True,
+ completion="startswith",
+ fit_to_content=None,
+ flags=None,
):
if fit_to_content is None:
fit_to_content = {"row": False, "column": True}
@@ -67,22 +67,16 @@ def __init__(
# label_list
self.labelList = QtWidgets.QListWidget()
if self._fit_to_content["row"]:
- self.labelList.setHorizontalScrollBarPolicy(
- QtCore.Qt.ScrollBarAlwaysOff
- )
+ self.labelList.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
if self._fit_to_content["column"]:
- self.labelList.setVerticalScrollBarPolicy(
- QtCore.Qt.ScrollBarAlwaysOff
- )
+ self.labelList.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
self._sort_labels = sort_labels
if labels:
self.labelList.addItems(labels)
if self._sort_labels:
self.labelList.sortItems()
else:
- self.labelList.setDragDropMode(
- QtWidgets.QAbstractItemView.InternalMove
- )
+ self.labelList.setDragDropMode(QtWidgets.QAbstractItemView.InternalMove)
self.labelList.currentItemChanged.connect(self.labelSelected)
self.labelList.itemDoubleClicked.connect(self.labelDoubleClicked)
self.edit.setListWidget(self.labelList)
@@ -188,9 +182,7 @@ def popUp(self, text=None, move=True, flags=None):
self.labelList.sizeHintForRow(0) * self.labelList.count() + 2
)
if self._fit_to_content["column"]:
- self.labelList.setMinimumWidth(
- self.labelList.sizeHintForColumn(0) + 2
- )
+ self.labelList.setMinimumWidth(self.labelList.sizeHintForColumn(0) + 2)
# if text is None, the previous label in self.edit is kept
if text is None:
text = self.edit.text()
diff --git a/PPOCRLabel/libs/labelDialog.py b/PPOCRLabel/libs/labelDialog.py
index 57071d77b5..9a17d06905 100644
--- a/PPOCRLabel/libs/labelDialog.py
+++ b/PPOCRLabel/libs/labelDialog.py
@@ -24,7 +24,6 @@
class LabelDialog(QDialog):
-
def __init__(self, text="Enter object label", parent=None, listItem=None):
super(LabelDialog, self).__init__(parent)
@@ -43,8 +42,8 @@ def __init__(self, text="Enter object label", parent=None, listItem=None):
layout = QVBoxLayout()
layout.addWidget(self.edit)
self.buttonBox = bb = BB(BB.Ok | BB.Cancel, Qt.Horizontal, self)
- bb.button(BB.Ok).setIcon(newIcon('done'))
- bb.button(BB.Cancel).setIcon(newIcon('undo'))
+ bb.button(BB.Ok).setIcon(newIcon("done"))
+ bb.button(BB.Cancel).setIcon(newIcon("undo"))
bb.accepted.connect(self.validate)
bb.rejected.connect(self.reject)
layout.addWidget(bb)
@@ -77,15 +76,23 @@ def postProcess(self):
self.edit.setText(self.edit.text())
print(self.edit.text())
- def popUp(self, text='', move=True):
+ def popUp(self, text="", move=True):
self.edit.setText(text)
self.edit.setSelection(0, len(text))
self.edit.setFocus(Qt.PopupFocusReason)
if move:
cursor_pos = QCursor.pos()
parent_bottomRight = self.parentWidget().geometry()
- max_x = parent_bottomRight.x() + parent_bottomRight.width() - self.sizeHint().width()
- max_y = parent_bottomRight.y() + parent_bottomRight.height() - self.sizeHint().height()
+ max_x = (
+ parent_bottomRight.x()
+ + parent_bottomRight.width()
+ - self.sizeHint().width()
+ )
+ max_y = (
+ parent_bottomRight.y()
+ + parent_bottomRight.height()
+ - self.sizeHint().height()
+ )
max_global = self.parentWidget().mapToGlobal(QPoint(max_x, max_y))
if cursor_pos.x() > max_global.x():
cursor_pos.setX(max_global.x())
diff --git a/PPOCRLabel/libs/resources.py b/PPOCRLabel/libs/resources.py
index e2dc0383a6..c2f18db888 100644
--- a/PPOCRLabel/libs/resources.py
+++ b/PPOCRLabel/libs/resources.py
@@ -11698,7 +11698,7 @@
\x00\x00\x01\x7a\x50\x6d\xd4\x88\
"
-qt_version = [int(v) for v in QtCore.qVersion().split('.')]
+qt_version = [int(v) for v in QtCore.qVersion().split(".")]
if qt_version < [5, 8, 0]:
rcc_version = 1
qt_resource_struct = qt_resource_struct_v1
@@ -11706,10 +11706,17 @@
rcc_version = 2
qt_resource_struct = qt_resource_struct_v2
+
def qInitResources():
- QtCore.qRegisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data)
+ QtCore.qRegisterResourceData(
+ rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data
+ )
+
def qCleanupResources():
- QtCore.qUnregisterResourceData(rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data)
+ QtCore.qUnregisterResourceData(
+ rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data
+ )
+
qInitResources()
diff --git a/PPOCRLabel/libs/settings.py b/PPOCRLabel/libs/settings.py
index 7f9a4fdda6..58c96332e0 100644
--- a/PPOCRLabel/libs/settings.py
+++ b/PPOCRLabel/libs/settings.py
@@ -22,7 +22,7 @@ def __init__(self):
home = os.path.expanduser("~")
self.data = {}
# self.path = os.path.join(home, '.labelImgSettings.pkl')
- self.path = os.path.join(home, '.autoOCRSettings.pkl')
+ self.path = os.path.join(home, ".autoOCRSettings.pkl")
def __setitem__(self, key, value):
self.data[key] = value
@@ -37,7 +37,7 @@ def get(self, key, default=None):
def save(self):
if self.path:
- with open(self.path, 'wb') as f:
+ with open(self.path, "wb") as f:
pickle.dump(self.data, f, pickle.HIGHEST_PROTOCOL)
return True
return False
@@ -45,16 +45,16 @@ def save(self):
def load(self):
try:
if os.path.exists(self.path):
- with open(self.path, 'rb') as f:
+ with open(self.path, "rb") as f:
self.data = pickle.load(f)
return True
except:
- print('Loading setting failed')
+ print("Loading setting failed")
return False
def reset(self):
if os.path.exists(self.path):
os.remove(self.path)
- print('Remove setting pkl file ${0}'.format(self.path))
+ print("Remove setting pkl file ${0}".format(self.path))
self.data = {}
self.path = None
diff --git a/PPOCRLabel/libs/shape.py b/PPOCRLabel/libs/shape.py
index 3611dfa04e..bd07853b9e 100644
--- a/PPOCRLabel/libs/shape.py
+++ b/PPOCRLabel/libs/shape.py
@@ -47,9 +47,17 @@ class Shape(object):
point_size = 8
scale = 1.0
- def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False, paintIdx=False):
+ def __init__(
+ self,
+ label=None,
+ line_color=None,
+ difficult=False,
+ key_cls="None",
+ paintLabel=False,
+ paintIdx=False,
+ ):
self.label = label
- self.idx = None # bbox order, only for table annotation
+ self.idx = None # bbox order, only for table annotation
self.points = []
self.fill = False
self.selected = False
@@ -88,20 +96,20 @@ def rotatePoint(self, p, theta):
cosTheta = math.cos(theta)
sinTheta = math.sin(theta)
pResx = cosTheta * order.x() + sinTheta * order.y()
- pResy = - sinTheta * order.x() + cosTheta * order.y()
+ pResy = -sinTheta * order.x() + cosTheta * order.y()
pRes = QPointF(self.center.x() + pResx, self.center.y() + pResy)
return pRes
def close(self):
try:
- self.center = QPointF((self.points[0].x() + self.points[2].x()) / 2,
- (self.points[0].y() + self.points[2].y()) / 2)
+ self.center = QPointF(
+ (self.points[0].x() + self.points[2].x()) / 2,
+ (self.points[0].y() + self.points[2].y()) / 2,
+ )
except:
self.center = None
logger = get_logger()
- logger.warning(
- 'The XY coordinates of QPointF are not detectable!'
- )
+ logger.warning("The XY coordinates of QPointF are not detectable!")
self._closed = True
def reachMaxPoints(self):
@@ -186,7 +194,7 @@ def paint(self, painter):
font.setPointSize(self.fontsize)
font.setBold(True)
painter.setFont(font)
- text = ''
+ text = ""
if self.idx != None:
text = str(self.idx)
if min_y < MIN_Y_LABEL:
diff --git a/PPOCRLabel/libs/stringBundle.py b/PPOCRLabel/libs/stringBundle.py
index e50090405a..b0b6d889cd 100644
--- a/PPOCRLabel/libs/stringBundle.py
+++ b/PPOCRLabel/libs/stringBundle.py
@@ -19,24 +19,26 @@
import locale
from libs.ustr import ustr
-__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
-__dirpath__ = os.path.abspath(os.path.join(__dir__, '../resources/strings'))
+__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
+__dirpath__ = os.path.abspath(os.path.join(__dir__, "../resources/strings"))
try:
from PyQt5.QtCore import *
except ImportError:
if sys.version_info.major >= 3:
import sip
- sip.setapi('QVariant', 2)
+
+ sip.setapi("QVariant", 2)
from PyQt4.QtCore import *
class StringBundle:
-
__create_key = object()
def __init__(self, create_key, localeStr):
- assert(create_key == StringBundle.__create_key), "StringBundle must be created using StringBundle.getBundle"
+ assert (
+ create_key == StringBundle.__create_key
+ ), "StringBundle must be created using StringBundle.getBundle"
self.idToMessage = {}
paths = self.__createLookupFallbackList(localeStr)
for path in paths:
@@ -46,34 +48,37 @@ def __init__(self, create_key, localeStr):
def getBundle(cls, localeStr=None):
if localeStr is None:
try:
- localeStr = locale.getlocale()[0] if locale.getlocale() and len(
- locale.getlocale()) > 0 else os.getenv('LANG')
+ localeStr = (
+ locale.getlocale()[0]
+ if locale.getlocale() and len(locale.getlocale()) > 0
+ else os.getenv("LANG")
+ )
except:
- print('Invalid locale')
- localeStr = 'en'
+ print("Invalid locale")
+ localeStr = "en"
return StringBundle(cls.__create_key, localeStr)
def getString(self, stringId):
- assert(stringId in self.idToMessage), "Missing string id : " + stringId
+ assert stringId in self.idToMessage, "Missing string id : " + stringId
return self.idToMessage[stringId]
def __createLookupFallbackList(self, localeStr):
resultPaths = []
- basePath = "\strings" if os.name == 'nt' else "/strings"
+ basePath = "\strings" if os.name == "nt" else "/strings"
resultPaths.append(basePath)
if localeStr is not None:
# Don't follow standard BCP47. Simple fallback
- tags = re.split('[^a-zA-Z]', localeStr)
+ tags = re.split("[^a-zA-Z]", localeStr)
for tag in tags:
lastPath = resultPaths[-1]
- resultPaths.append(lastPath + '-' + tag)
+ resultPaths.append(lastPath + "-" + tag)
resultPaths[-1] = __dirpath__ + resultPaths[-1] + ".properties"
return resultPaths
def __loadBundle(self, path):
- PROP_SEPERATOR = '='
+ PROP_SEPERATOR = "="
f = QFile(path)
if f.exists():
if f.open(QIODevice.ReadOnly | QFile.Text):
diff --git a/PPOCRLabel/libs/toolBar.py b/PPOCRLabel/libs/toolBar.py
index 9a63929ae9..ea0cb56722 100644
--- a/PPOCRLabel/libs/toolBar.py
+++ b/PPOCRLabel/libs/toolBar.py
@@ -20,7 +20,6 @@
class ToolBar(QToolBar):
-
def __init__(self, title):
super(ToolBar, self).__init__(title)
layout = self.layout()
@@ -41,6 +40,7 @@ def addAction(self, action):
class ToolButton(QToolButton):
"""ToolBar companion class which ensures all buttons have the same size."""
+
minSize = (60, 60)
def minimumSizeHint(self):
diff --git a/PPOCRLabel/libs/unique_label_qlist_widget.py b/PPOCRLabel/libs/unique_label_qlist_widget.py
index 07ae05fe67..82e358197e 100644
--- a/PPOCRLabel/libs/unique_label_qlist_widget.py
+++ b/PPOCRLabel/libs/unique_label_qlist_widget.py
@@ -37,7 +37,9 @@ def setItemLabel(self, item, label, color=None):
if color is None:
qlabel.setText(f"{label}")
else:
- qlabel.setText('● {} '.format(*color, label))
+ qlabel.setText(
+ '● {} '.format(*color, label)
+ )
qlabel.setAlignment(Qt.AlignBottom)
# item.setSizeHint(qlabel.sizeHint())
diff --git a/PPOCRLabel/libs/ustr.py b/PPOCRLabel/libs/ustr.py
index b35cf5d3ff..51d8ce4e6d 100644
--- a/PPOCRLabel/libs/ustr.py
+++ b/PPOCRLabel/libs/ustr.py
@@ -13,17 +13,19 @@
import sys
from libs.constants import DEFAULT_ENCODING
+
def ustr(x):
- '''py2/py3 unicode helper'''
+ """py2/py3 unicode helper"""
if sys.version_info < (3, 0, 0):
from PyQt4.QtCore import QString
+
if type(x) == str:
return x.decode(DEFAULT_ENCODING)
if type(x) == QString:
- #https://blog.csdn.net/friendan/article/details/51088476
- #https://blog.csdn.net/xxm524/article/details/74937308
- return unicode(x.toUtf8(), DEFAULT_ENCODING, 'ignore')
+ # https://blog.csdn.net/friendan/article/details/51088476
+ # https://blog.csdn.net/xxm524/article/details/74937308
+ return unicode(x.toUtf8(), DEFAULT_ENCODING, "ignore")
return x
else:
return x
diff --git a/PPOCRLabel/libs/utils.py b/PPOCRLabel/libs/utils.py
index 1bd46ab4da..8bc4f199c9 100644
--- a/PPOCRLabel/libs/utils.py
+++ b/PPOCRLabel/libs/utils.py
@@ -24,12 +24,14 @@
from libs.ustr import ustr
__dir__ = os.path.dirname(os.path.abspath(__file__)) # 获取本程序文件路径
-__iconpath__ = os.path.abspath(os.path.join(__dir__, '../resources/icons'))
+__iconpath__ = os.path.abspath(os.path.join(__dir__, "../resources/icons"))
def newIcon(icon, iconSize=None):
if iconSize is not None:
- return QIcon(QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize, iconSize))
+ return QIcon(
+ QIcon(__iconpath__ + "/" + icon + ".png").pixmap(iconSize, iconSize)
+ )
else:
return QIcon(__iconpath__ + "/" + icon + ".png")
@@ -43,8 +45,17 @@ def newButton(text, icon=None, slot=None):
return b
-def newAction(parent, text, slot=None, shortcut=None, icon=None,
- tip=None, checkable=False, enabled=True, iconSize=None):
+def newAction(
+ parent,
+ text,
+ slot=None,
+ shortcut=None,
+ icon=None,
+ tip=None,
+ checkable=False,
+ enabled=True,
+ iconSize=None,
+):
"""Create a new action and assign callbacks, shortcuts, etc."""
a = QAction(text, parent)
if icon is not None:
@@ -79,11 +90,10 @@ def addActions(widget, actions):
def labelValidator():
- return QRegExpValidator(QRegExp(r'^[^ \t].+'), None)
+ return QRegExpValidator(QRegExp(r"^[^ \t].+"), None)
class struct(object):
-
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@@ -93,13 +103,13 @@ def distance(p):
def fmtShortcut(text):
- mod, key = text.split('+', 1)
- return '%s+%s' % (mod, key)
+ mod, key = text.split("+", 1)
+ return "%s+%s" % (mod, key)
def generateColorByText(text):
s = ustr(text)
- hashCode = int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16)
+ hashCode = int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16)
r = int((hashCode / 255) % 255)
g = int((hashCode / 65025) % 255)
b = int((hashCode / 16581375) % 255)
@@ -107,8 +117,8 @@ def generateColorByText(text):
def have_qstring():
- '''p3/qt5 get rid of QString wrapper as py3 has native unicode str type'''
- return not (sys.version_info.major >= 3 or QT_VERSION_STR.startswith('5.'))
+ """p3/qt5 get rid of QString wrapper as py3 has native unicode str type"""
+ return not (sys.version_info.major >= 3 or QT_VERSION_STR.startswith("5."))
def natural_sort(list, key=lambda s: s):
@@ -118,7 +128,7 @@ def natural_sort(list, key=lambda s: s):
def get_alphanum_key_func(key):
convert = lambda text: int(text) if text.isdigit() else text
- return lambda s: [convert(c) for c in re.split('([0-9]+)', key(s))]
+ return lambda s: [convert(c) for c in re.split("([0-9]+)", key(s))]
sort_key = get_alphanum_key_func(key)
list.sort(key=sort_key)
@@ -129,8 +139,11 @@ def get_rotate_crop_image(img, points):
# author: biyanhua
d = 0.0
for index in range(-1, 3):
- d += -0.5 * (points[index + 1][1] + points[index][1]) * (
- points[index + 1][0] - points[index][0])
+ d += (
+ -0.5
+ * (points[index + 1][1] + points[index][1])
+ * (points[index + 1][0] - points[index][0])
+ )
if d < 0: # counterclockwise
tmp = np.array(points)
points[1], points[3] = tmp[3], tmp[1]
@@ -139,20 +152,31 @@ def get_rotate_crop_image(img, points):
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])))
+ np.linalg.norm(points[2] - points[3]),
+ )
+ )
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])))
- pts_std = np.float32([[0, 0], [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]])
+ np.linalg.norm(points[1] - points[2]),
+ )
+ )
+ pts_std = np.float32(
+ [
+ [0, 0],
+ [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height],
+ ]
+ )
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
- M, (img_crop_width, img_crop_height),
+ M,
+ (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC)
+ flags=cv2.INTER_CUBIC,
+ )
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
@@ -161,7 +185,7 @@ def get_rotate_crop_image(img, points):
print(e)
-def boxPad(box, imgShape, pad : int) -> np.array:
+def boxPad(box, imgShape, pad: int) -> np.array:
"""
Pad a box with [pad] pixels on each side.
"""
@@ -171,20 +195,20 @@ def boxPad(box, imgShape, pad : int) -> np.array:
box[2][0], box[2][1] = box[2][0] + pad, box[2][1] + pad
box[3][0], box[3][1] = box[3][0] - pad, box[3][1] + pad
h, w, _ = imgShape
- box[:,0] = np.clip(box[:,0], 0, w)
- box[:,1] = np.clip(box[:,1], 0, h)
+ box[:, 0] = np.clip(box[:, 0], 0, w)
+ box[:, 1] = np.clip(box[:, 1], 0, h)
return box
def expand_list(merged, html_list):
- '''
+ """
Fill blanks according to merged cells
- '''
+ """
sr, er, sc, ec = merged
for i in range(sr, er):
for j in range(sc, ec):
html_list[i][j] = None
- html_list[sr][sc] = ''
+ html_list[sr][sc] = ""
if ec - sc > 1:
html_list[sr][sc] += " colspan={}".format(ec - sc)
if er - sr > 1:
@@ -193,9 +217,9 @@ def expand_list(merged, html_list):
def convert_token(html_list):
- '''
+ """
Convert raw html to label format
- '''
+ """
token_list = [""]
# final html list:
for row in html_list:
@@ -203,16 +227,16 @@ def convert_token(html_list):
for col in row:
if col == None:
continue
- elif col == 'td':
+ elif col == "td":
token_list.extend(["", " | "])
else:
token_list.append("", " | "])
token_list.append("")
token_list.append("")
@@ -221,106 +245,111 @@ def convert_token(html_list):
def rebuild_html_from_ppstructure_label(label_info):
- from html import escape
- html_code = label_info['html']['structure']['tokens'].copy()
- to_insert = [
- i for i, tag in enumerate(html_code) if tag in ('', '>')
- ]
- for i, cell in zip(to_insert[::-1], label_info['html']['cells'][::-1]):
- if cell['tokens']:
- cell = [
- escape(token) if len(token) == 1 else token
- for token in cell['tokens']
- ]
- cell = ''.join(cell)
- html_code.insert(i + 1, cell)
- html_code = ''.join(html_code)
- html_code = ''.format(
- html_code)
- return html_code
-
-
-def stepsInfo(lang='en'):
- if lang == 'ch':
- msg = "1. 安装与运行:使用上述命令安装与运行程序。\n" \
- "2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n" \
- "3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态为 “X” 的图片进行自动标注。\n" \
- "4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动" \
- "绘制标记框。点击键盘P,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。\n" \
- "5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。\n" \
- "6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别。\n" \
- "7. 内容更改:双击识别结果,对不准确的识别结果进行手动更改。\n" \
- "8. 保存:点击 “保存”,图片状态切换为 “√”,跳转至下一张。\n" \
- "9. 删除:点击 “删除图像”,图片将会被删除至回收站。\n" \
- "10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
- "*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
- "识别标签保存在*rec_gt.txt*中。\n"
+ from html import escape
+
+ html_code = label_info["html"]["structure"]["tokens"].copy()
+ to_insert = [i for i, tag in enumerate(html_code) if tag in (" | ", ">")]
+ for i, cell in zip(to_insert[::-1], label_info["html"]["cells"][::-1]):
+ if cell["tokens"]:
+ cell = [
+ escape(token) if len(token) == 1 else token for token in cell["tokens"]
+ ]
+ cell = "".join(cell)
+ html_code.insert(i + 1, cell)
+ html_code = "".join(html_code)
+ html_code = "".format(html_code)
+ return html_code
+
+
+def stepsInfo(lang="en"):
+ if lang == "ch":
+ msg = (
+ "1. 安装与运行:使用上述命令安装与运行程序。\n"
+ "2. 打开文件夹:在菜单栏点击 “文件” - 打开目录 选择待标记图片的文件夹.\n"
+ "3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态为 “X” 的图片进行自动标注。\n"
+ "4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动"
+ "绘制标记框。点击键盘P,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。\n"
+ "5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。\n"
+ "6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别。\n"
+ "7. 内容更改:双击识别结果,对不准确的识别结果进行手动更改。\n"
+ "8. 保存:点击 “保存”,图片状态切换为 “√”,跳转至下一张。\n"
+ "9. 删除:点击 “删除图像”,图片将会被删除至回收站。\n"
+ "10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的"
+ "*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,"
+ "识别标签保存在*rec_gt.txt*中。\n"
+ )
else:
- msg = "1. Build and launch using the instructions above.\n" \
- "2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n" \
- "3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name." \
- "4. Create Box:\n" \
- "4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n" \
- "4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n" \
- "5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n" \
- "6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n" \
- "7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n" \
- "8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n" \
- "9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n" \
- "10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n" \
- " Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
+ msg = (
+ "1. Build and launch using the instructions above.\n"
+ "2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"
+ "3. Click 'Auto recognition', use PPOCR model to automatically annotate images which marked with 'X' before the file name."
+ "4. Create Box:\n"
+ "4.1 Click 'Create RectBox' or press 'W' in English keyboard mode to draw a new rectangle detection box. Click and release left mouse to select a region to annotate the text area.\n"
+ "4.2 Press 'P' to enter four-point labeling mode which enables you to create any four-point shape by clicking four points with the left mouse button in succession and DOUBLE CLICK the left mouse as the signal of labeling completion.\n"
+ "5. After the marking frame is drawn, the user clicks 'OK', and the detection frame will be pre-assigned a TEMPORARY label.\n"
+ "6. Click re-Recognition, model will rewrite ALL recognition results in ALL detection box.\n"
+ "7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.\n"
+ "8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"
+ "9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"
+ "10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"
+ " Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
+ )
return msg
-def keysInfo(lang='en'):
- if lang == 'ch':
- msg = "快捷键\t\t\t说明\n" \
- "———————————————————————\n" \
- "Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
- "W\t\t\t新建矩形框\n" \
- "Q\t\t\t新建四点框\n" \
- "Ctrl + E\t\t编辑所选框标签\n" \
- "Ctrl + R\t\t重新识别所选标记\n" \
- "Ctrl + C\t\t复制并粘贴选中的标记框\n" \
- "Ctrl + 鼠标左键\t\t多选标记框\n" \
- "Backspace\t\t删除所选框\n" \
- "Ctrl + V\t\t确认本张图片标记\n" \
- "Ctrl + Shift + d\t删除本张图片\n" \
- "D\t\t\t下一张图片\n" \
- "A\t\t\t上一张图片\n" \
- "Ctrl++\t\t\t缩小\n" \
- "Ctrl--\t\t\t放大\n" \
- "↑→↓←\t\t\t移动标记框\n" \
- "———————————————————————\n" \
- "注:Mac用户Command键替换上述Ctrl键"
+def keysInfo(lang="en"):
+ if lang == "ch":
+ msg = (
+ "快捷键\t\t\t说明\n"
+ "———————————————————————\n"
+ "Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n"
+ "W\t\t\t新建矩形框\n"
+ "Q\t\t\t新建四点框\n"
+ "Ctrl + E\t\t编辑所选框标签\n"
+ "Ctrl + R\t\t重新识别所选标记\n"
+ "Ctrl + C\t\t复制并粘贴选中的标记框\n"
+ "Ctrl + 鼠标左键\t\t多选标记框\n"
+ "Backspace\t\t删除所选框\n"
+ "Ctrl + V\t\t确认本张图片标记\n"
+ "Ctrl + Shift + d\t删除本张图片\n"
+ "D\t\t\t下一张图片\n"
+ "A\t\t\t上一张图片\n"
+ "Ctrl++\t\t\t缩小\n"
+ "Ctrl--\t\t\t放大\n"
+ "↑→↓←\t\t\t移动标记框\n"
+ "———————————————————————\n"
+ "注:Mac用户Command键替换上述Ctrl键"
+ )
else:
- msg = "Shortcut Keys\t\tDescription\n" \
- "———————————————————————\n" \
- "Ctrl + shift + R\t\tRe-recognize all the labels\n" \
- "\t\t\tof the current image\n" \
- "\n" \
- "W\t\t\tCreate a rect box\n" \
- "Q\t\t\tCreate a four-points box\n" \
- "Ctrl + E\t\tEdit label of the selected box\n" \
- "Ctrl + R\t\tRe-recognize the selected box\n" \
- "Ctrl + C\t\tCopy and paste the selected\n" \
- "\t\t\tbox\n" \
- "\n" \
- "Ctrl + Left Mouse\tMulti select the label\n" \
- "Button\t\t\tbox\n" \
- "\n" \
- "Backspace\t\tDelete the selected box\n" \
- "Ctrl + V\t\tCheck image\n" \
- "Ctrl + Shift + d\tDelete image\n" \
- "D\t\t\tNext image\n" \
- "A\t\t\tPrevious image\n" \
- "Ctrl++\t\t\tZoom in\n" \
- "Ctrl--\t\t\tZoom out\n" \
- "↑→↓←\t\t\tMove selected box" \
- "———————————————————————\n" \
- "Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
+ msg = (
+ "Shortcut Keys\t\tDescription\n"
+ "———————————————————————\n"
+ "Ctrl + shift + R\t\tRe-recognize all the labels\n"
+ "\t\t\tof the current image\n"
+ "\n"
+ "W\t\t\tCreate a rect box\n"
+ "Q\t\t\tCreate a four-points box\n"
+ "Ctrl + E\t\tEdit label of the selected box\n"
+ "Ctrl + R\t\tRe-recognize the selected box\n"
+ "Ctrl + C\t\tCopy and paste the selected\n"
+ "\t\t\tbox\n"
+ "\n"
+ "Ctrl + Left Mouse\tMulti select the label\n"
+ "Button\t\t\tbox\n"
+ "\n"
+ "Backspace\t\tDelete the selected box\n"
+ "Ctrl + V\t\tCheck image\n"
+ "Ctrl + Shift + d\tDelete image\n"
+ "D\t\t\tNext image\n"
+ "A\t\t\tPrevious image\n"
+ "Ctrl++\t\t\tZoom in\n"
+ "Ctrl--\t\t\tZoom out\n"
+ "↑→↓←\t\t\tMove selected box"
+ "———————————————————————\n"
+ "Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
+ )
return msg
diff --git a/PPOCRLabel/libs/zoomWidget.py b/PPOCRLabel/libs/zoomWidget.py
index 4512d2ab4f..d0fbc38f7e 100644
--- a/PPOCRLabel/libs/zoomWidget.py
+++ b/PPOCRLabel/libs/zoomWidget.py
@@ -20,14 +20,13 @@
class ZoomWidget(QSpinBox):
-
def __init__(self, value=100):
super(ZoomWidget, self).__init__()
self.setButtonSymbols(QAbstractSpinBox.NoButtons)
self.setRange(1, 500)
- self.setSuffix(' %')
+ self.setSuffix(" %")
self.setValue(value)
- self.setToolTip(u'Zoom Level')
+ self.setToolTip("Zoom Level")
self.setStatusTip(self.toolTip())
self.setAlignment(Qt.AlignCenter)
diff --git a/PPOCRLabel/setup.py b/PPOCRLabel/setup.py
index 9770b632bd..93a6c95fab 100644
--- a/PPOCRLabel/setup.py
+++ b/PPOCRLabel/setup.py
@@ -15,38 +15,41 @@
from setuptools import setup
from io import open
-with open('requirements.txt', encoding="utf-8-sig") as f:
+with open("requirements.txt", encoding="utf-8-sig") as f:
requirements = f.readlines()
- requirements.append('tqdm')
+ requirements.append("tqdm")
def readme():
- with open('README.md', encoding="utf-8-sig") as f:
+ with open("README.md", encoding="utf-8-sig") as f:
README = f.read()
return README
setup(
- name='PPOCRLabel',
- packages=['PPOCRLabel'],
- package_data = {'PPOCRLabel': ['libs/*','resources/strings/*','resources/icons/*']},
- package_dir={'PPOCRLabel': ''},
+ name="PPOCRLabel",
+ packages=["PPOCRLabel"],
+ package_data={"PPOCRLabel": ["libs/*", "resources/strings/*", "resources/icons/*"]},
+ package_dir={"PPOCRLabel": ""},
include_package_data=True,
entry_points={"console_scripts": ["PPOCRLabel= PPOCRLabel.PPOCRLabel:main"]},
- version='2.1.3',
+ version="2.1.3",
install_requires=requirements,
- license='Apache License 2.0',
- description='PPOCRLabelv2 is a semi-automatic graphic annotation tool suitable for OCR field, with built-in PP-OCR model to automatically detect and re-recognize data. It is written in Python3 and PyQT5, supporting rectangular box, table, irregular text and key information annotation modes. Annotations can be directly used for the training of PP-OCR detection and recognition models.',
+ license="Apache License 2.0",
+ description="PPOCRLabelv2 is a semi-automatic graphic annotation tool suitable for OCR field, with built-in PP-OCR model to automatically detect and re-recognize data. It is written in Python3 and PyQT5, supporting rectangular box, table, irregular text and key information annotation modes. Annotations can be directly used for the training of PP-OCR detection and recognition models.",
long_description=readme(),
- long_description_content_type='text/markdown',
- url='https://github.com/PaddlePaddle/PaddleOCR',
- download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
+ long_description_content_type="text/markdown",
+ url="https://github.com/PaddlePaddle/PaddleOCR",
+ download_url="https://github.com/PaddlePaddle/PaddleOCR.git",
keywords=[
- 'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
+ "ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition"
],
classifiers=[
- 'Intended Audience :: Developers', 'Operating System :: OS Independent',
- 'Natural Language :: English',
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
- ], )
\ No newline at end of file
+ "Intended Audience :: Developers",
+ "Operating System :: OS Independent",
+ "Natural Language :: English",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Topic :: Utilities",
+ ],
+)
diff --git a/README.md b/README.md
index 54af81f186..78df1c2fb1 100755
--- a/README.md
+++ b/README.md
@@ -62,7 +62,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
- 在线免费体验:
- PP-OCRv4 在线体验地址:https://aistudio.baidu.com/application/detail/7658
- PP-ChatOCRv2 在线体验地址:https://aistudio.baidu.com/application/detail/10368
-
+
- 一行命令快速使用:[快速开始(中英文/多语言/文档分析)](./doc/doc_ch/quickstart.md)
- 移动端demo体验:[安装包DEMO下载地址](https://ai.baidu.com/easyedge/app/openSource?from=paddlelite)(基于EasyEdge和Paddle-Lite, 支持iOS和Android系统)
diff --git a/README_en.md b/README_en.md
index fa2789e889..747b233b60 100644
--- a/README_en.md
+++ b/README_en.md
@@ -74,7 +74,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
- One line of code quick use: [Quick Start(Chinese/English/Multilingual/Document Analysis](./doc/doc_en/quickstart_en.md)
- Full-process experience of training, inference, and high-performance deployment in the Paddle AI suite (PaddleX):
- PP-OCRv4:https://aistudio.baidu.com/aistudio/modelsdetail?modelId=286
- - PP-ChatOCR:https://aistudio.baidu.com/aistudio/modelsdetail?modelId=332
+ - PP-ChatOCR:https://aistudio.baidu.com/aistudio/modelsdetail?modelId=332
- Mobile demo experience:[Installation DEMO](https://ai.baidu.com/easyedge/app/openSource?from=paddlelite)(Based on EasyEdge and Paddle-Lite, support iOS and Android systems)
diff --git a/StyleText/arch/base_module.py b/StyleText/arch/base_module.py
index da2b6b834c..e7b44c02ce 100644
--- a/StyleText/arch/base_module.py
+++ b/StyleText/arch/base_module.py
@@ -18,19 +18,21 @@
class CBN(nn.Layer):
- def __init__(self,
- name,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- use_bias=False,
- norm_layer=None,
- act=None,
- act_attr=None):
+ def __init__(
+ self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None,
+ ):
super(CBN, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
@@ -45,16 +47,17 @@ def __init__(self,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
- bias_attr=bias_attr)
+ bias_attr=bias_attr,
+ )
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
- num_features=out_channels, name=name + "_bn")
+ num_features=out_channels, name=name + "_bn"
+ )
else:
self._norm_layer = None
if act:
if act_attr:
- self._act = getattr(paddle.nn, act)(**act_attr,
- name=name + "_" + act)
+ self._act = getattr(paddle.nn, act)(**act_attr, name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
@@ -70,19 +73,21 @@ def forward(self, x):
class SNConv(nn.Layer):
- def __init__(self,
- name,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- use_bias=False,
- norm_layer=None,
- act=None,
- act_attr=None):
+ def __init__(
+ self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None,
+ ):
super(SNConv, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
@@ -98,16 +103,18 @@ def __init__(self,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
- bias_attr=bias_attr))
+ bias_attr=bias_attr,
+ )
+ )
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
- num_features=out_channels, name=name + "_bn")
+ num_features=out_channels, name=name + "_bn"
+ )
else:
self._norm_layer = None
if act:
if act_attr:
- self._act = getattr(paddle.nn, act)(**act_attr,
- name=name + "_" + act)
+ self._act = getattr(paddle.nn, act)(**act_attr, name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
@@ -123,20 +130,22 @@ def forward(self, x):
class SNConvTranspose(nn.Layer):
- def __init__(self,
- name,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- output_padding=0,
- dilation=1,
- groups=1,
- use_bias=False,
- norm_layer=None,
- act=None,
- act_attr=None):
+ def __init__(
+ self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None,
+ ):
super(SNConvTranspose, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
@@ -153,16 +162,18 @@ def __init__(self,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
- bias_attr=bias_attr))
+ bias_attr=bias_attr,
+ )
+ )
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
- num_features=out_channels, name=name + "_bn")
+ num_features=out_channels, name=name + "_bn"
+ )
else:
self._norm_layer = None
if act:
if act_attr:
- self._act = getattr(paddle.nn, act)(**act_attr,
- name=name + "_" + act)
+ self._act = getattr(paddle.nn, act)(**act_attr, name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
@@ -178,8 +189,7 @@ def forward(self, x):
class MiddleNet(nn.Layer):
- def __init__(self, name, in_channels, mid_channels, out_channels,
- use_bias):
+ def __init__(self, name, in_channels, mid_channels, out_channels, use_bias):
super(MiddleNet, self).__init__()
self._sn_conv1 = SNConv(
name=name + "_sn_conv1",
@@ -188,23 +198,25 @@ def __init__(self, name, in_channels, mid_channels, out_channels,
kernel_size=1,
use_bias=use_bias,
norm_layer=None,
- act=None)
+ act=None,
+ )
self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
self._sn_conv2 = SNConv(
name=name + "_sn_conv2",
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
- use_bias=use_bias)
+ use_bias=use_bias,
+ )
self._sn_conv3 = SNConv(
name=name + "_sn_conv3",
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
- use_bias=use_bias)
+ use_bias=use_bias,
+ )
def forward(self, x):
-
sn_conv1 = self._sn_conv1.forward(x)
pad_2d = self._pad2d.forward(sn_conv1)
sn_conv2 = self._sn_conv2.forward(pad_2d)
@@ -213,8 +225,7 @@ def forward(self, x):
class ResBlock(nn.Layer):
- def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
- use_bias):
+ def __init__(self, name, channels, norm_layer, use_dropout, use_dilation, use_bias):
super(ResBlock, self).__init__()
if use_dilation:
padding_mat = [1, 1, 1, 1]
@@ -231,7 +242,8 @@ def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
norm_layer=norm_layer,
use_bias=use_bias,
act="ReLU",
- act_attr=None)
+ act_attr=None,
+ )
if use_dropout:
self._dropout = nn.Dropout(0.5)
else:
@@ -245,7 +257,8 @@ def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
norm_layer=norm_layer,
use_bias=use_bias,
act="ReLU",
- act_attr=None)
+ act_attr=None,
+ )
def forward(self, x):
pad1 = self._pad1.forward(x)
diff --git a/StyleText/arch/decoder.py b/StyleText/arch/decoder.py
index 36f07c5998..5d613265ed 100644
--- a/StyleText/arch/decoder.py
+++ b/StyleText/arch/decoder.py
@@ -18,9 +18,21 @@
class Decoder(nn.Layer):
- def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
- act, act_attr, conv_block_dropout, conv_block_num,
- conv_block_dilation, out_conv_act, out_conv_act_attr):
+ def __init__(
+ self,
+ name,
+ encode_dim,
+ out_channels,
+ use_bias,
+ norm_layer,
+ act,
+ act_attr,
+ conv_block_dropout,
+ conv_block_num,
+ conv_block_dilation,
+ out_conv_act,
+ out_conv_act_attr,
+ ):
super(Decoder, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
@@ -31,7 +43,9 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
- use_bias=use_bias))
+ use_bias=use_bias,
+ )
+ )
self.conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
@@ -44,7 +58,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 4,
@@ -56,7 +71,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 2,
@@ -68,7 +84,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
@@ -78,7 +95,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
- act_attr=out_conv_act_attr)
+ act_attr=out_conv_act_attr,
+ )
def forward(self, x):
if isinstance(x, (list, tuple)):
@@ -94,9 +112,21 @@ def forward(self, x):
class DecoderUnet(nn.Layer):
- def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
- act, act_attr, conv_block_dropout, conv_block_num,
- conv_block_dilation, out_conv_act, out_conv_act_attr):
+ def __init__(
+ self,
+ name,
+ encode_dim,
+ out_channels,
+ use_bias,
+ norm_layer,
+ act,
+ act_attr,
+ conv_block_dropout,
+ conv_block_num,
+ conv_block_dilation,
+ out_conv_act,
+ out_conv_act_attr,
+ ):
super(DecoderUnet, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
@@ -107,7 +137,9 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
- use_bias=use_bias))
+ use_bias=use_bias,
+ )
+ )
self._conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
@@ -120,7 +152,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 8,
@@ -132,7 +165,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 4,
@@ -144,7 +178,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
@@ -154,29 +189,40 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
- act_attr=out_conv_act_attr)
+ act_attr=out_conv_act_attr,
+ )
def forward(self, x, y, feature2, feature1):
output_dict = dict()
- output_dict["conv_blocks"] = self._conv_blocks(
- paddle.concat(
- (x, y), axis=1))
+ output_dict["conv_blocks"] = self._conv_blocks(paddle.concat((x, y), axis=1))
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(
- paddle.concat(
- (output_dict["up1"], feature2), axis=1))
+ paddle.concat((output_dict["up1"], feature2), axis=1)
+ )
output_dict["up3"] = self._up3.forward(
- paddle.concat(
- (output_dict["up2"], feature1), axis=1))
+ paddle.concat((output_dict["up2"], feature1), axis=1)
+ )
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
class SingleDecoder(nn.Layer):
- def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
- act, act_attr, conv_block_dropout, conv_block_num,
- conv_block_dilation, out_conv_act, out_conv_act_attr):
+ def __init__(
+ self,
+ name,
+ encode_dim,
+ out_channels,
+ use_bias,
+ norm_layer,
+ act,
+ act_attr,
+ conv_block_dropout,
+ conv_block_num,
+ conv_block_dilation,
+ out_conv_act,
+ out_conv_act_attr,
+ ):
super(SingleDecoder, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
@@ -187,7 +233,9 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
- use_bias=use_bias))
+ use_bias=use_bias,
+ )
+ )
self._conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
@@ -200,7 +248,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 8,
@@ -212,7 +261,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 4,
@@ -224,7 +274,8 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
@@ -234,18 +285,19 @@ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
- act_attr=out_conv_act_attr)
+ act_attr=out_conv_act_attr,
+ )
def forward(self, x, feature2, feature1):
output_dict = dict()
output_dict["conv_blocks"] = self._conv_blocks.forward(x)
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(
- paddle.concat(
- (output_dict["up1"], feature2), axis=1))
+ paddle.concat((output_dict["up1"], feature2), axis=1)
+ )
output_dict["up3"] = self._up3.forward(
- paddle.concat(
- (output_dict["up2"], feature1), axis=1))
+ paddle.concat((output_dict["up2"], feature1), axis=1)
+ )
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
diff --git a/StyleText/arch/encoder.py b/StyleText/arch/encoder.py
index b884cda293..c545b1c124 100644
--- a/StyleText/arch/encoder.py
+++ b/StyleText/arch/encoder.py
@@ -18,9 +18,19 @@
class Encoder(nn.Layer):
- def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
- act, act_attr, conv_block_dropout, conv_block_num,
- conv_block_dilation):
+ def __init__(
+ self,
+ name,
+ in_channels,
+ encode_dim,
+ use_bias,
+ norm_layer,
+ act,
+ act_attr,
+ conv_block_dropout,
+ conv_block_num,
+ conv_block_dilation,
+ ):
super(Encoder, self).__init__()
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
self._in_conv = SNConv(
@@ -31,7 +41,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._down1 = SNConv(
name=name + "_down1",
in_channels=encode_dim,
@@ -42,7 +53,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._down2 = SNConv(
name=name + "_down2",
in_channels=encode_dim * 2,
@@ -53,7 +65,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._down3 = SNConv(
name=name + "_down3",
in_channels=encode_dim * 4,
@@ -64,7 +77,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
@@ -74,7 +88,9 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
- use_bias=use_bias))
+ use_bias=use_bias,
+ )
+ )
self._conv_blocks = nn.Sequential(*conv_blocks)
def forward(self, x):
@@ -89,8 +105,9 @@ def forward(self, x):
class EncoderUnet(nn.Layer):
- def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
- act, act_attr):
+ def __init__(
+ self, name, in_channels, encode_dim, use_bias, norm_layer, act, act_attr
+ ):
super(EncoderUnet, self).__init__()
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
self._in_conv = SNConv(
@@ -101,7 +118,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._down1 = SNConv(
name=name + "_down1",
in_channels=encode_dim,
@@ -112,7 +130,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._down2 = SNConv(
name=name + "_down2",
in_channels=encode_dim * 2,
@@ -123,7 +142,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._down3 = SNConv(
name=name + "_down3",
in_channels=encode_dim * 2,
@@ -134,7 +154,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._down4 = SNConv(
name=name + "_down4",
in_channels=encode_dim * 2,
@@ -145,7 +166,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 2,
@@ -156,7 +178,8 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 4,
@@ -167,20 +190,22 @@ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
- act_attr=act_attr)
+ act_attr=act_attr,
+ )
def forward(self, x):
output_dict = dict()
x = self._pad2d(x)
- output_dict['in_conv'] = self._in_conv.forward(x)
- output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
- output_dict['down2'] = self._down2.forward(output_dict['down1'])
- output_dict['down3'] = self._down3.forward(output_dict['down2'])
- output_dict['down4'] = self._down4.forward(output_dict['down3'])
- output_dict['up1'] = self._up1.forward(output_dict['down4'])
- output_dict['up2'] = self._up2.forward(
- paddle.concat(
- (output_dict['down3'], output_dict['up1']), axis=1))
- output_dict['concat'] = paddle.concat(
- (output_dict['down2'], output_dict['up2']), axis=1)
+ output_dict["in_conv"] = self._in_conv.forward(x)
+ output_dict["down1"] = self._down1.forward(output_dict["in_conv"])
+ output_dict["down2"] = self._down2.forward(output_dict["down1"])
+ output_dict["down3"] = self._down3.forward(output_dict["down2"])
+ output_dict["down4"] = self._down4.forward(output_dict["down3"])
+ output_dict["up1"] = self._up1.forward(output_dict["down4"])
+ output_dict["up2"] = self._up2.forward(
+ paddle.concat((output_dict["down3"], output_dict["up1"]), axis=1)
+ )
+ output_dict["concat"] = paddle.concat(
+ (output_dict["down2"], output_dict["up2"]), axis=1
+ )
return output_dict
diff --git a/StyleText/arch/spectral_norm.py b/StyleText/arch/spectral_norm.py
index 21d0afc8d4..a19f8cdef1 100644
--- a/StyleText/arch/spectral_norm.py
+++ b/StyleText/arch/spectral_norm.py
@@ -16,20 +16,21 @@
import paddle.nn.functional as F
-def normal_(x, mean=0., std=1.):
+def normal_(x, mean=0.0, std=1.0):
temp_value = paddle.normal(mean, std, shape=x.shape)
x.set_value(temp_value)
return x
class SpectralNorm(object):
- def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ def __init__(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
- raise ValueError('Expected n_power_iterations to be positive, but '
- 'got n_power_iterations={}'.format(
- n_power_iterations))
+ raise ValueError(
+ "Expected n_power_iterations to be positive, but "
+ "got n_power_iterations={}".format(n_power_iterations)
+ )
self.n_power_iterations = n_power_iterations
self.eps = eps
@@ -37,19 +38,18 @@ def reshape_weight_to_matrix(self, weight):
weight_mat = weight
if self.dim != 0:
# transpose dim to front
- weight_mat = weight_mat.transpose([
- self.dim,
- * [d for d in range(weight_mat.dim()) if d != self.dim]
- ])
+ weight_mat = weight_mat.transpose(
+ [self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]]
+ )
height = weight_mat.shape[0]
return weight_mat.reshape([height, -1])
def compute_weight(self, module, do_power_iteration):
- weight = getattr(module, self.name + '_orig')
- u = getattr(module, self.name + '_u')
- v = getattr(module, self.name + '_v')
+ weight = getattr(module, self.name + "_orig")
+ u = getattr(module, self.name + "_u")
+ v = getattr(module, self.name + "_v")
weight_mat = self.reshape_weight_to_matrix(weight)
if do_power_iteration:
@@ -58,18 +58,20 @@ def compute_weight(self, module, do_power_iteration):
v.set_value(
F.normalize(
paddle.matmul(
- weight_mat,
- u,
- transpose_x=True,
- transpose_y=False),
+ weight_mat, u, transpose_x=True, transpose_y=False
+ ),
axis=0,
- epsilon=self.eps, ))
+ epsilon=self.eps,
+ )
+ )
u.set_value(
F.normalize(
paddle.matmul(weight_mat, v),
axis=0,
- epsilon=self.eps, ))
+ epsilon=self.eps,
+ )
+ )
if self.n_power_iterations > 0:
u = u.clone()
v = v.clone()
@@ -82,9 +84,9 @@ def remove(self, module):
with paddle.no_grad():
weight = self.compute_weight(module, do_power_iteration=False)
delattr(module, self.name)
- delattr(module, self.name + '_u')
- delattr(module, self.name + '_v')
- delattr(module, self.name + '_orig')
+ delattr(module, self.name + "_u")
+ delattr(module, self.name + "_v")
+ delattr(module, self.name + "_orig")
module.add_parameter(self.name, weight.detach())
@@ -92,8 +94,8 @@ def __call__(self, module, inputs):
setattr(
module,
self.name,
- self.compute_weight(
- module, do_power_iteration=module.training))
+ self.compute_weight(module, do_power_iteration=module.training),
+ )
@staticmethod
def apply(module, name, n_power_iterations, dim, eps):
@@ -101,7 +103,8 @@ def apply(module, name, n_power_iterations, dim, eps):
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError(
"Cannot register two spectral_norm hooks on "
- "the same parameter {}".format(name))
+ "the same parameter {}".format(name)
+ )
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
@@ -112,9 +115,9 @@ def apply(module, name, n_power_iterations, dim, eps):
# randomly initialize u and v
u = module.create_parameter([h])
- u = normal_(u, 0., 1.)
+ u = normal_(u, 0.0, 1.0)
v = module.create_parameter([w])
- v = normal_(v, 0., 1.)
+ v = normal_(v, 0.0, 1.0)
u = F.normalize(u, axis=0, epsilon=fn.eps)
v = F.normalize(v, axis=0, epsilon=fn.eps)
@@ -134,15 +137,12 @@ def apply(module, name, n_power_iterations, dim, eps):
return fn
-def spectral_norm(module,
- name='weight',
- n_power_iterations=1,
- eps=1e-12,
- dim=None):
-
+def spectral_norm(module, name="weight", n_power_iterations=1, eps=1e-12, dim=None):
if dim is None:
- if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
- nn.Conv3DTranspose, nn.Linear)):
+ if isinstance(
+ module,
+ (nn.Conv1DTranspose, nn.Conv2DTranspose, nn.Conv3DTranspose, nn.Linear),
+ ):
dim = 1
else:
dim = 0
diff --git a/StyleText/arch/style_text_rec.py b/StyleText/arch/style_text_rec.py
index 599927ce3e..d8b38238e9 100644
--- a/StyleText/arch/style_text_rec.py
+++ b/StyleText/arch/style_text_rec.py
@@ -25,32 +25,32 @@ class StyleTextRec(nn.Layer):
def __init__(self, config):
super(StyleTextRec, self).__init__()
self.logger = get_logger()
- self.text_generator = TextGenerator(config["Predictor"][
- "text_generator"])
- self.bg_generator = BgGeneratorWithMask(config["Predictor"][
- "bg_generator"])
- self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
- "fusion_generator"])
+ self.text_generator = TextGenerator(config["Predictor"]["text_generator"])
+ self.bg_generator = BgGeneratorWithMask(config["Predictor"]["bg_generator"])
+ self.fusion_generator = FusionGeneratorSimple(
+ config["Predictor"]["fusion_generator"]
+ )
bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
- text_generator_pretrain = config["Predictor"]["text_generator"][
- "pretrain"]
- fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
- "pretrain"]
+ text_generator_pretrain = config["Predictor"]["text_generator"]["pretrain"]
+ fusion_generator_pretrain = config["Predictor"]["fusion_generator"]["pretrain"]
load_dygraph_pretrain(
self.bg_generator,
self.logger,
path=bg_generator_pretrain,
- load_static_weights=False)
+ load_static_weights=False,
+ )
load_dygraph_pretrain(
self.text_generator,
self.logger,
path=text_generator_pretrain,
- load_static_weights=False)
+ load_static_weights=False,
+ )
load_dygraph_pretrain(
self.fusion_generator,
self.logger,
path=fusion_generator_pretrain,
- load_static_weights=False)
+ load_static_weights=False,
+ )
def forward(self, style_input, text_input):
text_gen_output = self.text_generator.forward(style_input, text_input)
@@ -95,7 +95,8 @@ def __init__(self, config):
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
- conv_block_dilation=conv_block_dilation)
+ conv_block_dilation=conv_block_dilation,
+ )
self.encoder_style = Encoder(
name=name + "_encoder_style",
in_channels=3,
@@ -106,7 +107,8 @@ def __init__(self, config):
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
- conv_block_dilation=conv_block_dilation)
+ conv_block_dilation=conv_block_dilation,
+ )
self.decoder_text = Decoder(
name=name + "_decoder_text",
encode_dim=encode_dim,
@@ -119,7 +121,8 @@ def __init__(self, config):
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Tanh",
- out_conv_act_attr=None)
+ out_conv_act_attr=None,
+ )
self.decoder_sk = Decoder(
name=name + "_decoder_sk",
encode_dim=encode_dim,
@@ -132,22 +135,24 @@ def __init__(self, config):
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Sigmoid",
- out_conv_act_attr=None)
+ out_conv_act_attr=None,
+ )
self.middle = MiddleNet(
name=name + "_middle_net",
in_channels=int(encode_dim / 2) + 1,
mid_channels=encode_dim,
out_channels=3,
- use_bias=use_bias)
+ use_bias=use_bias,
+ )
def forward(self, style_input, text_input):
style_feature = self.encoder_style.forward(style_input)["res_blocks"]
text_feature = self.encoder_text.forward(text_input)["res_blocks"]
- fake_c_temp = self.decoder_text.forward([text_feature,
- style_feature])["out_conv"]
- fake_sk = self.decoder_sk.forward([text_feature,
- style_feature])["out_conv"]
+ fake_c_temp = self.decoder_text.forward([text_feature, style_feature])[
+ "out_conv"
+ ]
+ fake_sk = self.decoder_sk.forward([text_feature, style_feature])["out_conv"]
fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
return {"fake_sk": fake_sk, "fake_text": fake_text}
@@ -178,7 +183,8 @@ def __init__(self, config):
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
- conv_block_dilation=conv_block_dilation)
+ conv_block_dilation=conv_block_dilation,
+ )
self.decoder_bg = SingleDecoder(
name=name + "_decoder_bg",
@@ -192,7 +198,8 @@ def __init__(self, config):
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Tanh",
- out_conv_act_attr=None)
+ out_conv_act_attr=None,
+ )
self.decoder_mask = Decoder(
name=name + "_decoder_mask",
@@ -206,27 +213,30 @@ def __init__(self, config):
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Sigmoid",
- out_conv_act_attr=None)
+ out_conv_act_attr=None,
+ )
self.middle = MiddleNet(
name=name + "_middle_net",
in_channels=3 + 1,
mid_channels=encode_dim,
out_channels=3,
- use_bias=use_bias)
+ use_bias=use_bias,
+ )
def forward(self, style_input):
encode_bg_output = self.encoder_bg(style_input)
- decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
- encode_bg_output["down2"],
- encode_bg_output["down1"])
+ decode_bg_output = self.decoder_bg(
+ encode_bg_output["res_blocks"],
+ encode_bg_output["down2"],
+ encode_bg_output["down1"],
+ )
fake_c_temp = decode_bg_output["out_conv"]
- fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
- "res_blocks"])["out_conv"]
- fake_bg = self.middle(
- paddle.concat(
- (fake_c_temp, fake_bg_mask), axis=1))
+ fake_bg_mask = self.decoder_mask.forward(encode_bg_output["res_blocks"])[
+ "out_conv"
+ ]
+ fake_bg = self.middle(paddle.concat((fake_c_temp, fake_bg_mask), axis=1))
return {
"bg_encode_feature": encode_bg_output["res_blocks"],
"bg_decode_feature1": decode_bg_output["up1"],
@@ -257,7 +267,8 @@ def __init__(self, config):
padding=1,
groups=1,
weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
- bias_attr=False)
+ bias_attr=False,
+ )
self._res_block = ResBlock(
name="{}_conv_block".format(name),
@@ -265,7 +276,8 @@ def __init__(self, config):
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
- use_bias=use_bias)
+ use_bias=use_bias,
+ )
self._reduce_conv = nn.Conv2D(
in_channels=encode_dim,
@@ -275,7 +287,8 @@ def __init__(self, config):
padding=1,
groups=1,
weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
- bias_attr=False)
+ bias_attr=False,
+ )
def forward(self, fake_text, fake_bg):
fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
diff --git a/StyleText/engine/corpus_generators.py b/StyleText/engine/corpus_generators.py
index 186d15f36d..65c917c2a6 100644
--- a/StyleText/engine/corpus_generators.py
+++ b/StyleText/engine/corpus_generators.py
@@ -21,11 +21,13 @@ def __init__(self, config):
self.logger = get_logger()
self.logger.info("using FileCorpus")
- self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ self.char_list = (
+ " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ )
corpus_file = config["CorpusGenerator"]["corpus_file"]
self.language = config["CorpusGenerator"]["language"]
- with open(corpus_file, 'r') as f:
+ with open(corpus_file, "r") as f:
corpus_raw = f.read()
self.corpus_list = corpus_raw.split("\n")[:-1]
assert len(self.corpus_list) > 0
diff --git a/StyleText/engine/predictors.py b/StyleText/engine/predictors.py
index ca9ab9ce6f..43b79d39d6 100644
--- a/StyleText/engine/predictors.py
+++ b/StyleText/engine/predictors.py
@@ -23,12 +23,13 @@
class StyleTextRecPredictor(object):
def __init__(self, config):
- algorithm = config['Predictor']['algorithm']
- assert algorithm in ["StyleTextRec"
- ], "Generator {} not supported.".format(algorithm)
- use_gpu = config["Global"]['use_gpu']
+ algorithm = config["Predictor"]["algorithm"]
+ assert algorithm in ["StyleTextRec"], "Generator {} not supported.".format(
+ algorithm
+ )
+ use_gpu = config["Global"]["use_gpu"]
check_gpu(use_gpu)
- paddle.set_device('gpu' if use_gpu else 'cpu')
+ paddle.set_device("gpu" if use_gpu else "cpu")
self.logger = get_logger()
self.generator = getattr(style_text_rec, algorithm)(config)
self.height = config["Global"]["image_height"]
@@ -41,8 +42,7 @@ def __init__(self, config):
def reshape_to_same_height(self, img_list):
h = img_list[0].shape[0]
for idx in range(1, len(img_list)):
- new_w = round(1.0 * img_list[idx].shape[1] /
- img_list[idx].shape[0] * h)
+ new_w = round(1.0 * img_list[idx].shape[1] / img_list[idx].shape[0] * h)
img_list[idx] = cv2.resize(img_list[idx], (new_w, h))
return img_list
@@ -50,8 +50,9 @@ def predict_single_image(self, style_input, text_input):
style_input = self.rep_style_input(style_input, text_input)
tensor_style_input = self.preprocess(style_input)
tensor_text_input = self.preprocess(text_input)
- style_text_result = self.generator.forward(tensor_style_input,
- tensor_text_input)
+ style_text_result = self.generator.forward(
+ tensor_style_input, tensor_text_input
+ )
fake_fusion = self.postprocess(style_text_result["fake_fusion"])
fake_text = self.postprocess(style_text_result["fake_text"])
fake_sk = self.postprocess(style_text_result["fake_sk"])
@@ -88,7 +89,7 @@ def predict(self, style_input, text_input_list):
return synth_result
def preprocess(self, img):
- img = (img.astype('float32') * self.scale - self.mean) / self.std
+ img = (img.astype("float32") * self.scale - self.mean) / self.std
img_height, img_width, channel = img.shape
assert channel == 3, "Please use an rgb image."
ratio = img_width / float(img_height)
@@ -98,7 +99,7 @@ def preprocess(self, img):
resized_w = int(math.ceil(self.height * ratio))
img = cv2.resize(img, (resized_w, self.height))
- new_img = np.zeros([self.height, self.width, 3]).astype('float32')
+ new_img = np.zeros([self.height, self.width, 3]).astype("float32")
new_img[:, 0:resized_w, :] = img
img = new_img.transpose((2, 0, 1))
img = img[np.newaxis, :, :, :]
@@ -110,12 +111,18 @@ def postprocess(self, tensor):
img = (img * self.std + self.mean) / self.scale
img = np.maximum(img, 0.0)
img = np.minimum(img, 255.0)
- img = img.astype('uint8')
+ img = img.astype("uint8")
return img
def rep_style_input(self, style_input, text_input):
- rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
- (style_input.shape[1] / style_input.shape[0])) + 1
+ rep_num = (
+ int(
+ 1.2
+ * (text_input.shape[1] / text_input.shape[0])
+ / (style_input.shape[1] / style_input.shape[0])
+ )
+ + 1
+ )
style_input = np.tile(style_input, reps=[1, rep_num, 1])
max_width = int(self.width / self.height * style_input.shape[0])
style_input = style_input[:, :max_width, :]
diff --git a/StyleText/engine/style_samplers.py b/StyleText/engine/style_samplers.py
index e171d58db7..5821de06ef 100644
--- a/StyleText/engine/style_samplers.py
+++ b/StyleText/engine/style_samplers.py
@@ -35,7 +35,7 @@ def sample(self):
self.index = 0
if self.dataset_with_label:
path_label = self.path_label_list[self.index]
- rel_image_path, label = path_label.split('\t')
+ rel_image_path, label = path_label.split("\t")
else:
rel_image_path = self.path_label_list[self.index]
label = None
diff --git a/StyleText/engine/synthesisers.py b/StyleText/engine/synthesisers.py
index 6461d9e363..1bb4c933c5 100644
--- a/StyleText/engine/synthesisers.py
+++ b/StyleText/engine/synthesisers.py
@@ -28,8 +28,7 @@ def __init__(self):
self.output_dir = self.config["Global"]["output_dir"]
if not os.path.exists(self.output_dir):
os.mkdir(self.output_dir)
- self.logger = get_logger(
- log_file='{}/predict.log'.format(self.output_dir))
+ self.logger = get_logger(log_file="{}/predict.log".format(self.output_dir))
self.text_drawer = text_drawers.StdTextDrawer(self.config)
@@ -39,7 +38,8 @@ def __init__(self):
def synth_image(self, corpus, style_input, language="en"):
corpus_list, text_input_list = self.text_drawer.draw_text(
- corpus, language, style_input_width=style_input.shape[1])
+ corpus, language, style_input_width=style_input.shape[1]
+ )
synth_result = self.predictor.predict(style_input, text_input_list)
return synth_result
@@ -50,8 +50,9 @@ def __init__(self):
self.tag = self.FLAGS.tag
self.output_num = self.config["Global"]["output_num"]
corpus_generator_method = self.config["CorpusGenerator"]["method"]
- self.corpus_generator = getattr(corpus_generators,
- corpus_generator_method)(self.config)
+ self.corpus_generator = getattr(corpus_generators, corpus_generator_method)(
+ self.config
+ )
style_sampler_method = self.config["StyleSampler"]["method"]
assert style_sampler_method is not None
@@ -66,7 +67,8 @@ def synth_dataset(self):
text_input_label_list, text_input_list = self.text_drawer.draw_text(
text_input_label,
corpus_language,
- style_input_width=style_input.shape[1])
+ style_input_width=style_input.shape[1],
+ )
text_input_label = "".join(text_input_label_list)
diff --git a/StyleText/engine/text_drawers.py b/StyleText/engine/text_drawers.py
index 2eb73b38ac..a5128a03b6 100644
--- a/StyleText/engine/text_drawers.py
+++ b/StyleText/engine/text_drawers.py
@@ -8,7 +8,9 @@ class StdTextDrawer(object):
def __init__(self, config):
self.logger = get_logger()
self.max_width = config["Global"]["image_width"]
- self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ self.char_list = (
+ " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ )
self.height = config["Global"]["image_height"]
self.font_dict = {}
self.load_fonts(config["TextDrawer"]["fonts"])
@@ -28,16 +30,13 @@ def get_valid_height(self, font_path):
if font_height <= self.height - 4:
return self.height - 4
else:
- return int((self.height - 4)**2 / font_height)
+ return int((self.height - 4) ** 2 / font_height)
- def draw_text(self,
- corpus,
- language="en",
- crop=True,
- style_input_width=None):
+ def draw_text(self, corpus, language="en", crop=True, style_input_width=None):
if language not in self.support_languages:
self.logger.warning(
- "language {} not supported, use en instead.".format(language))
+ "language {} not supported, use en instead.".format(language)
+ )
language = "en"
if crop:
width = min(self.max_width, len(corpus) * self.height) + 4
diff --git a/StyleText/engine/writers.py b/StyleText/engine/writers.py
index 0df75e7234..d692c28cf3 100644
--- a/StyleText/engine/writers.py
+++ b/StyleText/engine/writers.py
@@ -52,16 +52,14 @@ def save_label(self):
for image_path in self.label_dict:
label = self.label_dict[image_path]
label_raw += "{}\t{}\n".format(image_path, label)
- label_file_path = os.path.join(label_home,
- "{}_label.txt".format(self.tag))
+ label_file_path = os.path.join(label_home, "{}_label.txt".format(self.tag))
with open(label_file_path, "w") as f:
f.write(label_raw)
self.label_file_index += 1
def merge_label(self):
label_raw = ""
- label_file_regex = os.path.join(self.output_dir, "label",
- "*_label.txt")
+ label_file_regex = os.path.join(self.output_dir, "label", "*_label.txt")
label_file_list = glob.glob(label_file_regex)
for label_file_i in label_file_list:
with open(label_file_i, "r") as f:
diff --git a/StyleText/tools/synth_dataset.py b/StyleText/tools/synth_dataset.py
index a75f7f393b..d225710e42 100644
--- a/StyleText/tools/synth_dataset.py
+++ b/StyleText/tools/synth_dataset.py
@@ -17,7 +17,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
from engine.synthesisers import DatasetSynthesiser
@@ -27,5 +27,5 @@ def synth_dataset():
dataset_synthesiser.synth_dataset()
-if __name__ == '__main__':
+if __name__ == "__main__":
synth_dataset()
diff --git a/StyleText/tools/synth_image.py b/StyleText/tools/synth_image.py
index cbc3118675..fde9f334a7 100644
--- a/StyleText/tools/synth_image.py
+++ b/StyleText/tools/synth_image.py
@@ -18,7 +18,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
from utils.config import ArgsParser
from engine.synthesisers import ImageSynthesiser
@@ -77,6 +77,6 @@ def batch_synth_images():
print(cno, corpus_num, sno, style_img_num)
-if __name__ == '__main__':
+if __name__ == "__main__":
# batch_synth_images()
synth_image()
diff --git a/StyleText/utils/config.py b/StyleText/utils/config.py
index b2f8a618a8..926a7cf589 100644
--- a/StyleText/utils/config.py
+++ b/StyleText/utils/config.py
@@ -32,25 +32,25 @@ def str2num(v):
except Exception:
return v
- assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
- assert len(ks) > 0, ('lenght of keys should larger than 0')
+ assert isinstance(dl, (list, dict)), "{} should be a list or a dict"
+ assert len(ks) > 0, "lenght of keys should larger than 0"
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
- assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
+ assert k < len(dl), "index({}) out of range({})".format(k, dl)
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
- #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
+ # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
- logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
+ logger.warning("A new filed ({}) detected!".format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
- assert ks[0] in dl, (
- '({}) doesn\'t exist in {}, a new dict field is invalid'.
- format(ks[0], dl))
+ assert (
+ ks[0] in dl
+ ), "({}) doesn't exist in {}, a new dict field is invalid".format(ks[0], dl)
override(dl[ks[0]], ks[1:], v)
@@ -71,15 +71,15 @@ def override_config(config, options=None):
"""
if options is not None:
for opt in options:
- assert isinstance(opt, str), (
- "option({}) should be a str".format(opt))
+ assert isinstance(opt, str), "option({}) should be a str".format(opt)
assert "=" in opt, (
"option({}) should contain a ="
- "to distinguish between key and value".format(opt))
- pair = opt.split('=')
- assert len(pair) == 2, ("there can be only a = in the option")
+ "to distinguish between key and value".format(opt)
+ )
+ pair = opt.split("=")
+ assert len(pair) == 2, "there can be only a = in the option"
key, value = pair
- keys = key.split('.')
+ keys = key.split(".")
override(config, keys, value)
return config
@@ -87,28 +87,29 @@ def override_config(config, options=None):
class ArgsParser(ArgumentParser):
def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
+ super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
+ self.add_argument("-t", "--tag", default="0", help="tag for marking worker")
self.add_argument(
- "-t", "--tag", default="0", help="tag for marking worker")
- self.add_argument(
- '-o',
- '--override',
- action='append',
+ "-o",
+ "--override",
+ action="append",
default=[],
- help='config options to be overridden')
- self.add_argument(
- "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
+ help="config options to be overridden",
+ )
self.add_argument(
- "--text_corpus", default="PaddleOCR", help="tag for marking worker")
+ "--style_image",
+ default="examples/style_images/1.jpg",
+ help="tag for marking worker",
+ )
self.add_argument(
- "--language", default="en", help="tag for marking worker")
+ "--text_corpus", default="PaddleOCR", help="tag for marking worker"
+ )
+ self.add_argument("--language", default="en", help="tag for marking worker")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, \
- "Please specify --config=configure_file_path."
+ assert args.config is not None, "Please specify --config=configure_file_path."
return args
@@ -120,8 +121,8 @@ def load_config(file_path):
Returns: config
"""
ext = os.path.splitext(file_path)[1]
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- with open(file_path, 'rb') as f:
+ assert ext in [".yml", ".yaml"], "only support yaml files for now"
+ with open(file_path, "rb") as f:
config = yaml.load(f, Loader=yaml.Loader)
return config
@@ -141,7 +142,7 @@ def gen_config():
"use_visualdl": False,
"save_epoch_step": 10,
"vgg_pretrain": "./pretrained/VGG19_pretrained",
- "vgg_load_static_pretrain": True
+ "vgg_load_static_pretrain": True,
},
"Architecture": {
"model_type": "data_aug",
@@ -153,7 +154,7 @@ def gen_config():
"use_dropout": False,
"init_type": "xavier",
"init_gain": 0.02,
- "use_dilation": 1
+ "use_dilation": 1,
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
@@ -172,21 +173,12 @@ def gen_config():
"netD": "basic",
"norm": "none",
"init_type": "xavier",
- }
- },
- "Loss": {
- "lamb": 10,
- "perceptual_lamb": 1,
- "muvar_lamb": 50,
- "style_lamb": 500
+ },
},
+ "Loss": {"lamb": 10, "perceptual_lamb": 1, "muvar_lamb": 50, "style_lamb": 500},
"Optimizer": {
"name": "Adam",
- "learning_rate": {
- "name": "lambda",
- "lr": 0.0002,
- "lr_decay_iters": 50
- },
+ "learning_rate": {"name": "lambda", "lr": 0.0002, "lr_decay_iters": 50},
"beta1": 0.5,
"beta2": 0.999,
},
@@ -197,28 +189,30 @@ def gen_config():
"delimiter": "\t",
"data_dir": "/",
"label_file": "tmp/label.txt",
- "transforms": [{
- "DecodeImage": {
- "to_rgb": True,
- "to_np": False,
- "channel_first": False
- }
- }, {
- "NormalizeImage": {
- "scale": 1. / 255.,
- "mean": [0.485, 0.456, 0.406],
- "std": [0.229, 0.224, 0.225],
- "order": None
- }
- }, {
- "ToCHWImage": None
- }]
- }
- }
+ "transforms": [
+ {
+ "DecodeImage": {
+ "to_rgb": True,
+ "to_np": False,
+ "channel_first": False,
+ }
+ },
+ {
+ "NormalizeImage": {
+ "scale": 1.0 / 255.0,
+ "mean": [0.485, 0.456, 0.406],
+ "std": [0.229, 0.224, 0.225],
+ "order": None,
+ }
+ },
+ {"ToCHWImage": None},
+ ],
+ },
+ },
}
with open("config.yml", "w") as f:
yaml.dump(base_config, f)
-if __name__ == '__main__':
+if __name__ == "__main__":
gen_config()
diff --git a/StyleText/utils/load_params.py b/StyleText/utils/load_params.py
index be0561363e..4606f73468 100644
--- a/StyleText/utils/load_params.py
+++ b/StyleText/utils/load_params.py
@@ -14,14 +14,13 @@
import os
import paddle
-__all__ = ['load_dygraph_pretrain']
+__all__ = ["load_dygraph_pretrain"]
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
- if not os.path.exists(path + '.pdparams'):
- raise ValueError("Model pretrain path {} does not "
- "exists.".format(path))
- param_state_dict = paddle.load(path + '.pdparams')
+ if not os.path.exists(path + ".pdparams"):
+ raise ValueError("Model pretrain path {} does not " "exists.".format(path))
+ param_state_dict = paddle.load(path + ".pdparams")
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(path))
return
diff --git a/StyleText/utils/logging.py b/StyleText/utils/logging.py
index f700fe26bc..e3de9c7e51 100644
--- a/StyleText/utils/logging.py
+++ b/StyleText/utils/logging.py
@@ -21,7 +21,7 @@
@functools.lru_cache()
-def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
+def get_logger(name="srnet", log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
@@ -45,8 +45,8 @@ def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
return logger
formatter = logging.Formatter(
- '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
- datefmt="%Y/%m/%d %H:%M:%S")
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
+ )
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
@@ -54,7 +54,7 @@ def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
- file_handler = logging.FileHandler(log_file, 'a')
+ file_handler = logging.FileHandler(log_file, "a")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
diff --git a/StyleText/utils/math_functions.py b/StyleText/utils/math_functions.py
index 3dc8d9160f..440ae784bc 100644
--- a/StyleText/utils/math_functions.py
+++ b/StyleText/utils/math_functions.py
@@ -39,7 +39,10 @@ def compute_mean_covariance(img):
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
eps = 1e-5
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
- union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
- y_pred_cls * training_mask) + eps
- loss = 1. - (2 * intersection / union)
+ union = (
+ paddle.sum(y_true_cls * training_mask)
+ + paddle.sum(y_pred_cls * training_mask)
+ + eps
+ )
+ loss = 1.0 - (2 * intersection / union)
return loss
diff --git a/StyleText/utils/sys_funcs.py b/StyleText/utils/sys_funcs.py
index 203d91d836..ea395ef670 100644
--- a/StyleText/utils/sys_funcs.py
+++ b/StyleText/utils/sys_funcs.py
@@ -19,15 +19,20 @@
def get_check_global_params(mode):
check_params = [
- 'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
- 'character_type', 'loss_type'
+ "use_gpu",
+ "max_text_length",
+ "image_shape",
+ "image_shape",
+ "character_type",
+ "loss_type",
]
if mode == "train_eval":
check_params = check_params + [
- 'train_batch_size_per_card', 'test_batch_size_per_card'
+ "train_batch_size_per_card",
+ "test_batch_size_per_card",
]
elif mode == "test":
- check_params = check_params + ['test_batch_size_per_card']
+ check_params = check_params + ["test_batch_size_per_card"]
return check_params
@@ -36,11 +41,13 @@ def check_gpu(use_gpu):
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
- err = "Config use_gpu cannot be set as true while you are " \
- "using paddlepaddle cpu version ! \nPlease try: \n" \
- "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
- "\t2. Set use_gpu as false in config file to run " \
- "model on CPU"
+ err = (
+ "Config use_gpu cannot be set as true while you are "
+ "using paddlepaddle cpu version ! \nPlease try: \n"
+ "\t1. Install paddlepaddle-gpu to run model on GPU \n"
+ "\t2. Set use_gpu as false in config file to run "
+ "model on CPU"
+ )
if use_gpu:
try:
if not paddle.is_compiled_with_cuda():
@@ -61,7 +68,7 @@ def _mkdir_if_not_exist(path, logger):
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
- 'be happy if some process has already created {}'.format(
- path))
+ "be happy if some process has already created {}".format(path)
+ )
else:
- raise OSError('Failed to mkdir {}'.format(path))
+ raise OSError("Failed to mkdir {}".format(path))
diff --git a/__init__.py b/__init__.py
index a7c32e9629..b085589dea 100644
--- a/__init__.py
+++ b/__init__.py
@@ -15,7 +15,13 @@
__version__ = paddleocr.VERSION
__all__ = [
- 'PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result',
- 'save_structure_res', 'download_with_progressbar', 'sorted_layout_boxes',
- 'convert_info_docx', 'to_excel'
+ "PaddleOCR",
+ "PPStructure",
+ "draw_ocr",
+ "draw_structure_result",
+ "save_structure_res",
+ "download_with_progressbar",
+ "sorted_layout_boxes",
+ "convert_info_docx",
+ "to_excel",
]
diff --git "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/gen.py" "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/gen.py"
index 0eb00cd1ef..74da8f50d9 100644
--- "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/gen.py"
+++ "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/gen.py"
@@ -30,7 +30,7 @@ def get_char_lines(txt_root_path):
txt_files = os.listdir(txt_root_path)
char_lines = []
for txt in txt_files:
- f = open(os.path.join(txt_root_path, txt), mode='r', encoding='utf-8')
+ f = open(os.path.join(txt_root_path, txt), mode="r", encoding="utf-8")
lines = f.readlines()
f.close()
for line in lines:
@@ -43,8 +43,8 @@ def get_horizontal_text_picture(image_file, chars, fonts_list, cf):
desc:gen horizontal text picture
"""
img = Image.open(image_file)
- if img.mode != 'RGB':
- img = img.convert('RGB')
+ if img.mode != "RGB":
+ img = img.convert("RGB")
img_w, img_h = img.size
# random choice font
@@ -56,7 +56,7 @@ def get_horizontal_text_picture(image_file, chars, fonts_list, cf):
ch_w = []
ch_h = []
for ch in chars:
- if int(PIL.__version__.split('.')[0]) < 10:
+ if int(PIL.__version__.split(".")[0]) < 10:
wt, ht = font.getsize(ch)
else:
left, top, right, bottom = font.getbbox(ch)
@@ -68,7 +68,7 @@ def get_horizontal_text_picture(image_file, chars, fonts_list, cf):
# add space
char_space_width = max(ch_w)
- f_w += (char_space_width * (len(chars) - 1))
+ f_w += char_space_width * (len(chars) - 1)
x1 = random.randint(0, img_w - f_w)
y1 = random.randint(0, img_h - f_h)
@@ -84,7 +84,7 @@ def get_horizontal_text_picture(image_file, chars, fonts_list, cf):
draw = ImageDraw.Draw(img)
for i, ch in enumerate(chars):
draw.text((x1, y1), ch, best_color, font=font)
- x1 += (ch_w[i] + char_space_width)
+ x1 += ch_w[i] + char_space_width
crop_img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2))
return crop_img, chars
@@ -94,8 +94,8 @@ def get_vertical_text_picture(image_file, chars, fonts_list, cf):
desc:gen vertical text picture
"""
img = Image.open(image_file)
- if img.mode != 'RGB':
- img = img.convert('RGB')
+ if img.mode != "RGB":
+ img = img.convert("RGB")
img_w, img_h = img.size
# random choice font
font_path = random.choice(fonts_list)
@@ -106,7 +106,7 @@ def get_vertical_text_picture(image_file, chars, fonts_list, cf):
ch_w = []
ch_h = []
for ch in chars:
- if int(PIL.__version__.split('.')[0]) < 10:
+ if int(PIL.__version__.split(".")[0]) < 10:
wt, ht = font.getsize(ch)
else:
left, top, right, bottom = font.getbbox(ch)
@@ -143,28 +143,52 @@ def get_fonts(fonts_path):
desc: get all fonts
"""
font_files = os.listdir(fonts_path)
- fonts_list=[]
+ fonts_list = []
for font_file in font_files:
- font_path=os.path.join(fonts_path, font_file)
+ font_path = os.path.join(fonts_path, font_file)
fonts_list.append(font_path)
return fonts_list
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--num_img', type=int, default=30, help="Number of images to generate")
- parser.add_argument('--font_min_size', type=int, default=11)
- parser.add_argument('--font_max_size', type=int, default=12,
- help="Help adjust the size of the generated text and the size of the picture")
- parser.add_argument('--bg_path', type=str, default='./background',
- help='The generated text pictures will be pasted onto the pictures of this folder')
- parser.add_argument('--det_bg_path', type=str, default='./det_background',
- help='The generated text pictures will use the pictures of this folder as the background')
- parser.add_argument('--fonts_path', type=str, default='../../StyleText/fonts',
- help='The font used to generate the picture')
- parser.add_argument('--corpus_path', type=str, default='./corpus',
- help='The corpus used to generate the text picture')
- parser.add_argument('--output_dir', type=str, default='./output/', help='Images save dir')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num_img", type=int, default=30, help="Number of images to generate"
+ )
+ parser.add_argument("--font_min_size", type=int, default=11)
+ parser.add_argument(
+ "--font_max_size",
+ type=int,
+ default=12,
+ help="Help adjust the size of the generated text and the size of the picture",
+ )
+ parser.add_argument(
+ "--bg_path",
+ type=str,
+ default="./background",
+ help="The generated text pictures will be pasted onto the pictures of this folder",
+ )
+ parser.add_argument(
+ "--det_bg_path",
+ type=str,
+ default="./det_background",
+ help="The generated text pictures will use the pictures of this folder as the background",
+ )
+ parser.add_argument(
+ "--fonts_path",
+ type=str,
+ default="../../StyleText/fonts",
+ help="The font used to generate the picture",
+ )
+ parser.add_argument(
+ "--corpus_path",
+ type=str,
+ default="./corpus",
+ help="The corpus used to generate the text picture",
+ )
+ parser.add_argument(
+ "--output_dir", type=str, default="./output/", help="Images save dir"
+ )
cf = parser.parse_args()
# save path
@@ -181,38 +205,37 @@ def get_fonts(fonts_path):
# rec bg
img_root_path = cf.bg_path
- imnames=os.listdir(img_root_path)
-
+ imnames = os.listdir(img_root_path)
+
# det bg
det_bg_path = cf.det_bg_path
bg_pics = os.listdir(det_bg_path)
# OCR det files
- det_val_file = open(cf.output_dir + 'det_gt_val.txt', 'w', encoding='utf-8')
- det_train_file = open(cf.output_dir + 'det_gt_train.txt', 'w', encoding='utf-8')
+ det_val_file = open(cf.output_dir + "det_gt_val.txt", "w", encoding="utf-8")
+ det_train_file = open(cf.output_dir + "det_gt_train.txt", "w", encoding="utf-8")
# det imgs
- det_save_dir = 'imgs/'
+ det_save_dir = "imgs/"
if not os.path.exists(cf.output_dir + det_save_dir):
os.mkdir(cf.output_dir + det_save_dir)
- det_val_save_dir = 'imgs_val/'
+ det_val_save_dir = "imgs_val/"
if not os.path.exists(cf.output_dir + det_val_save_dir):
os.mkdir(cf.output_dir + det_val_save_dir)
# OCR rec files
- rec_val_file = open(cf.output_dir + 'rec_gt_val.txt', 'w', encoding='utf-8')
- rec_train_file = open(cf.output_dir + 'rec_gt_train.txt', 'w', encoding='utf-8')
+ rec_val_file = open(cf.output_dir + "rec_gt_val.txt", "w", encoding="utf-8")
+ rec_train_file = open(cf.output_dir + "rec_gt_train.txt", "w", encoding="utf-8")
# rec imgs
- rec_save_dir = 'rec_imgs/'
+ rec_save_dir = "rec_imgs/"
if not os.path.exists(cf.output_dir + rec_save_dir):
os.mkdir(cf.output_dir + rec_save_dir)
- rec_val_save_dir = 'rec_imgs_val/'
+ rec_val_save_dir = "rec_imgs_val/"
if not os.path.exists(cf.output_dir + rec_val_save_dir):
os.mkdir(cf.output_dir + rec_val_save_dir)
-
val_ratio = cf.num_img * 0.2 # val dataset ratio
- print('start generating...')
+ print("start generating...")
for i in range(0, cf.num_img):
imname = random.choice(imnames)
img_path = os.path.join(img_root_path, imname)
@@ -220,35 +243,39 @@ def get_fonts(fonts_path):
rnd = random.random()
# gen horizontal text picture
if rnd < 0.5:
- gen_img, chars = get_horizontal_text_picture(img_path, char_lines[i], fonts_list, cf)
+ gen_img, chars = get_horizontal_text_picture(
+ img_path, char_lines[i], fonts_list, cf
+ )
ori_w, ori_h = gen_img.size
gen_img = gen_img.crop((0, 3, ori_w, ori_h))
# gen vertical text picture
else:
- gen_img, chars = get_vertical_text_picture(img_path, char_lines[i], fonts_list, cf)
+ gen_img, chars = get_vertical_text_picture(
+ img_path, char_lines[i], fonts_list, cf
+ )
ori_w, ori_h = gen_img.size
gen_img = gen_img.crop((3, 0, ori_w, ori_h))
ori_w, ori_h = gen_img.size
# rec imgs
- save_img_name = str(i).zfill(4) + '.jpg'
+ save_img_name = str(i).zfill(4) + ".jpg"
if i < val_ratio:
save_dir = os.path.join(rec_val_save_dir, save_img_name)
- line = save_dir + '\t' + char_lines[i] + '\n'
+ line = save_dir + "\t" + char_lines[i] + "\n"
rec_val_file.write(line)
else:
save_dir = os.path.join(rec_save_dir, save_img_name)
- line = save_dir + '\t' + char_lines[i] + '\n'
+ line = save_dir + "\t" + char_lines[i] + "\n"
rec_train_file.write(line)
- gen_img.save(cf.output_dir + save_dir, quality = 95, subsampling=0)
+ gen_img.save(cf.output_dir + save_dir, quality=95, subsampling=0)
# det img
# random choice bg
bg_pic = random.sample(bg_pics, 1)[0]
det_img = Image.open(os.path.join(det_bg_path, bg_pic))
# the PCB position is fixed, modify it according to your own scenario
- if bg_pic == '1.png':
+ if bg_pic == "1.png":
x1 = 38
y1 = 3
else:
@@ -257,14 +284,21 @@ def get_fonts(fonts_path):
det_img.paste(gen_img, (x1, y1))
# text pos
- chars_pos = [[x1, y1], [x1 + ori_w, y1], [x1 + ori_w, y1 + ori_h], [x1, y1 + ori_h]]
- label = [{"transcription":char_lines[i], "points":chars_pos}]
+ chars_pos = [
+ [x1, y1],
+ [x1 + ori_w, y1],
+ [x1 + ori_w, y1 + ori_h],
+ [x1, y1 + ori_h],
+ ]
+ label = [{"transcription": char_lines[i], "points": chars_pos}]
if i < val_ratio:
save_dir = os.path.join(det_val_save_dir, save_img_name)
- det_val_file.write(save_dir + '\t' + json.dumps(
- label, ensure_ascii=False) + '\n')
+ det_val_file.write(
+ save_dir + "\t" + json.dumps(label, ensure_ascii=False) + "\n"
+ )
else:
save_dir = os.path.join(det_save_dir, save_img_name)
- det_train_file.write(save_dir + '\t' + json.dumps(
- label, ensure_ascii=False) + '\n')
- det_img.save(cf.output_dir + save_dir, quality = 95, subsampling=0)
+ det_train_file.write(
+ save_dir + "\t" + json.dumps(label, ensure_ascii=False) + "\n"
+ )
+ det_img.save(cf.output_dir + save_dir, quality=95, subsampling=0)
diff --git a/applications/README_en.md b/applications/README_en.md
index df18465e55..0719a0933a 100644
--- a/applications/README_en.md
+++ b/applications/README_en.md
@@ -1,4 +1,4 @@
-English| [简体中文](README.md)
+English| [简体中文](README.md)
# Application
diff --git a/benchmark/PaddleOCR_DBNet/base/__init__.py b/benchmark/PaddleOCR_DBNet/base/__init__.py
index 223e9e02d7..5d7b417d88 100644
--- a/benchmark/PaddleOCR_DBNet/base/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/base/__init__.py
@@ -1,2 +1,2 @@
from .base_trainer import BaseTrainer
-from .base_dataset import BaseDataSet
\ No newline at end of file
+from .base_dataset import BaseDataSet
diff --git a/benchmark/PaddleOCR_DBNet/base/base_dataset.py b/benchmark/PaddleOCR_DBNet/base/base_dataset.py
index 4a839a8ffb..e9f6c4de87 100644
--- a/benchmark/PaddleOCR_DBNet/base/base_dataset.py
+++ b/benchmark/PaddleOCR_DBNet/base/base_dataset.py
@@ -7,24 +7,24 @@
class BaseDataSet(Dataset):
- def __init__(self,
- data_path: str,
- img_mode,
- pre_processes,
- filter_keys,
- ignore_tags,
- transform=None,
- target_transform=None):
- assert img_mode in ['RGB', 'BRG', 'GRAY']
+ def __init__(
+ self,
+ data_path: str,
+ img_mode,
+ pre_processes,
+ filter_keys,
+ ignore_tags,
+ transform=None,
+ target_transform=None,
+ ):
+ assert img_mode in ["RGB", "BRG", "GRAY"]
self.ignore_tags = ignore_tags
self.data_list = self.load_data(data_path)
- item_keys = [
- 'img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags'
- ]
+ item_keys = ["img_path", "img_name", "text_polys", "texts", "ignore_tags"]
for item in item_keys:
- assert item in self.data_list[
- 0], 'data_list from load_data must contains {}'.format(
- item_keys)
+ assert (
+ item in self.data_list[0]
+ ), "data_list from load_data must contains {}".format(item_keys)
self.img_mode = img_mode
self.filter_keys = filter_keys
self.transform = transform
@@ -35,14 +35,14 @@ def _init_pre_processes(self, pre_processes):
self.aug = []
if pre_processes is not None:
for aug in pre_processes:
- if 'args' not in aug:
+ if "args" not in aug:
args = {}
else:
- args = aug['args']
+ args = aug["args"]
if isinstance(args, dict):
- cls = eval(aug['type'])(**args)
+ cls = eval(aug["type"])(**args)
else:
- cls = eval(aug['type'])(args)
+ cls = eval(aug["type"])(args)
self.aug.append(cls)
def load_data(self, data_path: str) -> list:
@@ -61,17 +61,16 @@ def apply_pre_processes(self, data):
def __getitem__(self, index):
try:
data = copy.deepcopy(self.data_list[index])
- im = cv2.imread(data['img_path'], 1
- if self.img_mode != 'GRAY' else 0)
- if self.img_mode == 'RGB':
+ im = cv2.imread(data["img_path"], 1 if self.img_mode != "GRAY" else 0)
+ if self.img_mode == "RGB":
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
- data['img'] = im
- data['shape'] = [im.shape[0], im.shape[1]]
+ data["img"] = im
+ data["shape"] = [im.shape[0], im.shape[1]]
data = self.apply_pre_processes(data)
if self.transform:
- data['img'] = self.transform(data['img'])
- data['text_polys'] = data['text_polys'].tolist()
+ data["img"] = self.transform(data["img"])
+ data["text_polys"] = data["text_polys"].tolist()
if len(self.filter_keys):
data_dict = {}
for k, v in data.items():
diff --git a/benchmark/PaddleOCR_DBNet/base/base_trainer.py b/benchmark/PaddleOCR_DBNet/base/base_trainer.py
index 82c308d361..f0d7f74c80 100644
--- a/benchmark/PaddleOCR_DBNet/base/base_trainer.py
+++ b/benchmark/PaddleOCR_DBNet/base/base_trainer.py
@@ -18,20 +18,23 @@
class BaseTrainer:
- def __init__(self,
- config,
- model,
- criterion,
- train_loader,
- validate_loader,
- metric_cls,
- post_process=None):
- config['trainer']['output_dir'] = os.path.join(
+ def __init__(
+ self,
+ config,
+ model,
+ criterion,
+ train_loader,
+ validate_loader,
+ metric_cls,
+ post_process=None,
+ ):
+ config["trainer"]["output_dir"] = os.path.join(
str(pathlib.Path(os.path.abspath(__name__)).parent),
- config['trainer']['output_dir'])
- config['name'] = config['name'] + '_' + model.name
- self.save_dir = config['trainer']['output_dir']
- self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
+ config["trainer"]["output_dir"],
+ )
+ config["name"] = config["name"] + "_" + model.name
+ self.save_dir = config["trainer"]["output_dir"]
+ self.checkpoint_dir = os.path.join(self.save_dir, "checkpoint")
os.makedirs(self.checkpoint_dir, exist_ok=True)
@@ -40,33 +43,35 @@ def __init__(self,
self.config = config
self.criterion = criterion
# logger and tensorboard
- self.visualdl_enable = self.config['trainer'].get('visual_dl', False)
- self.epochs = self.config['trainer']['epochs']
- self.log_iter = self.config['trainer']['log_iter']
+ self.visualdl_enable = self.config["trainer"].get("visual_dl", False)
+ self.epochs = self.config["trainer"]["epochs"]
+ self.log_iter = self.config["trainer"]["log_iter"]
if paddle.distributed.get_rank() == 0:
- anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
- self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
+ anyconfig.dump(config, os.path.join(self.save_dir, "config.yaml"))
+ self.logger = setup_logger(os.path.join(self.save_dir, "train.log"))
self.logger_info(pformat(self.config))
self.model = self.apply_to_static(model)
# device
- if paddle.device.cuda.device_count(
- ) > 0 and paddle.device.is_compiled_with_cuda():
+ if (
+ paddle.device.cuda.device_count() > 0
+ and paddle.device.is_compiled_with_cuda()
+ ):
self.with_cuda = True
- random.seed(self.config['trainer']['seed'])
- np.random.seed(self.config['trainer']['seed'])
- paddle.seed(self.config['trainer']['seed'])
+ random.seed(self.config["trainer"]["seed"])
+ np.random.seed(self.config["trainer"]["seed"])
+ paddle.seed(self.config["trainer"]["seed"])
else:
self.with_cuda = False
- self.logger_info('train with and paddle {}'.format(paddle.__version__))
+ self.logger_info("train with and paddle {}".format(paddle.__version__))
# metrics
self.metrics = {
- 'recall': 0,
- 'precision': 0,
- 'hmean': 0,
- 'train_loss': float('inf'),
- 'best_model_epoch': 0
+ "recall": 0,
+ "precision": 0,
+ "hmean": 0,
+ "train_loss": float("inf"),
+ "best_model_epoch": 0,
}
self.train_loader = train_loader
@@ -79,67 +84,75 @@ def __init__(self,
if self.validate_loader is not None:
self.logger_info(
- 'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.
- format(
- len(self.train_loader.dataset), self.train_loader_len,
+ "train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader".format(
+ len(self.train_loader.dataset),
+ self.train_loader_len,
len(self.validate_loader.dataset),
- len(self.validate_loader)))
+ len(self.validate_loader),
+ )
+ )
else:
self.logger_info(
- 'train dataset has {} samples,{} in dataloader'.format(
- len(self.train_loader.dataset), self.train_loader_len))
+ "train dataset has {} samples,{} in dataloader".format(
+ len(self.train_loader.dataset), self.train_loader_len
+ )
+ )
self._initialize_scheduler()
self._initialize_optimizer()
# resume or finetune
- if self.config['trainer']['resume_checkpoint'] != '':
+ if self.config["trainer"]["resume_checkpoint"] != "":
self._load_checkpoint(
- self.config['trainer']['resume_checkpoint'], resume=True)
- elif self.config['trainer']['finetune_checkpoint'] != '':
+ self.config["trainer"]["resume_checkpoint"], resume=True
+ )
+ elif self.config["trainer"]["finetune_checkpoint"] != "":
self._load_checkpoint(
- self.config['trainer']['finetune_checkpoint'], resume=False)
+ self.config["trainer"]["finetune_checkpoint"], resume=False
+ )
if self.visualdl_enable and paddle.distributed.get_rank() == 0:
from visualdl import LogWriter
+
self.writer = LogWriter(self.save_dir)
# 混合精度训练
- self.amp = self.config.get('amp', None)
- if self.amp == 'None':
+ self.amp = self.config.get("amp", None)
+ if self.amp == "None":
self.amp = None
if self.amp:
- self.amp['scaler'] = paddle.amp.GradScaler(
+ self.amp["scaler"] = paddle.amp.GradScaler(
init_loss_scaling=self.amp.get("scale_loss", 1024),
- use_dynamic_loss_scaling=self.amp.get(
- 'use_dynamic_loss_scaling', True))
+ use_dynamic_loss_scaling=self.amp.get("use_dynamic_loss_scaling", True),
+ )
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
- level=self.amp.get('amp_level', 'O2'))
+ level=self.amp.get("amp_level", "O2"),
+ )
# 分布式训练
if paddle.device.cuda.device_count() > 1:
self.model = paddle.DataParallel(self.model)
# make inverse Normalize
self.UN_Normalize = False
- for t in self.config['dataset']['train']['dataset']['args'][
- 'transforms']:
- if t['type'] == 'Normalize':
- self.normalize_mean = t['args']['mean']
- self.normalize_std = t['args']['std']
+ for t in self.config["dataset"]["train"]["dataset"]["args"]["transforms"]:
+ if t["type"] == "Normalize":
+ self.normalize_mean = t["args"]["mean"]
+ self.normalize_std = t["args"]["std"]
self.UN_Normalize = True
def apply_to_static(self, model):
- support_to_static = self.config['trainer'].get('to_static', False)
+ support_to_static = self.config["trainer"].get("to_static", False)
if support_to_static:
specs = None
- print('static')
+ print("static")
specs = [InputSpec([None, 3, -1, -1])]
model = to_static(model, input_spec=specs)
self.logger_info(
- "Successfully to apply @to_static with specs: {}".format(specs))
+ "Successfully to apply @to_static with specs: {}".format(specs)
+ )
return model
def train(self):
@@ -185,12 +198,12 @@ def _save_checkpoint(self, epoch, file_name):
"""
state_dict = self.model.state_dict()
state = {
- 'epoch': epoch,
- 'global_step': self.global_step,
- 'state_dict': state_dict,
- 'optimizer': self.optimizer.state_dict(),
- 'config': self.config,
- 'metrics': self.metrics
+ "epoch": epoch,
+ "global_step": self.global_step,
+ "state_dict": state_dict,
+ "optimizer": self.optimizer.state_dict(),
+ "config": self.config,
+ "metrics": self.metrics,
}
filename = os.path.join(self.checkpoint_dir, file_name)
paddle.save(state, filename)
@@ -202,48 +215,54 @@ def _load_checkpoint(self, checkpoint_path, resume):
"""
self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
checkpoint = paddle.load(checkpoint_path)
- self.model.set_state_dict(checkpoint['state_dict'])
+ self.model.set_state_dict(checkpoint["state_dict"])
if resume:
- self.global_step = checkpoint['global_step']
- self.start_epoch = checkpoint['epoch']
- self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
+ self.global_step = checkpoint["global_step"]
+ self.start_epoch = checkpoint["epoch"]
+ self.config["lr_scheduler"]["args"]["last_epoch"] = self.start_epoch
# self.scheduler.load_state_dict(checkpoint['scheduler'])
- self.optimizer.set_state_dict(checkpoint['optimizer'])
- if 'metrics' in checkpoint:
- self.metrics = checkpoint['metrics']
- self.logger_info("resume from checkpoint {} (epoch {})".format(
- checkpoint_path, self.start_epoch))
+ self.optimizer.set_state_dict(checkpoint["optimizer"])
+ if "metrics" in checkpoint:
+ self.metrics = checkpoint["metrics"]
+ self.logger_info(
+ "resume from checkpoint {} (epoch {})".format(
+ checkpoint_path, self.start_epoch
+ )
+ )
else:
- self.logger_info("finetune from checkpoint {}".format(
- checkpoint_path))
+ self.logger_info("finetune from checkpoint {}".format(checkpoint_path))
def _initialize(self, name, module, *args, **kwargs):
- module_name = self.config[name]['type']
- module_args = self.config[name].get('args', {})
- assert all([k not in module_args for k in kwargs
- ]), 'Overwriting kwargs given in config file is not allowed'
+ module_name = self.config[name]["type"]
+ module_args = self.config[name].get("args", {})
+ assert all(
+ [k not in module_args for k in kwargs]
+ ), "Overwriting kwargs given in config file is not allowed"
module_args.update(kwargs)
return getattr(module, module_name)(*args, **module_args)
def _initialize_scheduler(self):
- self.lr_scheduler = self._initialize('lr_scheduler',
- paddle.optimizer.lr)
+ self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)
def _initialize_optimizer(self):
self.optimizer = self._initialize(
- 'optimizer',
+ "optimizer",
paddle.optimizer,
parameters=self.model.parameters(),
- learning_rate=self.lr_scheduler)
+ learning_rate=self.lr_scheduler,
+ )
def inverse_normalize(self, batch_img):
if self.UN_Normalize:
- batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[
- 0] + self.normalize_mean[0]
- batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[
- 1] + self.normalize_mean[1]
- batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[
- 2] + self.normalize_mean[2]
+ batch_img[:, 0, :, :] = (
+ batch_img[:, 0, :, :] * self.normalize_std[0] + self.normalize_mean[0]
+ )
+ batch_img[:, 1, :, :] = (
+ batch_img[:, 1, :, :] * self.normalize_std[1] + self.normalize_mean[1]
+ )
+ batch_img[:, 2, :, :] = (
+ batch_img[:, 2, :, :] * self.normalize_std[2] + self.normalize_mean[2]
+ )
def logger_info(self, s):
if paddle.distributed.get_rank() == 0:
diff --git a/benchmark/PaddleOCR_DBNet/data_loader/__init__.py b/benchmark/PaddleOCR_DBNet/data_loader/__init__.py
index afc6e56b51..e8a5f65f46 100644
--- a/benchmark/PaddleOCR_DBNet/data_loader/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/data_loader/__init__.py
@@ -21,20 +21,21 @@ def get_dataset(data_path, module_name, transform, dataset_args):
:return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
"""
from . import dataset
- s_dataset = getattr(dataset, module_name)(transform=transform,
- data_path=data_path,
- **dataset_args)
+
+ s_dataset = getattr(dataset, module_name)(
+ transform=transform, data_path=data_path, **dataset_args
+ )
return s_dataset
def get_transforms(transforms_config):
tr_list = []
for item in transforms_config:
- if 'args' not in item:
+ if "args" not in item:
args = {}
else:
- args = item['args']
- cls = getattr(transforms, item['type'])(**args)
+ args = item["args"]
+ cls = getattr(transforms, item["type"])(**args)
tr_list.append(cls)
tr_list = transforms.Compose(tr_list)
return tr_list
@@ -64,43 +65,50 @@ def get_dataloader(module_config, distributed=False):
if module_config is None:
return None
config = copy.deepcopy(module_config)
- dataset_args = config['dataset']['args']
- if 'transforms' in dataset_args:
- img_transfroms = get_transforms(dataset_args.pop('transforms'))
+ dataset_args = config["dataset"]["args"]
+ if "transforms" in dataset_args:
+ img_transfroms = get_transforms(dataset_args.pop("transforms"))
else:
img_transfroms = None
# 创建数据集
- dataset_name = config['dataset']['type']
- data_path = dataset_args.pop('data_path')
+ dataset_name = config["dataset"]["type"]
+ data_path = dataset_args.pop("data_path")
if data_path == None:
return None
data_path = [x for x in data_path if x is not None]
if len(data_path) == 0:
return None
- if 'collate_fn' not in config['loader'] or config['loader'][
- 'collate_fn'] is None or len(config['loader']['collate_fn']) == 0:
- config['loader']['collate_fn'] = None
+ if (
+ "collate_fn" not in config["loader"]
+ or config["loader"]["collate_fn"] is None
+ or len(config["loader"]["collate_fn"]) == 0
+ ):
+ config["loader"]["collate_fn"] = None
else:
- config['loader']['collate_fn'] = eval(config['loader']['collate_fn'])()
+ config["loader"]["collate_fn"] = eval(config["loader"]["collate_fn"])()
_dataset = get_dataset(
data_path=data_path,
module_name=dataset_name,
transform=img_transfroms,
- dataset_args=dataset_args)
+ dataset_args=dataset_args,
+ )
sampler = None
if distributed:
# 3)使用DistributedSampler
batch_sampler = DistributedBatchSampler(
dataset=_dataset,
- batch_size=config['loader'].pop('batch_size'),
- shuffle=config['loader'].pop('shuffle'))
+ batch_size=config["loader"].pop("batch_size"),
+ shuffle=config["loader"].pop("shuffle"),
+ )
else:
batch_sampler = BatchSampler(
dataset=_dataset,
- batch_size=config['loader'].pop('batch_size'),
- shuffle=config['loader'].pop('shuffle'))
+ batch_size=config["loader"].pop("batch_size"),
+ shuffle=config["loader"].pop("shuffle"),
+ )
loader = DataLoader(
- dataset=_dataset, batch_sampler=batch_sampler, **config['loader'])
+ dataset=_dataset, batch_sampler=batch_sampler, **config["loader"]
+ )
return loader
diff --git a/benchmark/PaddleOCR_DBNet/data_loader/dataset.py b/benchmark/PaddleOCR_DBNet/data_loader/dataset.py
index 29d3954fe6..b37f018399 100644
--- a/benchmark/PaddleOCR_DBNet/data_loader/dataset.py
+++ b/benchmark/PaddleOCR_DBNet/data_loader/dataset.py
@@ -13,72 +13,75 @@
class ICDAR2015Dataset(BaseDataSet):
- def __init__(self,
- data_path: str,
- img_mode,
- pre_processes,
- filter_keys,
- ignore_tags,
- transform=None,
- **kwargs):
- super().__init__(data_path, img_mode, pre_processes, filter_keys,
- ignore_tags, transform)
+ def __init__(
+ self,
+ data_path: str,
+ img_mode,
+ pre_processes,
+ filter_keys,
+ ignore_tags,
+ transform=None,
+ **kwargs
+ ):
+ super().__init__(
+ data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform
+ )
def load_data(self, data_path: str) -> list:
data_list = get_datalist(data_path)
t_data_list = []
for img_path, label_path in data_list:
data = self._get_annotation(label_path)
- if len(data['text_polys']) > 0:
- item = {
- 'img_path': img_path,
- 'img_name': pathlib.Path(img_path).stem
- }
+ if len(data["text_polys"]) > 0:
+ item = {"img_path": img_path, "img_name": pathlib.Path(img_path).stem}
item.update(data)
t_data_list.append(item)
else:
- print('there is no suit bbox in {}'.format(label_path))
+ print("there is no suit bbox in {}".format(label_path))
return t_data_list
def _get_annotation(self, label_path: str) -> dict:
boxes = []
texts = []
ignores = []
- with open(label_path, encoding='utf-8', mode='r') as f:
+ with open(label_path, encoding="utf-8", mode="r") as f:
for line in f.readlines():
- params = line.strip().strip('\ufeff').strip(
- '\xef\xbb\xbf').split(',')
+ params = line.strip().strip("\ufeff").strip("\xef\xbb\xbf").split(",")
try:
box = order_points_clockwise(
- np.array(list(map(float, params[:8]))).reshape(-1, 2))
+ np.array(list(map(float, params[:8]))).reshape(-1, 2)
+ )
if cv2.contourArea(box) > 0:
boxes.append(box)
label = params[8]
texts.append(label)
ignores.append(label in self.ignore_tags)
except:
- print('load label failed on {}'.format(label_path))
+ print("load label failed on {}".format(label_path))
data = {
- 'text_polys': np.array(boxes),
- 'texts': texts,
- 'ignore_tags': ignores,
+ "text_polys": np.array(boxes),
+ "texts": texts,
+ "ignore_tags": ignores,
}
return data
class DetDataset(BaseDataSet):
- def __init__(self,
- data_path: str,
- img_mode,
- pre_processes,
- filter_keys,
- ignore_tags,
- transform=None,
- **kwargs):
- self.load_char_annotation = kwargs['load_char_annotation']
- self.expand_one_char = kwargs['expand_one_char']
- super().__init__(data_path, img_mode, pre_processes, filter_keys,
- ignore_tags, transform)
+ def __init__(
+ self,
+ data_path: str,
+ img_mode,
+ pre_processes,
+ filter_keys,
+ ignore_tags,
+ transform=None,
+ **kwargs
+ ):
+ self.load_char_annotation = kwargs["load_char_annotation"]
+ self.expand_one_char = kwargs["expand_one_char"]
+ super().__init__(
+ data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform
+ )
def load_data(self, data_path: str) -> list:
"""
@@ -89,93 +92,99 @@ def load_data(self, data_path: str) -> list:
data_list = []
for path in data_path:
content = load(path)
- for gt in tqdm(
- content['data_list'], desc='read file {}'.format(path)):
- img_path = os.path.join(content['data_root'], gt['img_name'])
+ for gt in tqdm(content["data_list"], desc="read file {}".format(path)):
+ img_path = os.path.join(content["data_root"], gt["img_name"])
polygons = []
texts = []
illegibility_list = []
language_list = []
- for annotation in gt['annotations']:
- if len(annotation['polygon']) == 0 or len(annotation[
- 'text']) == 0:
+ for annotation in gt["annotations"]:
+ if len(annotation["polygon"]) == 0 or len(annotation["text"]) == 0:
continue
- if len(annotation['text']) > 1 and self.expand_one_char:
- annotation['polygon'] = expand_polygon(annotation[
- 'polygon'])
- polygons.append(annotation['polygon'])
- texts.append(annotation['text'])
- illegibility_list.append(annotation['illegibility'])
- language_list.append(annotation['language'])
+ if len(annotation["text"]) > 1 and self.expand_one_char:
+ annotation["polygon"] = expand_polygon(annotation["polygon"])
+ polygons.append(annotation["polygon"])
+ texts.append(annotation["text"])
+ illegibility_list.append(annotation["illegibility"])
+ language_list.append(annotation["language"])
if self.load_char_annotation:
- for char_annotation in annotation['chars']:
- if len(char_annotation['polygon']) == 0 or len(
- char_annotation['char']) == 0:
+ for char_annotation in annotation["chars"]:
+ if (
+ len(char_annotation["polygon"]) == 0
+ or len(char_annotation["char"]) == 0
+ ):
continue
- polygons.append(char_annotation['polygon'])
- texts.append(char_annotation['char'])
- illegibility_list.append(char_annotation[
- 'illegibility'])
- language_list.append(char_annotation['language'])
- data_list.append({
- 'img_path': img_path,
- 'img_name': gt['img_name'],
- 'text_polys': np.array(polygons),
- 'texts': texts,
- 'ignore_tags': illegibility_list
- })
+ polygons.append(char_annotation["polygon"])
+ texts.append(char_annotation["char"])
+ illegibility_list.append(char_annotation["illegibility"])
+ language_list.append(char_annotation["language"])
+ data_list.append(
+ {
+ "img_path": img_path,
+ "img_name": gt["img_name"],
+ "text_polys": np.array(polygons),
+ "texts": texts,
+ "ignore_tags": illegibility_list,
+ }
+ )
return data_list
class SynthTextDataset(BaseDataSet):
- def __init__(self,
- data_path: str,
- img_mode,
- pre_processes,
- filter_keys,
- transform=None,
- **kwargs):
+ def __init__(
+ self,
+ data_path: str,
+ img_mode,
+ pre_processes,
+ filter_keys,
+ transform=None,
+ **kwargs
+ ):
self.transform = transform
self.dataRoot = pathlib.Path(data_path)
if not self.dataRoot.exists():
- raise FileNotFoundError('Dataset folder is not exist.')
+ raise FileNotFoundError("Dataset folder is not exist.")
- self.targetFilePath = self.dataRoot / 'gt.mat'
+ self.targetFilePath = self.dataRoot / "gt.mat"
if not self.targetFilePath.exists():
- raise FileExistsError('Target file is not exist.')
+ raise FileExistsError("Target file is not exist.")
targets = {}
sio.loadmat(
self.targetFilePath,
targets,
squeeze_me=True,
struct_as_record=False,
- variable_names=['imnames', 'wordBB', 'txt'])
+ variable_names=["imnames", "wordBB", "txt"],
+ )
- self.imageNames = targets['imnames']
- self.wordBBoxes = targets['wordBB']
- self.transcripts = targets['txt']
- super().__init__(data_path, img_mode, pre_processes, filter_keys,
- transform)
+ self.imageNames = targets["imnames"]
+ self.wordBBoxes = targets["wordBB"]
+ self.transcripts = targets["txt"]
+ super().__init__(data_path, img_mode, pre_processes, filter_keys, transform)
def load_data(self, data_path: str) -> list:
t_data_list = []
for imageName, wordBBoxes, texts in zip(
- self.imageNames, self.wordBBoxes, self.transcripts):
+ self.imageNames, self.wordBBoxes, self.transcripts
+ ):
item = {}
- wordBBoxes = np.expand_dims(
- wordBBoxes, axis=2) if (wordBBoxes.ndim == 2) else wordBBoxes
+ wordBBoxes = (
+ np.expand_dims(wordBBoxes, axis=2)
+ if (wordBBoxes.ndim == 2)
+ else wordBBoxes
+ )
_, _, numOfWords = wordBBoxes.shape
text_polys = wordBBoxes.reshape(
- [8, numOfWords], order='F').T # num_words * 8
- text_polys = text_polys.reshape(numOfWords, 4,
- 2) # num_of_words * 4 * 2
+ [8, numOfWords], order="F"
+ ).T # num_words * 8
+ text_polys = text_polys.reshape(numOfWords, 4, 2) # num_of_words * 4 * 2
transcripts = [word for line in texts for word in line.split()]
if numOfWords != len(transcripts):
continue
- item['img_path'] = str(self.dataRoot / imageName)
- item['img_name'] = (self.dataRoot / imageName).stem
- item['text_polys'] = text_polys
- item['texts'] = transcripts
- item['ignore_tags'] = [x in self.ignore_tags for x in transcripts]
+ item["img_path"] = str(self.dataRoot / imageName)
+ item["img_name"] = (self.dataRoot / imageName).stem
+ item["text_polys"] = text_polys
+ item["texts"] = transcripts
+ item["ignore_tags"] = [x in self.ignore_tags for x in transcripts]
t_data_list.append(item)
return t_data_list
diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/augment.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/augment.py
index e81bc123d9..d2edd93d7a 100644
--- a/benchmark/PaddleOCR_DBNet/data_loader/modules/augment.py
+++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/augment.py
@@ -23,8 +23,9 @@ def __call__(self, data: dict):
"""
if random.random() > self.random_rate:
return data
- data['img'] = (random_noise(
- data['img'], mode='gaussian', clip=True) * 255).astype(im.dtype)
+ data["img"] = (
+ random_noise(data["img"], mode="gaussian", clip=True) * 255
+ ).astype(im.dtype)
return data
@@ -46,16 +47,16 @@ def __call__(self, data: dict) -> dict:
"""
if random.random() > self.random_rate:
return data
- im = data['img']
- text_polys = data['text_polys']
+ im = data["img"]
+ text_polys = data["text_polys"]
tmp_text_polys = text_polys.copy()
rd_scale = float(np.random.choice(self.scales))
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
tmp_text_polys *= rd_scale
- data['img'] = im
- data['text_polys'] = tmp_text_polys
+ data["img"] = im
+ data["text_polys"] = tmp_text_polys
return data
@@ -69,18 +70,18 @@ def __init__(self, degrees, random_rate, same_size=False):
"""
if isinstance(degrees, numbers.Number):
if degrees < 0:
- raise ValueError(
- "If degrees is a single number, it must be positive.")
+ raise ValueError("If degrees is a single number, it must be positive.")
degrees = (-degrees, degrees)
- elif isinstance(degrees, list) or isinstance(
- degrees, tuple) or isinstance(degrees, np.ndarray):
+ elif (
+ isinstance(degrees, list)
+ or isinstance(degrees, tuple)
+ or isinstance(degrees, np.ndarray)
+ ):
if len(degrees) != 2:
- raise ValueError(
- "If degrees is a sequence, it must be of len 2.")
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
degrees = degrees
else:
- raise Exception(
- 'degrees must in Number or list or tuple or np.ndarray')
+ raise Exception("degrees must in Number or list or tuple or np.ndarray")
self.degrees = degrees
self.same_size = same_size
self.random_rate = random_rate
@@ -93,8 +94,8 @@ def __call__(self, data: dict) -> dict:
"""
if random.random() > self.random_rate:
return data
- im = data['img']
- text_polys = data['text_polys']
+ im = data["img"]
+ text_polys = data["text_polys"]
# ---------------------- 旋转图像 ----------------------
w = im.shape[1]
@@ -108,21 +109,22 @@ def __call__(self, data: dict) -> dict:
# 角度变弧度
rangle = np.deg2rad(angle)
# 计算旋转之后图像的w, h
- nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
- nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
+ nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
+ nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
# 构造仿射矩阵
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
# 计算原图中心点到新图中心点的偏移量
- rot_move = np.dot(rot_mat,
- np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
+ rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
# 更新仿射矩阵
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
# 仿射变换
rot_img = cv2.warpAffine(
im,
- rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))),
- flags=cv2.INTER_LANCZOS4)
+ rot_mat,
+ (int(math.ceil(nw)), int(math.ceil(nh))),
+ flags=cv2.INTER_LANCZOS4,
+ )
# ---------------------- 矫正bbox坐标 ----------------------
# rot_mat是最终的旋转矩阵
@@ -134,8 +136,8 @@ def __call__(self, data: dict) -> dict:
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
rot_text_polys.append([point1, point2, point3, point4])
- data['img'] = rot_img
- data['text_polys'] = np.array(rot_text_polys)
+ data["img"] = rot_img
+ data["text_polys"] = np.array(rot_text_polys)
return data
@@ -150,17 +152,19 @@ def __init__(self, size, random_rate, keep_ratio=False):
if isinstance(size, numbers.Number):
if size < 0:
raise ValueError(
- "If input_size is a single number, it must be positive.")
+ "If input_size is a single number, it must be positive."
+ )
size = (size, size)
- elif isinstance(size, list) or isinstance(size, tuple) or isinstance(
- size, np.ndarray):
+ elif (
+ isinstance(size, list)
+ or isinstance(size, tuple)
+ or isinstance(size, np.ndarray)
+ ):
if len(size) != 2:
- raise ValueError(
- "If input_size is a sequence, it must be of len 2.")
+ raise ValueError("If input_size is a sequence, it must be of len 2.")
size = (size[0], size[1])
else:
- raise Exception(
- 'input_size must in Number or list or tuple or np.ndarray')
+ raise Exception("input_size must in Number or list or tuple or np.ndarray")
self.size = size
self.keep_ratio = keep_ratio
self.random_rate = random_rate
@@ -173,8 +177,8 @@ def __call__(self, data: dict) -> dict:
"""
if random.random() > self.random_rate:
return data
- im = data['img']
- text_polys = data['text_polys']
+ im = data["img"]
+ text_polys = data["text_polys"]
if self.keep_ratio:
# 将图片短边pad到和长边一样
@@ -192,8 +196,8 @@ def __call__(self, data: dict) -> dict:
text_polys[:, :, 0] *= w_scale
text_polys[:, :, 1] *= h_scale
- data['img'] = im
- data['text_polys'] = text_polys
+ data["img"] = im
+ data["text_polys"] = text_polys
return data
@@ -226,8 +230,8 @@ def __call__(self, data: dict) -> dict:
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
- im = data['img']
- text_polys = data['text_polys']
+ im = data["img"]
+ text_polys = data["text_polys"]
h, w, _ = im.shape
short_edge = min(h, w)
@@ -242,8 +246,8 @@ def __call__(self, data: dict) -> dict:
text_polys[:, 0] *= scale[0]
text_polys[:, 1] *= scale[1]
- data['img'] = im
- data['text_polys'] = text_polys
+ data["img"] = im
+ data["text_polys"] = text_polys
return data
@@ -263,16 +267,16 @@ def __call__(self, data: dict) -> dict:
"""
if random.random() > self.random_rate:
return data
- im = data['img']
- text_polys = data['text_polys']
+ im = data["img"]
+ text_polys = data["text_polys"]
flip_text_polys = text_polys.copy()
flip_im = cv2.flip(im, 1)
h, w, _ = flip_im.shape
flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
- data['img'] = flip_im
- data['text_polys'] = flip_text_polys
+ data["img"] = flip_im
+ data["text_polys"] = flip_text_polys
return data
@@ -292,13 +296,13 @@ def __call__(self, data: dict) -> dict:
"""
if random.random() > self.random_rate:
return data
- im = data['img']
- text_polys = data['text_polys']
+ im = data["img"]
+ text_polys = data["text_polys"]
flip_text_polys = text_polys.copy()
flip_im = cv2.flip(im, 0)
h, w, _ = flip_im.shape
flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
- data['img'] = flip_im
- data['text_polys'] = flip_text_polys
+ data["img"] = flip_im
+ data["text_polys"] = flip_text_polys
return data
diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/iaa_augment.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/iaa_augment.py
index 1cf891bbd6..1f0fa19027 100644
--- a/benchmark/PaddleOCR_DBNet/data_loader/modules/iaa_augment.py
+++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/iaa_augment.py
@@ -18,17 +18,14 @@ def build(self, args, root=True):
sequence = [self.build(value, root=False) for value in args]
return iaa.Sequential(sequence)
else:
- return getattr(
- iaa,
- args[0])(* [self.to_tuple_if_list(a) for a in args[1:]])
+ return getattr(iaa, args[0])(
+ *[self.to_tuple_if_list(a) for a in args[1:]]
+ )
elif isinstance(args, dict):
- cls = getattr(iaa, args['type'])
- return cls(**{
- k: self.to_tuple_if_list(v)
- for k, v in args['args'].items()
- })
+ cls = getattr(iaa, args["type"])
+ return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()})
else:
- raise RuntimeError('unknown augmenter arg: ' + str(args))
+ raise RuntimeError("unknown augmenter arg: " + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
@@ -36,18 +33,18 @@ def to_tuple_if_list(self, obj):
return obj
-class IaaAugment():
+class IaaAugment:
def __init__(self, augmenter_args):
self.augmenter_args = augmenter_args
self.augmenter = AugmenterBuilder().build(self.augmenter_args)
def __call__(self, data):
- image = data['img']
+ image = data["img"]
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
- data['img'] = aug.augment_image(image)
+ data["img"] = aug.augment_image(image)
data = self.may_augment_annotation(aug, data, shape)
return data
@@ -56,16 +53,16 @@ def may_augment_annotation(self, aug, data, shape):
return data
line_polys = []
- for poly in data['text_polys']:
+ for poly in data["text_polys"]:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
- data['text_polys'] = np.array(line_polys)
+ data["text_polys"] = np.array(line_polys)
return data
def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
- [imgaug.KeypointsOnImage(
- keypoints, shape=img_shape)])[0].keypoints
+ [imgaug.KeypointsOnImage(keypoints, shape=img_shape)]
+ )[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/make_border_map.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_border_map.py
index 2985f3c8a0..28b9ac9795 100644
--- a/benchmark/PaddleOCR_DBNet/data_loader/modules/make_border_map.py
+++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_border_map.py
@@ -1,11 +1,12 @@
import cv2
import numpy as np
-np.seterr(divide='ignore', invalid='ignore')
+
+np.seterr(divide="ignore", invalid="ignore")
import pyclipper
from shapely.geometry import Polygon
-class MakeBorderMap():
+class MakeBorderMap:
def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7):
self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min
@@ -17,9 +18,9 @@ def __call__(self, data: dict) -> dict:
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
- im = data['img']
- text_polys = data['text_polys']
- ignore_tags = data['ignore_tags']
+ im = data["img"]
+ text_polys = data["text_polys"]
+ ignore_tags = data["ignore_tags"]
canvas = np.zeros(im.shape[:2], dtype=np.float32)
mask = np.zeros(im.shape[:2], dtype=np.float32)
@@ -30,8 +31,8 @@ def __call__(self, data: dict) -> dict:
self.draw_border_map(text_polys[i], canvas, mask=mask)
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
- data['threshold_map'] = canvas
- data['threshold_mask'] = mask
+ data["threshold_map"] = canvas
+ data["threshold_mask"] = mask
return data
def draw_border_map(self, polygon, canvas, mask):
@@ -42,8 +43,11 @@ def draw_border_map(self, polygon, canvas, mask):
polygon_shape = Polygon(polygon)
if polygon_shape.area <= 0:
return
- distance = polygon_shape.area * (
- 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+ distance = (
+ polygon_shape.area
+ * (1 - np.power(self.shrink_ratio, 2))
+ / polygon_shape.length
+ )
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -62,14 +66,13 @@ def draw_border_map(self, polygon, canvas, mask):
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
- np.linspace(
- 0, width - 1, num=width).reshape(1, width), (height, width))
+ np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)
+ )
ys = np.broadcast_to(
- np.linspace(
- 0, height - 1, num=height).reshape(height, 1), (height, width))
+ np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)
+ )
- distance_map = np.zeros(
- (polygon.shape[0], height, width), dtype=np.float32)
+ distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
@@ -80,45 +83,53 @@ def draw_border_map(self, polygon, canvas, mask):
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
- canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
- 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
- xmin_valid - xmin:xmax_valid - xmax + width],
- canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
+ 1
+ - distance_map[
+ ymin_valid - ymin : ymax_valid - ymax + height,
+ xmin_valid - xmin : xmax_valid - xmax + width,
+ ],
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
+ )
def distance(self, xs, ys, point_1, point_2):
- '''
+ """
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
- '''
+ """
height, width = xs.shape[:2]
- square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
- 1])
- square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
- 1])
+ square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
+ square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
- point_1[1] - point_2[1])
+ point_1[1] - point_2[1]
+ )
cosin = (square_distance - square_distance_1 - square_distance_2) / (
- 2 * np.sqrt(square_distance_1 * square_distance_2))
+ 2 * np.sqrt(square_distance_1 * square_distance_2)
+ )
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
- result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
- square_distance)
- result[cosin <
- 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
- < 0]
+ result = np.sqrt(
+ square_distance_1 * square_distance_2 * square_sin / square_distance
+ )
+ result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[
+ cosin < 0
+ ]
# self.extend_line(point_1, point_2, result)
return result
def extend_line(self, point_1, point_2, result):
- ex_point_1 = (int(
- round(point_1[0] + (point_1[0] - point_2[0]) * (
- 1 + self.shrink_ratio))), int(
- round(point_1[1] + (point_1[1] - point_2[1]) * (
- 1 + self.shrink_ratio))))
+ ex_point_1 = (
+ int(
+ round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))
+ ),
+ int(
+ round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))
+ ),
+ )
cv2.line(
result,
tuple(ex_point_1),
@@ -126,12 +137,16 @@ def extend_line(self, point_1, point_2, result):
4096.0,
1,
lineType=cv2.LINE_AA,
- shift=0)
- ex_point_2 = (int(
- round(point_2[0] + (point_2[0] - point_1[0]) * (
- 1 + self.shrink_ratio))), int(
- round(point_2[1] + (point_2[1] - point_1[1]) * (
- 1 + self.shrink_ratio))))
+ shift=0,
+ )
+ ex_point_2 = (
+ int(
+ round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))
+ ),
+ int(
+ round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))
+ ),
+ )
cv2.line(
result,
tuple(ex_point_2),
@@ -139,5 +154,6 @@ def extend_line(self, point_1, point_2, result):
4096.0,
1,
lineType=cv2.LINE_AA,
- shift=0)
+ shift=0,
+ )
return ex_point_1, ex_point_2
diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/make_shrink_map.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_shrink_map.py
index 3f268b9dea..b3fea40f39 100644
--- a/benchmark/PaddleOCR_DBNet/data_loader/modules/make_shrink_map.py
+++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_shrink_map.py
@@ -16,9 +16,11 @@ def shrink_polygon_py(polygon, shrink_ratio):
def shrink_polygon_pyclipper(polygon, shrink_ratio):
from shapely.geometry import Polygon
import pyclipper
+
polygon_shape = Polygon(polygon)
- distance = polygon_shape.area * (
- 1 - np.power(shrink_ratio, 2)) / polygon_shape.length
+ distance = (
+ polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
+ )
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -30,19 +32,16 @@ def shrink_polygon_pyclipper(polygon, shrink_ratio):
return shrinked
-class MakeShrinkMap():
- r'''
+class MakeShrinkMap:
+ r"""
Making binary mask from detection data with ICDAR format.
Typically following the process of class `MakeICDARData`.
- '''
+ """
- def __init__(self,
- min_text_size=8,
- shrink_ratio=0.4,
- shrink_type='pyclipper'):
+ def __init__(self, min_text_size=8, shrink_ratio=0.4, shrink_type="pyclipper"):
shrink_func_dict = {
- 'py': shrink_polygon_py,
- 'pyclipper': shrink_polygon_pyclipper
+ "py": shrink_polygon_py,
+ "pyclipper": shrink_polygon_pyclipper,
}
self.shrink_func = shrink_func_dict[shrink_type]
self.min_text_size = min_text_size
@@ -54,13 +53,12 @@ def __call__(self, data: dict) -> dict:
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
- image = data['img']
- text_polys = data['text_polys']
- ignore_tags = data['ignore_tags']
+ image = data["img"]
+ text_polys = data["text_polys"]
+ ignore_tags = data["ignore_tags"]
h, w = image.shape[:2]
- text_polys, ignore_tags = self.validate_polygons(text_polys,
- ignore_tags, h, w)
+ text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)):
@@ -68,26 +66,24 @@ def __call__(self, data: dict) -> dict:
height = max(polygon[:, 1]) - min(polygon[:, 1])
width = max(polygon[:, 0]) - min(polygon[:, 0])
if ignore_tags[i] or min(height, width) < self.min_text_size:
- cv2.fillPoly(mask,
- polygon.astype(np.int32)[np.newaxis, :, :], 0)
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
shrinked = self.shrink_func(polygon, self.shrink_ratio)
if shrinked.size == 0:
- cv2.fillPoly(mask,
- polygon.astype(np.int32)[np.newaxis, :, :], 0)
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
- data['shrink_map'] = gt
- data['shrink_mask'] = mask
+ data["shrink_map"] = gt
+ data["shrink_mask"] = mask
return data
def validate_polygons(self, polygons, ignore_tags, h, w):
- '''
+ """
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
- '''
+ """
if len(polygons) == 0:
return polygons, ignore_tags
assert len(polygons) == len(ignore_tags)
@@ -113,7 +109,7 @@ def polygon_area(self, polygon):
# return edge / 2.
-if __name__ == '__main__':
+if __name__ == "__main__":
from shapely.geometry import Polygon
import pyclipper
diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/random_crop_data.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/random_crop_data.py
index fac2e4c07c..8c6f656452 100644
--- a/benchmark/PaddleOCR_DBNet/data_loader/modules/random_crop_data.py
+++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/random_crop_data.py
@@ -5,13 +5,15 @@
# random crop algorithm similar to https://github.com/argman/EAST
-class EastRandomCropData():
- def __init__(self,
- size=(640, 640),
- max_tries=50,
- min_crop_side_ratio=0.1,
- require_original_image=False,
- keep_ratio=True):
+class EastRandomCropData:
+ def __init__(
+ self,
+ size=(640, 640),
+ max_tries=50,
+ min_crop_side_ratio=0.1,
+ require_original_image=False,
+ keep_ratio=True,
+ ):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
@@ -24,13 +26,11 @@ def __call__(self, data: dict) -> dict:
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
- im = data['img']
- text_polys = data['text_polys']
- ignore_tags = data['ignore_tags']
- texts = data['texts']
- all_care_polys = [
- text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
- ]
+ im = data["img"]
+ text_polys = data["text_polys"]
+ ignore_tags = data["ignore_tags"]
+ texts = data["texts"]
+ all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
# crop 图片 保持比例填充
@@ -41,16 +41,17 @@ def __call__(self, data: dict) -> dict:
w = int(crop_w * scale)
if self.keep_ratio:
if len(im.shape) == 3:
- padimg = np.zeros((self.size[1], self.size[0], im.shape[2]),
- im.dtype)
+ padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
else:
padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
padimg[:h, :w] = cv2.resize(
- im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+ im[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w], (w, h)
+ )
img = padimg
else:
- img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
- tuple(self.size))
+ img = cv2.resize(
+ im[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w], tuple(self.size)
+ )
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
@@ -61,10 +62,10 @@ def __call__(self, data: dict) -> dict:
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
- data['img'] = img
- data['text_polys'] = np.float32(text_polys_crop)
- data['ignore_tags'] = ignore_tags_crop
- data['texts'] = texts_crop
+ data["img"] = img
+ data["text_polys"] = np.float32(text_polys_crop)
+ data["ignore_tags"] = ignore_tags_crop
+ data["texts"] = texts_crop
return data
def is_poly_in_rect(self, poly, x, y, w, h):
@@ -144,13 +145,17 @@ def crop_area(self, im, text_polys):
else:
ymin, ymax = self.random_select(h_axis, h)
- if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
+ if (
+ xmax - xmin < self.min_crop_side_ratio * w
+ or ymax - ymin < self.min_crop_side_ratio * h
+ ):
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
- if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
- ymax - ymin):
+ if not self.is_poly_outside_rect(
+ poly, xmin, ymin, xmax - xmin, ymax - ymin
+ ):
num_poly_in_rect += 1
break
@@ -160,12 +165,12 @@ def crop_area(self, im, text_polys):
return 0, 0, w, h
-class PSERandomCrop():
+class PSERandomCrop:
def __init__(self, size):
self.size = size
def __call__(self, data):
- imgs = data['imgs']
+ imgs = data["imgs"]
h, w = imgs[0].shape[0:2]
th, tw = self.size
@@ -188,7 +193,7 @@ def __call__(self, data):
i = random.randint(tl[0], br[0])
j = random.randint(tl[1], br[1])
# 保证shrink_label_map有文本
- if imgs[1][i:i + th, j:j + tw].sum() <= 0:
+ if imgs[1][i : i + th, j : j + tw].sum() <= 0:
continue
else:
break
@@ -199,8 +204,8 @@ def __call__(self, data):
# return i, j, th, tw
for idx in range(len(imgs)):
if len(imgs[idx].shape) == 3:
- imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
+ imgs[idx] = imgs[idx][i : i + th, j : j + tw, :]
else:
- imgs[idx] = imgs[idx][i:i + th, j:j + tw]
- data['imgs'] = imgs
+ imgs[idx] = imgs[idx][i : i + th, j : j + tw]
+ data["imgs"] = imgs
return data
diff --git a/benchmark/PaddleOCR_DBNet/models/__init__.py b/benchmark/PaddleOCR_DBNet/models/__init__.py
index 26ff73ff69..a2669bfee7 100644
--- a/benchmark/PaddleOCR_DBNet/models/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/models/__init__.py
@@ -5,8 +5,8 @@
from .model import Model
from .losses import build_loss
-__all__ = ['build_loss', 'build_model']
-support_model = ['Model']
+__all__ = ["build_loss", "build_model"]
+support_model = ["Model"]
def build_model(config):
@@ -14,7 +14,9 @@ def build_model(config):
get architecture model class
"""
copy_config = copy.deepcopy(config)
- arch_type = copy_config.pop('type')
- assert arch_type in support_model, f'{arch_type} is not developed yet!, only {support_model} are support now'
+ arch_type = copy_config.pop("type")
+ assert (
+ arch_type in support_model
+ ), f"{arch_type} is not developed yet!, only {support_model} are support now"
arch_model = eval(arch_type)(copy_config)
return arch_model
diff --git a/benchmark/PaddleOCR_DBNet/models/backbone/__init__.py b/benchmark/PaddleOCR_DBNet/models/backbone/__init__.py
index 740c8d5ff0..5d0d8a2c60 100644
--- a/benchmark/PaddleOCR_DBNet/models/backbone/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/models/backbone/__init__.py
@@ -4,15 +4,22 @@
from .resnet import *
-__all__ = ['build_backbone']
+__all__ = ["build_backbone"]
support_backbone = [
- 'resnet18', 'deformable_resnet18', 'deformable_resnet50', 'resnet50',
- 'resnet34', 'resnet101', 'resnet152'
+ "resnet18",
+ "deformable_resnet18",
+ "deformable_resnet50",
+ "resnet50",
+ "resnet34",
+ "resnet101",
+ "resnet152",
]
def build_backbone(backbone_name, **kwargs):
- assert backbone_name in support_backbone, f'all support backbone is {support_backbone}'
+ assert (
+ backbone_name in support_backbone
+ ), f"all support backbone is {support_backbone}"
backbone = eval(backbone_name)(**kwargs)
return backbone
diff --git a/benchmark/PaddleOCR_DBNet/models/backbone/resnet.py b/benchmark/PaddleOCR_DBNet/models/backbone/resnet.py
index 9b30b382d9..cb4ea809f9 100644
--- a/benchmark/PaddleOCR_DBNet/models/backbone/resnet.py
+++ b/benchmark/PaddleOCR_DBNet/models/backbone/resnet.py
@@ -5,40 +5,44 @@
BatchNorm2d = nn.BatchNorm2D
__all__ = [
- 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
- 'deformable_resnet18', 'deformable_resnet50', 'resnet152'
+ "ResNet",
+ "resnet18",
+ "resnet34",
+ "resnet50",
+ "resnet101",
+ "deformable_resnet18",
+ "deformable_resnet50",
+ "resnet152",
]
model_urls = {
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
+ "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
+ "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
+ "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
+ "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}
def constant_init(module, constant, bias=0):
module.weight = paddle.create_parameter(
shape=module.weight.shape,
- dtype='float32',
- default_initializer=paddle.nn.initializer.Constant(constant))
- if hasattr(module, 'bias'):
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(constant),
+ )
+ if hasattr(module, "bias"):
module.bias = paddle.create_parameter(
shape=module.bias.shape,
- dtype='float32',
- default_initializer=paddle.nn.initializer.Constant(bias))
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(bias),
+ )
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2D(
- in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias_attr=False)
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias_attr=False
+ )
class BasicBlock(nn.Layer):
@@ -53,18 +57,19 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
self.with_modulated_dcn = False
if not self.with_dcn:
self.conv2 = nn.Conv2D(
- planes, planes, kernel_size=3, padding=1, bias_attr=False)
+ planes, planes, kernel_size=3, padding=1, bias_attr=False
+ )
else:
from paddle.version.ops import DeformConv2D
- deformable_groups = dcn.get('deformable_groups', 1)
+
+ deformable_groups = dcn.get("deformable_groups", 1)
offset_channels = 18
self.conv2_offset = nn.Conv2D(
- planes,
- deformable_groups * offset_channels,
- kernel_size=3,
- padding=1)
+ planes, deformable_groups * offset_channels, kernel_size=3, padding=1
+ )
self.conv2 = DeformConv2D(
- planes, planes, kernel_size=3, padding=1, bias_attr=False)
+ planes, planes, kernel_size=3, padding=1, bias_attr=False
+ )
self.bn2 = BatchNorm2d(planes, momentum=0.1)
self.downsample = downsample
self.stride = stride
@@ -104,32 +109,25 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
self.with_modulated_dcn = False
if not self.with_dcn:
self.conv2 = nn.Conv2D(
- planes,
- planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias_attr=False)
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False
+ )
else:
- deformable_groups = dcn.get('deformable_groups', 1)
+ deformable_groups = dcn.get("deformable_groups", 1)
from paddle.vision.ops import DeformConv2D
+
offset_channels = 18
self.conv2_offset = nn.Conv2D(
planes,
deformable_groups * offset_channels,
stride=stride,
kernel_size=3,
- padding=1)
- self.conv2 = DeformConv2D(
- planes,
- planes,
- kernel_size=3,
padding=1,
- stride=stride,
- bias_attr=False)
+ )
+ self.conv2 = DeformConv2D(
+ planes, planes, kernel_size=3, padding=1, stride=stride, bias_attr=False
+ )
self.bn2 = BatchNorm2d(planes, momentum=0.1)
- self.conv3 = nn.Conv2D(
- planes, planes * 4, kernel_size=1, bias_attr=False)
+ self.conv3 = nn.Conv2D(planes, planes * 4, kernel_size=1, bias_attr=False)
self.bn3 = BatchNorm2d(planes * 4, momentum=0.1)
self.relu = nn.ReLU()
self.downsample = downsample
@@ -172,12 +170,8 @@ def __init__(self, block, layers, in_channels=3, dcn=None):
super(ResNet, self).__init__()
self.out_channels = []
self.conv1 = nn.Conv2D(
- in_channels,
- 64,
- kernel_size=7,
- stride=2,
- padding=3,
- bias_attr=False)
+ in_channels, 64, kernel_size=7, stride=2, padding=3, bias_attr=False
+ )
self.bn1 = BatchNorm2d(64, momentum=0.1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
@@ -189,7 +183,7 @@ def __init__(self, block, layers, in_channels=3, dcn=None):
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
- if hasattr(m, 'conv2_offset'):
+ if hasattr(m, "conv2_offset"):
constant_init(m.conv2_offset, 0)
def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
@@ -201,9 +195,10 @@ def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
planes * block.expansion,
kernel_size=1,
stride=stride,
- bias_attr=False),
- BatchNorm2d(
- planes * block.expansion, momentum=0.1), )
+ bias_attr=False,
+ ),
+ BatchNorm2d(planes * block.expansion, momentum=0.1),
+ )
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn))
@@ -230,12 +225,15 @@ def forward(self, x):
def load_torch_params(paddle_model, torch_patams):
paddle_params = paddle_model.state_dict()
- fc_names = ['classifier']
+ fc_names = ["classifier"]
for key, torch_value in torch_patams.items():
- if 'num_batches_tracked' in key:
+ if "num_batches_tracked" in key:
continue
- key = key.replace("running_var", "_variance").replace(
- "running_mean", "_mean").replace("module.", "")
+ key = (
+ key.replace("running_var", "_variance")
+ .replace("running_mean", "_mean")
+ .replace("module.", "")
+ )
torch_value = torch_value.detach().cpu().numpy()
if key in paddle_params:
flag = [i in key for i in fc_names]
@@ -247,12 +245,13 @@ def load_torch_params(paddle_model, torch_patams):
torch_value = torch_value.transpose(new_shape)
paddle_params[key] = torch_value
else:
- print(f'{key} not in paddle')
+ print(f"{key} not in paddle")
paddle_model.set_state_dict(paddle_params)
def load_models(model, model_name):
import torch.utils.model_zoo as model_zoo
+
torch_patams = model_zoo.load_url(model_urls[model_name])
load_torch_params(model, torch_patams)
@@ -264,11 +263,11 @@ def resnet18(pretrained=True, **kwargs):
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
- assert kwargs.get(
- 'in_channels',
- 3) == 3, 'in_channels must be 3 whem pretrained is True'
- print('load from imagenet')
- load_models(model, 'resnet18')
+ assert (
+ kwargs.get("in_channels", 3) == 3
+ ), "in_channels must be 3 whem pretrained is True"
+ print("load from imagenet")
+ load_models(model, "resnet18")
return model
@@ -277,15 +276,13 @@ def deformable_resnet18(pretrained=True, **kwargs):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
- model = ResNet(
- BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1), **kwargs)
+ model = ResNet(BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1), **kwargs)
if pretrained:
- assert kwargs.get(
- 'in_channels',
- 3) == 3, 'in_channels must be 3 whem pretrained is True'
- print('load from imagenet')
- model.load_state_dict(
- model_zoo.load_url(model_urls['resnet18']), strict=False)
+ assert (
+ kwargs.get("in_channels", 3) == 3
+ ), "in_channels must be 3 whem pretrained is True"
+ print("load from imagenet")
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]), strict=False)
return model
@@ -296,11 +293,10 @@ def resnet34(pretrained=True, **kwargs):
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
- assert kwargs.get(
- 'in_channels',
- 3) == 3, 'in_channels must be 3 whem pretrained is True'
- model.load_state_dict(
- model_zoo.load_url(model_urls['resnet34']), strict=False)
+ assert (
+ kwargs.get("in_channels", 3) == 3
+ ), "in_channels must be 3 whem pretrained is True"
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]), strict=False)
return model
@@ -311,10 +307,10 @@ def resnet50(pretrained=True, **kwargs):
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
- assert kwargs.get(
- 'in_channels',
- 3) == 3, 'in_channels must be 3 whem pretrained is True'
- load_models(model, 'resnet50')
+ assert (
+ kwargs.get("in_channels", 3) == 3
+ ), "in_channels must be 3 whem pretrained is True"
+ load_models(model, "resnet50")
return model
@@ -323,14 +319,12 @@ def deformable_resnet50(pretrained=True, **kwargs):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
- model = ResNet(
- Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1), **kwargs)
+ model = ResNet(Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1), **kwargs)
if pretrained:
- assert kwargs.get(
- 'in_channels',
- 3) == 3, 'in_channels must be 3 whem pretrained is True'
- model.load_state_dict(
- model_zoo.load_url(model_urls['resnet50']), strict=False)
+ assert (
+ kwargs.get("in_channels", 3) == 3
+ ), "in_channels must be 3 whem pretrained is True"
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]), strict=False)
return model
@@ -341,11 +335,10 @@ def resnet101(pretrained=True, **kwargs):
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
- assert kwargs.get(
- 'in_channels',
- 3) == 3, 'in_channels must be 3 whem pretrained is True'
- model.load_state_dict(
- model_zoo.load_url(model_urls['resnet101']), strict=False)
+ assert (
+ kwargs.get("in_channels", 3) == 3
+ ), "in_channels must be 3 whem pretrained is True"
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]), strict=False)
return model
@@ -356,16 +349,14 @@ def resnet152(pretrained=True, **kwargs):
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
- assert kwargs.get(
- 'in_channels',
- 3) == 3, 'in_channels must be 3 whem pretrained is True'
- model.load_state_dict(
- model_zoo.load_url(model_urls['resnet152']), strict=False)
+ assert (
+ kwargs.get("in_channels", 3) == 3
+ ), "in_channels must be 3 whem pretrained is True"
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]), strict=False)
return model
-if __name__ == '__main__':
-
+if __name__ == "__main__":
x = paddle.zeros([2, 3, 640, 640])
net = resnet50(pretrained=True)
y = net(x)
diff --git a/benchmark/PaddleOCR_DBNet/models/basic.py b/benchmark/PaddleOCR_DBNet/models/basic.py
index f661878df7..49d6a6a901 100644
--- a/benchmark/PaddleOCR_DBNet/models/basic.py
+++ b/benchmark/PaddleOCR_DBNet/models/basic.py
@@ -5,17 +5,19 @@
class ConvBnRelu(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=True,
- padding_mode='zeros',
- inplace=True):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ padding_mode="zeros",
+ inplace=True,
+ ):
super().__init__()
self.conv = nn.Conv2D(
in_channels=in_channels,
@@ -26,7 +28,8 @@ def __init__(self,
dilation=dilation,
groups=groups,
bias_attr=bias,
- padding_mode=padding_mode)
+ padding_mode=padding_mode,
+ )
self.bn = nn.BatchNorm2D(out_channels)
self.relu = nn.ReLU()
diff --git a/benchmark/PaddleOCR_DBNet/models/head/DBHead.py b/benchmark/PaddleOCR_DBNet/models/head/DBHead.py
index 29277cec9d..3a57914bbd 100644
--- a/benchmark/PaddleOCR_DBNet/models/head/DBHead.py
+++ b/benchmark/PaddleOCR_DBNet/models/head/DBHead.py
@@ -15,32 +15,32 @@ def __init__(self, in_channels, out_channels, k=50):
in_channels // 4,
3,
padding=1,
- weight_attr=ParamAttr(
- initializer=nn.initializer.KaimingNormal())),
+ weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
+ ),
nn.BatchNorm2D(
in_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
- bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
+ bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
+ ),
nn.ReLU(),
nn.Conv2DTranspose(
in_channels // 4,
in_channels // 4,
2,
2,
- weight_attr=ParamAttr(
- initializer=nn.initializer.KaimingNormal())),
+ weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
+ ),
nn.BatchNorm2D(
in_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
- bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
+ bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
+ ),
nn.ReLU(),
nn.Conv2DTranspose(
- in_channels // 4,
- 1,
- 2,
- 2,
- weight_attr=nn.initializer.KaimingNormal()),
- nn.Sigmoid())
+ in_channels // 4, 1, 2, 2, weight_attr=nn.initializer.KaimingNormal()
+ ),
+ nn.Sigmoid(),
+ )
self.thresh = self._init_thresh(in_channels)
@@ -49,17 +49,12 @@ def forward(self, x):
threshold_maps = self.thresh(x)
if self.training:
binary_maps = self.step_function(shrink_maps, threshold_maps)
- y = paddle.concat(
- (shrink_maps, threshold_maps, binary_maps), axis=1)
+ y = paddle.concat((shrink_maps, threshold_maps, binary_maps), axis=1)
else:
y = paddle.concat((shrink_maps, threshold_maps), axis=1)
return y
- def _init_thresh(self,
- inner_channels,
- serial=False,
- smooth=False,
- bias=False):
+ def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
in_channels = inner_channels
if serial:
in_channels += 1
@@ -70,48 +65,44 @@ def _init_thresh(self,
3,
padding=1,
bias_attr=bias,
- weight_attr=ParamAttr(
- initializer=nn.initializer.KaimingNormal())),
+ weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
+ ),
nn.BatchNorm2D(
inner_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
- bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
+ bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
+ ),
nn.ReLU(),
self._init_upsample(
- inner_channels // 4,
- inner_channels // 4,
- smooth=smooth,
- bias=bias),
+ inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias
+ ),
nn.BatchNorm2D(
inner_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
- bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
+ bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
+ ),
nn.ReLU(),
- self._init_upsample(
- inner_channels // 4, 1, smooth=smooth, bias=bias),
- nn.Sigmoid())
+ self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
+ nn.Sigmoid(),
+ )
return self.thresh
- def _init_upsample(self,
- in_channels,
- out_channels,
- smooth=False,
- bias=False):
+ def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
if smooth:
inter_out_channels = out_channels
if out_channels == 1:
inter_out_channels = in_channels
module_list = [
- nn.Upsample(
- scale_factor=2, mode='nearest'), nn.Conv2D(
- in_channels,
- inter_out_channels,
- 3,
- 1,
- 1,
- bias_attr=bias,
- weight_attr=ParamAttr(
- initializer=nn.initializer.KaimingNormal()))
+ nn.Upsample(scale_factor=2, mode="nearest"),
+ nn.Conv2D(
+ in_channels,
+ inter_out_channels,
+ 3,
+ 1,
+ 1,
+ bias_attr=bias,
+ weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
+ ),
]
if out_channels == 1:
module_list.append(
@@ -123,7 +114,10 @@ def _init_upsample(self,
padding=1,
bias_attr=True,
weight_attr=ParamAttr(
- initializer=nn.initializer.KaimingNormal())))
+ initializer=nn.initializer.KaimingNormal()
+ ),
+ )
+ )
return nn.Sequential(module_list)
else:
return nn.Conv2DTranspose(
@@ -131,8 +125,8 @@ def _init_upsample(self,
out_channels,
2,
2,
- weight_attr=ParamAttr(
- initializer=nn.initializer.KaimingNormal()))
+ weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
+ )
def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
diff --git a/benchmark/PaddleOCR_DBNet/models/head/__init__.py b/benchmark/PaddleOCR_DBNet/models/head/__init__.py
index 5610c69754..708ea9afd8 100644
--- a/benchmark/PaddleOCR_DBNet/models/head/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/models/head/__init__.py
@@ -3,11 +3,11 @@
# @Author : zhoujun
from .DBHead import DBHead
-__all__ = ['build_head']
-support_head = ['DBHead']
+__all__ = ["build_head"]
+support_head = ["DBHead"]
def build_head(head_name, **kwargs):
- assert head_name in support_head, f'all support head is {support_head}'
+ assert head_name in support_head, f"all support head is {support_head}"
head = eval(head_name)(**kwargs)
- return head
\ No newline at end of file
+ return head
diff --git a/benchmark/PaddleOCR_DBNet/models/losses/DB_loss.py b/benchmark/PaddleOCR_DBNet/models/losses/DB_loss.py
index 74d240c17b..bad05b697a 100644
--- a/benchmark/PaddleOCR_DBNet/models/losses/DB_loss.py
+++ b/benchmark/PaddleOCR_DBNet/models/losses/DB_loss.py
@@ -3,12 +3,7 @@
class DBLoss(paddle.nn.Layer):
- def __init__(self,
- alpha=1.0,
- beta=10,
- ohem_ratio=3,
- reduction='mean',
- eps=1e-06):
+ def __init__(self, alpha=1.0, beta=10, ohem_ratio=3, reduction="mean", eps=1e-06):
"""
Implement PSE Loss.
:param alpha: binary_map loss 前面的系数
@@ -17,7 +12,7 @@ def __init__(self,
:param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
"""
super().__init__()
- assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
+ assert reduction in ["mean", "sum"], " reduction must in ['mean','sum']"
self.alpha = alpha
self.beta = beta
self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
@@ -30,20 +25,26 @@ def forward(self, pred, batch):
shrink_maps = pred[:, 0, :, :]
threshold_maps = pred[:, 1, :, :]
binary_maps = pred[:, 2, :, :]
- loss_shrink_maps = self.bce_loss(shrink_maps, batch['shrink_map'],
- batch['shrink_mask'])
+ loss_shrink_maps = self.bce_loss(
+ shrink_maps, batch["shrink_map"], batch["shrink_mask"]
+ )
loss_threshold_maps = self.l1_loss(
- threshold_maps, batch['threshold_map'], batch['threshold_mask'])
+ threshold_maps, batch["threshold_map"], batch["threshold_mask"]
+ )
metrics = dict(
- loss_shrink_maps=loss_shrink_maps,
- loss_threshold_maps=loss_threshold_maps)
+ loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps
+ )
if pred.shape[1] > 2:
- loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'],
- batch['shrink_mask'])
- metrics['loss_binary_maps'] = loss_binary_maps
- loss_all = (self.alpha * loss_shrink_maps + self.beta *
- loss_threshold_maps + loss_binary_maps)
- metrics['loss'] = loss_all
+ loss_binary_maps = self.dice_loss(
+ binary_maps, batch["shrink_map"], batch["shrink_mask"]
+ )
+ metrics["loss_binary_maps"] = loss_binary_maps
+ loss_all = (
+ self.alpha * loss_shrink_maps
+ + self.beta * loss_threshold_maps
+ + loss_binary_maps
+ )
+ metrics["loss"] = loss_all
else:
- metrics['loss'] = loss_shrink_maps
+ metrics["loss"] = loss_shrink_maps
return metrics
diff --git a/benchmark/PaddleOCR_DBNet/models/losses/__init__.py b/benchmark/PaddleOCR_DBNet/models/losses/__init__.py
index 9dc0f1033b..e783f7ce1e 100644
--- a/benchmark/PaddleOCR_DBNet/models/losses/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/models/losses/__init__.py
@@ -4,13 +4,13 @@
import copy
from .DB_loss import DBLoss
-__all__ = ['build_loss']
-support_loss = ['DBLoss']
+__all__ = ["build_loss"]
+support_loss = ["DBLoss"]
def build_loss(config):
copy_config = copy.deepcopy(config)
- loss_type = copy_config.pop('type')
- assert loss_type in support_loss, f'all support loss is {support_loss}'
+ loss_type = copy_config.pop("type")
+ assert loss_type in support_loss, f"all support loss is {support_loss}"
criterion = eval(loss_type)(**copy_config)
return criterion
diff --git a/benchmark/PaddleOCR_DBNet/models/losses/basic_loss.py b/benchmark/PaddleOCR_DBNet/models/losses/basic_loss.py
index 8e68cb172a..27ac520093 100644
--- a/benchmark/PaddleOCR_DBNet/models/losses/basic_loss.py
+++ b/benchmark/PaddleOCR_DBNet/models/losses/basic_loss.py
@@ -6,7 +6,7 @@
class BalanceCrossEntropyLoss(nn.Layer):
- '''
+ """
Balanced cross entropy loss.
Shape:
- Input: :math:`(N, 1, H, W)`
@@ -14,36 +14,40 @@ class BalanceCrossEntropyLoss(nn.Layer):
- Mask: :math:`(N, H, W)`, same spatial shape as the input
- Output: scalar.
- '''
+ """
def __init__(self, negative_ratio=3.0, eps=1e-6):
super(BalanceCrossEntropyLoss, self).__init__()
self.negative_ratio = negative_ratio
self.eps = eps
- def forward(self,
- pred: paddle.Tensor,
- gt: paddle.Tensor,
- mask: paddle.Tensor,
- return_origin=False):
- '''
+ def forward(
+ self,
+ pred: paddle.Tensor,
+ gt: paddle.Tensor,
+ mask: paddle.Tensor,
+ return_origin=False,
+ ):
+ """
Args:
pred: shape :math:`(N, 1, H, W)`, the prediction of network
gt: shape :math:`(N, 1, H, W)`, the target
mask: shape :math:`(N, H, W)`, the mask indicates positive regions
- '''
- positive = (gt * mask)
- negative = ((1 - gt) * mask)
+ """
+ positive = gt * mask
+ negative = (1 - gt) * mask
positive_count = int(positive.sum())
negative_count = min(
- int(negative.sum()), int(positive_count * self.negative_ratio))
- loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
+ int(negative.sum()), int(positive_count * self.negative_ratio)
+ )
+ loss = nn.functional.binary_cross_entropy(pred, gt, reduction="none")
positive_loss = loss * positive
negative_loss = loss * negative
negative_loss, _ = negative_loss.reshape([-1]).topk(negative_count)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
- positive_count + negative_count + self.eps)
+ positive_count + negative_count + self.eps
+ )
if return_origin:
return balance_loss, loss
@@ -51,23 +55,23 @@ def forward(self,
class DiceLoss(nn.Layer):
- '''
+ """
Loss function from https://arxiv.org/abs/1707.03237,
where iou computation is introduced heatmap manner to measure the
diversity bwtween tow heatmaps.
- '''
+ """
def __init__(self, eps=1e-6):
super(DiceLoss, self).__init__()
self.eps = eps
def forward(self, pred: paddle.Tensor, gt, mask, weights=None):
- '''
+ """
pred: one or two heatmaps of shape (N, 1, H, W),
the losses of tow heatmaps are added together.
gt: (N, 1, H, W)
mask: (N, H, W)
- '''
+ """
return self._compute(pred, gt, mask, weights)
def _compute(self, pred, gt, mask, weights):
diff --git a/benchmark/PaddleOCR_DBNet/models/model.py b/benchmark/PaddleOCR_DBNet/models/model.py
index ee24ff5b3d..ed36f9072e 100644
--- a/benchmark/PaddleOCR_DBNet/models/model.py
+++ b/benchmark/PaddleOCR_DBNet/models/model.py
@@ -18,22 +18,22 @@ def __init__(self, model_config: dict):
"""
super().__init__()
model_config = Dict(model_config)
- backbone_type = model_config.backbone.pop('type')
- neck_type = model_config.neck.pop('type')
- head_type = model_config.head.pop('type')
+ backbone_type = model_config.backbone.pop("type")
+ neck_type = model_config.neck.pop("type")
+ head_type = model_config.head.pop("type")
self.backbone = build_backbone(backbone_type, **model_config.backbone)
self.neck = build_neck(
- neck_type,
- in_channels=self.backbone.out_channels,
- **model_config.neck)
+ neck_type, in_channels=self.backbone.out_channels, **model_config.neck
+ )
self.head = build_head(
- head_type, in_channels=self.neck.out_channels, **model_config.head)
- self.name = f'{backbone_type}_{neck_type}_{head_type}'
+ head_type, in_channels=self.neck.out_channels, **model_config.head
+ )
+ self.name = f"{backbone_type}_{neck_type}_{head_type}"
def forward(self, x):
_, _, H, W = x.shape
backbone_out = self.backbone(x)
neck_out = self.neck(backbone_out)
y = self.head(neck_out)
- y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True)
+ y = F.interpolate(y, size=(H, W), mode="bilinear", align_corners=True)
return y
diff --git a/benchmark/PaddleOCR_DBNet/models/neck/FPN.py b/benchmark/PaddleOCR_DBNet/models/neck/FPN.py
index 53a3fa4b80..3c49adf679 100644
--- a/benchmark/PaddleOCR_DBNet/models/neck/FPN.py
+++ b/benchmark/PaddleOCR_DBNet/models/neck/FPN.py
@@ -20,42 +20,33 @@ def __init__(self, in_channels, inner_channels=256, **kwargs):
inner_channels = inner_channels // 4
# reduce layers
self.reduce_conv_c2 = ConvBnRelu(
- in_channels[0], inner_channels, kernel_size=1, inplace=inplace)
+ in_channels[0], inner_channels, kernel_size=1, inplace=inplace
+ )
self.reduce_conv_c3 = ConvBnRelu(
- in_channels[1], inner_channels, kernel_size=1, inplace=inplace)
+ in_channels[1], inner_channels, kernel_size=1, inplace=inplace
+ )
self.reduce_conv_c4 = ConvBnRelu(
- in_channels[2], inner_channels, kernel_size=1, inplace=inplace)
+ in_channels[2], inner_channels, kernel_size=1, inplace=inplace
+ )
self.reduce_conv_c5 = ConvBnRelu(
- in_channels[3], inner_channels, kernel_size=1, inplace=inplace)
+ in_channels[3], inner_channels, kernel_size=1, inplace=inplace
+ )
# Smooth layers
self.smooth_p4 = ConvBnRelu(
- inner_channels,
- inner_channels,
- kernel_size=3,
- padding=1,
- inplace=inplace)
+ inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace
+ )
self.smooth_p3 = ConvBnRelu(
- inner_channels,
- inner_channels,
- kernel_size=3,
- padding=1,
- inplace=inplace)
+ inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace
+ )
self.smooth_p2 = ConvBnRelu(
- inner_channels,
- inner_channels,
- kernel_size=3,
- padding=1,
- inplace=inplace)
+ inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace
+ )
self.conv = nn.Sequential(
- nn.Conv2D(
- self.conv_out,
- self.conv_out,
- kernel_size=3,
- padding=1,
- stride=1),
+ nn.Conv2D(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2D(self.conv_out),
- nn.ReLU())
+ nn.ReLU(),
+ )
self.out_channels = self.conv_out
def forward(self, x):
diff --git a/benchmark/PaddleOCR_DBNet/models/neck/__init__.py b/benchmark/PaddleOCR_DBNet/models/neck/__init__.py
index 7655341378..d63b8b5186 100644
--- a/benchmark/PaddleOCR_DBNet/models/neck/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/models/neck/__init__.py
@@ -3,11 +3,11 @@
# @Author : zhoujun
from .FPN import FPN
-__all__ = ['build_neck']
-support_neck = ['FPN']
+__all__ = ["build_neck"]
+support_neck = ["FPN"]
def build_neck(neck_name, **kwargs):
- assert neck_name in support_neck, f'all support neck is {support_neck}'
+ assert neck_name in support_neck, f"all support neck is {support_neck}"
neck = eval(neck_name)(**kwargs)
return neck
diff --git a/benchmark/PaddleOCR_DBNet/post_processing/__init__.py b/benchmark/PaddleOCR_DBNet/post_processing/__init__.py
index 2f8e43223d..7d3a448ced 100644
--- a/benchmark/PaddleOCR_DBNet/post_processing/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/post_processing/__init__.py
@@ -7,7 +7,7 @@
def get_post_processing(config):
try:
- cls = eval(config['type'])(**config['args'])
+ cls = eval(config["type"])(**config["args"])
return cls
except:
- return None
\ No newline at end of file
+ return None
diff --git a/benchmark/PaddleOCR_DBNet/post_processing/seg_detector_representer.py b/benchmark/PaddleOCR_DBNet/post_processing/seg_detector_representer.py
index f1273dcfcc..2ee3411e0b 100644
--- a/benchmark/PaddleOCR_DBNet/post_processing/seg_detector_representer.py
+++ b/benchmark/PaddleOCR_DBNet/post_processing/seg_detector_representer.py
@@ -5,12 +5,10 @@
from shapely.geometry import Polygon
-class SegDetectorRepresenter():
- def __init__(self,
- thresh=0.3,
- box_thresh=0.7,
- max_candidates=1000,
- unclip_ratio=1.5):
+class SegDetectorRepresenter:
+ def __init__(
+ self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5
+ ):
self.min_size = 3
self.thresh = thresh
self.box_thresh = box_thresh
@@ -18,7 +16,7 @@ def __init__(self,
self.unclip_ratio = unclip_ratio
def __call__(self, batch, pred, is_output_polygon=False):
- '''
+ """
batch: (image, polygons, ignore_tags
batch: a dict produced by dataloaders.
image: tensor of shape (N, C, H, W).
@@ -30,7 +28,7 @@ def __call__(self, batch, pred, is_output_polygon=False):
binary: text region segmentation map, with shape (N, H, W)
thresh: [if exists] thresh hold prediction with shape (N, H, W)
thresh_binary: [if exists] binarized with threshhold, (N, H, W)
- '''
+ """
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
@@ -38,13 +36,15 @@ def __call__(self, batch, pred, is_output_polygon=False):
boxes_batch = []
scores_batch = []
for batch_index in range(pred.shape[0]):
- height, width = batch['shape'][batch_index]
+ height, width = batch["shape"][batch_index]
if is_output_polygon:
boxes, scores = self.polygons_from_bitmap(
- pred[batch_index], segmentation[batch_index], width, height)
+ pred[batch_index], segmentation[batch_index], width, height
+ )
else:
boxes, scores = self.boxes_from_bitmap(
- pred[batch_index], segmentation[batch_index], width, height)
+ pred[batch_index], segmentation[batch_index], width, height
+ )
boxes_batch.append(boxes)
scores_batch.append(scores)
return boxes_batch, scores_batch
@@ -53,10 +53,10 @@ def binarize(self, pred):
return pred > self.thresh
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
- '''
+ """
_bitmap: single map with shape (H, W),
whose values are binarized as {0, 1}
- '''
+ """
assert len(_bitmap.shape) == 2
bitmap = _bitmap # The first channel
@@ -64,10 +64,11 @@ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
boxes = []
scores = []
- contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
- cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+ contours, _ = cv2.findContours(
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+ )
- for contour in contours[:self.max_candidates]:
+ for contour in contours[: self.max_candidates]:
epsilon = 0.005 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
@@ -95,28 +96,29 @@ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
dest_width = dest_width.item()
dest_height = dest_height.item()
- box[:, 0] = np.clip(
- np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
- np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
+ )
boxes.append(box)
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
- '''
+ """
_bitmap: single map with shape (H, W),
whose values are binarized as {0, 1}
- '''
+ """
assert len(_bitmap.shape) == 2
bitmap = _bitmap # The first channel
height, width = bitmap.shape
- contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
- cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+ contours, _ = cv2.findContours(
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+ )
num_contours = min(len(contours), self.max_candidates)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
- scores = np.zeros((num_contours, ), dtype=np.float32)
+ scores = np.zeros((num_contours,), dtype=np.float32)
for index in range(num_contours):
contour = contours[index].squeeze(1)
@@ -128,8 +130,7 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
if self.box_thresh > score:
continue
- box = self.unclip(
- points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
+ box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
@@ -138,10 +139,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
dest_width = dest_width.item()
dest_height = dest_height.item()
- box[:, 0] = np.clip(
- np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
- np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
+ )
boxes[index, :, :] = box.astype(np.int16)
scores[index] = score
return boxes, scores
@@ -172,9 +173,7 @@ def get_mini_boxes(self, contour):
index_2 = 3
index_3 = 2
- box = [
- points[index_1], points[index_2], points[index_3], points[index_4]
- ]
+ box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
@@ -189,4 +188,4 @@ def box_score_fast(self, bitmap, _box):
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
- return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
diff --git a/benchmark/PaddleOCR_DBNet/tools/__init__.py b/benchmark/PaddleOCR_DBNet/tools/__init__.py
index 7cbf835d7e..aa5fe6c395 100644
--- a/benchmark/PaddleOCR_DBNet/tools/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/tools/__init__.py
@@ -1,3 +1,3 @@
# -*- coding: utf-8 -*-
# @Time : 2019/12/8 13:14
-# @Author : zhoujun
\ No newline at end of file
+# @Author : zhoujun
diff --git a/benchmark/PaddleOCR_DBNet/tools/eval.py b/benchmark/PaddleOCR_DBNet/tools/eval.py
index fe514ddc0d..21a5a4ad74 100644
--- a/benchmark/PaddleOCR_DBNet/tools/eval.py
+++ b/benchmark/PaddleOCR_DBNet/tools/eval.py
@@ -4,6 +4,7 @@
import os
import sys
import pathlib
+
__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))
@@ -14,30 +15,35 @@
from tqdm.auto import tqdm
-class EVAL():
+class EVAL:
def __init__(self, model_path, gpu_id=0):
from models import build_model
from data_loader import get_dataloader
from post_processing import get_post_processing
from utils import get_metric
+
self.gpu_id = gpu_id
- if self.gpu_id is not None and isinstance(
- self.gpu_id, int) and paddle.device.is_compiled_with_cuda():
+ if (
+ self.gpu_id is not None
+ and isinstance(self.gpu_id, int)
+ and paddle.device.is_compiled_with_cuda()
+ ):
paddle.device.set_device("gpu:{}".format(self.gpu_id))
else:
paddle.device.set_device("cpu")
checkpoint = paddle.load(model_path)
- config = checkpoint['config']
- config['arch']['backbone']['pretrained'] = False
+ config = checkpoint["config"]
+ config["arch"]["backbone"]["pretrained"] = False
- self.validate_loader = get_dataloader(config['dataset']['validate'],
- config['distributed'])
+ self.validate_loader = get_dataloader(
+ config["dataset"]["validate"], config["distributed"]
+ )
- self.model = build_model(config['arch'])
- self.model.set_state_dict(checkpoint['state_dict'])
+ self.model = build_model(config["arch"])
+ self.model.set_state_dict(checkpoint["state_dict"])
- self.post_process = get_post_processing(config['post_processing'])
- self.metric_cls = get_metric(config['metric'])
+ self.post_process = get_post_processing(config["post_processing"])
+ self.metric_cls = get_metric(config["metric"])
def eval(self):
self.model.eval()
@@ -45,42 +51,42 @@ def eval(self):
total_frame = 0.0
total_time = 0.0
for i, batch in tqdm(
- enumerate(self.validate_loader),
- total=len(self.validate_loader),
- desc='test model'):
+ enumerate(self.validate_loader),
+ total=len(self.validate_loader),
+ desc="test model",
+ ):
with paddle.no_grad():
start = time.time()
- preds = self.model(batch['img'])
+ preds = self.model(batch["img"])
boxes, scores = self.post_process(
- batch,
- preds,
- is_output_polygon=self.metric_cls.is_output_polygon)
- total_frame += batch['img'].shape[0]
+ batch, preds, is_output_polygon=self.metric_cls.is_output_polygon
+ )
+ total_frame += batch["img"].shape[0]
total_time += time.time() - start
- raw_metric = self.metric_cls.validate_measure(batch,
- (boxes, scores))
+ raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
raw_metrics.append(raw_metric)
metrics = self.metric_cls.gather_measure(raw_metrics)
- print('FPS:{}'.format(total_frame / total_time))
+ print("FPS:{}".format(total_frame / total_time))
return {
- 'recall': metrics['recall'].avg,
- 'precision': metrics['precision'].avg,
- 'fmeasure': metrics['fmeasure'].avg
+ "recall": metrics["recall"].avg,
+ "precision": metrics["precision"].avg,
+ "fmeasure": metrics["fmeasure"].avg,
}
def init_args():
- parser = argparse.ArgumentParser(description='DBNet.paddle')
+ parser = argparse.ArgumentParser(description="DBNet.paddle")
parser.add_argument(
- '--model_path',
+ "--model_path",
required=False,
- default='output/DBNet_resnet18_FPN_DBHead/checkpoint/1.pth',
- type=str)
+ default="output/DBNet_resnet18_FPN_DBHead/checkpoint/1.pth",
+ type=str,
+ )
args = parser.parse_args()
return args
-if __name__ == '__main__':
+if __name__ == "__main__":
args = init_args()
eval = EVAL(args.model_path)
result = eval.eval()
diff --git a/benchmark/PaddleOCR_DBNet/tools/export_model.py b/benchmark/PaddleOCR_DBNet/tools/export_model.py
index 59a318a196..71a33d0936 100644
--- a/benchmark/PaddleOCR_DBNet/tools/export_model.py
+++ b/benchmark/PaddleOCR_DBNet/tools/export_model.py
@@ -26,13 +26,13 @@ def load_checkpoint(model, checkpoint_path):
:param checkpoint_path: Checkpoint path to be loaded
"""
checkpoint = paddle.load(checkpoint_path)
- model.set_state_dict(checkpoint['state_dict'])
- print('load checkpoint from {}'.format(checkpoint_path))
+ model.set_state_dict(checkpoint["state_dict"])
+ print("load checkpoint from {}".format(checkpoint_path))
def main(config):
- model = build_model(config['arch'])
- load_checkpoint(model, config['trainer']['resume_checkpoint'])
+ model = build_model(config["arch"])
+ load_checkpoint(model, config["trainer"]["resume_checkpoint"])
model.eval()
save_path = config["trainer"]["output_dir"]
@@ -41,9 +41,9 @@ def main(config):
model = to_static(
model,
input_spec=[
- paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype="float32")
- ])
+ paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
+ ],
+ )
paddle.jit.save(model, save_path)
print("inference model is saved to {}".format(save_path))
diff --git a/benchmark/PaddleOCR_DBNet/tools/infer.py b/benchmark/PaddleOCR_DBNet/tools/infer.py
index 24e919c33f..5ed4b8e948 100644
--- a/benchmark/PaddleOCR_DBNet/tools/infer.py
+++ b/benchmark/PaddleOCR_DBNet/tools/infer.py
@@ -15,6 +15,7 @@
import os
import sys
import pathlib
+
__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))
@@ -33,7 +34,7 @@
class InferenceEngine(object):
"""InferenceEngine
-
+
Inference engina class which contains preprocess, run, postprocess
"""
@@ -47,35 +48,48 @@ def __init__(self, args):
self.args = args
# init inference engine
- self.predictor, self.config, self.input_tensor, self.output_tensor = self.load_predictor(
+ (
+ self.predictor,
+ self.config,
+ self.input_tensor,
+ self.output_tensor,
+ ) = self.load_predictor(
os.path.join(args.model_dir, "inference.pdmodel"),
- os.path.join(args.model_dir, "inference.pdiparams"))
+ os.path.join(args.model_dir, "inference.pdiparams"),
+ )
# build transforms
- self.transforms = transforms.Compose([
- transforms.ToTensor(), transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
+ self.transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
+ ]
+ )
# wamrup
if self.args.warmup > 0:
for idx in range(args.warmup):
print(idx)
- x = np.random.rand(1, 3, self.args.crop_size,
- self.args.crop_size).astype("float32")
+ x = np.random.rand(
+ 1, 3, self.args.crop_size, self.args.crop_size
+ ).astype("float32")
self.input_tensor.copy_from_cpu(x)
self.predictor.run()
self.output_tensor.copy_to_cpu()
- self.post_process = get_post_processing({
- 'type': 'SegDetectorRepresenter',
- 'args': {
- 'thresh': 0.3,
- 'box_thresh': 0.7,
- 'max_candidates': 1000,
- 'unclip_ratio': 1.5
+ self.post_process = get_post_processing(
+ {
+ "type": "SegDetectorRepresenter",
+ "args": {
+ "thresh": 0.3,
+ "box_thresh": 0.7,
+ "max_candidates": 1000,
+ "unclip_ratio": 1.5,
+ },
}
- })
+ )
def load_predictor(self, model_file_path, params_file_path):
"""load_predictor
@@ -98,20 +112,18 @@ def load_predictor(self, model_file_path, params_file_path):
workspace_size=1 << 30,
precision_mode=precision,
max_batch_size=args.max_batch_size,
- min_subgraph_size=args.
- min_subgraph_size, # skip the minmum trt subgraph
- use_calib_mode=False)
+ min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph
+ use_calib_mode=False,
+ )
# collect shape
trt_shape_f = os.path.join(model_dir, "_trt_dynamic_shape.txt")
if not os.path.exists(trt_shape_f):
config.collect_shape_range_info(trt_shape_f)
- logger.info(
- f"collect dynamic shape info into : {trt_shape_f}")
+ logger.info(f"collect dynamic shape info into : {trt_shape_f}")
try:
- config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f,
- True)
+ config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
except Exception as E:
logger.info(E)
logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
@@ -162,7 +174,7 @@ def preprocess(self, img_path, short_size):
img = resize_image(img, short_size)
img = self.transforms(img)
img = np.expand_dims(img, axis=0)
- shape_info = {'shape': [(h, w)]}
+ shape_info = {"shape": [(h, w)]}
return img, shape_info
def postprocess(self, x, shape_info, is_output_polygon):
@@ -173,7 +185,8 @@ def postprocess(self, x, shape_info, is_output_polygon):
Returns: Output data after argmax.
"""
box_list, score_list = self.post_process(
- shape_info, x, is_output_polygon=is_output_polygon)
+ shape_info, x, is_output_polygon=is_output_polygon
+ )
box_list, score_list = box_list[0], score_list[0]
if len(box_list) > 0:
if is_output_polygon:
@@ -181,8 +194,7 @@ def postprocess(self, x, shape_info, is_output_polygon):
box_list = [box_list[i] for i, v in enumerate(idx) if v]
score_list = [score_list[i] for i, v in enumerate(idx) if v]
else:
- idx = box_list.reshape(box_list.shape[0], -1).sum(
- axis=1) > 0 # 去掉全为0的框
+ idx = box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0 # 去掉全为0的框
box_list, score_list = box_list[idx], score_list[idx]
else:
box_list, score_list = [], []
@@ -211,19 +223,17 @@ def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser(
- description="PaddlePaddle Classification Training", add_help=add_help)
+ description="PaddlePaddle Classification Training", add_help=add_help
+ )
parser.add_argument("--model_dir", default=None, help="inference model dir")
parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument(
- "--short_size", default=1024, type=int, help="short size")
+ parser.add_argument("--short_size", default=1024, type=int, help="short size")
parser.add_argument("--img_path", default="./images/demo.jpg")
- parser.add_argument(
- "--benchmark", default=False, type=str2bool, help="benchmark")
+ parser.add_argument("--benchmark", default=False, type=str2bool, help="benchmark")
parser.add_argument("--warmup", default=0, type=int, help="warmup iter")
- parser.add_argument(
- '--polygon', action='store_true', help='output polygon or box')
+ parser.add_argument("--polygon", action="store_true", help="output polygon or box")
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
@@ -251,19 +261,20 @@ def main(args):
# init benchmark
if args.benchmark:
import auto_log
+
autolog = auto_log.AutoLogger(
model_name="db",
batch_size=args.batch_size,
inference_config=inference_engine.config,
- gpu_ids="auto" if args.use_gpu else None)
+ gpu_ids="auto" if args.use_gpu else None,
+ )
# enable benchmark
if args.benchmark:
autolog.times.start()
# preprocess
- img, shape_info = inference_engine.preprocess(args.img_path,
- args.short_size)
+ img, shape_info = inference_engine.preprocess(args.img_path, args.short_size)
if args.benchmark:
autolog.times.stamp()
@@ -274,8 +285,9 @@ def main(args):
autolog.times.stamp()
# postprocess
- box_list, score_list = inference_engine.postprocess(output, shape_info,
- args.polygon)
+ box_list, score_list = inference_engine.postprocess(
+ output, shape_info, args.polygon
+ )
if args.benchmark:
autolog.times.stamp()
@@ -284,13 +296,16 @@ def main(args):
img = draw_bbox(cv2.imread(args.img_path)[:, :, ::-1], box_list)
# 保存结果到路径
- os.makedirs('output', exist_ok=True)
+ os.makedirs("output", exist_ok=True)
img_path = pathlib.Path(args.img_path)
- output_path = os.path.join('output', img_path.stem + '_infer_result.jpg')
+ output_path = os.path.join("output", img_path.stem + "_infer_result.jpg")
cv2.imwrite(output_path, img[:, :, ::-1])
save_result(
- output_path.replace('_infer_result.jpg', '.txt'), box_list, score_list,
- args.polygon)
+ output_path.replace("_infer_result.jpg", ".txt"),
+ box_list,
+ score_list,
+ args.polygon,
+ )
if __name__ == "__main__":
diff --git a/benchmark/PaddleOCR_DBNet/tools/predict.py b/benchmark/PaddleOCR_DBNet/tools/predict.py
index 51beffd170..56312a27f3 100644
--- a/benchmark/PaddleOCR_DBNet/tools/predict.py
+++ b/benchmark/PaddleOCR_DBNet/tools/predict.py
@@ -5,6 +5,7 @@
import os
import sys
import pathlib
+
__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))
@@ -34,49 +35,48 @@ def resize_image(img, short_size):
class PaddleModel:
def __init__(self, model_path, post_p_thre=0.7, gpu_id=None):
- '''
+ """
初始化模型
:param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
:param gpu_id: 在哪一块gpu上运行
- '''
+ """
self.gpu_id = gpu_id
- if self.gpu_id is not None and isinstance(
- self.gpu_id, int) and paddle.device.is_compiled_with_cuda():
+ if (
+ self.gpu_id is not None
+ and isinstance(self.gpu_id, int)
+ and paddle.device.is_compiled_with_cuda()
+ ):
paddle.device.set_device("gpu:{}".format(self.gpu_id))
else:
paddle.device.set_device("cpu")
checkpoint = paddle.load(model_path)
- config = checkpoint['config']
- config['arch']['backbone']['pretrained'] = False
- self.model = build_model(config['arch'])
- self.post_process = get_post_processing(config['post_processing'])
+ config = checkpoint["config"]
+ config["arch"]["backbone"]["pretrained"] = False
+ self.model = build_model(config["arch"])
+ self.post_process = get_post_processing(config["post_processing"])
self.post_process.box_thresh = post_p_thre
- self.img_mode = config['dataset']['train']['dataset']['args'][
- 'img_mode']
- self.model.set_state_dict(checkpoint['state_dict'])
+ self.img_mode = config["dataset"]["train"]["dataset"]["args"]["img_mode"]
+ self.model.set_state_dict(checkpoint["state_dict"])
self.model.eval()
self.transform = []
- for t in config['dataset']['train']['dataset']['args']['transforms']:
- if t['type'] in ['ToTensor', 'Normalize']:
+ for t in config["dataset"]["train"]["dataset"]["args"]["transforms"]:
+ if t["type"] in ["ToTensor", "Normalize"]:
self.transform.append(t)
self.transform = get_transforms(self.transform)
- def predict(self,
- img_path: str,
- is_output_polygon=False,
- short_size: int=1024):
- '''
+ def predict(self, img_path: str, is_output_polygon=False, short_size: int = 1024):
+ """
对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢
:param img_path: 图像地址
:param is_numpy:
:return:
- '''
- assert os.path.exists(img_path), 'file is not exists'
- img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0)
- if self.img_mode == 'RGB':
+ """
+ assert os.path.exists(img_path), "file is not exists"
+ img = cv2.imread(img_path, 1 if self.img_mode != "GRAY" else 0)
+ if self.img_mode == "RGB":
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
img = resize_image(img, short_size)
@@ -84,12 +84,13 @@ def predict(self,
tensor = self.transform(img)
tensor = tensor.unsqueeze_(0)
- batch = {'shape': [(h, w)]}
+ batch = {"shape": [(h, w)]}
with paddle.no_grad():
start = time.time()
preds = self.model(tensor)
box_list, score_list = self.post_process(
- batch, preds, is_output_polygon=is_output_polygon)
+ batch, preds, is_output_polygon=is_output_polygon
+ )
box_list, score_list = box_list[0], score_list[0]
if len(box_list) > 0:
if is_output_polygon:
@@ -97,8 +98,9 @@ def predict(self,
box_list = [box_list[i] for i, v in enumerate(idx) if v]
score_list = [score_list[i] for i, v in enumerate(idx) if v]
else:
- idx = box_list.reshape(box_list.shape[0], -1).sum(
- axis=1) > 0 # 去掉全为0的框
+ idx = (
+ box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0
+ ) # 去掉全为0的框
box_list, score_list = box_list[idx], score_list[idx]
else:
box_list, score_list = [], []
@@ -107,10 +109,7 @@ def predict(self,
def save_depoly(net, input, save_path):
- input_spec = [
- paddle.static.InputSpec(
- shape=[None, 3, None, None], dtype="float32")
- ]
+ input_spec = [paddle.static.InputSpec(shape=[None, 3, None, None], dtype="float32")]
net = paddle.jit.to_static(net, input_spec=input_spec)
# save static model for inference directly
@@ -119,33 +118,29 @@ def save_depoly(net, input, save_path):
def init_args():
import argparse
- parser = argparse.ArgumentParser(description='DBNet.paddle')
- parser.add_argument('--model_path', default=r'model_best.pth', type=str)
- parser.add_argument(
- '--input_folder',
- default='./test/input',
- type=str,
- help='img path for predict')
+
+ parser = argparse.ArgumentParser(description="DBNet.paddle")
+ parser.add_argument("--model_path", default=r"model_best.pth", type=str)
parser.add_argument(
- '--output_folder',
- default='./test/output',
- type=str,
- help='img path for output')
- parser.add_argument('--gpu', default=0, type=int, help='gpu for inference')
+ "--input_folder", default="./test/input", type=str, help="img path for predict"
+ )
parser.add_argument(
- '--thre', default=0.3, type=float, help='the thresh of post_processing')
+ "--output_folder", default="./test/output", type=str, help="img path for output"
+ )
+ parser.add_argument("--gpu", default=0, type=int, help="gpu for inference")
parser.add_argument(
- '--polygon', action='store_true', help='output polygon or box')
- parser.add_argument('--show', action='store_true', help='show result')
+ "--thre", default=0.3, type=float, help="the thresh of post_processing"
+ )
+ parser.add_argument("--polygon", action="store_true", help="output polygon or box")
+ parser.add_argument("--show", action="store_true", help="show result")
parser.add_argument(
- '--save_result',
- action='store_true',
- help='save box and score to txt file')
+ "--save_result", action="store_true", help="save box and score to txt file"
+ )
args = parser.parse_args()
return args
-if __name__ == '__main__':
+if __name__ == "__main__":
import pathlib
from tqdm import tqdm
import matplotlib.pyplot as plt
@@ -158,7 +153,8 @@ def init_args():
img_folder = pathlib.Path(args.input_folder)
for img_path in tqdm(get_image_file_list(args.input_folder)):
preds, boxes_list, score_list, t = model.predict(
- img_path, is_output_polygon=args.polygon)
+ img_path, is_output_polygon=args.polygon
+ )
img = draw_bbox(cv2.imread(img_path)[:, :, ::-1], boxes_list)
if args.show:
show_img(preds)
@@ -167,12 +163,13 @@ def init_args():
# 保存结果到路径
os.makedirs(args.output_folder, exist_ok=True)
img_path = pathlib.Path(img_path)
- output_path = os.path.join(args.output_folder,
- img_path.stem + '_result.jpg')
- pred_path = os.path.join(args.output_folder,
- img_path.stem + '_pred.jpg')
+ output_path = os.path.join(args.output_folder, img_path.stem + "_result.jpg")
+ pred_path = os.path.join(args.output_folder, img_path.stem + "_pred.jpg")
cv2.imwrite(output_path, img[:, :, ::-1])
cv2.imwrite(pred_path, preds * 255)
save_result(
- output_path.replace('_result.jpg', '.txt'), boxes_list, score_list,
- args.polygon)
+ output_path.replace("_result.jpg", ".txt"),
+ boxes_list,
+ score_list,
+ args.polygon,
+ )
diff --git a/benchmark/PaddleOCR_DBNet/tools/train.py b/benchmark/PaddleOCR_DBNet/tools/train.py
index 403d6185fc..170eebd81b 100644
--- a/benchmark/PaddleOCR_DBNet/tools/train.py
+++ b/benchmark/PaddleOCR_DBNet/tools/train.py
@@ -1,6 +1,7 @@
import os
import sys
import pathlib
+
__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))
@@ -22,25 +23,26 @@ def main(config, profiler_options):
from trainer import Trainer
from post_processing import get_post_processing
from utils import get_metric
+
if paddle.device.cuda.device_count() > 1:
dist.init_parallel_env()
- config['distributed'] = True
+ config["distributed"] = True
else:
- config['distributed'] = False
- train_loader = get_dataloader(config['dataset']['train'],
- config['distributed'])
+ config["distributed"] = False
+ train_loader = get_dataloader(config["dataset"]["train"], config["distributed"])
assert train_loader is not None
- if 'validate' in config['dataset']:
- validate_loader = get_dataloader(config['dataset']['validate'], False)
+ if "validate" in config["dataset"]:
+ validate_loader = get_dataloader(config["dataset"]["validate"], False)
else:
validate_loader = None
- criterion = build_loss(config['loss'])
- config['arch']['backbone']['in_channels'] = 3 if config['dataset']['train'][
- 'dataset']['args']['img_mode'] != 'GRAY' else 1
- model = build_model(config['arch'])
+ criterion = build_loss(config["loss"])
+ config["arch"]["backbone"]["in_channels"] = (
+ 3 if config["dataset"]["train"]["dataset"]["args"]["img_mode"] != "GRAY" else 1
+ )
+ model = build_model(config["arch"])
# set @to_static for benchmark, skip this by default.
- post_p = get_post_processing(config['post_processing'])
- metric = get_metric(config['metric'])
+ post_p = get_post_processing(config["post_processing"])
+ metric = get_metric(config["metric"])
trainer = Trainer(
config=config,
model=model,
@@ -49,11 +51,12 @@ def main(config, profiler_options):
post_process=post_p,
metric_cls=metric,
validate_loader=validate_loader,
- profiler_options=profiler_options)
+ profiler_options=profiler_options,
+ )
trainer.train()
-if __name__ == '__main__':
+if __name__ == "__main__":
args = init_args()
assert os.path.exists(args.config_file)
config = Config(args.config_file)
diff --git a/benchmark/PaddleOCR_DBNet/trainer/__init__.py b/benchmark/PaddleOCR_DBNet/trainer/__init__.py
index 76c7392d14..e5b22345e7 100644
--- a/benchmark/PaddleOCR_DBNet/trainer/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/trainer/__init__.py
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:58
# @Author : zhoujun
-from .trainer import Trainer
\ No newline at end of file
+from .trainer import Trainer
diff --git a/benchmark/PaddleOCR_DBNet/trainer/trainer.py b/benchmark/PaddleOCR_DBNet/trainer/trainer.py
index 34b259f3d1..077d39ca58 100644
--- a/benchmark/PaddleOCR_DBNet/trainer/trainer.py
+++ b/benchmark/PaddleOCR_DBNet/trainer/trainer.py
@@ -11,19 +11,28 @@
class Trainer(BaseTrainer):
- def __init__(self,
- config,
- model,
- criterion,
- train_loader,
- validate_loader,
- metric_cls,
- post_process=None,
- profiler_options=None):
- super(Trainer, self).__init__(config, model, criterion, train_loader,
- validate_loader, metric_cls, post_process)
+ def __init__(
+ self,
+ config,
+ model,
+ criterion,
+ train_loader,
+ validate_loader,
+ metric_cls,
+ post_process=None,
+ profiler_options=None,
+ ):
+ super(Trainer, self).__init__(
+ config,
+ model,
+ criterion,
+ train_loader,
+ validate_loader,
+ metric_cls,
+ post_process,
+ )
self.profiler_options = profiler_options
- self.enable_eval = config['trainer'].get('enable_eval', True)
+ self.enable_eval = config["trainer"].get("enable_eval", True)
def _train_epoch(self, epoch):
self.model.train()
@@ -32,7 +41,7 @@ def _train_epoch(self, epoch):
train_batch_cost = 0.0
reader_start = time.time()
epoch_start = time.time()
- train_loss = 0.
+ train_loss = 0.0
running_metric_text = runningScore(2)
for i, batch in enumerate(self.train_loader):
@@ -42,25 +51,26 @@ def _train_epoch(self, epoch):
self.global_step += 1
lr = self.optimizer.get_lr()
- cur_batch_size = batch['img'].shape[0]
+ cur_batch_size = batch["img"].shape[0]
train_reader_cost += time.time() - reader_start
if self.amp:
with paddle.amp.auto_cast(
- enable='gpu' in paddle.device.get_device(),
- custom_white_list=self.amp.get('custom_white_list', []),
- custom_black_list=self.amp.get('custom_black_list', []),
- level=self.amp.get('level', 'O2')):
- preds = self.model(batch['img'])
+ enable="gpu" in paddle.device.get_device(),
+ custom_white_list=self.amp.get("custom_white_list", []),
+ custom_black_list=self.amp.get("custom_black_list", []),
+ level=self.amp.get("level", "O2"),
+ ):
+ preds = self.model(batch["img"])
loss_dict = self.criterion(preds.astype(paddle.float32), batch)
- scaled_loss = self.amp['scaler'].scale(loss_dict['loss'])
+ scaled_loss = self.amp["scaler"].scale(loss_dict["loss"])
scaled_loss.backward()
- self.amp['scaler'].minimize(self.optimizer, scaled_loss)
+ self.amp["scaler"].minimize(self.optimizer, scaled_loss)
else:
- preds = self.model(batch['img'])
+ preds = self.model(batch["img"])
loss_dict = self.criterion(preds, batch)
# backward
- loss_dict['loss'].backward()
+ loss_dict["loss"].backward()
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.clear_grad()
@@ -72,33 +82,45 @@ def _train_epoch(self, epoch):
# acc iou
score_shrink_map = cal_text_score(
preds[:, 0, :, :],
- batch['shrink_map'],
- batch['shrink_mask'],
+ batch["shrink_map"],
+ batch["shrink_mask"],
running_metric_text,
- thred=self.config['post_processing']['args']['thresh'])
+ thred=self.config["post_processing"]["args"]["thresh"],
+ )
# loss 和 acc 记录到日志
- loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item())
+ loss_str = "loss: {:.4f}, ".format(loss_dict["loss"].item())
for idx, (key, value) in enumerate(loss_dict.items()):
loss_dict[key] = value.item()
- if key == 'loss':
+ if key == "loss":
continue
- loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
+ loss_str += "{}: {:.4f}".format(key, loss_dict[key])
if idx < len(loss_dict) - 1:
- loss_str += ', '
+ loss_str += ", "
- train_loss += loss_dict['loss']
- acc = score_shrink_map['Mean Acc']
- iou_shrink_map = score_shrink_map['Mean IoU']
+ train_loss += loss_dict["loss"]
+ acc = score_shrink_map["Mean Acc"]
+ iou_shrink_map = score_shrink_map["Mean IoU"]
if self.global_step % self.log_iter == 0:
self.logger_info(
- '[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}'.
- format(epoch, self.epochs, i + 1, self.train_loader_len,
- self.global_step, total_samples / train_batch_cost,
- train_reader_cost / self.log_iter, train_batch_cost /
- self.log_iter, total_samples / self.log_iter, acc,
- iou_shrink_map, loss_str, lr, train_batch_cost))
+ "[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}".format(
+ epoch,
+ self.epochs,
+ i + 1,
+ self.train_loader_len,
+ self.global_step,
+ total_samples / train_batch_cost,
+ train_reader_cost / self.log_iter,
+ train_batch_cost / self.log_iter,
+ total_samples / self.log_iter,
+ acc,
+ iou_shrink_map,
+ loss_str,
+ lr,
+ train_batch_cost,
+ )
+ )
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
@@ -106,19 +128,20 @@ def _train_epoch(self, epoch):
if self.visualdl_enable and paddle.distributed.get_rank() == 0:
# write tensorboard
for key, value in loss_dict.items():
- self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value,
- self.global_step)
- self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc,
- self.global_step)
- self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map',
- iou_shrink_map, self.global_step)
- self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
+ self.writer.add_scalar(
+ "TRAIN/LOSS/{}".format(key), value, self.global_step
+ )
+ self.writer.add_scalar("TRAIN/ACC_IOU/acc", acc, self.global_step)
+ self.writer.add_scalar(
+ "TRAIN/ACC_IOU/iou_shrink_map", iou_shrink_map, self.global_step
+ )
+ self.writer.add_scalar("TRAIN/lr", lr, self.global_step)
reader_start = time.time()
return {
- 'train_loss': train_loss / self.train_loader_len,
- 'lr': lr,
- 'time': time.time() - epoch_start,
- 'epoch': epoch
+ "train_loss": train_loss / self.train_loader_len,
+ "lr": lr,
+ "time": time.time() - epoch_start,
+ "epoch": epoch,
}
def _eval(self, epoch):
@@ -127,104 +150,107 @@ def _eval(self, epoch):
total_frame = 0.0
total_time = 0.0
for i, batch in tqdm(
- enumerate(self.validate_loader),
- total=len(self.validate_loader),
- desc='test model'):
+ enumerate(self.validate_loader),
+ total=len(self.validate_loader),
+ desc="test model",
+ ):
with paddle.no_grad():
start = time.time()
if self.amp:
with paddle.amp.auto_cast(
- enable='gpu' in paddle.device.get_device(),
- custom_white_list=self.amp.get('custom_white_list',
- []),
- custom_black_list=self.amp.get('custom_black_list',
- []),
- level=self.amp.get('level', 'O2')):
- preds = self.model(batch['img'])
+ enable="gpu" in paddle.device.get_device(),
+ custom_white_list=self.amp.get("custom_white_list", []),
+ custom_black_list=self.amp.get("custom_black_list", []),
+ level=self.amp.get("level", "O2"),
+ ):
+ preds = self.model(batch["img"])
preds = preds.astype(paddle.float32)
else:
- preds = self.model(batch['img'])
+ preds = self.model(batch["img"])
boxes, scores = self.post_process(
- batch,
- preds,
- is_output_polygon=self.metric_cls.is_output_polygon)
- total_frame += batch['img'].shape[0]
+ batch, preds, is_output_polygon=self.metric_cls.is_output_polygon
+ )
+ total_frame += batch["img"].shape[0]
total_time += time.time() - start
- raw_metric = self.metric_cls.validate_measure(batch,
- (boxes, scores))
+ raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
raw_metrics.append(raw_metric)
metrics = self.metric_cls.gather_measure(raw_metrics)
- self.logger_info('FPS:{}'.format(total_frame / total_time))
- return metrics['recall'].avg, metrics['precision'].avg, metrics[
- 'fmeasure'].avg
+ self.logger_info("FPS:{}".format(total_frame / total_time))
+ return metrics["recall"].avg, metrics["precision"].avg, metrics["fmeasure"].avg
def _on_epoch_finish(self):
- self.logger_info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.
- format(self.epoch_result['epoch'], self.epochs, self.
- epoch_result['train_loss'], self.epoch_result[
- 'time'], self.epoch_result['lr']))
- net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)
- net_save_path_best = '{}/model_best.pth'.format(self.checkpoint_dir)
+ self.logger_info(
+ "[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}".format(
+ self.epoch_result["epoch"],
+ self.epochs,
+ self.epoch_result["train_loss"],
+ self.epoch_result["time"],
+ self.epoch_result["lr"],
+ )
+ )
+ net_save_path = "{}/model_latest.pth".format(self.checkpoint_dir)
+ net_save_path_best = "{}/model_best.pth".format(self.checkpoint_dir)
if paddle.distributed.get_rank() == 0:
- self._save_checkpoint(self.epoch_result['epoch'], net_save_path)
+ self._save_checkpoint(self.epoch_result["epoch"], net_save_path)
save_best = False
- if self.validate_loader is not None and self.metric_cls is not None and self.enable_eval: # 使用f1作为最优模型指标
- recall, precision, hmean = self._eval(self.epoch_result[
- 'epoch'])
+ if (
+ self.validate_loader is not None
+ and self.metric_cls is not None
+ and self.enable_eval
+ ): # 使用f1作为最优模型指标
+ recall, precision, hmean = self._eval(self.epoch_result["epoch"])
if self.visualdl_enable:
- self.writer.add_scalar('EVAL/recall', recall,
- self.global_step)
- self.writer.add_scalar('EVAL/precision', precision,
- self.global_step)
- self.writer.add_scalar('EVAL/hmean', hmean,
- self.global_step)
+ self.writer.add_scalar("EVAL/recall", recall, self.global_step)
+ self.writer.add_scalar(
+ "EVAL/precision", precision, self.global_step
+ )
+ self.writer.add_scalar("EVAL/hmean", hmean, self.global_step)
self.logger_info(
- 'test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}'.
- format(recall, precision, hmean))
+ "test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}".format(
+ recall, precision, hmean
+ )
+ )
- if hmean >= self.metrics['hmean']:
+ if hmean >= self.metrics["hmean"]:
save_best = True
- self.metrics['train_loss'] = self.epoch_result['train_loss']
- self.metrics['hmean'] = hmean
- self.metrics['precision'] = precision
- self.metrics['recall'] = recall
- self.metrics['best_model_epoch'] = self.epoch_result[
- 'epoch']
+ self.metrics["train_loss"] = self.epoch_result["train_loss"]
+ self.metrics["hmean"] = hmean
+ self.metrics["precision"] = precision
+ self.metrics["recall"] = recall
+ self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
else:
- if self.epoch_result['train_loss'] <= self.metrics[
- 'train_loss']:
+ if self.epoch_result["train_loss"] <= self.metrics["train_loss"]:
save_best = True
- self.metrics['train_loss'] = self.epoch_result['train_loss']
- self.metrics['best_model_epoch'] = self.epoch_result[
- 'epoch']
- best_str = 'current best, '
+ self.metrics["train_loss"] = self.epoch_result["train_loss"]
+ self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
+ best_str = "current best, "
for k, v in self.metrics.items():
- best_str += '{}: {:.6f}, '.format(k, v)
+ best_str += "{}: {:.6f}, ".format(k, v)
self.logger_info(best_str)
if save_best:
import shutil
+
shutil.copy(net_save_path, net_save_path_best)
- self.logger_info("Saving current best: {}".format(
- net_save_path_best))
+ self.logger_info("Saving current best: {}".format(net_save_path_best))
else:
self.logger_info("Saving checkpoint: {}".format(net_save_path))
def _on_train_finish(self):
if self.enable_eval:
for k, v in self.metrics.items():
- self.logger_info('{}:{}'.format(k, v))
- self.logger_info('finish train')
+ self.logger_info("{}:{}".format(k, v))
+ self.logger_info("finish train")
def _initialize_scheduler(self):
- if self.config['lr_scheduler']['type'] == 'Polynomial':
- self.config['lr_scheduler']['args']['epochs'] = self.config[
- 'trainer']['epochs']
- self.config['lr_scheduler']['args']['step_each_epoch'] = len(
- self.train_loader)
- self.lr_scheduler = Polynomial(
- **self.config['lr_scheduler']['args'])()
+ if self.config["lr_scheduler"]["type"] == "Polynomial":
+ self.config["lr_scheduler"]["args"]["epochs"] = self.config["trainer"][
+ "epochs"
+ ]
+ self.config["lr_scheduler"]["args"]["step_each_epoch"] = len(
+ self.train_loader
+ )
+ self.lr_scheduler = Polynomial(**self.config["lr_scheduler"]["args"])()
else:
- self.lr_scheduler = self._initialize('lr_scheduler',
- paddle.optimizer.lr)
+ self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)
diff --git a/benchmark/PaddleOCR_DBNet/utils/cal_recall/__init__.py b/benchmark/PaddleOCR_DBNet/utils/cal_recall/__init__.py
index 0db38a8a37..4b64268ad5 100644
--- a/benchmark/PaddleOCR_DBNet/utils/cal_recall/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/utils/cal_recall/__init__.py
@@ -2,4 +2,5 @@
# @Time : 1/16/19 6:40 AM
# @Author : zhoujun
from .script import cal_recall_precison_f1
-__all__ = ['cal_recall_precison_f1']
+
+__all__ = ["cal_recall_precison_f1"]
diff --git a/benchmark/PaddleOCR_DBNet/utils/cal_recall/rrc_evaluation_funcs.py b/benchmark/PaddleOCR_DBNet/utils/cal_recall/rrc_evaluation_funcs.py
index 4e12ee66a0..8f5040ae69 100644
--- a/benchmark/PaddleOCR_DBNet/utils/cal_recall/rrc_evaluation_funcs.py
+++ b/benchmark/PaddleOCR_DBNet/utils/cal_recall/rrc_evaluation_funcs.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python2
-#encoding: UTF-8
+# encoding: UTF-8
import json
import sys
-sys.path.append('./')
+
+sys.path.append("./")
import zipfile
import re
import sys
@@ -15,20 +16,21 @@
def print_help():
sys.stdout.write(
- 'Usage: python %s.py -g= -s= [-o= -p=]'
- % sys.argv[0])
+ "Usage: python %s.py -g= -s= [-o= -p=]"
+ % sys.argv[0]
+ )
sys.exit(2)
-def load_zip_file_keys(file, fileNameRegExp=''):
+def load_zip_file_keys(file, fileNameRegExp=""):
"""
Returns an array with the entries of the ZIP file that match with the regular expression.
The key's are the names or the file or the capturing group definied in the fileNameRegExp
"""
try:
- archive = zipfile.ZipFile(file, mode='r', allowZip64=True)
+ archive = zipfile.ZipFile(file, mode="r", allowZip64=True)
except:
- raise Exception('Error loading the ZIP archive.')
+ raise Exception("Error loading the ZIP archive.")
pairs = []
@@ -49,16 +51,16 @@ def load_zip_file_keys(file, fileNameRegExp=''):
return pairs
-def load_zip_file(file, fileNameRegExp='', allEntries=False):
+def load_zip_file(file, fileNameRegExp="", allEntries=False):
"""
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
The key's are the names or the file or the capturing group definied in the fileNameRegExp
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
"""
try:
- archive = zipfile.ZipFile(file, mode='r', allowZip64=True)
+ archive = zipfile.ZipFile(file, mode="r", allowZip64=True)
except:
- raise Exception('Error loading the ZIP archive')
+ raise Exception("Error loading the ZIP archive")
pairs = []
for name in archive.namelist():
@@ -76,12 +78,12 @@ def load_zip_file(file, fileNameRegExp='', allEntries=False):
pairs.append([keyName, archive.read(name)])
else:
if allEntries:
- raise Exception('ZIP entry not valid: %s' % name)
+ raise Exception("ZIP entry not valid: %s" % name)
return dict(pairs)
-def load_folder_file(file, fileNameRegExp='', allEntries=False):
+def load_folder_file(file, fileNameRegExp="", allEntries=False):
"""
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
The key's are the names or the file or the capturing group definied in the fileNameRegExp
@@ -103,7 +105,7 @@ def load_folder_file(file, fileNameRegExp='', allEntries=False):
pairs.append([keyName, open(os.path.join(file, name)).read()])
else:
if allEntries:
- raise Exception('ZIP entry not valid: %s' % name)
+ raise Exception("ZIP entry not valid: %s" % name)
return dict(pairs)
@@ -113,73 +115,77 @@ def decode_utf8(raw):
Returns a Unicode object on success, or None on failure
"""
try:
- raw = codecs.decode(raw, 'utf-8', 'replace')
- #extracts BOM if exists
- raw = raw.encode('utf8')
+ raw = codecs.decode(raw, "utf-8", "replace")
+ # extracts BOM if exists
+ raw = raw.encode("utf8")
if raw.startswith(codecs.BOM_UTF8):
- raw = raw.replace(codecs.BOM_UTF8, '', 1)
- return raw.decode('utf-8')
+ raw = raw.replace(codecs.BOM_UTF8, "", 1)
+ return raw.decode("utf-8")
except:
return None
-def validate_lines_in_file(fileName,
- file_contents,
- CRLF=True,
- LTRB=True,
- withTranscription=False,
- withConfidence=False,
- imWidth=0,
- imHeight=0):
+def validate_lines_in_file(
+ fileName,
+ file_contents,
+ CRLF=True,
+ LTRB=True,
+ withTranscription=False,
+ withConfidence=False,
+ imWidth=0,
+ imHeight=0,
+):
"""
This function validates that all lines of the file calling the Line validation function for each line
"""
utf8File = decode_utf8(file_contents)
- if (utf8File is None):
+ if utf8File is None:
raise Exception("The file %s is not UTF-8" % fileName)
lines = utf8File.split("\r\n" if CRLF else "\n")
for line in lines:
line = line.replace("\r", "").replace("\n", "")
- if (line != ""):
+ if line != "":
try:
- validate_tl_line(line, LTRB, withTranscription, withConfidence,
- imWidth, imHeight)
+ validate_tl_line(
+ line, LTRB, withTranscription, withConfidence, imWidth, imHeight
+ )
except Exception as e:
raise Exception(
- ("Line in sample not valid. Sample: %s Line: %s Error: %s" %
- (fileName, line, str(e))).encode('utf-8', 'replace'))
+ (
+ "Line in sample not valid. Sample: %s Line: %s Error: %s"
+ % (fileName, line, str(e))
+ ).encode("utf-8", "replace")
+ )
-def validate_tl_line(line,
- LTRB=True,
- withTranscription=True,
- withConfidence=True,
- imWidth=0,
- imHeight=0):
+def validate_tl_line(
+ line, LTRB=True, withTranscription=True, withConfidence=True, imWidth=0, imHeight=0
+):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
- LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
- LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
"""
- get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth,
- imHeight)
+ get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth, imHeight)
-def get_tl_line_values(line,
- LTRB=True,
- withTranscription=False,
- withConfidence=False,
- imWidth=0,
- imHeight=0):
+def get_tl_line_values(
+ line,
+ LTRB=True,
+ withTranscription=False,
+ withConfidence=False,
+ imWidth=0,
+ imHeight=0,
+):
"""
Validate the format of the line. If the line is not valid an exception will be raised.
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
Posible values are:
- LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
- LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
Returns values from a textline. Points , [Confidences], [Transcriptions]
"""
confidence = 0.0
@@ -189,111 +195,110 @@ def get_tl_line_values(line,
numPoints = 4
if LTRB:
-
numPoints = 4
if withTranscription and withConfidence:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$",
+ line,
+ )
if m == None:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$",
+ line,
+ )
raise Exception(
"Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription"
)
elif withConfidence:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$",
+ line,
+ )
if m == None:
raise Exception(
"Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence"
)
elif withTranscription:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$",
+ line,
+ )
if m == None:
raise Exception(
"Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription"
)
else:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$",
+ line,
+ )
if m == None:
- raise Exception(
- "Format incorrect. Should be: xmin,ymin,xmax,ymax")
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
xmin = int(m.group(1))
ymin = int(m.group(2))
xmax = int(m.group(3))
ymax = int(m.group(4))
- if (xmax < xmin):
+ if xmax < xmin:
raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." % (xmax))
- if (ymax < ymin):
- raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %
- (ymax))
+ if ymax < ymin:
+ raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." % (ymax))
points = [float(m.group(i)) for i in range(1, (numPoints + 1))]
- if (imWidth > 0 and imHeight > 0):
+ if imWidth > 0 and imHeight > 0:
validate_point_inside_bounds(xmin, ymin, imWidth, imHeight)
validate_point_inside_bounds(xmax, ymax, imWidth, imHeight)
else:
-
numPoints = 8
if withTranscription and withConfidence:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$",
+ line,
+ )
if m == None:
raise Exception(
"Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription"
)
elif withConfidence:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$",
+ line,
+ )
if m == None:
raise Exception(
"Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence"
)
elif withTranscription:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$",
+ line,
+ )
if m == None:
raise Exception(
"Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription"
)
else:
m = re.match(
- r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',
- line)
+ r"^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$",
+ line,
+ )
if m == None:
- raise Exception(
- "Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
points = [float(m.group(i)) for i in range(1, (numPoints + 1))]
- points = order_points_clockwise(np.array(points).reshape(-1,
- 2)).reshape(-1)
+ points = order_points_clockwise(np.array(points).reshape(-1, 2)).reshape(-1)
validate_clockwise_points(points)
- if (imWidth > 0 and imHeight > 0):
- validate_point_inside_bounds(points[0], points[1], imWidth,
- imHeight)
- validate_point_inside_bounds(points[2], points[3], imWidth,
- imHeight)
- validate_point_inside_bounds(points[4], points[5], imWidth,
- imHeight)
- validate_point_inside_bounds(points[6], points[7], imWidth,
- imHeight)
+ if imWidth > 0 and imHeight > 0:
+ validate_point_inside_bounds(points[0], points[1], imWidth, imHeight)
+ validate_point_inside_bounds(points[2], points[3], imWidth, imHeight)
+ validate_point_inside_bounds(points[4], points[5], imWidth, imHeight)
+ validate_point_inside_bounds(points[6], points[7], imWidth, imHeight)
if withConfidence:
try:
@@ -304,22 +309,26 @@ def get_tl_line_values(line,
if withTranscription:
posTranscription = numPoints + (2 if withConfidence else 1)
transcription = m.group(posTranscription)
- m2 = re.match(r'^\s*\"(.*)\"\s*$', transcription)
- if m2 != None: #Transcription with double quotes, we extract the value and replace escaped characters
- transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"",
- "\"")
+ m2 = re.match(r"^\s*\"(.*)\"\s*$", transcription)
+ if (
+ m2 != None
+ ): # Transcription with double quotes, we extract the value and replace escaped characters
+ transcription = m2.group(1).replace("\\\\", "\\").replace('\\"', '"')
return points, confidence, transcription
def validate_point_inside_bounds(x, y, imWidth, imHeight):
- if (x < 0 or x > imWidth):
- raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %
- (xmin, imWidth, imHeight))
- if (y < 0 or y > imHeight):
+ if x < 0 or x > imWidth:
+ raise Exception(
+ "X value (%s) not valid. Image dimensions: (%s,%s)"
+ % (xmin, imWidth, imHeight)
+ )
+ if y < 0 or y > imHeight:
raise Exception(
"Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s"
- % (ymin, imWidth, imHeight))
+ % (ymin, imWidth, imHeight)
+ )
def validate_clockwise_points(points):
@@ -330,14 +339,18 @@ def validate_clockwise_points(points):
if len(points) != 8:
raise Exception("Points list not valid." + str(len(points)))
- point = [[int(points[0]), int(points[1])],
- [int(points[2]), int(points[3])],
- [int(points[4]), int(points[5])],
- [int(points[6]), int(points[7])]]
- edge = [(point[1][0] - point[0][0]) * (point[1][1] + point[0][1]),
- (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]),
- (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]),
- (point[0][0] - point[3][0]) * (point[0][1] + point[3][1])]
+ point = [
+ [int(points[0]), int(points[1])],
+ [int(points[2]), int(points[3])],
+ [int(points[4]), int(points[5])],
+ [int(points[6]), int(points[7])],
+ ]
+ edge = [
+ (point[1][0] - point[0][0]) * (point[1][1] + point[0][1]),
+ (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]),
+ (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]),
+ (point[0][0] - point[3][0]) * (point[0][1] + point[3][1]),
+ ]
summatory = edge[0] + edge[1] + edge[2] + edge[3]
if summatory > 0:
@@ -346,14 +359,16 @@ def validate_clockwise_points(points):
)
-def get_tl_line_values_from_file_contents(content,
- CRLF=True,
- LTRB=True,
- withTranscription=False,
- withConfidence=False,
- imWidth=0,
- imHeight=0,
- sort_by_confidences=True):
+def get_tl_line_values_from_file_contents(
+ content,
+ CRLF=True,
+ LTRB=True,
+ withTranscription=False,
+ withConfidence=False,
+ imWidth=0,
+ imHeight=0,
+ sort_by_confidences=True,
+):
"""
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
xmin,ymin,xmax,ymax,[confidence],[transcription]
@@ -366,16 +381,17 @@ def get_tl_line_values_from_file_contents(content,
lines = content.split("\r\n" if CRLF else "\n")
for line in lines:
line = line.replace("\r", "").replace("\n", "")
- if (line != ""):
+ if line != "":
points, confidence, transcription = get_tl_line_values(
- line, LTRB, withTranscription, withConfidence, imWidth,
- imHeight)
+ line, LTRB, withTranscription, withConfidence, imWidth, imHeight
+ )
pointsList.append(points)
transcriptionsList.append(transcription)
confidencesList.append(confidence)
if withConfidence and len(confidencesList) > 0 and sort_by_confidences:
import numpy as np
+
sorted_ind = np.argsort(-np.array(confidencesList))
confidencesList = [confidencesList[i] for i in sorted_ind]
pointsList = [pointsList[i] for i in sorted_ind]
@@ -384,12 +400,14 @@ def get_tl_line_values_from_file_contents(content,
return pointsList, confidencesList, transcriptionsList
-def main_evaluation(p,
- default_evaluation_params_fn,
- validate_data_fn,
- evaluate_method_fn,
- show_result=True,
- per_sample=True):
+def main_evaluation(
+ p,
+ default_evaluation_params_fn,
+ validate_data_fn,
+ evaluate_method_fn,
+ show_result=True,
+ per_sample=True,
+):
"""
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
Params:
@@ -399,60 +417,56 @@ def main_evaluation(p,
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
"""
evalParams = default_evaluation_params_fn()
- if 'p' in p.keys():
- evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p[
- 'p'][1:-1]))
-
- resDict = {
- 'calculated': True,
- 'Message': '',
- 'method': '{}',
- 'per_sample': '{}'
- }
+ if "p" in p.keys():
+ evalParams.update(
+ p["p"] if isinstance(p["p"], dict) else json.loads(p["p"][1:-1])
+ )
+
+ resDict = {"calculated": True, "Message": "", "method": "{}", "per_sample": "{}"}
try:
# validate_data_fn(p['g'], p['s'], evalParams)
- evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
+ evalData = evaluate_method_fn(p["g"], p["s"], evalParams)
resDict.update(evalData)
except Exception as e:
traceback.print_exc()
- resDict['Message'] = str(e)
- resDict['calculated'] = False
+ resDict["Message"] = str(e)
+ resDict["calculated"] = False
- if 'o' in p:
- if not os.path.exists(p['o']):
- os.makedirs(p['o'])
+ if "o" in p:
+ if not os.path.exists(p["o"]):
+ os.makedirs(p["o"])
- resultsOutputname = p['o'] + '/results.zip'
- outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
+ resultsOutputname = p["o"] + "/results.zip"
+ outZip = zipfile.ZipFile(resultsOutputname, mode="w", allowZip64=True)
- del resDict['per_sample']
- if 'output_items' in resDict.keys():
- del resDict['output_items']
+ del resDict["per_sample"]
+ if "output_items" in resDict.keys():
+ del resDict["output_items"]
- outZip.writestr('method.json', json.dumps(resDict))
+ outZip.writestr("method.json", json.dumps(resDict))
- if not resDict['calculated']:
+ if not resDict["calculated"]:
if show_result:
- sys.stderr.write('Error!\n' + resDict['Message'] + '\n\n')
- if 'o' in p:
+ sys.stderr.write("Error!\n" + resDict["Message"] + "\n\n")
+ if "o" in p:
outZip.close()
return resDict
- if 'o' in p:
+ if "o" in p:
if per_sample == True:
- for k, v in evalData['per_sample'].iteritems():
- outZip.writestr(k + '.json', json.dumps(v))
+ for k, v in evalData["per_sample"].iteritems():
+ outZip.writestr(k + ".json", json.dumps(v))
- if 'output_items' in evalData.keys():
- for k, v in evalData['output_items'].iteritems():
+ if "output_items" in evalData.keys():
+ for k, v in evalData["output_items"].iteritems():
outZip.writestr(k, v)
outZip.close()
if show_result:
sys.stdout.write("Calculated!")
- sys.stdout.write(json.dumps(resDict['method']))
+ sys.stdout.write(json.dumps(resDict["method"]))
return resDict
@@ -465,14 +479,15 @@ def main_validation(default_evaluation_params_fn, validate_data_fn):
validate_data_fn: points to a method that validates the corrct format of the submission
"""
try:
- p = dict([s[1:].split('=') for s in sys.argv[1:]])
+ p = dict([s[1:].split("=") for s in sys.argv[1:]])
evalParams = default_evaluation_params_fn()
- if 'p' in p.keys():
- evalParams.update(p['p'] if isinstance(p['p'], dict) else
- json.loads(p['p'][1:-1]))
+ if "p" in p.keys():
+ evalParams.update(
+ p["p"] if isinstance(p["p"], dict) else json.loads(p["p"][1:-1])
+ )
- validate_data_fn(p['g'], p['s'], evalParams)
- print('SUCCESS')
+ validate_data_fn(p["g"], p["s"], evalParams)
+ print("SUCCESS")
sys.exit(0)
except Exception as e:
print(str(e))
diff --git a/benchmark/PaddleOCR_DBNet/utils/cal_recall/script.py b/benchmark/PaddleOCR_DBNet/utils/cal_recall/script.py
index 3b2f3916f6..eb4753aec9 100644
--- a/benchmark/PaddleOCR_DBNet/utils/cal_recall/script.py
+++ b/benchmark/PaddleOCR_DBNet/utils/cal_recall/script.py
@@ -11,17 +11,14 @@ def default_evaluation_params():
default_evaluation_params: Default parameters to use for the validation and evaluation.
"""
return {
- 'IOU_CONSTRAINT': 0.5,
- 'AREA_PRECISION_CONSTRAINT': 0.5,
- 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt',
- 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt',
- 'LTRB':
- False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
- 'CRLF': False, # Lines are delimited by Windows CRLF format
- 'CONFIDENCES':
- False, # Detections must include confidence value. AP will be calculated
- 'PER_SAMPLE_RESULTS':
- True # Generate per sample results and produce data for visualization
+ "IOU_CONSTRAINT": 0.5,
+ "AREA_PRECISION_CONSTRAINT": 0.5,
+ "GT_SAMPLE_NAME_2_ID": "gt_img_([0-9]+).txt",
+ "DET_SAMPLE_NAME_2_ID": "res_img_([0-9]+).txt",
+ "LTRB": False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
+ "CRLF": False, # Lines are delimited by Windows CRLF format
+ "CONFIDENCES": False, # Detections must include confidence value. AP will be calculated
+ "PER_SAMPLE_RESULTS": True, # Generate per sample results and produce data for visualization
}
@@ -32,15 +29,18 @@ def validate_data(gtFilePath, submFilePath, evaluationParams):
If some error detected, the method raises the error
"""
gt = rrc_evaluation_funcs.load_folder_file(
- gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
+ gtFilePath, evaluationParams["GT_SAMPLE_NAME_2_ID"]
+ )
subm = rrc_evaluation_funcs.load_folder_file(
- submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
+ submFilePath, evaluationParams["DET_SAMPLE_NAME_2_ID"], True
+ )
# Validate format of GroundTruth
for k in gt:
rrc_evaluation_funcs.validate_lines_in_file(
- k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True)
+ k, gt[k], evaluationParams["CRLF"], evaluationParams["LTRB"], True
+ )
# Validate format of results
for k in subm:
@@ -48,8 +48,13 @@ def validate_data(gtFilePath, submFilePath, evaluationParams):
raise Exception("The sample %s not present in GT" % k)
rrc_evaluation_funcs.validate_lines_in_file(
- k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'],
- False, evaluationParams['CONFIDENCES'])
+ k,
+ subm[k],
+ evaluationParams["CRLF"],
+ evaluationParams["LTRB"],
+ False,
+ evaluationParams["CONFIDENCES"],
+ )
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
@@ -64,7 +69,7 @@ def polygon_from_points(points):
"""
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
"""
- resBoxes = np.empty([1, 8], dtype='int32')
+ resBoxes = np.empty([1, 8], dtype="int32")
resBoxes[0, 0] = int(points[0])
resBoxes[0, 4] = int(points[1])
resBoxes[0, 1] = int(points[2])
@@ -77,7 +82,7 @@ def polygon_from_points(points):
return plg.Polygon(pointMat)
def rectangle_to_polygon(rect):
- resBoxes = np.empty([1, 8], dtype='int32')
+ resBoxes = np.empty([1, 8], dtype="int32")
resBoxes[0, 0] = int(rect.xmin)
resBoxes[0, 4] = int(rect.ymax)
resBoxes[0, 1] = int(rect.xmin)
@@ -93,8 +98,14 @@ def rectangle_to_polygon(rect):
def rectangle_to_points(rect):
points = [
- int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax),
- int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)
+ int(rect.xmin),
+ int(rect.ymax),
+ int(rect.xmax),
+ int(rect.ymax),
+ int(rect.xmax),
+ int(rect.ymin),
+ int(rect.xmin),
+ int(rect.ymin),
]
return points
@@ -139,12 +150,14 @@ def compute_ap(confList, matchList, numGtCare):
matchedSum = 0
- Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
+ Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax")
gt = rrc_evaluation_funcs.load_folder_file(
- gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
+ gtFilePath, evaluationParams["GT_SAMPLE_NAME_2_ID"]
+ )
subm = rrc_evaluation_funcs.load_folder_file(
- submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
+ submFilePath, evaluationParams["DET_SAMPLE_NAME_2_ID"], True
+ )
numGlobalCareGt = 0
numGlobalCareDet = 0
@@ -153,7 +166,6 @@ def compute_ap(confList, matchList, numGtCare):
arrGlobalMatches = []
for resFile in gt:
-
gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile])
recall = 0
precision = 0
@@ -183,14 +195,18 @@ def compute_ap(confList, matchList, numGtCare):
evaluationLog = ""
- pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(
- gtFile, evaluationParams['CRLF'], evaluationParams['LTRB'], True,
- False)
+ (
+ pointsList,
+ _,
+ transcriptionsList,
+ ) = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(
+ gtFile, evaluationParams["CRLF"], evaluationParams["LTRB"], True, False
+ )
for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
dontCare = transcription == "###"
- if evaluationParams['LTRB']:
+ if evaluationParams["LTRB"]:
gtRect = Rectangle(*points)
gtPol = rectangle_to_polygon(gtRect)
else:
@@ -200,22 +216,34 @@ def compute_ap(confList, matchList, numGtCare):
if dontCare:
gtDontCarePolsNum.append(len(gtPols) - 1)
- evaluationLog += "GT polygons: " + str(len(gtPols)) + (
- " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
- if len(gtDontCarePolsNum) > 0 else "\n")
+ evaluationLog += (
+ "GT polygons: "
+ + str(len(gtPols))
+ + (
+ " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
+ if len(gtDontCarePolsNum) > 0
+ else "\n"
+ )
+ )
if resFile in subm:
-
- detFile = subm[
- resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile])
-
- pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(
- detFile, evaluationParams['CRLF'], evaluationParams['LTRB'],
- False, evaluationParams['CONFIDENCES'])
+ detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile])
+
+ (
+ pointsList,
+ confidencesList,
+ _,
+ ) = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(
+ detFile,
+ evaluationParams["CRLF"],
+ evaluationParams["LTRB"],
+ False,
+ evaluationParams["CONFIDENCES"],
+ )
for n in range(len(pointsList)):
points = pointsList[n]
- if evaluationParams['LTRB']:
+ if evaluationParams["LTRB"]:
detRect = Rectangle(*points)
detPol = rectangle_to_polygon(detRect)
else:
@@ -227,15 +255,22 @@ def compute_ap(confList, matchList, numGtCare):
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol, detPol)
pdDimensions = detPol.area()
- precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
- if (precision >
- evaluationParams['AREA_PRECISION_CONSTRAINT']):
+ precision = (
+ 0 if pdDimensions == 0 else intersected_area / pdDimensions
+ )
+ if precision > evaluationParams["AREA_PRECISION_CONSTRAINT"]:
detDontCarePolsNum.append(len(detPols) - 1)
break
- evaluationLog += "DET polygons: " + str(len(detPols)) + (
- " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
- if len(detDontCarePolsNum) > 0 else "\n")
+ evaluationLog += (
+ "DET polygons: "
+ + str(len(detPols))
+ + (
+ " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
+ if len(detDontCarePolsNum) > 0
+ else "\n"
+ )
+ )
if len(gtPols) > 0 and len(detPols) > 0:
# Calculate IoU and precision matrixs
@@ -247,24 +282,34 @@ def compute_ap(confList, matchList, numGtCare):
for detNum in range(len(detPols)):
pG = gtPols[gtNum]
pD = detPols[detNum]
- iouMat[gtNum, detNum] = get_intersection_over_union(pD,
- pG)
+ iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
- if iouMat[gtNum, detNum] > evaluationParams[
- 'IOU_CONSTRAINT']:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCarePolsNum
+ and detNum not in detDontCarePolsNum
+ ):
+ if (
+ iouMat[gtNum, detNum]
+ > evaluationParams["IOU_CONSTRAINT"]
+ ):
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
detMatched += 1
- pairs.append({'gt': gtNum, 'det': detNum})
+ pairs.append({"gt": gtNum, "det": detNum})
detMatchedNums.append(detNum)
- evaluationLog += "Match GT #" + str(
- gtNum) + " with Det #" + str(detNum) + "\n"
-
- if evaluationParams['CONFIDENCES']:
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
+
+ if evaluationParams["CONFIDENCES"]:
for detNum in range(len(detPols)):
if detNum not in detDontCarePolsNum:
# we exclude the don't care detections
@@ -276,8 +321,8 @@ def compute_ap(confList, matchList, numGtCare):
arrGlobalConfidences.append(confidencesList[detNum])
arrGlobalMatches.append(match)
- numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
- numDetCare = (len(detPols) - len(detDontCarePolsNum))
+ numGtCare = len(gtPols) - len(gtDontCarePolsNum)
+ numDetCare = len(detPols) - len(detDontCarePolsNum)
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare > 0 else float(1)
@@ -285,66 +330,73 @@ def compute_ap(confList, matchList, numGtCare):
else:
recall = float(detMatched) / numGtCare
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
- if evaluationParams['CONFIDENCES'] and evaluationParams[
- 'PER_SAMPLE_RESULTS']:
- sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch,
- numGtCare)
-
- hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (
- precision + recall)
+ if (
+ evaluationParams["CONFIDENCES"]
+ and evaluationParams["PER_SAMPLE_RESULTS"]
+ ):
+ sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare)
+
+ hmean = (
+ 0
+ if (precision + recall) == 0
+ else 2.0 * precision * recall / (precision + recall)
+ )
matchedSum += detMatched
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
- if evaluationParams['PER_SAMPLE_RESULTS']:
+ if evaluationParams["PER_SAMPLE_RESULTS"]:
perSampleMetrics[resFile] = {
- 'precision': precision,
- 'recall': recall,
- 'hmean': hmean,
- 'pairs': pairs,
- 'AP': sampleAP,
- 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
- 'gtPolPoints': gtPolPoints,
- 'detPolPoints': detPolPoints,
- 'gtDontCare': gtDontCarePolsNum,
- 'detDontCare': detDontCarePolsNum,
- 'evaluationParams': evaluationParams,
- 'evaluationLog': evaluationLog
+ "precision": precision,
+ "recall": recall,
+ "hmean": hmean,
+ "pairs": pairs,
+ "AP": sampleAP,
+ "iouMat": [] if len(detPols) > 100 else iouMat.tolist(),
+ "gtPolPoints": gtPolPoints,
+ "detPolPoints": detPolPoints,
+ "gtDontCare": gtDontCarePolsNum,
+ "detDontCare": detDontCarePolsNum,
+ "evaluationParams": evaluationParams,
+ "evaluationLog": evaluationLog,
}
# Compute MAP and MAR
AP = 0
- if evaluationParams['CONFIDENCES']:
+ if evaluationParams["CONFIDENCES"]:
AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
- methodRecall = 0 if numGlobalCareGt == 0 else float(
- matchedSum) / numGlobalCareGt
- methodPrecision = 0 if numGlobalCareDet == 0 else float(
- matchedSum) / numGlobalCareDet
- methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
- methodRecall + methodPrecision)
+ methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
+ methodPrecision = (
+ 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
+ )
+ methodHmean = (
+ 0
+ if methodRecall + methodPrecision == 0
+ else 2 * methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ )
methodMetrics = {
- 'precision': methodPrecision,
- 'recall': methodRecall,
- 'hmean': methodHmean,
- 'AP': AP
+ "precision": methodPrecision,
+ "recall": methodRecall,
+ "hmean": methodHmean,
+ "AP": AP,
}
resDict = {
- 'calculated': True,
- 'Message': '',
- 'method': methodMetrics,
- 'per_sample': perSampleMetrics
+ "calculated": True,
+ "Message": "",
+ "method": methodMetrics,
+ "per_sample": perSampleMetrics,
}
return resDict
def cal_recall_precison_f1(gt_path, result_path, show_result=False):
- p = {'g': gt_path, 's': result_path}
- result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params,
- validate_data,
- evaluate_method, show_result)
- return result['method']
+ p = {"g": gt_path, "s": result_path}
+ result = rrc_evaluation_funcs.main_evaluation(
+ p, default_evaluation_params, validate_data, evaluate_method, show_result
+ )
+ return result["method"]
diff --git a/benchmark/PaddleOCR_DBNet/utils/compute_mean_std.py b/benchmark/PaddleOCR_DBNet/utils/compute_mean_std.py
index 5d0ab5cd23..c76015eccc 100644
--- a/benchmark/PaddleOCR_DBNet/utils/compute_mean_std.py
+++ b/benchmark/PaddleOCR_DBNet/utils/compute_mean_std.py
@@ -7,8 +7,9 @@
import os
import random
from tqdm import tqdm
+
# calculate means and std
-train_txt_path = './train_val_list.txt'
+train_txt_path = "./train_val_list.txt"
CNum = 10000 # 挑选多少图片进行计算
@@ -16,12 +17,12 @@
imgs = np.zeros([img_w, img_h, 3, 1])
means, stdevs = [], []
-with open(train_txt_path, 'r') as f:
+with open(train_txt_path, "r") as f:
lines = f.readlines()
random.shuffle(lines) # shuffle , 随机挑选图片
for i in tqdm(range(CNum)):
- img_path = lines[i].split('\t')[0]
+ img_path = lines[i].split("\t")[0]
img = cv2.imread(img_path)
img = cv2.resize(img, (img_h, img_w))
@@ -30,7 +31,7 @@
imgs = np.concatenate((imgs, img), axis=3)
# print(i)
-imgs = imgs.astype(np.float32) / 255.
+imgs = imgs.astype(np.float32) / 255.0
for i in tqdm(range(3)):
pixels = imgs[:, :, i, :].ravel() # 拉成一行
@@ -43,4 +44,4 @@
print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))
-print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
\ No newline at end of file
+print("transforms.Normalize(normMean = {}, normStd = {})".format(means, stdevs))
diff --git a/benchmark/PaddleOCR_DBNet/utils/make_trainfile.py b/benchmark/PaddleOCR_DBNet/utils/make_trainfile.py
index 9b7ae70ff7..3232b011e0 100644
--- a/benchmark/PaddleOCR_DBNet/utils/make_trainfile.py
+++ b/benchmark/PaddleOCR_DBNet/utils/make_trainfile.py
@@ -5,17 +5,17 @@
import glob
import pathlib
-data_path = r'test'
+data_path = r"test"
# data_path/img 存放图片
# data_path/gt 存放标签文件
-f_w = open(os.path.join(data_path, 'test.txt'), 'w', encoding='utf8')
-for img_path in glob.glob(data_path + '/img/*.jpg', recursive=True):
+f_w = open(os.path.join(data_path, "test.txt"), "w", encoding="utf8")
+for img_path in glob.glob(data_path + "/img/*.jpg", recursive=True):
d = pathlib.Path(img_path)
- label_path = os.path.join(data_path, 'gt', ('gt_' + str(d.stem) + '.txt'))
+ label_path = os.path.join(data_path, "gt", ("gt_" + str(d.stem) + ".txt"))
if os.path.exists(img_path) and os.path.exists(label_path):
print(img_path, label_path)
else:
- print('不存在', img_path, label_path)
- f_w.write('{}\t{}\n'.format(img_path, label_path))
-f_w.close()
\ No newline at end of file
+ print("不存在", img_path, label_path)
+ f_w.write("{}\t{}\n".format(img_path, label_path))
+f_w.close()
diff --git a/benchmark/PaddleOCR_DBNet/utils/metrics.py b/benchmark/PaddleOCR_DBNet/utils/metrics.py
index e9c54b8d2e..81aa9bb4fd 100644
--- a/benchmark/PaddleOCR_DBNet/utils/metrics.py
+++ b/benchmark/PaddleOCR_DBNet/utils/metrics.py
@@ -16,42 +16,44 @@ def _fast_hist(self, label_true, label_pred, n_class):
print(label_pred[label_pred < 0])
hist = np.bincount(
n_class * label_true[mask].astype(int) + label_pred[mask],
- minlength=n_class**2).reshape(n_class, n_class)
+ minlength=n_class**2,
+ ).reshape(n_class, n_class)
return hist
def update(self, label_trues, label_preds):
# print label_trues.dtype, label_preds.dtype
for lt, lp in zip(label_trues, label_preds):
try:
- self.confusion_matrix += self._fast_hist(lt.flatten(),
- lp.flatten(),
- self.n_classes)
+ self.confusion_matrix += self._fast_hist(
+ lt.flatten(), lp.flatten(), self.n_classes
+ )
except:
pass
def get_scores(self):
"""Returns accuracy score evaluation result.
- - overall accuracy
- - mean accuracy
- - mean IU
- - fwavacc
+ - overall accuracy
+ - mean accuracy
+ - mean IU
+ - fwavacc
"""
hist = self.confusion_matrix
acc = np.diag(hist).sum() / (hist.sum() + 0.0001)
acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / (
- hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001)
+ hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001
+ )
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / (hist.sum() + 0.0001)
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
cls_iu = dict(zip(range(self.n_classes), iu))
return {
- 'Overall Acc': acc,
- 'Mean Acc': acc_cls,
- 'FreqW Acc': fwavacc,
- 'Mean IoU': mean_iu,
+ "Overall Acc": acc,
+ "Mean Acc": acc_cls,
+ "FreqW Acc": fwavacc,
+ "Mean IoU": mean_iu,
}, cls_iu
def reset(self):
diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/__init__.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/__init__.py
index 3e7c51cf06..005f39cb59 100644
--- a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/__init__.py
@@ -6,14 +6,14 @@
def get_metric(config):
try:
- if 'args' not in config:
+ if "args" not in config:
args = {}
else:
- args = config['args']
+ args = config["args"]
if isinstance(args, dict):
- cls = eval(config['type'])(**args)
+ cls = eval(config["type"])(**args)
else:
- cls = eval(config['type'])(args)
+ cls = eval(config["type"])(args)
return cls
except:
- return None
\ No newline at end of file
+ return None
diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/__init__.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/__init__.py
index 375ae557e9..5fd0e454d8 100644
--- a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/__init__.py
+++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/__init__.py
@@ -2,4 +2,4 @@
# @Time : 2019/12/5 15:36
# @Author : zhoujun
-from .quad_metric import QuadMetric
\ No newline at end of file
+from .quad_metric import QuadMetric
diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/deteval.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/deteval.py
index c5dcfc4b96..a23839e3d6 100644
--- a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/deteval.py
+++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/deteval.py
@@ -7,14 +7,15 @@
class DetectionDetEvalEvaluator(object):
- def __init__(self,
- area_recall_constraint=0.8,
- area_precision_constraint=0.4,
- ev_param_ind_center_diff_thr=1,
- mtype_oo_o=1.0,
- mtype_om_o=0.8,
- mtype_om_m=1.0):
-
+ def __init__(
+ self,
+ area_recall_constraint=0.8,
+ area_precision_constraint=0.4,
+ ev_param_ind_center_diff_thr=1,
+ mtype_oo_o=1.0,
+ mtype_om_o=0.8,
+ mtype_om_m=1.0,
+ ):
self.area_recall_constraint = area_recall_constraint
self.area_precision_constraint = area_precision_constraint
self.ev_param_ind_center_diff_thr = ev_param_ind_center_diff_thr
@@ -35,24 +36,27 @@ def get_intersection(pD, pG):
def one_to_one_match(row, col):
cont = 0
for j in range(len(recallMat[0])):
- if recallMat[row,
- j] >= self.area_recall_constraint and precisionMat[
- row, j] >= self.area_precision_constraint:
+ if (
+ recallMat[row, j] >= self.area_recall_constraint
+ and precisionMat[row, j] >= self.area_precision_constraint
+ ):
cont = cont + 1
- if (cont != 1):
+ if cont != 1:
return False
cont = 0
for i in range(len(recallMat)):
- if recallMat[
- i, col] >= self.area_recall_constraint and precisionMat[
- i, col] >= self.area_precision_constraint:
+ if (
+ recallMat[i, col] >= self.area_recall_constraint
+ and precisionMat[i, col] >= self.area_precision_constraint
+ ):
cont = cont + 1
- if (cont != 1):
+ if cont != 1:
return False
- if recallMat[row,
- col] >= self.area_recall_constraint and precisionMat[
- row, col] >= self.area_precision_constraint:
+ if (
+ recallMat[row, col] >= self.area_recall_constraint
+ and precisionMat[row, col] >= self.area_precision_constraint
+ ):
return True
return False
@@ -82,10 +86,12 @@ def one_to_many_match(gtNum):
many_sum = 0
detRects = []
for detNum in range(len(recallMat[0])):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and detNum not in detDontCareRectsNum:
- if precisionMat[gtNum,
- detNum] >= self.area_precision_constraint:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and detNum not in detDontCareRectsNum
+ ):
+ if precisionMat[gtNum, detNum] >= self.area_precision_constraint:
many_sum += recallMat[gtNum, detNum]
detRects.append(detNum)
if round(many_sum, 4) >= self.area_recall_constraint:
@@ -97,8 +103,11 @@ def many_to_one_match(detNum):
many_sum = 0
gtRects = []
for gtNum in range(len(recallMat)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCareRectsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCareRectsNum
+ ):
if recallMat[gtNum, detNum] >= self.area_recall_constraint:
many_sum += precisionMat[gtNum, detNum]
gtRects.append(gtNum)
@@ -108,28 +117,32 @@ def many_to_one_match(detNum):
return False, []
def center_distance(r1, r2):
- return ((np.mean(r1, axis=0) - np.mean(r2, axis=0))**2).sum()**0.5
+ return ((np.mean(r1, axis=0) - np.mean(r2, axis=0)) ** 2).sum() ** 0.5
def diag(r):
r = np.array(r)
- return ((r[:, 0].max() - r[:, 0].min())**2 +
- (r[:, 1].max() - r[:, 1].min())**2)**0.5
+ return (
+ (r[:, 0].max() - r[:, 0].min()) ** 2
+ + (r[:, 1].max() - r[:, 1].min()) ** 2
+ ) ** 0.5
perSampleMetrics = {}
recall = 0
precision = 0
hmean = 0
- recallAccum = 0.
- precisionAccum = 0.
+ recallAccum = 0.0
+ precisionAccum = 0.0
gtRects = []
detRects = []
gtPolPoints = []
detPolPoints = []
- gtDontCareRectsNum = [
- ] #Array of Ground Truth Rectangles' keys marked as don't Care
- detDontCareRectsNum = [
- ] #Array of Detected Rectangles' matched with a don't Care GT
+ gtDontCareRectsNum = (
+ []
+ ) # Array of Ground Truth Rectangles' keys marked as don't Care
+ detDontCareRectsNum = (
+ []
+ ) # Array of Detected Rectangles' matched with a don't Care GT
pairs = []
evaluationLog = ""
@@ -137,9 +150,9 @@ def diag(r):
precisionMat = np.empty([1, 1])
for n in range(len(gt)):
- points = gt[n]['points']
+ points = gt[n]["points"]
# transcription = gt[n]['text']
- dontCare = gt[n]['ignore']
+ dontCare = gt[n]["ignore"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -149,12 +162,18 @@ def diag(r):
if dontCare:
gtDontCareRectsNum.append(len(gtRects) - 1)
- evaluationLog += "GT rectangles: " + str(len(gtRects)) + (
- " (" + str(len(gtDontCareRectsNum)) + " don't care)\n"
- if len(gtDontCareRectsNum) > 0 else "\n")
+ evaluationLog += (
+ "GT rectangles: "
+ + str(len(gtRects))
+ + (
+ " (" + str(len(gtDontCareRectsNum)) + " don't care)\n"
+ if len(gtDontCareRectsNum) > 0
+ else "\n"
+ )
+ )
for n in range(len(pred)):
- points = pred[n]['points']
+ points = pred[n]["points"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -167,24 +186,30 @@ def diag(r):
dontCareRect = gtRects[dontCareRectNum]
intersected_area = get_intersection(dontCareRect, detRect)
rdDimensions = Polygon(detRect).area
- if (rdDimensions == 0):
+ if rdDimensions == 0:
precision = 0
else:
precision = intersected_area / rdDimensions
- if (precision > self.area_precision_constraint):
+ if precision > self.area_precision_constraint:
detDontCareRectsNum.append(len(detRects) - 1)
break
- evaluationLog += "DET rectangles: " + str(len(detRects)) + (
- " (" + str(len(detDontCareRectsNum)) + " don't care)\n"
- if len(detDontCareRectsNum) > 0 else "\n")
+ evaluationLog += (
+ "DET rectangles: "
+ + str(len(detRects))
+ + (
+ " (" + str(len(detDontCareRectsNum)) + " don't care)\n"
+ if len(detDontCareRectsNum) > 0
+ else "\n"
+ )
+ )
if len(gtRects) == 0:
recall = 1
precision = 0 if len(detRects) > 0 else 1
if len(detRects) > 0:
- #Calculate recall and precision matrixs
+ # Calculate recall and precision matrixs
outputShape = [len(gtRects), len(detRects)]
recallMat = np.empty(outputShape)
precisionMat = np.empty(outputShape)
@@ -197,22 +222,26 @@ def diag(r):
intersected_area = get_intersection(rG, rD)
rgDimensions = Polygon(rG).area
rdDimensions = Polygon(rD).area
- recallMat[
- gtNum,
- detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions
- precisionMat[
- gtNum,
- detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions
+ recallMat[gtNum, detNum] = (
+ 0 if rgDimensions == 0 else intersected_area / rgDimensions
+ )
+ precisionMat[gtNum, detNum] = (
+ 0 if rdDimensions == 0 else intersected_area / rdDimensions
+ )
# Find one-to-one matches
evaluationLog += "Find one-to-one matches\n"
for gtNum in range(len(gtRects)):
for detNum in range(len(detRects)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCareRectsNum
+ and detNum not in detDontCareRectsNum
+ ):
match = one_to_one_match(gtNum, detNum)
if match is True:
- #in deteval we have to make other validation before mark as one-to-one
+ # in deteval we have to make other validation before mark as one-to-one
if is_single_overlap(gtNum, detNum) is True:
rG = gtRects[gtNum]
rD = detRects[detNum]
@@ -224,23 +253,34 @@ def diag(r):
detRectMat[detNum] = 1
recallAccum += self.mtype_oo_o
precisionAccum += self.mtype_oo_o
- pairs.append({
- 'gt': gtNum,
- 'det': detNum,
- 'type': 'OO'
- })
- evaluationLog += "Match GT #" + str(
- gtNum) + " with Det #" + str(
- detNum) + "\n"
+ pairs.append(
+ {"gt": gtNum, "det": detNum, "type": "OO"}
+ )
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
else:
- evaluationLog += "Match Discarded GT #" + str(
- gtNum) + " with Det #" + str(
- detNum) + " normDist: " + str(
- normDist) + " \n"
+ evaluationLog += (
+ "Match Discarded GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + " normDist: "
+ + str(normDist)
+ + " \n"
+ )
else:
- evaluationLog += "Match Discarded GT #" + str(
- gtNum) + " with Det #" + str(
- detNum) + " not single overlap\n"
+ evaluationLog += (
+ "Match Discarded GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + " not single overlap\n"
+ )
# Find one-to-many matches
evaluationLog += "Find one-to-many matches\n"
for gtNum in range(len(gtRects)):
@@ -248,30 +288,45 @@ def diag(r):
match, matchesDet = one_to_many_match(gtNum)
if match is True:
evaluationLog += "num_overlaps_gt=" + str(
- num_overlaps_gt(gtNum))
- #in deteval we have to make other validation before mark as one-to-one
+ num_overlaps_gt(gtNum)
+ )
+ # in deteval we have to make other validation before mark as one-to-one
if num_overlaps_gt(gtNum) >= 2:
gtRectMat[gtNum] = 1
- recallAccum += (self.mtype_oo_o
- if len(matchesDet) == 1 else
- self.mtype_om_o)
- precisionAccum += (self.mtype_oo_o
- if len(matchesDet) == 1 else
- self.mtype_om_o *
- len(matchesDet))
- pairs.append({
- 'gt': gtNum,
- 'det': matchesDet,
- 'type': 'OO' if len(matchesDet) == 1 else 'OM'
- })
+ recallAccum += (
+ self.mtype_oo_o
+ if len(matchesDet) == 1
+ else self.mtype_om_o
+ )
+ precisionAccum += (
+ self.mtype_oo_o
+ if len(matchesDet) == 1
+ else self.mtype_om_o * len(matchesDet)
+ )
+ pairs.append(
+ {
+ "gt": gtNum,
+ "det": matchesDet,
+ "type": "OO" if len(matchesDet) == 1 else "OM",
+ }
+ )
for detNum in matchesDet:
detRectMat[detNum] = 1
- evaluationLog += "Match GT #" + str(
- gtNum) + " with Det #" + str(matchesDet) + "\n"
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(matchesDet)
+ + "\n"
+ )
else:
- evaluationLog += "Match Discarded GT #" + str(
- gtNum) + " with Det #" + str(
- matchesDet) + " not single overlap\n"
+ evaluationLog += (
+ "Match Discarded GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(matchesDet)
+ + " not single overlap\n"
+ )
# Find many-to-one matches
evaluationLog += "Find many-to-one matches\n"
@@ -279,63 +334,81 @@ def diag(r):
if detNum not in detDontCareRectsNum:
match, matchesGt = many_to_one_match(detNum)
if match is True:
- #in deteval we have to make other validation before mark as one-to-one
+ # in deteval we have to make other validation before mark as one-to-one
if num_overlaps_det(detNum) >= 2:
detRectMat[detNum] = 1
- recallAccum += (self.mtype_oo_o
- if len(matchesGt) == 1 else
- self.mtype_om_m * len(matchesGt))
- precisionAccum += (self.mtype_oo_o
- if len(matchesGt) == 1 else
- self.mtype_om_m)
- pairs.append({
- 'gt': matchesGt,
- 'det': detNum,
- 'type': 'OO' if len(matchesGt) == 1 else 'MO'
- })
+ recallAccum += (
+ self.mtype_oo_o
+ if len(matchesGt) == 1
+ else self.mtype_om_m * len(matchesGt)
+ )
+ precisionAccum += (
+ self.mtype_oo_o
+ if len(matchesGt) == 1
+ else self.mtype_om_m
+ )
+ pairs.append(
+ {
+ "gt": matchesGt,
+ "det": detNum,
+ "type": "OO" if len(matchesGt) == 1 else "MO",
+ }
+ )
for gtNum in matchesGt:
gtRectMat[gtNum] = 1
- evaluationLog += "Match GT #" + str(
- matchesGt) + " with Det #" + str(detNum) + "\n"
+ evaluationLog += (
+ "Match GT #"
+ + str(matchesGt)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
else:
- evaluationLog += "Match Discarded GT #" + str(
- matchesGt) + " with Det #" + str(
- detNum) + " not single overlap\n"
-
- numGtCare = (len(gtRects) - len(gtDontCareRectsNum))
+ evaluationLog += (
+ "Match Discarded GT #"
+ + str(matchesGt)
+ + " with Det #"
+ + str(detNum)
+ + " not single overlap\n"
+ )
+
+ numGtCare = len(gtRects) - len(gtDontCareRectsNum)
if numGtCare == 0:
recall = float(1)
precision = float(0) if len(detRects) > 0 else float(1)
else:
recall = float(recallAccum) / numGtCare
- precision = float(0) if (
- len(detRects) - len(detDontCareRectsNum)
- ) == 0 else float(precisionAccum) / (
- len(detRects) - len(detDontCareRectsNum))
- hmean = 0 if (precision + recall
- ) == 0 else 2.0 * precision * recall / (
- precision + recall)
+ precision = (
+ float(0)
+ if (len(detRects) - len(detDontCareRectsNum)) == 0
+ else float(precisionAccum)
+ / (len(detRects) - len(detDontCareRectsNum))
+ )
+ hmean = (
+ 0
+ if (precision + recall) == 0
+ else 2.0 * precision * recall / (precision + recall)
+ )
numGtCare = len(gtRects) - len(gtDontCareRectsNum)
numDetCare = len(detRects) - len(detDontCareRectsNum)
perSampleMetrics = {
- 'precision': precision,
- 'recall': recall,
- 'hmean': hmean,
- 'pairs': pairs,
- 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(),
- 'precisionMat': []
- if len(detRects) > 100 else precisionMat.tolist(),
- 'gtPolPoints': gtPolPoints,
- 'detPolPoints': detPolPoints,
- 'gtCare': numGtCare,
- 'detCare': numDetCare,
- 'gtDontCare': gtDontCareRectsNum,
- 'detDontCare': detDontCareRectsNum,
- 'recallAccum': recallAccum,
- 'precisionAccum': precisionAccum,
- 'evaluationLog': evaluationLog
+ "precision": precision,
+ "recall": recall,
+ "hmean": hmean,
+ "pairs": pairs,
+ "recallMat": [] if len(detRects) > 100 else recallMat.tolist(),
+ "precisionMat": [] if len(detRects) > 100 else precisionMat.tolist(),
+ "gtPolPoints": gtPolPoints,
+ "detPolPoints": detPolPoints,
+ "gtCare": numGtCare,
+ "detCare": numDetCare,
+ "gtDontCare": gtDontCareRectsNum,
+ "detDontCare": detDontCareRectsNum,
+ "recallAccum": recallAccum,
+ "precisionAccum": precisionAccum,
+ "evaluationLog": evaluationLog,
}
return perSampleMetrics
@@ -347,41 +420,53 @@ def combine_results(self, results):
methodPrecisionSum = 0
for result in results:
- numGt += result['gtCare']
- numDet += result['detCare']
- methodRecallSum += result['recallAccum']
- methodPrecisionSum += result['precisionAccum']
+ numGt += result["gtCare"]
+ numDet += result["detCare"]
+ methodRecallSum += result["recallAccum"]
+ methodPrecisionSum += result["precisionAccum"]
methodRecall = 0 if numGt == 0 else methodRecallSum / numGt
methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet
- methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
- methodRecall + methodPrecision)
+ methodHmean = (
+ 0
+ if methodRecall + methodPrecision == 0
+ else 2 * methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ )
methodMetrics = {
- 'precision': methodPrecision,
- 'recall': methodRecall,
- 'hmean': methodHmean
+ "precision": methodPrecision,
+ "recall": methodRecall,
+ "hmean": methodHmean,
}
return methodMetrics
-if __name__ == '__main__':
+if __name__ == "__main__":
evaluator = DetectionDetEvalEvaluator()
- gts = [[{
- 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
- 'text': 1234,
- 'ignore': False,
- }, {
- 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
- 'text': 5678,
- 'ignore': True,
- }]]
- preds = [[{
- 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
- 'text': 123,
- 'ignore': False,
- }]]
+ gts = [
+ [
+ {
+ "points": [(0, 0), (1, 0), (1, 1), (0, 1)],
+ "text": 1234,
+ "ignore": False,
+ },
+ {
+ "points": [(2, 2), (3, 2), (3, 3), (2, 3)],
+ "text": 5678,
+ "ignore": True,
+ },
+ ]
+ ]
+ preds = [
+ [
+ {
+ "points": [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+ "text": 123,
+ "ignore": False,
+ }
+ ]
+ ]
results = []
for gt, pred in zip(gts, preds):
results.append(evaluator.evaluate_image(gt, pred))
diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/icdar2013.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/icdar2013.py
index 7e8c86aae3..d29c2c0c0e 100644
--- a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/icdar2013.py
+++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/icdar2013.py
@@ -7,14 +7,15 @@
class DetectionICDAR2013Evaluator(object):
- def __init__(self,
- area_recall_constraint=0.8,
- area_precision_constraint=0.4,
- ev_param_ind_center_diff_thr=1,
- mtype_oo_o=1.0,
- mtype_om_o=0.8,
- mtype_om_m=1.0):
-
+ def __init__(
+ self,
+ area_recall_constraint=0.8,
+ area_precision_constraint=0.4,
+ ev_param_ind_center_diff_thr=1,
+ mtype_oo_o=1.0,
+ mtype_om_o=0.8,
+ mtype_om_m=1.0,
+ ):
self.area_recall_constraint = area_recall_constraint
self.area_precision_constraint = area_precision_constraint
self.ev_param_ind_center_diff_thr = ev_param_ind_center_diff_thr
@@ -35,24 +36,27 @@ def get_intersection(pD, pG):
def one_to_one_match(row, col):
cont = 0
for j in range(len(recallMat[0])):
- if recallMat[row,
- j] >= self.area_recall_constraint and precisionMat[
- row, j] >= self.area_precision_constraint:
+ if (
+ recallMat[row, j] >= self.area_recall_constraint
+ and precisionMat[row, j] >= self.area_precision_constraint
+ ):
cont = cont + 1
- if (cont != 1):
+ if cont != 1:
return False
cont = 0
for i in range(len(recallMat)):
- if recallMat[
- i, col] >= self.area_recall_constraint and precisionMat[
- i, col] >= self.area_precision_constraint:
+ if (
+ recallMat[i, col] >= self.area_recall_constraint
+ and precisionMat[i, col] >= self.area_precision_constraint
+ ):
cont = cont + 1
- if (cont != 1):
+ if cont != 1:
return False
- if recallMat[row,
- col] >= self.area_recall_constraint and precisionMat[
- row, col] >= self.area_precision_constraint:
+ if (
+ recallMat[row, col] >= self.area_recall_constraint
+ and precisionMat[row, col] >= self.area_precision_constraint
+ ):
return True
return False
@@ -60,10 +64,12 @@ def one_to_many_match(gtNum):
many_sum = 0
detRects = []
for detNum in range(len(recallMat[0])):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and detNum not in detDontCareRectsNum:
- if precisionMat[gtNum,
- detNum] >= self.area_precision_constraint:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and detNum not in detDontCareRectsNum
+ ):
+ if precisionMat[gtNum, detNum] >= self.area_precision_constraint:
many_sum += recallMat[gtNum, detNum]
detRects.append(detNum)
if round(many_sum, 4) >= self.area_recall_constraint:
@@ -75,8 +81,11 @@ def many_to_one_match(detNum):
many_sum = 0
gtRects = []
for gtNum in range(len(recallMat)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCareRectsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCareRectsNum
+ ):
if recallMat[gtNum, detNum] >= self.area_recall_constraint:
many_sum += precisionMat[gtNum, detNum]
gtRects.append(gtNum)
@@ -86,28 +95,32 @@ def many_to_one_match(detNum):
return False, []
def center_distance(r1, r2):
- return ((np.mean(r1, axis=0) - np.mean(r2, axis=0))**2).sum()**0.5
+ return ((np.mean(r1, axis=0) - np.mean(r2, axis=0)) ** 2).sum() ** 0.5
def diag(r):
r = np.array(r)
- return ((r[:, 0].max() - r[:, 0].min())**2 +
- (r[:, 1].max() - r[:, 1].min())**2)**0.5
+ return (
+ (r[:, 0].max() - r[:, 0].min()) ** 2
+ + (r[:, 1].max() - r[:, 1].min()) ** 2
+ ) ** 0.5
perSampleMetrics = {}
recall = 0
precision = 0
hmean = 0
- recallAccum = 0.
- precisionAccum = 0.
+ recallAccum = 0.0
+ precisionAccum = 0.0
gtRects = []
detRects = []
gtPolPoints = []
detPolPoints = []
- gtDontCareRectsNum = [
- ] #Array of Ground Truth Rectangles' keys marked as don't Care
- detDontCareRectsNum = [
- ] #Array of Detected Rectangles' matched with a don't Care GT
+ gtDontCareRectsNum = (
+ []
+ ) # Array of Ground Truth Rectangles' keys marked as don't Care
+ detDontCareRectsNum = (
+ []
+ ) # Array of Detected Rectangles' matched with a don't Care GT
pairs = []
evaluationLog = ""
@@ -115,9 +128,9 @@ def diag(r):
precisionMat = np.empty([1, 1])
for n in range(len(gt)):
- points = gt[n]['points']
+ points = gt[n]["points"]
# transcription = gt[n]['text']
- dontCare = gt[n]['ignore']
+ dontCare = gt[n]["ignore"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -127,12 +140,18 @@ def diag(r):
if dontCare:
gtDontCareRectsNum.append(len(gtRects) - 1)
- evaluationLog += "GT rectangles: " + str(len(gtRects)) + (
- " (" + str(len(gtDontCareRectsNum)) + " don't care)\n"
- if len(gtDontCareRectsNum) > 0 else "\n")
+ evaluationLog += (
+ "GT rectangles: "
+ + str(len(gtRects))
+ + (
+ " (" + str(len(gtDontCareRectsNum)) + " don't care)\n"
+ if len(gtDontCareRectsNum) > 0
+ else "\n"
+ )
+ )
for n in range(len(pred)):
- points = pred[n]['points']
+ points = pred[n]["points"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -145,24 +164,30 @@ def diag(r):
dontCareRect = gtRects[dontCareRectNum]
intersected_area = get_intersection(dontCareRect, detRect)
rdDimensions = Polygon(detRect).area
- if (rdDimensions == 0):
+ if rdDimensions == 0:
precision = 0
else:
precision = intersected_area / rdDimensions
- if (precision > self.area_precision_constraint):
+ if precision > self.area_precision_constraint:
detDontCareRectsNum.append(len(detRects) - 1)
break
- evaluationLog += "DET rectangles: " + str(len(detRects)) + (
- " (" + str(len(detDontCareRectsNum)) + " don't care)\n"
- if len(detDontCareRectsNum) > 0 else "\n")
+ evaluationLog += (
+ "DET rectangles: "
+ + str(len(detRects))
+ + (
+ " (" + str(len(detDontCareRectsNum)) + " don't care)\n"
+ if len(detDontCareRectsNum) > 0
+ else "\n"
+ )
+ )
if len(gtRects) == 0:
recall = 1
precision = 0 if len(detRects) > 0 else 1
if len(detRects) > 0:
- #Calculate recall and precision matrixs
+ # Calculate recall and precision matrixs
outputShape = [len(gtRects), len(detRects)]
recallMat = np.empty(outputShape)
precisionMat = np.empty(outputShape)
@@ -175,22 +200,26 @@ def diag(r):
intersected_area = get_intersection(rG, rD)
rgDimensions = Polygon(rG).area
rdDimensions = Polygon(rD).area
- recallMat[
- gtNum,
- detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions
- precisionMat[
- gtNum,
- detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions
+ recallMat[gtNum, detNum] = (
+ 0 if rgDimensions == 0 else intersected_area / rgDimensions
+ )
+ precisionMat[gtNum, detNum] = (
+ 0 if rdDimensions == 0 else intersected_area / rdDimensions
+ )
# Find one-to-one matches
evaluationLog += "Find one-to-one matches\n"
for gtNum in range(len(gtRects)):
for detNum in range(len(detRects)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCareRectsNum
+ and detNum not in detDontCareRectsNum
+ ):
match = one_to_one_match(gtNum, detNum)
if match is True:
- #in deteval we have to make other validation before mark as one-to-one
+ # in deteval we have to make other validation before mark as one-to-one
rG = gtRects[gtNum]
rD = detRects[detNum]
normDist = center_distance(rG, rD)
@@ -201,18 +230,24 @@ def diag(r):
detRectMat[detNum] = 1
recallAccum += self.mtype_oo_o
precisionAccum += self.mtype_oo_o
- pairs.append({
- 'gt': gtNum,
- 'det': detNum,
- 'type': 'OO'
- })
- evaluationLog += "Match GT #" + str(
- gtNum) + " with Det #" + str(detNum) + "\n"
+ pairs.append({"gt": gtNum, "det": detNum, "type": "OO"})
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
else:
- evaluationLog += "Match Discarded GT #" + str(
- gtNum) + " with Det #" + str(
- detNum) + " normDist: " + str(
- normDist) + " \n"
+ evaluationLog += (
+ "Match Discarded GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + " normDist: "
+ + str(normDist)
+ + " \n"
+ )
# Find one-to-many matches
evaluationLog += "Find one-to-many matches\n"
for gtNum in range(len(gtRects)):
@@ -220,22 +255,33 @@ def diag(r):
match, matchesDet = one_to_many_match(gtNum)
if match is True:
evaluationLog += "num_overlaps_gt=" + str(
- num_overlaps_gt(gtNum))
+ num_overlaps_gt(gtNum)
+ )
gtRectMat[gtNum] = 1
- recallAccum += (self.mtype_oo_o if len(matchesDet) == 1
- else self.mtype_om_o)
- precisionAccum += (self.mtype_oo_o
- if len(matchesDet) == 1 else
- self.mtype_om_o * len(matchesDet))
- pairs.append({
- 'gt': gtNum,
- 'det': matchesDet,
- 'type': 'OO' if len(matchesDet) == 1 else 'OM'
- })
+ recallAccum += (
+ self.mtype_oo_o if len(matchesDet) == 1 else self.mtype_om_o
+ )
+ precisionAccum += (
+ self.mtype_oo_o
+ if len(matchesDet) == 1
+ else self.mtype_om_o * len(matchesDet)
+ )
+ pairs.append(
+ {
+ "gt": gtNum,
+ "det": matchesDet,
+ "type": "OO" if len(matchesDet) == 1 else "OM",
+ }
+ )
for detNum in matchesDet:
detRectMat[detNum] = 1
- evaluationLog += "Match GT #" + str(
- gtNum) + " with Det #" + str(matchesDet) + "\n"
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(matchesDet)
+ + "\n"
+ )
# Find many-to-one matches
evaluationLog += "Find many-to-one matches\n"
@@ -244,55 +290,68 @@ def diag(r):
match, matchesGt = many_to_one_match(detNum)
if match is True:
detRectMat[detNum] = 1
- recallAccum += (self.mtype_oo_o if len(matchesGt) == 1
- else self.mtype_om_m * len(matchesGt))
- precisionAccum += (self.mtype_oo_o
- if len(matchesGt) == 1 else
- self.mtype_om_m)
- pairs.append({
- 'gt': matchesGt,
- 'det': detNum,
- 'type': 'OO' if len(matchesGt) == 1 else 'MO'
- })
+ recallAccum += (
+ self.mtype_oo_o
+ if len(matchesGt) == 1
+ else self.mtype_om_m * len(matchesGt)
+ )
+ precisionAccum += (
+ self.mtype_oo_o if len(matchesGt) == 1 else self.mtype_om_m
+ )
+ pairs.append(
+ {
+ "gt": matchesGt,
+ "det": detNum,
+ "type": "OO" if len(matchesGt) == 1 else "MO",
+ }
+ )
for gtNum in matchesGt:
gtRectMat[gtNum] = 1
- evaluationLog += "Match GT #" + str(
- matchesGt) + " with Det #" + str(detNum) + "\n"
-
- numGtCare = (len(gtRects) - len(gtDontCareRectsNum))
+ evaluationLog += (
+ "Match GT #"
+ + str(matchesGt)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
+
+ numGtCare = len(gtRects) - len(gtDontCareRectsNum)
if numGtCare == 0:
recall = float(1)
precision = float(0) if len(detRects) > 0 else float(1)
else:
recall = float(recallAccum) / numGtCare
- precision = float(0) if (
- len(detRects) - len(detDontCareRectsNum)
- ) == 0 else float(precisionAccum) / (
- len(detRects) - len(detDontCareRectsNum))
- hmean = 0 if (precision + recall
- ) == 0 else 2.0 * precision * recall / (
- precision + recall)
+ precision = (
+ float(0)
+ if (len(detRects) - len(detDontCareRectsNum)) == 0
+ else float(precisionAccum)
+ / (len(detRects) - len(detDontCareRectsNum))
+ )
+ hmean = (
+ 0
+ if (precision + recall) == 0
+ else 2.0 * precision * recall / (precision + recall)
+ )
numGtCare = len(gtRects) - len(gtDontCareRectsNum)
numDetCare = len(detRects) - len(detDontCareRectsNum)
perSampleMetrics = {
- 'precision': precision,
- 'recall': recall,
- 'hmean': hmean,
- 'pairs': pairs,
- 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(),
- 'precisionMat': []
- if len(detRects) > 100 else precisionMat.tolist(),
- 'gtPolPoints': gtPolPoints,
- 'detPolPoints': detPolPoints,
- 'gtCare': numGtCare,
- 'detCare': numDetCare,
- 'gtDontCare': gtDontCareRectsNum,
- 'detDontCare': detDontCareRectsNum,
- 'recallAccum': recallAccum,
- 'precisionAccum': precisionAccum,
- 'evaluationLog': evaluationLog
+ "precision": precision,
+ "recall": recall,
+ "hmean": hmean,
+ "pairs": pairs,
+ "recallMat": [] if len(detRects) > 100 else recallMat.tolist(),
+ "precisionMat": [] if len(detRects) > 100 else precisionMat.tolist(),
+ "gtPolPoints": gtPolPoints,
+ "detPolPoints": detPolPoints,
+ "gtCare": numGtCare,
+ "detCare": numDetCare,
+ "gtDontCare": gtDontCareRectsNum,
+ "detDontCare": detDontCareRectsNum,
+ "recallAccum": recallAccum,
+ "precisionAccum": precisionAccum,
+ "evaluationLog": evaluationLog,
}
return perSampleMetrics
@@ -304,41 +363,53 @@ def combine_results(self, results):
methodPrecisionSum = 0
for result in results:
- numGt += result['gtCare']
- numDet += result['detCare']
- methodRecallSum += result['recallAccum']
- methodPrecisionSum += result['precisionAccum']
+ numGt += result["gtCare"]
+ numDet += result["detCare"]
+ methodRecallSum += result["recallAccum"]
+ methodPrecisionSum += result["precisionAccum"]
methodRecall = 0 if numGt == 0 else methodRecallSum / numGt
methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet
- methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
- methodRecall + methodPrecision)
+ methodHmean = (
+ 0
+ if methodRecall + methodPrecision == 0
+ else 2 * methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ )
methodMetrics = {
- 'precision': methodPrecision,
- 'recall': methodRecall,
- 'hmean': methodHmean
+ "precision": methodPrecision,
+ "recall": methodRecall,
+ "hmean": methodHmean,
}
return methodMetrics
-if __name__ == '__main__':
+if __name__ == "__main__":
evaluator = DetectionICDAR2013Evaluator()
- gts = [[{
- 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
- 'text': 1234,
- 'ignore': False,
- }, {
- 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
- 'text': 5678,
- 'ignore': True,
- }]]
- preds = [[{
- 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
- 'text': 123,
- 'ignore': False,
- }]]
+ gts = [
+ [
+ {
+ "points": [(0, 0), (1, 0), (1, 1), (0, 1)],
+ "text": 1234,
+ "ignore": False,
+ },
+ {
+ "points": [(2, 2), (3, 2), (3, 3), (2, 3)],
+ "text": 5678,
+ "ignore": True,
+ },
+ ]
+ ]
+ preds = [
+ [
+ {
+ "points": [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+ "text": 123,
+ "ignore": False,
+ }
+ ]
+ ]
results = []
for gt, pred in zip(gts, preds):
results.append(evaluator.evaluate_image(gt, pred))
diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/iou.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/iou.py
index 5f9533b3c3..85fde783cb 100644
--- a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/iou.py
+++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/iou.py
@@ -6,7 +6,7 @@
import cv2
-def iou_rotate(box_a, box_b, method='union'):
+def iou_rotate(box_a, box_b, method="union"):
rect_a = cv2.minAreaRect(box_a)
rect_b = cv2.minAreaRect(box_b)
r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b)
@@ -19,9 +19,9 @@ def iou_rotate(box_a, box_b, method='union'):
union_area = area_a + area_b - inter_area
if union_area == 0 or inter_area == 0:
return 0
- if method == 'union':
+ if method == "union":
iou = inter_area / union_area
- elif method == 'intersection':
+ elif method == "intersection":
iou = inter_area / min(area_a, area_b)
else:
raise NotImplementedError
@@ -29,10 +29,9 @@ def iou_rotate(box_a, box_b, method='union'):
class DetectionIoUEvaluator(object):
- def __init__(self,
- is_output_polygon=False,
- iou_constraint=0.5,
- area_precision_constraint=0.5):
+ def __init__(
+ self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5
+ ):
self.is_output_polygon = is_output_polygon
self.iou_constraint = iou_constraint
self.area_precision_constraint = area_precision_constraint
@@ -71,7 +70,7 @@ def compute_ap(confList, matchList, numGtCare):
matchedSum = 0
- Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
+ Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax")
numGlobalCareGt = 0
numGlobalCareDet = 0
@@ -107,9 +106,9 @@ def compute_ap(confList, matchList, numGtCare):
evaluationLog = ""
for n in range(len(gt)):
- points = gt[n]['points']
+ points = gt[n]["points"]
# transcription = gt[n]['text']
- dontCare = gt[n]['ignore']
+ dontCare = gt[n]["ignore"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -120,12 +119,18 @@ def compute_ap(confList, matchList, numGtCare):
if dontCare:
gtDontCarePolsNum.append(len(gtPols) - 1)
- evaluationLog += "GT polygons: " + str(len(gtPols)) + (
- " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
- if len(gtDontCarePolsNum) > 0 else "\n")
+ evaluationLog += (
+ "GT polygons: "
+ + str(len(gtPols))
+ + (
+ " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
+ if len(gtDontCarePolsNum) > 0
+ else "\n"
+ )
+ )
for n in range(len(pred)):
- points = pred[n]['points']
+ points = pred[n]["points"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -137,14 +142,22 @@ def compute_ap(confList, matchList, numGtCare):
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol, detPol)
pdDimensions = Polygon(detPol).area
- precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
- if (precision > self.area_precision_constraint):
+ precision = (
+ 0 if pdDimensions == 0 else intersected_area / pdDimensions
+ )
+ if precision > self.area_precision_constraint:
detDontCarePolsNum.append(len(detPols) - 1)
break
- evaluationLog += "DET polygons: " + str(len(detPols)) + (
- " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
- if len(detDontCarePolsNum) > 0 else "\n")
+ evaluationLog += (
+ "DET polygons: "
+ + str(len(detPols))
+ + (
+ " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
+ if len(detDontCarePolsNum) > 0
+ else "\n"
+ )
+ )
if len(gtPols) > 0 and len(detPols) > 0:
# Calculate IoU and precision matrixs
@@ -157,8 +170,7 @@ def compute_ap(confList, matchList, numGtCare):
for detNum in range(len(detPols)):
pG = gtPols[gtNum]
pD = detPols[detNum]
- iouMat[gtNum, detNum] = get_intersection_over_union(pD,
- pG)
+ iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
else:
# gtPols = np.float32(gtPols)
# detPols = np.float32(detPols)
@@ -169,19 +181,28 @@ def compute_ap(confList, matchList, numGtCare):
iouMat[gtNum, detNum] = iou_rotate(pD, pG)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCarePolsNum
+ and detNum not in detDontCarePolsNum
+ ):
if iouMat[gtNum, detNum] > self.iou_constraint:
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
detMatched += 1
- pairs.append({'gt': gtNum, 'det': detNum})
+ pairs.append({"gt": gtNum, "det": detNum})
detMatchedNums.append(detNum)
- evaluationLog += "Match GT #" + \
- str(gtNum) + " with Det #" + str(detNum) + "\n"
-
- numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
- numDetCare = (len(detPols) - len(detDontCarePolsNum))
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
+
+ numGtCare = len(gtPols) - len(gtDontCarePolsNum)
+ numDetCare = len(detPols) - len(detDontCarePolsNum)
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare > 0 else float(1)
@@ -189,27 +210,30 @@ def compute_ap(confList, matchList, numGtCare):
recall = float(detMatched) / numGtCare
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
- hmean = 0 if (precision + recall) == 0 else 2.0 * \
- precision * recall / (precision + recall)
+ hmean = (
+ 0
+ if (precision + recall) == 0
+ else 2.0 * precision * recall / (precision + recall)
+ )
matchedSum += detMatched
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
perSampleMetrics = {
- 'precision': precision,
- 'recall': recall,
- 'hmean': hmean,
- 'pairs': pairs,
- 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
- 'gtPolPoints': gtPolPoints,
- 'detPolPoints': detPolPoints,
- 'gtCare': numGtCare,
- 'detCare': numDetCare,
- 'gtDontCare': gtDontCarePolsNum,
- 'detDontCare': detDontCarePolsNum,
- 'detMatched': detMatched,
- 'evaluationLog': evaluationLog
+ "precision": precision,
+ "recall": recall,
+ "hmean": hmean,
+ "pairs": pairs,
+ "iouMat": [] if len(detPols) > 100 else iouMat.tolist(),
+ "gtPolPoints": gtPolPoints,
+ "detPolPoints": detPolPoints,
+ "gtCare": numGtCare,
+ "detCare": numDetCare,
+ "gtDontCare": gtDontCarePolsNum,
+ "detDontCare": detDontCarePolsNum,
+ "detMatched": detMatched,
+ "evaluationLog": evaluationLog,
}
return perSampleMetrics
@@ -219,43 +243,56 @@ def combine_results(self, results):
numGlobalCareDet = 0
matchedSum = 0
for result in results:
- numGlobalCareGt += result['gtCare']
- numGlobalCareDet += result['detCare']
- matchedSum += result['detMatched']
-
- methodRecall = 0 if numGlobalCareGt == 0 else float(
- matchedSum) / numGlobalCareGt
- methodPrecision = 0 if numGlobalCareDet == 0 else float(
- matchedSum) / numGlobalCareDet
- methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
- methodRecall * methodPrecision / (
- methodRecall + methodPrecision)
+ numGlobalCareGt += result["gtCare"]
+ numGlobalCareDet += result["detCare"]
+ matchedSum += result["detMatched"]
+
+ methodRecall = (
+ 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
+ )
+ methodPrecision = (
+ 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
+ )
+ methodHmean = (
+ 0
+ if methodRecall + methodPrecision == 0
+ else 2 * methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ )
methodMetrics = {
- 'precision': methodPrecision,
- 'recall': methodRecall,
- 'hmean': methodHmean
+ "precision": methodPrecision,
+ "recall": methodRecall,
+ "hmean": methodHmean,
}
return methodMetrics
-if __name__ == '__main__':
+if __name__ == "__main__":
evaluator = DetectionIoUEvaluator()
- preds = [[{
- 'points': [(0.1, 0.1), (0.5, 0), (0.5, 1), (0, 1)],
- 'text': 1234,
- 'ignore': False,
- }, {
- 'points': [(0.5, 0.1), (1, 0), (1, 1), (0.5, 1)],
- 'text': 5678,
- 'ignore': False,
- }]]
- gts = [[{
- 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
- 'text': 123,
- 'ignore': False,
- }]]
+ preds = [
+ [
+ {
+ "points": [(0.1, 0.1), (0.5, 0), (0.5, 1), (0, 1)],
+ "text": 1234,
+ "ignore": False,
+ },
+ {
+ "points": [(0.5, 0.1), (1, 0), (1, 1), (0.5, 1)],
+ "text": 5678,
+ "ignore": False,
+ },
+ ]
+ ]
+ gts = [
+ [
+ {
+ "points": [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+ "text": 123,
+ "ignore": False,
+ }
+ ]
+ ]
results = []
for gt, pred in zip(gts, preds):
results.append(evaluator.evaluate_image(gt, pred))
diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/mtwi2018.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/mtwi2018.py
index 8e319aacf5..51eccdcef2 100644
--- a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/mtwi2018.py
+++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/mtwi2018.py
@@ -8,11 +8,11 @@
class DetectionMTWI2018Evaluator(object):
def __init__(
- self,
- area_recall_constraint=0.7,
- area_precision_constraint=0.7,
- ev_param_ind_center_diff_thr=1, ):
-
+ self,
+ area_recall_constraint=0.7,
+ area_precision_constraint=0.7,
+ ev_param_ind_center_diff_thr=1,
+ ):
self.area_recall_constraint = area_recall_constraint
self.area_precision_constraint = area_precision_constraint
self.ev_param_ind_center_diff_thr = ev_param_ind_center_diff_thr
@@ -30,24 +30,27 @@ def get_intersection(pD, pG):
def one_to_one_match(row, col):
cont = 0
for j in range(len(recallMat[0])):
- if recallMat[row,
- j] >= self.area_recall_constraint and precisionMat[
- row, j] >= self.area_precision_constraint:
+ if (
+ recallMat[row, j] >= self.area_recall_constraint
+ and precisionMat[row, j] >= self.area_precision_constraint
+ ):
cont = cont + 1
- if (cont != 1):
+ if cont != 1:
return False
cont = 0
for i in range(len(recallMat)):
- if recallMat[
- i, col] >= self.area_recall_constraint and precisionMat[
- i, col] >= self.area_precision_constraint:
+ if (
+ recallMat[i, col] >= self.area_recall_constraint
+ and precisionMat[i, col] >= self.area_precision_constraint
+ ):
cont = cont + 1
- if (cont != 1):
+ if cont != 1:
return False
- if recallMat[row,
- col] >= self.area_recall_constraint and precisionMat[
- row, col] >= self.area_precision_constraint:
+ if (
+ recallMat[row, col] >= self.area_recall_constraint
+ and precisionMat[row, col] >= self.area_precision_constraint
+ ):
return True
return False
@@ -55,10 +58,12 @@ def one_to_many_match(gtNum):
many_sum = 0
detRects = []
for detNum in range(len(recallMat[0])):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and detNum not in detDontCareRectsNum:
- if precisionMat[gtNum,
- detNum] >= self.area_precision_constraint:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and detNum not in detDontCareRectsNum
+ ):
+ if precisionMat[gtNum, detNum] >= self.area_precision_constraint:
many_sum += recallMat[gtNum, detNum]
detRects.append(detNum)
if round(many_sum, 4) >= self.area_recall_constraint:
@@ -70,8 +75,11 @@ def many_to_one_match(detNum):
many_sum = 0
gtRects = []
for gtNum in range(len(recallMat)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCareRectsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCareRectsNum
+ ):
if recallMat[gtNum, detNum] >= self.area_recall_constraint:
many_sum += precisionMat[gtNum, detNum]
gtRects.append(gtNum)
@@ -81,28 +89,32 @@ def many_to_one_match(detNum):
return False, []
def center_distance(r1, r2):
- return ((np.mean(r1, axis=0) - np.mean(r2, axis=0))**2).sum()**0.5
+ return ((np.mean(r1, axis=0) - np.mean(r2, axis=0)) ** 2).sum() ** 0.5
def diag(r):
r = np.array(r)
- return ((r[:, 0].max() - r[:, 0].min())**2 +
- (r[:, 1].max() - r[:, 1].min())**2)**0.5
+ return (
+ (r[:, 0].max() - r[:, 0].min()) ** 2
+ + (r[:, 1].max() - r[:, 1].min()) ** 2
+ ) ** 0.5
perSampleMetrics = {}
recall = 0
precision = 0
hmean = 0
- recallAccum = 0.
- precisionAccum = 0.
+ recallAccum = 0.0
+ precisionAccum = 0.0
gtRects = []
detRects = []
gtPolPoints = []
detPolPoints = []
- gtDontCareRectsNum = [
- ] #Array of Ground Truth Rectangles' keys marked as don't Care
- detDontCareRectsNum = [
- ] #Array of Detected Rectangles' matched with a don't Care GT
+ gtDontCareRectsNum = (
+ []
+ ) # Array of Ground Truth Rectangles' keys marked as don't Care
+ detDontCareRectsNum = (
+ []
+ ) # Array of Detected Rectangles' matched with a don't Care GT
pairs = []
evaluationLog = ""
@@ -110,9 +122,9 @@ def diag(r):
precisionMat = np.empty([1, 1])
for n in range(len(gt)):
- points = gt[n]['points']
+ points = gt[n]["points"]
# transcription = gt[n]['text']
- dontCare = gt[n]['ignore']
+ dontCare = gt[n]["ignore"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -122,12 +134,18 @@ def diag(r):
if dontCare:
gtDontCareRectsNum.append(len(gtRects) - 1)
- evaluationLog += "GT rectangles: " + str(len(gtRects)) + (
- " (" + str(len(gtDontCareRectsNum)) + " don't care)\n"
- if len(gtDontCareRectsNum) > 0 else "\n")
+ evaluationLog += (
+ "GT rectangles: "
+ + str(len(gtRects))
+ + (
+ " (" + str(len(gtDontCareRectsNum)) + " don't care)\n"
+ if len(gtDontCareRectsNum) > 0
+ else "\n"
+ )
+ )
for n in range(len(pred)):
- points = pred[n]['points']
+ points = pred[n]["points"]
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
@@ -140,24 +158,30 @@ def diag(r):
dontCareRect = gtRects[dontCareRectNum]
intersected_area = get_intersection(dontCareRect, detRect)
rdDimensions = Polygon(detRect).area
- if (rdDimensions == 0):
+ if rdDimensions == 0:
precision = 0
else:
precision = intersected_area / rdDimensions
- if (precision > 0.5):
+ if precision > 0.5:
detDontCareRectsNum.append(len(detRects) - 1)
break
- evaluationLog += "DET rectangles: " + str(len(detRects)) + (
- " (" + str(len(detDontCareRectsNum)) + " don't care)\n"
- if len(detDontCareRectsNum) > 0 else "\n")
+ evaluationLog += (
+ "DET rectangles: "
+ + str(len(detRects))
+ + (
+ " (" + str(len(detDontCareRectsNum)) + " don't care)\n"
+ if len(detDontCareRectsNum) > 0
+ else "\n"
+ )
+ )
if len(gtRects) == 0:
recall = 1
precision = 0 if len(detRects) > 0 else 1
if len(detRects) > 0:
- #Calculate recall and precision matrixs
+ # Calculate recall and precision matrixs
outputShape = [len(gtRects), len(detRects)]
recallMat = np.empty(outputShape)
precisionMat = np.empty(outputShape)
@@ -170,22 +194,26 @@ def diag(r):
intersected_area = get_intersection(rG, rD)
rgDimensions = Polygon(rG).area
rdDimensions = Polygon(rD).area
- recallMat[
- gtNum,
- detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions
- precisionMat[
- gtNum,
- detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions
+ recallMat[gtNum, detNum] = (
+ 0 if rgDimensions == 0 else intersected_area / rgDimensions
+ )
+ precisionMat[gtNum, detNum] = (
+ 0 if rdDimensions == 0 else intersected_area / rdDimensions
+ )
# Find one-to-one matches
evaluationLog += "Find one-to-one matches\n"
for gtNum in range(len(gtRects)):
for detNum in range(len(detRects)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCareRectsNum
+ and detNum not in detDontCareRectsNum
+ ):
match = one_to_one_match(gtNum, detNum)
if match is True:
- #in deteval we have to make other validation before mark as one-to-one
+ # in deteval we have to make other validation before mark as one-to-one
rG = gtRects[gtNum]
rD = detRects[detNum]
normDist = center_distance(rG, rD)
@@ -196,18 +224,24 @@ def diag(r):
detRectMat[detNum] = 1
recallAccum += 1.0
precisionAccum += 1.0
- pairs.append({
- 'gt': gtNum,
- 'det': detNum,
- 'type': 'OO'
- })
- evaluationLog += "Match GT #" + str(
- gtNum) + " with Det #" + str(detNum) + "\n"
+ pairs.append({"gt": gtNum, "det": detNum, "type": "OO"})
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
else:
- evaluationLog += "Match Discarded GT #" + str(
- gtNum) + " with Det #" + str(
- detNum) + " normDist: " + str(
- normDist) + " \n"
+ evaluationLog += (
+ "Match Discarded GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + " normDist: "
+ + str(normDist)
+ + " \n"
+ )
# Find one-to-many matches
evaluationLog += "Find one-to-many matches\n"
for gtNum in range(len(gtRects)):
@@ -217,16 +251,24 @@ def diag(r):
gtRectMat[gtNum] = 1
recallAccum += 1.0
precisionAccum += len(matchesDet) / (
- 1 + math.log(len(matchesDet)))
- pairs.append({
- 'gt': gtNum,
- 'det': matchesDet,
- 'type': 'OO' if len(matchesDet) == 1 else 'OM'
- })
+ 1 + math.log(len(matchesDet))
+ )
+ pairs.append(
+ {
+ "gt": gtNum,
+ "det": matchesDet,
+ "type": "OO" if len(matchesDet) == 1 else "OM",
+ }
+ )
for detNum in matchesDet:
detRectMat[detNum] = 1
- evaluationLog += "Match GT #" + str(
- gtNum) + " with Det #" + str(matchesDet) + "\n"
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(matchesDet)
+ + "\n"
+ )
# Find many-to-one matches
evaluationLog += "Find many-to-one matches\n"
@@ -235,53 +277,62 @@ def diag(r):
match, matchesGt = many_to_one_match(detNum)
if match is True:
detRectMat[detNum] = 1
- recallAccum += len(matchesGt) / (
- 1 + math.log(len(matchesGt)))
+ recallAccum += len(matchesGt) / (1 + math.log(len(matchesGt)))
precisionAccum += 1.0
- pairs.append({
- 'gt': matchesGt,
- 'det': detNum,
- 'type': 'OO' if len(matchesGt) == 1 else 'MO'
- })
+ pairs.append(
+ {
+ "gt": matchesGt,
+ "det": detNum,
+ "type": "OO" if len(matchesGt) == 1 else "MO",
+ }
+ )
for gtNum in matchesGt:
gtRectMat[gtNum] = 1
- evaluationLog += "Match GT #" + str(
- matchesGt) + " with Det #" + str(detNum) + "\n"
-
- numGtCare = (len(gtRects) - len(gtDontCareRectsNum))
+ evaluationLog += (
+ "Match GT #"
+ + str(matchesGt)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
+
+ numGtCare = len(gtRects) - len(gtDontCareRectsNum)
if numGtCare == 0:
recall = float(1)
precision = float(0) if len(detRects) > 0 else float(1)
else:
recall = float(recallAccum) / numGtCare
- precision = float(0) if (
- len(detRects) - len(detDontCareRectsNum)
- ) == 0 else float(precisionAccum) / (
- len(detRects) - len(detDontCareRectsNum))
- hmean = 0 if (precision + recall
- ) == 0 else 2.0 * precision * recall / (
- precision + recall)
+ precision = (
+ float(0)
+ if (len(detRects) - len(detDontCareRectsNum)) == 0
+ else float(precisionAccum)
+ / (len(detRects) - len(detDontCareRectsNum))
+ )
+ hmean = (
+ 0
+ if (precision + recall) == 0
+ else 2.0 * precision * recall / (precision + recall)
+ )
numGtCare = len(gtRects) - len(gtDontCareRectsNum)
numDetCare = len(detRects) - len(detDontCareRectsNum)
perSampleMetrics = {
- 'precision': precision,
- 'recall': recall,
- 'hmean': hmean,
- 'pairs': pairs,
- 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(),
- 'precisionMat': []
- if len(detRects) > 100 else precisionMat.tolist(),
- 'gtPolPoints': gtPolPoints,
- 'detPolPoints': detPolPoints,
- 'gtCare': numGtCare,
- 'detCare': numDetCare,
- 'gtDontCare': gtDontCareRectsNum,
- 'detDontCare': detDontCareRectsNum,
- 'recallAccum': recallAccum,
- 'precisionAccum': precisionAccum,
- 'evaluationLog': evaluationLog
+ "precision": precision,
+ "recall": recall,
+ "hmean": hmean,
+ "pairs": pairs,
+ "recallMat": [] if len(detRects) > 100 else recallMat.tolist(),
+ "precisionMat": [] if len(detRects) > 100 else precisionMat.tolist(),
+ "gtPolPoints": gtPolPoints,
+ "detPolPoints": detPolPoints,
+ "gtCare": numGtCare,
+ "detCare": numDetCare,
+ "gtDontCare": gtDontCareRectsNum,
+ "detDontCare": detDontCareRectsNum,
+ "recallAccum": recallAccum,
+ "precisionAccum": precisionAccum,
+ "evaluationLog": evaluationLog,
}
return perSampleMetrics
@@ -293,41 +344,53 @@ def combine_results(self, results):
methodPrecisionSum = 0
for result in results:
- numGt += result['gtCare']
- numDet += result['detCare']
- methodRecallSum += result['recallAccum']
- methodPrecisionSum += result['precisionAccum']
+ numGt += result["gtCare"]
+ numDet += result["detCare"]
+ methodRecallSum += result["recallAccum"]
+ methodPrecisionSum += result["precisionAccum"]
methodRecall = 0 if numGt == 0 else methodRecallSum / numGt
methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet
- methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
- methodRecall + methodPrecision)
+ methodHmean = (
+ 0
+ if methodRecall + methodPrecision == 0
+ else 2 * methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ )
methodMetrics = {
- 'precision': methodPrecision,
- 'recall': methodRecall,
- 'hmean': methodHmean
+ "precision": methodPrecision,
+ "recall": methodRecall,
+ "hmean": methodHmean,
}
return methodMetrics
-if __name__ == '__main__':
+if __name__ == "__main__":
evaluator = DetectionICDAR2013Evaluator()
- gts = [[{
- 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
- 'text': 1234,
- 'ignore': False,
- }, {
- 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
- 'text': 5678,
- 'ignore': True,
- }]]
- preds = [[{
- 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
- 'text': 123,
- 'ignore': False,
- }]]
+ gts = [
+ [
+ {
+ "points": [(0, 0), (1, 0), (1, 1), (0, 1)],
+ "text": 1234,
+ "ignore": False,
+ },
+ {
+ "points": [(2, 2), (3, 2), (3, 3), (2, 3)],
+ "text": 5678,
+ "ignore": True,
+ },
+ ]
+ ]
+ preds = [
+ [
+ {
+ "points": [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+ "text": 123,
+ "ignore": False,
+ }
+ ]
+ ]
results = []
for gt, pred in zip(gts, preds):
results.append(evaluator.evaluate_image(gt, pred))
diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/quad_metric.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/quad_metric.py
index e7e403a31c..4a4a2b81cd 100644
--- a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/quad_metric.py
+++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/quad_metric.py
@@ -23,14 +23,13 @@ def update(self, val, n=1):
return self
-class QuadMetric():
+class QuadMetric:
def __init__(self, is_output_polygon=False):
self.is_output_polygon = is_output_polygon
- self.evaluator = DetectionIoUEvaluator(
- is_output_polygon=is_output_polygon)
+ self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon)
def measure(self, batch, output, box_thresh=0.6):
- '''
+ """
batch: (image, polygons, ignore_tags
batch: a dict produced by dataloaders.
image: tensor of shape (N, C, H, W).
@@ -39,24 +38,22 @@ def measure(self, batch, output, box_thresh=0.6):
shape: the original shape of images.
filename: the original filenames of images.
output: (polygons, ...)
- '''
+ """
results = []
- gt_polyons_batch = batch['text_polys']
- ignore_tags_batch = batch['ignore_tags']
+ gt_polyons_batch = batch["text_polys"]
+ ignore_tags_batch = batch["ignore_tags"]
pred_polygons_batch = np.array(output[0])
pred_scores_batch = np.array(output[1])
for polygons, pred_polygons, pred_scores, ignore_tags in zip(
- gt_polyons_batch, pred_polygons_batch, pred_scores_batch,
- ignore_tags_batch):
+ gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch
+ ):
gt = [
- dict(
- points=np.int64(polygons[i]), ignore=ignore_tags[i])
+ dict(points=np.int64(polygons[i]), ignore=ignore_tags[i])
for i in range(len(polygons))
]
if self.is_output_polygon:
pred = [
- dict(points=pred_polygons[i])
- for i in range(len(pred_polygons))
+ dict(points=pred_polygons[i]) for i in range(len(pred_polygons))
]
else:
pred = []
@@ -64,8 +61,7 @@ def measure(self, batch, output, box_thresh=0.6):
for i in range(pred_polygons.shape[0]):
if pred_scores[i] >= box_thresh:
# print(pred_polygons[i,:,:].tolist())
- pred.append(
- dict(points=pred_polygons[i, :, :].astype(np.int)))
+ pred.append(dict(points=pred_polygons[i, :, :].astype(np.int)))
# pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])]
results.append(self.evaluator.evaluate_image(gt, pred))
return results
@@ -74,13 +70,16 @@ def validate_measure(self, batch, output, box_thresh=0.6):
return self.measure(batch, output, box_thresh)
def evaluate_measure(self, batch, output):
- return self.measure(batch, output), np.linspace(
- 0, batch['image'].shape[0]).tolist()
+ return (
+ self.measure(batch, output),
+ np.linspace(0, batch["image"].shape[0]).tolist(),
+ )
def gather_measure(self, raw_metrics):
raw_metrics = [
image_metrics
- for batch_metrics in raw_metrics for image_metrics in batch_metrics
+ for batch_metrics in raw_metrics
+ for image_metrics in batch_metrics
]
result = self.evaluator.combine_results(raw_metrics)
@@ -89,10 +88,11 @@ def gather_measure(self, raw_metrics):
recall = AverageMeter()
fmeasure = AverageMeter()
- precision.update(result['precision'], n=len(raw_metrics))
- recall.update(result['recall'], n=len(raw_metrics))
- fmeasure_score = 2 * precision.val * recall.val / (
- precision.val + recall.val + 1e-8)
+ precision.update(result["precision"], n=len(raw_metrics))
+ recall.update(result["recall"], n=len(raw_metrics))
+ fmeasure_score = (
+ 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8)
+ )
fmeasure.update(fmeasure_score)
- return {'precision': precision, 'recall': recall, 'fmeasure': fmeasure}
+ return {"precision": precision, "recall": recall, "fmeasure": fmeasure}
diff --git a/benchmark/PaddleOCR_DBNet/utils/profiler.py b/benchmark/PaddleOCR_DBNet/utils/profiler.py
index e64afd6a0d..34fcea8e60 100644
--- a/benchmark/PaddleOCR_DBNet/utils/profiler.py
+++ b/benchmark/PaddleOCR_DBNet/utils/profiler.py
@@ -24,7 +24,7 @@
class ProfilerOptions(object):
- '''
+ """
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
@@ -33,7 +33,7 @@ class ProfilerOptions(object):
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
- state - a string, the optional values are 'CPU', 'GPU' or 'All'.
+ state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
@@ -41,54 +41,54 @@ class ProfilerOptions(object):
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
- '''
+ """
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
- 'batch_range': [10, 20],
- 'state': 'All',
- 'sorted_key': 'total',
- 'tracer_option': 'Default',
- 'profile_path': '/tmp/profile',
- 'exit_on_finished': True
+ "batch_range": [10, 20],
+ "state": "All",
+ "sorted_key": "total",
+ "tracer_option": "Default",
+ "profile_path": "/tmp/profile",
+ "exit_on_finished": True,
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
- for kv in options_str.replace(' ', '').split(';'):
- key, value = kv.split('=')
- if key == 'batch_range':
- value_list = value.replace('[', '').replace(']', '').split(',')
+ for kv in options_str.replace(" ", "").split(";"):
+ key, value = kv.split("=")
+ if key == "batch_range":
+ value_list = value.replace("[", "").replace("]", "").split(",")
value_list = list(map(int, value_list))
- if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
- 1] > value_list[0]:
+ if (
+ len(value_list) >= 2
+ and value_list[0] >= 0
+ and value_list[1] > value_list[0]
+ ):
self._options[key] = value_list
- elif key == 'exit_on_finished':
+ elif key == "exit_on_finished":
self._options[key] = value.lower() in ("yes", "true", "t", "1")
- elif key in [
- 'state', 'sorted_key', 'tracer_option', 'profile_path'
- ]:
+ elif key in ["state", "sorted_key", "tracer_option", "profile_path"]:
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
- raise ValueError(
- "ProfilerOptions does not have an option named %s." % name)
+ raise ValueError("ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
- '''
+ """
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
-
+
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
- '''
+ """
if options_str is None:
return
@@ -98,13 +98,15 @@ def add_profiler_step(options_str=None):
if _profiler_options is None:
_profiler_options = ProfilerOptions(options_str)
- if _profiler_step_id == _profiler_options['batch_range'][0]:
- paddle.utils.profiler.start_profiler(_profiler_options['state'],
- _profiler_options['tracer_option'])
- elif _profiler_step_id == _profiler_options['batch_range'][1]:
- paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
- _profiler_options['profile_path'])
- if _profiler_options['exit_on_finished']:
+ if _profiler_step_id == _profiler_options["batch_range"][0]:
+ paddle.utils.profiler.start_profiler(
+ _profiler_options["state"], _profiler_options["tracer_option"]
+ )
+ elif _profiler_step_id == _profiler_options["batch_range"][1]:
+ paddle.utils.profiler.stop_profiler(
+ _profiler_options["sorted_key"], _profiler_options["profile_path"]
+ )
+ if _profiler_options["exit_on_finished"]:
sys.exit(0)
_profiler_step_id += 1
diff --git a/benchmark/PaddleOCR_DBNet/utils/schedulers.py b/benchmark/PaddleOCR_DBNet/utils/schedulers.py
index 1b6fb7d285..e038ddcb58 100644
--- a/benchmark/PaddleOCR_DBNet/utils/schedulers.py
+++ b/benchmark/PaddleOCR_DBNet/utils/schedulers.py
@@ -1,6 +1,7 @@
from paddle.optimizer import lr
import logging
-__all__ = ['Polynomial']
+
+__all__ = ["Polynomial"]
class Polynomial(object):
@@ -18,20 +19,22 @@ class Polynomial(object):
by_epoch: Whether the set parameter is based on epoch or iter, when set to true,, epochs and warmup_epoch will be automatically multiplied by step_each_epoch. Default: True
"""
- def __init__(self,
- learning_rate,
- epochs,
- step_each_epoch,
- end_lr=0.0,
- power=1.0,
- warmup_epoch=0,
- warmup_start_lr=0.0,
- last_epoch=-1,
- by_epoch=True,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate,
+ epochs,
+ step_each_epoch,
+ end_lr=0.0,
+ power=1.0,
+ warmup_epoch=0,
+ warmup_start_lr=0.0,
+ last_epoch=-1,
+ by_epoch=True,
+ **kwargs,
+ ):
super().__init__()
if warmup_epoch >= epochs:
- msg = f"When using warm up, the value of \"epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
+ msg = f'When using warm up, the value of "epochs" must be greater than value of "Optimizer.lr.warmup_epoch". The value of "Optimizer.lr.warmup_epoch" has been set to {epochs}.'
logging.warning(msg)
warmup_epoch = epochs
self.learning_rate = learning_rate
@@ -47,18 +50,23 @@ def __init__(self,
self.warmup_epoch = int(self.warmup_epoch * step_each_epoch)
def __call__(self):
- learning_rate = lr.PolynomialDecay(
- learning_rate=self.learning_rate,
- decay_steps=self.epochs,
- end_lr=self.end_lr,
- power=self.power,
- last_epoch=self.
- last_epoch) if self.epochs > 0 else self.learning_rate
+ learning_rate = (
+ lr.PolynomialDecay(
+ learning_rate=self.learning_rate,
+ decay_steps=self.epochs,
+ end_lr=self.end_lr,
+ power=self.power,
+ last_epoch=self.last_epoch,
+ )
+ if self.epochs > 0
+ else self.learning_rate
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
diff --git a/benchmark/PaddleOCR_DBNet/utils/util.py b/benchmark/PaddleOCR_DBNet/utils/util.py
index 39bae76409..ad36cabb83 100644
--- a/benchmark/PaddleOCR_DBNet/utils/util.py
+++ b/benchmark/PaddleOCR_DBNet/utils/util.py
@@ -16,7 +16,7 @@
def _check_image_file(path):
- img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
+ img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
return any([path.lower().endswith(e) for e in img_end])
@@ -25,7 +25,7 @@ def get_image_file_list(img_file):
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
- img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
+ img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
@@ -39,12 +39,12 @@ def get_image_file_list(img_file):
return imgs_lists
-def setup_logger(log_file_path: str=None):
+def setup_logger(log_file_path: str = None):
import logging
+
logging._warn_preinit_stderr = 0
- logger = logging.getLogger('DBNet.paddle')
- formatter = logging.Formatter(
- '%(asctime)s %(name)s %(levelname)s: %(message)s')
+ logger = logging.getLogger("DBNet.paddle")
+ formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
@@ -69,29 +69,28 @@ def newFunc(*args, **args2):
def load(file_path: str):
file_path = pathlib.Path(file_path)
- func_dict = {'.txt': _load_txt, '.json': _load_json, '.list': _load_txt}
+ func_dict = {".txt": _load_txt, ".json": _load_json, ".list": _load_txt}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](file_path)
def _load_txt(file_path: str):
- with open(file_path, 'r', encoding='utf8') as f:
+ with open(file_path, "r", encoding="utf8") as f:
content = [
- x.strip().strip('\ufeff').strip('\xef\xbb\xbf')
- for x in f.readlines()
+ x.strip().strip("\ufeff").strip("\xef\xbb\xbf") for x in f.readlines()
]
return content
def _load_json(file_path: str):
- with open(file_path, 'r', encoding='utf8') as f:
+ with open(file_path, "r", encoding="utf8") as f:
content = json.load(f)
return content
def save(data, file_path):
file_path = pathlib.Path(file_path)
- func_dict = {'.txt': _save_txt, '.json': _save_json}
+ func_dict = {".txt": _save_txt, ".json": _save_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](data, file_path)
@@ -105,22 +104,22 @@ def _save_txt(data, file_path):
"""
if not isinstance(data, list):
data = [data]
- with open(file_path, mode='w', encoding='utf8') as f:
- f.write('\n'.join(data))
+ with open(file_path, mode="w", encoding="utf8") as f:
+ f.write("\n".join(data))
def _save_json(data, file_path):
- with open(file_path, 'w', encoding='utf-8') as json_file:
+ with open(file_path, "w", encoding="utf-8") as json_file:
json.dump(data, json_file, ensure_ascii=False, indent=4)
-def show_img(imgs: np.ndarray, title='img'):
- color = (len(imgs.shape) == 3 and imgs.shape[-1] == 3)
+def show_img(imgs: np.ndarray, title="img"):
+ color = len(imgs.shape) == 3 and imgs.shape[-1] == 3
imgs = np.expand_dims(imgs, axis=0)
for i, img in enumerate(imgs):
plt.figure()
- plt.title('{}_{}'.format(title, i))
- plt.imshow(img, cmap=None if color else 'gray')
+ plt.title("{}_{}".format(title, i))
+ plt.imshow(img, cmap=None if color else "gray")
plt.show()
@@ -135,11 +134,7 @@ def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
return img_path
-def cal_text_score(texts,
- gt_texts,
- training_masks,
- running_metric_text,
- thred=0.5):
+def cal_text_score(texts, gt_texts, training_masks, running_metric_text, thred=0.5):
training_masks = training_masks.numpy()
pred_text = texts.numpy() * training_masks
pred_text[pred_text <= thred] = 0
@@ -180,34 +175,37 @@ def get_datalist(train_data_path):
"""
train_data = []
for p in train_data_path:
- with open(p, 'r', encoding='utf-8') as f:
+ with open(p, "r", encoding="utf-8") as f:
for line in f.readlines():
- line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t')
+ line = line.strip("\n").replace(".jpg ", ".jpg\t").split("\t")
if len(line) > 1:
- img_path = pathlib.Path(line[0].strip(' '))
- label_path = pathlib.Path(line[1].strip(' '))
- if img_path.exists() and img_path.stat(
- ).st_size > 0 and label_path.exists() and label_path.stat(
- ).st_size > 0:
+ img_path = pathlib.Path(line[0].strip(" "))
+ label_path = pathlib.Path(line[1].strip(" "))
+ if (
+ img_path.exists()
+ and img_path.stat().st_size > 0
+ and label_path.exists()
+ and label_path.stat().st_size > 0
+ ):
train_data.append((str(img_path), str(label_path)))
return train_data
def save_result(result_path, box_list, score_list, is_output_polygon):
if is_output_polygon:
- with open(result_path, 'wt') as res:
+ with open(result_path, "wt") as res:
for i, box in enumerate(box_list):
box = box.reshape(-1).tolist()
result = ",".join([str(int(x)) for x in box])
score = score_list[i]
- res.write(result + ',' + str(score) + "\n")
+ res.write(result + "," + str(score) + "\n")
else:
- with open(result_path, 'wt') as res:
+ with open(result_path, "wt") as res:
for i, box in enumerate(box_list):
score = score_list[i]
box = box.reshape(-1).tolist()
result = ",".join([str(int(x)) for x in box])
- res.write(result + ',' + str(score) + "\n")
+ res.write(result + "," + str(score) + "\n")
def expand_polygon(polygon):
@@ -225,7 +223,7 @@ def expand_polygon(polygon):
def _merge_dict(config, merge_dct):
- """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
+ """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
@@ -235,12 +233,15 @@ def _merge_dict(config, merge_dct):
Returns: dct
"""
for key, value in merge_dct.items():
- sub_keys = key.split('.')
+ sub_keys = key.split(".")
key = sub_keys[0]
if key in config and len(sub_keys) > 1:
- _merge_dict(config[key], {'.'.join(sub_keys[1:]): value})
- elif key in config and isinstance(config[key], dict) and isinstance(
- value, Mapping):
+ _merge_dict(config[key], {".".join(sub_keys[1:]): value})
+ elif (
+ key in config
+ and isinstance(config[key], dict)
+ and isinstance(value, Mapping)
+ ):
_merge_dict(config[key], value)
else:
config[key] = value
@@ -265,19 +266,19 @@ def print_dict(cfg, print_func=print, delimiter=0):
class Config(object):
- def __init__(self, config_path, BASE_KEY='base'):
+ def __init__(self, config_path, BASE_KEY="base"):
self.BASE_KEY = BASE_KEY
self.cfg = self._load_config_with_base(config_path)
def _load_config_with_base(self, file_path):
"""
- Load config from file.
- Args:
- file_path (str): Path of the config file to be loaded.
- Returns: global config
- """
+ Load config from file.
+ Args:
+ file_path (str): Path of the config file to be loaded.
+ Returns: global config
+ """
_, ext = os.path.splitext(file_path)
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
+ assert ext in [".yml", ".yaml"], "only support yaml files for now"
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
@@ -293,7 +294,7 @@ def _load_config_with_base(self, file_path):
del file_cfg[self.BASE_KEY]
file_cfg = _merge_dict(all_base_cfg, file_cfg)
- file_cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
+ file_cfg["filename"] = os.path.splitext(os.path.split(file_path)[-1])[0]
return file_cfg
def merge_dict(self, args):
@@ -304,37 +305,34 @@ def print_cfg(self, print_func=print):
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
- print_func('----------- Config -----------')
+ print_func("----------- Config -----------")
print_dict(self.cfg, print_func)
- print_func('---------------------------------------------')
+ print_func("---------------------------------------------")
def save(self, p):
- with open(p, 'w') as f:
- yaml.dump(
- dict(self.cfg), f, default_flow_style=False, sort_keys=False)
+ with open(p, "w") as f:
+ yaml.dump(dict(self.cfg), f, default_flow_style=False, sort_keys=False)
class ArgsParser(ArgumentParser):
def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
- self.add_argument(
- "-c", "--config_file", help="configuration file to use")
- self.add_argument(
- "-o", "--opt", nargs='*', help="set configuration options")
+ super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument("-c", "--config_file", help="configuration file to use")
+ self.add_argument("-o", "--opt", nargs="*", help="set configuration options")
self.add_argument(
- '-p',
- '--profiler_options',
+ "-p",
+ "--profiler_options",
type=str,
default=None,
- help='The option of profiler, which should be in format ' \
- '\"key1=value1;key2=value2;key3=value3\".'
+ help="The option of profiler, which should be in format "
+ '"key1=value1;key2=value2;key3=value3".',
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
- assert args.config_file is not None, \
- "Please specify --config_file=configure_file_path."
+ assert (
+ args.config_file is not None
+ ), "Please specify --config_file=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
@@ -344,11 +342,11 @@ def _parse_opt(self, opts):
return config
for s in opts:
s = s.strip()
- k, v = s.split('=', 1)
- if '.' not in k:
+ k, v = s.split("=", 1)
+ if "." not in k:
config[k] = yaml.load(v, Loader=yaml.Loader)
else:
- keys = k.split('.')
+ keys = k.split(".")
if keys[0] not in config:
config[keys[0]] = {}
cur = config[keys[0]]
@@ -361,7 +359,7 @@ def _parse_opt(self, opts):
return config
-if __name__ == '__main__':
+if __name__ == "__main__":
img = np.zeros((1, 3, 640, 640))
show_img(img[0][0])
plt.show()
diff --git a/benchmark/analysis.py b/benchmark/analysis.py
index 7322f00ace..bf6233fd53 100644
--- a/benchmark/analysis.py
+++ b/benchmark/analysis.py
@@ -24,68 +24,64 @@
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
- "--filename", type=str, help="The name of log which need to analysis.")
+ "--filename", type=str, help="The name of log which need to analysis."
+ )
parser.add_argument(
- "--log_with_profiler",
- type=str,
- help="The path of train log with profiler")
- parser.add_argument(
- "--profiler_path", type=str, help="The path of profiler timeline log.")
+ "--log_with_profiler", type=str, help="The path of train log with profiler"
+ )
parser.add_argument(
- "--keyword", type=str, help="Keyword to specify analysis data")
+ "--profiler_path", type=str, help="The path of profiler timeline log."
+ )
+ parser.add_argument("--keyword", type=str, help="Keyword to specify analysis data")
parser.add_argument(
"--separator",
type=str,
default=None,
- help="Separator of different field in log")
+ help="Separator of different field in log",
+ )
parser.add_argument(
- '--position', type=int, default=None, help='The position of data field')
+ "--position", type=int, default=None, help="The position of data field"
+ )
parser.add_argument(
- '--range',
- type=str,
- default="",
- help='The range of data field to intercept')
+ "--range", type=str, default="", help="The range of data field to intercept"
+ )
+ parser.add_argument("--base_batch_size", type=int, help="base_batch size on gpu")
parser.add_argument(
- '--base_batch_size', type=int, help='base_batch size on gpu')
+ "--skip_steps", type=int, default=0, help="The number of steps to be skipped"
+ )
parser.add_argument(
- '--skip_steps',
- type=int,
- default=0,
- help='The number of steps to be skipped')
+ "--model_mode", type=int, default=-1, help="Analysis mode, default value is -1"
+ )
+ parser.add_argument("--ips_unit", type=str, default=None, help="IPS unit")
parser.add_argument(
- '--model_mode',
- type=int,
- default=-1,
- help='Analysis mode, default value is -1')
- parser.add_argument('--ips_unit', type=str, default=None, help='IPS unit')
- parser.add_argument(
- '--model_name',
+ "--model_name",
type=str,
default=0,
- help='training model_name, transformer_base')
+ help="training model_name, transformer_base",
+ )
parser.add_argument(
- '--mission_name', type=str, default=0, help='training mission name')
+ "--mission_name", type=str, default=0, help="training mission name"
+ )
parser.add_argument(
- '--direction_id', type=int, default=0, help='training direction_id')
+ "--direction_id", type=int, default=0, help="training direction_id"
+ )
parser.add_argument(
- '--run_mode',
- type=str,
- default="sp",
- help='multi process or single process')
+ "--run_mode", type=str, default="sp", help="multi process or single process"
+ )
parser.add_argument(
- '--index',
+ "--index",
type=int,
default=1,
- help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}')
- parser.add_argument(
- '--gpu_num', type=int, default=1, help='nums of training gpus')
+ help="{1: speed, 2:mem, 3:profiler, 6:max_batch_size}",
+ )
+ parser.add_argument("--gpu_num", type=int, default=1, help="nums of training gpus")
args = parser.parse_args()
args.separator = None if args.separator == "None" else args.separator
return args
def _is_number(num):
- pattern = re.compile(r'^[-+]?[-0-9]\d*\.\d*|[-+]?\.?[0-9]\d*$')
+ pattern = re.compile(r"^[-+]?[-0-9]\d*\.\d*|[-+]?\.?[0-9]\d*$")
result = pattern.match(num)
if result:
return True
@@ -94,12 +90,9 @@ def _is_number(num):
class TimeAnalyzer(object):
- def __init__(self,
- filename,
- keyword=None,
- separator=None,
- position=None,
- range="-1"):
+ def __init__(
+ self, filename, keyword=None, separator=None, position=None, range="-1"
+ ):
if filename is None:
raise Exception("Please specify the filename!")
@@ -126,8 +119,9 @@ def _distil(self):
# Distil the string from a line.
line = line.strip()
- line_words = line.split(
- self.separator) if self.separator else line.split()
+ line_words = (
+ line.split(self.separator) if self.separator else line.split()
+ )
if args.position:
result = line_words[self.position]
else:
@@ -141,31 +135,34 @@ def _distil(self):
if not self.range:
result = result[0:]
elif _is_number(self.range):
- result = result[0:int(self.range)]
+ result = result[0 : int(self.range)]
else:
- result = result[int(self.range.split(":")[0]):int(
- self.range.split(":")[1])]
+ result = result[
+ int(self.range.split(":")[0]) : int(
+ self.range.split(":")[1]
+ )
+ ]
self.records.append(float(result))
except Exception as exc:
- print("line is: {}; separator={}; position={}".format(
- line, self.separator, self.position))
+ print(
+ "line is: {}; separator={}; position={}".format(
+ line, self.separator, self.position
+ )
+ )
- print("Extract {} records: separator={}; position={}".format(
- len(self.records), self.separator, self.position))
+ print(
+ "Extract {} records: separator={}; position={}".format(
+ len(self.records), self.separator, self.position
+ )
+ )
- def _get_fps(self,
- mode,
- batch_size,
- gpu_num,
- avg_of_records,
- run_mode,
- unit=None):
- if mode == -1 and run_mode == 'sp':
+ def _get_fps(self, mode, batch_size, gpu_num, avg_of_records, run_mode, unit=None):
+ if mode == -1 and run_mode == "sp":
assert unit, "Please set the unit when mode is -1."
fps = gpu_num * avg_of_records
- elif mode == -1 and run_mode == 'mp':
+ elif mode == -1 and run_mode == "mp":
assert unit, "Please set the unit when mode is -1."
- fps = gpu_num * avg_of_records #temporarily, not used now
+ fps = gpu_num * avg_of_records # temporarily, not used now
print("------------this is mp")
elif mode == 0:
# s/step -> samples/s
@@ -192,22 +189,18 @@ def _get_fps(self,
return fps, unit
- def analysis(self,
- batch_size,
- gpu_num=1,
- skip_steps=0,
- mode=-1,
- run_mode='sp',
- unit=None):
+ def analysis(
+ self, batch_size, gpu_num=1, skip_steps=0, mode=-1, run_mode="sp", unit=None
+ ):
if batch_size <= 0:
print("base_batch_size should larger than 0.")
- return 0, ''
+ return 0, ""
- if len(
- self.records
- ) <= skip_steps: # to address the condition which item of log equals to skip_steps
+ if (
+ len(self.records) <= skip_steps
+ ): # to address the condition which item of log equals to skip_steps
print("no records")
- return 0, ''
+ return 0, ""
sum_of_records = 0
sum_of_records_skipped = 0
@@ -225,20 +218,20 @@ def analysis(self,
skip_max = self.records[i]
avg_of_records = sum_of_records / float(count)
- avg_of_records_skipped = sum_of_records_skipped / float(count -
- skip_steps)
+ avg_of_records_skipped = sum_of_records_skipped / float(count - skip_steps)
- fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records,
- run_mode, unit)
- fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num,
- avg_of_records_skipped, run_mode, unit)
+ fps, fps_unit = self._get_fps(
+ mode, batch_size, gpu_num, avg_of_records, run_mode, unit
+ )
+ fps_skipped, _ = self._get_fps(
+ mode, batch_size, gpu_num, avg_of_records_skipped, run_mode, unit
+ )
if mode == -1:
print("average ips of %d steps, skip 0 step:" % count)
print("\tAvg: %.3f %s" % (avg_of_records, fps_unit))
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
- print("average ips of %d steps, skip %d steps:" %
- (count, skip_steps))
+ print("average ips of %d steps, skip %d steps:" % (count, skip_steps))
print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit))
print("\tMin: %.3f %s" % (skip_min, fps_unit))
print("\tMax: %.3f %s" % (skip_max, fps_unit))
@@ -248,8 +241,9 @@ def analysis(self,
print("\tAvg: %.3f steps/s" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
- print("average latency of %d steps, skip %d steps:" %
- (count, skip_steps))
+ print(
+ "average latency of %d steps, skip %d steps:" % (count, skip_steps)
+ )
print("\tAvg: %.3f steps/s" % avg_of_records_skipped)
print("\tMin: %.3f steps/s" % skip_min)
print("\tMax: %.3f steps/s" % skip_max)
@@ -259,8 +253,9 @@ def analysis(self,
print("\tAvg: %.3f s/step" % avg_of_records)
print("\tFPS: %.3f %s" % (fps, fps_unit))
if skip_steps > 0:
- print("average latency of %d steps, skip %d steps:" %
- (count, skip_steps))
+ print(
+ "average latency of %d steps, skip %d steps:" % (count, skip_steps)
+ )
print("\tAvg: %.3f s/step" % avg_of_records_skipped)
print("\tMin: %.3f s/step" % skip_min)
print("\tMax: %.3f s/step" % skip_max)
@@ -287,60 +282,73 @@ def analysis(self,
if args.gpu_num == 1:
run_info["log_with_profiler"] = args.log_with_profiler
run_info["profiler_path"] = args.profiler_path
- analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator,
- args.position, args.range)
+ analyzer = TimeAnalyzer(
+ args.filename, args.keyword, args.separator, args.position, args.range
+ )
run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis(
batch_size=args.base_batch_size,
gpu_num=args.gpu_num,
skip_steps=args.skip_steps,
mode=args.model_mode,
run_mode=args.run_mode,
- unit=args.ips_unit)
+ unit=args.ips_unit,
+ )
try:
- if int(os.getenv('job_fail_flag')) == 1 or int(run_info[
- "FINAL_RESULT"]) == 0:
+ if (
+ int(os.getenv("job_fail_flag")) == 1
+ or int(run_info["FINAL_RESULT"]) == 0
+ ):
run_info["JOB_FAIL_FLAG"] = 1
except:
pass
elif args.index == 3:
run_info["FINAL_RESULT"] = {}
- records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead',
- None, 3, '').records
- records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead',
- None, 5).records
- records_ct_total = TimeAnalyzer(args.filename, 'Computation time',
- None, 3, '').records
- records_gm_total = TimeAnalyzer(args.filename,
- 'GpuMemcpy Calls',
- None, 4, '').records
- records_gm_ratio = TimeAnalyzer(args.filename,
- 'GpuMemcpy Calls',
- None, 6).records
- records_gmas_total = TimeAnalyzer(args.filename,
- 'GpuMemcpyAsync Calls',
- None, 4, '').records
- records_gms_total = TimeAnalyzer(args.filename,
- 'GpuMemcpySync Calls',
- None, 4, '').records
- run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[
- 0] if records_fo_total else 0
- run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[
- 0] if records_fo_ratio else 0
- run_info["FINAL_RESULT"][
- "ComputationTime_Total"] = records_ct_total[
- 0] if records_ct_total else 0
- run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[
- 0] if records_gm_total else 0
- run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[
- 0] if records_gm_ratio else 0
- run_info["FINAL_RESULT"][
- "GpuMemcpyAsync_Total"] = records_gmas_total[
- 0] if records_gmas_total else 0
- run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[
- 0] if records_gms_total else 0
+ records_fo_total = TimeAnalyzer(
+ args.filename, "Framework overhead", None, 3, ""
+ ).records
+ records_fo_ratio = TimeAnalyzer(
+ args.filename, "Framework overhead", None, 5
+ ).records
+ records_ct_total = TimeAnalyzer(
+ args.filename, "Computation time", None, 3, ""
+ ).records
+ records_gm_total = TimeAnalyzer(
+ args.filename, "GpuMemcpy Calls", None, 4, ""
+ ).records
+ records_gm_ratio = TimeAnalyzer(
+ args.filename, "GpuMemcpy Calls", None, 6
+ ).records
+ records_gmas_total = TimeAnalyzer(
+ args.filename, "GpuMemcpyAsync Calls", None, 4, ""
+ ).records
+ records_gms_total = TimeAnalyzer(
+ args.filename, "GpuMemcpySync Calls", None, 4, ""
+ ).records
+ run_info["FINAL_RESULT"]["Framework_Total"] = (
+ records_fo_total[0] if records_fo_total else 0
+ )
+ run_info["FINAL_RESULT"]["Framework_Ratio"] = (
+ records_fo_ratio[0] if records_fo_ratio else 0
+ )
+ run_info["FINAL_RESULT"]["ComputationTime_Total"] = (
+ records_ct_total[0] if records_ct_total else 0
+ )
+ run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = (
+ records_gm_total[0] if records_gm_total else 0
+ )
+ run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = (
+ records_gm_ratio[0] if records_gm_ratio else 0
+ )
+ run_info["FINAL_RESULT"]["GpuMemcpyAsync_Total"] = (
+ records_gmas_total[0] if records_gmas_total else 0
+ )
+ run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = (
+ records_gms_total[0] if records_gms_total else 0
+ )
else:
print("Not support!")
except Exception:
traceback.print_exc()
- print("{}".format(json.dumps(run_info))
- ) # it's required, for the log file path insert to the database
+ print(
+ "{}".format(json.dumps(run_info))
+ ) # it's required, for the log file path insert to the database
diff --git a/configs/rec/multi_language/generate_multi_language_configs.py b/configs/rec/multi_language/generate_multi_language_configs.py
index 6759ca2a46..e41be8f776 100644
--- a/configs/rec/multi_language/generate_multi_language_configs.py
+++ b/configs/rec/multi_language/generate_multi_language_configs.py
@@ -16,89 +16,155 @@
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import os.path
import logging
+
logging.basicConfig(level=logging.INFO)
support_list = {
- 'it': 'italian',
- 'xi': 'spanish',
- 'pu': 'portuguese',
- 'ru': 'russian',
- 'ar': 'arabic',
- 'ta': 'tamil',
- 'ug': 'uyghur',
- 'fa': 'persian',
- 'ur': 'urdu',
- 'rs': 'serbian latin',
- 'oc': 'occitan',
- 'rsc': 'serbian cyrillic',
- 'bg': 'bulgarian',
- 'uk': 'ukranian',
- 'be': 'belarusian',
- 'te': 'telugu',
- 'ka': 'kannada',
- 'chinese_cht': 'chinese tradition',
- 'hi': 'hindi',
- 'mr': 'marathi',
- 'ne': 'nepali',
+ "it": "italian",
+ "xi": "spanish",
+ "pu": "portuguese",
+ "ru": "russian",
+ "ar": "arabic",
+ "ta": "tamil",
+ "ug": "uyghur",
+ "fa": "persian",
+ "ur": "urdu",
+ "rs": "serbian latin",
+ "oc": "occitan",
+ "rsc": "serbian cyrillic",
+ "bg": "bulgarian",
+ "uk": "ukranian",
+ "be": "belarusian",
+ "te": "telugu",
+ "ka": "kannada",
+ "chinese_cht": "chinese tradition",
+ "hi": "hindi",
+ "mr": "marathi",
+ "ne": "nepali",
}
latin_lang = [
- 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
- 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
- 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
- 'sw', 'tl', 'tr', 'uz', 'vi', 'latin'
+ "af",
+ "az",
+ "bs",
+ "cs",
+ "cy",
+ "da",
+ "de",
+ "es",
+ "et",
+ "fr",
+ "ga",
+ "hr",
+ "hu",
+ "id",
+ "is",
+ "it",
+ "ku",
+ "la",
+ "lt",
+ "lv",
+ "mi",
+ "ms",
+ "mt",
+ "nl",
+ "no",
+ "oc",
+ "pi",
+ "pl",
+ "pt",
+ "ro",
+ "rs_latin",
+ "sk",
+ "sl",
+ "sq",
+ "sv",
+ "sw",
+ "tl",
+ "tr",
+ "uz",
+ "vi",
+ "latin",
]
-arabic_lang = ['ar', 'fa', 'ug', 'ur']
+arabic_lang = ["ar", "fa", "ug", "ur"]
cyrillic_lang = [
- 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
- 'dar', 'inh', 'che', 'lbe', 'lez', 'tab', 'cyrillic'
+ "ru",
+ "rs_cyrillic",
+ "be",
+ "bg",
+ "uk",
+ "mn",
+ "abq",
+ "ady",
+ "kbd",
+ "ava",
+ "dar",
+ "inh",
+ "che",
+ "lbe",
+ "lez",
+ "tab",
+ "cyrillic",
]
devanagari_lang = [
- 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
- 'sa', 'bgc', 'devanagari'
+ "hi",
+ "mr",
+ "ne",
+ "bh",
+ "mai",
+ "ang",
+ "bho",
+ "mah",
+ "sck",
+ "new",
+ "gom",
+ "sa",
+ "bgc",
+ "devanagari",
]
multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang
-assert (os.path.isfile("./rec_multi_language_lite_train.yml")
- ), "Loss basic configuration file rec_multi_language_lite_train.yml.\
+assert os.path.isfile(
+ "./rec_multi_language_lite_train.yml"
+), "Loss basic configuration file rec_multi_language_lite_train.yml.\
You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
global_config = yaml.load(
- open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
+ open("./rec_multi_language_lite_train.yml", "rb"), Loader=yaml.Loader
+)
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
class ArgsParser(ArgumentParser):
def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
- self.add_argument(
- "-o", "--opt", nargs='+', help="set configuration options")
+ super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
self.add_argument(
"-l",
"--language",
- nargs='+',
- help="set language type, support {}".format(support_list))
+ nargs="+",
+ help="set language type, support {}".format(support_list),
+ )
self.add_argument(
"--train",
type=str,
- help="you can use this command to change the train dataset default path"
+ help="you can use this command to change the train dataset default path",
)
self.add_argument(
"--val",
type=str,
- help="you can use this command to change the eval dataset default path"
+ help="you can use this command to change the eval dataset default path",
)
self.add_argument(
"--dict",
type=str,
- help="you can use this command to change the dictionary default path"
+ help="you can use this command to change the dictionary default path",
)
self.add_argument(
"--data_dir",
type=str,
- help="you can use this command to change the dataset default root path"
+ help="you can use this command to change the dataset default root path",
)
def parse_args(self, argv=None):
@@ -113,17 +179,17 @@ def _parse_opt(self, opts):
return config
for s in opts:
s = s.strip()
- k, v = s.split('=')
+ k, v = s.split("=")
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
def _set_language(self, type):
lang = type[0]
- assert (type), "please use -l or --language to choose language type"
- assert(
- lang in support_list.keys() or lang in multi_lang
- ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
- "please check your running command".format(multi_lang, type)
+ assert type, "please use -l or --language to choose language type"
+ assert lang in support_list.keys() or lang in multi_lang, (
+ "the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, "
+ "please check your running command".format(multi_lang, type)
+ )
if lang in latin_lang:
lang = "latin"
elif lang in arabic_lang:
@@ -132,22 +198,23 @@ def _set_language(self, type):
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
- global_config['Global'][
- 'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(lang)
- global_config['Global'][
- 'save_model_dir'] = './output/rec_{}_lite'.format(lang)
- global_config['Train']['dataset'][
- 'label_file_list'] = ["train_data/{}_train.txt".format(lang)]
- global_config['Eval']['dataset'][
- 'label_file_list'] = ["train_data/{}_val.txt".format(lang)]
- global_config['Global']['character_type'] = lang
- assert (
- os.path.isfile(
- os.path.join(project_path, global_config['Global'][
- 'character_dict_path']))
+ global_config["Global"][
+ "character_dict_path"
+ ] = "ppocr/utils/dict/{}_dict.txt".format(lang)
+ global_config["Global"]["save_model_dir"] = "./output/rec_{}_lite".format(lang)
+ global_config["Train"]["dataset"]["label_file_list"] = [
+ "train_data/{}_train.txt".format(lang)
+ ]
+ global_config["Eval"]["dataset"]["label_file_list"] = [
+ "train_data/{}_val.txt".format(lang)
+ ]
+ global_config["Global"]["character_type"] = lang
+ assert os.path.isfile(
+ os.path.join(project_path, global_config["Global"]["character_dict_path"])
), "Loss default dictionary file {}_dict.txt.You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(
- lang)
+ lang
+ )
return lang
@@ -165,11 +232,12 @@ def merge_config(config):
else:
global_config[key] = value
else:
- sub_keys = key.split('.')
+ sub_keys = key.split(".")
assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
- global_config.keys(), sub_keys[0])
+ global_config.keys(), sub_keys[0]
+ )
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
@@ -179,48 +247,61 @@ def merge_config(config):
def loss_file(path):
- assert (
- os.path.exists(path)
+ assert os.path.exists(
+ path
), "There is no such file:{},Please do not forget to put in the specified file".format(
- path)
+ path
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
FLAGS = ArgsParser().parse_args()
merge_config(FLAGS.opt)
- save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
+ save_file_path = "rec_{}_lite_train.yml".format(FLAGS.language)
if os.path.isfile(save_file_path):
os.remove(save_file_path)
if FLAGS.train:
- global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
+ global_config["Train"]["dataset"]["label_file_list"] = [FLAGS.train]
train_label_path = os.path.join(project_path, FLAGS.train)
loss_file(train_label_path)
if FLAGS.val:
- global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
+ global_config["Eval"]["dataset"]["label_file_list"] = [FLAGS.val]
eval_label_path = os.path.join(project_path, FLAGS.val)
loss_file(eval_label_path)
if FLAGS.dict:
- global_config['Global']['character_dict_path'] = FLAGS.dict
+ global_config["Global"]["character_dict_path"] = FLAGS.dict
dict_path = os.path.join(project_path, FLAGS.dict)
loss_file(dict_path)
if FLAGS.data_dir:
- global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
- global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
+ global_config["Eval"]["dataset"]["data_dir"] = FLAGS.data_dir
+ global_config["Train"]["dataset"]["data_dir"] = FLAGS.data_dir
data_dir = os.path.join(project_path, FLAGS.data_dir)
loss_file(data_dir)
- with open(save_file_path, 'w') as f:
- yaml.dump(
- dict(global_config), f, default_flow_style=False, sort_keys=False)
+ with open(save_file_path, "w") as f:
+ yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
logging.info("Project path is :{}".format(project_path))
- logging.info("Train list path set to :{}".format(global_config['Train'][
- 'dataset']['label_file_list'][0]))
- logging.info("Eval list path set to :{}".format(global_config['Eval'][
- 'dataset']['label_file_list'][0]))
- logging.info("Dataset root path set to :{}".format(global_config['Eval'][
- 'dataset']['data_dir']))
- logging.info("Dict path set to :{}".format(global_config['Global'][
- 'character_dict_path']))
- logging.info("Config file set to :configs/rec/multi_language/{}".
- format(save_file_path))
+ logging.info(
+ "Train list path set to :{}".format(
+ global_config["Train"]["dataset"]["label_file_list"][0]
+ )
+ )
+ logging.info(
+ "Eval list path set to :{}".format(
+ global_config["Eval"]["dataset"]["label_file_list"][0]
+ )
+ )
+ logging.info(
+ "Dataset root path set to :{}".format(
+ global_config["Eval"]["dataset"]["data_dir"]
+ )
+ )
+ logging.info(
+ "Dict path set to :{}".format(
+ global_config["Global"]["character_dict_path"]
+ )
+ )
+ logging.info(
+ "Config file set to :configs/rec/multi_language/{}".format(save_file_path)
+ )
diff --git a/deploy/README.md b/deploy/README.md
index 0cfb793f92..d0ec44d176 100644
--- a/deploy/README.md
+++ b/deploy/README.md
@@ -8,7 +8,7 @@ English | [简体中文](README_ch.md)
## Paddle Deployment Introduction
-Paddle provides a variety of deployment schemes to meet the deployment requirements of different scenarios. Please choose according to the actual situation:
+Paddle provides a variety of deployment schemes to meet the deployment requirements of different scenarios. Please choose according to the actual situation:
diff --git a/deploy/README_ch.md b/deploy/README_ch.md
index 1773aedc2c..057d387ba8 100644
--- a/deploy/README_ch.md
+++ b/deploy/README_ch.md
@@ -28,4 +28,4 @@ PP-OCR模型已打通多种场景部署方案,点击链接获取具体的使
- [Jetson 推理](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/deploy/Jetson/readme_ch.md)
- [Paddle2ONNX 推理](./paddle2onnx/readme_ch.md)
-需要PP-OCR以外的学术算法模型的推理部署,请直接进入相应算法主页面,[入口](../doc/doc_ch/algorithm_overview.md)。
\ No newline at end of file
+需要PP-OCR以外的学术算法模型的推理部署,请直接进入相应算法主页面,[入口](../doc/doc_ch/algorithm_overview.md)。
diff --git a/deploy/android_demo/app/src/main/cpp/native.cpp b/deploy/android_demo/app/src/main/cpp/native.cpp
index 4961e5ecf1..5674f64143 100644
--- a/deploy/android_demo/app/src/main/cpp/native.cpp
+++ b/deploy/android_demo/app/src/main/cpp/native.cpp
@@ -13,8 +13,8 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
extern "C" JNIEXPORT jlong JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
JNIEnv *env, jobject thiz, jstring j_det_model_path,
- jstring j_rec_model_path, jstring j_cls_model_path, jint j_use_opencl, jint j_thread_num,
- jstring j_cpu_mode) {
+ jstring j_rec_model_path, jstring j_cls_model_path, jint j_use_opencl,
+ jint j_thread_num, jstring j_cpu_mode) {
std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path);
std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path);
std::string cls_model_path = jstring_to_cpp_string(env, j_cls_model_path);
@@ -58,7 +58,8 @@ str_to_cpu_mode(const std::string &cpu_mode) {
extern "C" JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
- JNIEnv *env, jobject thiz, jlong java_pointer, jobject original_image,jint j_max_size_len, jint j_run_det, jint j_run_cls, jint j_run_rec) {
+ JNIEnv *env, jobject thiz, jlong java_pointer, jobject original_image,
+ jint j_max_size_len, jint j_run_det, jint j_run_cls, jint j_run_rec) {
LOGI("begin to run native forward");
if (java_pointer == 0) {
LOGE("JAVA pointer is NULL");
diff --git a/deploy/android_demo/app/src/main/cpp/native.h b/deploy/android_demo/app/src/main/cpp/native.h
index 9b8e4e40e3..3e159ad56c 100644
--- a/deploy/android_demo/app/src/main/cpp/native.h
+++ b/deploy/android_demo/app/src/main/cpp/native.h
@@ -47,8 +47,8 @@ inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) {
reinterpret_cast (data));
jstring encoding = env->NewStringUTF("UTF-8");
- jstring res = (jstring)(
- env->NewObject(strClass, strClassInitMethodID, bytes, encoding));
+ jstring res = (jstring)(env->NewObject(strClass, strClassInitMethodID, bytes,
+ encoding));
env->DeleteLocalRef(strClass);
env->DeleteLocalRef(encoding);
diff --git a/deploy/android_demo/app/src/main/cpp/ocr_clipper.cpp b/deploy/android_demo/app/src/main/cpp/ocr_clipper.cpp
index 4a531fcf4f..f81efa85e5 100644
--- a/deploy/android_demo/app/src/main/cpp/ocr_clipper.cpp
+++ b/deploy/android_demo/app/src/main/cpp/ocr_clipper.cpp
@@ -1,42 +1,38 @@
/*******************************************************************************
-* *
-* Author : Angus Johnson *
-* Version : 6.4.2 *
-* Date : 27 February 2017 *
-* Website : http://www.angusj.com *
-* Copyright : Angus Johnson 2010-2017 *
-* *
-* License: *
-* Use, modification & distribution is subject to Boost Software License Ver 1. *
-* http://www.boost.org/LICENSE_1_0.txt *
-* *
-* Attributions: *
-* The code in this library is an extension of Bala Vatti's clipping algorithm: *
-* "A generic solution to polygon clipping" *
-* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
-* http://portal.acm.org/citation.cfm?id=129906 *
-* *
-* Computer graphics and geometric modeling: implementation and algorithms *
-* By Max K. Agoston *
-* Springer; 1 edition (January 4, 2005) *
-* http://books.google.com/books?q=vatti+clipping+agoston *
-* *
-* See also: *
-* "Polygon Offsetting by Computing Winding Numbers" *
-* Paper no. DETC2005-85513 pp. 565-575 *
-* ASME 2005 International Design Engineering Technical Conferences *
-* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
-* September 24-28, 2005 , Long Beach, California, USA *
-* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
-* *
-*******************************************************************************/
+ * *
+ * Author : Angus Johnson * Version : 6.4.2 * Date : 27 February
+ *2017 * Website :
+ *http://www.angusj.com * Copyright :
+ *Angus Johnson 2010-2017 *
+ * *
+ * License: * Use, modification & distribution is subject to Boost Software
+ *License Ver 1. * http://www.boost.org/LICENSE_1_0.txt *
+ * *
+ * Attributions: * The code in this library is an extension of Bala Vatti's
+ *clipping algorithm: * "A generic solution to polygon clipping" *
+ * Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
+ * http://portal.acm.org/citation.cfm?id=129906 *
+ * *
+ * Computer graphics and geometric modeling: implementation and algorithms * By
+ *Max K. Agoston *
+ * Springer; 1 edition (January 4, 2005) *
+ * http://books.google.com/books?q=vatti+clipping+agoston *
+ * *
+ * See also: * "Polygon Offsetting by Computing Winding Numbers" * Paper no.
+ *DETC2005-85513 pp. 565-575 * ASME 2005
+ *International Design Engineering Technical Conferences * and
+ *Computers and Information in Engineering Conference (IDETC/CIE2005) *
+ * September 24-28, 2005 , Long Beach, California, USA *
+ * http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
+ * *
+ *******************************************************************************/
/*******************************************************************************
-* *
-* This is a translation of the Delphi Clipper library and the naming style *
-* used has retained a Delphi flavour. *
-* *
-*******************************************************************************/
+ * *
+ * This is a translation of the Delphi Clipper library and the naming style *
+ * used has retained a Delphi flavour. *
+ * *
+ *******************************************************************************/
#include "ocr_clipper.hpp"
#include
@@ -1045,8 +1041,9 @@ bool ClipperBase::AddPath(const Path &pg, PolyType PolyTyp, bool Closed) {
}
if (E->Prev == E->Next)
break; // only two vertices
- else if (Closed && SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr,
- m_UseFullRange) &&
+ else if (Closed &&
+ SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr,
+ m_UseFullRange) &&
(!m_PreserveCollinear ||
!Pt2IsBetweenPt1AndPt3(E->Prev->Curr, E->Curr, E->Next->Curr))) {
// Collinear edges are allowed for open paths but in closed paths
@@ -2518,14 +2515,14 @@ void GetHorzDirection(TEdge &HorzEdge, Direction &Dir, cInt &Left,
//------------------------------------------------------------------------
/*******************************************************************************
-* Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or *
-* Bottom of a scanbeam) are processed as if layered. The order in which HEs *
-* are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] *
-* (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), *
-* and with other non-horizontal edges [*]. Once these intersections are *
-* processed, intermediate HEs then 'promote' the Edge above (NextInLML) into *
-* the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. *
-*******************************************************************************/
+ * Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or *
+ * Bottom of a scanbeam) are processed as if layered. The order in which HEs *
+ * are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] *
+ * (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), * and
+ *with other non-horizontal edges [*]. Once these intersections are *
+ * processed, intermediate HEs then 'promote' the Edge above (NextInLML) into *
+ * the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. *
+ *******************************************************************************/
void Clipper::ProcessHorizontal(TEdge *horzEdge) {
Direction dir;
@@ -4377,4 +4374,4 @@ std::ostream &operator<<(std::ostream &s, const Paths &p) {
}
//------------------------------------------------------------------------------
-} // ClipperLib namespace
+} // namespace ClipperLib
diff --git a/deploy/android_demo/app/src/main/cpp/ocr_clipper.hpp b/deploy/android_demo/app/src/main/cpp/ocr_clipper.hpp
index 60af2bb733..54f0c4e36e 100644
--- a/deploy/android_demo/app/src/main/cpp/ocr_clipper.hpp
+++ b/deploy/android_demo/app/src/main/cpp/ocr_clipper.hpp
@@ -1,35 +1,31 @@
/*******************************************************************************
-* *
-* Author : Angus Johnson *
-* Version : 6.4.2 *
-* Date : 27 February 2017 *
-* Website : http://www.angusj.com *
-* Copyright : Angus Johnson 2010-2017 *
-* *
-* License: *
-* Use, modification & distribution is subject to Boost Software License Ver 1. *
-* http://www.boost.org/LICENSE_1_0.txt *
-* *
-* Attributions: *
-* The code in this library is an extension of Bala Vatti's clipping algorithm: *
-* "A generic solution to polygon clipping" *
-* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
-* http://portal.acm.org/citation.cfm?id=129906 *
-* *
-* Computer graphics and geometric modeling: implementation and algorithms *
-* By Max K. Agoston *
-* Springer; 1 edition (January 4, 2005) *
-* http://books.google.com/books?q=vatti+clipping+agoston *
-* *
-* See also: *
-* "Polygon Offsetting by Computing Winding Numbers" *
-* Paper no. DETC2005-85513 pp. 565-575 *
-* ASME 2005 International Design Engineering Technical Conferences *
-* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
-* September 24-28, 2005 , Long Beach, California, USA *
-* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
-* *
-*******************************************************************************/
+ * *
+ * Author : Angus Johnson * Version : 6.4.2 * Date : 27 February
+ *2017 * Website :
+ *http://www.angusj.com * Copyright :
+ *Angus Johnson 2010-2017 *
+ * *
+ * License: * Use, modification & distribution is subject to Boost Software
+ *License Ver 1. * http://www.boost.org/LICENSE_1_0.txt *
+ * *
+ * Attributions: * The code in this library is an extension of Bala Vatti's
+ *clipping algorithm: * "A generic solution to polygon clipping" *
+ * Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
+ * http://portal.acm.org/citation.cfm?id=129906 *
+ * *
+ * Computer graphics and geometric modeling: implementation and algorithms * By
+ *Max K. Agoston *
+ * Springer; 1 edition (January 4, 2005) *
+ * http://books.google.com/books?q=vatti+clipping+agoston *
+ * *
+ * See also: * "Polygon Offsetting by Computing Winding Numbers" * Paper no.
+ *DETC2005-85513 pp. 565-575 * ASME 2005
+ *International Design Engineering Technical Conferences * and
+ *Computers and Information in Engineering Conference (IDETC/CIE2005) *
+ * September 24-28, 2005 , Long Beach, California, USA *
+ * http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
+ * *
+ *******************************************************************************/
#ifndef clipper_hpp
#define clipper_hpp
@@ -539,6 +535,6 @@ class clipperException : public std::exception {
};
//------------------------------------------------------------------------------
-} // ClipperLib namespace
+} // namespace ClipperLib
#endif // clipper_hpp
diff --git a/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp
index e7de9b0b1c..141b5157a4 100644
--- a/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp
+++ b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp
@@ -17,7 +17,6 @@
#include
#include
#include
-#include
#include
const std::vector CLS_IMAGE_SHAPE = {3, 48, 192};
diff --git a/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.cpp b/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.cpp
index 44c34a2800..7e61a33e07 100644
--- a/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.cpp
+++ b/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.cpp
@@ -17,7 +17,6 @@
#include
#include
#include
-#include
#include
const std::string CHARACTER_TYPE = "ch";
diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp
index 1bd989c9da..277ec80f1a 100644
--- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp
+++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp
@@ -16,16 +16,16 @@ OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) {}
int OCR_PPredictor::init(const std::string &det_model_content,
const std::string &rec_model_content,
const std::string &cls_model_content) {
- _det_predictor = std::unique_ptr(
- new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR, _config.mode});
+ _det_predictor = std::unique_ptr(new PPredictor{
+ _config.use_opencl, _config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_nb(det_model_content);
- _rec_predictor = std::unique_ptr(
- new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
+ _rec_predictor = std::unique_ptr(new PPredictor{
+ _config.use_opencl, _config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_nb(rec_model_content);
- _cls_predictor = std::unique_ptr(
- new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
+ _cls_predictor = std::unique_ptr(new PPredictor{
+ _config.use_opencl, _config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_nb(cls_model_content);
return RETURN_OK;
}
@@ -33,17 +33,16 @@ int OCR_PPredictor::init(const std::string &det_model_content,
int OCR_PPredictor::init_from_file(const std::string &det_model_path,
const std::string &rec_model_path,
const std::string &cls_model_path) {
- _det_predictor = std::unique_ptr(
- new PPredictor{_config.use_opencl, _config.thread_num, NET_OCR, _config.mode});
+ _det_predictor = std::unique_ptr(new PPredictor{
+ _config.use_opencl, _config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_from_file(det_model_path);
-
- _rec_predictor = std::unique_ptr(
- new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
+ _rec_predictor = std::unique_ptr(new PPredictor{
+ _config.use_opencl, _config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_from_file(rec_model_path);
- _cls_predictor = std::unique_ptr(
- new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
+ _cls_predictor = std::unique_ptr(new PPredictor{
+ _config.use_opencl, _config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_from_file(cls_model_path);
return RETURN_OK;
}
@@ -78,22 +77,23 @@ visual_img(const std::vector>> &filter_boxes,
}
std::vector
-OCR_PPredictor::infer_ocr(cv::Mat &origin,int max_size_len, int run_det, int run_cls, int run_rec) {
+OCR_PPredictor::infer_ocr(cv::Mat &origin, int max_size_len, int run_det,
+ int run_cls, int run_rec) {
LOGI("ocr cpp start *****************");
LOGI("ocr cpp det: %d, cls: %d, rec: %d", run_det, run_cls, run_rec);
std::vector ocr_results;
- if(run_det){
+ if (run_det) {
infer_det(origin, max_size_len, ocr_results);
}
- if(run_rec){
- if(ocr_results.size()==0){
+ if (run_rec) {
+ if (ocr_results.size() == 0) {
OCRPredictResult res;
ocr_results.emplace_back(std::move(res));
}
- for(int i = 0; i < ocr_results.size();i++) {
+ for (int i = 0; i < ocr_results.size(); i++) {
infer_rec(origin, run_cls, ocr_results[i]);
}
- }else if(run_cls){
+ } else if (run_cls) {
ClsPredictResult cls_res = infer_cls(origin);
OCRPredictResult res;
res.cls_score = cls_res.cls_score;
@@ -144,7 +144,8 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
return resize_img;
}
-void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len, std::vector &ocr_results) {
+void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len,
+ std::vector &ocr_results) {
std::vector mean = {0.485f, 0.456f, 0.406f};
std::vector scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
@@ -160,22 +161,27 @@ void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len, std::vector results = _det_predictor->infer();
PredictorOutput &res = results.at(0);
- std::vector>> filtered_box = calc_filtered_boxes(
- res.get_float_data(), res.get_size(), input_image.rows, input_image.cols, origin);
+ std::vector>> filtered_box =
+ calc_filtered_boxes(res.get_float_data(), res.get_size(),
+ input_image.rows, input_image.cols, origin);
LOGI("ocr cpp det Filter_box size %ld", filtered_box.size());
- for(int i = 0;i mean = {0.5f, 0.5f, 0.5f};
std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector dims = {1, 3, 0, 0};
@@ -184,21 +190,19 @@ void OCR_PPredictor::infer_rec(const cv::Mat &origin_img, int run_cls, OCRPredic
const std::vector> &box = ocr_result.points;
cv::Mat crop_img;
- if(box.size()>0){
+ if (box.size() > 0) {
crop_img = get_rotate_crop_image(origin_img, box);
- }
- else{
+ } else {
crop_img = origin_img;
}
- if(run_cls){
+ if (run_cls) {
ClsPredictResult cls_res = infer_cls(crop_img);
crop_img = cls_res.img;
ocr_result.cls_score = cls_res.cls_score;
ocr_result.cls_label = cls_res.cls_label;
}
-
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
@@ -347,4 +351,4 @@ float OCR_PPredictor::postprocess_rec_score(const PredictorOutput &res) {
}
NET_TYPE OCR_PPredictor::get_net_flag() const { return NET_OCR; }
-}
\ No newline at end of file
+} // namespace ppredictor
\ No newline at end of file
diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h
index f0bff93f1f..7bd0e7a4bb 100644
--- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h
+++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h
@@ -15,8 +15,8 @@ namespace ppredictor {
* Config
*/
struct OCR_Config {
- int use_opencl = 0;
- int thread_num = 4; // Thread num
+ int use_opencl = 0;
+ int thread_num = 4; // Thread num
paddle::lite_api::PowerMode mode =
paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
};
@@ -29,13 +29,13 @@ struct OCRPredictResult {
std::vector> points;
float score;
float cls_score;
- int cls_label=-1;
+ int cls_label = -1;
};
struct ClsPredictResult {
- float cls_score;
- int cls_label=-1;
- cv::Mat img;
+ float cls_score;
+ int cls_label = -1;
+ cv::Mat img;
};
/**
* OCR there are 2 models
@@ -69,8 +69,9 @@ class OCR_PPredictor : public PPredictor_Interface {
* @param origin
* @return
*/
- virtual std::vector
- infer_ocr(cv::Mat &origin, int max_size_len, int run_det, int run_cls, int run_rec);
+ virtual std::vector infer_ocr(cv::Mat &origin,
+ int max_size_len, int run_det,
+ int run_cls, int run_rec);
virtual NET_TYPE get_net_flag() const;
@@ -87,8 +88,8 @@ class OCR_PPredictor : public PPredictor_Interface {
calc_filtered_boxes(const float *pred, int pred_size, int output_height,
int output_width, const cv::Mat &origin);
- void
- infer_det(cv::Mat &origin, int max_side_len, std::vector& ocr_results);
+ void infer_det(cv::Mat &origin, int max_side_len,
+ std::vector &ocr_results);
/**
* infer for rec model
*
@@ -96,16 +97,16 @@ class OCR_PPredictor : public PPredictor_Interface {
* @param origin
* @return
*/
- void
- infer_rec(const cv::Mat &origin, int run_cls, OCRPredictResult& ocr_result);
+ void infer_rec(const cv::Mat &origin, int run_cls,
+ OCRPredictResult &ocr_result);
- /**
- * infer for cls model
- *
- * @param boxes
- * @param origin
- * @return
- */
+ /**
+ * infer for cls model
+ *
+ * @param boxes
+ * @param origin
+ * @return
+ */
ClsPredictResult infer_cls(const cv::Mat &origin, float thresh = 0.9);
/**
@@ -127,4 +128,4 @@ class OCR_PPredictor : public PPredictor_Interface {
std::unique_ptr _cls_predictor;
OCR_Config _config;
};
-}
+} // namespace ppredictor
diff --git a/deploy/android_demo/app/src/main/cpp/ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ppredictor.cpp
index a40fe5e1b2..a4725017bf 100644
--- a/deploy/android_demo/app/src/main/cpp/ppredictor.cpp
+++ b/deploy/android_demo/app/src/main/cpp/ppredictor.cpp
@@ -4,7 +4,8 @@
namespace ppredictor {
PPredictor::PPredictor(int use_opencl, int thread_num, int net_flag,
paddle::lite_api::PowerMode mode)
- : _use_opencl(use_opencl), _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {}
+ : _use_opencl(use_opencl), _thread_num(thread_num), _net_flag(net_flag),
+ _mode(mode) {}
int PPredictor::init_nb(const std::string &model_content) {
paddle::lite_api::MobileConfig config;
@@ -19,7 +20,8 @@ int PPredictor::init_from_file(const std::string &model_content) {
}
template int PPredictor::_init(ConfigT &config) {
- bool is_opencl_backend_valid = paddle::lite_api::IsOpenCLBackendValid(/*check_fp16_valid = false*/);
+ bool is_opencl_backend_valid =
+ paddle::lite_api::IsOpenCLBackendValid(/*check_fp16_valid = false*/);
if (is_opencl_backend_valid) {
if (_use_opencl != 0) {
// Make sure you have write permission of the binary path.
@@ -35,7 +37,8 @@ template int PPredictor::_init(ConfigT &config) {
// CL_TUNE_EXHAUSTIVE: 3
const std::string tuned_path = "/data/local/tmp/";
const std::string tuned_name = "lite_opencl_tuned.bin";
- config.set_opencl_tune(paddle::lite_api::CL_TUNE_NORMAL, tuned_path, tuned_name);
+ config.set_opencl_tune(paddle::lite_api::CL_TUNE_NORMAL, tuned_path,
+ tuned_name);
// opencl precision option
// CL_PRECISION_AUTO: 0, first fp16 if valid, default
@@ -84,7 +87,8 @@ std::vector PPredictor::infer() {
for (int i = 0; i < _predictor->GetOutputNames().size(); i++) {
std::unique_ptr output_tensor =
_predictor->GetOutput(i);
- LOGI("ocr cpp output tensor[%d] size %ld", i, product(output_tensor->shape()));
+ LOGI("ocr cpp output tensor[%d] size %ld", i,
+ product(output_tensor->shape()));
PredictorOutput result{std::move(output_tensor), i, _net_flag};
results.emplace_back(std::move(result));
}
@@ -92,4 +96,4 @@ std::vector PPredictor::infer() {
}
NET_TYPE PPredictor::get_net_flag() const { return (NET_TYPE)_net_flag; }
-}
\ No newline at end of file
+} // namespace ppredictor
\ No newline at end of file
diff --git a/deploy/android_demo/app/src/main/cpp/ppredictor.h b/deploy/android_demo/app/src/main/cpp/ppredictor.h
index 40250764f8..230a84df43 100644
--- a/deploy/android_demo/app/src/main/cpp/ppredictor.h
+++ b/deploy/android_demo/app/src/main/cpp/ppredictor.h
@@ -22,7 +22,7 @@ class PPredictor_Interface {
class PPredictor : public PPredictor_Interface {
public:
PPredictor(
- int use_opencl, int thread_num, int net_flag = 0,
+ int use_opencl, int thread_num, int net_flag = 0,
paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH);
virtual ~PPredictor() {}
@@ -54,11 +54,11 @@ class PPredictor : public PPredictor_Interface {
template int _init(ConfigT &config);
private:
- int _use_opencl;
+ int _use_opencl;
int _thread_num;
paddle::lite_api::PowerMode _mode;
std::shared_ptr _predictor;
bool _is_input_get = false;
int _net_flag;
};
-}
+} // namespace ppredictor
diff --git a/deploy/android_demo/app/src/main/cpp/predictor_input.cpp b/deploy/android_demo/app/src/main/cpp/predictor_input.cpp
index f0b4bf8a82..035be764da 100644
--- a/deploy/android_demo/app/src/main/cpp/predictor_input.cpp
+++ b/deploy/android_demo/app/src/main/cpp/predictor_input.cpp
@@ -25,4 +25,4 @@ void PredictorInput::set_data(const float *input_data, int input_float_len) {
float *input_raw_data = get_mutable_float_data();
memcpy(input_raw_data, input_data, input_float_len * sizeof(float));
}
-}
\ No newline at end of file
+} // namespace ppredictor
\ No newline at end of file
diff --git a/deploy/android_demo/app/src/main/cpp/predictor_input.h b/deploy/android_demo/app/src/main/cpp/predictor_input.h
index f3fd6cfe47..185f08144b 100644
--- a/deploy/android_demo/app/src/main/cpp/predictor_input.h
+++ b/deploy/android_demo/app/src/main/cpp/predictor_input.h
@@ -23,4 +23,4 @@ class PredictorInput {
int _index;
int _net_flag;
};
-}
+} // namespace ppredictor
diff --git a/deploy/android_demo/app/src/main/cpp/predictor_output.cpp b/deploy/android_demo/app/src/main/cpp/predictor_output.cpp
index e9cfdbc319..43ef68931c 100644
--- a/deploy/android_demo/app/src/main/cpp/predictor_output.cpp
+++ b/deploy/android_demo/app/src/main/cpp/predictor_output.cpp
@@ -23,4 +23,4 @@ int64_t PredictorOutput::get_size() const {
const std::vector PredictorOutput::get_shape() const {
return _tensor->shape();
}
-}
\ No newline at end of file
+} // namespace ppredictor
\ No newline at end of file
diff --git a/deploy/android_demo/app/src/main/cpp/predictor_output.h b/deploy/android_demo/app/src/main/cpp/predictor_output.h
index 8e8c9ba014..ce00af7ce7 100644
--- a/deploy/android_demo/app/src/main/cpp/predictor_output.h
+++ b/deploy/android_demo/app/src/main/cpp/predictor_output.h
@@ -28,4 +28,4 @@ class PredictorOutput {
int _index;
int _net_flag;
};
-}
+} // namespace ppredictor
diff --git a/deploy/avh/README.md b/deploy/avh/README.md
index 0087103146..b40933ecfd 100644
--- a/deploy/avh/README.md
+++ b/deploy/avh/README.md
@@ -31,7 +31,7 @@ You can refer to this [guide](https://arm-software.github.io/AVH/main/examples/h
Case 2: If the demo is run in the [ci_cpu Docker container](https://github.com/apache/tvm/blob/main/docker/Dockerfile.ci_cpu) provided with [TVM](https://github.com/apache/tvm), then the following software will already be installed.
Case 3: If the demo is not run in the ci_cpu Docker container, then you will need the following:
-- Software required to build and run the demo (These can all be installed by running
+- Software required to build and run the demo (These can all be installed by running
tvm/docker/install/ubuntu_install_ethosu_driver_stack.sh.)
- [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps)
- [cmake 3.19.5](https://github.com/Kitware/CMake/releases/)
@@ -45,7 +45,7 @@ Case 3: If the demo is not run in the ci_cpu Docker container, then you will nee
```
In case2 and case3:
-
+
You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP.
For example if you've installed these in ```/opt/arm``` , then you would do the following:
```bash
@@ -112,5 +112,3 @@ PP-OCRv3 is the third version of the PP-OCR series model. This series of models
- PP-OCRv3: ultra-lightweight OCR system: detection (3.6M) + direction classifier (1.4M) + recognition (12M) = 17.0M
- Support more than 80 kinds of multi-language recognition models, including English, Chinese, French, German, Arabic, Korean, Japanese and so on. For details
- Support vertical text recognition, and long text recognition
-
-
diff --git a/deploy/avh/README_ch.md b/deploy/avh/README_ch.md
index 35cc9f2b7c..0d05887b1a 100644
--- a/deploy/avh/README_ch.md
+++ b/deploy/avh/README_ch.md
@@ -26,7 +26,7 @@
本demo运行在TVM提供的docker环境上,在该环境中已经安装好的必须的软件
-在非docker环境中,需要手动安装如下依赖项:
+在非docker环境中,需要手动安装如下依赖项:
- 软件可通过[安装脚本](https://github.com/apache/tvm/blob/main/docker/install/ubuntu_install_ethosu_driver_stack.sh)一键安装
- [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps)
@@ -90,5 +90,5 @@ export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/
PP-OCRv3是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)发布的PP-OCR系列模型的第三个版本,该系列模型具有以下特点:
- 超轻量级OCR系统:检测(3.6M)+方向分类器(1.4M)+识别(12M)=17.0M。
- - 支持80多种多语言识别模型,包括英文、中文、法文、德文、阿拉伯文、韩文、日文等。
+ - 支持80多种多语言识别模型,包括英文、中文、法文、德文、阿拉伯文、韩文、日文等。
- 支持竖排文本识别,长文本识别。
diff --git a/deploy/avh/convert_image.py b/deploy/avh/convert_image.py
index 7c6dbd7fd8..7a78faccbe 100755
--- a/deploy/avh/convert_image.py
+++ b/deploy/avh/convert_image.py
@@ -30,8 +30,7 @@ def resize_norm_img(img, image_shape, padding=True):
h = img.shape[0]
w = img.shape[1]
if not padding:
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
@@ -40,7 +39,7 @@ def resize_norm_img(img, image_shape, padding=True):
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -62,8 +61,9 @@ def create_header_file(name, tensor_name, tensor_data, output_path):
raw_path = file_path.with_suffix(".h").resolve()
with open(raw_path, "w") as header_file:
header_file.write(
- "\n" + f"const size_t {tensor_name}_len = {tensor_data.size};\n" +
- f'__attribute__((section(".data.tvm"), aligned(16))) float {tensor_name}[] = '
+ "\n"
+ + f"const size_t {tensor_name}_len = {tensor_data.size};\n"
+ + f'__attribute__((section(".data.tvm"), aligned(16))) float {tensor_name}[] = '
)
header_file.write("{")
@@ -94,7 +94,8 @@ def create_headers(image_name):
"outputs",
"output",
output_data,
- "./include", )
+ "./include",
+ )
if __name__ == "__main__":
diff --git a/deploy/avh/include/crt_config.h b/deploy/avh/include/crt_config.h
index 4b9ccca02b..2fd0ead606 100644
--- a/deploy/avh/include/crt_config.h
+++ b/deploy/avh/include/crt_config.h
@@ -23,4 +23,4 @@
/*! Log level of the CRT runtime */
#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG
-#endif // TVM_RUNTIME_CRT_CONFIG_H_
+#endif // TVM_RUNTIME_CRT_CONFIG_H_
diff --git a/deploy/avh/include/tvm_runtime.h b/deploy/avh/include/tvm_runtime.h
index 2b59d93470..0978d7adfa 100644
--- a/deploy/avh/include/tvm_runtime.h
+++ b/deploy/avh/include/tvm_runtime.h
@@ -33,22 +33,26 @@ void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t error_code) {
exit(-1);
}
-tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
+tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev,
+ void **out_ptr) {
return kTvmErrorFunctionCallNotImplemented;
}
-tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
+tvm_crt_error_t TVMPlatformMemoryFree(void *ptr, DLDevice dev) {
return kTvmErrorFunctionCallNotImplemented;
}
-void TVMLogf(const char* msg, ...) {
+void TVMLogf(const char *msg, ...) {
va_list args;
va_start(args, msg);
vfprintf(stdout, msg, args);
va_end(args);
}
-TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { return 0; }
+TVM_DLL int TVMFuncRegisterGlobal(const char *name, TVMFunctionHandle f,
+ int override) {
+ return 0;
+}
#ifdef __cplusplus
}
diff --git a/deploy/avh/src/demo_bare_metal.c b/deploy/avh/src/demo_bare_metal.c
index 3f5f1bc4b0..c90e17fa95 100644
--- a/deploy/avh/src/demo_bare_metal.c
+++ b/deploy/avh/src/demo_bare_metal.c
@@ -27,9 +27,9 @@
#include "inputs.h"
#include "outputs.h"
-
-int main(int argc, char** argv) {
- char dict[]={"#0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~!\"#$%&'()*+,-./ "};
+int main(int argc, char **argv) {
+ char dict[] = {"#0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`"
+ "abcdefghijklmnopqrstuvwxyz{|}~!\"#$%&'()*+,-./ "};
int char_dict_nums = 97;
uart_init();
printf("Starting ocr rec inference\n");
@@ -44,17 +44,17 @@ int main(int argc, char** argv) {
// post process
int char_nums = output_len / char_dict_nums;
-
+
int last_index = 0;
float score = 0.f;
int count = 0;
-
+
printf("text: ");
for (int i = 0; i < char_nums; i++) {
int argmax_idx = 0;
float max_value = 0.0f;
- for (int j = 0; j < char_dict_nums; j++){
- if (output[i * char_dict_nums + j] > max_value){
+ for (int j = 0; j < char_dict_nums; j++) {
+ if (output[i * char_dict_nums + j] > max_value) {
max_value = output[i * char_dict_nums + j];
argmax_idx = j;
}
@@ -69,7 +69,7 @@ int main(int argc, char** argv) {
}
score /= count;
printf(", score: %f\n", score);
-
+
// The FVP will shut down when it receives "EXITTHESIM" on the UART
printf("EXITTHESIM\n");
while (1 == 1)
diff --git a/deploy/cpp_infer/include/clipper.h b/deploy/cpp_infer/include/clipper.h
index 522f81c8c4..d19e95ca2c 100644
--- a/deploy/cpp_infer/include/clipper.h
+++ b/deploy/cpp_infer/include/clipper.h
@@ -1,35 +1,31 @@
/*******************************************************************************
-* *
-* Author : Angus Johnson *
-* Version : 6.4.2 *
-* Date : 27 February 2017 *
-* Website : http://www.angusj.com *
-* Copyright : Angus Johnson 2010-2017 *
-* *
-* License: *
-* Use, modification & distribution is subject to Boost Software License Ver 1. *
-* http://www.boost.org/LICENSE_1_0.txt *
-* *
-* Attributions: *
-* The code in this library is an extension of Bala Vatti's clipping algorithm: *
-* "A generic solution to polygon clipping" *
-* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
-* http://portal.acm.org/citation.cfm?id=129906 *
-* *
-* Computer graphics and geometric modeling: implementation and algorithms *
-* By Max K. Agoston *
-* Springer; 1 edition (January 4, 2005) *
-* http://books.google.com/books?q=vatti+clipping+agoston *
-* *
-* See also: *
-* "Polygon Offsetting by Computing Winding Numbers" *
-* Paper no. DETC2005-85513 pp. 565-575 *
-* ASME 2005 International Design Engineering Technical Conferences *
-* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
-* September 24-28, 2005 , Long Beach, California, USA *
-* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
-* *
-*******************************************************************************/
+ * *
+ * Author : Angus Johnson * Version : 6.4.2 * Date : 27 February
+ *2017 * Website :
+ *http://www.angusj.com * Copyright :
+ *Angus Johnson 2010-2017 *
+ * *
+ * License: * Use, modification & distribution is subject to Boost Software
+ *License Ver 1. * http://www.boost.org/LICENSE_1_0.txt *
+ * *
+ * Attributions: * The code in this library is an extension of Bala Vatti's
+ *clipping algorithm: * "A generic solution to polygon clipping" *
+ * Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
+ * http://portal.acm.org/citation.cfm?id=129906 *
+ * *
+ * Computer graphics and geometric modeling: implementation and algorithms * By
+ *Max K. Agoston *
+ * Springer; 1 edition (January 4, 2005) *
+ * http://books.google.com/books?q=vatti+clipping+agoston *
+ * *
+ * See also: * "Polygon Offsetting by Computing Winding Numbers" * Paper no.
+ *DETC2005-85513 pp. 565-575 * ASME 2005
+ *International Design Engineering Technical Conferences * and
+ *Computers and Information in Engineering Conference (IDETC/CIE2005) *
+ * September 24-28, 2005 , Long Beach, California, USA *
+ * http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
+ * *
+ *******************************************************************************/
#pragma once
@@ -420,6 +416,6 @@ class clipperException : public std::exception {
};
//------------------------------------------------------------------------------
-} // ClipperLib namespace
+} // namespace ClipperLib
#endif // clipper_hpp
diff --git a/deploy/cpp_infer/src/clipper.cpp b/deploy/cpp_infer/src/clipper.cpp
index 5f5d221676..75f21c9422 100644
--- a/deploy/cpp_infer/src/clipper.cpp
+++ b/deploy/cpp_infer/src/clipper.cpp
@@ -1,42 +1,38 @@
/*******************************************************************************
-* *
-* Author : Angus Johnson *
-* Version : 6.4.2 *
-* Date : 27 February 2017 *
-* Website : http://www.angusj.com *
-* Copyright : Angus Johnson 2010-2017 *
-* *
-* License: *
-* Use, modification & distribution is subject to Boost Software License Ver 1. *
-* http://www.boost.org/LICENSE_1_0.txt *
-* *
-* Attributions: *
-* The code in this library is an extension of Bala Vatti's clipping algorithm: *
-* "A generic solution to polygon clipping" *
-* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
-* http://portal.acm.org/citation.cfm?id=129906 *
-* *
-* Computer graphics and geometric modeling: implementation and algorithms *
-* By Max K. Agoston *
-* Springer; 1 edition (January 4, 2005) *
-* http://books.google.com/books?q=vatti+clipping+agoston *
-* *
-* See also: *
-* "Polygon Offsetting by Computing Winding Numbers" *
-* Paper no. DETC2005-85513 pp. 565-575 *
-* ASME 2005 International Design Engineering Technical Conferences *
-* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
-* September 24-28, 2005 , Long Beach, California, USA *
-* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
-* *
-*******************************************************************************/
+ * *
+ * Author : Angus Johnson * Version : 6.4.2 * Date : 27 February
+ *2017 * Website :
+ *http://www.angusj.com * Copyright :
+ *Angus Johnson 2010-2017 *
+ * *
+ * License: * Use, modification & distribution is subject to Boost Software
+ *License Ver 1. * http://www.boost.org/LICENSE_1_0.txt *
+ * *
+ * Attributions: * The code in this library is an extension of Bala Vatti's
+ *clipping algorithm: * "A generic solution to polygon clipping" *
+ * Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
+ * http://portal.acm.org/citation.cfm?id=129906 *
+ * *
+ * Computer graphics and geometric modeling: implementation and algorithms * By
+ *Max K. Agoston *
+ * Springer; 1 edition (January 4, 2005) *
+ * http://books.google.com/books?q=vatti+clipping+agoston *
+ * *
+ * See also: * "Polygon Offsetting by Computing Winding Numbers" * Paper no.
+ *DETC2005-85513 pp. 565-575 * ASME 2005
+ *International Design Engineering Technical Conferences * and
+ *Computers and Information in Engineering Conference (IDETC/CIE2005) *
+ * September 24-28, 2005 , Long Beach, California, USA *
+ * http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
+ * *
+ *******************************************************************************/
/*******************************************************************************
-* *
-* This is a translation of the Delphi Clipper library and the naming style *
-* used has retained a Delphi flavour. *
-* *
-*******************************************************************************/
+ * *
+ * This is a translation of the Delphi Clipper library and the naming style *
+ * used has retained a Delphi flavour. *
+ * *
+ *******************************************************************************/
#include
#include
#include
@@ -1045,8 +1041,9 @@ bool ClipperBase::AddPath(const Path &pg, PolyType PolyTyp, bool Closed) {
}
if (E->Prev == E->Next)
break; // only two vertices
- else if (Closed && SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr,
- m_UseFullRange) &&
+ else if (Closed &&
+ SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr,
+ m_UseFullRange) &&
(!m_PreserveCollinear ||
!Pt2IsBetweenPt1AndPt3(E->Prev->Curr, E->Curr, E->Next->Curr))) {
// Collinear edges are allowed for open paths but in closed paths
@@ -2518,14 +2515,14 @@ void GetHorzDirection(TEdge &HorzEdge, Direction &Dir, cInt &Left,
//------------------------------------------------------------------------
/*******************************************************************************
-* Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or *
-* Bottom of a scanbeam) are processed as if layered. The order in which HEs *
-* are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] *
-* (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), *
-* and with other non-horizontal edges [*]. Once these intersections are *
-* processed, intermediate HEs then 'promote' the Edge above (NextInLML) into *
-* the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. *
-*******************************************************************************/
+ * Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or *
+ * Bottom of a scanbeam) are processed as if layered. The order in which HEs *
+ * are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] *
+ * (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), * and
+ *with other non-horizontal edges [*]. Once these intersections are *
+ * processed, intermediate HEs then 'promote' the Edge above (NextInLML) into *
+ * the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. *
+ *******************************************************************************/
void Clipper::ProcessHorizontal(TEdge *horzEdge) {
Direction dir;
@@ -4377,4 +4374,4 @@ std::ostream &operator<<(std::ostream &s, const Paths &p) {
}
//------------------------------------------------------------------------------
-} // ClipperLib namespace
+} // namespace ClipperLib
diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp
index 6f2b5509e6..9130c18073 100644
--- a/deploy/cpp_infer/src/ocr_cls.cpp
+++ b/deploy/cpp_infer/src/ocr_cls.cpp
@@ -20,9 +20,12 @@ void Classifier::Run(std::vector img_list,
std::vector &cls_labels,
std::vector &cls_scores,
std::vector ×) {
- std::chrono::duration preprocess_diff = std::chrono::duration::zero();
- std::chrono::duration inference_diff = std::chrono::duration::zero();
- std::chrono::duration postprocess_diff = std::chrono::duration::zero();
+ std::chrono::duration preprocess_diff =
+ std::chrono::duration::zero();
+ std::chrono::duration inference_diff =
+ std::chrono::duration::zero();
+ std::chrono::duration postprocess_diff =
+ std::chrono::duration::zero();
int img_num = img_list.size();
std::vector cls_image_shape = {3, 48, 192};
diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp
index cf3e58d42a..5605736d1d 100644
--- a/deploy/cpp_infer/src/ocr_rec.cpp
+++ b/deploy/cpp_infer/src/ocr_rec.cpp
@@ -20,9 +20,12 @@ void CRNNRecognizer::Run(std::vector img_list,
std::vector &rec_texts,
std::vector &rec_text_scores,
std::vector ×) {
- std::chrono::duration preprocess_diff = std::chrono::duration::zero();
- std::chrono::duration inference_diff = std::chrono::duration::zero();
- std::chrono::duration postprocess_diff = std::chrono::duration::zero();
+ std::chrono::duration preprocess_diff =
+ std::chrono::duration::zero();
+ std::chrono::duration inference_diff =
+ std::chrono::duration::zero();
+ std::chrono::duration postprocess_diff =
+ std::chrono::duration::zero();
int img_num = img_list.size();
std::vector width_list;
diff --git a/deploy/fastdeploy/ascend/python/infer.py b/deploy/fastdeploy/ascend/python/infer.py
index ceb28e0f7f..cdaa167f2f 100755
--- a/deploy/fastdeploy/ascend/python/infer.py
+++ b/deploy/fastdeploy/ascend/python/infer.py
@@ -20,28 +20,27 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--det_model", required=True, help="Path of Detection model of PPOCR.")
+ "--det_model", required=True, help="Path of Detection model of PPOCR."
+ )
parser.add_argument(
- "--cls_model",
- required=True,
- help="Path of Classification model of PPOCR.")
+ "--cls_model", required=True, help="Path of Classification model of PPOCR."
+ )
parser.add_argument(
- "--rec_model",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_model", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--rec_label_file",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_label_file", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
return parser.parse_args()
def build_option(args):
-
det_option = fd.RuntimeOption()
cls_option = fd.RuntimeOption()
rec_option = fd.RuntimeOption()
@@ -68,13 +67,16 @@ def build_option(args):
det_option, cls_option, rec_option = build_option(args)
det_model = fd.vision.ocr.DBDetector(
- det_model_file, det_params_file, runtime_option=det_option)
+ det_model_file, det_params_file, runtime_option=det_option
+)
cls_model = fd.vision.ocr.Classifier(
- cls_model_file, cls_params_file, runtime_option=cls_option)
+ cls_model_file, cls_params_file, runtime_option=cls_option
+)
rec_model = fd.vision.ocr.Recognizer(
- rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
+ rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option
+)
# Rec model enable static shape infer.
# When deploy on Ascend, it must be true.
@@ -83,7 +85,8 @@ def build_option(args):
# Create PP-OCRv3, if cls_model is not needed,
# just set cls_model=None .
ppocr_v3 = fd.vision.ocr.PPOCRv3(
- det_model=det_model, cls_model=cls_model, rec_model=rec_model)
+ det_model=det_model, cls_model=cls_model, rec_model=rec_model
+)
# The batch size must be set to 1, when enable static shape infer.
ppocr_v3.cls_batch_size = 1
diff --git a/deploy/fastdeploy/cpu-gpu/python/infer.py b/deploy/fastdeploy/cpu-gpu/python/infer.py
index 8eac845998..67f847e04b 100755
--- a/deploy/fastdeploy/cpu-gpu/python/infer.py
+++ b/deploy/fastdeploy/cpu-gpu/python/infer.py
@@ -20,55 +20,55 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--det_model", required=True, help="Path of Detection model of PPOCR.")
+ "--det_model", required=True, help="Path of Detection model of PPOCR."
+ )
parser.add_argument(
- "--cls_model",
- required=True,
- help="Path of Classification model of PPOCR.")
+ "--cls_model", required=True, help="Path of Classification model of PPOCR."
+ )
parser.add_argument(
- "--rec_model",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_model", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--rec_label_file",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_label_file", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
parser.add_argument(
"--device",
type=str,
- default='cpu',
- help="Type of inference device, support 'cpu' or 'gpu'.")
+ default="cpu",
+ help="Type of inference device, support 'cpu' or 'gpu'.",
+ )
parser.add_argument(
"--device_id",
type=int,
default=0,
- help="Define which GPU card used to run model.")
+ help="Define which GPU card used to run model.",
+ )
parser.add_argument(
"--cls_bs",
type=int,
default=1,
- help="Classification model inference batch size.")
+ help="Classification model inference batch size.",
+ )
parser.add_argument(
- "--rec_bs",
- type=int,
- default=6,
- help="Recognition model inference batch size")
+ "--rec_bs", type=int, default=6, help="Recognition model inference batch size"
+ )
parser.add_argument(
"--backend",
type=str,
default="default",
- help="Type of inference backend, support ort/trt/paddle/openvino, default 'openvino' for cpu, 'tensorrt' for gpu"
+ help="Type of inference backend, support ort/trt/paddle/openvino, default 'openvino' for cpu, 'tensorrt' for gpu",
)
return parser.parse_args()
def build_option(args):
-
det_option = fd.RuntimeOption()
cls_option = fd.RuntimeOption()
rec_option = fd.RuntimeOption()
@@ -79,8 +79,9 @@ def build_option(args):
rec_option.use_gpu(args.device_id)
if args.backend.lower() == "trt":
- assert args.device.lower(
- ) == "gpu", "TensorRT backend require inference on device GPU."
+ assert (
+ args.device.lower() == "gpu"
+ ), "TensorRT backend require inference on device GPU."
det_option.use_trt_backend()
cls_option.use_trt_backend()
rec_option.use_trt_backend()
@@ -88,14 +89,15 @@ def build_option(args):
# If use TRT backend, the dynamic shape will be set as follow.
# We recommend that users set the length and height of the detection model to a multiple of 32.
# We also recommend that users set the Trt input shape as follow.
- det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
- [1, 3, 960, 960])
- cls_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [args.cls_bs, 3, 48, 320],
- [args.cls_bs, 3, 48, 1024])
- rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [args.rec_bs, 3, 48, 320],
- [args.rec_bs, 3, 48, 2304])
+ det_option.set_trt_input_shape(
+ "x", [1, 3, 64, 64], [1, 3, 640, 640], [1, 3, 960, 960]
+ )
+ cls_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [args.cls_bs, 3, 48, 320], [args.cls_bs, 3, 48, 1024]
+ )
+ rec_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [args.rec_bs, 3, 48, 320], [args.rec_bs, 3, 48, 2304]
+ )
# Users could save TRT cache file to disk as follow.
det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt")
@@ -103,8 +105,9 @@ def build_option(args):
rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt")
elif args.backend.lower() == "pptrt":
- assert args.device.lower(
- ) == "gpu", "Paddle-TensorRT backend require inference on device GPU."
+ assert (
+ args.device.lower() == "gpu"
+ ), "Paddle-TensorRT backend require inference on device GPU."
det_option.use_paddle_infer_backend()
det_option.paddle_infer_option.collect_trt_shape = True
det_option.paddle_infer_option.enable_trt = True
@@ -120,14 +123,15 @@ def build_option(args):
# If use TRT backend, the dynamic shape will be set as follow.
# We recommend that users set the length and height of the detection model to a multiple of 32.
# We also recommend that users set the Trt input shape as follow.
- det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
- [1, 3, 960, 960])
- cls_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [args.cls_bs, 3, 48, 320],
- [args.cls_bs, 3, 48, 1024])
- rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [args.rec_bs, 3, 48, 320],
- [args.rec_bs, 3, 48, 2304])
+ det_option.set_trt_input_shape(
+ "x", [1, 3, 64, 64], [1, 3, 640, 640], [1, 3, 960, 960]
+ )
+ cls_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [args.cls_bs, 3, 48, 320], [args.cls_bs, 3, 48, 1024]
+ )
+ rec_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [args.rec_bs, 3, 48, 320], [args.rec_bs, 3, 48, 2304]
+ )
# Users could save TRT cache file to disk as follow.
det_option.set_trt_cache_file(args.det_model)
@@ -145,15 +149,17 @@ def build_option(args):
rec_option.use_paddle_infer_backend()
elif args.backend.lower() == "openvino":
- assert args.device.lower(
- ) == "cpu", "OpenVINO backend require inference on device CPU."
+ assert (
+ args.device.lower() == "cpu"
+ ), "OpenVINO backend require inference on device CPU."
det_option.use_openvino_backend()
cls_option.use_openvino_backend()
rec_option.use_openvino_backend()
elif args.backend.lower() == "pplite":
- assert args.device.lower(
- ) == "cpu", "Paddle Lite backend require inference on device CPU."
+ assert (
+ args.device.lower() == "cpu"
+ ), "Paddle Lite backend require inference on device CPU."
det_option.use_lite_backend()
cls_option.use_lite_backend()
rec_option.use_lite_backend()
@@ -176,13 +182,16 @@ def build_option(args):
det_option, cls_option, rec_option = build_option(args)
det_model = fd.vision.ocr.DBDetector(
- det_model_file, det_params_file, runtime_option=det_option)
+ det_model_file, det_params_file, runtime_option=det_option
+)
cls_model = fd.vision.ocr.Classifier(
- cls_model_file, cls_params_file, runtime_option=cls_option)
+ cls_model_file, cls_params_file, runtime_option=cls_option
+)
rec_model = fd.vision.ocr.Recognizer(
- rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
+ rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option
+)
# Parameters settings for pre and post processing of Det/Cls/Rec Models.
# All parameters are set to default values.
@@ -196,7 +205,8 @@ def build_option(args):
# Create PP-OCRv3, if cls_model is not needed, just set cls_model=None .
ppocr_v3 = fd.vision.ocr.PPOCRv3(
- det_model=det_model, cls_model=cls_model, rec_model=rec_model)
+ det_model=det_model, cls_model=cls_model, rec_model=rec_model
+)
# Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity.
# When inference batch size is set to -1, it means that the inference batch size
diff --git a/deploy/fastdeploy/cpu-gpu/python/infer_cls.py b/deploy/fastdeploy/cpu-gpu/python/infer_cls.py
index b34868daef..a52bd3c52c 100755
--- a/deploy/fastdeploy/cpu-gpu/python/infer_cls.py
+++ b/deploy/fastdeploy/cpu-gpu/python/infer_cls.py
@@ -20,28 +20,30 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--cls_model",
- required=True,
- help="Path of Classification model of PPOCR.")
+ "--cls_model", required=True, help="Path of Classification model of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
parser.add_argument(
"--device",
type=str,
- default='cpu',
- help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.")
+ default="cpu",
+ help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.",
+ )
parser.add_argument(
"--device_id",
type=int,
default=0,
- help="Define which GPU card used to run model.")
+ help="Define which GPU card used to run model.",
+ )
return parser.parse_args()
def build_option(args):
-
cls_option = fd.RuntimeOption()
if args.device.lower() == "gpu":
@@ -60,7 +62,8 @@ def build_option(args):
# Create the cls_model
cls_model = fd.vision.ocr.Classifier(
- cls_model_file, cls_params_file, runtime_option=cls_option)
+ cls_model_file, cls_params_file, runtime_option=cls_option
+)
# Set the postprocessing parameters
cls_model.postprocessor.cls_thresh = 0.9
diff --git a/deploy/fastdeploy/cpu-gpu/python/infer_det.py b/deploy/fastdeploy/cpu-gpu/python/infer_det.py
index 7a7f5a07b7..0b9f1bf4e9 100755
--- a/deploy/fastdeploy/cpu-gpu/python/infer_det.py
+++ b/deploy/fastdeploy/cpu-gpu/python/infer_det.py
@@ -20,26 +20,30 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--det_model", required=True, help="Path of Detection model of PPOCR.")
+ "--det_model", required=True, help="Path of Detection model of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
parser.add_argument(
"--device",
type=str,
- default='cpu',
- help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.")
+ default="cpu",
+ help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.",
+ )
parser.add_argument(
"--device_id",
type=int,
default=0,
- help="Define which GPU card used to run model.")
+ help="Define which GPU card used to run model.",
+ )
return parser.parse_args()
def build_option(args):
-
det_option = fd.RuntimeOption()
if args.device.lower() == "gpu":
@@ -58,7 +62,8 @@ def build_option(args):
# Create the det_model
det_model = fd.vision.ocr.DBDetector(
- det_model_file, det_params_file, runtime_option=det_option)
+ det_model_file, det_params_file, runtime_option=det_option
+)
# Set the preporcessing parameters
det_model.preprocessor.max_side_len = 960
diff --git a/deploy/fastdeploy/cpu-gpu/python/infer_rec.py b/deploy/fastdeploy/cpu-gpu/python/infer_rec.py
index 6f9e03b20e..2b219e98a9 100755
--- a/deploy/fastdeploy/cpu-gpu/python/infer_rec.py
+++ b/deploy/fastdeploy/cpu-gpu/python/infer_rec.py
@@ -20,32 +20,33 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--rec_model",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_model", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--rec_label_file",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_label_file", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
parser.add_argument(
"--device",
type=str,
- default='cpu',
- help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.")
+ default="cpu",
+ help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.",
+ )
parser.add_argument(
"--device_id",
type=int,
default=0,
- help="Define which GPU card used to run model.")
+ help="Define which GPU card used to run model.",
+ )
return parser.parse_args()
def build_option(args):
-
rec_option = fd.RuntimeOption()
if args.device.lower() == "gpu":
@@ -65,7 +66,8 @@ def build_option(args):
# Create the rec_model
rec_model = fd.vision.ocr.Recognizer(
- rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
+ rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option
+)
# Read the image
im = cv2.imread(args.image)
diff --git a/deploy/fastdeploy/kunlunxin/python/infer.py b/deploy/fastdeploy/kunlunxin/python/infer.py
index 4780df832c..bf95174017 100755
--- a/deploy/fastdeploy/kunlunxin/python/infer.py
+++ b/deploy/fastdeploy/kunlunxin/python/infer.py
@@ -20,38 +20,36 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--det_model", required=True, help="Path of Detection model of PPOCR.")
+ "--det_model", required=True, help="Path of Detection model of PPOCR."
+ )
parser.add_argument(
- "--cls_model",
- required=True,
- help="Path of Classification model of PPOCR.")
+ "--cls_model", required=True, help="Path of Classification model of PPOCR."
+ )
parser.add_argument(
- "--rec_model",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_model", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--rec_label_file",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_label_file", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
parser.add_argument(
"--cls_bs",
type=int,
default=1,
- help="Classification model inference batch size.")
+ help="Classification model inference batch size.",
+ )
parser.add_argument(
- "--rec_bs",
- type=int,
- default=6,
- help="Recognition model inference batch size")
+ "--rec_bs", type=int, default=6, help="Recognition model inference batch size"
+ )
return parser.parse_args()
def build_option(args):
-
det_option = fd.RuntimeOption()
cls_option = fd.RuntimeOption()
rec_option = fd.RuntimeOption()
@@ -78,18 +76,22 @@ def build_option(args):
det_option, cls_option, rec_option = build_option(args)
det_model = fd.vision.ocr.DBDetector(
- det_model_file, det_params_file, runtime_option=det_option)
+ det_model_file, det_params_file, runtime_option=det_option
+)
cls_model = fd.vision.ocr.Classifier(
- cls_model_file, cls_params_file, runtime_option=cls_option)
+ cls_model_file, cls_params_file, runtime_option=cls_option
+)
rec_model = fd.vision.ocr.Recognizer(
- rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
+ rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option
+)
# Create PP-OCRv3, if cls_model is not needed,
# just set cls_model=None .
ppocr_v3 = fd.vision.ocr.PPOCRv3(
- det_model=det_model, cls_model=cls_model, rec_model=rec_model)
+ det_model=det_model, cls_model=cls_model, rec_model=rec_model
+)
# Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity.
# When inference batch size is set to -1, it means that the inference batch size
diff --git a/deploy/fastdeploy/rockchip/python/infer.py b/deploy/fastdeploy/rockchip/python/infer.py
index 7aa1382179..090353f938 100755
--- a/deploy/fastdeploy/rockchip/python/infer.py
+++ b/deploy/fastdeploy/rockchip/python/infer.py
@@ -20,38 +20,39 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--det_model", required=True, help="Path of Detection model of PPOCR.")
+ "--det_model", required=True, help="Path of Detection model of PPOCR."
+ )
parser.add_argument(
- "--cls_model",
- required=True,
- help="Path of Classification model of PPOCR.")
+ "--cls_model", required=True, help="Path of Classification model of PPOCR."
+ )
parser.add_argument(
- "--rec_model",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_model", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--rec_label_file",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_label_file", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
parser.add_argument(
"--device",
type=str,
- default='cpu',
- help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.")
+ default="cpu",
+ help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.",
+ )
parser.add_argument(
"--cpu_thread_num",
type=int,
default=9,
- help="Number of threads while inference on CPU.")
+ help="Number of threads while inference on CPU.",
+ )
return parser.parse_args()
def build_option(args):
-
det_option = fd.RuntimeOption()
cls_option = fd.RuntimeOption()
rec_option = fd.RuntimeOption()
@@ -92,23 +93,20 @@ def build_format(args):
det_format, cls_format, rec_format = build_format(args)
det_model = fd.vision.ocr.DBDetector(
- det_model_file,
- det_params_file,
- runtime_option=det_option,
- model_format=det_format)
+ det_model_file, det_params_file, runtime_option=det_option, model_format=det_format
+)
cls_model = fd.vision.ocr.Classifier(
- cls_model_file,
- cls_params_file,
- runtime_option=cls_option,
- model_format=cls_format)
+ cls_model_file, cls_params_file, runtime_option=cls_option, model_format=cls_format
+)
rec_model = fd.vision.ocr.Recognizer(
rec_model_file,
rec_params_file,
rec_label_file,
runtime_option=rec_option,
- model_format=rec_format)
+ model_format=rec_format,
+)
# Det,Rec模型启用静态shape推理
det_model.preprocessor.static_shape_infer = True
@@ -124,7 +122,8 @@ def build_format(args):
# 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None
ppocr_v3 = fd.vision.ocr.PPOCRv3(
- det_model=det_model, cls_model=cls_model, rec_model=rec_model)
+ det_model=det_model, cls_model=cls_model, rec_model=rec_model
+)
# Cls模型和Rec模型的batch size 必须设置为1, 开启静态shape推理
ppocr_v3.cls_batch_size = 1
@@ -133,7 +132,7 @@ def build_format(args):
# 预测图片准备
im = cv2.imread(args.image)
-#预测并打印结果
+# 预测并打印结果
result = ppocr_v3.predict(im)
print(result)
diff --git a/deploy/fastdeploy/rockchip/rknpu2_tools/export.py b/deploy/fastdeploy/rockchip/rknpu2_tools/export.py
index a94b348859..0d17d8bb11 100644
--- a/deploy/fastdeploy/rockchip/rknpu2_tools/export.py
+++ b/deploy/fastdeploy/rockchip/rknpu2_tools/export.py
@@ -40,21 +40,22 @@ def get_config():
model.config(
mean_values=mean_values,
std_values=std_values,
- target_platform=config.target_platform)
+ target_platform=config.target_platform,
+ )
# Load ONNX model
if yaml_config["outputs_nodes"] is None:
ret = model.load_onnx(model=yaml_config["model_path"])
else:
ret = model.load_onnx(
- model=yaml_config["model_path"],
- outputs=yaml_config["outputs_nodes"])
+ model=yaml_config["model_path"], outputs=yaml_config["outputs_nodes"]
+ )
assert ret == 0, "Load model failed!"
# Build model
ret = model.build(
- do_quantization=yaml_config["do_quantization"],
- dataset=yaml_config["dataset"])
+ do_quantization=yaml_config["do_quantization"], dataset=yaml_config["dataset"]
+ )
assert ret == 0, "Build model failed!"
# Init Runtime
@@ -71,10 +72,13 @@ def get_config():
model_base_name += name
model_device_name = config.target_platform.lower()
if yaml_config["do_quantization"]:
- model_save_name = model_base_name + "_" + model_device_name + "_quantized" + ".rknn"
+ model_save_name = (
+ model_base_name + "_" + model_device_name + "_quantized" + ".rknn"
+ )
else:
- model_save_name = model_base_name + "_" + model_device_name + "_unquantized" + ".rknn"
- ret = model.export_rknn(
- os.path.join(yaml_config["output_folder"], model_save_name))
+ model_save_name = (
+ model_base_name + "_" + model_device_name + "_unquantized" + ".rknn"
+ )
+ ret = model.export_rknn(os.path.join(yaml_config["output_folder"], model_save_name))
assert ret == 0, "Export rknn model failed!"
print("Export OK!")
diff --git a/deploy/fastdeploy/serving/fastdeploy_serving/client.py b/deploy/fastdeploy/serving/fastdeploy_serving/client.py
index 6b758c5e39..6f0b22d2ab 100755
--- a/deploy/fastdeploy/serving/fastdeploy_serving/client.py
+++ b/deploy/fastdeploy/serving/fastdeploy_serving/client.py
@@ -6,7 +6,13 @@
import json
from tritonclient import utils as client_utils
-from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
+from tritonclient.grpc import (
+ InferenceServerClient,
+ InferInput,
+ InferRequestedOutput,
+ service_pb2_grpc,
+ service_pb2,
+)
LOGGER = logging.getLogger("run_inference_on_triton")
@@ -15,34 +21,38 @@ class SyncGRPCTritonRunner:
DEFAULT_MAX_RESP_WAIT_S = 120
def __init__(
- self,
- server_url: str,
- model_name: str,
- model_version: str,
- *,
- verbose=False,
- resp_wait_s: Optional[float]=None, ):
+ self,
+ server_url: str,
+ model_name: str,
+ model_version: str,
+ *,
+ verbose=False,
+ resp_wait_s: Optional[float] = None,
+ ):
self._server_url = server_url
self._model_name = model_name
self._model_version = model_version
self._verbose = verbose
- self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
+ self._response_wait_t = (
+ self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
+ )
- self._client = InferenceServerClient(
- self._server_url, verbose=self._verbose)
+ self._client = InferenceServerClient(self._server_url, verbose=self._verbose)
error = self._verify_triton_state(self._client)
if error:
- raise RuntimeError(
- f"Could not communicate to Triton Server: {error}")
+ raise RuntimeError(f"Could not communicate to Triton Server: {error}")
LOGGER.debug(
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
- f"are up and ready!")
+ f"are up and ready!"
+ )
- model_config = self._client.get_model_config(self._model_name,
- self._model_version)
- model_metadata = self._client.get_model_metadata(self._model_name,
- self._model_version)
+ model_config = self._client.get_model_config(
+ self._model_name, self._model_version
+ )
+ model_metadata = self._client.get_model_metadata(
+ self._model_name, self._model_version
+ )
LOGGER.info(f"Model config {model_config}")
LOGGER.info(f"Model metadata {model_metadata}")
@@ -50,9 +60,7 @@ def __init__(
self._input_names = list(self._inputs)
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
self._output_names = list(self._outputs)
- self._outputs_req = [
- InferRequestedOutput(name) for name in self._outputs
- ]
+ self._outputs_req = [InferRequestedOutput(name) for name in self._outputs]
def Run(self, inputs):
"""
@@ -63,8 +71,7 @@ def Run(self, inputs):
"""
infer_inputs = []
for idx, data in enumerate(inputs):
- infer_input = InferInput(self._input_names[idx], data.shape,
- "UINT8")
+ infer_input = InferInput(self._input_names[idx], data.shape, "UINT8")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)
@@ -73,7 +80,8 @@ def Run(self, inputs):
model_version=self._model_version,
inputs=infer_inputs,
outputs=self._outputs_req,
- client_timeout=self._response_wait_t, )
+ client_timeout=self._response_wait_t,
+ )
results = {name: results.as_numpy(name) for name in self._output_names}
return results
@@ -82,8 +90,7 @@ def _verify_triton_state(self, triton_client):
return f"Triton server {self._server_url} is not live"
elif not triton_client.is_server_ready():
return f"Triton server {self._server_url} is not ready"
- elif not triton_client.is_model_ready(self._model_name,
- self._model_version):
+ elif not triton_client.is_model_ready(self._model_name, self._model_version):
return f"Model {self._model_name}:{self._model_version} is not ready"
return None
@@ -94,16 +101,30 @@ def _verify_triton_state(self, triton_client):
url = "localhost:8001"
runner = SyncGRPCTritonRunner(url, model_name, model_version)
im = cv2.imread("12.jpg")
- im = np.array([im, ])
+ im = np.array(
+ [
+ im,
+ ]
+ )
for i in range(1):
- result = runner.Run([im, ])
- batch_texts = result['rec_texts']
- batch_scores = result['rec_scores']
- batch_bboxes = result['det_bboxes']
+ result = runner.Run(
+ [
+ im,
+ ]
+ )
+ batch_texts = result["rec_texts"]
+ batch_scores = result["rec_scores"]
+ batch_bboxes = result["det_bboxes"]
for i_batch in range(len(batch_texts)):
texts = batch_texts[i_batch]
scores = batch_scores[i_batch]
bboxes = batch_bboxes[i_batch]
for i_box in range(len(texts)):
- print('text=', texts[i_box].decode('utf-8'), ' score=',
- scores[i_box], ' bbox=', bboxes[i_box])
+ print(
+ "text=",
+ texts[i_box].decode("utf-8"),
+ " score=",
+ scores[i_box],
+ " bbox=",
+ bboxes[i_box],
+ )
diff --git a/deploy/fastdeploy/serving/fastdeploy_serving/models/cls_postprocess/1/model.py b/deploy/fastdeploy/serving/fastdeploy_serving/models/cls_postprocess/1/model.py
index 891db5f24b..fbafcaef49 100644
--- a/deploy/fastdeploy/serving/fastdeploy_serving/models/cls_postprocess/1/model.py
+++ b/deploy/fastdeploy/serving/fastdeploy_serving/models/cls_postprocess/1/model.py
@@ -46,7 +46,7 @@ def initialize(self, args):
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
- self.model_config = json.loads(args['model_config'])
+ self.model_config = json.loads(args["model_config"])
print("model_config:", self.model_config)
self.input_names = []
@@ -85,15 +85,15 @@ def execute(self, requests):
responses = []
for request in requests:
infer_outputs = pb_utils.get_input_tensor_by_name(
- request, self.input_names[0])
+ request, self.input_names[0]
+ )
infer_outputs = infer_outputs.as_numpy()
results = self.postprocessor.run([infer_outputs])
- out_tensor_0 = pb_utils.Tensor(self.output_names[0],
- np.array(results[0]))
- out_tensor_1 = pb_utils.Tensor(self.output_names[1],
- np.array(results[1]))
+ out_tensor_0 = pb_utils.Tensor(self.output_names[0], np.array(results[0]))
+ out_tensor_1 = pb_utils.Tensor(self.output_names[1], np.array(results[1]))
inference_response = pb_utils.InferenceResponse(
- output_tensors=[out_tensor_0, out_tensor_1])
+ output_tensors=[out_tensor_0, out_tensor_1]
+ )
responses.append(inference_response)
return responses
@@ -102,4 +102,4 @@ def finalize(self):
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
- print('Cleaning up...')
+ print("Cleaning up...")
diff --git a/deploy/fastdeploy/serving/fastdeploy_serving/models/det_postprocess/1/model.py b/deploy/fastdeploy/serving/fastdeploy_serving/models/det_postprocess/1/model.py
index 87115c2d94..9d43832ca5 100644
--- a/deploy/fastdeploy/serving/fastdeploy_serving/models/det_postprocess/1/model.py
+++ b/deploy/fastdeploy/serving/fastdeploy_serving/models/det_postprocess/1/model.py
@@ -27,7 +27,7 @@
def get_rotate_crop_image(img, box):
- '''
+ """
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
@@ -36,7 +36,7 @@ def get_rotate_crop_image(img, box):
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
- '''
+ """
points = []
for i in range(4):
points.append([box[2 * i], box[2 * i + 1]])
@@ -45,21 +45,30 @@ def get_rotate_crop_image(img, box):
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
- np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])))
+ np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])
+ )
+ )
img_crop_height = int(
max(
- np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])))
- pts_std = np.float32([[0, 0], [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]])
+ np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])
+ )
+ )
+ pts_std = np.float32(
+ [
+ [0, 0],
+ [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height],
+ ]
+ )
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
- M, (img_crop_width, img_crop_height),
+ M,
+ (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC)
+ flags=cv2.INTER_CUBIC,
+ )
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
@@ -87,7 +96,7 @@ def initialize(self, args):
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
- self.model_config = json.loads(args['model_config'])
+ self.model_config = json.loads(args["model_config"])
print("model_config:", self.model_config)
self.input_names = []
@@ -129,11 +138,10 @@ def execute(self, requests):
responses = []
for request in requests:
infer_outputs = pb_utils.get_input_tensor_by_name(
- request, self.input_names[0])
- im_infos = pb_utils.get_input_tensor_by_name(request,
- self.input_names[1])
- ori_imgs = pb_utils.get_input_tensor_by_name(request,
- self.input_names[2])
+ request, self.input_names[0]
+ )
+ im_infos = pb_utils.get_input_tensor_by_name(request, self.input_names[1])
+ ori_imgs = pb_utils.get_input_tensor_by_name(request, self.input_names[2])
infer_outputs = infer_outputs.as_numpy()
im_infos = im_infos.as_numpy()
@@ -144,7 +152,6 @@ def execute(self, requests):
batch_rec_scores = []
batch_box_list = []
for i_batch in range(len(results)):
-
cls_labels = []
cls_scores = []
rec_texts = []
@@ -163,70 +170,81 @@ def execute(self, requests):
cls_pre_tensors = self.cls_preprocessor.run(image_list)
cls_dlpack_tensor = cls_pre_tensors[0].to_dlpack()
- cls_input_tensor = pb_utils.Tensor.from_dlpack(
- "x", cls_dlpack_tensor)
+ cls_input_tensor = pb_utils.Tensor.from_dlpack("x", cls_dlpack_tensor)
inference_request = pb_utils.InferenceRequest(
- model_name='cls_pp',
- requested_output_names=['cls_labels', 'cls_scores'],
- inputs=[cls_input_tensor])
+ model_name="cls_pp",
+ requested_output_names=["cls_labels", "cls_scores"],
+ inputs=[cls_input_tensor],
+ )
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(
- inference_response.error().message())
+ inference_response.error().message()
+ )
else:
# Extract the output tensors from the inference response.
cls_labels = pb_utils.get_output_tensor_by_name(
- inference_response, 'cls_labels')
+ inference_response, "cls_labels"
+ )
cls_labels = cls_labels.as_numpy()
cls_scores = pb_utils.get_output_tensor_by_name(
- inference_response, 'cls_scores')
+ inference_response, "cls_scores"
+ )
cls_scores = cls_scores.as_numpy()
for index in range(len(image_list)):
- if cls_labels[index] == 1 and cls_scores[
- index] > self.cls_threshold:
+ if (
+ cls_labels[index] == 1
+ and cls_scores[index] > self.cls_threshold
+ ):
image_list[index] = cv2.rotate(
- image_list[index].astype(np.float32), 1)
+ image_list[index].astype(np.float32), 1
+ )
image_list[index] = np.astype(np.uint8)
rec_pre_tensors = self.rec_preprocessor.run(image_list)
rec_dlpack_tensor = rec_pre_tensors[0].to_dlpack()
- rec_input_tensor = pb_utils.Tensor.from_dlpack(
- "x", rec_dlpack_tensor)
+ rec_input_tensor = pb_utils.Tensor.from_dlpack("x", rec_dlpack_tensor)
inference_request = pb_utils.InferenceRequest(
- model_name='rec_pp',
- requested_output_names=['rec_texts', 'rec_scores'],
- inputs=[rec_input_tensor])
+ model_name="rec_pp",
+ requested_output_names=["rec_texts", "rec_scores"],
+ inputs=[rec_input_tensor],
+ )
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(
- inference_response.error().message())
+ inference_response.error().message()
+ )
else:
# Extract the output tensors from the inference response.
rec_texts = pb_utils.get_output_tensor_by_name(
- inference_response, 'rec_texts')
+ inference_response, "rec_texts"
+ )
rec_texts = rec_texts.as_numpy()
rec_scores = pb_utils.get_output_tensor_by_name(
- inference_response, 'rec_scores')
+ inference_response, "rec_scores"
+ )
rec_scores = rec_scores.as_numpy()
batch_rec_texts.append(rec_texts)
batch_rec_scores.append(rec_scores)
out_tensor_0 = pb_utils.Tensor(
- self.output_names[0],
- np.array(
- batch_rec_texts, dtype=np.object_))
- out_tensor_1 = pb_utils.Tensor(self.output_names[1],
- np.array(batch_rec_scores))
- out_tensor_2 = pb_utils.Tensor(self.output_names[2],
- np.array(batch_box_list))
+ self.output_names[0], np.array(batch_rec_texts, dtype=np.object_)
+ )
+ out_tensor_1 = pb_utils.Tensor(
+ self.output_names[1], np.array(batch_rec_scores)
+ )
+ out_tensor_2 = pb_utils.Tensor(
+ self.output_names[2], np.array(batch_box_list)
+ )
inference_response = pb_utils.InferenceResponse(
- output_tensors=[out_tensor_0, out_tensor_1, out_tensor_2])
+ output_tensors=[out_tensor_0, out_tensor_1, out_tensor_2]
+ )
responses.append(inference_response)
return responses
@@ -235,4 +253,4 @@ def finalize(self):
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
- print('Cleaning up...')
+ print("Cleaning up...")
diff --git a/deploy/fastdeploy/serving/fastdeploy_serving/models/det_preprocess/1/model.py b/deploy/fastdeploy/serving/fastdeploy_serving/models/det_preprocess/1/model.py
index 28e838da5b..070cb45e62 100644
--- a/deploy/fastdeploy/serving/fastdeploy_serving/models/det_preprocess/1/model.py
+++ b/deploy/fastdeploy/serving/fastdeploy_serving/models/det_preprocess/1/model.py
@@ -46,7 +46,7 @@ def initialize(self, args):
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
- self.model_config = json.loads(args['model_config'])
+ self.model_config = json.loads(args["model_config"])
print("model_config:", self.model_config)
self.input_names = []
@@ -84,18 +84,19 @@ def execute(self, requests):
"""
responses = []
for request in requests:
- data = pb_utils.get_input_tensor_by_name(request,
- self.input_names[0])
+ data = pb_utils.get_input_tensor_by_name(request, self.input_names[0])
data = data.as_numpy()
outputs, im_infos = self.preprocessor.run(data)
dlpack_tensor = outputs[0].to_dlpack()
- output_tensor_0 = pb_utils.Tensor.from_dlpack(self.output_names[0],
- dlpack_tensor)
+ output_tensor_0 = pb_utils.Tensor.from_dlpack(
+ self.output_names[0], dlpack_tensor
+ )
output_tensor_1 = pb_utils.Tensor(
- self.output_names[1], np.array(
- im_infos, dtype=np.int32))
+ self.output_names[1], np.array(im_infos, dtype=np.int32)
+ )
inference_response = pb_utils.InferenceResponse(
- output_tensors=[output_tensor_0, output_tensor_1])
+ output_tensors=[output_tensor_0, output_tensor_1]
+ )
responses.append(inference_response)
return responses
@@ -104,4 +105,4 @@ def finalize(self):
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
- print('Cleaning up...')
+ print("Cleaning up...")
diff --git a/deploy/fastdeploy/serving/fastdeploy_serving/models/rec_postprocess/1/model.py b/deploy/fastdeploy/serving/fastdeploy_serving/models/rec_postprocess/1/model.py
index c046cd929b..dc4c6555ea 100755
--- a/deploy/fastdeploy/serving/fastdeploy_serving/models/rec_postprocess/1/model.py
+++ b/deploy/fastdeploy/serving/fastdeploy_serving/models/rec_postprocess/1/model.py
@@ -48,7 +48,7 @@ def initialize(self, args):
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
- self.model_config = json.loads(args['model_config'])
+ self.model_config = json.loads(args["model_config"])
print("model_config:", self.model_config)
self.input_names = []
@@ -66,7 +66,7 @@ def initialize(self, args):
dir_name = os.path.dirname(os.path.realpath(__file__)) + "/"
file_name = dir_name + "ppocr_keys_v1.txt"
- #self.label_list = load_dict()
+ # self.label_list = load_dict()
self.postprocessor = fd.vision.ocr.RecognizerPostprocessor(file_name)
def execute(self, requests):
@@ -91,16 +91,17 @@ def execute(self, requests):
responses = []
for request in requests:
infer_outputs = pb_utils.get_input_tensor_by_name(
- request, self.input_names[0])
+ request, self.input_names[0]
+ )
infer_outputs = infer_outputs.as_numpy()
results = self.postprocessor.run([infer_outputs])
out_tensor_0 = pb_utils.Tensor(
- self.output_names[0], np.array(
- results[0], dtype=np.object_))
- out_tensor_1 = pb_utils.Tensor(self.output_names[1],
- np.array(results[1]))
+ self.output_names[0], np.array(results[0], dtype=np.object_)
+ )
+ out_tensor_1 = pb_utils.Tensor(self.output_names[1], np.array(results[1]))
inference_response = pb_utils.InferenceResponse(
- output_tensors=[out_tensor_0, out_tensor_1])
+ output_tensors=[out_tensor_0, out_tensor_1]
+ )
responses.append(inference_response)
return responses
@@ -109,4 +110,4 @@ def finalize(self):
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
- print('Cleaning up...')
+ print("Cleaning up...")
diff --git a/deploy/fastdeploy/serving/simple_serving/client.py b/deploy/fastdeploy/serving/simple_serving/client.py
index 6849c22046..cbe9ead7c9 100644
--- a/deploy/fastdeploy/serving/simple_serving/client.py
+++ b/deploy/fastdeploy/serving/simple_serving/client.py
@@ -4,7 +4,7 @@
import fastdeploy as fd
from fastdeploy.serving.utils import cv2_to_base64
-if __name__ == '__main__':
+if __name__ == "__main__":
url = "http://127.0.0.1:8000/fd/ppocrv3"
headers = {"Content-Type": "application/json"}
diff --git a/deploy/fastdeploy/serving/simple_serving/server.py b/deploy/fastdeploy/serving/simple_serving/server.py
index 0078b7112f..3a6f88727a 100644
--- a/deploy/fastdeploy/serving/simple_serving/server.py
+++ b/deploy/fastdeploy/serving/simple_serving/server.py
@@ -6,14 +6,14 @@
logging.getLogger().setLevel(logging.INFO)
# Configurations
-det_model_dir = 'ch_PP-OCRv3_det_infer'
-cls_model_dir = 'ch_ppocr_mobile_v2.0_cls_infer'
-rec_model_dir = 'ch_PP-OCRv3_rec_infer'
-rec_label_file = 'ppocr_keys_v1.txt'
-device = 'cpu'
+det_model_dir = "ch_PP-OCRv3_det_infer"
+cls_model_dir = "ch_ppocr_mobile_v2.0_cls_infer"
+rec_model_dir = "ch_PP-OCRv3_rec_infer"
+rec_label_file = "ppocr_keys_v1.txt"
+device = "cpu"
# backend: ['paddle', 'trt'], you can also use other backends, but need to modify
# the runtime option below
-backend = 'paddle'
+backend = "paddle"
# Prepare models
# Detection model
@@ -28,46 +28,49 @@
# Setup runtime option to select hardware, backend, etc.
option = fd.RuntimeOption()
-if device.lower() == 'gpu':
+if device.lower() == "gpu":
option.use_gpu()
-if backend == 'trt':
+if backend == "trt":
option.use_trt_backend()
else:
option.use_paddle_infer_backend()
det_option = option
-det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
- [1, 3, 960, 960])
+det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640], [1, 3, 960, 960])
# det_option.set_trt_cache_file("det_trt_cache.trt")
print(det_model_file, det_params_file)
det_model = fd.vision.ocr.DBDetector(
- det_model_file, det_params_file, runtime_option=det_option)
+ det_model_file, det_params_file, runtime_option=det_option
+)
cls_batch_size = 1
rec_batch_size = 6
cls_option = option
-cls_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [cls_batch_size, 3, 48, 320],
- [cls_batch_size, 3, 48, 1024])
+cls_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [cls_batch_size, 3, 48, 320], [cls_batch_size, 3, 48, 1024]
+)
# cls_option.set_trt_cache_file("cls_trt_cache.trt")
cls_model = fd.vision.ocr.Classifier(
- cls_model_file, cls_params_file, runtime_option=cls_option)
+ cls_model_file, cls_params_file, runtime_option=cls_option
+)
rec_option = option
-rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [rec_batch_size, 3, 48, 320],
- [rec_batch_size, 3, 48, 2304])
+rec_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [rec_batch_size, 3, 48, 320], [rec_batch_size, 3, 48, 2304]
+)
# rec_option.set_trt_cache_file("rec_trt_cache.trt")
rec_model = fd.vision.ocr.Recognizer(
- rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
+ rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option
+)
# Create PPOCRv3 pipeline
ppocr_v3 = fd.vision.ocr.PPOCRv3(
- det_model=det_model, cls_model=cls_model, rec_model=rec_model)
+ det_model=det_model, cls_model=cls_model, rec_model=rec_model
+)
ppocr_v3.cls_batch_size = cls_batch_size
ppocr_v3.rec_batch_size = rec_batch_size
@@ -77,4 +80,5 @@
app.register(
task_name="fd/ppocrv3",
model_handler=fd.serving.handler.VisionModelHandler,
- predictor=ppocr_v3)
+ predictor=ppocr_v3,
+)
diff --git a/deploy/fastdeploy/sophgo/python/infer.py b/deploy/fastdeploy/sophgo/python/infer.py
index 356317099f..c92ee7ad11 100644
--- a/deploy/fastdeploy/sophgo/python/infer.py
+++ b/deploy/fastdeploy/sophgo/python/infer.py
@@ -6,23 +6,23 @@
def parse_arguments():
import argparse
import ast
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--det_model", required=True, help="Path of Detection model of PPOCR.")
+ "--det_model", required=True, help="Path of Detection model of PPOCR."
+ )
parser.add_argument(
- "--cls_model",
- required=True,
- help="Path of Classification model of PPOCR.")
+ "--cls_model", required=True, help="Path of Classification model of PPOCR."
+ )
parser.add_argument(
- "--rec_model",
- required=True,
- help="Path of Recognization model of PPOCR.")
+ "--rec_model", required=True, help="Path of Recognization model of PPOCR."
+ )
parser.add_argument(
- "--rec_label_file",
- required=True,
- help="Path of Recognization label of PPOCR.")
+ "--rec_label_file", required=True, help="Path of Recognization label of PPOCR."
+ )
parser.add_argument(
- "--image", type=str, required=True, help="Path of test image file.")
+ "--image", type=str, required=True, help="Path of test image file."
+ )
return parser.parse_args()
@@ -53,32 +53,33 @@ def parse_arguments():
# 注意: 需要在检测模型创建完成后,再设置分类模型的动态输入并创建分类模型, 识别模型同理.
# 如果用户想要自己改动检测模型的输入shape, 我们建议用户把检测模型的长和高设置为32的倍数.
det_option = runtime_option
-det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
- [1, 3, 960, 960])
+det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640], [1, 3, 960, 960])
# 用户可以把TRT引擎文件保存至本地
# det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt")
det_model = fd.vision.ocr.DBDetector(
det_model_file,
det_params_file,
runtime_option=det_option,
- model_format=fd.ModelFormat.SOPHGO)
+ model_format=fd.ModelFormat.SOPHGO,
+)
cls_option = runtime_option
-cls_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [cls_batch_size, 3, 48, 320],
- [cls_batch_size, 3, 48, 1024])
+cls_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [cls_batch_size, 3, 48, 320], [cls_batch_size, 3, 48, 1024]
+)
# 用户可以把TRT引擎文件保存至本地
# cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt")
cls_model = fd.vision.ocr.Classifier(
cls_model_file,
cls_params_file,
runtime_option=cls_option,
- model_format=fd.ModelFormat.SOPHGO)
+ model_format=fd.ModelFormat.SOPHGO,
+)
rec_option = runtime_option
-rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
- [rec_batch_size, 3, 48, 320],
- [rec_batch_size, 3, 48, 2304])
+rec_option.set_trt_input_shape(
+ "x", [1, 3, 48, 10], [rec_batch_size, 3, 48, 320], [rec_batch_size, 3, 48, 2304]
+)
# 用户可以把TRT引擎文件保存至本地
# rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt")
rec_model = fd.vision.ocr.Recognizer(
@@ -86,11 +87,13 @@ def parse_arguments():
rec_params_file,
rec_label_file,
runtime_option=rec_option,
- model_format=fd.ModelFormat.SOPHGO)
+ model_format=fd.ModelFormat.SOPHGO,
+)
# 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None
ppocr_v3 = fd.vision.ocr.PPOCRv3(
- det_model=det_model, cls_model=cls_model, rec_model=rec_model)
+ det_model=det_model, cls_model=cls_model, rec_model=rec_model
+)
# 需要使用下行代码, 来启用rec模型的静态shape推理,这里rec模型的静态输入为[3, 48, 584]
rec_model.preprocessor.static_shape_infer = True
@@ -105,7 +108,7 @@ def parse_arguments():
# 预测图片准备
im = cv2.imread(args.image)
-#预测并打印结果
+# 预测并打印结果
result = ppocr_v3.predict(im)
print(result)
diff --git a/deploy/hubserving/kie_ser/__init__.py b/deploy/hubserving/kie_ser/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/kie_ser/__init__.py
+++ b/deploy/hubserving/kie_ser/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/kie_ser/module.py b/deploy/hubserving/kie_ser/module.py
index f0ef3585d8..2c046be07b 100644
--- a/deploy/hubserving/kie_ser/module.py
+++ b/deploy/hubserving/kie_ser/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
@@ -42,7 +43,8 @@
summary="kie ser service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/KIE_SER")
+ type="cv/KIE_SER",
+)
class KIESer(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -67,7 +69,9 @@ def _initialize(self, use_gpu=False, enable_mkldnn=False):
self.ser_predictor = SerPredictor(cfg)
- def merge_configs(self, ):
+ def merge_configs(
+ self,
+ ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
@@ -84,8 +88,9 @@ def merge_configs(self, ):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -109,7 +114,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
@@ -134,12 +141,12 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
ocr = OCRSystem()
ocr._initialize()
image_path = [
- './doc/imgs/11.jpg',
- './doc/imgs/12.jpg',
+ "./doc/imgs/11.jpg",
+ "./doc/imgs/12.jpg",
]
res = ocr.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/kie_ser_re/__init__.py b/deploy/hubserving/kie_ser_re/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/kie_ser_re/__init__.py
+++ b/deploy/hubserving/kie_ser_re/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/kie_ser_re/module.py b/deploy/hubserving/kie_ser_re/module.py
index 5a63a8a1f1..4f2bc4479c 100644
--- a/deploy/hubserving/kie_ser_re/module.py
+++ b/deploy/hubserving/kie_ser_re/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
@@ -42,7 +43,8 @@
summary="kie ser re service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/KIE_SER_RE")
+ type="cv/KIE_SER_RE",
+)
class KIESerRE(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -67,7 +69,9 @@ def _initialize(self, use_gpu=False, enable_mkldnn=False):
self.ser_re_predictor = SerRePredictor(cfg)
- def merge_configs(self, ):
+ def merge_configs(
+ self,
+ ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
@@ -84,8 +88,9 @@ def merge_configs(self, ):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -109,7 +114,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
@@ -136,12 +143,12 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
ocr = OCRSystem()
ocr._initialize()
image_path = [
- './doc/imgs/11.jpg',
- './doc/imgs/12.jpg',
+ "./doc/imgs/11.jpg",
+ "./doc/imgs/12.jpg",
]
res = ocr.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/ocr_cls/__init__.py b/deploy/hubserving/ocr_cls/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/ocr_cls/__init__.py
+++ b/deploy/hubserving/ocr_cls/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/ocr_cls/module.py b/deploy/hubserving/ocr_cls/module.py
index 8b70f0376e..a201acbebe 100644
--- a/deploy/hubserving/ocr_cls/module.py
+++ b/deploy/hubserving/ocr_cls/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
import paddlehub
@@ -38,7 +39,8 @@
summary="ocr angle cls service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/text_angle_cls")
+ type="cv/text_angle_cls",
+)
class OCRCls(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -63,7 +65,9 @@ def _initialize(self, use_gpu=False, enable_mkldnn=False):
self.text_classifier = TextClassifier(cfg)
- def merge_configs(self, ):
+ def merge_configs(
+ self,
+ ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
@@ -80,8 +84,9 @@ def merge_configs(self, ):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -106,7 +111,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
img_list = []
for img in predicted_data:
@@ -119,10 +126,12 @@ def predict(self, images=[], paths=[]):
img_list, cls_res, predict_time = self.text_classifier(img_list)
for dno in range(len(cls_res)):
angle, score = cls_res[dno]
- rec_res_final.append({
- 'angle': angle,
- 'confidence': float(score),
- })
+ rec_res_final.append(
+ {
+ "angle": angle,
+ "confidence": float(score),
+ }
+ )
except Exception as e:
print(e)
return [[]]
@@ -139,13 +148,13 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
ocr = OCRCls()
ocr._initialize()
image_path = [
- './doc/imgs_words/ch/word_1.jpg',
- './doc/imgs_words/ch/word_2.jpg',
- './doc/imgs_words/ch/word_3.jpg',
+ "./doc/imgs_words/ch/word_1.jpg",
+ "./doc/imgs_words/ch/word_2.jpg",
+ "./doc/imgs_words/ch/word_3.jpg",
]
res = ocr.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/ocr_cls/params.py b/deploy/hubserving/ocr_cls/params.py
index fe4e84843a..589fd33c79 100755
--- a/deploy/hubserving/ocr_cls/params.py
+++ b/deploy/hubserving/ocr_cls/params.py
@@ -24,10 +24,10 @@ class Config(object):
def read_params():
cfg = Config()
- #params for text classifier
+ # params for text classifier
cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v2.0_cls_infer/"
cfg.cls_image_shape = "3, 48, 192"
- cfg.label_list = ['0', '180']
+ cfg.label_list = ["0", "180"]
cfg.cls_batch_num = 30
cfg.cls_thresh = 0.9
diff --git a/deploy/hubserving/ocr_det/__init__.py b/deploy/hubserving/ocr_det/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/ocr_det/__init__.py
+++ b/deploy/hubserving/ocr_det/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/ocr_det/module.py b/deploy/hubserving/ocr_det/module.py
index 3dbaf161cd..c14b7e2a1a 100644
--- a/deploy/hubserving/ocr_det/module.py
+++ b/deploy/hubserving/ocr_det/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
@@ -40,7 +41,8 @@
summary="ocr detection service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/text_detection")
+ type="cv/text_detection",
+)
class OCRDet(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -65,7 +67,9 @@ def _initialize(self, use_gpu=False, enable_mkldnn=False):
self.text_detector = TextDetector(cfg)
- def merge_configs(self, ):
+ def merge_configs(
+ self,
+ ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
@@ -82,8 +86,9 @@ def merge_configs(self, ):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -108,7 +113,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
@@ -121,9 +128,9 @@ def predict(self, images=[], paths=[]):
rec_res_final = []
for dno in range(len(dt_boxes)):
- rec_res_final.append({
- 'text_region': dt_boxes[dno].astype(np.int32).tolist()
- })
+ rec_res_final.append(
+ {"text_region": dt_boxes[dno].astype(np.int32).tolist()}
+ )
all_results.append(rec_res_final)
return all_results
@@ -137,12 +144,12 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
ocr = OCRDet()
ocr._initialize()
image_path = [
- './doc/imgs/11.jpg',
- './doc/imgs/12.jpg',
+ "./doc/imgs/11.jpg",
+ "./doc/imgs/12.jpg",
]
res = ocr.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/ocr_det/params.py b/deploy/hubserving/ocr_det/params.py
index c4aa13817f..2f83e4e1dd 100755
--- a/deploy/hubserving/ocr_det/params.py
+++ b/deploy/hubserving/ocr_det/params.py
@@ -24,13 +24,13 @@ class Config(object):
def read_params():
cfg = Config()
- #params for text detector
+ # params for text detector
cfg.det_algorithm = "DB"
cfg.det_model_dir = "./inference/ch_PP-OCRv3_det_infer/"
cfg.det_limit_side_len = 960
- cfg.det_limit_type = 'max'
+ cfg.det_limit_type = "max"
- #DB parmas
+ # DB parmas
cfg.det_db_thresh = 0.3
cfg.det_db_box_thresh = 0.6
cfg.det_db_unclip_ratio = 1.5
diff --git a/deploy/hubserving/ocr_rec/__init__.py b/deploy/hubserving/ocr_rec/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/ocr_rec/__init__.py
+++ b/deploy/hubserving/ocr_rec/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/ocr_rec/module.py b/deploy/hubserving/ocr_rec/module.py
index 9fae54e2a3..22b7ef5681 100644
--- a/deploy/hubserving/ocr_rec/module.py
+++ b/deploy/hubserving/ocr_rec/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
import paddlehub
@@ -38,7 +39,8 @@
summary="ocr recognition service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/text_recognition")
+ type="cv/text_recognition",
+)
class OCRRec(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -63,7 +65,9 @@ def _initialize(self, use_gpu=False, enable_mkldnn=False):
self.text_recognizer = TextRecognizer(cfg)
- def merge_configs(self, ):
+ def merge_configs(
+ self,
+ ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
@@ -80,8 +84,9 @@ def merge_configs(self, ):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -106,7 +111,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
img_list = []
for img in predicted_data:
@@ -119,10 +126,12 @@ def predict(self, images=[], paths=[]):
rec_res, predict_time = self.text_recognizer(img_list)
for dno in range(len(rec_res)):
text, score = rec_res[dno]
- rec_res_final.append({
- 'text': text,
- 'confidence': float(score),
- })
+ rec_res_final.append(
+ {
+ "text": text,
+ "confidence": float(score),
+ }
+ )
except Exception as e:
print(e)
return [[]]
@@ -139,13 +148,13 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
ocr = OCRRec()
ocr._initialize()
image_path = [
- './doc/imgs_words/ch/word_1.jpg',
- './doc/imgs_words/ch/word_2.jpg',
- './doc/imgs_words/ch/word_3.jpg',
+ "./doc/imgs_words/ch/word_1.jpg",
+ "./doc/imgs_words/ch/word_2.jpg",
+ "./doc/imgs_words/ch/word_3.jpg",
]
res = ocr.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/ocr_rec/params.py b/deploy/hubserving/ocr_rec/params.py
index b8854c8c07..0496b4743d 100644
--- a/deploy/hubserving/ocr_rec/params.py
+++ b/deploy/hubserving/ocr_rec/params.py
@@ -24,7 +24,7 @@ class Config(object):
def read_params():
cfg = Config()
- #params for text recognizer
+ # params for text recognizer
cfg.rec_algorithm = "CRNN"
cfg.rec_model_dir = "./inference/ch_PP-OCRv3_rec_infer/"
diff --git a/deploy/hubserving/ocr_system/__init__.py b/deploy/hubserving/ocr_system/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/ocr_system/__init__.py
+++ b/deploy/hubserving/ocr_system/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py
index 192fff9650..fdcbe12e35 100644
--- a/deploy/hubserving/ocr_system/module.py
+++ b/deploy/hubserving/ocr_system/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
@@ -41,7 +42,8 @@
summary="ocr system service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/PP-OCR_system")
+ type="cv/PP-OCR_system",
+)
class OCRSystem(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -66,7 +68,9 @@ def _initialize(self, use_gpu=False, enable_mkldnn=False):
self.text_sys = TextSystem(cfg)
- def merge_configs(self, ):
+ def merge_configs(
+ self,
+ ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
@@ -83,8 +87,9 @@ def merge_configs(self, ):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -109,7 +114,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
@@ -127,11 +134,13 @@ def predict(self, images=[], paths=[]):
for dno in range(dt_num):
text, score = rec_res[dno]
- rec_res_final.append({
- 'text': text,
- 'confidence': float(score),
- 'text_region': dt_boxes[dno].astype(np.int32).tolist()
- })
+ rec_res_final.append(
+ {
+ "text": text,
+ "confidence": float(score),
+ "text_region": dt_boxes[dno].astype(np.int32).tolist(),
+ }
+ )
all_results.append(rec_res_final)
return all_results
@@ -145,12 +154,12 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
ocr = OCRSystem()
ocr._initialize()
image_path = [
- './doc/imgs/11.jpg',
- './doc/imgs/12.jpg',
+ "./doc/imgs/11.jpg",
+ "./doc/imgs/12.jpg",
]
res = ocr.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/ocr_system/params.py b/deploy/hubserving/ocr_system/params.py
index 4df1979a10..12b6283d74 100755
--- a/deploy/hubserving/ocr_system/params.py
+++ b/deploy/hubserving/ocr_system/params.py
@@ -24,25 +24,25 @@ class Config(object):
def read_params():
cfg = Config()
- #params for text detector
+ # params for text detector
cfg.det_algorithm = "DB"
cfg.det_model_dir = "./inference/ch_PP-OCRv3_det_infer/"
cfg.det_limit_side_len = 960
- cfg.det_limit_type = 'max'
+ cfg.det_limit_type = "max"
- #DB parmas
+ # DB parmas
cfg.det_db_thresh = 0.3
cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
- #EAST parmas
+ # EAST parmas
cfg.det_east_score_thresh = 0.8
cfg.det_east_cover_thresh = 0.1
cfg.det_east_nms_thresh = 0.2
- #params for text recognizer
+ # params for text recognizer
cfg.rec_algorithm = "CRNN"
cfg.rec_model_dir = "./inference/ch_PP-OCRv3_rec_infer/"
@@ -53,11 +53,11 @@ def read_params():
cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
cfg.use_space_char = True
- #params for text classifier
+ # params for text classifier
cfg.use_angle_cls = True
cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v2.0_cls_infer/"
cfg.cls_image_shape = "3, 48, 192"
- cfg.label_list = ['0', '180']
+ cfg.label_list = ["0", "180"]
cfg.cls_batch_num = 30
cfg.cls_thresh = 0.9
diff --git a/deploy/hubserving/structure_layout/__init__.py b/deploy/hubserving/structure_layout/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/structure_layout/__init__.py
+++ b/deploy/hubserving/structure_layout/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/structure_layout/module.py b/deploy/hubserving/structure_layout/module.py
index 7091f123fc..4962bee8cb 100644
--- a/deploy/hubserving/structure_layout/module.py
+++ b/deploy/hubserving/structure_layout/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
@@ -40,7 +41,8 @@
summary="PP-Structure layout service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/structure_layout")
+ type="cv/structure_layout",
+)
class LayoutPredictor(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -81,8 +83,9 @@ def merge_configs(self):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -107,7 +110,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
@@ -121,8 +126,8 @@ def predict(self, images=[], paths=[]):
logger.info("Predict time: {}".format(elapse))
for item in res:
- item['bbox'] = item['bbox'].tolist()
- all_results.append({'layout': res})
+ item["bbox"] = item["bbox"].tolist()
+ all_results.append({"layout": res})
return all_results
@serving
@@ -135,9 +140,9 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
layout = LayoutPredictor()
layout._initialize()
- image_path = ['./ppstructure/docs/table/1.png']
+ image_path = ["./ppstructure/docs/table/1.png"]
res = layout.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/structure_layout/params.py b/deploy/hubserving/structure_layout/params.py
index 448b66ac42..5c0b2e8a37 100755
--- a/deploy/hubserving/structure_layout/params.py
+++ b/deploy/hubserving/structure_layout/params.py
@@ -25,8 +25,8 @@ def read_params():
cfg = Config()
# params for layout analysis
- cfg.layout_model_dir = './inference/picodet_lcnet_x1_0_fgd_layout_infer/'
- cfg.layout_dict_path = './ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt'
+ cfg.layout_model_dir = "./inference/picodet_lcnet_x1_0_fgd_layout_infer/"
+ cfg.layout_dict_path = "./ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt"
cfg.layout_score_threshold = 0.5
cfg.layout_nms_threshold = 0.5
return cfg
diff --git a/deploy/hubserving/structure_system/__init__.py b/deploy/hubserving/structure_system/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/structure_system/__init__.py
+++ b/deploy/hubserving/structure_system/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/structure_system/module.py b/deploy/hubserving/structure_system/module.py
index 61c93bb146..35084683ce 100644
--- a/deploy/hubserving/structure_system/module.py
+++ b/deploy/hubserving/structure_system/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
@@ -42,7 +43,8 @@
summary="PP-Structure system service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/structure_system")
+ type="cv/structure_system",
+)
class StructureSystem(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -84,8 +86,9 @@ def merge_configs(self):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -110,7 +113,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
@@ -126,9 +131,9 @@ def predict(self, images=[], paths=[]):
# parse result
res_final = []
for region in res:
- region.pop('img')
+ region.pop("img")
res_final.append(region)
- all_results.append({'regions': res_final})
+ all_results.append({"regions": res_final})
return all_results
@serving
@@ -141,9 +146,9 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
structure_system = StructureSystem()
structure_system._initialize()
- image_path = ['./ppstructure/docs/table/1.png']
+ image_path = ["./ppstructure/docs/table/1.png"]
res = structure_system.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/structure_system/params.py b/deploy/hubserving/structure_system/params.py
index fe691fbc2d..c5d28de1ea 100755
--- a/deploy/hubserving/structure_system/params.py
+++ b/deploy/hubserving/structure_system/params.py
@@ -23,11 +23,11 @@ def read_params():
cfg = table_read_params()
# params for layout parser model
- cfg.layout_model_dir = ''
- cfg.layout_dict_path = './ppocr/utils/dict/layout_publaynet_dict.txt'
+ cfg.layout_model_dir = ""
+ cfg.layout_dict_path = "./ppocr/utils/dict/layout_publaynet_dict.txt"
cfg.layout_score_threshold = 0.5
cfg.layout_nms_threshold = 0.5
- cfg.mode = 'structure'
- cfg.output = './output'
+ cfg.mode = "structure"
+ cfg.output = "./output"
return cfg
diff --git a/deploy/hubserving/structure_table/__init__.py b/deploy/hubserving/structure_table/__init__.py
index c747d3e7ae..97043fd7ba 100644
--- a/deploy/hubserving/structure_table/__init__.py
+++ b/deploy/hubserving/structure_table/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/deploy/hubserving/structure_table/module.py b/deploy/hubserving/structure_table/module.py
index b4432b2d7b..230b4043e2 100644
--- a/deploy/hubserving/structure_table/module.py
+++ b/deploy/hubserving/structure_table/module.py
@@ -18,6 +18,7 @@
import os
import sys
+
sys.path.insert(0, ".")
import copy
@@ -42,7 +43,8 @@
summary="PP-Structure table service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
- type="cv/structure_table")
+ type="cv/structure_table",
+)
class TableSystem(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
@@ -83,8 +85,9 @@ def merge_configs(self):
def read_images(self, paths=[]):
images = []
for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
+ assert os.path.isfile(img_path), "The {} isn't a valid file.".format(
+ img_path
+ )
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
@@ -109,7 +112,9 @@ def predict(self, images=[], paths=[]):
else:
raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
+ assert (
+ predicted_data != []
+ ), "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
@@ -122,7 +127,7 @@ def predict(self, images=[], paths=[]):
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))
- all_results.append({'html': res['html']})
+ all_results.append({"html": res["html"]})
return all_results
@serving
@@ -135,9 +140,9 @@ def serving_method(self, images, **kwargs):
return results
-if __name__ == '__main__':
+if __name__ == "__main__":
table_system = TableSystem()
table_system._initialize()
- image_path = ['./ppstructure/docs/table/table.jpg']
+ image_path = ["./ppstructure/docs/table/table.jpg"]
res = table_system.predict(paths=image_path)
print(res)
diff --git a/deploy/hubserving/structure_table/params.py b/deploy/hubserving/structure_table/params.py
index 9632c2f70b..1b0bf0ca89 100755
--- a/deploy/hubserving/structure_table/params.py
+++ b/deploy/hubserving/structure_table/params.py
@@ -24,7 +24,7 @@ def read_params():
# params for table structure model
cfg.table_max_len = 488
- cfg.table_model_dir = './inference/en_ppocr_mobile_v2.0_table_structure_infer/'
- cfg.table_char_dict_path = './ppocr/utils/dict/table_structure_dict.txt'
+ cfg.table_model_dir = "./inference/en_ppocr_mobile_v2.0_table_structure_infer/"
+ cfg.table_char_dict_path = "./ppocr/utils/dict/table_structure_dict.txt"
cfg.show_log = False
return cfg
diff --git a/deploy/lite/ocr_db_crnn.cc b/deploy/lite/ocr_db_crnn.cc
index fde0d07d6c..46ff22dff8 100644
--- a/deploy/lite/ocr_db_crnn.cc
+++ b/deploy/lite/ocr_db_crnn.cc
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include
#include "paddle_api.h" // NOLINT
#include "paddle_place.h"
+#include
+#include "AutoLog/auto_log/lite_autolog.h"
#include "cls_process.h"
#include "crnn_process.h"
#include "db_post_process.h"
-#include "AutoLog/auto_log/lite_autolog.h"
using namespace paddle::lite_api; // NOLINT
using namespace std;
@@ -161,8 +161,7 @@ void RunRecModel(std::vector>> boxes, cv::Mat img,
std::vector &rec_text_score,
std::vector charactor_dict,
std::shared_ptr predictor_cls,
- int use_direction_classify,
- std::vector *times,
+ int use_direction_classify, std::vector *times,
int rec_image_height) {
std::vector mean = {0.5f, 0.5f, 0.5f};
std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
@@ -234,18 +233,20 @@ void RunRecModel(std::vector>> boxes, cv::Mat img,
rec_text_score.push_back(score);
auto postprocess_end = std::chrono::steady_clock::now();
- std::chrono::duration preprocess_diff = preprocess_end - preprocess_start;
+ std::chrono::duration preprocess_diff =
+ preprocess_end - preprocess_start;
time_info[0] += double(preprocess_diff.count() * 1000);
- std::chrono::duration inference_diff = inference_end - inference_start;
+ std::chrono::duration inference_diff =
+ inference_end - inference_start;
time_info[1] += double(inference_diff.count() * 1000);
- std::chrono::duration postprocess_diff = postprocess_end - postprocess_start;
+ std::chrono::duration postprocess_diff =
+ postprocess_end - postprocess_start;
time_info[2] += double(postprocess_diff.count() * 1000);
-
}
-times->push_back(time_info[0]);
-times->push_back(time_info[1]);
-times->push_back(time_info[2]);
+ times->push_back(time_info[0]);
+ times->push_back(time_info[1]);
+ times->push_back(time_info[2]);
}
std::vector>>
@@ -257,7 +258,7 @@ RunDetModel(std::shared_ptr predictor, cv::Mat img,
cv::Mat srcimg;
img.copyTo(srcimg);
-
+
auto preprocess_start = std::chrono::steady_clock::now();
std::vector ratio_hw;
img = DetResizeImg(img, max_side_len, ratio_hw);
@@ -318,17 +319,20 @@ RunDetModel(std::shared_ptr predictor, cv::Mat img,
FilterTagDetRes(boxes, ratio_hw[0], ratio_hw[1], srcimg);
auto postprocess_end = std::chrono::steady_clock::now();
- std::chrono::duration preprocess_diff = preprocess_end - preprocess_start;
+ std::chrono::duration preprocess_diff =
+ preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration inference_diff = inference_end - inference_start;
times->push_back(double(inference_diff.count() * 1000));
- std::chrono::duration postprocess_diff = postprocess_end - postprocess_start;
+ std::chrono::duration postprocess_diff =
+ postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
return filter_boxes;
}
-std::shared_ptr loadModel(std::string model_file, int num_threads) {
+std::shared_ptr loadModel(std::string model_file,
+ int num_threads) {
MobileConfig config;
config.set_model_from_file(model_file);
@@ -393,36 +397,45 @@ std::map LoadConfigTxt(std::string config_path) {
}
void check_params(int argc, char **argv) {
- if (argc<=1 || (strcmp(argv[1], "det")!=0 && strcmp(argv[1], "rec")!=0 && strcmp(argv[1], "system")!=0)) {
+ if (argc <= 1 ||
+ (strcmp(argv[1], "det") != 0 && strcmp(argv[1], "rec") != 0 &&
+ strcmp(argv[1], "system") != 0)) {
std::cerr << "Please choose one mode of [det, rec, system] !" << std::endl;
exit(1);
}
if (strcmp(argv[1], "det") == 0) {
- if (argc < 9){
- std::cerr << "[ERROR] usage:" << argv[0]
- << " det det_model runtime_device num_threads batchsize img_dir det_config lite_benchmark_value" << std::endl;
- exit(1);
- }
+ if (argc < 9) {
+ std::cerr << "[ERROR] usage:" << argv[0]
+ << " det det_model runtime_device num_threads batchsize "
+ "img_dir det_config lite_benchmark_value"
+ << std::endl;
+ exit(1);
+ }
}
if (strcmp(argv[1], "rec") == 0) {
- if (argc < 9){
- std::cerr << "[ERROR] usage:" << argv[0]
- << " rec rec_model runtime_device num_threads batchsize img_dir key_txt lite_benchmark_value" << std::endl;
- exit(1);
- }
+ if (argc < 9) {
+ std::cerr << "[ERROR] usage:" << argv[0]
+ << " rec rec_model runtime_device num_threads batchsize "
+ "img_dir key_txt lite_benchmark_value"
+ << std::endl;
+ exit(1);
+ }
}
if (strcmp(argv[1], "system") == 0) {
- if (argc < 12){
- std::cerr << "[ERROR] usage:" << argv[0]
- << " system det_model rec_model clas_model runtime_device num_threads batchsize img_dir det_config key_txt lite_benchmark_value" << std::endl;
- exit(1);
- }
+ if (argc < 12) {
+ std::cerr << "[ERROR] usage:" << argv[0]
+ << " system det_model rec_model clas_model runtime_device "
+ "num_threads batchsize img_dir det_config key_txt "
+ "lite_benchmark_value"
+ << std::endl;
+ exit(1);
+ }
}
}
-void system(char **argv){
+void system(char **argv) {
std::string det_model_file = argv[2];
std::string rec_model_file = argv[3];
std::string cls_model_file = argv[4];
@@ -435,8 +448,8 @@ void system(char **argv){
std::string dict_path = argv[11];
if (strcmp(argv[6], "FP32") != 0 && strcmp(argv[6], "INT8") != 0) {
- std::cerr << "Only support FP32 or INT8." << std::endl;
- exit(1);
+ std::cerr << "Only support FP32 or INT8." << std::endl;
+ exit(1);
}
std::vector cv_all_img_names;
@@ -462,28 +475,29 @@ void system(char **argv){
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
- std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << std::endl;
+ std::cerr << "[ERROR] image read failed! image path: "
+ << cv_all_img_names[i] << std::endl;
exit(1);
}
std::vector det_times;
auto boxes = RunDetModel(det_predictor, srcimg, Config, &det_times);
-
+
std::vector rec_text;
std::vector rec_text_score;
-
+
std::vector rec_times;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
- charactor_dict, cls_predictor, use_direction_classify, &rec_times, rec_image_height);
-
+ charactor_dict, cls_predictor, use_direction_classify,
+ &rec_times, rec_image_height);
+
//// visualization
auto img_vis = Visualization(srcimg, boxes);
-
+
//// print recognized text
for (int i = 0; i < rec_text.size(); i++) {
std::cout << i << "\t" << rec_text[i] << "\t" << rec_text_score[i]
- << std::endl;
-
+ << std::endl;
}
det_time_info[0] += det_times[0];
@@ -494,22 +508,14 @@ void system(char **argv){
rec_time_info[2] += rec_times[2];
}
if (strcmp(argv[12], "True") == 0) {
- AutoLogger autolog_det(det_model_file,
- runtime_device,
- std::stoi(num_threads),
- std::stoi(batchsize),
- "dynamic",
- precision,
- det_time_info,
- cv_all_img_names.size());
- AutoLogger autolog_rec(rec_model_file,
- runtime_device,
- std::stoi(num_threads),
- std::stoi(batchsize),
- "dynamic",
- precision,
- rec_time_info,
- cv_all_img_names.size());
+ AutoLogger autolog_det(det_model_file, runtime_device,
+ std::stoi(num_threads), std::stoi(batchsize),
+ "dynamic", precision, det_time_info,
+ cv_all_img_names.size());
+ AutoLogger autolog_rec(rec_model_file, runtime_device,
+ std::stoi(num_threads), std::stoi(batchsize),
+ "dynamic", precision, rec_time_info,
+ cv_all_img_names.size());
autolog_det.report();
std::cout << std::endl;
@@ -527,8 +533,8 @@ void det(int argc, char **argv) {
std::string det_config_path = argv[8];
if (strcmp(argv[4], "FP32") != 0 && strcmp(argv[4], "INT8") != 0) {
- std::cerr << "Only support FP32 or INT8." << std::endl;
- exit(1);
+ std::cerr << "Only support FP32 or INT8." << std::endl;
+ exit(1);
}
std::vector cv_all_img_names;
@@ -545,7 +551,8 @@ void det(int argc, char **argv) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
- std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << std::endl;
+ std::cerr << "[ERROR] image read failed! image path: "
+ << cv_all_img_names[i] << std::endl;
exit(1);
}
@@ -556,10 +563,10 @@ void det(int argc, char **argv) {
auto img_vis = Visualization(srcimg, boxes);
std::cout << boxes.size() << " bboxes have detected:" << std::endl;
- for (int i=0; i upper_left = {0, 0};
std::vector upper_right = {width, 0};
std::vector lower_right = {width, height};
- std::vector lower_left = {0, height};
- std::vector> box = {upper_left, upper_right, lower_right, lower_left};
+ std::vector lower_left = {0, height};
+ std::vector> box = {upper_left, upper_right, lower_right,
+ lower_left};
std::vector>> boxes = {box};
std::vector rec_text;
@@ -636,7 +640,7 @@ void rec(int argc, char **argv) {
std::vector times;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict, cls_predictor, 0, ×, rec_image_height);
-
+
//// print recognized text
for (int i = 0; i < rec_text.size(); i++) {
std::cout << i << "\t" << rec_text[i] << "\t" << rec_text_score[i]
@@ -648,13 +652,8 @@ void rec(int argc, char **argv) {
}
// TODO: support autolog
if (strcmp(argv[9], "True") == 0) {
- AutoLogger autolog(rec_model_file,
- runtime_device,
- std::stoi(num_threads),
- std::stoi(batchsize),
- "dynamic",
- precision,
- time_info,
+ AutoLogger autolog(rec_model_file, runtime_device, std::stoi(num_threads),
+ std::stoi(batchsize), "dynamic", precision, time_info,
cv_all_img_names.size());
autolog.report();
}
diff --git a/deploy/paddlecloud/README.md b/deploy/paddlecloud/README.md
index f96e50edbd..1ff49c7a02 100644
--- a/deploy/paddlecloud/README.md
+++ b/deploy/paddlecloud/README.md
@@ -1,6 +1,6 @@
# 云上飞桨部署工具
-[云上飞桨(PaddleCloud)](https://github.com/PaddlePaddle/PaddleCloud) 是面向飞桨框架及其模型套件的部署工具,
+[云上飞桨(PaddleCloud)](https://github.com/PaddlePaddle/PaddleCloud) 是面向飞桨框架及其模型套件的部署工具,
为用户提供了模型套件Docker化部署和Kubernetes集群部署两种方式,可以满足不同场景与环境的部署需求。
本章节我们将使用PaddleCloud提供的OCR标准镜像以及云原生组件来训练和部署PP-OCRv3识别模型。
@@ -146,7 +146,7 @@ $ tar xf /home/PaddleOCR/pre_train/ch_PP-OCRv3_det_distill_train.tar -C /home/Pa
启动训练,训练模型默认保存在`output`目录下,加载PP-OCRv3检测预训练模型。
-```bash
+```bash
# 这里以 GPU 训练为例,使用 CPU 进行训练的话,需要指定参数 Global.use_gpu=false
python3 tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.save_model_dir=./output/ Global.pretrained_model=./pre_train/ch_PP-OCRv3_det_distill_train/best_accuracy
```
@@ -163,10 +163,10 @@ python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3' tools/t
训练过程中保存的模型在output目录下,包含以下文件:
```
-best_accuracy.states
+best_accuracy.states
best_accuracy.pdparams # 默认保存最优精度的模型参数
best_accuracy.pdopt # 默认保存最优精度的优化器相关参数
-latest.states
+latest.states
latest.pdparams # 默认保存的最新模型参数
latest.pdopt # 默认保存的最新模型的优化器相关参数
```
@@ -336,4 +336,4 @@ $ kubectl logs -f ppocrv3-worker-0 -n paddlecloud
## 更多资源
欢迎关注[云上飞桨项目PaddleCloud](https://github.com/PaddlePaddle/PaddleCloud),我们为您提供了飞桨模型套件标准镜像以及全栈的云原生模型套件部署组件,如您有任何关于飞桨模型套件的部署问题,请联系我们。
-如果你发现任何PaddleCloud存在的问题或者是建议, 欢迎通过[GitHub Issues](https://github.com/PaddlePaddle/PaddleCloud/issues)给我们提issues。
\ No newline at end of file
+如果你发现任何PaddleCloud存在的问题或者是建议, 欢迎通过[GitHub Issues](https://github.com/PaddlePaddle/PaddleCloud/issues)给我们提issues。
diff --git a/deploy/pdserving/general_detection_op.cpp b/deploy/pdserving/general_detection_op.cpp
index 7d9182950b..131a5e18cb 100644
--- a/deploy/pdserving/general_detection_op.cpp
+++ b/deploy/pdserving/general_detection_op.cpp
@@ -32,12 +32,12 @@ namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
-using baidu::paddle_serving::predictor::MempoolWrapper;
-using baidu::paddle_serving::predictor::general_model::Tensor;
-using baidu::paddle_serving::predictor::general_model::Response;
-using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::InferManager;
+using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
+using baidu::paddle_serving::predictor::general_model::Request;
+using baidu::paddle_serving::predictor::general_model::Response;
+using baidu::paddle_serving::predictor::general_model::Tensor;
int GeneralDetectionOp::inference() {
VLOG(2) << "Going to run inference";
diff --git a/deploy/pdserving/ocr_cpp_client.py b/deploy/pdserving/ocr_cpp_client.py
index 3aaf031559..3f58c9274d 100755
--- a/deploy/pdserving/ocr_cpp_client.py
+++ b/deploy/pdserving/ocr_cpp_client.py
@@ -31,18 +31,18 @@
client.connect(["127.0.0.1:8181"])
import paddle
+
test_img_dir = "../../doc/imgs/1.jpg"
ocr_reader = OCRReader(char_dict_path="../../ppocr/utils/ppocr_keys_v1.txt")
def cv2_to_base64(image):
- return base64.b64encode(image).decode(
- 'utf8') #data.tostring()).decode('utf8')
+ return base64.b64encode(image).decode("utf8") # data.tostring()).decode('utf8')
def _check_image_file(path):
- img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
+ img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif"}
return any([path.lower().endswith(e) for e in img_end])
@@ -58,26 +58,25 @@ def _check_image_file(path):
raise Exception("not found any img file in {}".format(test_img_dir))
for img_file in test_img_list:
- with open(img_file, 'rb') as file:
+ with open(img_file, "rb") as file:
image_data = file.read()
image = cv2_to_base64(image_data)
res_list = []
fetch_map = client.predict(feed={"x": image}, fetch=[], batch=True)
if fetch_map is None:
- print('no results')
+ print("no results")
else:
if "text" in fetch_map:
for x in fetch_map["text"]:
x = codecs.encode(x)
- words = base64.b64decode(x).decode('utf-8')
+ words = base64.b64decode(x).decode("utf-8")
res_list.append(words)
else:
try:
- one_batch_res = ocr_reader.postprocess(
- fetch_map, with_score=True)
+ one_batch_res = ocr_reader.postprocess(fetch_map, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
except:
- print('no results')
+ print("no results")
res = {"res": str(res_list)}
print(res)
diff --git a/deploy/pdserving/ocr_reader.py b/deploy/pdserving/ocr_reader.py
index d488cc0920..81d90b755e 100644
--- a/deploy/pdserving/ocr_reader.py
+++ b/deploy/pdserving/ocr_reader.py
@@ -27,18 +27,18 @@ class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
- if 'image_shape' in kwargs:
- self.image_shape = kwargs['image_shape']
+ if "image_shape" in kwargs:
+ self.image_shape = kwargs["image_shape"]
self.resize_type = 1
- elif 'limit_side_len' in kwargs:
- self.limit_side_len = kwargs['limit_side_len']
- self.limit_type = kwargs.get('limit_type', 'min')
- elif 'resize_short' in kwargs:
+ elif "limit_side_len" in kwargs:
+ self.limit_side_len = kwargs["limit_side_len"]
+ self.limit_type = kwargs.get("limit_type", "min")
+ elif "resize_short" in kwargs:
self.limit_side_len = 736
- self.limit_type = 'min'
+ self.limit_type = "min"
else:
self.resize_type = 2
- self.resize_long = kwargs.get('resize_long', 960)
+ self.resize_long = kwargs.get("resize_long", 960)
def __call__(self, data):
img = deepcopy(data)
@@ -73,14 +73,14 @@ def resize_image_type0(self, img):
h, w, _ = img.shape
# limit the max side
- if self.limit_type == 'max':
+ if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
else:
if min(h, w) < limit_side_len:
if h < w:
@@ -88,7 +88,7 @@ def resize_image_type0(self, img):
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
resize_h = int(h * ratio)
resize_w = int(w * ratio)
@@ -133,20 +133,48 @@ def resize_image_type2(self, img):
class BaseRecLabelDecode(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, config):
support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
- 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
- 'ne', 'EN'
+ "ch",
+ "en",
+ "EN_symbol",
+ "french",
+ "german",
+ "japan",
+ "korean",
+ "it",
+ "xi",
+ "pu",
+ "ru",
+ "ar",
+ "ta",
+ "ug",
+ "fa",
+ "ur",
+ "rs",
+ "oc",
+ "rsc",
+ "bg",
+ "uk",
+ "be",
+ "te",
+ "ka",
+ "chinese_cht",
+ "hi",
+ "mr",
+ "ne",
+ "EN",
]
- character_type = config['character_type']
- character_dict_path = config['character_dict_path']
+ character_type = config["character_type"]
+ character_dict_path = config["character_dict_path"]
use_space_char = True
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
+ assert (
+ character_type in support_character_type
+ ), "Only {} are supported now but get {}".format(
+ support_character_type, character_type
+ )
self.beg_str = "sos"
self.end_str = "eos"
@@ -160,12 +188,15 @@ def __init__(self, config):
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
+ assert (
+ character_dict_path is not None
+ ), "character_dict_path should not be None when character_type is {}".format(
+ character_type
+ )
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str += line
if use_space_char:
self.character_str += " "
@@ -184,7 +215,7 @@ def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
@@ -196,16 +227,17 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
continue
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
@@ -214,15 +246,16 @@ def get_ignored_tokens(self):
class CTCLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(
- self,
- config,
- #character_dict_path=None,
- #character_type='ch',
- #use_space_char=False,
- **kwargs):
+ self,
+ config,
+ # character_dict_path=None,
+ # character_type='ch',
+ # use_space_char=False,
+ **kwargs
+ ):
super(CTCLabelDecode, self).__init__(config)
def __call__(self, preds, label=None, *args, **kwargs):
@@ -235,26 +268,26 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
+ dict_character = ["blank"] + dict_character
return dict_character
class CharacterOps(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, config):
- self.character_type = config['character_type']
- self.loss_type = config['loss_type']
+ self.character_type = config["character_type"]
+ self.loss_type = config["loss_type"]
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
- character_dict_path = config['character_dict_path']
+ character_dict_path = config["character_dict_path"]
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str += line
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
@@ -263,8 +296,9 @@ def __init__(self, config):
dict_character = list(self.character_str)
else:
self.character_str = None
- assert self.character_str is not None, \
- "Nonsupport type of the character: {}".format(self.character_str)
+ assert (
+ self.character_str is not None
+ ), "Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos"
self.end_str = "eos"
if self.loss_type == "attention":
@@ -296,7 +330,7 @@ def encode(self, text):
return text
def decode(self, text_index, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
char_list = []
char_num = self.get_char_num()
@@ -314,7 +348,7 @@ def decode(self, text_index, is_remove_duplicate=False):
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[text_index[idx]])
- text = ''.join(char_list)
+ text = "".join(char_list)
return text
def get_char_num(self):
@@ -327,29 +361,31 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx"\
- % beg_or_end
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
else:
- err = "error in get_beg_end_flag_idx when using the loss %s"\
- % (self.loss_type)
+ err = "error in get_beg_end_flag_idx when using the loss %s" % (
+ self.loss_type
+ )
assert False, err
class OCRReader(object):
- def __init__(self,
- algorithm="CRNN",
- image_shape=[3, 48, 320],
- char_type="ch",
- batch_num=1,
- char_dict_path="./ppocr_keys_v1.txt"):
+ def __init__(
+ self,
+ algorithm="CRNN",
+ image_shape=[3, 48, 320],
+ char_type="ch",
+ batch_num=1,
+ char_dict_path="./ppocr_keys_v1.txt",
+ ):
self.rec_image_shape = image_shape
self.character_type = char_type
self.rec_batch_num = batch_num
char_ops_params = {}
char_ops_params["character_type"] = char_type
char_ops_params["character_dict_path"] = char_dict_path
- char_ops_params['loss_type'] = 'ctc'
+ char_ops_params["loss_type"] = "ctc"
self.char_ops = CharacterOps(char_ops_params)
self.label_ops = CTCLabelDecode(char_ops_params)
@@ -365,7 +401,7 @@ def resize_norm_img(self, img, max_wh_ratio):
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
@@ -377,7 +413,7 @@ def resize_norm_img(self, img, max_wh_ratio):
def preprocess(self, img_list):
img_num = len(img_list)
norm_img_batch = []
- max_wh_ratio = 320/48.
+ max_wh_ratio = 320 / 48.0
for ino in range(img_num):
h, w = img_list[ino].shape[0:2]
wh_ratio = w * 1.0 / h
@@ -400,8 +436,7 @@ def postprocess(self, outputs, with_score=False):
pass
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
- text = self.label_ops.decode(
- preds_idx, preds_prob, is_remove_duplicate=True)
+ text = self.label_ops.decode(preds_idx, preds_prob, is_remove_duplicate=True)
return text
@@ -411,16 +446,13 @@ def postprocess(self, outputs, with_score=False):
class ArgsParser(ArgumentParser):
def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
+ super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
- self.add_argument(
- "-o", "--opt", nargs='+', help="set configuration options")
+ self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, \
- "Please specify --config=configure_file_path."
+ assert args.config is not None, "Please specify --config=configure_file_path."
args.conf_dict = self._parse_opt(args.opt, args.config)
print("args config:", args.conf_dict)
return args
@@ -432,7 +464,7 @@ def _parse_helper(self, v):
else:
v = int(v)
elif v == "True" or v == "False":
- v = (v == "True")
+ v = v == "True"
return v
def _parse_opt(self, opts, conf_path):
@@ -442,7 +474,7 @@ def _parse_opt(self, opts, conf_path):
return config
for s in opts:
s = s.strip()
- k, v = s.split('=')
+ k, v = s.split("=")
v = self._parse_helper(v)
print(k, v, type(v))
cur = config
diff --git a/deploy/pdserving/pipeline_http_client.py b/deploy/pdserving/pipeline_http_client.py
index 0a86a63987..4ab4a46e4f 100644
--- a/deploy/pdserving/pipeline_http_client.py
+++ b/deploy/pdserving/pipeline_http_client.py
@@ -33,11 +33,11 @@ def str2bool(v):
def cv2_to_base64(image):
- return base64.b64encode(image).decode('utf8')
+ return base64.b64encode(image).decode("utf8")
def _check_image_file(path):
- img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
+ img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif"}
return any([path.lower().endswith(e) for e in img_end])
@@ -56,10 +56,10 @@ def _check_image_file(path):
raise Exception("not found any img file in {}".format(test_img_dir))
for idx, img_file in enumerate(test_img_list):
- with open(img_file, 'rb') as file:
+ with open(img_file, "rb") as file:
image_data1 = file.read()
# print file name
- print('{}{}{}'.format('*' * 10, img_file, '*' * 10))
+ print("{}{}{}".format("*" * 10, img_file, "*" * 10))
image = cv2_to_base64(image_data1)
@@ -83,7 +83,5 @@ def _check_image_file(path):
continue
else:
- print(
- "For details about error message, see PipelineServingLogs/pipeline.log"
- )
+ print("For details about error message, see PipelineServingLogs/pipeline.log")
print("==> total number of test imgs: ", len(test_img_list))
diff --git a/deploy/pdserving/pipeline_rpc_client.py b/deploy/pdserving/pipeline_rpc_client.py
index 3d2a90f443..14f7a1c997 100644
--- a/deploy/pdserving/pipeline_rpc_client.py
+++ b/deploy/pdserving/pipeline_rpc_client.py
@@ -23,21 +23,22 @@
import os
client = PipelineClient()
-client.connect(['127.0.0.1:18091'])
+client.connect(["127.0.0.1:18091"])
def cv2_to_base64(image):
- return base64.b64encode(image).decode('utf8')
+ return base64.b64encode(image).decode("utf8")
import argparse
+
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument("--image_dir", type=str, default="../../doc/imgs/")
args = parser.parse_args()
test_img_dir = args.image_dir
for img_file in os.listdir(test_img_dir):
- with open(os.path.join(test_img_dir, img_file), 'rb') as file:
+ with open(os.path.join(test_img_dir, img_file), "rb") as file:
image_data = file.read()
image = cv2_to_base64(image_data)
diff --git a/deploy/pdserving/web_service.py b/deploy/pdserving/web_service.py
index b6fadb91d5..27c70b072f 100644
--- a/deploy/pdserving/web_service.py
+++ b/deploy/pdserving/web_service.py
@@ -18,34 +18,45 @@
import copy
import cv2
import base64
+
# from paddle_serving_app.reader import OCRReader
from ocr_reader import OCRReader, DetResizeForTest, ArgsParser
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
-from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
+from paddle_serving_app.reader import (
+ DBPostProcess,
+ FilterBoxes,
+ GetRotateCropImage,
+ SortedBoxes,
+)
_LOGGER = logging.getLogger()
class DetOp(Op):
def init_op(self):
- self.det_preprocess = Sequential([
- DetResizeForTest(), Div(255),
- Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
- (2, 0, 1))
- ])
+ self.det_preprocess = Sequential(
+ [
+ DetResizeForTest(),
+ Div(255),
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ Transpose((2, 0, 1)),
+ ]
+ )
self.filter_func = FilterBoxes(10, 10)
- self.post_func = DBPostProcess({
- "thresh": 0.3,
- "box_thresh": 0.6,
- "max_candidates": 1000,
- "unclip_ratio": 1.5,
- "min_size": 3
- })
+ self.post_func = DBPostProcess(
+ {
+ "thresh": 0.3,
+ "box_thresh": 0.6,
+ "max_candidates": 1000,
+ "unclip_ratio": 1.5,
+ "min_size": 3,
+ }
+ )
def preprocess(self, input_dicts, data_id, log_id):
- (_, input_dict), = input_dicts.items()
- data = base64.b64decode(input_dict["image"].encode('utf8'))
+ ((_, input_dict),) = input_dicts.items()
+ data = base64.b64decode(input_dict["image"].encode("utf8"))
self.raw_im = data
data = np.fromstring(data, np.uint8)
# Note: class variables(self.var) can only be used in process op mode
@@ -57,9 +68,7 @@ def preprocess(self, input_dicts, data_id, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
det_out = list(fetch_dict.values())[0]
- ratio_list = [
- float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
- ]
+ ratio_list = [float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
out_dict = {"dt_boxes": dt_boxes, "image": self.raw_im}
@@ -69,13 +78,14 @@ def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
class RecOp(Op):
def init_op(self):
self.ocr_reader = OCRReader(
- char_dict_path="../../ppocr/utils/ppocr_keys_v1.txt")
+ char_dict_path="../../ppocr/utils/ppocr_keys_v1.txt"
+ )
self.get_rotate_crop_image = GetRotateCropImage()
self.sorted_boxes = SortedBoxes()
def preprocess(self, input_dicts, data_id, log_id):
- (_, input_dict), = input_dicts.items()
+ ((_, input_dict),) = input_dicts.items()
raw_im = input_dict["image"]
data = np.frombuffer(raw_im, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
@@ -85,7 +95,7 @@ def preprocess(self, input_dicts, data_id, log_id):
dt_boxes = copy.deepcopy(self.dt_list)
feed_list = []
img_list = []
- max_wh_ratio = 320 / 48.
+ max_wh_ratio = 320 / 48.0
## Many mini-batchs, the type of feed_data is list.
max_batch_size = 6 # len(dt_boxes)
@@ -106,8 +116,11 @@ def preprocess(self, input_dicts, data_id, log_id):
elif bt_idx < batch_size:
boxes_num_in_one_batch = max_batch_size
else:
- _LOGGER.error("batch_size error, bt_idx={}, batch_size={}".
- format(bt_idx, batch_size))
+ _LOGGER.error(
+ "batch_size error, bt_idx={}, batch_size={}".format(
+ bt_idx, batch_size
+ )
+ )
break
start = bt_idx * max_batch_size
@@ -119,10 +132,9 @@ def preprocess(self, input_dicts, data_id, log_id):
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
- _, w, h = self.ocr_reader.resize_norm_img(img_list[0],
- max_wh_ratio).shape
+ _, w, h = self.ocr_reader.resize_norm_img(img_list[0], max_wh_ratio).shape
- imgs = np.zeros((boxes_num_in_one_batch, 3, w, h)).astype('float32')
+ imgs = np.zeros((boxes_num_in_one_batch, 3, w, h)).astype("float32")
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
@@ -135,14 +147,12 @@ def postprocess(self, input_dicts, fetch_data, data_id, log_id):
dt_num = len(self.dt_list)
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
- rec_batch_res = self.ocr_reader.postprocess(
- fetch_data, with_score=True)
+ rec_batch_res = self.ocr_reader.postprocess(fetch_data, with_score=True)
for res in rec_batch_res:
rec_list.append(res)
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
- one_batch_res = self.ocr_reader.postprocess(
- one_batch, with_score=True)
+ one_batch_res = self.ocr_reader.postprocess(one_batch, with_score=True)
for res in one_batch_res:
rec_list.append(res)
result_list = []
diff --git a/deploy/pdserving/web_service_det.py b/deploy/pdserving/web_service_det.py
index 4a62ab861d..26ed8f3a81 100644
--- a/deploy/pdserving/web_service_det.py
+++ b/deploy/pdserving/web_service_det.py
@@ -17,34 +17,45 @@
import numpy as np
import cv2
import base64
+
# from paddle_serving_app.reader import OCRReader
from ocr_reader import OCRReader, DetResizeForTest, ArgsParser
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
-from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
+from paddle_serving_app.reader import (
+ DBPostProcess,
+ FilterBoxes,
+ GetRotateCropImage,
+ SortedBoxes,
+)
_LOGGER = logging.getLogger()
class DetOp(Op):
def init_op(self):
- self.det_preprocess = Sequential([
- DetResizeForTest(), Div(255),
- Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
- (2, 0, 1))
- ])
+ self.det_preprocess = Sequential(
+ [
+ DetResizeForTest(),
+ Div(255),
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ Transpose((2, 0, 1)),
+ ]
+ )
self.filter_func = FilterBoxes(10, 10)
- self.post_func = DBPostProcess({
- "thresh": 0.3,
- "box_thresh": 0.5,
- "max_candidates": 1000,
- "unclip_ratio": 1.5,
- "min_size": 3
- })
+ self.post_func = DBPostProcess(
+ {
+ "thresh": 0.3,
+ "box_thresh": 0.5,
+ "max_candidates": 1000,
+ "unclip_ratio": 1.5,
+ "min_size": 3,
+ }
+ )
def preprocess(self, input_dicts, data_id, log_id):
- (_, input_dict), = input_dicts.items()
- data = base64.b64decode(input_dict["image"].encode('utf8'))
+ ((_, input_dict),) = input_dicts.items()
+ data = base64.b64decode(input_dict["image"].encode("utf8"))
self.raw_im = data
data = np.fromstring(data, np.uint8)
# Note: class variables(self.var) can only be used in process op mode
@@ -56,9 +67,7 @@ def preprocess(self, input_dicts, data_id, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
det_out = list(fetch_dict.values())[0]
- ratio_list = [
- float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
- ]
+ ratio_list = [float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
out_dict = {"dt_boxes": str(dt_boxes)}
diff --git a/deploy/pdserving/web_service_rec.py b/deploy/pdserving/web_service_rec.py
index c4720d0818..a748295ba9 100644
--- a/deploy/pdserving/web_service_rec.py
+++ b/deploy/pdserving/web_service_rec.py
@@ -17,6 +17,7 @@
import numpy as np
import cv2
import base64
+
# from paddle_serving_app.reader import OCRReader
from ocr_reader import OCRReader, DetResizeForTest, ArgsParser
from paddle_serving_app.reader import Sequential, ResizeByFactor
@@ -28,11 +29,12 @@
class RecOp(Op):
def init_op(self):
self.ocr_reader = OCRReader(
- char_dict_path="../../ppocr/utils/ppocr_keys_v1.txt")
+ char_dict_path="../../ppocr/utils/ppocr_keys_v1.txt"
+ )
def preprocess(self, input_dicts, data_id, log_id):
- (_, input_dict), = input_dicts.items()
- raw_im = base64.b64decode(input_dict["image"].encode('utf8'))
+ ((_, input_dict),) = input_dicts.items()
+ raw_im = base64.b64decode(input_dict["image"].encode("utf8"))
data = np.fromstring(raw_im, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
feed_list = []
@@ -60,14 +62,12 @@ def postprocess(self, input_dicts, fetch_data, data_id, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
- rec_batch_res = self.ocr_reader.postprocess(
- fetch_data, with_score=True)
+ rec_batch_res = self.ocr_reader.postprocess(fetch_data, with_score=True)
for res in rec_batch_res:
res_list.append(res[0])
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
- one_batch_res = self.ocr_reader.postprocess(
- one_batch, with_score=True)
+ one_batch_res = self.ocr_reader.postprocess(one_batch, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
diff --git a/deploy/pdserving/win/ocr_reader.py b/deploy/pdserving/win/ocr_reader.py
index 18b9385aa0..6b2a4f8078 100644
--- a/deploy/pdserving/win/ocr_reader.py
+++ b/deploy/pdserving/win/ocr_reader.py
@@ -27,18 +27,18 @@ class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
- if 'image_shape' in kwargs:
- self.image_shape = kwargs['image_shape']
+ if "image_shape" in kwargs:
+ self.image_shape = kwargs["image_shape"]
self.resize_type = 1
- elif 'limit_side_len' in kwargs:
- self.limit_side_len = kwargs['limit_side_len']
- self.limit_type = kwargs.get('limit_type', 'min')
- elif 'resize_short' in kwargs:
+ elif "limit_side_len" in kwargs:
+ self.limit_side_len = kwargs["limit_side_len"]
+ self.limit_type = kwargs.get("limit_type", "min")
+ elif "resize_short" in kwargs:
self.limit_side_len = 736
- self.limit_type = 'min'
+ self.limit_type = "min"
else:
self.resize_type = 2
- self.resize_long = kwargs.get('resize_long', 960)
+ self.resize_long = kwargs.get("resize_long", 960)
def __call__(self, data):
img = deepcopy(data)
@@ -73,14 +73,14 @@ def resize_image_type0(self, img):
h, w, _ = img.shape
# limit the max side
- if self.limit_type == 'max':
+ if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
else:
if min(h, w) < limit_side_len:
if h < w:
@@ -88,7 +88,7 @@ def resize_image_type0(self, img):
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
resize_h = int(h * ratio)
resize_w = int(w * ratio)
@@ -133,20 +133,48 @@ def resize_image_type2(self, img):
class BaseRecLabelDecode(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, config):
support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
- 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
- 'ne', 'EN'
+ "ch",
+ "en",
+ "EN_symbol",
+ "french",
+ "german",
+ "japan",
+ "korean",
+ "it",
+ "xi",
+ "pu",
+ "ru",
+ "ar",
+ "ta",
+ "ug",
+ "fa",
+ "ur",
+ "rs",
+ "oc",
+ "rsc",
+ "bg",
+ "uk",
+ "be",
+ "te",
+ "ka",
+ "chinese_cht",
+ "hi",
+ "mr",
+ "ne",
+ "EN",
]
- character_type = config['character_type']
- character_dict_path = config['character_dict_path']
+ character_type = config["character_type"]
+ character_dict_path = config["character_dict_path"]
use_space_char = True
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
+ assert (
+ character_type in support_character_type
+ ), "Only {} are supported now but get {}".format(
+ support_character_type, character_type
+ )
self.beg_str = "sos"
self.end_str = "eos"
@@ -160,12 +188,15 @@ def __init__(self, config):
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
+ assert (
+ character_dict_path is not None
+ ), "character_dict_path should not be None when character_type is {}".format(
+ character_type
+ )
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str += line
if use_space_char:
self.character_str += " "
@@ -184,7 +215,7 @@ def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
@@ -196,16 +227,17 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
continue
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
@@ -214,15 +246,16 @@ def get_ignored_tokens(self):
class CTCLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(
- self,
- config,
- #character_dict_path=None,
- #character_type='ch',
- #use_space_char=False,
- **kwargs):
+ self,
+ config,
+ # character_dict_path=None,
+ # character_type='ch',
+ # use_space_char=False,
+ **kwargs
+ ):
super(CTCLabelDecode, self).__init__(config)
def __call__(self, preds, label=None, *args, **kwargs):
@@ -235,26 +268,26 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
+ dict_character = ["blank"] + dict_character
return dict_character
class CharacterOps(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, config):
- self.character_type = config['character_type']
- self.loss_type = config['loss_type']
+ self.character_type = config["character_type"]
+ self.loss_type = config["loss_type"]
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
- character_dict_path = config['character_dict_path']
+ character_dict_path = config["character_dict_path"]
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str += line
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
@@ -263,8 +296,9 @@ def __init__(self, config):
dict_character = list(self.character_str)
else:
self.character_str = None
- assert self.character_str is not None, \
- "Nonsupport type of the character: {}".format(self.character_str)
+ assert (
+ self.character_str is not None
+ ), "Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos"
self.end_str = "eos"
if self.loss_type == "attention":
@@ -296,7 +330,7 @@ def encode(self, text):
return text
def decode(self, text_index, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
char_list = []
char_num = self.get_char_num()
@@ -314,7 +348,7 @@ def decode(self, text_index, is_remove_duplicate=False):
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[text_index[idx]])
- text = ''.join(char_list)
+ text = "".join(char_list)
return text
def get_char_num(self):
@@ -327,29 +361,31 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx"\
- % beg_or_end
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
else:
- err = "error in get_beg_end_flag_idx when using the loss %s"\
- % (self.loss_type)
+ err = "error in get_beg_end_flag_idx when using the loss %s" % (
+ self.loss_type
+ )
assert False, err
class OCRReader(object):
- def __init__(self,
- algorithm="CRNN",
- image_shape=[3, 32, 320],
- char_type="ch",
- batch_num=1,
- char_dict_path="./ppocr_keys_v1.txt"):
+ def __init__(
+ self,
+ algorithm="CRNN",
+ image_shape=[3, 32, 320],
+ char_type="ch",
+ batch_num=1,
+ char_dict_path="./ppocr_keys_v1.txt",
+ ):
self.rec_image_shape = image_shape
self.character_type = char_type
self.rec_batch_num = batch_num
char_ops_params = {}
char_ops_params["character_type"] = char_type
char_ops_params["character_dict_path"] = char_dict_path
- char_ops_params['loss_type'] = 'ctc'
+ char_ops_params["loss_type"] = "ctc"
self.char_ops = CharacterOps(char_ops_params)
self.label_ops = CTCLabelDecode(char_ops_params)
@@ -365,7 +401,7 @@ def resize_norm_img(self, img, max_wh_ratio):
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
@@ -400,6 +436,5 @@ def postprocess(self, outputs, with_score=False):
pass
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
- text = self.label_ops.decode(
- preds_idx, preds_prob, is_remove_duplicate=True)
+ text = self.label_ops.decode(preds_idx, preds_prob, is_remove_duplicate=True)
return text
diff --git a/deploy/pdserving/win/ocr_web_client.py b/deploy/pdserving/win/ocr_web_client.py
index a288529316..0332dc7a0d 100644
--- a/deploy/pdserving/win/ocr_web_client.py
+++ b/deploy/pdserving/win/ocr_web_client.py
@@ -22,9 +22,8 @@
def cv2_to_base64(image):
- #data = cv2.imencode('.jpg', image)[1]
- return base64.b64encode(image).decode(
- 'utf8') #data.tostring()).decode('utf8')
+ # data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(image).decode("utf8") # data.tostring()).decode('utf8')
headers = {"Content-type": "application/json"}
@@ -32,7 +31,7 @@ def cv2_to_base64(image):
test_img_dir = "../../../doc/imgs/"
for idx, img_file in enumerate(os.listdir(test_img_dir)):
- with open(os.path.join(test_img_dir, img_file), 'rb') as file:
+ with open(os.path.join(test_img_dir, img_file), "rb") as file:
image_data1 = file.read()
image = cv2_to_base64(image_data1)
diff --git a/deploy/pdserving/win/ocr_web_server.py b/deploy/pdserving/win/ocr_web_server.py
index 9fc9490129..a89390afac 100644
--- a/deploy/pdserving/win/ocr_web_server.py
+++ b/deploy/pdserving/win/ocr_web_server.py
@@ -20,8 +20,14 @@
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
-from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
+from paddle_serving_app.reader import (
+ DBPostProcess,
+ FilterBoxes,
+ GetRotateCropImage,
+ SortedBoxes,
+)
from ocr_reader import OCRReader
+
try:
from paddle_serving_server_gpu.web_service import WebService
except ImportError:
@@ -34,22 +40,25 @@
class OCRService(WebService):
def init_det_debugger(self, det_model_config):
- self.det_preprocess = Sequential([
- ResizeByFactor(32, 960), Div(255),
- Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
- (2, 0, 1))
- ])
+ self.det_preprocess = Sequential(
+ [
+ ResizeByFactor(32, 960),
+ Div(255),
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ Transpose((2, 0, 1)),
+ ]
+ )
self.det_client = LocalPredictor()
- if sys.argv[1] == 'gpu':
- self.det_client.load_model_config(
- det_model_config, use_gpu=True, gpu_id=0)
- elif sys.argv[1] == 'cpu':
+ if sys.argv[1] == "gpu":
+ self.det_client.load_model_config(det_model_config, use_gpu=True, gpu_id=0)
+ elif sys.argv[1] == "cpu":
self.det_client.load_model_config(det_model_config)
self.ocr_reader = OCRReader(
- char_dict_path="../../../ppocr/utils/ppocr_keys_v1.txt")
+ char_dict_path="../../../ppocr/utils/ppocr_keys_v1.txt"
+ )
def preprocess(self, feed=[], fetch=[]):
- data = base64.b64decode(feed[0]["image"].encode('utf8'))
+ data = base64.b64decode(feed[0]["image"].encode("utf8"))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
ori_h, ori_w, _ = im.shape
@@ -58,18 +67,23 @@ def preprocess(self, feed=[], fetch=[]):
det_img = det_img[np.newaxis, :]
det_img = det_img.copy()
det_out = self.det_client.predict(
- feed={"x": det_img}, fetch=["save_infer_model/scale_0.tmp_1"], batch=True)
+ feed={"x": det_img}, fetch=["save_infer_model/scale_0.tmp_1"], batch=True
+ )
filter_func = FilterBoxes(10, 10)
- post_func = DBPostProcess({
- "thresh": 0.3,
- "box_thresh": 0.5,
- "max_candidates": 1000,
- "unclip_ratio": 1.5,
- "min_size": 3
- })
+ post_func = DBPostProcess(
+ {
+ "thresh": 0.3,
+ "box_thresh": 0.5,
+ "max_candidates": 1000,
+ "unclip_ratio": 1.5,
+ "min_size": 3,
+ }
+ )
sorted_boxes = SortedBoxes()
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
- dt_boxes_list = post_func(det_out["save_infer_model/scale_0.tmp_1"], [ratio_list])
+ dt_boxes_list = post_func(
+ det_out["save_infer_model/scale_0.tmp_1"], [ratio_list]
+ )
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
get_rotate_crop_image = GetRotateCropImage()
@@ -83,9 +97,8 @@ def preprocess(self, feed=[], fetch=[]):
max_wh_ratio = max(max_wh_ratio, wh_ratio)
if len(img_list) == 0:
return [], []
- _, w, h = self.ocr_reader.resize_norm_img(img_list[0],
- max_wh_ratio).shape
- imgs = np.zeros((len(img_list), 3, w, h)).astype('float32')
+ _, w, h = self.ocr_reader.resize_norm_img(img_list[0], max_wh_ratio).shape
+ imgs = np.zeros((len(img_list), 3, w, h)).astype("float32")
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
@@ -106,9 +119,9 @@ def postprocess(self, feed={}, fetch=[], fetch_map=None):
ocr_service.load_model_config("../ppocr_rec_mobile_2.0_serving")
ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_debugger(det_model_config="../ppocr_det_mobile_2.0_serving")
-if sys.argv[1] == 'gpu':
+if sys.argv[1] == "gpu":
ocr_service.set_gpus("0")
ocr_service.run_debugger_service(gpu=True)
-elif sys.argv[1] == 'cpu':
+elif sys.argv[1] == "cpu":
ocr_service.run_debugger_service()
ocr_service.run_web_service()
diff --git a/deploy/slim/auto_compression/README.md b/deploy/slim/auto_compression/README.md
index e6021415d9..408256e10e 100644
--- a/deploy/slim/auto_compression/README.md
+++ b/deploy/slim/auto_compression/README.md
@@ -91,7 +91,7 @@ pip install scikit-image imgaug
```shell
git clone -b release/2.7 https://github.com/PaddlePaddle/PaddleOCR.git
cd PaddleOCR/
-pip install -r requirements.txt
+pip install -r requirements.txt
```
### 3.2 准备数据集
@@ -184,7 +184,7 @@ TensorRT预测环境配置:
##### 4.1.1 使用测试脚本进行批量测试:
我们提供两个脚本文件用于测试模型自动化压缩的效果,分别是[test_ocr_det.sh](./test_ocr_det.sh)和[test_ocr_rec.sh](./test_ocr_rec.sh),这两个脚本都接收一个`model_type`参数,用于区分是测试mobile模型还是server模型,可选参数为`mobile`和`server`,使用示例:
-
+
```shell
# 测试mobile模型
bash test_ocr_det.sh mobile
@@ -273,7 +273,7 @@ Eval:
name: SimpleDataSet
data_dir: datasets/v4_4_test_dataset
label_file_list:
- - datasets/v4_4_test_dataset/label.txt
+ - datasets/v4_4_test_dataset/label.txt
```
### 5.2 软件环境一致,硬件不同导致精度差异很大?
@@ -291,7 +291,7 @@ if args.precision == 'int8' and "ppocrv4_det_server_qat_dist.yaml" in args.confi
use_static=True,
use_calib_mode=False, )
pred_cfg.exp_disable_tensorrt_ops(["elementwise_add"])
-else:
+else:
pred_cfg.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
diff --git a/deploy/slim/auto_compression/ppocrv4_det_server_dataset_process.py b/deploy/slim/auto_compression/ppocrv4_det_server_dataset_process.py
index 4186e25c98..43a70f35a9 100644
--- a/deploy/slim/auto_compression/ppocrv4_det_server_dataset_process.py
+++ b/deploy/slim/auto_compression/ppocrv4_det_server_dataset_process.py
@@ -1,33 +1,33 @@
-import os
-import cv2
-
-dataset_path = 'datasets/v4_4_test_dataset'
-annotation_file = 'datasets/v4_4_test_dataset/label.txt'
-
-small_images_path = 'datasets/v4_4_test_dataset_small'
-new_annotation_file = 'datasets/v4_4_test_dataset_small/label.txt'
-
-os.makedirs(small_images_path, exist_ok=True)
-
-with open(annotation_file, 'r') as f:
- lines = f.readlines()
-
-for i, line in enumerate(lines):
+import os
+import cv2
+
+dataset_path = "datasets/v4_4_test_dataset"
+annotation_file = "datasets/v4_4_test_dataset/label.txt"
+
+small_images_path = "datasets/v4_4_test_dataset_small"
+new_annotation_file = "datasets/v4_4_test_dataset_small/label.txt"
+
+os.makedirs(small_images_path, exist_ok=True)
+
+with open(annotation_file, "r") as f:
+ lines = f.readlines()
+
+for i, line in enumerate(lines):
image_name = line.split(" ")[0]
-
- image_path = os.path.join(dataset_path, image_name)
-
+
+ image_path = os.path.join(dataset_path, image_name)
+
try:
- image = cv2.imread(image_path)
- height, width, _ = image.shape
-
- # 如果图像的宽度和高度都小于2000而且长宽比小于2,将其复制到新的文件夹,并保存其标注信息
- if height < 2000 and width < 2000:
- if max(height, width)/min(height,width) < 2:
+ image = cv2.imread(image_path)
+ height, width, _ = image.shape
+
+ # 如果图像的宽度和高度都小于2000而且长宽比小于2,将其复制到新的文件夹,并保存其标注信息
+ if height < 2000 and width < 2000:
+ if max(height, width) / min(height, width) < 2:
print(i, height, width, image_path)
- small_image_path = os.path.join(small_images_path, image_name)
+ small_image_path = os.path.join(small_images_path, image_name)
cv2.imwrite(small_image_path, image)
- with open(new_annotation_file, 'a') as f:
- f.write(f'{line}')
+ with open(new_annotation_file, "a") as f:
+ f.write(f"{line}")
except:
- continue
\ No newline at end of file
+ continue
diff --git a/deploy/slim/auto_compression/run.py b/deploy/slim/auto_compression/run.py
index 8042555786..32bdededb9 100644
--- a/deploy/slim/auto_compression/run.py
+++ b/deploy/slim/auto_compression/run.py
@@ -23,7 +23,8 @@
from paddleslim.common.dataloader import get_feed_vars
import sys
-sys.path.append('../../../')
+
+sys.path.append("../../../")
from ppocr.data import build_dataloader
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
@@ -34,21 +35,21 @@
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
- '--config_path',
+ "--config_path",
type=str,
default=None,
help="path of compression strategy config.",
- required=True)
+ required=True,
+ )
parser.add_argument(
- '--save_dir',
+ "--save_dir",
type=str,
- default='output',
- help="directory to save compressed model.")
+ default="output",
+ help="directory to save compressed model.",
+ )
parser.add_argument(
- '--devices',
- type=str,
- default='gpu',
- help="which device used to compress.")
+ "--devices", type=str, default="gpu", help="which device used to compress."
+ )
return parser
@@ -56,7 +57,7 @@ def reader_wrapper(reader, input_name):
if isinstance(input_name, list) and len(input_name) == 1:
input_name = input_name[0]
- def gen(): # 形成一个字典输入
+ def gen(): # 形成一个字典输入
for i, batch in enumerate(reader()):
yield {input_name: batch[0]}
@@ -64,100 +65,106 @@ def gen(): # 形成一个字典输入
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
- post_process_class = build_post_process(all_config['PostProcess'],
- global_config)
- eval_class = build_metric(all_config['Metric'])
- model_type = global_config['model_type']
+ post_process_class = build_post_process(all_config["PostProcess"], global_config)
+ eval_class = build_metric(all_config["Metric"])
+ model_type = global_config["model_type"]
with tqdm(
- total=len(val_loader),
- bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
- ncols=80) as t:
+ total=len(val_loader),
+ bar_format="Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}",
+ ncols=80,
+ ) as t:
for batch_id, batch in enumerate(val_loader):
images = batch[0]
-
+
try:
- preds, = exe.run(compiled_test_program,
- feed={test_feed_names[0]: images},
- fetch_list=test_fetch_list)
+ (preds,) = exe.run(
+ compiled_test_program,
+ feed={test_feed_names[0]: images},
+ fetch_list=test_fetch_list,
+ )
except:
- preds, _ = exe.run(compiled_test_program,
- feed={test_feed_names[0]: images},
- fetch_list=test_fetch_list)
+ preds, _ = exe.run(
+ compiled_test_program,
+ feed={test_feed_names[0]: images},
+ fetch_list=test_fetch_list,
+ )
batch_numpy = []
for item in batch:
batch_numpy.append(np.array(item))
- if model_type == 'det':
- preds_map = {'maps': preds}
+ if model_type == "det":
+ preds_map = {"maps": preds}
post_result = post_process_class(preds_map, batch_numpy[1])
eval_class(post_result, batch_numpy)
- elif model_type == 'rec':
+ elif model_type == "rec":
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
t.update()
metric = eval_class.get_metric()
- logger.info('metric eval ***************')
+ logger.info("metric eval ***************")
for k, v in metric.items():
- logger.info('{}:{}'.format(k, v))
+ logger.info("{}:{}".format(k, v))
- if model_type == 'det':
- return metric['hmean']
- elif model_type == 'rec':
- return metric['acc']
+ if model_type == "det":
+ return metric["hmean"]
+ elif model_type == "rec":
+ return metric["acc"]
return metric
def main():
rank_id = paddle.distributed.get_rank()
- if args.devices == 'gpu':
+ if args.devices == "gpu":
place = paddle.CUDAPlace(rank_id)
- paddle.set_device('gpu')
+ paddle.set_device("gpu")
else:
place = paddle.CPUPlace()
- paddle.set_device('cpu')
+ paddle.set_device("cpu")
global all_config, global_config
all_config = load_slim_config(args.config_path)
- if "Global" not in all_config:
+ if "Global" not in all_config:
raise KeyError(f"Key 'Global' not found in config file. \n{all_config}")
global_config = all_config["Global"]
gpu_num = paddle.distributed.get_world_size()
- train_dataloader = build_dataloader(all_config, 'Train', args.devices,
- logger)
+ train_dataloader = build_dataloader(all_config, "Train", args.devices, logger)
global val_loader
- val_loader = build_dataloader(all_config, 'Eval', args.devices, logger)
-
- if isinstance(all_config['TrainConfig']['learning_rate'],
- dict) and all_config['TrainConfig']['learning_rate'][
- 'type'] == 'CosineAnnealingDecay':
- steps = len(train_dataloader) * all_config['TrainConfig']['epochs']
- all_config['TrainConfig']['learning_rate']['T_max'] = steps
- print('total training steps:', steps)
-
- global_config['input_name'] = get_feed_vars(
- global_config['model_dir'], global_config['model_filename'],
- global_config['params_filename'])
-
+ val_loader = build_dataloader(all_config, "Eval", args.devices, logger)
+
+ if (
+ isinstance(all_config["TrainConfig"]["learning_rate"], dict)
+ and all_config["TrainConfig"]["learning_rate"]["type"] == "CosineAnnealingDecay"
+ ):
+ steps = len(train_dataloader) * all_config["TrainConfig"]["epochs"]
+ all_config["TrainConfig"]["learning_rate"]["T_max"] = steps
+ print("total training steps:", steps)
+
+ global_config["input_name"] = get_feed_vars(
+ global_config["model_dir"],
+ global_config["model_filename"],
+ global_config["params_filename"],
+ )
+
ac = AutoCompression(
- model_dir=global_config['model_dir'],
- model_filename=global_config['model_filename'],
- params_filename=global_config['params_filename'],
+ model_dir=global_config["model_dir"],
+ model_filename=global_config["model_filename"],
+ params_filename=global_config["params_filename"],
save_dir=args.save_dir,
config=all_config,
- train_dataloader=reader_wrapper(train_dataloader,
- global_config['input_name']),
+ train_dataloader=reader_wrapper(train_dataloader, global_config["input_name"]),
eval_callback=eval_function if rank_id == 0 else None,
- eval_dataloader=reader_wrapper(val_loader, global_config['input_name']))
+ eval_dataloader=reader_wrapper(val_loader, global_config["input_name"]),
+ )
ac.compress()
-if __name__ == '__main__':
+if __name__ == "__main__":
paddle.enable_static()
parser = argsparser()
args = parser.parse_args()
diff --git a/deploy/slim/auto_compression/test_ocr.py b/deploy/slim/auto_compression/test_ocr.py
index da31b8a704..65868f6828 100644
--- a/deploy/slim/auto_compression/test_ocr.py
+++ b/deploy/slim/auto_compression/test_ocr.py
@@ -28,7 +28,8 @@
from paddleslim.common import get_logger
import sys
-sys.path.append('../../../')
+
+sys.path.append("../../../")
from ppocr.data import build_dataloader
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
@@ -39,17 +40,16 @@
logger = get_logger(__name__, level=logging.INFO)
-
def find_images_with_bounding_size(dataset: paddle.io.Dataset):
max_length_index = -1
max_width_index = -1
min_length_index = -1
min_width_index = -1
- max_length = float('-inf')
- max_width = float('-inf')
- min_length = float('inf')
- min_width = float('inf')
+ max_length = float("-inf")
+ max_width = float("-inf")
+ min_length = float("inf")
+ min_width = float("inf")
for idx, data in enumerate(dataset):
image = np.array(data[0])
h, w = image.shape[-2:]
@@ -69,8 +69,10 @@ def find_images_with_bounding_size(dataset: paddle.io.Dataset):
print(f"Found max image width: {max_width}, index: {max_width_index}")
print(f"Found min image length: {min_length}, index: {min_length_index}")
print(f"Found min image width: {min_width}, index: {min_width_index}")
- return paddle.io.Subset(dataset, [max_width_index,max_length_index,
- min_width_index, min_length_index])
+ return paddle.io.Subset(
+ dataset, [max_width_index, max_length_index, min_width_index, min_length_index]
+ )
+
def load_predictor(args):
"""
@@ -91,8 +93,8 @@ def load_predictor(args):
pred_cfg.enable_mkldnn()
if args.precision == "int8":
pred_cfg.enable_mkldnn_int8({"conv2d"})
-
- if global_config['model_type']=="rec":
+
+ if global_config["model_type"] == "rec":
# delete pass which influence the accuracy, please refer to https://github.com/PaddlePaddle/Paddle/issues/55290
pred_cfg.delete_pass("fc_mkldnn_pass")
pred_cfg.delete_pass("fc_act_mkldnn_fuse_pass")
@@ -101,15 +103,17 @@ def load_predictor(args):
# To collect the dynamic shapes of inputs for TensorRT engine
dynamic_shape_file = os.path.join(args.model_path, "dynamic_shape.txt")
if os.path.exists(dynamic_shape_file):
- pred_cfg.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
- True)
+ pred_cfg.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True)
print("trt set dynamic shape done!")
precision_map = {
"fp16": PrecisionType.Half,
"fp32": PrecisionType.Float32,
- "int8": PrecisionType.Int8
+ "int8": PrecisionType.Int8,
}
- if args.precision == 'int8' and "ppocrv4_det_server_qat_dist.yaml" in args.config_path:
+ if (
+ args.precision == "int8"
+ and "ppocrv4_det_server_qat_dist.yaml" in args.config_path
+ ):
# Use the following settings only when the hardware is a Tesla V100. If you are using
# a RTX 3090, use the settings in the else branch.
pred_cfg.enable_tensorrt_engine(
@@ -118,16 +122,18 @@ def load_predictor(args):
min_subgraph_size=30,
precision_mode=precision_map[args.precision],
use_static=True,
- use_calib_mode=False, )
+ use_calib_mode=False,
+ )
pred_cfg.exp_disable_tensorrt_ops(["elementwise_add"])
- else:
+ else:
pred_cfg.enable_tensorrt_engine(
- workspace_size=1 << 30,
- max_batch_size=1,
- min_subgraph_size=4,
- precision_mode=precision_map[args.precision],
- use_static=True,
- use_calib_mode=False, )
+ workspace_size=1 << 30,
+ max_batch_size=1,
+ min_subgraph_size=4,
+ precision_mode=precision_map[args.precision],
+ use_static=True,
+ use_calib_mode=False,
+ )
else:
# pred_cfg.disable_gpu()
# pred_cfg.set_cpu_math_library_num_threads(24)
@@ -135,7 +141,6 @@ def load_predictor(args):
print("Start collect dynamic shape...")
rerun_flag = True
-
predictor = create_predictor(pred_cfg)
return predictor, rerun_flag
@@ -146,25 +151,23 @@ def eval(args):
"""
# DataLoader need run on cpu
paddle.set_device("cpu")
- devices = paddle.device.get_device().split(':')[0]
+ devices = paddle.device.get_device().split(":")[0]
- val_loader = build_dataloader(all_config, 'Eval', devices, logger)
- post_process_class = build_post_process(all_config['PostProcess'],
- global_config)
- eval_class = build_metric(all_config['Metric'])
- model_type = global_config['model_type']
+ val_loader = build_dataloader(all_config, "Eval", devices, logger)
+ post_process_class = build_post_process(all_config["PostProcess"], global_config)
+ eval_class = build_metric(all_config["Metric"])
+ model_type = global_config["model_type"]
predictor, rerun_flag = load_predictor(args)
if rerun_flag:
eval_dataset = find_images_with_bounding_size(val_loader.dataset)
batch_sampler = paddle.io.BatchSampler(
- eval_dataset, batch_size=1, shuffle=False, drop_last=False)
+ eval_dataset, batch_size=1, shuffle=False, drop_last=False
+ )
val_loader = paddle.io.DataLoader(
- eval_dataset,
- batch_sampler=batch_sampler,
- num_workers=4,
- return_list=True)
+ eval_dataset, batch_sampler=batch_sampler, num_workers=4, return_list=True
+ )
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
@@ -177,7 +180,6 @@ def eval(args):
print("Start evaluating ( total_iters: {}).".format(sample_nums))
for batch_id, batch in enumerate(val_loader):
-
images = np.array(batch[0])
batch_numpy = []
@@ -198,11 +200,11 @@ def eval(args):
time_max = max(time_max, timed)
predict_time += timed
- if model_type == 'det':
- preds_map = {'maps': preds}
+ if model_type == "det":
+ preds_map = {"maps": preds}
post_result = post_process_class(preds_map, batch_numpy[1])
eval_class(post_result, batch_numpy)
- elif model_type == 'rec':
+ elif model_type == "rec":
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
@@ -216,21 +218,21 @@ def eval(args):
print("Eval iter:", batch_id)
sys.stdout.flush()
-
metric = eval_class.get_metric()
-
+
time_avg = predict_time / sample_nums
print(
- "[Benchmark] Inference time(ms): min={}, max={}, avg={}".
- format(
- round(time_min * 1000, 2),
- round(time_max * 1000, 1), round(time_avg * 1000, 1)))
+ "[Benchmark] Inference time(ms): min={}, max={}, avg={}".format(
+ round(time_min * 1000, 2),
+ round(time_max * 1000, 1),
+ round(time_avg * 1000, 1),
+ )
+ )
for k, v in metric.items():
- print('{}:{}'.format(k, v))
+ print("{}:{}".format(k, v))
sys.stdout.flush()
-
def main():
global all_config, global_config
all_config = load_slim_config(args.config_path)
@@ -241,23 +243,25 @@ def main():
if __name__ == "__main__":
paddle.enable_static()
parser = argparse.ArgumentParser()
+ parser.add_argument("--model_path", type=str, help="inference model filepath")
parser.add_argument(
- "--model_path", type=str, help="inference model filepath")
- parser.add_argument(
- "--config_path",
+ "--config_path",
type=str,
- default='./configs/ppocrv3_det_qat_dist.yaml',
- help="path of compression strategy config.")
+ default="./configs/ppocrv3_det_qat_dist.yaml",
+ help="path of compression strategy config.",
+ )
parser.add_argument(
"--model_filename",
type=str,
default="inference.pdmodel",
- help="model file name")
+ help="model file name",
+ )
parser.add_argument(
"--params_filename",
type=str,
default="inference.pdiparams",
- help="params file name")
+ help="params file name",
+ )
parser.add_argument(
"--device",
type=str,
@@ -276,13 +280,13 @@ def main():
"--use_trt",
type=bool,
default=False,
- help="Whether to use tensorrt engine or not.")
+ help="Whether to use tensorrt engine or not.",
+ )
parser.add_argument(
- "--use_mkldnn",
- type=bool,
- default=False,
- help="Whether use mkldnn or not.")
+ "--use_mkldnn", type=bool, default=False, help="Whether use mkldnn or not."
+ )
parser.add_argument(
- "--cpu_threads", type=int, default=10, help="Num of cpu threads.")
+ "--cpu_threads", type=int, default=10, help="Num of cpu threads."
+ )
args = parser.parse_args()
- main()
\ No newline at end of file
+ main()
diff --git a/deploy/slim/prune/README_en.md b/deploy/slim/prune/README_en.md
index a2ee2cba48..9a0ed52911 100644
--- a/deploy/slim/prune/README_en.md
+++ b/deploy/slim/prune/README_en.md
@@ -38,7 +38,7 @@ PaddleOCR also provides a series of [models](../../../doc/doc_en/models_list_en.
After the pre-trained model is loaded, sensitivity analysis is performed on each network layer of the model to understand the redundancy of each network layer, and save a sensitivity file which named: sen.pickle. After that, user could load the sensitivity file via the [methods provided by PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L221) and determining the pruning ratio of each network layer automatically. For specific details of sensitivity analysis, see:[Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/en/tutorials/image_classification_sensitivity_analysis_tutorial_en.md)
The data format of sensitivity file:
-```
+```
sen.pickle(Dict){
'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
'layer_weight_name_1': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
diff --git a/deploy/slim/prune/export_prune_model.py b/deploy/slim/prune/export_prune_model.py
index b64b1d4c1e..6d191941c3 100644
--- a/deploy/slim/prune/export_prune_model.py
+++ b/deploy/slim/prune/export_prune_model.py
@@ -21,8 +21,8 @@
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
-sys.path.append(os.path.join(__dir__, '..', '..', '..'))
-sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
+sys.path.append(os.path.join(__dir__, "..", "..", ".."))
+sys.path.append(os.path.join(__dir__, "..", "..", "..", "tools"))
import paddle
from ppocr.data import build_dataloader, set_signal_handlers
@@ -35,56 +35,52 @@
def main(config, device, logger, vdl_writer):
-
- global_config = config['Global']
+ global_config = config["Global"]
# build dataloader
set_signal_handlers()
- valid_dataloader = build_dataloader(config, 'Eval', device, logger)
+ valid_dataloader = build_dataloader(config, "Eval", device, logger)
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
# for rec algorithm
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- config['Architecture']["Head"]['out_channels'] = char_num
- model = build_model(config['Architecture'])
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ config["Architecture"]["Head"]["out_channels"] = char_num
+ model = build_model(config["Architecture"])
- if config['Architecture']['model_type'] == 'det':
+ if config["Architecture"]["model_type"] == "det":
input_shape = [1, 3, 640, 640]
- elif config['Architecture']['model_type'] == 'rec':
+ elif config["Architecture"]["model_type"] == "rec":
input_shape = [1, 3, 32, 320]
flops = paddle.flops(model, input_shape)
logger.info("FLOPs before pruning: {}".format(flops))
from paddleslim.dygraph import FPGMFilterPruner
+
model.train()
pruner = FPGMFilterPruner(model, input_shape)
# build metric
- eval_class = build_metric(config['Metric'])
+ eval_class = build_metric(config["Metric"])
def eval_fn():
- metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class)
- if config['Architecture']['model_type'] == 'det':
- main_indicator = 'hmean'
+ metric = program.eval(model, valid_dataloader, post_process_class, eval_class)
+ if config["Architecture"]["model_type"] == "det":
+ main_indicator = "hmean"
else:
- main_indicator = 'acc'
- logger.info("metric[{}]: {}".format(main_indicator, metric[
- main_indicator]))
+ main_indicator = "acc"
+ logger.info("metric[{}]: {}".format(main_indicator, metric[main_indicator]))
return metric[main_indicator]
params_sensitive = pruner.sensitive(
eval_func=eval_fn,
sen_file="./sen.pickle",
- skip_vars=[
- "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
- ])
+ skip_vars=["conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"],
+ )
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
@@ -101,40 +97,41 @@ def eval_fn():
# load pretrain model
load_model(config, model)
- metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class)
- if config['Architecture']['model_type'] == 'det':
- main_indicator = 'hmean'
+ metric = program.eval(model, valid_dataloader, post_process_class, eval_class)
+ if config["Architecture"]["model_type"] == "det":
+ main_indicator = "hmean"
else:
- main_indicator = 'acc'
+ main_indicator = "acc"
logger.info("metric['']: {}".format(main_indicator, metric[main_indicator]))
# start export model
from paddle.jit import to_static
infer_shape = [3, -1, -1]
- if config['Architecture']['model_type'] == "rec":
+ if config["Architecture"]["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
- if 'Transform' in config['Architecture'] and config['Architecture'][
- 'Transform'] is not None and config['Architecture'][
- 'Transform']['name'] == 'TPS':
+ if (
+ "Transform" in config["Architecture"]
+ and config["Architecture"]["Transform"] is not None
+ and config["Architecture"]["Transform"]["name"] == "TPS"
+ ):
logger.info(
- 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
+ "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
- paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype='float32')
- ])
+ paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
+ ],
+ )
- save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
+ save_path = "{}/inference".format(config["Global"]["save_inference_dir"])
paddle.jit.save(model, save_path)
- logger.info('inference model is saved to {}'.format(save_path))
+ logger.info("inference model is saved to {}".format(save_path))
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py
index eb97029b0b..4fb2f6a2e0 100644
--- a/deploy/slim/prune/sensitivity_anal.py
+++ b/deploy/slim/prune/sensitivity_anal.py
@@ -21,8 +21,8 @@
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
-sys.path.append(os.path.join(__dir__, '..', '..', '..'))
-sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
+sys.path.append(os.path.join(__dir__, "..", "..", ".."))
+sys.path.append(os.path.join(__dir__, "..", "..", "..", "tools"))
import paddle
import paddle.distributed as dist
@@ -42,84 +42,95 @@ def get_pruned_params(parameters):
params = []
for param in parameters:
- if len(
- param.shape
- ) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
+ if (
+ len(param.shape) == 4
+ and "depthwise" not in param.name
+ and "transpose" not in param.name
+ and "conv2d_57" not in param.name
+ and "conv2d_56" not in param.name
+ ):
params.append(param.name)
return params
def main(config, device, logger, vdl_writer):
# init dist environment
- if config['Global']['distributed']:
+ if config["Global"]["distributed"]:
dist.init_parallel_env()
- global_config = config['Global']
+ global_config = config["Global"]
# build dataloader
set_signal_handlers()
- train_dataloader = build_dataloader(config, 'Train', device, logger)
- if config['Eval']:
- valid_dataloader = build_dataloader(config, 'Eval', device, logger)
+ train_dataloader = build_dataloader(config, "Train", device, logger)
+ if config["Eval"]:
+ valid_dataloader = build_dataloader(config, "Eval", device, logger)
else:
valid_dataloader = None
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
# for rec algorithm
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- config['Architecture']["Head"]['out_channels'] = char_num
- model = build_model(config['Architecture'])
- if config['Architecture']['model_type'] == 'det':
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ config["Architecture"]["Head"]["out_channels"] = char_num
+ model = build_model(config["Architecture"])
+ if config["Architecture"]["model_type"] == "det":
input_shape = [1, 3, 640, 640]
- elif config['Architecture']['model_type'] == 'rec':
+ elif config["Architecture"]["model_type"] == "rec":
input_shape = [1, 3, 32, 320]
flops = paddle.flops(model, input_shape)
logger.info("FLOPs before pruning: {}".format(flops))
from paddleslim.dygraph import FPGMFilterPruner
+
model.train()
pruner = FPGMFilterPruner(model, input_shape)
# build loss
- loss_class = build_loss(config['Loss'])
+ loss_class = build_loss(config["Loss"])
# build optim
optimizer, lr_scheduler = build_optimizer(
- config['Optimizer'],
- epochs=config['Global']['epoch_num'],
+ config["Optimizer"],
+ epochs=config["Global"]["epoch_num"],
step_each_epoch=len(train_dataloader),
- model=model)
+ model=model,
+ )
# build metric
- eval_class = build_metric(config['Metric'])
+ eval_class = build_metric(config["Metric"])
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer)
- logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
- format(len(train_dataloader), len(valid_dataloader)))
+ logger.info(
+ "train dataloader has {} iters, valid dataloader has {} iters".format(
+ len(train_dataloader), len(valid_dataloader)
+ )
+ )
# build metric
- eval_class = build_metric(config['Metric'])
+ eval_class = build_metric(config["Metric"])
- logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
- format(len(train_dataloader), len(valid_dataloader)))
+ logger.info(
+ "train dataloader has {} iters, valid dataloader has {} iters".format(
+ len(train_dataloader), len(valid_dataloader)
+ )
+ )
def eval_fn():
- metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, False)
- if config['Architecture']['model_type'] == 'det':
- main_indicator = 'hmean'
+ metric = program.eval(
+ model, valid_dataloader, post_process_class, eval_class, False
+ )
+ if config["Architecture"]["model_type"] == "det":
+ main_indicator = "hmean"
else:
- main_indicator = 'acc'
+ main_indicator = "acc"
- logger.info("metric[{}]: {}".format(main_indicator, metric[
- main_indicator]))
+ logger.info("metric[{}]: {}".format(main_indicator, metric[main_indicator]))
return metric[main_indicator]
run_sensitive_analysis = False
@@ -141,21 +152,22 @@ def eval_fn():
eval_func=eval_fn,
sen_file="./deploy/slim/prune/sen.pickle",
skip_vars=[
- "conv2d_57.w_0", "conv2d_transpose_2.w_0",
- "conv2d_transpose_3.w_0"
- ])
+ "conv2d_57.w_0",
+ "conv2d_transpose_2.w_0",
+ "conv2d_transpose_3.w_0",
+ ],
+ )
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
)
# calculate pruned params's ratio
- params_sensitive = pruner._get_ratios_by_loss(
- params_sensitive, loss=0.02)
+ params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info("{}, {}".format(key, params_sensitive[key]))
else:
params_sensitive = {}
for param in model.parameters():
- if 'transpose' not in param.name and 'linear' not in param.name:
+ if "transpose" not in param.name and "linear" not in param.name:
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
params_sensitive[param.name] = 0.1
@@ -166,11 +178,23 @@ def eval_fn():
# start train
- program.train(config, train_dataloader, valid_dataloader, device, model,
- loss_class, optimizer, lr_scheduler, post_process_class,
- eval_class, pre_best_model_dict, logger, vdl_writer)
-
-
-if __name__ == '__main__':
+ program.train(
+ config,
+ train_dataloader,
+ valid_dataloader,
+ device,
+ model,
+ loss_class,
+ optimizer,
+ lr_scheduler,
+ post_process_class,
+ eval_class,
+ pre_best_model_dict,
+ logger,
+ vdl_writer,
+ )
+
+
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py
index 30696f3e36..12339534f8 100755
--- a/deploy/slim/quantization/export_model.py
+++ b/deploy/slim/quantization/export_model.py
@@ -17,9 +17,8 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
-sys.path.insert(
- 0, os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..", "..", "..")))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..", "..", "..", "tools")))
import argparse
@@ -43,26 +42,26 @@ def main():
# 1. quantization configs
############################################################################################################
quant_config = {
- # weight preprocess type, default is None and no preprocessing is performed.
- 'weight_preprocess_type': None,
+ # weight preprocess type, default is None and no preprocessing is performed.
+ "weight_preprocess_type": None,
# activation preprocess type, default is None and no preprocessing is performed.
- 'activation_preprocess_type': None,
+ "activation_preprocess_type": None,
# weight quantize type, default is 'channel_wise_abs_max'
- 'weight_quantize_type': 'channel_wise_abs_max',
+ "weight_quantize_type": "channel_wise_abs_max",
# activation quantize type, default is 'moving_average_abs_max'
- 'activation_quantize_type': 'moving_average_abs_max',
+ "activation_quantize_type": "moving_average_abs_max",
# weight quantize bit num, default is 8
- 'weight_bits': 8,
+ "weight_bits": 8,
# activation quantize bit num, default is 8
- 'activation_bits': 8,
+ "activation_bits": 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
- 'dtype': 'int8',
+ "dtype": "int8",
# window size for 'range_abs_max' quantization. default is 10000
- 'window_size': 10000,
+ "window_size": 10000,
# The decay coefficient of moving average, default is 0.9
- 'moving_rate': 0.9,
+ "moving_rate": 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
- 'quantizable_layer_type': ['Conv2D', 'Linear'],
+ "quantizable_layer_type": ["Conv2D", "Linear"],
}
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
@@ -70,59 +69,62 @@ def main():
logger = get_logger()
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- config['Global'])
+ post_process_class = build_post_process(config["PostProcess"], config["Global"])
# build model
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- if config['Architecture']["algorithm"] in ["Distillation",
- ]: # distillation model
- for key in config['Architecture']["Models"]:
- if config['Architecture']['Models'][key]['Head'][
- 'name'] == 'MultiHead': # for multi head
- if config['PostProcess'][
- 'name'] == 'DistillationSARLabelDecode':
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ if config["Architecture"]["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
+ for key in config["Architecture"]["Models"]:
+ if (
+ config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
+ ): # for multi head
+ if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
# update SARLoss params
- assert list(config['Loss']['loss_config_list'][-1].keys())[
- 0] == 'DistillationSARLoss'
- config['Loss']['loss_config_list'][-1][
- 'DistillationSARLoss']['ignore_index'] = char_num + 1
+ assert (
+ list(config["Loss"]["loss_config_list"][-1].keys())[0]
+ == "DistillationSARLoss"
+ )
+ config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
+ "ignore_index"
+ ] = (char_num + 1)
out_channels_list = {}
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- config['Architecture']['Models'][key]['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels_list"
+ ] = out_channels_list
else:
- config['Architecture']["Models"][key]["Head"][
- 'out_channels'] = char_num
- elif config['Architecture']['Head'][
- 'name'] == 'MultiHead': # for multi head
- if config['PostProcess']['name'] == 'SARLabelDecode':
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels"
+ ] = char_num
+ elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
+ if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
# update SARLoss params
- assert list(config['Loss']['loss_config_list'][1].keys())[
- 0] == 'SARLoss'
- if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
- config['Loss']['loss_config_list'][1]['SARLoss'] = {
- 'ignore_index': char_num + 1
+ assert list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss"
+ if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
+ config["Loss"]["loss_config_list"][1]["SARLoss"] = {
+ "ignore_index": char_num + 1
}
else:
- config['Loss']['loss_config_list'][1]['SARLoss'][
- 'ignore_index'] = char_num + 1
+ config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
+ char_num + 1
+ )
out_channels_list = {}
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- config['Architecture']['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
- config['Architecture']["Head"]['out_channels'] = char_num
+ config["Architecture"]["Head"]["out_channels"] = char_num
- if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
- config['Loss']['ignore_index'] = char_num - 1
+ if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
+ config["Loss"]["ignore_index"] = char_num - 1
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
# get QAT model
quanter = QAT(config=quant_config)
@@ -131,45 +133,55 @@ def main():
load_model(config, model)
# build metric
- eval_class = build_metric(config['Metric'])
+ eval_class = build_metric(config["Metric"])
# build dataloader
set_signal_handlers()
- valid_dataloader = build_dataloader(config, 'Eval', device, logger)
+ valid_dataloader = build_dataloader(config, "Eval", device, logger)
- use_srn = config['Architecture']['algorithm'] == "SRN"
- model_type = config['Architecture'].get('model_type', None)
+ use_srn = config["Architecture"]["algorithm"] == "SRN"
+ model_type = config["Architecture"].get("model_type", None)
# start eval
- metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, model_type, use_srn)
+ metric = program.eval(
+ model, valid_dataloader, post_process_class, eval_class, model_type, use_srn
+ )
model.eval()
- logger.info('metric eval ***************')
+ logger.info("metric eval ***************")
for k, v in metric.items():
- logger.info('{}:{}'.format(k, v))
+ logger.info("{}:{}".format(k, v))
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
- if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
- "name"] != 'MultiHead':
- input_shape = config["Eval"]["dataset"]["transforms"][-2][
- 'SVTRRecResizeImg']['image_shape']
+ if (
+ arch_config["algorithm"] == "SVTR"
+ and arch_config["Head"]["name"] != "MultiHead"
+ ):
+ input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
+ "image_shape"
+ ]
else:
input_shape = None
- if arch_config["algorithm"] in ["Distillation", ]: # distillation model
+ if arch_config["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
- export_single_model(model.model_list[idx], archs[idx],
- sub_model_save_path, logger, input_shape,
- quanter)
+ export_single_model(
+ model.model_list[idx],
+ archs[idx],
+ sub_model_save_path,
+ logger,
+ input_shape,
+ quanter,
+ )
else:
save_path = os.path.join(save_path, "inference")
- export_single_model(model, arch_config, save_path, logger, input_shape,
- quanter)
+ export_single_model(model, arch_config, save_path, logger, input_shape, quanter)
if __name__ == "__main__":
diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py
index a580ce4346..8bcc553e0c 100755
--- a/deploy/slim/quantization/quant.py
+++ b/deploy/slim/quantization/quant.py
@@ -21,9 +21,8 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
-sys.path.append(
- os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..", "..", "..")))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..", "..", "..", "tools")))
import yaml
import paddle
@@ -51,10 +50,10 @@ def __init__(self):
name=self.full_name() + ".pact",
initializer=paddle.nn.initializer.Constant(value=20),
learning_rate=1.0,
- regularizer=paddle.regularizer.L2Decay(2e-5))
+ regularizer=paddle.regularizer.L2Decay(2e-5),
+ )
- self.alpha = self.create_parameter(
- shape=[1], attr=alpha_attr, dtype='float32')
+ self.alpha = self.create_parameter(shape=[1], attr=alpha_attr, dtype="float32")
def forward(self, x):
out_left = paddle.nn.functional.relu(x - self.alpha)
@@ -64,140 +63,164 @@ def forward(self, x):
quant_config = {
- # weight preprocess type, default is None and no preprocessing is performed.
- 'weight_preprocess_type': None,
+ # weight preprocess type, default is None and no preprocessing is performed.
+ "weight_preprocess_type": None,
# activation preprocess type, default is None and no preprocessing is performed.
- 'activation_preprocess_type': None,
+ "activation_preprocess_type": None,
# weight quantize type, default is 'channel_wise_abs_max'
- 'weight_quantize_type': 'channel_wise_abs_max',
+ "weight_quantize_type": "channel_wise_abs_max",
# activation quantize type, default is 'moving_average_abs_max'
- 'activation_quantize_type': 'moving_average_abs_max',
+ "activation_quantize_type": "moving_average_abs_max",
# weight quantize bit num, default is 8
- 'weight_bits': 8,
+ "weight_bits": 8,
# activation quantize bit num, default is 8
- 'activation_bits': 8,
+ "activation_bits": 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
- 'dtype': 'int8',
+ "dtype": "int8",
# window size for 'range_abs_max' quantization. default is 10000
- 'window_size': 10000,
+ "window_size": 10000,
# The decay coefficient of moving average, default is 0.9
- 'moving_rate': 0.9,
+ "moving_rate": 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
- 'quantizable_layer_type': ['Conv2D', 'Linear'],
+ "quantizable_layer_type": ["Conv2D", "Linear"],
}
def main(config, device, logger, vdl_writer):
# init dist environment
- if config['Global']['distributed']:
+ if config["Global"]["distributed"]:
dist.init_parallel_env()
- global_config = config['Global']
+ global_config = config["Global"]
# build dataloader
set_signal_handlers()
- train_dataloader = build_dataloader(config, 'Train', device, logger)
- if config['Eval']:
- valid_dataloader = build_dataloader(config, 'Eval', device, logger)
+ train_dataloader = build_dataloader(config, "Train", device, logger)
+ if config["Eval"]:
+ valid_dataloader = build_dataloader(config, "Eval", device, logger)
else:
valid_dataloader = None
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
# for rec algorithm
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- if config['Architecture']["algorithm"] in ["Distillation",
- ]: # distillation model
- for key in config['Architecture']["Models"]:
- if config['Architecture']['Models'][key]['Head'][
- 'name'] == 'MultiHead': # for multi head
- if config['PostProcess'][
- 'name'] == 'DistillationSARLabelDecode':
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ if config["Architecture"]["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
+ for key in config["Architecture"]["Models"]:
+ if (
+ config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
+ ): # for multi head
+ if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
# update SARLoss params
- assert list(config['Loss']['loss_config_list'][-1].keys())[
- 0] == 'DistillationSARLoss'
- config['Loss']['loss_config_list'][-1][
- 'DistillationSARLoss']['ignore_index'] = char_num + 1
+ assert (
+ list(config["Loss"]["loss_config_list"][-1].keys())[0]
+ == "DistillationSARLoss"
+ )
+ config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
+ "ignore_index"
+ ] = (char_num + 1)
out_channels_list = {}
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- config['Architecture']['Models'][key]['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels_list"
+ ] = out_channels_list
else:
- config['Architecture']["Models"][key]["Head"][
- 'out_channels'] = char_num
- elif config['Architecture']['Head'][
- 'name'] == 'MultiHead': # for multi head
- if config['PostProcess']['name'] == 'SARLabelDecode':
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels"
+ ] = char_num
+ elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
+ if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
# update SARLoss params
- assert list(config['Loss']['loss_config_list'][1].keys())[
- 0] == 'SARLoss'
- if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
- config['Loss']['loss_config_list'][1]['SARLoss'] = {
- 'ignore_index': char_num + 1
+ assert list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss"
+ if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
+ config["Loss"]["loss_config_list"][1]["SARLoss"] = {
+ "ignore_index": char_num + 1
}
else:
- config['Loss']['loss_config_list'][1]['SARLoss'][
- 'ignore_index'] = char_num + 1
+ config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
+ char_num + 1
+ )
out_channels_list = {}
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- config['Architecture']['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
- config['Architecture']["Head"]['out_channels'] = char_num
+ config["Architecture"]["Head"]["out_channels"] = char_num
- if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
- config['Loss']['ignore_index'] = char_num - 1
- model = build_model(config['Architecture'])
+ if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
+ config["Loss"]["ignore_index"] = char_num - 1
+ model = build_model(config["Architecture"])
pre_best_model_dict = dict()
# load fp32 model to begin quantization
- pre_best_model_dict = load_model(config, model, None, config['Architecture']["model_type"])
+ pre_best_model_dict = load_model(
+ config, model, None, config["Architecture"]["model_type"]
+ )
freeze_params = False
- if config['Architecture']["algorithm"] in ["Distillation"]:
- for key in config['Architecture']["Models"]:
- freeze_params = freeze_params or config['Architecture']['Models'][
- key].get('freeze_params', False)
+ if config["Architecture"]["algorithm"] in ["Distillation"]:
+ for key in config["Architecture"]["Models"]:
+ freeze_params = freeze_params or config["Architecture"]["Models"][key].get(
+ "freeze_params", False
+ )
act = None if freeze_params else PACT
quanter = QAT(config=quant_config, act_preprocess=act)
quanter.quantize(model)
- if config['Global']['distributed']:
+ if config["Global"]["distributed"]:
model = paddle.DataParallel(model)
# build loss
- loss_class = build_loss(config['Loss'])
+ loss_class = build_loss(config["Loss"])
# build optim
optimizer, lr_scheduler = build_optimizer(
- config['Optimizer'],
- epochs=config['Global']['epoch_num'],
+ config["Optimizer"],
+ epochs=config["Global"]["epoch_num"],
step_each_epoch=len(train_dataloader),
- model=model)
+ model=model,
+ )
# resume PACT training process
- pre_best_model_dict = load_model(config, model, optimizer, config['Architecture']["model_type"])
+ pre_best_model_dict = load_model(
+ config, model, optimizer, config["Architecture"]["model_type"]
+ )
# build metric
- eval_class = build_metric(config['Metric'])
+ eval_class = build_metric(config["Metric"])
- logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
- format(len(train_dataloader), len(valid_dataloader)))
+ logger.info(
+ "train dataloader has {} iters, valid dataloader has {} iters".format(
+ len(train_dataloader), len(valid_dataloader)
+ )
+ )
# start train
- program.train(config, train_dataloader, valid_dataloader, device, model,
- loss_class, optimizer, lr_scheduler, post_process_class,
- eval_class, pre_best_model_dict, logger, vdl_writer)
-
-
-if __name__ == '__main__':
+ program.train(
+ config,
+ train_dataloader,
+ valid_dataloader,
+ device,
+ model,
+ loss_class,
+ optimizer,
+ lr_scheduler,
+ post_process_class,
+ eval_class,
+ pre_best_model_dict,
+ logger,
+ vdl_writer,
+ )
+
+
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
diff --git a/deploy/slim/quantization/quant_kl.py b/deploy/slim/quantization/quant_kl.py
index 71bf6bbd4c..f367203c57 100755
--- a/deploy/slim/quantization/quant_kl.py
+++ b/deploy/slim/quantization/quant_kl.py
@@ -21,9 +21,8 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
-sys.path.append(
- os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..", "..", "..")))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..", "..", "..", "tools")))
import yaml
import paddle
@@ -53,10 +52,10 @@ def __init__(self):
name=self.full_name() + ".pact",
initializer=paddle.nn.initializer.Constant(value=20),
learning_rate=1.0,
- regularizer=paddle.regularizer.L2Decay(2e-5))
+ regularizer=paddle.regularizer.L2Decay(2e-5),
+ )
- self.alpha = self.create_parameter(
- shape=[1], attr=alpha_attr, dtype='float32')
+ self.alpha = self.create_parameter(shape=[1], attr=alpha_attr, dtype="float32")
def forward(self, x):
out_left = paddle.nn.functional.relu(x - self.alpha)
@@ -66,26 +65,26 @@ def forward(self, x):
quant_config = {
- # weight preprocess type, default is None and no preprocessing is performed.
- 'weight_preprocess_type': None,
+ # weight preprocess type, default is None and no preprocessing is performed.
+ "weight_preprocess_type": None,
# activation preprocess type, default is None and no preprocessing is performed.
- 'activation_preprocess_type': None,
+ "activation_preprocess_type": None,
# weight quantize type, default is 'channel_wise_abs_max'
- 'weight_quantize_type': 'channel_wise_abs_max',
+ "weight_quantize_type": "channel_wise_abs_max",
# activation quantize type, default is 'moving_average_abs_max'
- 'activation_quantize_type': 'moving_average_abs_max',
+ "activation_quantize_type": "moving_average_abs_max",
# weight quantize bit num, default is 8
- 'weight_bits': 8,
+ "weight_bits": 8,
# activation quantize bit num, default is 8
- 'activation_bits': 8,
+ "activation_bits": 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
- 'dtype': 'int8',
+ "dtype": "int8",
# window size for 'range_abs_max' quantization. default is 10000
- 'window_size': 10000,
+ "window_size": 10000,
# The decay coefficient of moving average, default is 0.9
- 'moving_rate': 0.9,
+ "moving_rate": 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
- 'quantizable_layer_type': ['Conv2D', 'Linear'],
+ "quantizable_layer_type": ["Conv2D", "Linear"],
}
@@ -97,6 +96,7 @@ def __reader__():
return __reader__
+
def sample_generator_layoutxlm_ser(loader):
def __reader__():
for indx, data in enumerate(loader):
@@ -109,21 +109,25 @@ def __reader__():
return __reader__
+
def main(config, device, logger, vdl_writer):
# init dist environment
- if config['Global']['distributed']:
+ if config["Global"]["distributed"]:
dist.init_parallel_env()
- global_config = config['Global']
+ global_config = config["Global"]
# build dataloader
set_signal_handlers()
- config['Train']['loader']['num_workers'] = 0
- is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer'
- train_dataloader = build_dataloader(config, 'Train', device, logger)
- if config['Eval']:
- config['Eval']['loader']['num_workers'] = 0
- valid_dataloader = build_dataloader(config, 'Eval', device, logger)
+ config["Train"]["loader"]["num_workers"] = 0
+ is_layoutxlm_ser = (
+ config["Architecture"]["model_type"] == "kie"
+ and config["Architecture"]["Backbone"]["name"] == "LayoutXLMForSer"
+ )
+ train_dataloader = build_dataloader(config, "Train", device, logger)
+ if config["Eval"]:
+ config["Eval"]["loader"]["num_workers"] = 0
+ valid_dataloader = build_dataloader(config, "Eval", device, logger)
if is_layoutxlm_ser:
train_dataloader = valid_dataloader
else:
@@ -132,16 +136,18 @@ def main(config, device, logger, vdl_writer):
paddle.enable_static()
exe = paddle.static.Executor(device)
- if 'inference_model' in global_config.keys(): # , 'inference_model'):
- inference_model_dir = global_config['inference_model']
+ if "inference_model" in global_config.keys(): # , 'inference_model'):
+ inference_model_dir = global_config["inference_model"]
else:
- inference_model_dir = os.path.dirname(global_config['pretrained_model'])
- if not (os.path.exists(os.path.join(inference_model_dir, "inference.pdmodel")) and \
- os.path.exists(os.path.join(inference_model_dir, "inference.pdiparams")) ):
+ inference_model_dir = os.path.dirname(global_config["pretrained_model"])
+ if not (
+ os.path.exists(os.path.join(inference_model_dir, "inference.pdmodel"))
+ and os.path.exists(os.path.join(inference_model_dir, "inference.pdiparams"))
+ ):
raise ValueError(
"Please set inference model dir in Global.inference_model or Global.pretrained_model for post-quantization"
)
-
+
if is_layoutxlm_ser:
generator = sample_generator_layoutxlm_ser(train_dataloader)
else:
@@ -150,16 +156,17 @@ def main(config, device, logger, vdl_writer):
paddleslim.quant.quant_post_static(
executor=exe,
model_dir=inference_model_dir,
- model_filename='inference.pdmodel',
- params_filename='inference.pdiparams',
- quantize_model_path=global_config['save_inference_dir'],
+ model_filename="inference.pdmodel",
+ params_filename="inference.pdiparams",
+ quantize_model_path=global_config["save_inference_dir"],
sample_generator=generator,
- save_model_filename='inference.pdmodel',
- save_params_filename='inference.pdiparams',
+ save_model_filename="inference.pdmodel",
+ save_params_filename="inference.pdiparams",
batch_size=1,
- batch_nums=None)
+ batch_nums=None,
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
diff --git a/doc/doc_ch/PP-OCRv4_introduction.md b/doc/doc_ch/PP-OCRv4_introduction.md
index b3dbe76e8b..a6694af6d9 100644
--- a/doc/doc_ch/PP-OCRv4_introduction.md
+++ b/doc/doc_ch/PP-OCRv4_introduction.md
@@ -176,4 +176,3 @@ GTC(Guided Training of CTC),是PP-OCRv3识别模型的最有效的策略
| PP-OCR_mul | 69.60% | 40.50% | 38.50% | 55.40% |
| PP-OCRv3_mul | 71.57%| 72.90% | 45.85% | 77.23% |
| PP-OCRv4_mul | 80.00%| 75.48% | 56.50% | 83.25% |
-
diff --git a/doc/doc_ch/algorithm_rec-satrn.md b/doc/doc_ch/algorithm_rec-satrn.md
index ec55af03de..f59b61c2d4 100644
--- a/doc/doc_ch/algorithm_rec-satrn.md
+++ b/doc/doc_ch/algorithm_rec-satrn.md
@@ -102,11 +102,11 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png"
```bibtex
@article{lee2019recognizing,
- title={On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention},
+ title={On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention},
author={Junyeop Lee and Sungrae Park and Jeonghun Baek and Seong Joon Oh and Seonghyeon Kim and Hwalsuk Lee},
year={2019},
eprint={1910.04396},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
-```
\ No newline at end of file
+```
diff --git a/doc/doc_ch/clone.md b/doc/doc_ch/clone.md
index f2ec15fd26..98a63d19a2 100644
--- a/doc/doc_ch/clone.md
+++ b/doc/doc_ch/clone.md
@@ -20,4 +20,3 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
cd PaddleOCR
pip3 install -r requirements.txt
```
-
diff --git a/doc/doc_ch/code_and_doc.md b/doc/doc_ch/code_and_doc.md
index 43e28d7ab9..ce11015795 100644
--- a/doc/doc_ch/code_and_doc.md
+++ b/doc/doc_ch/code_and_doc.md
@@ -14,14 +14,14 @@
PaddleOCR的Python代码遵循 [PEP8规范](https://www.python.org/dev/peps/pep-0008/),其中一些关注的重点包括如下内容
-- 空格
+- 空格
- 空格应该加在逗号、分号、冒号后,而非他们的前面
```python
# 正确:
print(x, y)
-
+
# 错误:
print(x , y)
```
@@ -53,27 +53,27 @@ PaddleOCR的Python代码遵循 [PEP8规范](https://www.python.org/dev/peps/pep-
```python
def fetch_bigtable_rows(big_table, keys, other_silly_variable=None):
"""Fetches rows from a Bigtable.
-
+
Retrieves rows pertaining to the given keys from the Table instance
represented by big_table. Silly things may happen if
other_silly_variable is not None.
-
+
Args:
big_table: An open Bigtable Table instance.
keys: A sequence of strings representing the key of each table row
to fetch.
other_silly_variable: Another optional variable, that has a much
longer name than the other args, and which does nothing.
-
+
Returns:
A dict mapping keys to the corresponding table row data
fetched. Each row is represented as a tuple of strings. For
example:
-
+
{'Serak': ('Rigel VII', 'Preparer'),
'Zim': ('Irk', 'Invader'),
'Lrrr': ('Omicron Persei 8', 'Emperor')}
-
+
If a key from the keys argument is missing from the dictionary,
then that row was not found in the table.
"""
@@ -92,7 +92,7 @@ PaddleOCR的Python代码遵循 [PEP8规范](https://www.python.org/dev/peps/pep-
- 新增Markdown文档格式:目录 - 正文 - FAQ
- > 目录生成方法可以使用 [此网站](https://ecotrust-canada.github.io/markdown-toc/) 将md内容复制之后自动提取目录,然后在md文件的每个标题前添加 ``
+ > 目录生成方法可以使用 [此网站](https://ecotrust-canada.github.io/markdown-toc/) 将md内容复制之后自动提取目录,然后在md文件的每个标题前添加 ``
- 中英双语:任何对文档的改动或新增都需要分别在中文和英文文档上进行。
@@ -211,7 +211,7 @@ git checkout -b new_branch upstream/dygraph
> ```
> # 基于用户远程仓库(origin)的dygraph创建new_branch分支
> git checkout -b new_branch origin/dygraph
->
+>
> # 基于用户远程仓库(origin)的默认分支创建new_branch分支
> git checkout -b new_branch
> ```
@@ -257,12 +257,12 @@ pre-commit
提交修改,并写明修改内容("your commit info")
```
-git commit -m "your commit info"
+git commit -m "your commit info"
```
#### 3.2.6 Push到远程仓库
-使用push命令将修改的commit提交到 `远程仓库`
+使用push命令将修改的commit提交到 `远程仓库`
```
git push origin new_branch
@@ -299,7 +299,7 @@ git push origin new_branch
```
# 切换到dygraph分支,否则无法删除当前分支
git checkout dygraph
-
+
# 删除new_branch分支
git branch -D new_branch
```
diff --git a/doc/doc_ch/data_synthesis.md b/doc/doc_ch/data_synthesis.md
index 8c43ac2e3b..0b2b95af16 100644
--- a/doc/doc_ch/data_synthesis.md
+++ b/doc/doc_ch/data_synthesis.md
@@ -6,4 +6,4 @@
- [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator)
- [SynthText3D](https://github.com/MhLiao/SynthText3D)
- [UnrealText](https://github.com/Jyouhou/UnrealText/)
-- [SynthTIGER](https://github.com/clovaai/synthtiger)
\ No newline at end of file
+- [SynthTIGER](https://github.com/clovaai/synthtiger)
diff --git a/doc/doc_ch/enhanced_ctc_loss.md b/doc/doc_ch/enhanced_ctc_loss.md
index 8c0856a7a7..c85883747e 100644
--- a/doc/doc_ch/enhanced_ctc_loss.md
+++ b/doc/doc_ch/enhanced_ctc_loss.md
@@ -5,53 +5,53 @@
## 1. Focal-CTC Loss
Focal Loss 出自论文《Focal Loss for Dense Object Detection》, 该loss最先提出的时候主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
其损失函数形式如下:
-
-
+
+
-
-其中, y' 是经过激活函数的输出,取值在0-1之间。其在原始的交叉熵损失的基础上加了一个调制系数(1 – y’)^ γ和平衡因子α。 当α = 1,y=1时,其损失函数与交叉熵损失的对比如下图所示:
-
-
+
+其中, y' 是经过激活函数的输出,取值在0-1之间。其在原始的交叉熵损失的基础上加了一个调制系数(1 – y’)^ γ和平衡因子α。 当α = 1,y=1时,其损失函数与交叉熵损失的对比如下图所示:
+
+
从上图可以看到, 当γ> 0时,调整系数(1-y’)^γ 赋予易分类样本损失一个更小的权重,使得网络更关注于困难的、错分的样本。 调整因子γ用于调节简单样本权重降低的速率,当γ为0时即为交叉熵损失函数,当γ增加时,调整因子的影响也会随之增大。实验发现γ为2是最优。平衡因子α用来平衡正负样本本身的比例不均,文中α取0.25。
对于经典的CTC算法,假设某个特征序列(f 1, f 2, ......f t), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现, CTCLoss值和y’有如下关系:
-
-
+
+
结合Focal Loss的思想,赋予困难样本较大的权重,简单样本较小的权重,可以使网络更加聚焦于对困难样本的挖掘,进一步提升识别的准确率,由此我们提出了Focal-CTC Loss; 其定义如下所示:
-
-
+
+
实验中,γ取值为2, α= 1, 具体实现见: [rec_ctc_loss.py](../../ppocr/losses/rec_ctc_loss.py)
## 2. A-CTC Loss
-A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggregation Cross-Entropy for Sequence Recognition>. ACE Loss相比于CTCLoss,主要有如下两点优势:
+A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggregation Cross-Entropy for Sequence Recognition>. ACE Loss相比于CTCLoss,主要有如下两点优势:
+ ACE Loss能够解决2-D文本的识别问题; CTCLoss只能够处理1-D文本
+ ACE Loss 在时间复杂度和空间复杂度上优于CTC loss
前人总结的OCR识别算法的优劣如下图所示:
-
+
-
+
虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行结合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。
A_CTC Loss定义如下:
-
+
实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py)
## 3. C-CTC Loss
-C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大类间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
+C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大类间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。
通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下:
-
+
实验中,我们设置λ=0.25. center_loss实现代码见: [center_loss.py](../../ppocr/losses/center_loss.py)
@@ -60,7 +60,7 @@ C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 <
+ 基于原始的CTCLoss, 训练得到一个网络N
+ 挑选出训练集中,识别完全正确的部分, 组成集合G
+ 将G中的每个样本送入网络,进行前向计算, 提取最后一个FC层的输入(即feature)及其经过argmax计算的结果(即index)之间的对应关系
-+ 将相同index的feature进行聚合,计算平均值,得到各自字符的初始center.
++ 将相同index的feature进行聚合,计算平均值,得到各自字符的初始center.
以配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`为例, center提取命令如下所示:
```
@@ -72,7 +72,7 @@ python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o
对于上述的三种方案,我们基于百度内部数据集进行了训练、评测,实验情况如下表所示:
|algorithm| Focal_CTC | A_CTC | C-CTC |
|:------| :------| ------: | :------: |
-|gain| +0.3% | +0.7% | +1.7% |
+|gain| +0.3% | +0.7% | +1.7% |
基于上述实验结论,我们在PP-OCRv2中,采用了C-CTC的策略。 值得一提的是,由于PP-OCRv2 处理的是6625个中文字符的识别任务,字符集比较大,形似字较多,所以在该任务上C-CTC 方案带来的提升较大。 但如果换做其他OCR识别任务,结论可能会有所不同。大家可以尝试Focal-CTC,A-CTC, C-CTC以及组合方案EnhancedCTC,相信会带来不同程度的提升效果。
统一的融合方案见如下文件: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py)
diff --git a/doc/doc_ch/models.md b/doc/doc_ch/models.md
index bd798e05b3..ffe646c167 100644
--- a/doc/doc_ch/models.md
+++ b/doc/doc_ch/models.md
@@ -44,4 +44,3 @@ OCR识别算法的输入数据一般是文本行,背景信息不多,文字
PaddleOCR 中集成了很多OCR算法,文本检测算法有DB、EAST、SAST等等,文本识别算法有CRNN、RARE、StarNet、Rosetta、SRN等算法。
其中PaddleOCR针对中英文自然场景通用OCR,推出了PP-OCR系列模型,PP-OCR模型由DB+CRNN算法组成,利用海量中文数据训练加上模型调优方法,在中文场景上具备较高的文本检测识别能力。并且PaddleOCR推出了高精度超轻量PP-OCRv2模型,检测模型仅3M,识别模型仅8.5M,利用[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)的模型量化方法,可以在保持精度不降低的情况下,将检测模型压缩到0.8M,识别压缩到3M,更加适用于移动端部署场景。
-
diff --git a/doc/doc_ch/models_list.md b/doc/doc_ch/models_list.md
index 9b1dc97114..730dcb5f56 100644
--- a/doc/doc_ch/models_list.md
+++ b/doc/doc_ch/models_list.md
@@ -153,4 +153,3 @@ Paddle-Lite 是一个高性能、轻量级、灵活性强且易于扩展的深
|PP-OCRv2(slim)|蒸馏版超轻量中文OCR移动端模型|4.9M|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_opt.nb)|v2.9|
|V2.0|ppocr_v2.0超轻量中文OCR移动端模型|7.8M|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_opt.nb)|v2.9|
|V2.0(slim)|ppocr_v2.0超轻量中文OCR移动端模型|3.3M|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_slim_opt.nb)|v2.9|
-
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index acf09f7bde..815d14995c 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -229,8 +229,8 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs
log 中自动打印如下信息:
-| 字段 | 含义 |
-| :----: | :------: |
+| 字段 | 含义 |
+| :----: | :------: |
| epoch | 当前迭代轮次 |
| iter | 当前迭代次数 |
| lr | 当前学习率 |
diff --git a/doc/doc_en/algorithm_rec_parseq_en.md b/doc/doc_en/algorithm_rec_parseq_en.md
index a2f8948e5b..618095ec75 100644
--- a/doc/doc_en/algorithm_rec_parseq_en.md
+++ b/doc/doc_en/algorithm_rec_parseq_en.md
@@ -21,7 +21,7 @@ Paper:
> Darwin Bautista, Rowel Atienza
> ECCV, 2021
-Using real datasets (real) and synthetic datsets (synth) for training respectively,and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets.
+Using real datasets (real) and synthetic datsets (synth) for training respectively,and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets.
- The real datasets include COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR and OpenVINO datasets.
- The synthesis datasets include MJSynth and SynthText datasets.
diff --git a/doc/doc_en/algorithm_rec_satrn_en.md b/doc/doc_en/algorithm_rec_satrn_en.md
index b369608638..acc95ac035 100644
--- a/doc/doc_en/algorithm_rec_satrn_en.md
+++ b/doc/doc_en/algorithm_rec_satrn_en.md
@@ -101,11 +101,11 @@ Not supported
```bibtex
@article{lee2019recognizing,
- title={On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention},
+ title={On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention},
author={Junyeop Lee and Sungrae Park and Jeonghun Baek and Seong Joon Oh and Seonghyeon Kim and Hwalsuk Lee},
year={2019},
eprint={1910.04396},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
-```
\ No newline at end of file
+```
diff --git a/doc/doc_en/algorithm_rec_starnet.md b/doc/doc_en/algorithm_rec_starnet.md
index dbb53a9c73..45802c8939 100644
--- a/doc/doc_en/algorithm_rec_starnet.md
+++ b/doc/doc_en/algorithm_rec_starnet.md
@@ -135,5 +135,3 @@ The STAR-Net model also supports the following inference deployment methods:
year={2016}
}
```
-
-
diff --git a/doc/doc_en/benchmark_en.md b/doc/doc_en/benchmark_en.md
index 70b33aebd9..62ac0258a9 100755
--- a/doc/doc_en/benchmark_en.md
+++ b/doc/doc_en/benchmark_en.md
@@ -1,8 +1,8 @@
-# Benchmark
+# Benchmark
This document gives the performance of the series models for Chinese and English recognition.
-## Test Data
+## Test Data
We collected 300 images for different real application scenarios to evaluate the overall OCR system, including contract samples, license plates, nameplates, train tickets, test sheets, forms, certificates, street view images, business cards, digital meter, etc. The following figure shows some images of the test set.
diff --git a/doc/doc_en/code_and_doc.md b/doc/doc_en/code_and_doc.md
index f3ee769e7d..d4a7cf4835 100644
--- a/doc/doc_en/code_and_doc.md
+++ b/doc/doc_en/code_and_doc.md
@@ -14,14 +14,14 @@
The Python code of PaddleOCR follows [PEP8 Specification]( https://www.python.org/dev/peps/pep-0008/ ), some of the key concerns include the following
- - Space
+ - Space
- Spaces should be added after commas, semicolons, colons, not before them
```python
# true:
print(x, y)
-
+
# false:
print(x , y)
```
@@ -53,27 +53,27 @@
```python
def fetch_bigtable_rows(big_table, keys, other_silly_variable=None):
"""Fetches rows from a Bigtable.
-
+
Retrieves rows pertaining to the given keys from the Table instance
represented by big_table. Silly things may happen if
other_silly_variable is not None.
-
+
Args:
big_table: An open Bigtable Table instance.
keys: A sequence of strings representing the key of each table row
to fetch.
other_silly_variable: Another optional variable, that has a much
longer name than the other args, and which does nothing.
-
+
Returns:
A dict mapping keys to the corresponding table row data
fetched. Each row is represented as a tuple of strings. For
example:
-
+
{'Serak': ('Rigel VII', 'Preparer'),
'Zim': ('Irk', 'Invader'),
'Lrrr': ('Omicron Persei 8', 'Emperor')}
-
+
If a key from the keys argument is missing from the dictionary,
then that row was not found in the table.
"""
@@ -166,23 +166,23 @@
```
Only the information of the clone `remote repo`, i.e. the PaddleOCR under your username, is available. Due to the change in Github's login method, you need to reconfigure the `remote repo` address by means of a Token. The token is generated as follows:
-
+
1. Find Personal Access Tokens: Click on your avatar in the upper right corner of the Github page and choose Settings --> Developer settings --> Personal access tokens,
-
+
2. Click Generate new token: Fill in the token name in Note, such as 'paddle'. In Select scopes, select repo (required), admin:repo_hook, delete_repo, etc. You can check them according to your needs. Then click Generate token to generate the token, and finally copy the generated token.
Delete the original origin configuration
-
+
```
git remote rm origin
```
-
+
Change the remote branch to `https://oauth2:{token}@github.com/{your_name}/PaddleOCR.git`. For example, if the token value is 12345 and your user name is PPOCR, run the following command
-
+
```
git remote add origin https://oauth2:12345@github.com/PPOCR/PaddleOCR.git
```
-
+
This establishes a connection to our own `remote repo`. Next we create a remote host of the original PaddleOCR repo, named upstream.
```
@@ -203,19 +203,19 @@
#### 3.2.3 Create Local Branch
First get the latest code of upstream, then create a new_branch branch based on the dygraph of the upstream repo (upstream).
-
+
```
git fetch upstream
git checkout -b new_branch upstream/dygraph
```
-
+
> If for a newly forked PaddleOCR project, the user's remote repo (origin) has the same branch updates as the upstream repository (upstream), you can also create a new local branch based on the default branch of the origin repo or a specified branch with the following command
>
> ```
> # Create new_branch branch on user remote repo (origin) based on develop branch
> git checkout -b new_branch origin/develop
> # Create new_branch branch based on upstream remote repo develop branch
- > # If you need to create a new branch from upstream,
+ > # If you need to create a new branch from upstream,
> # you need to first use git fetch upstream to get upstream code
> git checkout -b new_branch upstream/develop
> ```
@@ -226,9 +226,9 @@
Branch new_branch set up to track remote branch develop from upstream.
Switched to a new branch 'new_branch'
```
-
+
After switching branches, file changes can be made on this branch
-
+
#### 3.2.4 Use Pre-Commit Hook
Paddle developers use the pre-commit tool to manage Git pre-submit hooks. It helps us format the source code (C++, Python) and automatically check for basic things (such as having only one EOL per file, not adding large files to Git) before committing it.
@@ -310,7 +310,7 @@
```
# Switch to the development branch, otherwise the current branch cannot be deleted
git checkout develop
-
+
# Delete new_ Branch Branch
git branch -D new_branch
```
@@ -322,28 +322,28 @@
In order for official maintainers to better focus on the code itself when reviewing it, please follow the following conventions each time you submit your code:
1)Please ensure that the unit tests in Travis-CI pass smoothly. If not, indicate that there is a problem with the submitted code, and the official maintainer generally does not review it.
-
+
2)Before submitting a Pull Request.
-
+
- Note the number of commits.
Reason: If you only modify one file and submit more than a dozen commits, each commit will only make a few modifications, which can be very confusing to the reviewer. The reviewer needs to look at each commit individually to see what changes have been made, and does not exclude the fact that changes between commits overlap each other.
-
+
Suggestion: Keep as few commits as possible each time you submit, and supplement your last commit with git commit --amend. For multiple commits that have been Push to a remote warehouse, you can refer to [squash commits after push](https://stackoverflow.com/questions/5667884/how-to-squash-commits-in-git-after-they-have-been-pushed ).
- Note the name of each commit: it should reflect the content of the current commit, not be too arbitrary.
3) If you have solved a problem, add in the first comment box of the Pull Request:fix #issue_number,This will automatically close the corresponding Issue when the Pull Request is merged. Key words include:close, closes, closed, fix, fixes, fixed, resolve, resolves, resolved,please choose the right vocabulary. Detailed reference [Closing issues via commit messages](https://help.github.com/articles/closing-issues-via-commit-messages).
-
+
In addition, in response to the reviewer's comments, you are requested to abide by the following conventions:
-
+
1) Each review comment from an official maintainer would like a response, which would better enhance the contribution of the open source community.
-
+
- If you agree to the review opinion and modify it accordingly, give a simple Done.
- If you disagree with the review, please give your own reasons for refuting.
-
+
2)If there are many reviews:
-
+
- Please give an overview of the changes.
- Please reply with `start a review', not directly. The reason is that each reply sends an e-mail message, which can cause a mail disaster.
diff --git a/doc/doc_en/community_contribution_en.md b/doc/doc_en/community_contribution_en.md
index 43ce20c6d2..e2d24f6144 100644
--- a/doc/doc_en/community_contribution_en.md
+++ b/doc/doc_en/community_contribution_en.md
@@ -12,7 +12,7 @@ PaddleOCR wants to help any developer with a dream realize their vision and enjo
> The picture above shows PaddleOCR's current Contributor, updated regularly
-## 1. COMMUNITY CONTRIBUTION
+## 1. COMMUNITY CONTRIBUTION
### 1.1 PaddleOCR BASED COMMUNITY PROJECT
@@ -60,7 +60,7 @@ PaddleOCR welcomes community contributions to various services, deployment examp
- Project form: the project code certified by the official community shall have good specifications and structure, and shall be equipped with a detailed README.md, which describes how to use the project. Through add a line 'paddleocr' to the requirements.txt, which can be automatically included in the usedby of paddleocr.
-- Integration method: if it is an update to the existing PaddleOCR tool, it will be integrated into the main repo. If a new function is expanded for paddleocr, please contact the official personnel first to confirm whether the project is integrated into the master repo, *even if the new function is not integrated into the master repo, we will also increase the exposure of your personal project in the way of community contribution.*
+- Integration method: if it is an update to the existing PaddleOCR tool, it will be integrated into the main repo. If a new function is expanded for paddleocr, please contact the official personnel first to confirm whether the project is integrated into the master repo, *even if the new function is not integrated into the master repo, we will also increase the exposure of your personal project in the way of community contribution.*
### 2.2 CODE OPTIMIZATION
diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md
index d467a7f918..b7c828f071 100644
--- a/doc/doc_en/config_en.md
+++ b/doc/doc_en/config_en.md
@@ -135,7 +135,7 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck
| Parameter | Use | Defaults | Note |
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
| project | Project to which the run is to be logged | uncategorized | \
-| name | Alias/Name of the run | Randomly generated by wandb | \
+| name | Alias/Name of the run | Randomly generated by wandb | \
| id | ID of the run | Randomly generated by wandb | \
| entity | User or team to which the run is being logged | The logged in user | \
| save_dir | local directory in which all the models and other data is saved | wandb | \
@@ -245,4 +245,4 @@ For more supported languages, please refer to : [Multi-language model](https://g
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
-* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
\ No newline at end of file
+* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
diff --git a/doc/doc_en/data_synthesis_en.md b/doc/doc_en/data_synthesis_en.md
index ee58fc680d..f05a793cf5 100644
--- a/doc/doc_en/data_synthesis_en.md
+++ b/doc/doc_en/data_synthesis_en.md
@@ -9,4 +9,4 @@ There are the commonly used data synthesis tools, which will be continuously upd
* [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator)
* [SynthText3D](https://github.com/MhLiao/SynthText3D)
* [UnrealText](https://github.com/Jyouhou/UnrealText/)
-* [SynthTIGER](https://github.com/clovaai/synthtiger)
\ No newline at end of file
+* [SynthTIGER](https://github.com/clovaai/synthtiger)
diff --git a/doc/doc_en/enhanced_ctc_loss_en.md b/doc/doc_en/enhanced_ctc_loss_en.md
index 908f79e412..72cada714e 100644
--- a/doc/doc_en/enhanced_ctc_loss_en.md
+++ b/doc/doc_en/enhanced_ctc_loss_en.md
@@ -7,14 +7,14 @@ In OCR recognition, CRNN is a text recognition algorithm widely applied in the i
Focal Loss was proposed by the paper, "[Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)". When the loss was first proposed, it was mainly to solve the problem of a serious imbalance in the ratio of positive and negative samples in one-stage target detection. This loss function reduces the weight of a large number of simple negative samples in training and also can be understood as a kind of difficult sample mining.
The form of the loss function is as follows:
-
-
+
+
Among them, y' is the output of the activation function, and the value is between 0-1. It adds a modulation factor (1-y’)^γ and a balance factor α on the basis of the original cross-entropy loss. When α = 1, y = 1, the comparison between the loss function and the cross-entropy loss is shown in the following figure:
-
-
+
+
@@ -23,16 +23,16 @@ As can be seen from the above figure, when γ > 0, the adjustment coeffici
For the classic CTC algorithm, suppose a certain feature sequence (f 1, f 2, ......f t), after CTC decoding, the probability that the result is equal to label is y', then the probability that the CTC decoding result is not equal to label is (1-y'); it is not difficult to find that the CTCLoss value and y' have the following relationship:
-
-
+
+
Combining the idea of Focal Loss, assigning larger weights to difficult samples and smaller weights to simple samples can make the network focus more on the mining of difficult samples and further improve the accuracy of recognition. Therefore, we propose Focal-CTC Loss. Its definition is as follows:
-
-
+
+
@@ -50,7 +50,7 @@ A-CTC Loss is short for CTC Loss + ACE Loss. Among them, ACE Loss was proposed b
The advantages and disadvantages of the OCR recognition algorithm summarized by the predecessors are shown in the following figure:
-
+
@@ -58,7 +58,7 @@ Although ACELoss does handle 2D predictions, as shown in the figure above, and h
A_CTC Loss is defined as follows:
-
+
@@ -76,7 +76,7 @@ In the task of Chinese OCR recognition, through the analysis of bad cases, we fo
By trying Arcmargin, Cosmargin and other methods, we finally found that Centerloss can help further improve the accuracy of recognition. C_CTC Loss is defined as follows:
-
+
In the experiment, we set λ=0.25. See the center_loss implementation code: [center_loss.py](../../ppocr/losses/center_loss.py)
@@ -107,4 +107,4 @@ For the above three solutions, we conducted training and evaluation based on Bai
Based on the above experimental conclusions, we adopted the C-CTC strategy in PP-OCRv2. It is worth mentioning that, because PP-OCRv2 deals with the recognition task of 6625 Chinese characters, the character set is relatively large and there are many similar characters, so the C-CTC solution brings a significant improvement on this task. But if you switch to other OCR recognition tasks, the conclusion may be different. You can try Focal-CTC, A-CTC, C-CTC, and the combined solution EnhancedCTC. We believe it will bring different degrees of improvement.
-The unified combined plan is shown in the following file: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py)
\ No newline at end of file
+The unified combined plan is shown in the following file: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py)
diff --git a/doc/doc_en/environment_en.md b/doc/doc_en/environment_en.md
index 6521d3c414..453287385f 100644
--- a/doc/doc_en/environment_en.md
+++ b/doc/doc_en/environment_en.md
@@ -1,6 +1,6 @@
# Environment Preparation
-Windows and Mac users are recommended to use Anaconda to build a Python environment, and Linux users are recommended to use docker to build a Python environment.
+Windows and Mac users are recommended to use Anaconda to build a Python environment, and Linux users are recommended to use docker to build a Python environment.
Recommended working environment:
- PaddlePaddle >= 2.1.2
diff --git a/doc/doc_en/logging_en.md b/doc/doc_en/logging_en.md
index d00ab8bd56..10c400c084 100644
--- a/doc/doc_en/logging_en.md
+++ b/doc/doc_en/logging_en.md
@@ -1,6 +1,6 @@
-## Logging metrics and models
+## Logging metrics and models
-PaddleOCR comes with two metric logging tools integrated directly into the training API: [VisualDL](https://readthedocs.org/projects/visualdl/) and [Weights & Biases](https://docs.wandb.ai/).
+PaddleOCR comes with two metric logging tools integrated directly into the training API: [VisualDL](https://readthedocs.org/projects/visualdl/) and [Weights & Biases](https://docs.wandb.ai/).
### VisualDL
VisualDL is a visualization analysis tool of PaddlePaddle. The integration allows all training metrics to be logged to a VisualDL dashboard. To use it, add the following line to the `Global` section of the config yaml file -
@@ -35,7 +35,7 @@ Global:
use_wandb: True
```
-To add more arguments to the `WandbLogger` listed [here](./config_en.md) add the header `wandb` to the yaml file and add the arguments under it -
+To add more arguments to the `WandbLogger` listed [here](./config_en.md) add the header `wandb` to the yaml file and add the arguments under it -
```
wandb:
@@ -58,4 +58,4 @@ For more advanced usage to log images, audios, videos or any other form of data,
To view the dashboard, the link to the dashboard is printed to the console at the beginning and end of every training job and you can also access it by logging into your W&B account on your browser.
### Using Multiple Loggers
-Both VisualDL and W&B can also be used simultaneously by just setting both the aforementioned flags to True.
\ No newline at end of file
+Both VisualDL and W&B can also be used simultaneously by just setting both the aforementioned flags to True.
diff --git a/doc/doc_en/ocr_book_en.md b/doc/doc_en/ocr_book_en.md
index 63162be566..ec2b65529e 100644
--- a/doc/doc_en/ocr_book_en.md
+++ b/doc/doc_en/ocr_book_en.md
@@ -27,4 +27,3 @@
-
diff --git a/doc/doc_en/reference_en.md b/doc/doc_en/reference_en.md
index 55c56f5617..066a86a16c 100644
--- a/doc/doc_en/reference_en.md
+++ b/doc/doc_en/reference_en.md
@@ -1,55 +1,55 @@
-# REFERENCE
-
-```
-1. EAST:
-@inproceedings{zhou2017east,
- title={EAST: an efficient and accurate scene text detector},
- author={Zhou, Xinyu and Yao, Cong and Wen, He and Wang, Yuzhi and Zhou, Shuchang and He, Weiran and Liang, Jiajun},
- booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition},
- pages={5551--5560},
- year={2017}
-}
-
-2. DB:
-@article{liao2019real,
- title={Real-time Scene Text Detection with Differentiable Binarization},
- author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang},
- journal={arXiv preprint arXiv:1911.08947},
- year={2019}
-}
-
-3. DTRB:
-@inproceedings{baek2019wrong,
- title={What is wrong with scene text recognition model comparisons? dataset and model analysis},
- author={Baek, Jeonghun and Kim, Geewook and Lee, Junyeop and Park, Sungrae and Han, Dongyoon and Yun, Sangdoo and Oh, Seong Joon and Lee, Hwalsuk},
- booktitle={Proceedings of the IEEE International Conference on Computer Vision},
- pages={4715--4723},
- year={2019}
-}
-
-4. SAST:
-@inproceedings{wang2019single,
- title={A Single-Shot Arbitrarily-Shaped Text Detector based on Context Attended Multi-Task Learning},
- author={Wang, Pengfei and Zhang, Chengquan and Qi, Fei and Huang, Zuming and En, Mengyi and Han, Junyu and Liu, Jingtuo and Ding, Errui and Shi, Guangming},
- booktitle={Proceedings of the 27th ACM International Conference on Multimedia},
- pages={1277--1285},
- year={2019}
-}
-
-5. SRN:
-@article{yu2020towards,
- title={Towards Accurate Scene Text Recognition with Semantic Reasoning Networks},
- author={Yu, Deli and Li, Xuan and Zhang, Chengquan and Han, Junyu and Liu, Jingtuo and Ding, Errui},
- journal={arXiv preprint arXiv:2003.12294},
- year={2020}
-}
-
-6. end2end-psl:
-@inproceedings{sun2019chinese,
- title={Chinese Street View Text: Large-scale Chinese Text Reading with Partially Supervised Learning},
- author={Sun, Yipeng and Liu, Jiaming and Liu, Wei and Han, Junyu and Ding, Errui and Liu, Jingtuo},
- booktitle={Proceedings of the IEEE International Conference on Computer Vision},
- pages={9086--9095},
- year={2019}
-}
-```
\ No newline at end of file
+# REFERENCE
+
+```
+1. EAST:
+@inproceedings{zhou2017east,
+ title={EAST: an efficient and accurate scene text detector},
+ author={Zhou, Xinyu and Yao, Cong and Wen, He and Wang, Yuzhi and Zhou, Shuchang and He, Weiran and Liang, Jiajun},
+ booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition},
+ pages={5551--5560},
+ year={2017}
+}
+
+2. DB:
+@article{liao2019real,
+ title={Real-time Scene Text Detection with Differentiable Binarization},
+ author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang},
+ journal={arXiv preprint arXiv:1911.08947},
+ year={2019}
+}
+
+3. DTRB:
+@inproceedings{baek2019wrong,
+ title={What is wrong with scene text recognition model comparisons? dataset and model analysis},
+ author={Baek, Jeonghun and Kim, Geewook and Lee, Junyeop and Park, Sungrae and Han, Dongyoon and Yun, Sangdoo and Oh, Seong Joon and Lee, Hwalsuk},
+ booktitle={Proceedings of the IEEE International Conference on Computer Vision},
+ pages={4715--4723},
+ year={2019}
+}
+
+4. SAST:
+@inproceedings{wang2019single,
+ title={A Single-Shot Arbitrarily-Shaped Text Detector based on Context Attended Multi-Task Learning},
+ author={Wang, Pengfei and Zhang, Chengquan and Qi, Fei and Huang, Zuming and En, Mengyi and Han, Junyu and Liu, Jingtuo and Ding, Errui and Shi, Guangming},
+ booktitle={Proceedings of the 27th ACM International Conference on Multimedia},
+ pages={1277--1285},
+ year={2019}
+}
+
+5. SRN:
+@article{yu2020towards,
+ title={Towards Accurate Scene Text Recognition with Semantic Reasoning Networks},
+ author={Yu, Deli and Li, Xuan and Zhang, Chengquan and Han, Junyu and Liu, Jingtuo and Ding, Errui},
+ journal={arXiv preprint arXiv:2003.12294},
+ year={2020}
+}
+
+6. end2end-psl:
+@inproceedings{sun2019chinese,
+ title={Chinese Street View Text: Large-scale Chinese Text Reading with Partially Supervised Learning},
+ author={Sun, Yipeng and Liu, Jiaming and Liu, Wei and Han, Junyu and Ding, Errui and Liu, Jingtuo},
+ booktitle={Proceedings of the IEEE International Conference on Computer Vision},
+ pages={9086--9095},
+ year={2019}
+}
+```
diff --git "a/doc/doc_i18n/README_\320\240\321\203\314\201\321\201\321\201\320\272\320\270\320\271_\321\217\320\267\321\213\314\201\320\272.md" "b/doc/doc_i18n/README_\320\240\321\203\314\201\321\201\321\201\320\272\320\270\320\271_\321\217\320\267\321\213\314\201\320\272.md"
index 0b3a59c5c7..a829b3af23 100644
--- "a/doc/doc_i18n/README_\320\240\321\203\314\201\321\201\321\201\320\272\320\270\320\271_\321\217\320\267\321\213\314\201\320\272.md"
+++ "b/doc/doc_i18n/README_\320\240\321\203\314\201\321\201\321\201\320\272\320\270\320\271_\321\217\320\267\321\213\314\201\320\272.md"
@@ -43,7 +43,7 @@ PaddleOCR стремится создавать многоязычные, пот
- **🔥2022.7 Выпуск [Коллекция приложений сцены OCR](../../applications/README_en.md)**
- Выпуск **9 вертикальных моделей**, таких как цифровая трубка, ЖК-экран, номерной знак, модель распознавания рукописного ввода, высокоточная модель SVTR и т. д., охватывающих основные вертикальные приложения OCR в целом, производственной, финансовой и транспортной отраслях.
- **🔥2022.5.9 Выпуск PaddleOCR [Выпуск /2.5](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.5)**
-- Выпускать [PP-OCRv3](../doc_en/ppocr_introduction_en.md#pp-ocrv3): При сопоставимой скорости эффект китайской сцены улучшен на 5% по сравнению с ПП-OCRRv2, эффект английской сцены улучшен на 11%, а средняя точность распознавания 80 языковых многоязычных моделей улучшена более чем на 5%.
+- Выпускать [PP-OCRv3](../doc_en/ppocr_introduction_en.md#pp-ocrv3): При сопоставимой скорости эффект китайской сцены улучшен на 5% по сравнению с ПП-OCRRv2, эффект английской сцены улучшен на 11%, а средняя точность распознавания 80 языковых многоязычных моделей улучшена более чем на 5%.
- Выпускать [PPOCRLabelv2](./PPOCRLabel): Добавьте функцию аннотации для задачи распознавания таблиц, задачи извлечения ключевой информации и изображения неправильного текста.
- Выпустить интерактивную электронную книгу [*"Погружение в OCR"*](../doc_en/ocr_book_en.md), охватывает передовую теорию и практику кодирования технологии полного стека OCR.
- [подробнее](../doc_en/update_en.md)
@@ -77,7 +77,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=ru
-## 👫 Сообщество
+## 👫 Сообщество
Что касается международных разработчиков, мы рассматриваем [Обсуждения PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/discussions) как нашу платформу для международного сообщества. Все идеи и вOCRосы можно обсудить здесь на английском языке.
@@ -97,12 +97,12 @@ paddleocr --image_dir /your/test/image.jpg --lang=ru
-## 📖 Учебники
+## 📖 Учебники
- [Подготовка окружающей среды](../doc_en/environment_en.md)
- [PP-OCR 🔥](../doc_en/ppocr_introduction_en.md)
-
- - [Быстрый старт](../doc_en/quickstart_en.md)
+
+ - [Быстрый старт](../doc_en/quickstart_en.md)
- [Модель Zoo](../doc_en/модельs_en.md)
- [Модель тренировки](../doc_en/training_en.md)
- [Обнаружение текста](../doc_en/detection_en.md)
@@ -136,7 +136,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=ru
- [Распознавание текста](../doc_en/algorithm_overview_en.md)
- [Непрерывной цепью OCR](../doc_en/algorithm_overview_en.md)
- [Распознавание таблиц](../doc_en/algorithm_overview_en.md)
- - [Извлечение ключевой информации](../doc_en/algorithm_overview_en.md)
+ - [Извлечение ключевой информации](../doc_en/algorithm_overview_en.md)
- [Добавьте новые алгоритмы в PaddleOCR](../doc_en/add_new_algorithm_en.md)
- Аннотации и синтез данных
- [Полуавтоматический инструмент аннотации данных: метка ППOCRR](./PPOCRLabel/README.md)
@@ -170,7 +170,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=ru
-## 👀 Визуализация [больше](../doc_en/visualization_en.md)
+## 👀 Визуализация [больше](../doc_en/visualization_en.md)
PP-OCRv3 Многоязычная модель
@@ -215,7 +215,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=ru
3. RE (Извлечение отношений)
-
+
@@ -226,5 +226,5 @@ paddleocr --image_dir /your/test/image.jpg --lang=ru
-## 📄 Лицензия
+## 📄 Лицензия
Этот проект выпущен под Apache 2.0 license
diff --git "a/doc/doc_i18n/README_\340\244\271\340\244\277\340\244\250\340\245\215\340\244\246.md" "b/doc/doc_i18n/README_\340\244\271\340\244\277\340\244\250\340\245\215\340\244\246.md"
index ef8a22f289..a93212cd6a 100644
--- "a/doc/doc_i18n/README_\340\244\271\340\244\277\340\244\250\340\245\215\340\244\246.md"
+++ "b/doc/doc_i18n/README_\340\244\271\340\244\277\340\244\250\340\245\215\340\244\246.md"
@@ -49,7 +49,7 @@
- [और अधिक](../doc_en/update_en.md)
-## 🌟 विशेषताएँ
+## 🌟 विशेषताएँ
Paddleओसीआर से संबंधित विभिन्न प्रकार के अत्याधुनिक एल्गोरिथ्म को सपोर्ट करता है, और विकसित औद्योगिक विशेष रुप से प्रदर्शित मॉडल/समाधान [PP- OCR](../doc_en/ppocr_introduction_en.md) और [PP-Structure](../../ppstructure/README.md) इस आधार पर और डेटा प्रोडक्शन की पूरी प्रोसेस के माध्यम से प्राप्त करें, मॉडल ट्रेनिंग, दबाव, अनुमान और तैनाती।
@@ -57,7 +57,7 @@ Paddleओसीआर से संबंधित विभिन्न प्
-## ⚡ शीघ्र अनुभव
+## ⚡ शीघ्र अनुभव
```bash
pip3 install paddlepaddle # for gpu user please install paddlepaddle-gpu
@@ -105,7 +105,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=hi
- [टेक्सट डिटेक्शन](../doc_en/detection_en.md)
- [टेक्सट रिकोगनीशन](../doc_en/recognition_en.md)
- [टेक्सट डायरेक्शन क्लासिफिकेशन](../doc_en/angle_class_en.md)
- - मॉडल कम्प्रेशन
+ - मॉडल कम्प्रेशन
- [मॉडल परिमाणीकरण](./deploy/slim/quantization/README_en.md)
- [मॉडल प्रूनिंग](./deploy/slim/prune/README_en.md)
- [ज्ञान आसवन](../doc_en/knowledge_distillation_en.md)
@@ -133,7 +133,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=hi
- [टेक्स्ट रिकोगनाइजेशन](../doc_en/algorithm_overview_en.md)
- [एंड-टू-एंड ओसीआर](../doc_en/algorithm_overview_en.md)
- [टेबल रिकोगनाइजेशन](../doc_en/algorithm_overview_en.md)
- - [की इंफॉर्मेशन एक्स्ट्रेक्शन](../doc_en/algorithm_overview_en.md)
+ - [की इंफॉर्मेशन एक्स्ट्रेक्शन](../doc_en/algorithm_overview_en.md)
- [पैडलओसीआर में नए एल्गोरिदम जोड़ें](../doc_en/add_new_algorithm_en.md)
- डेटा एनोटेशन और सिंथेसिस
- [सेमी-ऑटोमैटिक एनोटेशन टूल: PPओसीआरलेबल](./PPOCRLabel/README.md)
@@ -143,7 +143,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=hi
- डेटा सेट
- [सामान्य ओसीआर डेटासेट (चीनी/अंग्रेज़ी)](../doc_en/dataset/datasets_en.md)
- [हस्तलिखित_ओसीआर_डेटासेट (चीनी)](../doc_en/dataset/handwritten_datasets_en.md)
- - [विभिन्न ओसीआर
+ - [विभिन्न ओसीआर
डेटासेट (बहुभाषी)](../doc_en/dataset/vertical_and_multilingual_datasets_en.md)
- [लेआउट एनालाइस](../doc_en/dataset/layout_datasets_en.md)
- [टेबल रिकोगनाइजेशन](../doc_en/dataset/table_datasets_en.md)
@@ -213,7 +213,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=hi
3. RE (रिलेशन एक्सट्रैक्शन)
-
+
diff --git "a/doc/doc_i18n/README_\346\227\245\346\234\254\350\252\236.md" "b/doc/doc_i18n/README_\346\227\245\346\234\254\350\252\236.md"
index a75003ecfe..db2b7f9420 100644
--- "a/doc/doc_i18n/README_\346\227\245\346\234\254\350\252\236.md"
+++ "b/doc/doc_i18n/README_\346\227\245\346\234\254\350\252\236.md"
@@ -133,7 +133,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=japan # change for i18n abbr
- [テキスト認識](../doc_en/algorithm_overview_en.md)
- [エンド・ツー・エンド OCR](../doc_en/algorithm_overview_en.md)
- [表認識](../doc_en/algorithm_overview_en.md)
- - [キー情報抽出](../doc_en/algorithm_overview_en.md)
+ - [キー情報抽出](../doc_en/algorithm_overview_en.md)
- [PaddleOCR に新しいアルゴリズムを追加する](../doc_en/add_new_algorithm_en.md)
- データの注釈と合成
- [半自動注釈ツール: PPOCRLabel](./PPOCRLabel/README.md)
@@ -212,7 +212,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=japan # change for i18n abbr
3. RE (関係抽出)
-
+
diff --git "a/doc/doc_i18n/README_\355\225\234\352\265\255\354\226\264.md" "b/doc/doc_i18n/README_\355\225\234\352\265\255\354\226\264.md"
index 7ec3072404..46c8a398f6 100644
--- "a/doc/doc_i18n/README_\355\225\234\352\265\255\354\226\264.md"
+++ "b/doc/doc_i18n/README_\355\225\234\352\265\255\354\226\264.md"
@@ -15,7 +15,7 @@
## 소개
-PaddleOCR은 사용자들이 보다 나은 모델을 훈련하여 실전에 투입하는데 도움을 주는 다중 언어로 된 엄청나게 멋지고 주도적이며 실용적인 OCR 툴을 만드는데 목표를 두고 있습니다.
+PaddleOCR은 사용자들이 보다 나은 모델을 훈련하여 실전에 투입하는데 도움을 주는 다중 언어로 된 엄청나게 멋지고 주도적이며 실용적인 OCR 툴을 만드는데 목표를 두고 있습니다.
@@ -35,12 +35,12 @@ PaddleOCR은 사용자들이 보다 나은 모델을 훈련하여 실전에 투
- [레이아웃 분석](../../ppstructure/layout) 최적화: 95% 감소된 모델 저장, while 반면 속도는 11배 증가하고, 평균 CPU 시간 비용은 41ms에 불과함;
- [표 인식](../../ppstructure/table) 최적화: 3 최적화 전략이 디자인되고 모델 정확도는 비교 가능한 시간 소비 하에 6% 개선됨;
- [핵심 정보 추출](../../ppstructure/kie) 최적화: 시각에 의존하지 않는 모델 구조가 디자인되고, 의미체 인식 정확도가 2.8% 증가되며 관계 추출 정확도는 9.1% 증가됨.
-
+
- **🔥2022년 7월 출시[OCR 씬 애플리케이션 컬렉션](../../applications/README_en.md)**
디지털 튜브, LCD 스크린, 라이선스 플레이트, 수기 인식 모델, 고정밀 SVTR 모델 등등과 같은 “9수직 모델” 출시로, 일반적으로 주된 OCR 수직 애플리케이션, 제조, 금융 및 수송 산업 커버
- **🔥2022년 5월 9일에 패들 OCR 출시 [출시/2.5](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.5)**
- - [PP-OCRv3](../doc_en/ppocr_introduction_en.md#pp-ocrv3)출시: 5%.비교 가능한 속도로, 차이니즈 씬의 효과는 PP-OCRv2와 비교해 볼 때 추가로 5% 정도 더 개선되고, 잉글리쉬 씬 효과는 11% 개선되었으며, 80개 언어 다중 언어 모델 평균 인식 정확도는 5% 이상 개선됨.
+ - [PP-OCRv3](../doc_en/ppocr_introduction_en.md#pp-ocrv3)출시: 5%.비교 가능한 속도로, 차이니즈 씬의 효과는 PP-OCRv2와 비교해 볼 때 추가로 5% 정도 더 개선되고, 잉글리쉬 씬 효과는 11% 개선되었으며, 80개 언어 다중 언어 모델 평균 인식 정확도는 5% 이상 개선됨.
- [PPOCRLabelv2](./PPOCRLabel)출시: 표 인식 업무, 핵심 정보 추출 업무 및 불규칙한 텍스트 이미지주석 기능 추가.
- 쌍방향e-북 출시 [*"OCR 뛰어들기"*](../doc_en/ocr_book_en.md), 첨단 이론 및 OCR 정식 스택 기술 코드 연습 포함.
@@ -131,9 +131,9 @@ paddleocr --image_dir /your/test/image.jpg --lang=korean
- [텍스트 인식](../doc_en/algorithm_overview_en.md)
- [종단종OCR](../doc_en/algorithm_overview_en.md)
- [표 인식](../doc_en/algorithm_overview_en.md)
- - [핵심 정보 추출](../doc_en/algorithm_overview_en.md)
+ - [핵심 정보 추출](../doc_en/algorithm_overview_en.md)
- [PaddleOCR에 신규 알고리즘 추가](../doc_en/add_new_algorithm_en.md)
-- 데이터 주석 및 합성
+- 데이터 주석 및 합성
- [반-자동 주석 툴: PPOCRLabel](./PPOCRLabel/README.md)
- [데이터 합성 툴: 스타일-텍스트](./StyleText/README.md)
- [기타 데이터 주석 툴](../doc_en/data_annotation_en.md)
@@ -155,10 +155,10 @@ paddleocr --image_dir /your/test/image.jpg --lang=korean
-## 신규 언어 요청에 대한 유엔 가이드라인
+## 신규 언어 요청에 대한 유엔 가이드라인
만일 신규 언어 모델을 요청하고자 한다면**, [다중 언어 모델 업그레이드 투표하기](https://github.com/PaddlePaddle/PaddleOCR/discussions/7253)에서 투표하기 바람. 우리는 결과에 따라 규칙적으로 모델을 업그레이드 시킬 것임**함께 투표하고자 당신의 친구들을 초대할 것!**
-만일 당신이 시나리오 기반 “신규 언어 모델”을 훈련하고자 한다면, [다중 언어 모델 훈련 프로젝트](https://github.com/PaddlePaddle/PaddleOCR/discussions/7252) 를 통해 당신의 데이터세트를 작성하는데 도움이 되고 단계별로 전체 절차를 보여줄 것입니다.
+만일 당신이 시나리오 기반 “신규 언어 모델”을 훈련하고자 한다면, [다중 언어 모델 훈련 프로젝트](https://github.com/PaddlePaddle/PaddleOCR/discussions/7252) 를 통해 당신의 데이터세트를 작성하는데 도움이 되고 단계별로 전체 절차를 보여줄 것입니다.
원본[다중 언어 OCR 개발 계획](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)은 여전히 수많은 유용한 말뭉치와 사전을 보여줍니다.
@@ -210,7 +210,7 @@ paddleocr --image_dir /your/test/image.jpg --lang=korean
3. RE (관계 추출)
-
+
diff --git a/paddleocr.py b/paddleocr.py
index 9f3ecda131..95e316fb1e 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -20,7 +20,7 @@
import paddle
-sys.path.append(os.path.join(__dir__, ''))
+sys.path.append(os.path.join(__dir__, ""))
import cv2
import logging
@@ -42,397 +42,356 @@ def _import_file(module_name, file_path, make_importable=False):
tools = _import_file(
- 'tools', os.path.join(__dir__, 'tools/__init__.py'), make_importable=True)
-ppocr = importlib.import_module('ppocr', 'paddleocr')
-ppstructure = importlib.import_module('ppstructure', 'paddleocr')
+ "tools", os.path.join(__dir__, "tools/__init__.py"), make_importable=True
+)
+ppocr = importlib.import_module("ppocr", "paddleocr")
+ppstructure = importlib.import_module("ppstructure", "paddleocr")
from ppocr.utils.logging import get_logger
logger = get_logger()
-from ppocr.utils.utility import check_and_read, get_image_file_list, alpha_to_color, binarize_img
-from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
+from ppocr.utils.utility import (
+ check_and_read,
+ get_image_file_list,
+ alpha_to_color,
+ binarize_img,
+)
+from ppocr.utils.network import (
+ maybe_download,
+ download_with_progressbar,
+ is_link,
+ confirm_model_dir_url,
+)
from tools.infer.utility import draw_ocr, str2bool, check_gpu
from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import StructureSystem, save_structure_res, to_excel
logger = get_logger()
__all__ = [
- 'PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result',
- 'save_structure_res', 'download_with_progressbar', 'to_excel'
+ "PaddleOCR",
+ "PPStructure",
+ "draw_ocr",
+ "draw_structure_result",
+ "save_structure_res",
+ "download_with_progressbar",
+ "to_excel",
]
-SUPPORT_DET_MODEL = ['DB']
-VERSION = '2.8.0'
-SUPPORT_REC_MODEL = ['CRNN', 'SVTR_LCNet']
+SUPPORT_DET_MODEL = ["DB"]
+VERSION = "2.8.0"
+SUPPORT_REC_MODEL = ["CRNN", "SVTR_LCNet"]
BASE_DIR = os.path.expanduser("~/.paddleocr/")
-DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv4'
-SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2', 'PP-OCRv3', 'PP-OCRv4']
-DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-StructureV2'
-SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-Structure', 'PP-StructureV2']
+DEFAULT_OCR_MODEL_VERSION = "PP-OCRv4"
+SUPPORT_OCR_MODEL_VERSION = ["PP-OCR", "PP-OCRv2", "PP-OCRv3", "PP-OCRv4"]
+DEFAULT_STRUCTURE_MODEL_VERSION = "PP-StructureV2"
+SUPPORT_STRUCTURE_MODEL_VERSION = ["PP-Structure", "PP-StructureV2"]
MODEL_URLS = {
- 'OCR': {
- 'PP-OCRv4': {
- 'det': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar',
- },
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar',
- },
- 'ml': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/Multilingual_PP-OCRv3_det_infer.tar'
- }
+ "OCR": {
+ "PP-OCRv4": {
+ "det": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar",
+ },
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar",
+ },
+ "ml": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/Multilingual_PP-OCRv3_det_infer.tar"
+ },
},
- 'rec': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
- },
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/english/en_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/en_dict.txt'
- },
- 'korean': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/korean_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/korean_dict.txt'
- },
- 'japan': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/japan_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/japan_dict.txt'
- },
- 'chinese_cht': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/chinese_cht_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
- },
- 'ta': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/ta_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ta_dict.txt'
- },
- 'te': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/te_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/te_dict.txt'
- },
- 'ka': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/ka_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ka_dict.txt'
- },
- 'latin': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/latin_dict.txt'
- },
- 'arabic': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/arabic_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
- },
- 'cyrillic': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
- },
- 'devanagari': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/devanagari_PP-OCRv4_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
+ "rec": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/ppocr_keys_v1.txt",
+ },
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/english/en_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/en_dict.txt",
+ },
+ "korean": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/korean_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/korean_dict.txt",
+ },
+ "japan": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/japan_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/japan_dict.txt",
+ },
+ "chinese_cht": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/chinese_cht_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/chinese_cht_dict.txt",
+ },
+ "ta": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/ta_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/ta_dict.txt",
+ },
+ "te": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/te_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/te_dict.txt",
+ },
+ "ka": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/ka_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/ka_dict.txt",
+ },
+ "latin": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/latin_dict.txt",
+ },
+ "arabic": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/arabic_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/arabic_dict.txt",
+ },
+ "cyrillic": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/cyrillic_dict.txt",
+ },
+ "devanagari": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv4/multilingual/devanagari_PP-OCRv4_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/devanagari_dict.txt",
},
},
- 'cls': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
+ "cls": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar",
}
},
},
- 'PP-OCRv3': {
- 'det': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar',
- },
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar',
- },
- 'ml': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/Multilingual_PP-OCRv3_det_infer.tar'
- }
+ "PP-OCRv3": {
+ "det": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar",
+ },
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar",
+ },
+ "ml": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/Multilingual_PP-OCRv3_det_infer.tar"
+ },
},
- 'rec': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
- },
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/en_dict.txt'
- },
- 'korean': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/korean_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/korean_dict.txt'
- },
- 'japan': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/japan_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/japan_dict.txt'
- },
- 'chinese_cht': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/chinese_cht_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
- },
- 'ta': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ta_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ta_dict.txt'
- },
- 'te': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/te_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/te_dict.txt'
- },
- 'ka': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ka_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ka_dict.txt'
- },
- 'latin': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/latin_dict.txt'
- },
- 'arabic': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
- },
- 'cyrillic': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
- },
- 'devanagari': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/devanagari_PP-OCRv3_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
+ "rec": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/ppocr_keys_v1.txt",
+ },
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/en_dict.txt",
+ },
+ "korean": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/korean_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/korean_dict.txt",
+ },
+ "japan": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/japan_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/japan_dict.txt",
+ },
+ "chinese_cht": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/chinese_cht_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/chinese_cht_dict.txt",
+ },
+ "ta": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ta_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/ta_dict.txt",
+ },
+ "te": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/te_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/te_dict.txt",
+ },
+ "ka": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ka_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/ka_dict.txt",
+ },
+ "latin": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/latin_dict.txt",
+ },
+ "arabic": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/arabic_dict.txt",
+ },
+ "cyrillic": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/cyrillic_dict.txt",
+ },
+ "devanagari": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/devanagari_PP-OCRv3_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/devanagari_dict.txt",
},
},
- 'cls': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
+ "cls": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar",
}
},
},
- 'PP-OCRv2': {
- 'det': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar',
+ "PP-OCRv2": {
+ "det": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar",
},
},
- 'rec': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar',
- 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
+ "rec": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar",
+ "dict_path": "./ppocr/utils/ppocr_keys_v1.txt",
}
},
- 'cls': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
+ "cls": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar",
}
},
},
- 'PP-OCR': {
- 'det': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
- },
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
- },
- 'structure': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
- }
+ "PP-OCR": {
+ "det": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar",
+ },
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar",
+ },
+ "structure": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar"
+ },
},
- 'rec': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
- },
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/en_dict.txt'
- },
- 'french': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/french_dict.txt'
- },
- 'german': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/german_dict.txt'
- },
- 'korean': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/korean_dict.txt'
- },
- 'japan': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/japan_dict.txt'
- },
- 'chinese_cht': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
- },
- 'ta': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ta_dict.txt'
- },
- 'te': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/te_dict.txt'
- },
- 'ka': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ka_dict.txt'
- },
- 'latin': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/latin_dict.txt'
- },
- 'arabic': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
- },
- 'cyrillic': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
- },
- 'devanagari': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
- },
- 'structure': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
- 'dict_path': 'ppocr/utils/dict/table_dict.txt'
- }
+ "rec": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/ppocr_keys_v1.txt",
+ },
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/en_dict.txt",
+ },
+ "french": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/french_dict.txt",
+ },
+ "german": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/german_dict.txt",
+ },
+ "korean": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/korean_dict.txt",
+ },
+ "japan": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/japan_dict.txt",
+ },
+ "chinese_cht": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/chinese_cht_dict.txt",
+ },
+ "ta": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/ta_dict.txt",
+ },
+ "te": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/te_dict.txt",
+ },
+ "ka": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/ka_dict.txt",
+ },
+ "latin": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/latin_dict.txt",
+ },
+ "arabic": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/arabic_dict.txt",
+ },
+ "cyrillic": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/cyrillic_dict.txt",
+ },
+ "devanagari": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar",
+ "dict_path": "./ppocr/utils/dict/devanagari_dict.txt",
+ },
+ "structure": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar",
+ "dict_path": "ppocr/utils/dict/table_dict.txt",
+ },
},
- 'cls': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
+ "cls": {
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar",
}
},
- }
+ },
},
- 'STRUCTURE': {
- 'PP-Structure': {
- 'table': {
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
- 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
+ "STRUCTURE": {
+ "PP-Structure": {
+ "table": {
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar",
+ "dict_path": "ppocr/utils/dict/table_structure_dict.txt",
}
}
},
- 'PP-StructureV2': {
- 'table': {
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar',
- 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
- },
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar',
- 'dict_path': 'ppocr/utils/dict/table_structure_dict_ch.txt'
- }
+ "PP-StructureV2": {
+ "table": {
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar",
+ "dict_path": "ppocr/utils/dict/table_structure_dict.txt",
+ },
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar",
+ "dict_path": "ppocr/utils/dict/table_structure_dict_ch.txt",
+ },
},
- 'layout': {
- 'en': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar',
- 'dict_path':
- 'ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt'
- },
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_cdla_infer.tar',
- 'dict_path':
- 'ppocr/utils/dict/layout_dict/layout_cdla_dict.txt'
- }
- }
- }
- }
+ "layout": {
+ "en": {
+ "url": "https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar",
+ "dict_path": "ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt",
+ },
+ "ch": {
+ "url": "https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_cdla_infer.tar",
+ "dict_path": "ppocr/utils/dict/layout_dict/layout_cdla_dict.txt",
+ },
+ },
+ },
+ },
}
def parse_args(mMain=True):
import argparse
+
parser = init_args()
parser.add_help = mMain
- parser.add_argument("--lang", type=str, default='ch')
+ parser.add_argument("--lang", type=str, default="ch")
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
- parser.add_argument("--type", type=str, default='ocr')
+ parser.add_argument("--type", type=str, default="ocr")
parser.add_argument("--savefile", type=str2bool, default=False)
parser.add_argument(
"--ocr_version",
type=str,
choices=SUPPORT_OCR_MODEL_VERSION,
- default='PP-OCRv4',
- help='OCR Model version, the current model support list is as follows: '
- '1. PP-OCRv4/v3 Support Chinese and English detection and recognition model, and direction classifier model'
- '2. PP-OCRv2 Support Chinese detection and recognition model. '
- '3. PP-OCR support Chinese detection, recognition and direction classifier and multilingual recognition model.'
+ default="PP-OCRv4",
+ help="OCR Model version, the current model support list is as follows: "
+ "1. PP-OCRv4/v3 Support Chinese and English detection and recognition model, and direction classifier model"
+ "2. PP-OCRv2 Support Chinese detection and recognition model. "
+ "3. PP-OCR support Chinese detection, recognition and direction classifier and multilingual recognition model.",
)
parser.add_argument(
"--structure_version",
type=str,
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
- default='PP-StructureV2',
- help='Model version, the current model support list is as follows:'
- ' 1. PP-Structure Support en table structure model.'
- ' 2. PP-StructureV2 Support ch and en table structure model.')
+ default="PP-StructureV2",
+ help="Model version, the current model support list is as follows:"
+ " 1. PP-Structure Support en table structure model."
+ " 2. PP-StructureV2 Support ch and en table structure model.",
+ )
for action in parser._actions:
if action.dest in [
- 'rec_char_dict_path', 'table_char_dict_path', 'layout_dict_path'
+ "rec_char_dict_path",
+ "table_char_dict_path",
+ "layout_dict_path",
]:
action.default = None
if mMain:
@@ -446,19 +405,82 @@ def parse_args(mMain=True):
def parse_lang(lang):
latin_lang = [
- 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
- 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
- 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
- 'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
+ "af",
+ "az",
+ "bs",
+ "cs",
+ "cy",
+ "da",
+ "de",
+ "es",
+ "et",
+ "fr",
+ "ga",
+ "hr",
+ "hu",
+ "id",
+ "is",
+ "it",
+ "ku",
+ "la",
+ "lt",
+ "lv",
+ "mi",
+ "ms",
+ "mt",
+ "nl",
+ "no",
+ "oc",
+ "pi",
+ "pl",
+ "pt",
+ "ro",
+ "rs_latin",
+ "sk",
+ "sl",
+ "sq",
+ "sv",
+ "sw",
+ "tl",
+ "tr",
+ "uz",
+ "vi",
+ "french",
+ "german",
]
- arabic_lang = ['ar', 'fa', 'ug', 'ur']
+ arabic_lang = ["ar", "fa", "ug", "ur"]
cyrillic_lang = [
- 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
- 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
+ "ru",
+ "rs_cyrillic",
+ "be",
+ "bg",
+ "uk",
+ "mn",
+ "abq",
+ "ady",
+ "kbd",
+ "ava",
+ "dar",
+ "inh",
+ "che",
+ "lbe",
+ "lez",
+ "tab",
]
devanagari_lang = [
- 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
- 'sa', 'bgc'
+ "hi",
+ "mr",
+ "ne",
+ "bh",
+ "mai",
+ "ang",
+ "bho",
+ "mah",
+ "sck",
+ "new",
+ "gom",
+ "sa",
+ "bgc",
]
if lang in latin_lang:
lang = "latin"
@@ -468,13 +490,15 @@ def parse_lang(lang):
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
- assert lang in MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION][
- 'rec'], 'param lang must in {}, but got {}'.format(
- MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION]['rec'].keys(), lang)
+ assert (
+ lang in MODEL_URLS["OCR"][DEFAULT_OCR_MODEL_VERSION]["rec"]
+ ), "param lang must in {}, but got {}".format(
+ MODEL_URLS["OCR"][DEFAULT_OCR_MODEL_VERSION]["rec"].keys(), lang
+ )
if lang == "ch":
det_lang = "ch"
- elif lang == 'structure':
- det_lang = 'structure'
+ elif lang == "structure":
+ det_lang = "structure"
elif lang in ["en", "latin"]:
det_lang = "en"
else:
@@ -483,9 +507,9 @@ def parse_lang(lang):
def get_model_config(type, version, model_type, lang):
- if type == 'OCR':
+ if type == "OCR":
DEFAULT_MODEL_VERSION = DEFAULT_OCR_MODEL_VERSION
- elif type == 'STRUCTURE':
+ elif type == "STRUCTURE":
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else:
raise NotImplementedError
@@ -497,8 +521,11 @@ def get_model_config(type, version, model_type, lang):
if model_type in model_urls[DEFAULT_MODEL_VERSION]:
version = DEFAULT_MODEL_VERSION
else:
- logger.error('{} models is not support, we only support {}'.format(
- model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
+ logger.error(
+ "{} models is not support, we only support {}".format(
+ model_type, model_urls[DEFAULT_MODEL_VERSION].keys()
+ )
+ )
sys.exit(-1)
if lang not in model_urls[version][model_type]:
@@ -506,9 +533,12 @@ def get_model_config(type, version, model_type, lang):
version = DEFAULT_MODEL_VERSION
else:
logger.error(
- 'lang {} is not support, we only support {} for {} models'.
- format(lang, model_urls[DEFAULT_MODEL_VERSION][model_type].keys(
- ), model_type))
+ "lang {} is not support, we only support {} for {} models".format(
+ lang,
+ model_urls[DEFAULT_MODEL_VERSION][model_type].keys(),
+ model_type,
+ )
+ )
sys.exit(-1)
return model_urls[version][model_type][lang]
@@ -536,12 +566,12 @@ def check_img(img, alpha_color=(255, 255, 255)):
if isinstance(img, str):
# download net image
if is_link(img):
- download_with_progressbar(img, 'tmp.jpg')
- img = 'tmp.jpg'
+ download_with_progressbar(img, "tmp.jpg")
+ img = "tmp.jpg"
image_file = img
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
- with open(image_file, 'rb') as f:
+ with open(image_file, "rb") as f:
img_str = f.read()
img = img_decode(img_str)
if img is None:
@@ -549,12 +579,11 @@ def check_img(img, alpha_color=(255, 255, 255)):
buf = BytesIO()
image = BytesIO(img_str)
im = Image.open(image)
- rgb = im.convert('RGB')
- rgb.save(buf, 'jpeg')
+ rgb = im.convert("RGB")
+ rgb.save(buf, "jpeg")
buf.seek(0)
image_bytes = buf.read()
- data_base64 = str(base64.b64encode(image_bytes),
- encoding="utf-8")
+ data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
@@ -582,8 +611,11 @@ def __init__(self, **kwargs):
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
- assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(
- SUPPORT_OCR_MODEL_VERSION, params.ocr_version)
+ assert (
+ params.ocr_version in SUPPORT_OCR_MODEL_VERSION
+ ), "ocr_version must in {}, but get {}".format(
+ SUPPORT_OCR_MODEL_VERSION, params.ocr_version
+ )
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log:
@@ -592,23 +624,25 @@ def __init__(self, **kwargs):
lang, det_lang = parse_lang(params.lang)
# init model dir
- det_model_config = get_model_config('OCR', params.ocr_version, 'det',
- det_lang)
+ det_model_config = get_model_config("OCR", params.ocr_version, "det", det_lang)
params.det_model_dir, det_url = confirm_model_dir_url(
params.det_model_dir,
- os.path.join(BASE_DIR, 'whl', 'det', det_lang),
- det_model_config['url'])
- rec_model_config = get_model_config('OCR', params.ocr_version, 'rec',
- lang)
+ os.path.join(BASE_DIR, "whl", "det", det_lang),
+ det_model_config["url"],
+ )
+ rec_model_config = get_model_config("OCR", params.ocr_version, "rec", lang)
params.rec_model_dir, rec_url = confirm_model_dir_url(
params.rec_model_dir,
- os.path.join(BASE_DIR, 'whl', 'rec', lang), rec_model_config['url'])
- cls_model_config = get_model_config('OCR', params.ocr_version, 'cls',
- 'ch')
+ os.path.join(BASE_DIR, "whl", "rec", lang),
+ rec_model_config["url"],
+ )
+ cls_model_config = get_model_config("OCR", params.ocr_version, "cls", "ch")
params.cls_model_dir, cls_url = confirm_model_dir_url(
params.cls_model_dir,
- os.path.join(BASE_DIR, 'whl', 'cls'), cls_model_config['url'])
- if params.ocr_version in ['PP-OCRv3', 'PP-OCRv4']:
+ os.path.join(BASE_DIR, "whl", "cls"),
+ cls_model_config["url"],
+ )
+ if params.ocr_version in ["PP-OCRv3", "PP-OCRv4"]:
params.rec_image_shape = "3, 48, 320"
else:
params.rec_image_shape = "3, 32, 320"
@@ -619,25 +653,35 @@ def __init__(self, **kwargs):
maybe_download(params.cls_model_dir, cls_url)
if params.det_algorithm not in SUPPORT_DET_MODEL:
- logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
+ logger.error("det_algorithm must in {}".format(SUPPORT_DET_MODEL))
sys.exit(0)
if params.rec_algorithm not in SUPPORT_REC_MODEL:
- logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
+ logger.error("rec_algorithm must in {}".format(SUPPORT_REC_MODEL))
sys.exit(0)
if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(
- Path(__file__).parent / rec_model_config['dict_path'])
+ Path(__file__).parent / rec_model_config["dict_path"]
+ )
logger.debug(params)
# init det_model and rec_model
super().__init__(params)
self.page_num = params.page_num
- def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_color=(255, 255, 255)):
+ def ocr(
+ self,
+ img,
+ det=True,
+ rec=True,
+ cls=True,
+ bin=False,
+ inv=False,
+ alpha_color=(255, 255, 255),
+ ):
"""
OCR with PaddleOCR
-
+
args:
img: img for OCR, support ndarray, img_path and list or ndarray
det: use text detection or not. If False, only rec will be exec. Default is True
@@ -649,11 +693,11 @@ def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_col
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
- logger.error('When input a list of images, det must be false')
+ logger.error("When input a list of images, det must be false")
exit(0)
if cls == True and self.use_angle_cls == False:
logger.warning(
- 'Since the angle classifier is not initialized, it will not be used during the forward process'
+ "Since the angle classifier is not initialized, it will not be used during the forward process"
)
img = check_img(img, alpha_color)
@@ -662,7 +706,7 @@ def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_col
if self.page_num > len(img) or self.page_num == 0:
imgs = img
else:
- imgs = img[:self.page_num]
+ imgs = img[: self.page_num]
else:
imgs = [img]
@@ -682,8 +726,7 @@ def preprocess_image(_image):
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
- tmp_res = [[box.tolist(), res]
- for box, res in zip(dt_boxes, rec_res)]
+ tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
@@ -719,43 +762,53 @@ class PPStructure(StructureSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
- assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
- SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
+ assert (
+ params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION
+ ), "structure_version must in {}, but get {}".format(
+ SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version
+ )
params.use_gpu = check_gpu(params.use_gpu)
- params.mode = 'structure'
+ params.mode = "structure"
if not params.show_log:
logger.setLevel(logging.INFO)
lang, det_lang = parse_lang(params.lang)
- if lang == 'ch':
- table_lang = 'ch'
+ if lang == "ch":
+ table_lang = "ch"
else:
- table_lang = 'en'
- if params.structure_version == 'PP-Structure':
+ table_lang = "en"
+ if params.structure_version == "PP-Structure":
params.merge_no_span_structure = False
# init model dir
- det_model_config = get_model_config('OCR', params.ocr_version, 'det',
- det_lang)
+ det_model_config = get_model_config("OCR", params.ocr_version, "det", det_lang)
params.det_model_dir, det_url = confirm_model_dir_url(
params.det_model_dir,
- os.path.join(BASE_DIR, 'whl', 'det', det_lang),
- det_model_config['url'])
- rec_model_config = get_model_config('OCR', params.ocr_version, 'rec',
- lang)
+ os.path.join(BASE_DIR, "whl", "det", det_lang),
+ det_model_config["url"],
+ )
+ rec_model_config = get_model_config("OCR", params.ocr_version, "rec", lang)
params.rec_model_dir, rec_url = confirm_model_dir_url(
params.rec_model_dir,
- os.path.join(BASE_DIR, 'whl', 'rec', lang), rec_model_config['url'])
+ os.path.join(BASE_DIR, "whl", "rec", lang),
+ rec_model_config["url"],
+ )
table_model_config = get_model_config(
- 'STRUCTURE', params.structure_version, 'table', table_lang)
+ "STRUCTURE", params.structure_version, "table", table_lang
+ )
params.table_model_dir, table_url = confirm_model_dir_url(
params.table_model_dir,
- os.path.join(BASE_DIR, 'whl', 'table'), table_model_config['url'])
+ os.path.join(BASE_DIR, "whl", "table"),
+ table_model_config["url"],
+ )
layout_model_config = get_model_config(
- 'STRUCTURE', params.structure_version, 'layout', lang)
+ "STRUCTURE", params.structure_version, "layout", lang
+ )
params.layout_model_dir, layout_url = confirm_model_dir_url(
params.layout_model_dir,
- os.path.join(BASE_DIR, 'whl', 'layout'), layout_model_config['url'])
+ os.path.join(BASE_DIR, "whl", "layout"),
+ layout_model_config["url"],
+ )
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
@@ -764,20 +817,28 @@ def __init__(self, **kwargs):
if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(
- Path(__file__).parent / rec_model_config['dict_path'])
+ Path(__file__).parent / rec_model_config["dict_path"]
+ )
if params.table_char_dict_path is None:
params.table_char_dict_path = str(
- Path(__file__).parent / table_model_config['dict_path'])
+ Path(__file__).parent / table_model_config["dict_path"]
+ )
if params.layout_dict_path is None:
params.layout_dict_path = str(
- Path(__file__).parent / layout_model_config['dict_path'])
+ Path(__file__).parent / layout_model_config["dict_path"]
+ )
logger.debug(params)
super().__init__(params)
- def __call__(self, img, return_ocr_result_in_table=False, img_idx=0, alpha_color=(255, 255, 255)):
+ def __call__(
+ self,
+ img,
+ return_ocr_result_in_table=False,
+ img_idx=0,
+ alpha_color=(255, 255, 255),
+ ):
img = check_img(img, alpha_color)
- res, _ = super().__call__(
- img, return_ocr_result_in_table, img_idx=img_idx)
+ res, _ = super().__call__(img, return_ocr_result_in_table, img_idx=img_idx)
return res
@@ -786,24 +847,24 @@ def main():
args = parse_args(mMain=True)
image_dir = args.image_dir
if is_link(image_dir):
- download_with_progressbar(image_dir, 'tmp.jpg')
- image_file_list = ['tmp.jpg']
+ download_with_progressbar(image_dir, "tmp.jpg")
+ image_file_list = ["tmp.jpg"]
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
- logger.error('no images find in {}'.format(args.image_dir))
+ logger.error("no images find in {}".format(args.image_dir))
return
- if args.type == 'ocr':
+ if args.type == "ocr":
engine = PaddleOCR(**(args.__dict__))
- elif args.type == 'structure':
+ elif args.type == "structure":
engine = PPStructure(**(args.__dict__))
else:
raise NotImplementedError
for img_path in image_file_list:
- img_name = os.path.basename(img_path).split('.')[0]
- logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
- if args.type == 'ocr':
+ img_name = os.path.basename(img_path).split(".")[0]
+ logger.info("{}{}{}".format("*" * 10, img_path, "*" * 10))
+ if args.type == "ocr":
result = engine.ocr(
img_path,
det=args.det,
@@ -811,7 +872,7 @@ def main():
cls=args.use_angle_cls,
bin=args.binarize,
inv=args.invert,
- alpha_color=args.alphacolor
+ alpha_color=args.alphacolor,
)
if result is not None:
lines = []
@@ -819,33 +880,33 @@ def main():
res = result[idx]
for line in res:
logger.info(line)
- val = '['
+ val = "["
for box in line[0]:
- val += str(box[0]) + ',' + str(box[1]) + ','
+ val += str(box[0]) + "," + str(box[1]) + ","
val = val[:-1]
- val += '],' + line[1][0] + ',' + str(line[1][1]) + '\n'
+ val += "]," + line[1][0] + "," + str(line[1][1]) + "\n"
lines.append(val)
if args.savefile:
if os.path.exists(args.output) is False:
os.mkdir(args.output)
- outfile = args.output + '/' + img_name + '.txt'
- with open(outfile,'w',encoding='utf-8') as f:
+ outfile = args.output + "/" + img_name + ".txt"
+ with open(outfile, "w", encoding="utf-8") as f:
f.writelines(lines)
-
- elif args.type == 'structure':
+
+ elif args.type == "structure":
img, flag_gif, flag_pdf = check_and_read(img_path)
if not flag_gif and not flag_pdf:
img = cv2.imread(img_path)
if args.recovery and args.use_pdf2docx_api and flag_pdf:
from pdf2docx.converter import Converter
- docx_file = os.path.join(args.output,
- '{}.docx'.format(img_name))
+
+ docx_file = os.path.join(args.output, "{}.docx".format(img_name))
cv = Converter(img_path)
cv.convert(docx_file)
cv.close()
- logger.info('docx save to {}'.format(docx_file))
+ logger.info("docx save to {}".format(docx_file))
continue
if not flag_pdf:
@@ -856,25 +917,24 @@ def main():
else:
img_paths = []
for index, pdf_img in enumerate(img):
- os.makedirs(
- os.path.join(args.output, img_name), exist_ok=True)
+ os.makedirs(os.path.join(args.output, img_name), exist_ok=True)
pdf_img_path = os.path.join(
- args.output, img_name,
- img_name + '_' + str(index) + '.jpg')
+ args.output, img_name, img_name + "_" + str(index) + ".jpg"
+ )
cv2.imwrite(pdf_img_path, pdf_img)
img_paths.append([pdf_img_path, pdf_img])
all_res = []
for index, (new_img_path, img) in enumerate(img_paths):
- logger.info('processing {}/{} page:'.format(index + 1,
- len(img_paths)))
- new_img_name = os.path.basename(new_img_path).split('.')[0]
+ logger.info("processing {}/{} page:".format(index + 1, len(img_paths)))
+ new_img_name = os.path.basename(new_img_path).split(".")[0]
result = engine(img, img_idx=index)
save_structure_res(result, args.output, img_name, index)
if args.recovery and result != []:
from copy import deepcopy
from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes
+
h, w, _ = img.shape
result_cp = deepcopy(result)
result_sorted = sorted_layout_boxes(result_cp, w)
@@ -883,15 +943,18 @@ def main():
if args.recovery and all_res != []:
try:
from ppstructure.recovery.recovery_to_doc import convert_info_docx
+
convert_info_docx(img, all_res, args.output, img_name)
except Exception as ex:
logger.error(
"error in layout recovery image:{}, err msg: {}".format(
- img_name, ex))
+ img_name, ex
+ )
+ )
continue
for item in all_res:
- item.pop('img')
- item.pop('res')
+ item.pop("img")
+ item.pop("res")
logger.info(item)
- logger.info('result save to {}'.format(args.output))
+ logger.info("result save to {}".format(args.output))
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 48cd8ad8c5..27d74c89d8 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -26,7 +26,7 @@
import random
__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "../..")))
import copy
from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
@@ -46,14 +46,11 @@
PubTabTableRecDataset = PubTabDataSet
KieDataset = SimpleDataSet
-__all__ = [
- 'build_dataloader', 'transform', 'create_operators', 'set_signal_handlers'
-]
+__all__ = ["build_dataloader", "transform", "create_operators", "set_signal_handlers"]
def term_mp(sig_num, frame):
- """ kill all child processes
- """
+ """kill all child processes"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
@@ -69,11 +66,11 @@ def set_signal_handlers():
# because we cannot do safe cleanup.
pass
else:
- # XXX: `term_mp` kills all processes in the process group, which in
- # some cases includes the parent process of current process and may
- # cause unexpected results. To solve this problem, we set signal
- # handlers only when current process is the group leader. In the
- # future, it would be better to consider killing only descendants of
+ # XXX: `term_mp` kills all processes in the process group, which in
+ # some cases includes the parent process of current process and may
+ # cause unexpected results. To solve this problem, we set signal
+ # handlers only when current process is the group leader. In the
+ # future, it would be better to consider killing only descendants of
# the current process.
if pid == pgid:
# support exit using ctrl+c
@@ -85,40 +82,40 @@ def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = [
- 'SimpleDataSet',
- 'LMDBDataSet',
- 'PGDataSet',
- 'PubTabDataSet',
- 'LMDBDataSetSR',
- 'LMDBDataSetTableMaster',
- 'MultiScaleDataSet',
- 'TextDetDataset',
- 'TextRecDataset',
- 'MSTextRecDataset',
- 'PubTabTableRecDataset',
- 'KieDataset',
+ "SimpleDataSet",
+ "LMDBDataSet",
+ "PGDataSet",
+ "PubTabDataSet",
+ "LMDBDataSetSR",
+ "LMDBDataSetTableMaster",
+ "MultiScaleDataSet",
+ "TextDetDataset",
+ "TextRecDataset",
+ "MSTextRecDataset",
+ "PubTabTableRecDataset",
+ "KieDataset",
]
- module_name = config[mode]['dataset']['name']
+ module_name = config[mode]["dataset"]["name"]
assert module_name in support_dict, Exception(
- 'DataSet only support {}'.format(support_dict))
- assert mode in ['Train', 'Eval', 'Test'
- ], "Mode should be Train, Eval or Test."
+ "DataSet only support {}".format(support_dict)
+ )
+ assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode, logger, seed)
- loader_config = config[mode]['loader']
- batch_size = loader_config['batch_size_per_card']
- drop_last = loader_config['drop_last']
- shuffle = loader_config['shuffle']
- num_workers = loader_config['num_workers']
- if 'use_shared_memory' in loader_config.keys():
- use_shared_memory = loader_config['use_shared_memory']
+ loader_config = config[mode]["loader"]
+ batch_size = loader_config["batch_size_per_card"]
+ drop_last = loader_config["drop_last"]
+ shuffle = loader_config["shuffle"]
+ num_workers = loader_config["num_workers"]
+ if "use_shared_memory" in loader_config.keys():
+ use_shared_memory = loader_config["use_shared_memory"]
else:
use_shared_memory = True
if mode == "Train":
# Distribute data to multiple cards
- if 'sampler' in config[mode]:
- config_sampler = config[mode]['sampler']
+ if "sampler" in config[mode]:
+ config_sampler = config[mode]["sampler"]
sampler_name = config_sampler.pop("name")
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
else:
@@ -126,18 +123,18 @@ def build_dataloader(config, mode, device, logger, seed=None):
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
- drop_last=drop_last)
+ drop_last=drop_last,
+ )
else:
# Distribute data to single card
batch_sampler = BatchSampler(
- dataset=dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- drop_last=drop_last)
+ dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
+ )
- if 'collate_fn' in loader_config:
+ if "collate_fn" in loader_config:
from . import collate_fn
- collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
+
+ collate_fn = getattr(collate_fn, loader_config["collate_fn"])()
else:
collate_fn = None
data_loader = DataLoader(
@@ -147,6 +144,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
- collate_fn=collate_fn)
+ collate_fn=collate_fn,
+ )
return data_loader
diff --git a/ppocr/data/collate_fn.py b/ppocr/data/collate_fn.py
index 067b2158ac..f1f317510b 100644
--- a/ppocr/data/collate_fn.py
+++ b/ppocr/data/collate_fn.py
@@ -24,7 +24,7 @@ class DictCollator(object):
"""
def __call__(self, batch):
- # todo:support batch operators
+ # todo:support batch operators
data_dict = defaultdict(list)
to_tensor_keys = []
for sample in batch:
@@ -44,7 +44,7 @@ class ListCollator(object):
"""
def __call__(self, batch):
- # todo:support batch operators
+ # todo:support batch operators
data_dict = defaultdict(list)
to_tensor_idxs = []
for sample in batch:
@@ -88,24 +88,24 @@ def __call__(self, batch):
bs, channel = len(batch), batch[0][0].shape[0]
proper_items = []
for item in batch:
- if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
- 2] * max_height > 1600 * 320:
+ if (
+ item[0].shape[1] * max_width > 1600 * 320
+ or item[0].shape[2] * max_height > 1600 * 320
+ ):
continue
- max_height = item[0].shape[1] if item[0].shape[
- 1] > max_height else max_height
- max_width = item[0].shape[2] if item[0].shape[
- 2] > max_width else max_width
- max_length = len(item[1]) if len(item[
- 1]) > max_length else max_length
+ max_height = (
+ item[0].shape[1] if item[0].shape[1] > max_height else max_height
+ )
+ max_width = item[0].shape[2] if item[0].shape[2] > max_width else max_width
+ max_length = len(item[1]) if len(item[1]) > max_length else max_length
proper_items.append(item)
images, image_masks = np.zeros(
- (len(proper_items), channel, max_height, max_width),
- dtype='float32'), np.zeros(
- (len(proper_items), 1, max_height, max_width), dtype='float32')
+ (len(proper_items), channel, max_height, max_width), dtype="float32"
+ ), np.zeros((len(proper_items), 1, max_height, max_width), dtype="float32")
labels, label_masks = np.zeros(
- (len(proper_items), max_length), dtype='int64'), np.zeros(
- (len(proper_items), max_length), dtype='int64')
+ (len(proper_items), max_length), dtype="int64"
+ ), np.zeros((len(proper_items), max_length), dtype="int64")
for i in range(len(proper_items)):
_, h, w = proper_items[i][0].shape
diff --git a/ppocr/data/imaug/ColorJitter.py b/ppocr/data/imaug/ColorJitter.py
index 4b542abc8f..46c1955e4f 100644
--- a/ppocr/data/imaug/ColorJitter.py
+++ b/ppocr/data/imaug/ColorJitter.py
@@ -13,14 +13,15 @@
# limitations under the License.
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
-__all__ = ['ColorJitter']
+__all__ = ["ColorJitter"]
+
class ColorJitter(object):
- def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, **kwargs):
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
def __call__(self, data):
- image = data['image']
+ image = data["image"]
image = self.aug(image)
- data['image'] = image
+ data["image"] = image
return data
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 1eb611f6c0..350887933b 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,11 +23,26 @@
from .make_pse_gt import MakePseGt
-
-from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
- SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
- ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
- RFLRecResizeImg, SVTRRecAug, ParseQRecAug
+from .rec_img_aug import (
+ BaseDataAugmentation,
+ RecAug,
+ RecConAug,
+ RecResizeImg,
+ ClsResizeImg,
+ SRNRecResizeImg,
+ GrayRecResizeImg,
+ SARRecResizeImg,
+ PRENResizeImg,
+ ABINetRecResizeImg,
+ SVTRRecResizeImg,
+ ABINetRecAug,
+ VLRecResizeImg,
+ SPINRecResizeImg,
+ RobustScannerRecResizeImg,
+ RFLRecResizeImg,
+ SVTRRecAug,
+ ParseQRecAug,
+)
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
@@ -49,7 +64,7 @@
def transform(data, ops=None):
- """ transform """
+ """transform"""
if ops is None:
ops = []
for op in ops:
@@ -66,11 +81,10 @@ def create_operators(op_param_list, global_config=None):
Args:
params(list): a dict list, used to create some operators
"""
- assert isinstance(op_param_list, list), ('operator config should be a list')
+ assert isinstance(op_param_list, list), "operator config should be a list"
ops = []
for operator in op_param_list:
- assert isinstance(operator,
- dict) and len(operator) == 1, "yaml format error"
+ assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
diff --git a/ppocr/data/imaug/abinet_aug.py b/ppocr/data/imaug/abinet_aug.py
index 9e1b6a6ce9..3df255ee93 100644
--- a/ppocr/data/imaug/abinet_aug.py
+++ b/ppocr/data/imaug/abinet_aug.py
@@ -36,31 +36,28 @@ def sample_uniform(low, high, size=None):
return np.random.uniform(low, high, size=size)
-def get_interpolation(type='random'):
- if type == 'random':
- choice = [
- cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
- ]
+def get_interpolation(type="random"):
+ if type == "random":
+ choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA]
interpolation = choice[random.randint(0, len(choice) - 1)]
- elif type == 'nearest':
+ elif type == "nearest":
interpolation = cv2.INTER_NEAREST
- elif type == 'linear':
+ elif type == "linear":
interpolation = cv2.INTER_LINEAR
- elif type == 'cubic':
+ elif type == "cubic":
interpolation = cv2.INTER_CUBIC
- elif type == 'area':
+ elif type == "area":
interpolation = cv2.INTER_AREA
else:
raise TypeError(
- 'Interpolation types only nearest, linear, cubic, area are supported!'
+ "Interpolation types only nearest, linear, cubic, area are supported!"
)
return interpolation
class CVRandomRotation(object):
def __init__(self, degrees=15):
- assert isinstance(degrees,
- numbers.Number), "degree should be a single number."
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
assert degrees >= 0, "degree must be positive."
self.degrees = degrees
@@ -72,7 +69,8 @@ def __call__(self, img):
angle = self.get_params(self.degrees)
src_h, src_w = img.shape[:2]
M = cv2.getRotationMatrix2D(
- center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
+ center=(src_w / 2, src_h / 2), angle=angle, scale=1.0
+ )
abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
dst_w = int(src_h * abs_sin + src_w * abs_cos)
dst_h = int(src_h * abs_cos + src_w * abs_sin)
@@ -81,31 +79,29 @@ def __call__(self, img):
flags = get_interpolation()
return cv2.warpAffine(
- img,
- M, (dst_w, dst_h),
- flags=flags,
- borderMode=cv2.BORDER_REPLICATE)
+ img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE
+ )
class CVRandomAffine(object):
def __init__(self, degrees, translate=None, scale=None, shear=None):
- assert isinstance(degrees,
- numbers.Number), "degree should be a single number."
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
assert degrees >= 0, "degree must be positive."
self.degrees = degrees
if translate is not None:
- assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
- "translate should be a list or tuple and it must be of length 2."
+ assert (
+ isinstance(translate, (tuple, list)) and len(translate) == 2
+ ), "translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
- raise ValueError(
- "translation values should be between 0 and 1")
+ raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
- assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
- "scale should be a list or tuple and it must be of length 2."
+ assert (
+ isinstance(scale, (tuple, list)) and len(scale) == 2
+ ), "scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
@@ -115,17 +111,18 @@ def __init__(self, degrees, translate=None, scale=None, shear=None):
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError(
- "If shear is a single number, it must be positive.")
+ "If shear is a single number, it must be positive."
+ )
self.shear = [shear]
else:
- assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
- "shear should be a list or tuple and it must be of length 2."
+ assert isinstance(shear, (tuple, list)) and (
+ len(shear) == 2
+ ), "shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
- def _get_inverse_affine_matrix(self, center, angle, translate, scale,
- shear):
+ def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear):
# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
from numpy import sin, cos, tan
@@ -134,8 +131,9 @@ def _get_inverse_affine_matrix(self, center, angle, translate, scale,
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
raise ValueError(
- "Shear should be a single value or a tuple/list containing " +
- "two values. Got {}".format(shear))
+ "Shear should be a single value or a tuple/list containing "
+ + "two values. Got {}".format(shear)
+ )
rot = math.radians(angle)
sx, sy = [math.radians(s) for s in shear]
@@ -169,8 +167,7 @@ def get_params(degrees, translate, scale_ranges, shears, height):
if translate is not None:
max_dx = translate[0] * height
max_dy = translate[1] * height
- translations = (np.round(sample_sym(max_dx)),
- np.round(sample_sym(max_dy)))
+ translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy)))
else:
translations = (0, 0)
@@ -181,7 +178,7 @@ def get_params(degrees, translate, scale_ranges, shears, height):
if shears is not None:
if len(shears) == 1:
- shear = [sample_sym(shears[0]), 0.]
+ shear = [sample_sym(shears[0]), 0.0]
elif len(shears) == 2:
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
else:
@@ -192,17 +189,19 @@ def get_params(degrees, translate, scale_ranges, shears, height):
def __call__(self, img):
src_h, src_w = img.shape[:2]
angle, translate, scale, shear = self.get_params(
- self.degrees, self.translate, self.scale, self.shear, src_h)
+ self.degrees, self.translate, self.scale, self.shear, src_h
+ )
- M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
- (0, 0), scale, shear)
+ M = self._get_inverse_affine_matrix(
+ (src_w / 2, src_h / 2), angle, (0, 0), scale, shear
+ )
M = np.array(M).reshape(2, 3)
- startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
- (0, src_h - 1)]
+ startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)]
project = lambda x, y, a, b, c: int(a * x + b * y + c)
- endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
- for x, y in startpoints]
+ endpoints = [
+ (project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints
+ ]
rect = cv2.minAreaRect(np.array(endpoints))
bbox = cv2.boxPoints(rect).astype(dtype=np.int32)
@@ -217,15 +216,15 @@ def __call__(self, img):
# add translate
dst_w += int(abs(translate[0]))
dst_h += int(abs(translate[1]))
- if translate[0] < 0: M[0, 2] += abs(translate[0])
- if translate[1] < 0: M[1, 2] += abs(translate[1])
+ if translate[0] < 0:
+ M[0, 2] += abs(translate[0])
+ if translate[1] < 0:
+ M[1, 2] += abs(translate[1])
flags = get_interpolation()
return cv2.warpAffine(
- img,
- M, (dst_w, dst_h),
- flags=flags,
- borderMode=cv2.BORDER_REPLICATE)
+ img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE
+ )
class CVRandomPerspective(object):
@@ -233,21 +232,18 @@ def __init__(self, distortion=0.5):
self.distortion = distortion
def get_params(self, width, height, distortion):
- offset_h = sample_asym(
- distortion * height / 2, size=4).astype(dtype=np.int32)
- offset_w = sample_asym(
- distortion * width / 2, size=4).astype(dtype=np.int32)
+ offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int32)
+ offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int32)
topleft = (offset_w[0], offset_h[0])
topright = (width - 1 - offset_w[1], offset_h[1])
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
botleft = (offset_w[3], height - 1 - offset_h[3])
- startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
- (0, height - 1)]
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
endpoints = [topleft, topright, botright, botleft]
- return np.array(
- startpoints, dtype=np.float32), np.array(
- endpoints, dtype=np.float32)
+ return np.array(startpoints, dtype=np.float32), np.array(
+ endpoints, dtype=np.float32
+ )
def __call__(self, img):
height, width = img.shape[:2]
@@ -263,18 +259,16 @@ def __call__(self, img):
flags = get_interpolation()
img = cv2.warpPerspective(
- img,
- M, (max_x, max_y),
- flags=flags,
- borderMode=cv2.BORDER_REPLICATE)
+ img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE
+ )
img = img[min_y:, min_x:]
return img
class CVRescale(object):
def __init__(self, factor=4, base_size=(128, 512)):
- """ Define image scales using gaussian pyramid and rescale image to target scale.
-
+ """Define image scales using gaussian pyramid and rescale image to target scale.
+
Args:
factor: the decayed factor from base size, factor=4 keeps target scale by default.
base_size: base size the build the bottom layer of pyramid
@@ -284,20 +278,21 @@ def __init__(self, factor=4, base_size=(128, 512)):
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
self.factor = round(sample_uniform(factor[0], factor[1]))
else:
- raise Exception('factor must be number or list with length 2')
+ raise Exception("factor must be number or list with length 2")
# assert factor is valid
self.base_h, self.base_w = base_size[:2]
def __call__(self, img):
- if self.factor == 0: return img
+ if self.factor == 0:
+ return img
src_h, src_w = img.shape[:2]
cur_w, cur_h = self.base_w, self.base_h
- scale_img = cv2.resize(
- img, (cur_w, cur_h), interpolation=get_interpolation())
+ scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation())
for _ in range(self.factor):
scale_img = cv2.pyrDown(scale_img)
scale_img = cv2.resize(
- scale_img, (src_w, src_h), interpolation=get_interpolation())
+ scale_img, (src_w, src_h), interpolation=get_interpolation()
+ )
return scale_img
@@ -309,13 +304,14 @@ def __init__(self, mean=0, var=20):
elif isinstance(var, (tuple, list)) and len(var) == 2:
self.var = int(sample_uniform(var[0], var[1]))
else:
- raise Exception('degree must be number or list with length 2')
+ raise Exception("degree must be number or list with length 2")
def __call__(self, img):
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
img = np.clip(img + noise, 0, 255).astype(np.uint8)
return img
+
class CVPossionNoise(object):
def __init__(self, lam=20):
self.lam = lam
@@ -324,13 +320,14 @@ def __init__(self, lam=20):
elif isinstance(lam, (tuple, list)) and len(lam) == 2:
self.lam = int(sample_uniform(lam[0], lam[1]))
else:
- raise Exception('lam must be number or list with length 2')
+ raise Exception("lam must be number or list with length 2")
def __call__(self, img):
noise = np.random.poisson(lam=self.lam, size=img.shape)
img = np.clip(img + noise, 0, 255).astype(np.uint8)
return img
+
class CVGaussionBlur(object):
def __init__(self, radius):
self.radius = radius
@@ -339,13 +336,14 @@ def __init__(self, radius):
elif isinstance(radius, (tuple, list)) and len(radius) == 2:
self.radius = int(sample_uniform(radius[0], radius[1]))
else:
- raise Exception('radius must be number or list with length 2')
+ raise Exception("radius must be number or list with length 2")
def __call__(self, img):
fil = cv2.getGaussianKernel(ksize=self.radius, sigma=1, ktype=cv2.CV_32F)
img = cv2.sepFilter2D(img, -1, fil, fil)
return img
+
class CVMotionBlur(object):
def __init__(self, degrees=12, angle=90):
if isinstance(degrees, numbers.Number):
@@ -353,16 +351,16 @@ def __init__(self, degrees=12, angle=90):
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
self.degree = int(sample_uniform(degrees[0], degrees[1]))
else:
- raise Exception('degree must be number or list with length 2')
+ raise Exception("degree must be number or list with length 2")
self.angle = sample_uniform(-angle, angle)
def __call__(self, img):
- M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
- self.angle, 1)
+ M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1)
motion_blur_kernel = np.zeros((self.degree, self.degree))
motion_blur_kernel[self.degree // 2, :] = 1
- motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
- (self.degree, self.degree))
+ motion_blur_kernel = cv2.warpAffine(
+ motion_blur_kernel, M, (self.degree, self.degree)
+ )
motion_blur_kernel = motion_blur_kernel / self.degree
img = cv2.filter2D(img, -1, motion_blur_kernel)
img = np.clip(img, 0, 255).astype(np.uint8)
@@ -370,20 +368,23 @@ def __call__(self, img):
class CVGeometry(object):
- def __init__(self,
- degrees=15,
- translate=(0.3, 0.3),
- scale=(0.5, 2.),
- shear=(45, 15),
- distortion=0.5,
- p=0.5):
+ def __init__(
+ self,
+ degrees=15,
+ translate=(0.3, 0.3),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=0.5,
+ ):
self.p = p
type_p = random.random()
if type_p < 0.33:
self.transforms = CVRandomRotation(degrees=degrees)
elif type_p < 0.66:
self.transforms = CVRandomAffine(
- degrees=degrees, translate=translate, scale=scale, shear=shear)
+ degrees=degrees, translate=translate, scale=scale, shear=shear
+ )
else:
self.transforms = CVRandomPerspective(distortion=distortion)
@@ -411,29 +412,23 @@ def __init__(self, var, degrees, factor, p=0.5):
def __call__(self, img):
if random.random() < self.p:
-
return self.transforms(img)
else:
return img
class CVColorJitter(object):
- def __init__(self,
- brightness=0.5,
- contrast=0.5,
- saturation=0.5,
- hue=0.1,
- p=0.5):
+ def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5):
self.p = p
self.transforms = ColorJitter(
- brightness=brightness,
- contrast=contrast,
- saturation=saturation,
- hue=hue)
+ brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
+ )
def __call__(self, img):
- if random.random() < self.p: return self.transforms(img)
- else: return img
+ if random.random() < self.p:
+ return self.transforms(img)
+ else:
+ return img
class SVTRDeterioration(object):
@@ -456,6 +451,7 @@ def __call__(self, img):
else:
return img
+
class ParseQDeterioration(object):
def __init__(self, var, degrees, lam, radius, factor, p=0.5):
self.p = p
@@ -480,29 +476,34 @@ def __call__(self, img):
else:
return img
+
class SVTRGeometry(object):
- def __init__(self,
- aug_type=0,
- degrees=15,
- translate=(0.3, 0.3),
- scale=(0.5, 2.),
- shear=(45, 15),
- distortion=0.5,
- p=0.5):
+ def __init__(
+ self,
+ aug_type=0,
+ degrees=15,
+ translate=(0.3, 0.3),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=0.5,
+ ):
self.aug_type = aug_type
self.p = p
self.transforms = []
self.transforms.append(CVRandomRotation(degrees=degrees))
self.transforms.append(
CVRandomAffine(
- degrees=degrees, translate=translate, scale=scale, shear=shear))
+ degrees=degrees, translate=translate, scale=scale, shear=shear
+ )
+ )
self.transforms.append(CVRandomPerspective(distortion=distortion))
def __call__(self, img):
if random.random() < self.p:
if self.aug_type:
random.shuffle(self.transforms)
- transforms = Compose(self.transforms[:random.randint(1, 3)])
+ transforms = Compose(self.transforms[: random.randint(1, 3)])
img = transforms(img)
else:
img = self.transforms[random.randint(0, 2)](img)
diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py
index 79343da60f..3634fbe077 100644
--- a/ppocr/data/imaug/copy_paste.py
+++ b/ppocr/data/imaug/copy_paste.py
@@ -28,24 +28,23 @@ def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
self.ext_data_num = 1
self.objects_paste_ratio = objects_paste_ratio
self.limit_paste = limit_paste
- augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
+ augmenter_args = [{"type": "Resize", "args": {"size": [0.5, 3]}}]
self.aug = IaaAugment(augmenter_args)
def __call__(self, data):
- point_num = data['polys'].shape[1]
- src_img = data['image']
- src_polys = data['polys'].tolist()
- src_texts = data['texts']
- src_ignores = data['ignore_tags'].tolist()
- ext_data = data['ext_data'][0]
- ext_image = ext_data['image']
- ext_polys = ext_data['polys']
- ext_texts = ext_data['texts']
- ext_ignores = ext_data['ignore_tags']
+ point_num = data["polys"].shape[1]
+ src_img = data["image"]
+ src_polys = data["polys"].tolist()
+ src_texts = data["texts"]
+ src_ignores = data["ignore_tags"].tolist()
+ ext_data = data["ext_data"][0]
+ ext_image = ext_data["image"]
+ ext_polys = ext_data["polys"]
+ ext_texts = ext_data["texts"]
+ ext_ignores = ext_data["ignore_tags"]
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
- select_num = max(
- 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
+ select_num = max(1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
random.shuffle(indexs)
select_idxs = indexs[:select_num]
@@ -54,13 +53,13 @@ def __call__(self, data):
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
- src_img = Image.fromarray(src_img).convert('RGBA')
+ src_img = Image.fromarray(src_img).convert("RGBA")
for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
box_img = get_rotate_crop_image(ext_image, poly)
src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None:
- box = box.tolist()
+ box = box.tolist()
for _ in range(len(box), point_num):
box.append(box[-1])
src_polys.append(box)
@@ -71,14 +70,14 @@ def __call__(self, data):
src_polys = np.array(src_polys)
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
- data['image'] = src_img
- data['polys'] = src_polys
- data['texts'] = src_texts
- data['ignore_tags'] = np.array(src_ignores)
+ data["image"] = src_img
+ data["polys"] = src_polys
+ data["texts"] = src_texts
+ data["ignore_tags"] = np.array(src_ignores)
return data
def paste_img(self, src_img, box_img, src_polys):
- box_img_pil = Image.fromarray(box_img).convert('RGBA')
+ box_img_pil = Image.fromarray(box_img).convert("RGBA")
src_w, src_h = src_img.size
box_w, box_h = box_img_pil.size
@@ -90,8 +89,9 @@ def paste_img(self, src_img, box_img, src_polys):
if src_w - box_w < 0 or src_h - box_h < 0:
return src_img, None
- paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
- src_h - box_h)
+ paste_x, paste_y = self.select_coord(
+ src_polys, box, src_w - box_w, src_h - box_h
+ )
if paste_x is None:
return src_img, None
box[:, 0] += paste_x
@@ -103,8 +103,12 @@ def paste_img(self, src_img, box_img, src_polys):
def select_coord(self, src_polys, box, endx, endy):
if self.limit_paste:
- xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
- ), box[:, 0].max(), box[:, 1].max()
+ xmin, ymin, xmax, ymax = (
+ box[:, 0].min(),
+ box[:, 1].min(),
+ box[:, 0].max(),
+ box[:, 1].max(),
+ )
for _ in range(50):
paste_x = random.randint(0, endx)
paste_y = random.randint(0, endy)
@@ -115,8 +119,9 @@ def select_coord(self, src_polys, box, endx, endy):
num_poly_in_rect = 0
for poly in src_polys:
- if not is_poly_outside_rect(poly, xmin1, ymin1,
- xmax1 - xmin1, ymax1 - ymin1):
+ if not is_poly_outside_rect(
+ poly, xmin1, ymin1, xmax1 - xmin1, ymax1 - ymin1
+ ):
num_poly_in_rect += 1
break
if num_poly_in_rect == 0:
@@ -156,8 +161,8 @@ def rotate_bbox(img, text_polys, angle, scale=1):
h = img.shape[0]
rangle = np.deg2rad(angle)
- nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
- nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
+ nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
+ nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
rot_mat[0, 2] += rot_move[0]
diff --git a/ppocr/data/imaug/ct_process.py b/ppocr/data/imaug/ct_process.py
index 2434c91609..26e111dd85 100644
--- a/ppocr/data/imaug/ct_process.py
+++ b/ppocr/data/imaug/ct_process.py
@@ -25,7 +25,7 @@
from ppocr.utils.utility import check_install
-class RandomScale():
+class RandomScale:
def __init__(self, short_size=640, **kwargs):
self.short_size = short_size
@@ -43,19 +43,19 @@ def scale_aligned(self, img, scale):
return img, factor_h, factor_w
def __call__(self, data):
- img = data['image']
+ img = data["image"]
h, w = img.shape[0:2]
random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])
scale = (np.random.choice(random_scale) * self.short_size) / min(h, w)
img, factor_h, factor_w = self.scale_aligned(img, scale)
- data['scale_factor'] = (factor_w, factor_h)
- data['image'] = img
+ data["scale_factor"] = (factor_w, factor_h)
+ data["image"] = img
return data
-class MakeShrink():
+class MakeShrink:
def __init__(self, kernel_scale=0.7, **kwargs):
self.kernel_scale = kernel_scale
@@ -69,8 +69,9 @@ def perimeter(self, bbox):
return peri
def shrink(self, bboxes, rate, max_shr=20):
- check_install('Polygon', 'Polygon3')
+ check_install("Polygon", "Polygon3")
import Polygon as plg
+
rate = rate * rate
shrinked_bboxes = []
for bbox in bboxes:
@@ -79,10 +80,8 @@ def shrink(self, bboxes, rate, max_shr=20):
try:
pco = pyclipper.PyclipperOffset()
- pco.AddPath(bbox, pyclipper.JT_ROUND,
- pyclipper.ET_CLOSEDPOLYGON)
- offset = min(
- int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
+ pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
shrinked_bbox = pco.Execute(-offset)
if len(shrinked_bbox) == 0:
@@ -101,40 +100,41 @@ def shrink(self, bboxes, rate, max_shr=20):
return shrinked_bboxes
def __call__(self, data):
- img = data['image']
- bboxes = data['polys']
- words = data['texts']
- scale_factor = data['scale_factor']
+ img = data["image"]
+ bboxes = data["polys"]
+ words = data["texts"]
+ scale_factor = data["scale_factor"]
- gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w
- training_mask = np.ones(img.shape[0:2], dtype='uint8')
- training_mask_distance = np.ones(img.shape[0:2], dtype='uint8')
+ gt_instance = np.zeros(img.shape[0:2], dtype="uint8") # h,w
+ training_mask = np.ones(img.shape[0:2], dtype="uint8")
+ training_mask_distance = np.ones(img.shape[0:2], dtype="uint8")
for i in range(len(bboxes)):
- bboxes[i] = np.reshape(bboxes[i] * (
- [scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)),
- (bboxes[i].shape[0] // 2, 2)).astype('int32')
+ bboxes[i] = np.reshape(
+ bboxes[i]
+ * ([scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)),
+ (bboxes[i].shape[0] // 2, 2),
+ ).astype("int32")
for i in range(len(bboxes)):
- #different value for different bbox
+ # different value for different bbox
cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
# set training mask to 0
cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)
# for not accurate annotation, use training_mask_distance
- if words[i] == '###' or words[i] == '???':
+ if words[i] == "###" or words[i] == "???":
cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1)
# make shrink
- gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8')
+ gt_kernel_instance = np.zeros(img.shape[0:2], dtype="uint8")
kernel_bboxes = self.shrink(bboxes, self.kernel_scale)
for i in range(len(bboxes)):
- cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1,
- -1)
+ cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1, -1)
# for training mask, kernel and background= 1, box region=0
- if words[i] != '###' and words[i] != '???':
+ if words[i] != "###" and words[i] != "???":
cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1)
gt_kernel = gt_kernel_instance.copy()
@@ -158,33 +158,38 @@ def __call__(self, data):
# gt_kernel_inner: text kernel reference
# training_mask_distance: word without anno = 0, else 1
- data['image'] = [
- img, gt_instance, training_mask, gt_kernel_instance, gt_kernel,
- gt_kernel_inner, training_mask_distance
+ data["image"] = [
+ img,
+ gt_instance,
+ training_mask,
+ gt_kernel_instance,
+ gt_kernel,
+ gt_kernel_inner,
+ training_mask_distance,
]
return data
-class GroupRandomHorizontalFlip():
+class GroupRandomHorizontalFlip:
def __init__(self, p=0.5, **kwargs):
self.p = p
def __call__(self, data):
- imgs = data['image']
+ imgs = data["image"]
if random.random() < self.p:
for i in range(len(imgs)):
imgs[i] = np.flip(imgs[i], axis=1).copy()
- data['image'] = imgs
+ data["image"] = imgs
return data
-class GroupRandomRotate():
+class GroupRandomRotate:
def __init__(self, **kwargs):
pass
def __call__(self, data):
- imgs = data['image']
+ imgs = data["image"]
max_angle = 10
angle = random.random() * 2 * max_angle - max_angle
@@ -193,19 +198,20 @@ def __call__(self, data):
w, h = img.shape[:2]
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
img_rotation = cv2.warpAffine(
- img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST)
+ img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST
+ )
imgs[i] = img_rotation
- data['image'] = imgs
+ data["image"] = imgs
return data
-class GroupRandomCropPadding():
+class GroupRandomCropPadding:
def __init__(self, target_size=(640, 640), **kwargs):
self.target_size = target_size
def __call__(self, data):
- imgs = data['image']
+ imgs = data["image"]
h, w = imgs[0].shape[0:2]
t_w, t_h = self.target_size
@@ -235,7 +241,7 @@ def __call__(self, data):
for idx in range(len(imgs)):
if len(imgs[idx].shape) == 3:
s3_length = int(imgs[idx].shape[-1])
- img = imgs[idx][i:i + t_h, j:j + t_w, :]
+ img = imgs[idx][i : i + t_h, j : j + t_w, :]
img_p = cv2.copyMakeBorder(
img,
0,
@@ -243,9 +249,10 @@ def __call__(self, data):
0,
p_w - t_w,
borderType=cv2.BORDER_CONSTANT,
- value=tuple(0 for i in range(s3_length)))
+ value=tuple(0 for i in range(s3_length)),
+ )
else:
- img = imgs[idx][i:i + t_h, j:j + t_w]
+ img = imgs[idx][i : i + t_h, j : j + t_w]
img_p = cv2.copyMakeBorder(
img,
0,
@@ -253,14 +260,15 @@ def __call__(self, data):
0,
p_w - t_w,
borderType=cv2.BORDER_CONSTANT,
- value=(0, ))
+ value=(0,),
+ )
n_imgs.append(img_p)
- data['image'] = n_imgs
+ data["image"] = n_imgs
return data
-class MakeCentripetalShift():
+class MakeCentripetalShift:
def __init__(self, **kwargs):
pass
@@ -269,20 +277,32 @@ def jaccard(self, As, Bs):
B = Bs.shape[0] # large
dis = np.sqrt(
- np.sum((As[:, np.newaxis, :].repeat(
- B, axis=1) - Bs[np.newaxis, :, :].repeat(
- A, axis=0))**2,
- axis=-1))
+ np.sum(
+ (
+ As[:, np.newaxis, :].repeat(B, axis=1)
+ - Bs[np.newaxis, :, :].repeat(A, axis=0)
+ )
+ ** 2,
+ axis=-1,
+ )
+ )
ind = np.argmin(dis, axis=-1)
return ind
def __call__(self, data):
- imgs = data['image']
-
- img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \
- imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6]
+ imgs = data["image"]
+
+ (
+ img,
+ gt_instance,
+ training_mask,
+ gt_kernel_instance,
+ gt_kernel,
+ gt_kernel_inner,
+ training_mask_distance,
+ ) = (imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6])
max_instance = np.max(gt_instance) # num bbox
@@ -290,23 +310,23 @@ def __call__(self, data):
gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32)
for i in range(1, max_instance + 1):
# kernel_reference
- ind = (gt_kernel_inner == i)
+ ind = gt_kernel_inner == i
if np.sum(ind) == 0:
training_mask[gt_instance == i] = 0
training_mask_distance[gt_instance == i] = 0
continue
- kpoints = np.array(np.where(ind)).transpose(
- (1, 0))[:, ::-1].astype('float32')
+ kpoints = (
+ np.array(np.where(ind)).transpose((1, 0))[:, ::-1].astype("float32")
+ )
ind = (gt_instance == i) * (gt_kernel_instance == 0)
if np.sum(ind) == 0:
continue
pixels = np.where(ind)
- points = np.array(pixels).transpose(
- (1, 0))[:, ::-1].astype('float32')
+ points = np.array(pixels).transpose((1, 0))[:, ::-1].astype("float32")
bbox_ind = self.jaccard(points, kpoints)
@@ -315,7 +335,7 @@ def __call__(self, data):
gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1
img = Image.fromarray(img)
- img = img.convert('RGB')
+ img = img.convert("RGB")
data["image"] = img
data["gt_kernel"] = gt_kernel.astype("int64")
@@ -328,12 +348,12 @@ def __call__(self, data):
return data
-class ScaleAlignedShort():
+class ScaleAlignedShort:
def __init__(self, short_size=640, **kwargs):
self.short_size = short_size
def __call__(self, data):
- img = data['image']
+ img = data["image"]
org_img_shape = img.shape
@@ -350,7 +370,7 @@ def __call__(self, data):
new_img_shape = img.shape
img_shape = np.array(org_img_shape + new_img_shape)
- data['shape'] = img_shape
- data['image'] = img
+ data["shape"] = img_shape
+ data["image"] = img
- return data
\ No newline at end of file
+ return data
diff --git a/ppocr/data/imaug/drrg_targets.py b/ppocr/data/imaug/drrg_targets.py
index 7fdfd09681..5c2ba7caa9 100644
--- a/ppocr/data/imaug/drrg_targets.py
+++ b/ppocr/data/imaug/drrg_targets.py
@@ -23,22 +23,23 @@
class DRRGTargets(object):
- def __init__(self,
- orientation_thr=2.0,
- resample_step=8.0,
- num_min_comps=9,
- num_max_comps=600,
- min_width=8.0,
- max_width=24.0,
- center_region_shrink_ratio=0.3,
- comp_shrink_ratio=1.0,
- comp_w_h_ratio=0.3,
- text_comp_nms_thr=0.25,
- min_rand_half_height=8.0,
- max_rand_half_height=24.0,
- jitter_level=0.2,
- **kwargs):
-
+ def __init__(
+ self,
+ orientation_thr=2.0,
+ resample_step=8.0,
+ num_min_comps=9,
+ num_max_comps=600,
+ min_width=8.0,
+ max_width=24.0,
+ center_region_shrink_ratio=0.3,
+ comp_shrink_ratio=1.0,
+ comp_w_h_ratio=0.3,
+ text_comp_nms_thr=0.25,
+ min_rand_half_height=8.0,
+ max_rand_half_height=24.0,
+ jitter_level=0.2,
+ **kwargs
+ ):
super().__init__()
self.orientation_thr = orientation_thr
self.resample_step = resample_step
@@ -64,9 +65,7 @@ def vector_angle(self, vec1, vec2):
unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps).reshape((-1, 1))
else:
unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps)
- return np.arccos(
- np.clip(
- np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+ return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
def vector_slope(self, vec):
assert len(vec) == 2
@@ -81,7 +80,6 @@ def vector_cos(self, vec):
return vec[0] / (norm(vec) + self.eps)
def find_head_tail(self, points, orientation_thr):
-
assert points.ndim == 2
assert points.shape[0] >= 4
assert points.shape[1] == 2
@@ -96,20 +94,19 @@ def find_head_tail(self, points, orientation_thr):
for i, edge_vec1 in enumerate(edge_vec):
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
adjacent_edge_vec = edge_vec[adjacent_ind]
- temp_theta_sum = np.sum(
- self.vector_angle(edge_vec1, adjacent_edge_vec))
- temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
- adjacent_edge_vec[1])
+ temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
+ temp_adjacent_theta = self.vector_angle(
+ adjacent_edge_vec[0], adjacent_edge_vec[1]
+ )
theta_sum.append(temp_theta_sum)
adjacent_vec_theta.append(temp_adjacent_theta)
theta_sum_score = np.array(theta_sum) / np.pi
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
poly_center = np.mean(points, axis=0)
edge_dist = np.maximum(
- norm(
- pad_points[1:] - poly_center, axis=-1),
- norm(
- pad_points[:-1] - poly_center, axis=-1))
+ norm(pad_points[1:] - poly_center, axis=-1),
+ norm(pad_points[:-1] - poly_center, axis=-1),
+ )
dist_score = edge_dist / (np.max(edge_dist) + self.eps)
position_score = np.zeros(len(edge_vec))
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
@@ -121,15 +118,21 @@ def find_head_tail(self, points, orientation_thr):
pad_score = np.concatenate([score, score])
score_matrix = np.zeros((len(score), len(score) - 3))
x = np.arange(len(score) - 3) / float(len(score) - 4)
- gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
- (x - 0.5) / 0.5, 2.) / 2)
+ gaussian = (
+ 1.0
+ / (np.sqrt(2.0 * np.pi) * 0.5)
+ * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
+ )
gaussian = gaussian / np.max(gaussian)
for i in range(len(score)):
- score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
- score) - 1)] * gaussian * 0.3
-
- head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
- score_matrix.shape)
+ score_matrix[i, :] = (
+ score[i]
+ + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
+ )
+
+ head_start, tail_increment = np.unravel_index(
+ score_matrix.argmax(), score_matrix.shape
+ )
tail_start = (head_start + tail_increment + 2) % len(points)
head_end = (head_start + 1) % len(points)
tail_end = (tail_start + 1) % len(points)
@@ -141,22 +144,26 @@ def find_head_tail(self, points, orientation_thr):
tail_inds = [tail_start, tail_end]
else:
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
- points[3] - points[2]) < self.vector_slope(points[
- 2] - points[1]) + self.vector_slope(points[0] - points[
- 3]):
+ points[3] - points[2]
+ ) < self.vector_slope(points[2] - points[1]) + self.vector_slope(
+ points[0] - points[3]
+ ):
horizontal_edge_inds = [[0, 1], [2, 3]]
vertical_edge_inds = [[3, 0], [1, 2]]
else:
horizontal_edge_inds = [[3, 0], [1, 2]]
vertical_edge_inds = [[0, 1], [2, 3]]
- vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
- vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
- 0]] - points[vertical_edge_inds[1][1]])
- horizontal_len_sum = norm(points[horizontal_edge_inds[0][
- 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
- horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
- [1]])
+ vertical_len_sum = norm(
+ points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
+ ) + norm(
+ points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
+ )
+ horizontal_len_sum = norm(
+ points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
+ ) + norm(
+ points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
+ )
if vertical_len_sum > horizontal_len_sum * orientation_thr:
head_inds = horizontal_edge_inds[0]
@@ -168,7 +175,6 @@ def find_head_tail(self, points, orientation_thr):
return head_inds, tail_inds
def reorder_poly_edge(self, points):
-
assert points.ndim == 2
assert points.shape[0] >= 4
assert points.shape[1] == 2
@@ -179,11 +185,9 @@ def reorder_poly_edge(self, points):
pad_points = np.vstack([points, points])
if tail_inds[1] < 1:
tail_inds[1] = len(points)
- sideline1 = pad_points[head_inds[1]:tail_inds[1]]
- sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
- sideline_mean_shift = np.mean(
- sideline1, axis=0) - np.mean(
- sideline2, axis=0)
+ sideline1 = pad_points[head_inds[1] : tail_inds[1]]
+ sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
+ sideline_mean_shift = np.mean(sideline1, axis=0) - np.mean(sideline2, axis=0)
if sideline_mean_shift[1] > 0:
top_sideline, bot_sideline = sideline2, sideline1
@@ -193,17 +197,16 @@ def reorder_poly_edge(self, points):
return head_edge, tail_edge, top_sideline, bot_sideline
def cal_curve_length(self, line):
-
assert line.ndim == 2
assert len(line) >= 2
- edges_length = np.sqrt((line[1:, 0] - line[:-1, 0])**2 + (line[
- 1:, 1] - line[:-1, 1])**2)
+ edges_length = np.sqrt(
+ (line[1:, 0] - line[:-1, 0]) ** 2 + (line[1:, 1] - line[:-1, 1]) ** 2
+ )
total_length = np.sum(edges_length)
return edges_length, total_length
def resample_line(self, line, n):
-
assert line.ndim == 2
assert line.shape[0] >= 2
assert line.shape[1] == 2
@@ -220,8 +223,9 @@ def resample_line(self, line, n):
while edge_ind < len(edges_length) - 1 and t > t_org[edge_ind + 1]:
edge_ind += 1
t_l, t_r = t_org[edge_ind], t_org[edge_ind + 1]
- weight = np.array(
- [t_r - t, t - t_l], dtype=np.float32) / (t_r - t_l + self.eps)
+ weight = np.array([t_r - t, t - t_l], dtype=np.float32) / (
+ t_r - t_l + self.eps
+ )
p_coords = np.dot(weight, line[[edge_ind, edge_ind + 1]])
points.append(p_coords)
points.append(line[-1])
@@ -230,7 +234,6 @@ def resample_line(self, line, n):
return resampled_line
def resample_sidelines(self, sideline1, sideline2, resample_step):
-
assert sideline1.ndim == sideline2.ndim == 2
assert sideline1.shape[1] == sideline2.shape[1] == 2
assert sideline1.shape[0] >= 2
@@ -249,54 +252,65 @@ def resample_sidelines(self, sideline1, sideline2, resample_step):
return resampled_line1, resampled_line2
def dist_point2line(self, point, line):
-
assert isinstance(line, tuple)
point1, point2 = line
d = abs(np.cross(point2 - point1, point - point1)) / (
- norm(point2 - point1) + 1e-8)
+ norm(point2 - point1) + 1e-8
+ )
return d
- def draw_center_region_maps(self, top_line, bot_line, center_line,
- center_region_mask, top_height_map,
- bot_height_map, sin_map, cos_map,
- region_shrink_ratio):
-
+ def draw_center_region_maps(
+ self,
+ top_line,
+ bot_line,
+ center_line,
+ center_region_mask,
+ top_height_map,
+ bot_height_map,
+ sin_map,
+ cos_map,
+ region_shrink_ratio,
+ ):
assert top_line.shape == bot_line.shape == center_line.shape
- assert (center_region_mask.shape == top_height_map.shape ==
- bot_height_map.shape == sin_map.shape == cos_map.shape)
+ assert (
+ center_region_mask.shape
+ == top_height_map.shape
+ == bot_height_map.shape
+ == sin_map.shape
+ == cos_map.shape
+ )
assert isinstance(region_shrink_ratio, float)
h, w = center_region_mask.shape
for i in range(0, len(center_line) - 1):
-
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
- tl = center_line[i] + (top_line[i] - center_line[i]
- ) * region_shrink_ratio
- tr = center_line[i + 1] + (top_line[i + 1] - center_line[i + 1]
- ) * region_shrink_ratio
- br = center_line[i + 1] + (bot_line[i + 1] - center_line[i + 1]
- ) * region_shrink_ratio
- bl = center_line[i] + (bot_line[i] - center_line[i]
- ) * region_shrink_ratio
+ tl = center_line[i] + (top_line[i] - center_line[i]) * region_shrink_ratio
+ tr = (
+ center_line[i + 1]
+ + (top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
+ )
+ br = (
+ center_line[i + 1]
+ + (bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
+ )
+ bl = center_line[i] + (bot_line[i] - center_line[i]) * region_shrink_ratio
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
- current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
- w - 1)
- current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
- h - 1)
+ current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0, w - 1)
+ current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0, h - 1)
min_coord = np.min(current_center_box, axis=0).astype(np.int32)
max_coord = np.max(current_center_box, axis=0).astype(np.int32)
current_center_box = current_center_box - min_coord
- box_sz = (max_coord - min_coord + 1)
+ box_sz = max_coord - min_coord + 1
center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
cv2.fillPoly(center_box_mask, [current_center_box], color=1)
@@ -305,12 +319,13 @@ def draw_center_region_maps(self, top_line, bot_line, center_line,
inds = inds + (min_coord[1], min_coord[0])
inds_xy = np.fliplr(inds)
top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
- inds_xy, (top_line[i], top_line[i + 1]))
+ inds_xy, (top_line[i], top_line[i + 1])
+ )
bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
- inds_xy, (bot_line[i], bot_line[i + 1]))
+ inds_xy, (bot_line[i], bot_line[i + 1])
+ )
def generate_center_mask_attrib_maps(self, img_size, text_polys):
-
assert isinstance(img_size, tuple)
h, w = img_size
@@ -326,7 +341,8 @@ def generate_center_mask_attrib_maps(self, img_size, text_polys):
polygon_points = poly
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
- top_line, bot_line, self.resample_step)
+ top_line, bot_line, self.resample_step
+ )
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
@@ -341,33 +357,58 @@ def generate_center_mask_attrib_maps(self, img_size, text_polys):
resampled_top_line = resampled_top_line[::-1]
resampled_bot_line = resampled_bot_line[::-1]
- line_head_shrink_len = np.clip(
- (norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
- self.min_width, self.max_width) / 2
- line_tail_shrink_len = np.clip(
- (norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
- self.min_width, self.max_width) / 2
+ line_head_shrink_len = (
+ np.clip(
+ (norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
+ self.min_width,
+ self.max_width,
+ )
+ / 2
+ )
+ line_tail_shrink_len = (
+ np.clip(
+ (norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
+ self.min_width,
+ self.max_width,
+ )
+ / 2
+ )
num_head_shrink = int(line_head_shrink_len // self.resample_step)
num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > num_head_shrink + num_tail_shrink + 2:
- center_line = center_line[num_head_shrink:len(center_line) -
- num_tail_shrink]
- resampled_top_line = resampled_top_line[num_head_shrink:len(
- resampled_top_line) - num_tail_shrink]
- resampled_bot_line = resampled_bot_line[num_head_shrink:len(
- resampled_bot_line) - num_tail_shrink]
+ center_line = center_line[
+ num_head_shrink : len(center_line) - num_tail_shrink
+ ]
+ resampled_top_line = resampled_top_line[
+ num_head_shrink : len(resampled_top_line) - num_tail_shrink
+ ]
+ resampled_bot_line = resampled_bot_line[
+ num_head_shrink : len(resampled_bot_line) - num_tail_shrink
+ ]
center_lines.append(center_line.astype(np.int32))
self.draw_center_region_maps(
- resampled_top_line, resampled_bot_line, center_line,
- center_region_mask, top_height_map, bot_height_map, sin_map,
- cos_map, self.center_region_shrink_ratio)
-
- return (center_lines, center_region_mask, top_height_map,
- bot_height_map, sin_map, cos_map)
+ resampled_top_line,
+ resampled_bot_line,
+ center_line,
+ center_region_mask,
+ top_height_map,
+ bot_height_map,
+ sin_map,
+ cos_map,
+ self.center_region_shrink_ratio,
+ )
+
+ return (
+ center_lines,
+ center_region_mask,
+ top_height_map,
+ bot_height_map,
+ sin_map,
+ cos_map,
+ )
def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
-
assert isinstance(num_rand_comps, int)
assert num_rand_comps > 0
assert center_sample_mask.ndim == 2
@@ -377,31 +418,34 @@ def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
max_rand_half_height = self.max_rand_half_height
min_rand_half_height = self.min_rand_half_height
max_rand_height = max_rand_half_height * 2
- max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
- self.min_width, self.max_width)
- margin = int(
- np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
+ max_rand_width = np.clip(
+ max_rand_height * self.comp_w_h_ratio, self.min_width, self.max_width
+ )
+ margin = (
+ int(np.sqrt((max_rand_height / 2) ** 2 + (max_rand_width / 2) ** 2)) + 1
+ )
if 2 * margin + 1 > min(h, w):
-
assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
- min_rand_half_height = max(max_rand_half_height / 4,
- self.min_width / 2)
+ min_rand_half_height = max(max_rand_half_height / 4, self.min_width / 2)
max_rand_height = max_rand_half_height * 2
- max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
- self.min_width, self.max_width)
- margin = int(
- np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
+ max_rand_width = np.clip(
+ max_rand_height * self.comp_w_h_ratio, self.min_width, self.max_width
+ )
+ margin = (
+ int(np.sqrt((max_rand_height / 2) ** 2 + (max_rand_width / 2) ** 2)) + 1
+ )
inner_center_sample_mask = np.zeros_like(center_sample_mask)
- inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
- center_sample_mask[margin:h - margin, margin:w - margin]
+ inner_center_sample_mask[
+ margin : h - margin, margin : w - margin
+ ] = center_sample_mask[margin : h - margin, margin : w - margin]
kernel_size = int(np.clip(max_rand_half_height, 7, 21))
inner_center_sample_mask = cv2.erode(
- inner_center_sample_mask,
- np.ones((kernel_size, kernel_size), np.uint8))
+ inner_center_sample_mask, np.ones((kernel_size, kernel_size), np.uint8)
+ )
center_candidates = np.argwhere(inner_center_sample_mask > 0)
num_center_candidates = len(center_candidates)
@@ -409,13 +453,11 @@ def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
rand_centers = center_candidates[sample_inds]
rand_top_height = np.random.randint(
- min_rand_half_height,
- max_rand_half_height,
- size=(len(rand_centers), 1))
+ min_rand_half_height, max_rand_half_height, size=(len(rand_centers), 1)
+ )
rand_bot_height = np.random.randint(
- min_rand_half_height,
- max_rand_half_height,
- size=(len(rand_centers), 1))
+ min_rand_half_height, max_rand_half_height, size=(len(rand_centers), 1)
+ )
rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
@@ -423,14 +465,19 @@ def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
rand_cos = rand_cos * scale
rand_sin = rand_sin * scale
- height = (rand_top_height + rand_bot_height)
- width = np.clip(height * self.comp_w_h_ratio, self.min_width,
- self.max_width)
+ height = rand_top_height + rand_bot_height
+ width = np.clip(height * self.comp_w_h_ratio, self.min_width, self.max_width)
- rand_comp_attribs = np.hstack([
- rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
- np.zeros_like(rand_sin)
- ]).astype(np.float32)
+ rand_comp_attribs = np.hstack(
+ [
+ rand_centers[:, ::-1],
+ height,
+ width,
+ rand_cos,
+ rand_sin,
+ np.zeros_like(rand_sin),
+ ]
+ ).astype(np.float32)
return rand_comp_attribs
@@ -459,20 +506,22 @@ def jitter_comp_attribs(self, comp_attribs, jitter_level):
sin = comp_attribs[:, 5].reshape((-1, 1))
comp_labels = comp_attribs[:, 6].reshape((-1, 1))
- x += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
- h * np.abs(cos) + w * np.abs(sin)) * jitter_level
- y += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
- h * np.abs(sin) + w * np.abs(cos)) * jitter_level
+ x += (
+ (np.random.random(size=(len(comp_attribs), 1)) - 0.5)
+ * (h * np.abs(cos) + w * np.abs(sin))
+ * jitter_level
+ )
+ y += (
+ (np.random.random(size=(len(comp_attribs), 1)) - 0.5)
+ * (h * np.abs(sin) + w * np.abs(cos))
+ * jitter_level
+ )
- h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
- ) * h * jitter_level
- w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
- ) * w * jitter_level
+ h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * h * jitter_level
+ w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * w * jitter_level
- cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
- ) * 2 * jitter_level
- sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
- ) * 2 * jitter_level
+ cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * 2 * jitter_level
+ sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * 2 * jitter_level
scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
cos = cos * scale
@@ -482,8 +531,16 @@ def jitter_comp_attribs(self, comp_attribs, jitter_level):
return jittered_comp_attribs
- def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
- top_height_map, bot_height_map, sin_map, cos_map):
+ def generate_comp_attribs(
+ self,
+ center_lines,
+ text_mask,
+ center_region_mask,
+ top_height_map,
+ bot_height_map,
+ sin_map,
+ cos_map,
+ ):
"""Generate text component attributes.
Args:
@@ -508,8 +565,13 @@ def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
assert isinstance(center_lines, list)
assert (
- text_mask.shape == center_region_mask.shape == top_height_map.shape
- == bot_height_map.shape == sin_map.shape == cos_map.shape)
+ text_mask.shape
+ == center_region_mask.shape
+ == top_height_map.shape
+ == bot_height_map.shape
+ == sin_map.shape
+ == cos_map.shape
+ )
center_lines_mask = np.zeros_like(center_region_mask)
cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
@@ -519,17 +581,13 @@ def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
y = comp_centers[:, 0]
x = comp_centers[:, 1]
- top_height = top_height_map[y, x].reshape(
- (-1, 1)) * self.comp_shrink_ratio
- bot_height = bot_height_map[y, x].reshape(
- (-1, 1)) * self.comp_shrink_ratio
+ top_height = top_height_map[y, x].reshape((-1, 1)) * self.comp_shrink_ratio
+ bot_height = bot_height_map[y, x].reshape((-1, 1)) * self.comp_shrink_ratio
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
- top_mid_points = comp_centers + np.hstack(
- [top_height * sin, top_height * cos])
- bot_mid_points = comp_centers - np.hstack(
- [bot_height * sin, bot_height * cos])
+ top_mid_points = comp_centers + np.hstack([top_height * sin, top_height * cos])
+ bot_mid_points = comp_centers - np.hstack([bot_height * sin, bot_height * cos])
width = (top_height + bot_height) * self.comp_w_h_ratio
width = np.clip(width, self.min_width, self.max_width)
@@ -543,8 +601,9 @@ def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
text_comps = np.hstack([text_comps, score])
- check_install('lanms', 'lanms-neo')
+ check_install("lanms", "lanms-neo")
from lanms import merge_quadrangle_n9 as la_nms
+
text_comps = la_nms(text_comps, self.text_comp_nms_thr)
if text_comps.shape[0] >= 1:
@@ -553,51 +612,54 @@ def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
comp_centers = np.mean(
- text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
+ text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1
+ ).astype(np.int32)
x = comp_centers[:, 0]
y = comp_centers[:, 1]
- height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
- (-1, 1))
- width = np.clip(height * self.comp_w_h_ratio, self.min_width,
- self.max_width)
+ height = (top_height_map[y, x] + bot_height_map[y, x]).reshape((-1, 1))
+ width = np.clip(
+ height * self.comp_w_h_ratio, self.min_width, self.max_width
+ )
cos = cos_map[y, x].reshape((-1, 1))
sin = sin_map[y, x].reshape((-1, 1))
_, comp_label_mask = cv2.connectedComponents(
- center_region_mask, connectivity=8)
- comp_labels = comp_label_mask[y, x].reshape(
- (-1, 1)).astype(np.float32)
+ center_region_mask, connectivity=8
+ )
+ comp_labels = comp_label_mask[y, x].reshape((-1, 1)).astype(np.float32)
x = x.reshape((-1, 1)).astype(np.float32)
y = y.reshape((-1, 1)).astype(np.float32)
- comp_attribs = np.hstack(
- [x, y, height, width, cos, sin, comp_labels])
- comp_attribs = self.jitter_comp_attribs(comp_attribs,
- self.jitter_level)
+ comp_attribs = np.hstack([x, y, height, width, cos, sin, comp_labels])
+ comp_attribs = self.jitter_comp_attribs(comp_attribs, self.jitter_level)
if comp_attribs.shape[0] < self.num_min_comps:
num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
rand_comp_attribs = self.generate_rand_comp_attribs(
- num_rand_comps, 1 - text_mask)
+ num_rand_comps, 1 - text_mask
+ )
comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
else:
- comp_attribs = self.generate_rand_comp_attribs(self.num_min_comps,
- 1 - text_mask)
-
- num_comps = (np.ones(
- (comp_attribs.shape[0], 1),
- dtype=np.float32) * comp_attribs.shape[0])
+ comp_attribs = self.generate_rand_comp_attribs(
+ self.num_min_comps, 1 - text_mask
+ )
+
+ num_comps = (
+ np.ones((comp_attribs.shape[0], 1), dtype=np.float32)
+ * comp_attribs.shape[0]
+ )
comp_attribs = np.hstack([num_comps, comp_attribs])
if comp_attribs.shape[0] > self.num_max_comps:
- comp_attribs = comp_attribs[:self.num_max_comps, :]
+ comp_attribs = comp_attribs[: self.num_max_comps, :]
comp_attribs[:, 0] = self.num_max_comps
pad_comp_attribs = np.zeros(
- (self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
- pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
+ (self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32
+ )
+ pad_comp_attribs[: comp_attribs.shape[0], :] = comp_attribs
return pad_comp_attribs
@@ -655,9 +717,9 @@ def generate_targets(self, data):
assert isinstance(data, dict)
- image = data['image']
- polygons = data['polys']
- ignore_tags = data['ignore_tags']
+ image = data["image"]
+ polygons = data["polys"]
+ ignore_tags = data["ignore_tags"]
h, w, _ = image.shape
polygon_masks = []
@@ -670,27 +732,37 @@ def generate_targets(self, data):
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
- (center_lines, gt_center_region_mask, gt_top_height_map,
- gt_bot_height_map, gt_sin_map,
- gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
- polygon_masks)
+ (
+ center_lines,
+ gt_center_region_mask,
+ gt_top_height_map,
+ gt_bot_height_map,
+ gt_sin_map,
+ gt_cos_map,
+ ) = self.generate_center_mask_attrib_maps((h, w), polygon_masks)
gt_comp_attribs = self.generate_comp_attribs(
- center_lines, gt_text_mask, gt_center_region_mask,
- gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map)
+ center_lines,
+ gt_text_mask,
+ gt_center_region_mask,
+ gt_top_height_map,
+ gt_bot_height_map,
+ gt_sin_map,
+ gt_cos_map,
+ )
mapping = {
- 'gt_text_mask': gt_text_mask,
- 'gt_center_region_mask': gt_center_region_mask,
- 'gt_mask': gt_mask,
- 'gt_top_height_map': gt_top_height_map,
- 'gt_bot_height_map': gt_bot_height_map,
- 'gt_sin_map': gt_sin_map,
- 'gt_cos_map': gt_cos_map
+ "gt_text_mask": gt_text_mask,
+ "gt_center_region_mask": gt_center_region_mask,
+ "gt_mask": gt_mask,
+ "gt_top_height_map": gt_top_height_map,
+ "gt_bot_height_map": gt_bot_height_map,
+ "gt_sin_map": gt_sin_map,
+ "gt_cos_map": gt_cos_map,
}
data.update(mapping)
- data['gt_comp_attribs'] = gt_comp_attribs
+ data["gt_comp_attribs"] = gt_comp_attribs
return data
def __call__(self, data):
diff --git a/ppocr/data/imaug/east_process.py b/ppocr/data/imaug/east_process.py
index df08adfa15..c3fbf88243 100644
--- a/ppocr/data/imaug/east_process.py
+++ b/ppocr/data/imaug/east_process.py
@@ -1,16 +1,16 @@
-#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
This code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
@@ -22,16 +22,18 @@
import sys
import os
-__all__ = ['EASTProcessTrain']
+__all__ = ["EASTProcessTrain"]
class EASTProcessTrain(object):
- def __init__(self,
- image_shape=[512, 512],
- background_ratio=0.125,
- min_crop_side_ratio=0.1,
- min_text_size=10,
- **kwargs):
+ def __init__(
+ self,
+ image_shape=[512, 512],
+ background_ratio=0.125,
+ min_crop_side_ratio=0.1,
+ min_text_size=10,
+ **kwargs
+ ):
self.input_size = image_shape[1]
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
self.background_ratio = background_ratio
@@ -83,10 +85,16 @@ def rotate_im_poly(self, im, text_polys):
poly = []
for j in range(4):
sx, sy = wordBB[j][0], wordBB[j][1]
- dx = math.cos(rot_angle) * (sx - cx)\
- - math.sin(rot_angle) * (sy - cy) + ncx
- dy = math.sin(rot_angle) * (sx - cx)\
- + math.cos(rot_angle) * (sy - cy) + ncy
+ dx = (
+ math.cos(rot_angle) * (sx - cx)
+ - math.sin(rot_angle) * (sy - cy)
+ + ncx
+ )
+ dy = (
+ math.sin(rot_angle) * (sx - cx)
+ + math.cos(rot_angle) * (sy - cy)
+ + ncy
+ )
poly.append([dx, dy])
dst_polys.append(poly)
dst_polys = np.array(dst_polys, dtype=np.float32)
@@ -98,11 +106,13 @@ def polygon_area(self, poly):
:param poly:
:return:
"""
- edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
- (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
- (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
- (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
- return np.sum(edge) / 2.
+ edge = [
+ (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
+ ]
+ return np.sum(edge) / 2.0
def check_and_validate_polys(self, polys, tags, img_height, img_width):
"""
@@ -122,13 +132,13 @@ def check_and_validate_polys(self, polys, tags, img_height, img_width):
validated_tags = []
for poly, tag in zip(polys, tags):
p_area = self.polygon_area(poly)
- #invalid poly
+ # invalid poly
if abs(p_area) < 1:
continue
if p_area > 0:
#'poly in wrong direction'
if not tag:
- tag = True #reversed cases should be ignore
+ tag = True # reversed cases should be ignore
poly = poly[(0, 3, 2, 1), :]
validated_polys.append(poly)
validated_tags.append(tag)
@@ -148,6 +158,7 @@ def draw_img_polys(self, img, polys):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
import random
+
ino = random.randint(0, 100)
cv2.imwrite("tmp_%d.jpg" % ino, img)
return
@@ -170,29 +181,25 @@ def shrink_poly(self, poly, r):
if dist0 + dist1 > dist2 + dist3:
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
## p0, p1
- theta = np.arctan2((poly[1][1] - poly[0][1]),
- (poly[1][0] - poly[0][0]))
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
- theta = np.arctan2((poly[2][1] - poly[3][1]),
- (poly[2][0] - poly[3][0]))
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
poly[2][1] -= R * r[2] * np.sin(theta)
## p0, p3
- theta = np.arctan2((poly[3][0] - poly[0][0]),
- (poly[3][1] - poly[0][1]))
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
- theta = np.arctan2((poly[2][0] - poly[1][0]),
- (poly[2][1] - poly[1][1]))
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
@@ -200,29 +207,25 @@ def shrink_poly(self, poly, r):
else:
## p0, p3
# print poly
- theta = np.arctan2((poly[3][0] - poly[0][0]),
- (poly[3][1] - poly[0][1]))
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
- theta = np.arctan2((poly[2][0] - poly[1][0]),
- (poly[2][1] - poly[1][1]))
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
poly[2][1] -= R * r[2] * np.cos(theta)
## p0, p1
- theta = np.arctan2((poly[1][1] - poly[0][1]),
- (poly[1][0] - poly[0][0]))
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
- theta = np.arctan2((poly[2][1] - poly[3][1]),
- (poly[2][0] - poly[3][0]))
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
@@ -250,24 +253,23 @@ def generate_quad(self, im_size, polys, tags):
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
r[i] = min(dist1, dist2)
# score map
- shrinked_poly = self.shrink_poly(
- poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
+ shrinked_poly = self.shrink_poly(poly.copy(), r).astype(np.int32)[
+ np.newaxis, :, :
+ ]
cv2.fillPoly(score_map, shrinked_poly, 1)
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
# if the poly is too small, then ignore it during training
poly_h = min(
- np.linalg.norm(poly[0] - poly[3]),
- np.linalg.norm(poly[1] - poly[2]))
+ np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])
+ )
poly_w = min(
- np.linalg.norm(poly[0] - poly[1]),
- np.linalg.norm(poly[2] - poly[3]))
+ np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])
+ )
if min(poly_h, poly_w) < self.min_text_size:
- cv2.fillPoly(training_mask,
- poly.astype(np.int32)[np.newaxis, :, :], 0)
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
if tag:
- cv2.fillPoly(training_mask,
- poly.astype(np.int32)[np.newaxis, :, :], 0)
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
# geo map.
@@ -277,12 +279,13 @@ def generate_quad(self, im_size, polys, tags):
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
for pno in range(4):
geo_channel_beg = pno * 2
- geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
+ geo_map[y_in_poly, x_in_poly, geo_channel_beg] = (
x_in_poly - poly[pno, 0]
- geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
+ )
+ geo_map[y_in_poly, x_in_poly, geo_channel_beg + 1] = (
y_in_poly - poly[pno, 1]
- geo_map[y_in_poly, x_in_poly, 8] = \
- 1.0 / max(min(poly_h, poly_w), 1.0)
+ )
+ geo_map[y_in_poly, x_in_poly, 8] = 1.0 / max(min(poly_h, poly_w), 1.0)
return score_map, geo_map, training_mask
def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
@@ -304,10 +307,10 @@ def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
- w_array[minx + pad_w:maxx + pad_w] = 1
+ w_array[minx + pad_w : maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
- h_array[miny + pad_h:maxy + pad_h] = 1
+ h_array[miny + pad_h : maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
@@ -325,31 +328,34 @@ def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
- if xmax - xmin < self.min_crop_side_ratio * w or \
- ymax - ymin < self.min_crop_side_ratio * h:
+ if (
+ xmax - xmin < self.min_crop_side_ratio * w
+ or ymax - ymin < self.min_crop_side_ratio * h
+ ):
# area too small
continue
if polys.shape[0] != 0:
- poly_axis_in_area = (polys[:, :, 0] >= xmin)\
- & (polys[:, :, 0] <= xmax)\
- & (polys[:, :, 1] >= ymin)\
+ poly_axis_in_area = (
+ (polys[:, :, 0] >= xmin)
+ & (polys[:, :, 0] <= xmax)
+ & (polys[:, :, 1] >= ymin)
& (polys[:, :, 1] <= ymax)
- selected_polys = np.where(
- np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ )
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
- im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+ im = im[ymin : ymax + 1, xmin : xmax + 1, :]
polys = []
tags = []
return im, polys, tags
else:
continue
- im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+ im = im[ymin : ymax + 1, xmin : xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
polys[:, :, 0] -= xmin
@@ -359,7 +365,8 @@ def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
def crop_background_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
- im, text_polys, text_tags, crop_background=True)
+ im, text_polys, text_tags, crop_background=True
+ )
if len(text_polys) > 0:
return None
@@ -373,11 +380,12 @@ def crop_background_infor(self, im, text_polys, text_tags):
def crop_foreground_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
- im, text_polys, text_tags, crop_background=False)
+ im, text_polys, text_tags, crop_background=False
+ )
if text_polys.shape[0] == 0:
return None
- #continue for all ignore case
+ # continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
# pad and resize image
@@ -389,24 +397,26 @@ def crop_foreground_infor(self, im, text_polys, text_tags):
# print(im.shape)
# self.draw_img_polys(im, text_polys)
score_map, geo_map, training_mask = self.generate_quad(
- (new_h, new_w), text_polys, text_tags)
+ (new_h, new_w), text_polys, text_tags
+ )
return im, score_map, geo_map, training_mask
def __call__(self, data):
- im = data['image']
- text_polys = data['polys']
- text_tags = data['ignore_tags']
+ im = data["image"]
+ text_polys = data["polys"]
+ text_tags = data["ignore_tags"]
if im is None:
return None
if text_polys.shape[0] == 0:
return None
- #add rotate cases
+ # add rotate cases
if np.random.rand() < 0.5:
im, text_polys = self.rotate_im_poly(im, text_polys)
h, w, _ = im.shape
- text_polys, text_tags = self.check_and_validate_polys(text_polys,
- text_tags, h, w)
+ text_polys, text_tags = self.check_and_validate_polys(
+ text_polys, text_tags, h, w
+ )
if text_polys.shape[0] == 0:
return None
@@ -429,8 +439,8 @@ def __call__(self, data):
training_mask = training_mask[np.newaxis, ::4, ::4]
training_mask = training_mask.astype(np.float32)
- data['image'] = im[0]
- data['score_map'] = score_map
- data['geo_map'] = geo_map
- data['training_mask'] = training_mask
+ data["image"] = im[0]
+ data["score_map"] = score_map
+ data["geo_map"] = geo_map
+ data["training_mask"] = training_mask
return data
diff --git a/ppocr/data/imaug/fce_aug.py b/ppocr/data/imaug/fce_aug.py
index baaaa33555..39657a0243 100644
--- a/ppocr/data/imaug/fce_aug.py
+++ b/ppocr/data/imaug/fce_aug.py
@@ -24,7 +24,7 @@
class RandomScaling:
- def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
+ def __init__(self, size=800, scale=(3.0 / 4, 5.0 / 2), **kwargs):
"""Random scale the image while keeping aspect.
Args:
@@ -34,12 +34,11 @@ def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
assert isinstance(size, int)
assert isinstance(scale, float) or isinstance(scale, tuple)
self.size = size
- self.scale = scale if isinstance(scale, tuple) \
- else (1 - scale, 1 + scale)
+ self.scale = scale if isinstance(scale, tuple) else (1 - scale, 1 + scale)
def __call__(self, data):
- image = data['image']
- text_polys = data['polys']
+ image = data["image"]
+ text_polys = data["polys"]
h, w, _ = image.shape
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
@@ -48,21 +47,18 @@ def __call__(self, data):
out_size = (int(h * scales[1]), int(w * scales[0]))
image = cv2.resize(image, out_size[::-1])
- data['image'] = image
+ data["image"] = image
text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
- data['polys'] = text_polys
+ data["polys"] = text_polys
return data
class RandomCropFlip:
- def __init__(self,
- pad_ratio=0.1,
- crop_ratio=0.5,
- iter_num=1,
- min_area_ratio=0.2,
- **kwargs):
+ def __init__(
+ self, pad_ratio=0.1, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2, **kwargs
+ ):
"""Random crop and flip a patch of the image.
Args:
@@ -88,9 +84,9 @@ def __call__(self, results):
return results
def random_crop_flip(self, results):
- image = results['image']
- polygons = results['polys']
- ignore_tags = results['ignore_tags']
+ image = results["image"]
+ polygons = results["polys"]
+ ignore_tags = results["ignore_tags"]
if len(polygons) == 0:
return results
@@ -101,8 +97,7 @@ def random_crop_flip(self, results):
area = h * w
pad_h = int(h * self.pad_ratio)
pad_w = int(w * self.pad_ratio)
- h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h,
- pad_w)
+ h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h, pad_w)
if len(h_axis) == 0 or len(w_axis) == 0:
return results
@@ -127,15 +122,18 @@ def random_crop_flip(self, results):
# area too small
continue
- pts = np.stack([[xmin, xmax, xmax, xmin],
- [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
+ pts = np.stack(
+ [[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]
+ ).T.astype(np.int32)
pp = Polygon(pts)
fail_flag = False
for polygon, ignore_tag in zip(polygons, ignore_tags):
ppi = Polygon(polygon.reshape(-1, 2))
ppiou, _ = poly_intersection(ppi, pp, buffer=0)
- if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
- np.abs(ppiou) > self.epsilon:
+ if (
+ np.abs(ppiou - float(ppi.area)) > self.epsilon
+ and np.abs(ppiou) > self.epsilon
+ ):
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
@@ -159,7 +157,7 @@ def random_crop_flip(self, results):
else:
img = np.ascontiguousarray(cropped[::-1, ::-1])
image[ymin:ymax, xmin:xmax, :] = img
- results['img'] = image
+ results["img"] = image
if len(polys_new) != 0:
height, width, _ = cropped.shape
@@ -181,8 +179,8 @@ def random_crop_flip(self, results):
polys_new[idx] = poly
polygons = polys_keep + polys_new
ignore_tags = ignore_tags_keep + ignore_tags_new
- results['polys'] = np.array(polygons)
- results['ignore_tags'] = ignore_tags
+ results["polys"] = np.array(polygons)
+ results["ignore_tags"] = ignore_tags
return results
@@ -216,10 +214,10 @@ def generate_crop_target(self, image, all_polys, pad_h, pad_w):
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
- w_array[minx + pad_w:maxx + pad_w] = 1
+ w_array[minx + pad_w : maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
- h_array[miny + pad_h:maxy + pad_h] = 1
+ h_array[miny + pad_h : maxy + pad_h] = 1
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
@@ -236,7 +234,6 @@ def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
self.min_side_ratio = min_side_ratio
def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
-
assert isinstance(min_len, int)
assert len(valid_array) > min_len
@@ -248,8 +245,7 @@ def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
region_starts = np.where(diff_array < 0)[0]
region_ends = np.where(diff_array > 0)[0]
region_ind = np.random.randint(0, len(region_starts))
- start = np.random.randint(region_starts[region_ind],
- region_ends[region_ind])
+ start = np.random.randint(region_starts[region_ind], region_ends[region_ind])
end_array = valid_array.copy()
min_end = max(start + min_len, min_end)
@@ -259,8 +255,7 @@ def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
region_starts = np.where(diff_array < 0)[0]
region_ends = np.where(diff_array > 0)[0]
region_ind = np.random.randint(0, len(region_starts))
- end = np.random.randint(region_starts[region_ind],
- region_ends[region_ind])
+ end = np.random.randint(region_starts[region_ind], region_ends[region_ind])
return start, end
def sample_crop_box(self, img_size, results):
@@ -274,7 +269,7 @@ def sample_crop_box(self, img_size, results):
assert isinstance(img_size, tuple)
h, w = img_size[:2]
- key_masks = results['polys']
+ key_masks = results["polys"]
x_valid_array = np.ones(w, dtype=np.int32)
y_valid_array = np.ones(h, dtype=np.int32)
@@ -293,16 +288,18 @@ def sample_crop_box(self, img_size, results):
min_x, max_x = np.min(clip_x), np.max(clip_x)
min_y, max_y = np.min(clip_y), np.max(clip_y)
- x_valid_array[min_x - 2:max_x + 3] = 0
- y_valid_array[min_y - 2:max_y + 3] = 0
+ x_valid_array[min_x - 2 : max_x + 3] = 0
+ y_valid_array[min_y - 2 : max_y + 3] = 0
min_w = int(w * self.min_side_ratio)
min_h = int(h * self.min_side_ratio)
- x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
- min_x_end)
- y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
- min_y_end)
+ x1, x2 = self.sample_valid_start_end(
+ x_valid_array, min_w, max_x_start, min_x_end
+ )
+ y1, y2 = self.sample_valid_start_end(
+ y_valid_array, min_h, max_y_start, min_y_end
+ )
return np.array([x1, y1, x2, y2])
@@ -311,20 +308,19 @@ def crop_img(self, img, bbox):
h, w, _ = img.shape
assert 0 <= bbox[1] < bbox[3] <= h
assert 0 <= bbox[0] < bbox[2] <= w
- return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
+ return img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
def __call__(self, results):
- image = results['image']
- polygons = results['polys']
- ignore_tags = results['ignore_tags']
+ image = results["image"]
+ polygons = results["polys"]
+ ignore_tags = results["ignore_tags"]
if len(polygons) < 1:
return results
if np.random.random_sample() < self.crop_ratio:
-
crop_box = self.sample_crop_box(image.shape, results)
img = self.crop_img(image, crop_box)
- results['image'] = img
+ results["image"] = img
# crop and filter masks
x1, y1, x2, y2 = crop_box
w = max(x2 - x1, 1)
@@ -335,17 +331,19 @@ def __call__(self, results):
valid_masks_list = []
valid_tags_list = []
for ind, polygon in enumerate(polygons):
- if (polygon[:, ::2] > -4).all() and (
- polygon[:, ::2] < w + 4).all() and (
- polygon[:, 1::2] > -4).all() and (
- polygon[:, 1::2] < h + 4).all():
+ if (
+ (polygon[:, ::2] > -4).all()
+ and (polygon[:, ::2] < w + 4).all()
+ and (polygon[:, 1::2] > -4).all()
+ and (polygon[:, 1::2] < h + 4).all()
+ ):
polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
valid_masks_list.append(polygon)
valid_tags_list.append(ignore_tags[ind])
- results['polys'] = np.array(valid_masks_list)
- results['ignore_tags'] = valid_tags_list
+ results["polys"] = np.array(valid_masks_list)
+ results["ignore_tags"] = valid_tags_list
return results
@@ -355,12 +353,14 @@ def __repr__(self):
class RandomRotatePolyInstances:
- def __init__(self,
- rotate_ratio=0.5,
- max_angle=10,
- pad_with_fixed_color=False,
- pad_value=(0, 0, 0),
- **kwargs):
+ def __init__(
+ self,
+ rotate_ratio=0.5,
+ max_angle=10,
+ pad_with_fixed_color=False,
+ pad_value=(0, 0, 0),
+ **kwargs
+ ):
"""Randomly rotate images and polygon masks.
Args:
@@ -387,8 +387,8 @@ def rotate(self, center, points, theta, center_shift=(0, 0)):
cos = math.cos(theta)
sin = math.sin(theta)
- x = (x - center_x)
- y = (y - center_y)
+ x = x - center_x
+ y = y - center_y
_x = center_x + x * cos - y * sin + center_shift[0]
_y = -(center_y + x * sin + y * cos) + center_shift[1]
@@ -422,47 +422,56 @@ def rotate_img(self, img, angle, canvas_size):
if self.pad_with_fixed_color:
target_img = cv2.warpAffine(
img,
- rotation_matrix, (canvas_size[1], canvas_size[0]),
+ rotation_matrix,
+ (canvas_size[1], canvas_size[0]),
flags=cv2.INTER_NEAREST,
- borderValue=self.pad_value)
+ borderValue=self.pad_value,
+ )
else:
mask = np.zeros_like(img)
- (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
- np.random.randint(0, w * 7 // 8))
- img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+ (h_ind, w_ind) = (
+ np.random.randint(0, h * 7 // 8),
+ np.random.randint(0, w * 7 // 8),
+ )
+ img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
mask = cv2.warpAffine(
mask,
- rotation_matrix, (canvas_size[1], canvas_size[0]),
- borderValue=[1, 1, 1])
+ rotation_matrix,
+ (canvas_size[1], canvas_size[0]),
+ borderValue=[1, 1, 1],
+ )
target_img = cv2.warpAffine(
img,
- rotation_matrix, (canvas_size[1], canvas_size[0]),
- borderValue=[0, 0, 0])
+ rotation_matrix,
+ (canvas_size[1], canvas_size[0]),
+ borderValue=[0, 0, 0],
+ )
target_img = target_img + img_cut * mask
return target_img
def __call__(self, results):
if np.random.random_sample() < self.rotate_ratio:
- image = results['image']
- polygons = results['polys']
+ image = results["image"]
+ polygons = results["polys"]
h, w = image.shape[:2]
angle = self.sample_angle(self.max_angle)
canvas_size = self.cal_canvas_size((h, w), angle)
- center_shift = (int((canvas_size[1] - w) / 2), int(
- (canvas_size[0] - h) / 2))
+ center_shift = (
+ int((canvas_size[1] - w) / 2),
+ int((canvas_size[0] - h) / 2),
+ )
image = self.rotate_img(image, angle, canvas_size)
- results['image'] = image
+ results["image"] = image
# rotate polygons
rotated_masks = []
for mask in polygons:
- rotated_mask = self.rotate((w / 2, h / 2), mask, angle,
- center_shift)
+ rotated_mask = self.rotate((w / 2, h / 2), mask, angle, center_shift)
rotated_masks.append(rotated_mask)
- results['polys'] = np.array(rotated_masks)
+ results["polys"] = np.array(rotated_masks)
return results
@@ -472,12 +481,14 @@ def __repr__(self):
class SquareResizePad:
- def __init__(self,
- target_size,
- pad_ratio=0.6,
- pad_with_fixed_color=False,
- pad_value=(0, 0, 0),
- **kwargs):
+ def __init__(
+ self,
+ target_size,
+ pad_ratio=0.6,
+ pad_with_fixed_color=False,
+ pad_value=(0, 0, 0),
+ **kwargs
+ ):
"""Resize or pad images to be square shape.
Args:
@@ -516,15 +527,17 @@ def square_pad(self, img):
expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
expand_img[:] = self.pad_value
else:
- (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
- np.random.randint(0, w * 7 // 8))
- img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+ (h_ind, w_ind) = (
+ np.random.randint(0, h * 7 // 8),
+ np.random.randint(0, w * 7 // 8),
+ )
+ img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
expand_img = cv2.resize(img_cut, (pad_size, pad_size))
if h > w:
y0, x0 = 0, (h - w) // 2
else:
y0, x0 = (w - h) // 2, 0
- expand_img[y0:y0 + h, x0:x0 + w] = img
+ expand_img[y0 : y0 + h, x0 : x0 + w] = img
offset = (x0, y0)
return expand_img, offset
@@ -537,8 +550,8 @@ def square_pad_mask(self, points, offset):
return pad_points
def __call__(self, results):
- image = results['image']
- polygons = results['polys']
+ image = results["image"]
+ polygons = results["polys"]
h, w = image.shape[:2]
if np.random.random_sample() < self.pad_ratio:
@@ -547,15 +560,13 @@ def __call__(self, results):
else:
image, out_size = self.resize_img(image, keep_ratio=False)
offset = (0, 0)
- results['image'] = image
+ results["image"] = image
try:
- polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[
- 1] / w + offset[0]
- polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[
- 0] / h + offset[1]
+ polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[0]
+ polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[1]
except:
pass
- results['polys'] = polygons
+ results["polys"] = polygons
return results
diff --git a/ppocr/data/imaug/fce_targets.py b/ppocr/data/imaug/fce_targets.py
index 054631cb2d..acceff30d6 100644
--- a/ppocr/data/imaug/fce_targets.py
+++ b/ppocr/data/imaug/fce_targets.py
@@ -45,15 +45,16 @@ class FCENetTargets:
assigned to each level.
"""
- def __init__(self,
- fourier_degree=5,
- resample_step=4.0,
- center_region_shrink_ratio=0.3,
- level_size_divisors=(8, 16, 32),
- level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
- orientation_thr=2.0,
- **kwargs):
-
+ def __init__(
+ self,
+ fourier_degree=5,
+ resample_step=4.0,
+ center_region_shrink_ratio=0.3,
+ level_size_divisors=(8, 16, 32),
+ level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
+ orientation_thr=2.0,
+ **kwargs
+ ):
super().__init__()
assert isinstance(level_size_divisors, tuple)
assert isinstance(level_proportion_range, tuple)
@@ -75,9 +76,7 @@ def vector_angle(self, vec1, vec2):
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
else:
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
- return np.arccos(
- np.clip(
- np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+ return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
def resample_line(self, line, n):
"""Resample n points on a line.
@@ -96,9 +95,7 @@ def resample_line(self, line, n):
assert isinstance(n, int)
assert n > 0
- length_list = [
- norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
- ]
+ length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
total_length = sum(length_list)
length_cumsum = np.cumsum([0.0] + length_list)
delta_length = total_length / (float(n) + 1e-8)
@@ -109,21 +106,22 @@ def resample_line(self, line, n):
for i in range(1, n):
current_line_len = i * delta_length
- while current_edge_ind + 1 < len(
- length_cumsum) and current_line_len >= length_cumsum[
- current_edge_ind + 1]:
+ while (
+ current_edge_ind + 1 < len(length_cumsum)
+ and current_line_len >= length_cumsum[current_edge_ind + 1]
+ ):
current_edge_ind += 1
- current_edge_end_shift = current_line_len - length_cumsum[
- current_edge_ind]
+ current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
if current_edge_ind >= len(length_list):
break
- end_shift_ratio = current_edge_end_shift / length_list[
- current_edge_ind]
- current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
- - line[current_edge_ind]
- ) * end_shift_ratio
+ end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
+ current_point = (
+ line[current_edge_ind]
+ + (line[current_edge_ind + 1] - line[current_edge_ind])
+ * end_shift_ratio
+ )
resampled_line.append(current_point)
resampled_line.append(line[-1])
resampled_line = np.array(resampled_line)
@@ -158,11 +156,9 @@ def reorder_poly_edge(self, points):
pad_points = np.vstack([points, points])
if tail_inds[1] < 1:
tail_inds[1] = len(points)
- sideline1 = pad_points[head_inds[1]:tail_inds[1]]
- sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
- sideline_mean_shift = np.mean(
- sideline1, axis=0) - np.mean(
- sideline2, axis=0)
+ sideline1 = pad_points[head_inds[1] : tail_inds[1]]
+ sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
+ sideline_mean_shift = np.mean(sideline1, axis=0) - np.mean(sideline2, axis=0)
if sideline_mean_shift[1] > 0:
top_sideline, bot_sideline = sideline2, sideline1
@@ -199,20 +195,19 @@ def find_head_tail(self, points, orientation_thr):
for i, edge_vec1 in enumerate(edge_vec):
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
adjacent_edge_vec = edge_vec[adjacent_ind]
- temp_theta_sum = np.sum(
- self.vector_angle(edge_vec1, adjacent_edge_vec))
- temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
- adjacent_edge_vec[1])
+ temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
+ temp_adjacent_theta = self.vector_angle(
+ adjacent_edge_vec[0], adjacent_edge_vec[1]
+ )
theta_sum.append(temp_theta_sum)
adjacent_vec_theta.append(temp_adjacent_theta)
theta_sum_score = np.array(theta_sum) / np.pi
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
poly_center = np.mean(points, axis=0)
edge_dist = np.maximum(
- norm(
- pad_points[1:] - poly_center, axis=-1),
- norm(
- pad_points[:-1] - poly_center, axis=-1))
+ norm(pad_points[1:] - poly_center, axis=-1),
+ norm(pad_points[:-1] - poly_center, axis=-1),
+ )
dist_score = edge_dist / np.max(edge_dist)
position_score = np.zeros(len(edge_vec))
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
@@ -224,15 +219,21 @@ def find_head_tail(self, points, orientation_thr):
pad_score = np.concatenate([score, score])
score_matrix = np.zeros((len(score), len(score) - 3))
x = np.arange(len(score) - 3) / float(len(score) - 4)
- gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
- (x - 0.5) / 0.5, 2.) / 2)
+ gaussian = (
+ 1.0
+ / (np.sqrt(2.0 * np.pi) * 0.5)
+ * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
+ )
gaussian = gaussian / np.max(gaussian)
for i in range(len(score)):
- score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
- score) - 1)] * gaussian * 0.3
-
- head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
- score_matrix.shape)
+ score_matrix[i, :] = (
+ score[i]
+ + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
+ )
+
+ head_start, tail_increment = np.unravel_index(
+ score_matrix.argmax(), score_matrix.shape
+ )
tail_start = (head_start + tail_increment + 2) % len(points)
head_end = (head_start + 1) % len(points)
tail_end = (tail_start + 1) % len(points)
@@ -243,22 +244,27 @@ def find_head_tail(self, points, orientation_thr):
head_inds = [head_start, head_end]
tail_inds = [tail_start, tail_end]
else:
- if vector_slope(points[1] - points[0]) + vector_slope(points[
- 3] - points[2]) < vector_slope(points[2] - points[
- 1]) + vector_slope(points[0] - points[3]):
+ if vector_slope(points[1] - points[0]) + vector_slope(
+ points[3] - points[2]
+ ) < vector_slope(points[2] - points[1]) + vector_slope(
+ points[0] - points[3]
+ ):
horizontal_edge_inds = [[0, 1], [2, 3]]
vertical_edge_inds = [[3, 0], [1, 2]]
else:
horizontal_edge_inds = [[3, 0], [1, 2]]
vertical_edge_inds = [[0, 1], [2, 3]]
- vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
- vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
- 0]] - points[vertical_edge_inds[1][1]])
- horizontal_len_sum = norm(points[horizontal_edge_inds[0][
- 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
- horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
- [1]])
+ vertical_len_sum = norm(
+ points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
+ ) + norm(
+ points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
+ )
+ horizontal_len_sum = norm(
+ points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
+ ) + norm(
+ points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
+ )
if vertical_len_sum > horizontal_len_sum * orientation_thr:
head_inds = horizontal_edge_inds[0]
@@ -291,14 +297,12 @@ def resample_sidelines(self, sideline1, sideline2, resample_step):
assert sideline2.shape[0] >= 2
assert isinstance(resample_step, float)
- length1 = sum([
- norm(sideline1[i + 1] - sideline1[i])
- for i in range(len(sideline1) - 1)
- ])
- length2 = sum([
- norm(sideline2[i + 1] - sideline2[i])
- for i in range(len(sideline2) - 1)
- ])
+ length1 = sum(
+ [norm(sideline1[i + 1] - sideline1[i]) for i in range(len(sideline1) - 1)]
+ )
+ length2 = sum(
+ [norm(sideline2[i + 1] - sideline2[i]) for i in range(len(sideline2) - 1)]
+ )
total_length = (length1 + length2) / 2
resample_point_num = max(int(float(total_length) / resample_step), 1)
@@ -332,39 +336,54 @@ def generate_center_region_mask(self, img_size, text_polys):
polygon_points = poly.reshape(-1, 2)
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
- top_line, bot_line, self.resample_step)
+ top_line, bot_line, self.resample_step
+ )
resampled_bot_line = resampled_bot_line[::-1]
if len(resampled_top_line) != len(resampled_bot_line):
continue
center_line = (resampled_top_line + resampled_bot_line) / 2
- line_head_shrink_len = norm(resampled_top_line[0] -
- resampled_bot_line[0]) / 4.0
- line_tail_shrink_len = norm(resampled_top_line[-1] -
- resampled_bot_line[-1]) / 4.0
+ line_head_shrink_len = (
+ norm(resampled_top_line[0] - resampled_bot_line[0]) / 4.0
+ )
+ line_tail_shrink_len = (
+ norm(resampled_top_line[-1] - resampled_bot_line[-1]) / 4.0
+ )
head_shrink_num = int(line_head_shrink_len // self.resample_step)
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
- center_line = center_line[head_shrink_num:len(center_line) -
- tail_shrink_num]
- resampled_top_line = resampled_top_line[head_shrink_num:len(
- resampled_top_line) - tail_shrink_num]
- resampled_bot_line = resampled_bot_line[head_shrink_num:len(
- resampled_bot_line) - tail_shrink_num]
+ center_line = center_line[
+ head_shrink_num : len(center_line) - tail_shrink_num
+ ]
+ resampled_top_line = resampled_top_line[
+ head_shrink_num : len(resampled_top_line) - tail_shrink_num
+ ]
+ resampled_bot_line = resampled_bot_line[
+ head_shrink_num : len(resampled_bot_line) - tail_shrink_num
+ ]
for i in range(0, len(center_line) - 1):
- tl = center_line[i] + (resampled_top_line[i] - center_line[i]
- ) * self.center_region_shrink_ratio
- tr = center_line[i + 1] + (resampled_top_line[i + 1] -
- center_line[i + 1]
- ) * self.center_region_shrink_ratio
- br = center_line[i + 1] + (resampled_bot_line[i + 1] -
- center_line[i + 1]
- ) * self.center_region_shrink_ratio
- bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
- ) * self.center_region_shrink_ratio
- current_center_box = np.vstack([tl, tr, br,
- bl]).astype(np.int32)
+ tl = (
+ center_line[i]
+ + (resampled_top_line[i] - center_line[i])
+ * self.center_region_shrink_ratio
+ )
+ tr = (
+ center_line[i + 1]
+ + (resampled_top_line[i + 1] - center_line[i + 1])
+ * self.center_region_shrink_ratio
+ )
+ br = (
+ center_line[i + 1]
+ + (resampled_bot_line[i + 1] - center_line[i + 1])
+ * self.center_region_shrink_ratio
+ )
+ bl = (
+ center_line[i]
+ + (resampled_bot_line[i] - center_line[i])
+ * self.center_region_shrink_ratio
+ )
+ current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
center_region_boxes.append(current_center_box)
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
@@ -387,7 +406,7 @@ def resample_polygon(self, polygon, n=400):
p2 = polygon[0]
else:
p2 = polygon[i + 1]
- length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
+ length.append(((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5)
total_length = sum(length)
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
@@ -441,7 +460,7 @@ def poly2fourier(self, polygon, fourier_degree):
"""
points = polygon[:, 0] + polygon[:, 1] * 1j
c_fft = fft(points) / len(points)
- c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
+ c = np.hstack((c_fft[-fourier_degree:], c_fft[: fourier_degree + 1]))
return c
def clockwise(self, c, fourier_degree):
@@ -512,10 +531,14 @@ def generate_fourier_maps(self, img_size, text_polys):
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
for i in range(-k, k + 1):
if i != 0:
- real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
- 1 - mask) * real_map[i + k, :, :]
- imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
- 1 - mask) * imag_map[i + k, :, :]
+ real_map[i + k, :, :] = (
+ mask * fourier_coeff[i + k, 0]
+ + (1 - mask) * real_map[i + k, :, :]
+ )
+ imag_map[i + k, :, :] = (
+ mask * fourier_coeff[i + k, 1]
+ + (1 - mask) * imag_map[i + k, :, :]
+ )
else:
yx = np.argwhere(mask > 0.5)
k_ind = np.ones((len(yx)), dtype=np.int64) * k
@@ -607,19 +630,23 @@ def generate_level_targets(self, img_size, text_polys, ignore_polys):
level_img_size = (h // size_divisor, w // size_divisor)
text_region = self.generate_text_region_mask(
- level_img_size, lv_text_polys[ind])[None]
+ level_img_size, lv_text_polys[ind]
+ )[None]
current_level_maps.append(text_region)
center_region = self.generate_center_region_mask(
- level_img_size, lv_text_polys[ind])[None]
+ level_img_size, lv_text_polys[ind]
+ )[None]
current_level_maps.append(center_region)
effective_mask = self.generate_effective_mask(
- level_img_size, lv_ignore_polys[ind])[None]
+ level_img_size, lv_ignore_polys[ind]
+ )[None]
current_level_maps.append(effective_mask)
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
- level_img_size, lv_text_polys[ind])
+ level_img_size, lv_text_polys[ind]
+ )
current_level_maps.append(fourier_real_map)
current_level_maps.append(fourier_image_maps)
@@ -638,9 +665,9 @@ def generate_targets(self, results):
"""
assert isinstance(results, dict)
- image = results['image']
- polygons = results['polys']
- ignore_tags = results['ignore_tags']
+ image = results["image"]
+ polygons = results["polys"]
+ ignore_tags = results["ignore_tags"]
h, w, _ = image.shape
polygon_masks = []
@@ -651,13 +678,14 @@ def generate_targets(self, results):
else:
polygon_masks.append(polygon)
- level_maps = self.generate_level_targets((h, w), polygon_masks,
- polygon_masks_ignore)
+ level_maps = self.generate_level_targets(
+ (h, w), polygon_masks, polygon_masks_ignore
+ )
mapping = {
- 'p3_maps': level_maps[0],
- 'p4_maps': level_maps[1],
- 'p5_maps': level_maps[2]
+ "p3_maps": level_maps[0],
+ "p4_maps": level_maps[1],
+ "p5_maps": level_maps[2],
}
for key, value in mapping.items():
results[key] = value
diff --git a/ppocr/data/imaug/iaa_augment.py b/ppocr/data/imaug/iaa_augment.py
index 0aac7877c2..396e31444a 100644
--- a/ppocr/data/imaug/iaa_augment.py
+++ b/ppocr/data/imaug/iaa_augment.py
@@ -38,15 +38,13 @@ def build(self, args, root=True):
return iaa.Sequential(sequence)
else:
return getattr(iaa, args[0])(
- *[self.to_tuple_if_list(a) for a in args[1:]])
+ *[self.to_tuple_if_list(a) for a in args[1:]]
+ )
elif isinstance(args, dict):
- cls = getattr(iaa, args['type'])
- return cls(**{
- k: self.to_tuple_if_list(v)
- for k, v in args['args'].items()
- })
+ cls = getattr(iaa, args["type"])
+ return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()})
else:
- raise RuntimeError('unknown augmenter arg: ' + str(args))
+ raise RuntimeError("unknown augmenter arg: " + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
@@ -54,34 +52,23 @@ def to_tuple_if_list(self, obj):
return obj
-class IaaAugment():
+class IaaAugment:
def __init__(self, augmenter_args=None, **kwargs):
if augmenter_args is None:
- augmenter_args = [{
- 'type': 'Fliplr',
- 'args': {
- 'p': 0.5
- }
- }, {
- 'type': 'Affine',
- 'args': {
- 'rotate': [-10, 10]
- }
- }, {
- 'type': 'Resize',
- 'args': {
- 'size': [0.5, 3]
- }
- }]
+ augmenter_args = [
+ {"type": "Fliplr", "args": {"p": 0.5}},
+ {"type": "Affine", "args": {"rotate": [-10, 10]}},
+ {"type": "Resize", "args": {"size": [0.5, 3]}},
+ ]
self.augmenter = AugmenterBuilder().build(augmenter_args)
def __call__(self, data):
- image = data['image']
+ image = data["image"]
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
- data['image'] = aug.augment_image(image)
+ data["image"] = aug.augment_image(image)
data = self.may_augment_annotation(aug, data, shape)
return data
@@ -90,16 +77,16 @@ def may_augment_annotation(self, aug, data, shape):
return data
line_polys = []
- for poly in data['polys']:
+ for poly in data["polys"]:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
- data['polys'] = np.array(line_polys)
+ data["polys"] = np.array(line_polys)
return data
def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
- [imgaug.KeypointsOnImage(
- keypoints, shape=img_shape)])[0].keypoints
+ [imgaug.KeypointsOnImage(keypoints, shape=img_shape)]
+ )[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index a3986a0c04..ffb753f799 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -35,11 +35,11 @@ def __init__(self, label_list, **kwargs):
self.label_list = label_list
def __call__(self, data):
- label = data['label']
+ label = data["label"]
if label not in self.label_list:
return None
label = self.label_list.index(label)
- data['label'] = label
+ data["label"] = label
return data
@@ -48,16 +48,16 @@ def __init__(self, **kwargs):
pass
def __call__(self, data):
- label = data['label']
+ label = data["label"]
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
- box = label[bno]['points']
- txt = label[bno]['transcription']
+ box = label[bno]["points"]
+ txt = label[bno]["transcription"]
boxes.append(box)
txts.append(txt)
- if txt in ['*', '###']:
+ if txt in ["*", "###"]:
txt_tags.append(True)
else:
txt_tags.append(False)
@@ -67,9 +67,9 @@ def __call__(self, data):
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool_)
- data['polys'] = boxes
- data['texts'] = txts
- data['ignore_tags'] = txt_tags
+ data["polys"] = boxes
+ data["texts"] = txts
+ data["ignore_tags"] = txt_tags
return data
def order_points_clockwise(self, pts):
@@ -96,14 +96,15 @@ def expand_points_num(self, boxes):
class BaseRecLabelEncode(object):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- lower=False):
-
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=False,
+ ):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
@@ -122,7 +123,7 @@ def __init__(self,
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
@@ -163,75 +164,73 @@ def encode(self, text):
class CTCLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(CTCLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
- data['label'] = np.array(text)
+ data["label"] = np.array(text)
label = [0] * len(self.character)
for x in text:
label[x] += 1
- data['label_ace'] = np.array(label)
+ data["label_ace"] = np.array(label)
return data
def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
+ dict_character = ["blank"] + dict_character
return dict_character
class E2ELabelEncodeTest(BaseRecLabelEncode):
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(E2ELabelEncodeTest, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def __call__(self, data):
import json
+
padnum = len(self.dict)
- label = data['label']
+ label = data["label"]
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
- box = label[bno]['points']
- txt = label[bno]['transcription']
+ box = label[bno]["points"]
+ txt = label[bno]["transcription"]
boxes.append(box)
txts.append(txt)
- if txt in ['*', '###']:
+ if txt in ["*", "###"]:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool_)
- data['polys'] = boxes
- data['ignore_tags'] = txt_tags
+ data["polys"] = boxes
+ data["ignore_tags"] = txt_tags
temp_texts = []
for text in txts:
text = text.lower()
text = self.encode(text)
if text is None:
return None
- text = text + [padnum] * (self.max_text_len - len(text)
- ) # use 36 to pad
+ text = text + [padnum] * (self.max_text_len - len(text)) # use 36 to pad
temp_texts.append(text)
- data['texts'] = np.array(temp_texts)
+ data["texts"] = np.array(temp_texts)
return data
@@ -241,39 +240,37 @@ def __init__(self, **kwargs):
def __call__(self, data):
import json
- label = data['label']
+
+ label = data["label"]
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
- box = label[bno]['points']
- txt = label[bno]['transcription']
+ box = label[bno]["points"]
+ txt = label[bno]["transcription"]
boxes.append(box)
txts.append(txt)
- if txt in ['*', '###']:
+ if txt in ["*", "###"]:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool_)
- data['polys'] = boxes
- data['texts'] = txts
- data['ignore_tags'] = txt_tags
+ data["polys"] = boxes
+ data["texts"] = txts
+ data["ignore_tags"] = txt_tags
return data
class KieLabelEncode(object):
- def __init__(self,
- character_dict_path,
- class_path,
- norm=10,
- directed=False,
- **kwargs):
+ def __init__(
+ self, character_dict_path, class_path, norm=10, directed=False, **kwargs
+ ):
super(KieLabelEncode, self).__init__()
- self.dict = dict({'': 0})
+ self.dict = dict({"": 0})
self.label2classid_map = dict()
- with open(character_dict_path, 'r', encoding='utf-8') as fr:
+ with open(character_dict_path, "r", encoding="utf-8") as fr:
idx = 1
for line in fr:
char = line.strip()
@@ -306,19 +303,19 @@ def pad_text_indices(self, text_inds):
recoder_len = max([len(text_ind) for text_ind in text_inds])
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
for idx, text_ind in enumerate(text_inds):
- padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
+ padded_text_inds[idx, : len(text_ind)] = np.array(text_ind)
return padded_text_inds, recoder_len
def list_to_numpy(self, ann_infos):
"""Convert bboxes, relations, texts and labels to ndarray."""
- boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
+ boxes, text_inds = ann_infos["points"], ann_infos["text_inds"]
boxes = np.array(boxes, np.int32)
relations, bboxes = self.compute_relation(boxes)
- labels = ann_infos.get('labels', None)
+ labels = ann_infos.get("labels", None)
if labels is not None:
labels = np.array(labels, np.int32)
- edges = ann_infos.get('edges', None)
+ edges = ann_infos.get("edges", None)
if edges is not None:
labels = labels[:, None]
edges = np.array(edges)
@@ -340,19 +337,19 @@ def list_to_numpy(self, ann_infos):
temp_padded_text_inds[:h, :] = padded_text_inds
temp_labels = np.zeros([max_num, max_num])
- temp_labels[:h, :h + 1] = labels
+ temp_labels[:h, : h + 1] = labels
tag = np.array([h, recoder_len])
return dict(
- image=ann_infos['image'],
+ image=ann_infos["image"],
points=temp_bboxes,
relations=temp_relations,
texts=temp_padded_text_inds,
labels=temp_labels,
- tag=tag)
+ tag=tag,
+ )
def convert_canonical(self, points_x, points_y):
-
assert len(points_x) == 4
assert len(points_y) == 4
@@ -382,7 +379,6 @@ def convert_canonical(self, points_x, points_y):
return sorted_points_x, sorted_points_y
def sort_vertex(self, points_x, points_y):
-
assert len(points_x) == 4
assert len(points_y) == 4
@@ -406,11 +402,12 @@ def sort_vertex(self, points_x, points_y):
def __call__(self, data):
import json
- label = data['label']
+
+ label = data["label"]
annotations = json.loads(label)
boxes, texts, text_inds, labels, edges = [], [], [], [], []
for ann in annotations:
- box = ann['points']
+ box = ann["points"]
x_list = [box[i][0] for i in range(4)]
y_list = [box[i][1] for i in range(4)]
sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
@@ -419,40 +416,40 @@ def __call__(self, data):
sorted_box.append(x)
sorted_box.append(y)
boxes.append(sorted_box)
- text = ann['transcription']
- texts.append(ann['transcription'])
+ text = ann["transcription"]
+ texts.append(ann["transcription"])
text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind)
- if 'label' in ann.keys():
- labels.append(self.label2classid_map[ann['label']])
- elif 'key_cls' in ann.keys():
- labels.append(ann['key_cls'])
+ if "label" in ann.keys():
+ labels.append(self.label2classid_map[ann["label"]])
+ elif "key_cls" in ann.keys():
+ labels.append(ann["key_cls"])
else:
raise ValueError(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
- edges.append(ann.get('edge', 0))
+ edges.append(ann.get("edge", 0))
ann_infos = dict(
- image=data['image'],
+ image=data["image"],
points=boxes,
texts=texts,
text_inds=text_inds,
edges=edges,
- labels=labels)
+ labels=labels,
+ )
return self.list_to_numpy(ann_infos)
class AttnLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(AttnLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -461,16 +458,20 @@ def add_special_char(self, dict_character):
return dict_character
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
- data['length'] = np.array(len(text))
- text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
- - len(text) - 2)
- data['label'] = np.array(text)
+ data["length"] = np.array(len(text))
+ text = (
+ [0]
+ + text
+ + [len(self.character) - 1]
+ + [0] * (self.max_text_len - len(text) - 2)
+ )
+ data["label"] = np.array(text)
return data
def get_ignored_tokens(self):
@@ -484,21 +485,19 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
class RFLLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(RFLLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -513,20 +512,24 @@ def encode_cnt(self, text):
return np.array(cnt_label)
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
cnt_label = self.encode_cnt(text)
- data['length'] = np.array(len(text))
- text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
- - len(text) - 2)
+ data["length"] = np.array(len(text))
+ text = (
+ [0]
+ + text
+ + [len(self.character) - 1]
+ + [0] * (self.max_text_len - len(text) - 2)
+ )
if len(text) != self.max_text_len:
return None
- data['label'] = np.array(text)
- data['cnt_label'] = cnt_label
+ data["label"] = np.array(text)
+ data["cnt_label"] = cnt_label
return data
def get_ignored_tokens(self):
@@ -540,71 +543,73 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
class SEEDLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(SEEDLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def add_special_char(self, dict_character):
self.padding = "padding"
self.end_str = "eos"
self.unknown = "unknown"
- dict_character = dict_character + [
- self.end_str, self.padding, self.unknown
- ]
+ dict_character = dict_character + [self.end_str, self.padding, self.unknown]
return dict_character
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
- data['length'] = np.array(len(text)) + 1 # conclude eos
- text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
- self.max_text_len - len(text) - 1)
- data['label'] = np.array(text)
+ data["length"] = np.array(len(text)) + 1 # conclude eos
+ text = (
+ text
+ + [len(self.character) - 3]
+ + [len(self.character) - 2] * (self.max_text_len - len(text) - 1)
+ )
+ data["label"] = np.array(text)
return data
class SRNLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length=25,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length=25,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs
+ ):
super(SRNLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
char_num = len(self.character)
if text is None:
return None
if len(text) > self.max_text_len:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
text = text + [char_num - 1] * (self.max_text_len - len(text))
- data['label'] = np.array(text)
+ data["label"] = np.array(text)
return data
def get_ignored_tokens(self):
@@ -618,22 +623,23 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
class TableLabelEncode(AttnLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path,
- replace_empty_cell_token=False,
- merge_no_span_structure=False,
- learn_empty_box=False,
- loc_reg_num=4,
- **kwargs):
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path,
+ replace_empty_cell_token=False,
+ merge_no_span_structure=False,
+ learn_empty_box=False,
+ loc_reg_num=4,
+ **kwargs
+ ):
self.max_text_len = max_text_length
self.lower = False
self.learn_empty_box = learn_empty_box
@@ -644,7 +650,7 @@ def __init__(self,
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
dict_character.append(line)
if self.merge_no_span_structure:
@@ -665,20 +671,19 @@ def __init__(self,
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
- self.td_token = [' ', ' | ', ' | | ']
+ self.td_token = [" ", " | ", " | | "]
self.empty_bbox_token_dict = {
- "[]": ' ',
- "[' ']": ' ',
- "[' ', ' ', '']": ' ',
- "['\\u2028', '\\u2028']": ' ',
- "[' ', ' ', '']": ' ',
- "[' ', '']": ' ',
- "[' ', ' ', '']": ' ',
- "[' ', '', '', '']": ' ',
- "[' ', '', ' ', '', '']": ' ',
- "[' ', '']": ' ',
- "[' ', ' ', '\\u2028', ' ', '\\u2028', ' ', '']":
- ' ',
+ "[]": " ",
+ "[' ']": " ",
+ "[' ', ' ', '']": " ",
+ "['\\u2028', '\\u2028']": " ",
+ "[' ', ' ', '']": " ",
+ "[' ', '']": " ",
+ "[' ', ' ', '']": " ",
+ "[' ', '', '', '']": " ",
+ "[' ', '', ' ', '', '']": " ",
+ "[' ', '']": " ",
+ "[' ', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": " ",
}
@property
@@ -686,8 +691,8 @@ def _max_text_len(self):
return self.max_text_len + 2
def __call__(self, data):
- cells = data['cells']
- structure = data['structure']
+ cells = data["cells"]
+ structure = data["structure"]
if self.merge_no_span_structure:
structure = self._merge_no_span_structure(structure)
if self.replace_empty_cell_token:
@@ -695,45 +700,43 @@ def __call__(self, data):
# remove empty token and add " " to span token
new_structure = []
for token in structure:
- if token != '':
- if 'span' in token and token[0] != ' ':
- token = ' ' + token
+ if token != "":
+ if "span" in token and token[0] != " ":
+ token = " " + token
new_structure.append(token)
# encode structure
structure = self.encode(new_structure)
if structure is None:
return None
- structure = [self.start_idx] + structure + [self.end_idx
- ] # add sos abd eos
- structure = structure + [self.pad_idx] * (self._max_text_len -
- len(structure)) # pad
+ structure = [self.start_idx] + structure + [self.end_idx] # add sos abd eos
+ structure = structure + [self.pad_idx] * (
+ self._max_text_len - len(structure)
+ ) # pad
structure = np.array(structure)
- data['structure'] = structure
+ data["structure"] = structure
if len(structure) > self._max_text_len:
return None
# encode box
- bboxes = np.zeros(
- (self._max_text_len, self.loc_reg_num), dtype=np.float32)
+ bboxes = np.zeros((self._max_text_len, self.loc_reg_num), dtype=np.float32)
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0
for i, token in enumerate(structure):
if self.idx2char[token] in self.td_token:
- if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
- 'tokens']) > 0:
- bbox = cells[bbox_idx]['bbox'].copy()
+ if "bbox" in cells[bbox_idx] and len(cells[bbox_idx]["tokens"]) > 0:
+ bbox = cells[bbox_idx]["bbox"].copy()
bbox = np.array(bbox, dtype=np.float32).reshape(-1)
bboxes[i] = bbox
bbox_masks[i] = 1.0
if self.learn_empty_box:
bbox_masks[i] = 1.0
bbox_idx += 1
- data['bboxes'] = bboxes
- data['bbox_masks'] = bbox_masks
+ data["bboxes"] = bboxes
+ data["bbox_masks"] = bbox_masks
return data
def _merge_no_span_structure(self, structure):
@@ -745,8 +748,8 @@ def _merge_no_span_structure(self, structure):
i = 0
while i < len(structure):
token = structure[i]
- if token == ' ':
- token = ' | | '
+ if token == " ":
+ token = " | | "
i += 1
new_structure.append(token)
i += 1
@@ -761,9 +764,9 @@ def _replace_empty_cell_token(self, token_list, cells):
bbox_idx = 0
add_empty_bbox_token_list = []
for token in token_list:
- if token in [' | ', ' ']:
- if 'bbox' not in cells[bbox_idx].keys():
- content = str(cells[bbox_idx]['tokens'])
+ if token in [" | | ", " "]:
+ if "bbox" not in cells[bbox_idx].keys():
+ content = str(cells[bbox_idx]["tokens"])
token = self.empty_bbox_token_dict[content]
add_empty_bbox_token_list.append(token)
bbox_idx += 1
@@ -773,19 +776,27 @@ def _replace_empty_cell_token(self, token_list, cells):
class TableMasterLabelEncode(TableLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path,
- replace_empty_cell_token=False,
- merge_no_span_structure=False,
- learn_empty_box=False,
- loc_reg_num=4,
- **kwargs):
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path,
+ replace_empty_cell_token=False,
+ merge_no_span_structure=False,
+ learn_empty_box=False,
+ loc_reg_num=4,
+ **kwargs
+ ):
super(TableMasterLabelEncode, self).__init__(
- max_text_length, character_dict_path, replace_empty_cell_token,
- merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
+ max_text_length,
+ character_dict_path,
+ replace_empty_cell_token,
+ merge_no_span_structure,
+ learn_empty_box,
+ loc_reg_num,
+ **kwargs
+ )
self.pad_idx = self.dict[self.pad_str]
self.unknown_idx = self.dict[self.unknown_str]
@@ -794,36 +805,39 @@ def _max_text_len(self):
return self.max_text_len
def add_special_char(self, dict_character):
- self.beg_str = ''
- self.end_str = ''
- self.unknown_str = ''
- self.pad_str = ''
+ self.beg_str = ""
+ self.end_str = ""
+ self.unknown_str = ""
+ self.pad_str = ""
dict_character = dict_character
dict_character = dict_character + [
- self.unknown_str, self.beg_str, self.end_str, self.pad_str
+ self.unknown_str,
+ self.beg_str,
+ self.end_str,
+ self.pad_str,
]
return dict_character
class TableBoxEncode(object):
- def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs):
- assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy']
+ def __init__(self, in_box_format="xyxy", out_box_format="xyxy", **kwargs):
+ assert out_box_format in ["xywh", "xyxy", "xyxyxyxy"]
self.in_box_format = in_box_format
self.out_box_format = out_box_format
def __call__(self, data):
- img_height, img_width = data['image'].shape[:2]
- bboxes = data['bboxes']
+ img_height, img_width = data["image"].shape[:2]
+ bboxes = data["bboxes"]
if self.in_box_format != self.out_box_format:
- if self.out_box_format == 'xywh':
- if self.in_box_format == 'xyxyxyxy':
+ if self.out_box_format == "xywh":
+ if self.in_box_format == "xyxyxyxy":
bboxes = self.xyxyxyxy2xywh(bboxes)
- elif self.in_box_format == 'xyxy':
+ elif self.in_box_format == "xyxy":
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
- data['bboxes'] = bboxes
+ data["bboxes"] = bboxes
return data
def xyxyxyxy2xywh(self, boxes):
@@ -844,15 +858,14 @@ def xyxy2xywh(self, bboxes):
class SARLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(SARLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def add_special_char(self, dict_character):
beg_end_str = ""
@@ -869,18 +882,18 @@ def add_special_char(self, dict_character):
return dict_character
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
- padded_text[:len(target)] = target
- data['label'] = np.array(padded_text)
+ padded_text[: len(target)] = target
+ data["label"] = np.array(padded_text)
return data
def get_ignored_tokens(self):
@@ -888,16 +901,19 @@ def get_ignored_tokens(self):
class SATRNLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- lower=False,
- **kwargs):
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=False,
+ **kwargs
+ ):
super(SATRNLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
self.lower = lower
def add_special_char(self, dict_character):
@@ -925,18 +941,18 @@ def encode(self, text):
return text_list
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
if len(target) > self.max_text_len:
- padded_text = target[:self.max_text_len]
+ padded_text = target[: self.max_text_len]
else:
- padded_text[:len(target)] = target
- data['label'] = np.array(padded_text)
+ padded_text[: len(target)] = target
+ data["label"] = np.array(padded_text)
return data
def get_ignored_tokens(self):
@@ -944,18 +960,17 @@ def get_ignored_tokens(self):
class PRENLabelEncode(BaseRecLabelEncode):
- def __init__(self,
- max_text_length,
- character_dict_path,
- use_space_char=False,
- **kwargs):
+ def __init__(
+ self, max_text_length, character_dict_path, use_space_char=False, **kwargs
+ ):
super(PRENLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def add_special_char(self, dict_character):
- padding_str = '' # 0
- end_str = '' # 1
- unknown_str = '' # 2
+ padding_str = "" # 0
+ end_str = "" # 1
+ unknown_str = "" # 2
dict_character = [padding_str, end_str, unknown_str] + dict_character
self.padding_idx = 0
@@ -977,16 +992,15 @@ def encode(self, text):
text_list.append(self.dict[char])
text_list.append(self.end_idx)
if len(text_list) < self.max_text_len:
- text_list += [self.padding_idx] * (
- self.max_text_len - len(text_list))
+ text_list += [self.padding_idx] * (self.max_text_len - len(text_list))
return text_list
def __call__(self, data):
- text = data['label']
+ text = data["label"]
encoded_text = self.encode(text)
if encoded_text is None:
return None
- data['label'] = np.array(encoded_text)
+ data["label"] = np.array(encoded_text)
return data
@@ -995,37 +1009,45 @@ class VQATokenLabelEncode(object):
Label encode for NLP VQA methods
"""
- def __init__(self,
- class_path,
- contains_re=False,
- add_special_ids=False,
- algorithm='LayoutXLM',
- use_textline_bbox_info=True,
- order_method=None,
- infer_mode=False,
- ocr_engine=None,
- **kwargs):
+ def __init__(
+ self,
+ class_path,
+ contains_re=False,
+ add_special_ids=False,
+ algorithm="LayoutXLM",
+ use_textline_bbox_info=True,
+ order_method=None,
+ infer_mode=False,
+ ocr_engine=None,
+ **kwargs
+ ):
super(VQATokenLabelEncode, self).__init__()
- from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
+ from paddlenlp.transformers import (
+ LayoutXLMTokenizer,
+ LayoutLMTokenizer,
+ LayoutLMv2Tokenizer,
+ )
from ppocr.utils.utility import load_vqa_bio_label_maps
+
tokenizer_dict = {
- 'LayoutXLM': {
- 'class': LayoutXLMTokenizer,
- 'pretrained_model': 'layoutxlm-base-uncased'
+ "LayoutXLM": {
+ "class": LayoutXLMTokenizer,
+ "pretrained_model": "layoutxlm-base-uncased",
+ },
+ "LayoutLM": {
+ "class": LayoutLMTokenizer,
+ "pretrained_model": "layoutlm-base-uncased",
},
- 'LayoutLM': {
- 'class': LayoutLMTokenizer,
- 'pretrained_model': 'layoutlm-base-uncased'
+ "LayoutLMv2": {
+ "class": LayoutLMv2Tokenizer,
+ "pretrained_model": "layoutlmv2-base-uncased",
},
- 'LayoutLMv2': {
- 'class': LayoutLMv2Tokenizer,
- 'pretrained_model': 'layoutlmv2-base-uncased'
- }
}
self.contains_re = contains_re
tokenizer_config = tokenizer_dict[algorithm]
- self.tokenizer = tokenizer_config['class'].from_pretrained(
- tokenizer_config['pretrained_model'])
+ self.tokenizer = tokenizer_config["class"].from_pretrained(
+ tokenizer_config["pretrained_model"]
+ )
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
self.add_special_ids = add_special_ids
self.infer_mode = infer_mode
@@ -1074,8 +1096,7 @@ def __call__(self, data):
for idx in range(len(ocr_info)):
if "bbox" not in ocr_info[idx]:
- ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
- "points"])
+ ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx]["points"])
if self.order_method == "tb-yx":
ocr_info = order_by_tbyx(ocr_info)
@@ -1085,7 +1106,7 @@ def __call__(self, data):
if train_re:
ocr_info = self.filter_empty_contents(ocr_info)
- height, width, _ = data['image'].shape
+ height, width, _ = data["image"].shape
words_list = []
bbox_list = []
@@ -1102,7 +1123,7 @@ def __call__(self, data):
entity_id_to_index_map = {}
empty_entity = set()
- data['ocr_info'] = copy.deepcopy(ocr_info)
+ data["ocr_info"] = copy.deepcopy(ocr_info)
for info in ocr_info:
text = info["transcription"]
@@ -1122,21 +1143,21 @@ def __call__(self, data):
text,
pad_to_max_seq_len=False,
return_attention_mask=True,
- return_token_type_ids=True)
+ return_token_type_ids=True,
+ )
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
- encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
- -1]
- encode_res["attention_mask"] = encode_res["attention_mask"][1:
- -1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
if self.use_textline_bbox_info:
bbox = [info["bbox"]] * len(encode_res["input_ids"])
else:
- bbox = self.split_bbox(info["bbox"], info["transcription"],
- self.tokenizer)
+ bbox = self.split_bbox(
+ info["bbox"], info["transcription"], self.tokenizer
+ )
if len(bbox) <= 0:
continue
bbox = self._smooth_box(bbox, height, width)
@@ -1146,7 +1167,7 @@ def __call__(self, data):
# parse label
if not self.infer_mode:
- label = info['label']
+ label = info["label"]
gt_label = self._parse_label(label, encode_res)
# construct entities for re
@@ -1154,18 +1175,21 @@ def __call__(self, data):
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
- entities.append({
- "start": len(input_ids_list),
- "end":
- len(input_ids_list) + len(encode_res["input_ids"]),
- "label": label.upper(),
- })
+ entities.append(
+ {
+ "start": len(input_ids_list),
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": label.upper(),
+ }
+ )
else:
- entities.append({
- "start": len(input_ids_list),
- "end": len(input_ids_list) + len(encode_res["input_ids"]),
- "label": 'O',
- })
+ entities.append(
+ {
+ "start": len(input_ids_list),
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": "O",
+ }
+ )
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend(bbox)
@@ -1174,23 +1198,24 @@ def __call__(self, data):
if not self.infer_mode:
gt_label_list.extend(gt_label)
- data['input_ids'] = input_ids_list
- data['token_type_ids'] = token_type_ids_list
- data['bbox'] = bbox_list
- data['attention_mask'] = [1] * len(input_ids_list)
- data['labels'] = gt_label_list
- data['segment_offset_id'] = segment_offset_id
- data['tokenizer_params'] = dict(
+ data["input_ids"] = input_ids_list
+ data["token_type_ids"] = token_type_ids_list
+ data["bbox"] = bbox_list
+ data["attention_mask"] = [1] * len(input_ids_list)
+ data["labels"] = gt_label_list
+ data["segment_offset_id"] = segment_offset_id
+ data["tokenizer_params"] = dict(
padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id,
- pad_token_id=self.tokenizer.pad_token_id)
- data['entities'] = entities
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+ data["entities"] = entities
if train_re:
- data['relations'] = relations
- data['id2label'] = id2label
- data['empty_entity'] = empty_entity
- data['entity_id_to_index_map'] = entity_id_to_index_map
+ data["relations"] = relations
+ data["id2label"] = id2label
+ data["empty_entity"] = empty_entity
+ data["entity_id_to_index_map"] = entity_id_to_index_map
return data
def trans_poly_to_bbox(self, poly):
@@ -1202,17 +1227,19 @@ def trans_poly_to_bbox(self, poly):
def _load_ocr_info(self, data):
if self.infer_mode:
- ocr_result = self.ocr_engine.ocr(data['image'], cls=False)[0]
+ ocr_result = self.ocr_engine.ocr(data["image"], cls=False)[0]
ocr_info = []
for res in ocr_result:
- ocr_info.append({
- "transcription": res[1][0],
- "bbox": self.trans_poly_to_bbox(res[0]),
- "points": res[0],
- })
+ ocr_info.append(
+ {
+ "transcription": res[1][0],
+ "bbox": self.trans_poly_to_bbox(res[0]),
+ "points": res[0],
+ }
+ )
return ocr_info
else:
- info = data['label']
+ info = data["label"]
# read text info
info_dict = json.loads(info)
return info_dict
@@ -1232,107 +1259,112 @@ def _parse_label(self, label, encode_res):
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
- gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
- (len(encode_res["input_ids"]) - 1))
+ gt_label.extend(
+ [self.label2id_map[("i-" + label).upper()]]
+ * (len(encode_res["input_ids"]) - 1)
+ )
return gt_label
class MultiLabelEncode(BaseRecLabelEncode):
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- gtc_encode=None,
- **kwargs):
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ gtc_encode=None,
+ **kwargs
+ ):
super(MultiLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
- self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
- use_space_char, **kwargs)
+ self.ctc_encode = CTCLabelEncode(
+ max_text_length, character_dict_path, use_space_char, **kwargs
+ )
self.gtc_encode_type = gtc_encode
if gtc_encode is None:
self.gtc_encode = SARLabelEncode(
- max_text_length, character_dict_path, use_space_char, **kwargs)
+ max_text_length, character_dict_path, use_space_char, **kwargs
+ )
else:
self.gtc_encode = eval(gtc_encode)(
- max_text_length, character_dict_path, use_space_char, **kwargs)
+ max_text_length, character_dict_path, use_space_char, **kwargs
+ )
def __call__(self, data):
data_ctc = copy.deepcopy(data)
data_gtc = copy.deepcopy(data)
data_out = dict()
- data_out['img_path'] = data.get('img_path', None)
- data_out['image'] = data['image']
+ data_out["img_path"] = data.get("img_path", None)
+ data_out["image"] = data["image"]
ctc = self.ctc_encode.__call__(data_ctc)
gtc = self.gtc_encode.__call__(data_gtc)
if ctc is None or gtc is None:
return None
- data_out['label_ctc'] = ctc['label']
+ data_out["label_ctc"] = ctc["label"]
if self.gtc_encode_type is not None:
- data_out['label_gtc'] = gtc['label']
+ data_out["label_gtc"] = gtc["label"]
else:
- data_out['label_sar'] = gtc['label']
- data_out['length'] = ctc['length']
+ data_out["label_sar"] = gtc["label"]
+ data_out["length"] = ctc["length"]
return data_out
class NRTRLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ """Convert between text-label and text-index"""
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(NRTRLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
text.insert(0, 2)
text.append(3)
text = text + [0] * (self.max_text_len - len(text))
- data['label'] = np.array(text)
+ data["label"] = np.array(text)
return data
def add_special_char(self, dict_character):
- dict_character = ['blank', '', '', ''] + dict_character
+ dict_character = ["blank", "", "", ""] + dict_character
return dict_character
class ParseQLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
- BOS = '[B]'
- EOS = '[E]'
- PAD = '[P]'
+ """Convert between text-label and text-index"""
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
+ BOS = "[B]"
+ EOS = "[E]"
+ PAD = "[P]"
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
super(ParseQLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 2:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
text = text + [self.dict[self.PAD]] * (self.max_text_len - len(text))
- data['label'] = np.array(text)
+ data["label"] = np.array(text)
return data
def add_special_char(self, dict_character):
@@ -1341,97 +1373,100 @@ def add_special_char(self, dict_character):
class ViTSTRLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- ignore_index=0,
- **kwargs):
-
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ ignore_index=0,
+ **kwargs
+ ):
super(ViTSTRLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
self.ignore_index = ignore_index
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
text.insert(0, self.ignore_index)
text.append(1)
text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
- data['label'] = np.array(text)
+ data["label"] = np.array(text)
return data
def add_special_char(self, dict_character):
- dict_character = ['', ''] + dict_character
+ dict_character = ["", ""] + dict_character
return dict_character
class ABINetLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- ignore_index=100,
- **kwargs):
-
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ ignore_index=100,
+ **kwargs
+ ):
super(ABINetLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
self.ignore_index = ignore_index
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
text.append(0)
text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
- data['label'] = np.array(text)
+ data["label"] = np.array(text)
return data
def add_special_char(self, dict_character):
- dict_character = [''] + dict_character
+ dict_character = [""] + dict_character
return dict_character
class SRLabelEncode(BaseRecLabelEncode):
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
- super(SRLabelEncode, self).__init__(max_text_length,
- character_dict_path, use_space_char)
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
+ super(SRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char
+ )
self.dic = {}
- with open(character_dict_path, 'r') as fin:
+ with open(character_dict_path, "r") as fin:
for line in fin.readlines():
line = line.strip()
character, sequence = line.split()
self.dic[character] = sequence
- english_stroke_alphabet = '0123456789'
+ english_stroke_alphabet = "0123456789"
self.english_stroke_dict = {}
for index in range(len(english_stroke_alphabet)):
self.english_stroke_dict[english_stroke_alphabet[index]] = index
def encode(self, label):
- stroke_sequence = ''
+ stroke_sequence = ""
for character in label:
if character not in self.dic:
continue
else:
stroke_sequence += self.dic[character]
- stroke_sequence += '0'
+ stroke_sequence += "0"
label = stroke_sequence
length = len(label)
@@ -1443,7 +1478,7 @@ def encode(self, label):
return length, input_tensor
def __call__(self, data):
- text = data['label']
+ text = data["label"]
length, input_tensor = self.encode(text)
data["length"] = length
@@ -1454,16 +1489,19 @@ def __call__(self, data):
class SPINLabelEncode(AttnLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- lower=True,
- **kwargs):
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=True,
+ **kwargs
+ ):
super(SPINLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
self.lower = lower
def add_special_char(self, dict_character):
@@ -1473,37 +1511,36 @@ def add_special_char(self, dict_character):
return dict_character
def __call__(self, data):
- text = data['label']
+ text = data["label"]
text = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
target = [0] + text + [1]
padded_text = [0 for _ in range(self.max_text_len + 2)]
- padded_text[:len(target)] = target
- data['label'] = np.array(padded_text)
+ padded_text[: len(target)] = target
+ data["label"] = np.array(padded_text)
return data
class VLLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
- super(VLLabelEncode, self).__init__(max_text_length,
- character_dict_path, use_space_char)
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
+ super(VLLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char
+ )
self.dict = {}
for i, char in enumerate(self.character):
self.dict[char] = i
def __call__(self, data):
- text = data['label'] # original string
+ text = data["label"] # original string
# generate occluded text
len_str = len(text)
if len_str <= 0:
@@ -1517,19 +1554,19 @@ def __call__(self, data):
elif change_id == 0:
label_res = text[1:]
else:
- label_res = text[:change_id] + text[change_id + 1:]
+ label_res = text[:change_id] + text[change_id + 1 :]
- data['label_res'] = label_res # remaining string
- data['label_sub'] = label_sub # occluded character
- data['label_id'] = change_id # character index
+ data["label_res"] = label_res # remaining string
+ data["label_sub"] = label_sub # occluded character
+ data["label_id"] = change_id # character index
# encode label
text = self.encode(text)
if text is None:
return None
text = [i + 1 for i in text]
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
- data['label'] = np.array(text)
+ data["label"] = np.array(text)
label_res = self.encode(label_res)
label_sub = self.encode(label_sub)
if label_res is None:
@@ -1540,12 +1577,12 @@ def __call__(self, data):
label_sub = []
else:
label_sub = [i + 1 for i in label_sub]
- data['length_res'] = np.array(len(label_res))
- data['length_sub'] = np.array(len(label_sub))
+ data["length_res"] = np.array(len(label_res))
+ data["length_sub"] = np.array(len(label_sub))
label_res = label_res + [0] * (self.max_text_len - len(label_res))
label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
- data['label_res'] = np.array(label_res)
- data['label_sub'] = np.array(label_sub)
+ data["label_res"] = np.array(label_res)
+ data["label_sub"] = np.array(label_sub)
return data
@@ -1554,36 +1591,39 @@ def __init__(self, **kwargs):
pass
def __call__(self, data):
- label = data['label']
+ label = data["label"]
label = json.loads(label)
nBox = len(label)
boxes, txts = [], []
for bno in range(0, nBox):
- box = label[bno]['points']
+ box = label[bno]["points"]
box = np.array(box)
boxes.append(box)
- txt = label[bno]['transcription']
+ txt = label[bno]["transcription"]
txts.append(txt)
if len(boxes) == 0:
return None
- data['polys'] = boxes
- data['texts'] = txts
+ data["polys"] = boxes
+ data["texts"] = txts
return data
class CANLabelEncode(BaseRecLabelEncode):
- def __init__(self,
- character_dict_path,
- max_text_length=100,
- use_space_char=False,
- lower=True,
- **kwargs):
+ def __init__(
+ self,
+ character_dict_path,
+ max_text_length=100,
+ use_space_char=False,
+ lower=True,
+ **kwargs
+ ):
super(CANLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char, lower)
+ max_text_length, character_dict_path, use_space_char, lower
+ )
def encode(self, text_seq):
text_seq_encoded = []
@@ -1596,49 +1636,52 @@ def encode(self, text_seq):
return text_seq_encoded
def __call__(self, data):
- label = data['label']
+ label = data["label"]
if isinstance(label, str):
label = label.strip().split()
label.append(self.end_str)
- data['label'] = self.encode(label)
+ data["label"] = self.encode(label)
return data
class CPPDLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
-
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- ch=False,
- ignore_index=100,
- **kwargs):
+ """Convert between text-label and text-index"""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ ch=False,
+ ignore_index=100,
+ **kwargs
+ ):
super(CPPDLabelEncode, self).__init__(
- max_text_length, character_dict_path, use_space_char)
+ max_text_length, character_dict_path, use_space_char
+ )
self.ch = ch
self.ignore_index = ignore_index
def __call__(self, data):
- text = data['label']
+ text = data["label"]
if self.ch:
text, text_node_index, text_node_num = self.encodech(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
- text_pos_node = [1] * (len(text) + 1) + [0] * (self.max_text_len -
- len(text))
+ text_pos_node = [1] * (len(text) + 1) + [0] * (
+ self.max_text_len - len(text)
+ )
text.append(0) # eos
- text = text + [self.ignore_index] * (self.max_text_len + 1 -
- len(text))
+ text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
- data['label'] = np.array(text)
- data['label_node'] = np.array(text_node_num + text_pos_node)
- data['label_index'] = np.array(text_node_index)
+ data["label"] = np.array(text)
+ data["label_node"] = np.array(text_node_num + text_pos_node)
+ data["label_index"] = np.array(text_node_index)
return data
else:
text, text_char_node, ch_order = self.encode(text)
@@ -1646,29 +1689,28 @@ def __call__(self, data):
return None
if len(text) >= self.max_text_len:
return None
- data['length'] = np.array(len(text))
+ data["length"] = np.array(len(text))
- text_pos_node = [1] * (len(text) + 1) + [0] * (self.max_text_len -
- len(text))
+ text_pos_node = [1] * (len(text) + 1) + [0] * (
+ self.max_text_len - len(text)
+ )
text.append(0) # eos
- text = text + [self.ignore_index] * (self.max_text_len + 1 -
- len(text))
- data['label'] = np.array(text)
- data['label_node'] = np.array(text_char_node + text_pos_node)
- data['label_order'] = np.array(ch_order)
+ text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
+ data["label"] = np.array(text)
+ data["label_node"] = np.array(text_char_node + text_pos_node)
+ data["label_order"] = np.array(ch_order)
return data
def add_special_char(self, dict_character):
- dict_character = [''] + dict_character
+ dict_character = [""] + dict_character
self.num_character = len(dict_character)
return dict_character
def encode(self, text):
- """
- """
+ """ """
if len(text) == 0 or len(text) > self.max_text_len:
return None, None, None
if self.lower:
@@ -1683,8 +1725,7 @@ def encode(self, text):
continue
text_list.append(self.dict[char])
text_node[self.dict[char]] += 1
- ch_order.append(
- [self.dict[char], text_node[self.dict[char]], order])
+ ch_order.append([self.dict[char], text_node[self.dict[char]], order])
order += 1
no_ch_order = []
@@ -1693,15 +1734,14 @@ def encode(self, text):
no_ch_order.append([self.dict[char], 1, 0])
random.shuffle(no_ch_order)
ch_order = ch_order + no_ch_order
- ch_order = ch_order[:self.max_text_len + 1]
+ ch_order = ch_order[: self.max_text_len + 1]
if len(text_list) == 0:
return None, None, None
return text_list, text_node, ch_order.sort()
def encodech(self, text):
- """
- """
+ """ """
if len(text) == 0 or len(text) > self.max_text_len:
return None, None, None
if self.lower:
@@ -1721,8 +1761,7 @@ def encodech(self, text):
text_node_dict.update({i_c: 1})
for ic in list(text_node_dict.keys()):
character_index.remove(ic)
- none_char_index = sample(character_index,
- 37 - len(list(text_node_dict.keys())))
+ none_char_index = sample(character_index, 37 - len(list(text_node_dict.keys())))
for ic in none_char_index:
text_node_dict[ic] = 0
diff --git a/ppocr/data/imaug/make_border_map.py b/ppocr/data/imaug/make_border_map.py
index 03b7817cfb..9d253196fc 100644
--- a/ppocr/data/imaug/make_border_map.py
+++ b/ppocr/data/imaug/make_border_map.py
@@ -24,7 +24,7 @@
import numpy as np
import cv2
-np.seterr(divide='ignore', invalid='ignore')
+np.seterr(divide="ignore", invalid="ignore")
import pyclipper
from shapely.geometry import Polygon
import sys
@@ -32,28 +32,23 @@
warnings.simplefilter("ignore")
-__all__ = ['MakeBorderMap']
+__all__ = ["MakeBorderMap"]
class MakeBorderMap(object):
- def __init__(self,
- shrink_ratio=0.4,
- thresh_min=0.3,
- thresh_max=0.7,
- **kwargs):
+ def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7, **kwargs):
self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min
self.thresh_max = thresh_max
- if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
- 'epoch'] != "None":
- self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
- 'epoch'] / float(kwargs['total_epoch'])
+ if "total_epoch" in kwargs and "epoch" in kwargs and kwargs["epoch"] != "None":
+ self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs["epoch"] / float(
+ kwargs["total_epoch"]
+ )
def __call__(self, data):
-
- img = data['image']
- text_polys = data['polys']
- ignore_tags = data['ignore_tags']
+ img = data["image"]
+ text_polys = data["polys"]
+ ignore_tags = data["ignore_tags"]
canvas = np.zeros(img.shape[:2], dtype=np.float32)
mask = np.zeros(img.shape[:2], dtype=np.float32)
@@ -64,8 +59,8 @@ def __call__(self, data):
self.draw_border_map(text_polys[i], canvas, mask=mask)
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
- data['threshold_map'] = canvas
- data['threshold_mask'] = mask
+ data["threshold_map"] = canvas
+ data["threshold_mask"] = mask
return data
def draw_border_map(self, polygon, canvas, mask):
@@ -76,8 +71,11 @@ def draw_border_map(self, polygon, canvas, mask):
polygon_shape = Polygon(polygon)
if polygon_shape.area <= 0:
return
- distance = polygon_shape.area * (
- 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+ distance = (
+ polygon_shape.area
+ * (1 - np.power(self.shrink_ratio, 2))
+ / polygon_shape.length
+ )
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -96,14 +94,13 @@ def draw_border_map(self, polygon, canvas, mask):
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
- np.linspace(
- 0, width - 1, num=width).reshape(1, width), (height, width))
+ np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)
+ )
ys = np.broadcast_to(
- np.linspace(
- 0, height - 1, num=height).reshape(height, 1), (height, width))
+ np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)
+ )
- distance_map = np.zeros(
- (polygon.shape[0], height, width), dtype=np.float32)
+ distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
@@ -114,45 +111,49 @@ def draw_border_map(self, polygon, canvas, mask):
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
- canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
- 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
- xmin_valid - xmin:xmax_valid - xmax + width],
- canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
+ 1
+ - distance_map[
+ ymin_valid - ymin : ymax_valid - ymax + height,
+ xmin_valid - xmin : xmax_valid - xmax + width,
+ ],
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
+ )
def _distance(self, xs, ys, point_1, point_2):
- '''
+ """
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
- '''
+ """
height, width = xs.shape[:2]
- square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
- 1])
- square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
- 1])
+ square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
+ square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
- point_1[1] - point_2[1])
+ point_1[1] - point_2[1]
+ )
cosin = (square_distance - square_distance_1 - square_distance_2) / (
- 2 * np.sqrt(square_distance_1 * square_distance_2))
+ 2 * np.sqrt(square_distance_1 * square_distance_2)
+ )
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
- result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
- square_distance)
+ result = np.sqrt(
+ square_distance_1 * square_distance_2 * square_sin / square_distance
+ )
- result[cosin <
- 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
- < 0]
+ result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[
+ cosin < 0
+ ]
# self.extend_line(point_1, point_2, result)
return result
def extend_line(self, point_1, point_2, result, shrink_ratio):
- ex_point_1 = (int(
- round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
- int(
- round(point_1[1] + (point_1[1] - point_2[1]) * (
- 1 + shrink_ratio))))
+ ex_point_1 = (
+ int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
+ int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + shrink_ratio))),
+ )
cv2.line(
result,
tuple(ex_point_1),
@@ -160,12 +161,12 @@ def extend_line(self, point_1, point_2, result, shrink_ratio):
4096.0,
1,
lineType=cv2.LINE_AA,
- shift=0)
- ex_point_2 = (int(
- round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
- int(
- round(point_2[1] + (point_2[1] - point_1[1]) * (
- 1 + shrink_ratio))))
+ shift=0,
+ )
+ ex_point_2 = (
+ int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
+ int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + shrink_ratio))),
+ )
cv2.line(
result,
tuple(ex_point_2),
@@ -173,5 +174,6 @@ def extend_line(self, point_1, point_2, result, shrink_ratio):
4096.0,
1,
lineType=cv2.LINE_AA,
- shift=0)
+ shift=0,
+ )
return ex_point_1, ex_point_2
diff --git a/ppocr/data/imaug/make_pse_gt.py b/ppocr/data/imaug/make_pse_gt.py
index 255d076bde..2b8c78b713 100644
--- a/ppocr/data/imaug/make_pse_gt.py
+++ b/ppocr/data/imaug/make_pse_gt.py
@@ -22,7 +22,7 @@
import pyclipper
from shapely.geometry import Polygon
-__all__ = ['MakePseGt']
+__all__ = ["MakePseGt"]
class MakePseGt(object):
@@ -32,10 +32,9 @@ def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
self.size = size
def __call__(self, data):
-
- image = data['image']
- text_polys = data['polys']
- ignore_tags = data['ignore_tags']
+ image = data["image"]
+ text_polys = data["polys"]
+ ignore_tags = data["ignore_tags"]
h, w, _ = image.shape
short_edge = min(h, w)
@@ -48,34 +47,30 @@ def __call__(self, data):
gt_kernels = []
for i in range(1, self.kernel_num + 1):
# s1->sn, from big to small
- rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
- ) * i
+ rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
text_kernel, ignore_tags = self.generate_kernel(
- image.shape[0:2], rate, text_polys, ignore_tags)
+ image.shape[0:2], rate, text_polys, ignore_tags
+ )
gt_kernels.append(text_kernel)
- training_mask = np.ones(image.shape[0:2], dtype='uint8')
+ training_mask = np.ones(image.shape[0:2], dtype="uint8")
for i in range(text_polys.shape[0]):
if ignore_tags[i]:
- cv2.fillPoly(training_mask,
- text_polys[i].astype(np.int32)[np.newaxis, :, :],
- 0)
+ cv2.fillPoly(
+ training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0
+ )
gt_kernels = np.array(gt_kernels)
gt_kernels[gt_kernels > 0] = 1
- data['image'] = image
- data['polys'] = text_polys
- data['gt_kernels'] = gt_kernels[0:]
- data['gt_text'] = gt_kernels[0]
- data['mask'] = training_mask.astype('float32')
+ data["image"] = image
+ data["polys"] = text_polys
+ data["gt_kernels"] = gt_kernels[0:]
+ data["gt_text"] = gt_kernels[0]
+ data["mask"] = training_mask.astype("float32")
return data
- def generate_kernel(self,
- img_size,
- shrink_ratio,
- text_polys,
- ignore_tags=None):
+ def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
"""
Refer to part of the code:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
@@ -85,8 +80,11 @@ def generate_kernel(self,
text_kernel = np.zeros((h, w), dtype=np.float32)
for i, poly in enumerate(text_polys):
polygon = Polygon(poly)
- distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
- polygon.length + 1e-6)
+ distance = (
+ polygon.area
+ * (1 - shrink_ratio * shrink_ratio)
+ / (polygon.length + 1e-6)
+ )
subject = [tuple(l) for l in poly]
pco = pyclipper.PyclipperOffset()
pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
diff --git a/ppocr/data/imaug/make_shrink_map.py b/ppocr/data/imaug/make_shrink_map.py
index d0317b61fe..d57585ef3b 100644
--- a/ppocr/data/imaug/make_shrink_map.py
+++ b/ppocr/data/imaug/make_shrink_map.py
@@ -26,31 +26,30 @@
from shapely.geometry import Polygon
import pyclipper
-__all__ = ['MakeShrinkMap']
+__all__ = ["MakeShrinkMap"]
class MakeShrinkMap(object):
- r'''
+ r"""
Making binary mask from detection data with ICDAR format.
Typically following the process of class `MakeICDARData`.
- '''
+ """
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
self.min_text_size = min_text_size
self.shrink_ratio = shrink_ratio
- if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[
- 'epoch'] != "None":
- self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[
- 'epoch'] / float(kwargs['total_epoch'])
+ if "total_epoch" in kwargs and "epoch" in kwargs and kwargs["epoch"] != "None":
+ self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs["epoch"] / float(
+ kwargs["total_epoch"]
+ )
def __call__(self, data):
- image = data['image']
- text_polys = data['polys']
- ignore_tags = data['ignore_tags']
+ image = data["image"]
+ text_polys = data["polys"]
+ ignore_tags = data["ignore_tags"]
h, w = image.shape[:2]
- text_polys, ignore_tags = self.validate_polygons(text_polys,
- ignore_tags, h, w)
+ text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)):
@@ -58,33 +57,32 @@ def __call__(self, data):
height = max(polygon[:, 1]) - min(polygon[:, 1])
width = max(polygon[:, 0]) - min(polygon[:, 0])
if ignore_tags[i] or min(height, width) < self.min_text_size:
- cv2.fillPoly(mask,
- polygon.astype(np.int32)[np.newaxis, :, :], 0)
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
polygon_shape = Polygon(polygon)
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
- padding.AddPath(subject, pyclipper.JT_ROUND,
- pyclipper.ET_CLOSEDPOLYGON)
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = []
# Increase the shrink ratio every time we get multiple polygon returned back
- possible_ratios = np.arange(self.shrink_ratio, 1,
- self.shrink_ratio)
+ possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio)
np.append(possible_ratios, 1)
# print(possible_ratios)
for ratio in possible_ratios:
# print(f"Change shrink ratio to {ratio}")
- distance = polygon_shape.area * (
- 1 - np.power(ratio, 2)) / polygon_shape.length
+ distance = (
+ polygon_shape.area
+ * (1 - np.power(ratio, 2))
+ / polygon_shape.length
+ )
shrinked = padding.Execute(-distance)
if len(shrinked) == 1:
break
if shrinked == []:
- cv2.fillPoly(mask,
- polygon.astype(np.int32)[np.newaxis, :, :], 0)
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
@@ -92,14 +90,14 @@ def __call__(self, data):
shirnk = np.array(each_shirnk).reshape(-1, 2)
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
- data['shrink_map'] = gt
- data['shrink_mask'] = mask
+ data["shrink_map"] = gt
+ data["shrink_mask"] = mask
return data
def validate_polygons(self, polygons, ignore_tags, h, w):
- '''
+ """
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
- '''
+ """
if len(polygons) == 0:
return polygons, ignore_tags
assert len(polygons) == len(ignore_tags)
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 4ff2d29ed3..781e30d7ca 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -28,98 +28,96 @@
class DecodeImage(object):
- """ decode image """
+ """decode image"""
- def __init__(self,
- img_mode='RGB',
- channel_first=False,
- ignore_orientation=False,
- **kwargs):
+ def __init__(
+ self, img_mode="RGB", channel_first=False, ignore_orientation=False, **kwargs
+ ):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
- img = data['image']
+ img = data["image"]
if six.PY2:
- assert type(img) is str and len(
- img) > 0, "invalid input 'img' in DecodeImage"
+ assert (
+ type(img) is str and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
else:
- assert type(img) is bytes and len(
- img) > 0, "invalid input 'img' in DecodeImage"
- img = np.frombuffer(img, dtype='uint8')
+ assert (
+ type(img) is bytes and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype="uint8")
if self.ignore_orientation:
- img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
- cv2.IMREAD_COLOR)
+ img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
- if self.img_mode == 'GRAY':
+ if self.img_mode == "GRAY":
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- elif self.img_mode == 'RGB':
- assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
+ elif self.img_mode == "RGB":
+ assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
- data['image'] = img
+ data["image"] = img
return data
class NormalizeImage(object):
- """ normalize image such as substract mean, divide std
- """
+ """normalize image such as substract mean, divide std"""
- def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
+ def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype('float32')
- self.std = np.array(std).reshape(shape).astype('float32')
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
- img = data['image']
+ img = data["image"]
from PIL import Image
+
if isinstance(img, Image.Image):
img = np.array(img)
- assert isinstance(img,
- np.ndarray), "invalid input 'img' in NormalizeImage"
- data['image'] = (
- img.astype('float32') * self.scale - self.mean) / self.std
+ assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+ data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
- """ convert hwc image to chw image
- """
+ """convert hwc image to chw image"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
- img = data['image']
+ img = data["image"]
from PIL import Image
+
if isinstance(img, Image.Image):
img = np.array(img)
- data['image'] = img.transpose((2, 0, 1))
+ data["image"] = img.transpose((2, 0, 1))
return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
import fasttext
+
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
- label = data['label']
+ label = data["label"]
fast_label = self.fast_model[label]
- data['fast_label'] = fast_label
+ data["fast_label"] = fast_label
return data
@@ -137,29 +135,31 @@ def __call__(self, data):
class Pad(object):
def __init__(self, size=None, size_div=32, **kwargs):
if size is not None and not isinstance(size, (int, list, tuple)):
- raise TypeError("Type of target_size is invalid. Now is {}".format(
- type(size)))
+ raise TypeError(
+ "Type of target_size is invalid. Now is {}".format(type(size))
+ )
if isinstance(size, int):
size = [size, size]
self.size = size
self.size_div = size_div
def __call__(self, data):
-
- img = data['image']
+ img = data["image"]
img_h, img_w = img.shape[0], img.shape[1]
if self.size:
resize_h2, resize_w2 = self.size
assert (
img_h < resize_h2 and img_w < resize_w2
- ), '(h, w) of target size should be greater than (img_h, img_w)'
+ ), "(h, w) of target size should be greater than (img_h, img_w)"
else:
resize_h2 = max(
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
- self.size_div)
+ self.size_div,
+ )
resize_w2 = max(
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
- self.size_div)
+ self.size_div,
+ )
img = cv2.copyMakeBorder(
img,
0,
@@ -167,8 +167,9 @@ def __call__(self, data):
0,
resize_w2 - img_w,
cv2.BORDER_CONSTANT,
- value=0)
- data['image'] = img
+ value=0,
+ )
+ data["image"] = img
return data
@@ -185,20 +186,20 @@ def resize_image(self, img):
return img, [ratio_h, ratio_w]
def __call__(self, data):
- img = data['image']
- if 'polys' in data:
- text_polys = data['polys']
+ img = data["image"]
+ if "polys" in data:
+ text_polys = data["polys"]
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
- if 'polys' in data:
+ if "polys" in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
- data['polys'] = np.array(new_boxes, dtype=np.float32)
- data['image'] = img_resize
+ data["polys"] = np.array(new_boxes, dtype=np.float32)
+ data["image"] = img_resize
return data
@@ -207,23 +208,23 @@ def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
- if 'image_shape' in kwargs:
- self.image_shape = kwargs['image_shape']
+ if "image_shape" in kwargs:
+ self.image_shape = kwargs["image_shape"]
self.resize_type = 1
- if 'keep_ratio' in kwargs:
- self.keep_ratio = kwargs['keep_ratio']
- elif 'limit_side_len' in kwargs:
- self.limit_side_len = kwargs['limit_side_len']
- self.limit_type = kwargs.get('limit_type', 'min')
- elif 'resize_long' in kwargs:
+ if "keep_ratio" in kwargs:
+ self.keep_ratio = kwargs["keep_ratio"]
+ elif "limit_side_len" in kwargs:
+ self.limit_side_len = kwargs["limit_side_len"]
+ self.limit_type = kwargs.get("limit_type", "min")
+ elif "resize_long" in kwargs:
self.resize_type = 2
- self.resize_long = kwargs.get('resize_long', 960)
+ self.resize_long = kwargs.get("resize_long", 960)
else:
self.limit_side_len = 736
- self.limit_type = 'min'
+ self.limit_type = "min"
def __call__(self, data):
- img = data['image']
+ img = data["image"]
src_h, src_w, _ = img.shape
if sum([src_h, src_w]) < 64:
img = self.image_padding(img)
@@ -236,8 +237,8 @@ def __call__(self, data):
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
- data['image'] = img
- data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ data["image"] = img
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def image_padding(self, im, value=0):
@@ -271,26 +272,26 @@ def resize_image_type0(self, img):
h, w, c = img.shape
# limit the max side
- if self.limit_type == 'max':
+ if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
- elif self.limit_type == 'min':
+ ratio = 1.0
+ elif self.limit_type == "min":
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
- elif self.limit_type == 'resize_long':
+ ratio = 1.0
+ elif self.limit_type == "resize_long":
ratio = float(limit_side_len) / max(h, w)
else:
- raise Exception('not support limit type, image ')
+ raise Exception("not support limit type, image ")
resize_h = int(h * ratio)
resize_w = int(w * ratio)
@@ -335,24 +336,25 @@ def resize_image_type2(self, img):
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
- self.max_side_len = kwargs['max_side_len']
- self.valid_set = kwargs['valid_set']
+ self.max_side_len = kwargs["max_side_len"]
+ self.valid_set = kwargs["valid_set"]
def __call__(self, data):
- img = data['image']
+ img = data["image"]
src_h, src_w, _ = img.shape
- if self.valid_set == 'totaltext':
+ if self.valid_set == "totaltext":
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
- img, max_side_len=self.max_side_len)
+ img, max_side_len=self.max_side_len
+ )
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
- img, max_side_len=self.max_side_len)
- data['image'] = im_resized
- data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ img, max_side_len=self.max_side_len
+ )
+ data["image"] = im_resized
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
-
h, w, _ = im.shape
resize_w = w
resize_h = h
@@ -404,33 +406,36 @@ def resize_image(self, im, max_side_len=512):
class KieResize(object):
def __init__(self, **kwargs):
super(KieResize, self).__init__()
- self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
- 'img_scale'][1]
+ self.max_side, self.min_side = kwargs["img_scale"][0], kwargs["img_scale"][1]
def __call__(self, data):
- img = data['image']
- points = data['points']
+ img = data["image"]
+ points = data["points"]
src_h, src_w, _ = img.shape
- im_resized, scale_factor, [ratio_h, ratio_w
- ], [new_h, new_w] = self.resize_image(img)
+ (
+ im_resized,
+ scale_factor,
+ [ratio_h, ratio_w],
+ [new_h, new_w],
+ ) = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
- data['ori_image'] = img
- data['ori_boxes'] = points
- data['points'] = resize_points
- data['image'] = im_resized
- data['shape'] = np.array([new_h, new_w])
+ data["ori_image"] = img
+ data["ori_boxes"] = points
+ data["points"] = resize_points
+ data["image"] = im_resized
+ data["shape"] = np.array([new_h, new_w])
return data
def resize_image(self, img):
- norm_img = np.zeros([1024, 1024, 3], dtype='float32')
+ norm_img = np.zeros([1024, 1024, 3], dtype="float32")
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
- scale_factor = min(max_long_edge / max(h, w),
- max_short_edge / min(h, w))
- resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
- scale_factor) + 0.5)
+ scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
+ resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(
+ h * float(scale_factor) + 0.5
+ )
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
@@ -438,8 +443,7 @@ def resize_image(self, img):
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
- scale_factor = np.array(
- [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
@@ -452,15 +456,17 @@ def resize_boxes(self, im, points, scale_factor):
class SRResize(object):
- def __init__(self,
- imgH=32,
- imgW=128,
- down_sample_scale=4,
- keep_ratio=False,
- min_ratio=1,
- mask=False,
- infer_mode=False,
- **kwargs):
+ def __init__(
+ self,
+ imgH=32,
+ imgW=128,
+ down_sample_scale=4,
+ keep_ratio=False,
+ min_ratio=1,
+ mask=False,
+ infer_mode=False,
+ **kwargs
+ ):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
@@ -474,7 +480,8 @@ def __call__(self, data):
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
- (imgW // self.down_sample_scale, imgH // self.down_sample_scale))
+ (imgW // self.down_sample_scale, imgH // self.down_sample_scale)
+ )
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
@@ -504,21 +511,21 @@ class GrayImageChannelFormat(object):
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
- inverse: inverse gray image
+ inverse: inverse gray image
"""
def __init__(self, inverse=False, **kwargs):
self.inverse = inverse
def __call__(self, data):
- img = data['image']
+ img = data["image"]
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_expanded = np.expand_dims(img_single_channel, 0)
if self.inverse:
- data['image'] = np.abs(img_expanded - 1)
+ data["image"] = np.abs(img_expanded - 1)
else:
- data['image'] = img_expanded
+ data["image"] = img_expanded
- data['src_image'] = img
- return data
\ No newline at end of file
+ data["src_image"] = img
+ return data
diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py
index f1e5f912b7..08f4dad465 100644
--- a/ppocr/data/imaug/pg_process.py
+++ b/ppocr/data/imaug/pg_process.py
@@ -16,25 +16,29 @@
import cv2
import numpy as np
from skimage.morphology._skeletonize import thin
-from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2
+from ppocr.utils.e2e_utils.extract_textpoint_fast import (
+ sort_and_expand_with_direction_v2,
+)
-__all__ = ['PGProcessTrain']
+__all__ = ["PGProcessTrain"]
class PGProcessTrain(object):
- def __init__(self,
- character_dict_path,
- max_text_length,
- max_text_nums,
- tcl_len,
- batch_size=14,
- use_resize=True,
- use_random_crop=False,
- min_crop_size=24,
- min_text_size=4,
- max_text_size=512,
- point_gather_mode=None,
- **kwargs):
+ def __init__(
+ self,
+ character_dict_path,
+ max_text_length,
+ max_text_nums,
+ tcl_len,
+ batch_size=14,
+ use_resize=True,
+ use_random_crop=False,
+ min_crop_size=24,
+ min_text_size=4,
+ max_text_size=512,
+ point_gather_mode=None,
+ **kwargs
+ ):
self.tcl_len = tcl_len
self.max_text_length = max_text_length
self.max_text_nums = max_text_nums
@@ -55,7 +59,7 @@ def get_dict(self, character_dict_path):
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
@@ -66,11 +70,13 @@ def quad_area(self, poly):
:param poly:
:return:
"""
- edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
- (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
- (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
- (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
- return np.sum(edge) / 2.
+ edge = [
+ (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
+ ]
+ return np.sum(edge) / 2.0
def gen_quad_from_poly(self, poly):
"""
@@ -78,17 +84,20 @@ def gen_quad_from_poly(self, poly):
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
- rect = cv2.minAreaRect(poly.astype(
- np.int32)) # (center (x,y), (width, height), angle of rotation)
+ rect = cv2.minAreaRect(
+ poly.astype(np.int32)
+ ) # (center (x,y), (width, height), angle of rotation)
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
- dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
- np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
- np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
- np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ dist = (
+ np.linalg.norm(box[(i + 0) % 4] - poly[0])
+ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
+ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
+ + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ )
if dist < min_dist:
min_dist = dist
first_point_idx = i
@@ -118,20 +127,21 @@ def check_and_validate_polys(self, polys, tags, im_size):
quad = self.gen_quad_from_poly(poly)
p_area = self.quad_area(quad)
if abs(p_area) < 1:
- print('invalid poly')
+ print("invalid poly")
continue
if p_area > 0:
if tag == False:
- print('poly in wrong direction')
+ print("poly in wrong direction")
tag = True # reversed cases should be ignore
- poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
- 1), :]
+ poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
quad = quad[(0, 3, 2, 1), :]
- len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
- quad[2])
- len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
- quad[2])
+ len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(
+ quad[3] - quad[2]
+ )
+ len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(
+ quad[1] - quad[2]
+ )
hv_tag = 1
if len_w * 2.0 < len_h:
@@ -140,17 +150,11 @@ def check_and_validate_polys(self, polys, tags, im_size):
validated_polys.append(poly)
validated_tags.append(tag)
hv_tags.append(hv_tag)
- return np.array(validated_polys), np.array(validated_tags), np.array(
- hv_tags)
-
- def crop_area(self,
- im,
- polys,
- tags,
- hv_tags,
- txts,
- crop_background=False,
- max_tries=25):
+ return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
+
+ def crop_area(
+ self, im, polys, tags, hv_tags, txts, crop_background=False, max_tries=25
+ ):
"""
make random crop from the input image
:param im:
@@ -169,10 +173,10 @@ def crop_area(self,
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
- w_array[minx + pad_w:maxx + pad_w] = 1
+ w_array[minx + pad_w : maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
- h_array[miny + pad_h:maxy + pad_h] = 1
+ h_array[miny + pad_h : maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
@@ -189,14 +193,16 @@ def crop_area(self,
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
- if xmax - xmin < self.min_crop_size or \
- ymax - ymin < self.min_crop_size:
+ if xmax - xmin < self.min_crop_size or ymax - ymin < self.min_crop_size:
continue
if polys.shape[0] != 0:
- poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
- & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
- selected_polys = np.where(
- np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ poly_axis_in_area = (
+ (polys[:, :, 0] >= xmin)
+ & (polys[:, :, 0] <= xmax)
+ & (polys[:, :, 1] >= ymin)
+ & (polys[:, :, 1] <= ymax)
+ )
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
@@ -206,11 +212,16 @@ def crop_area(self,
for selected_poly in selected_polys:
txts_tmp.append(txts[selected_poly])
txts = txts_tmp
- return im[ymin: ymax + 1, xmin: xmax + 1, :], \
- polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
+ return (
+ im[ymin : ymax + 1, xmin : xmax + 1, :],
+ polys[selected_polys],
+ tags[selected_polys],
+ hv_tags[selected_polys],
+ txts,
+ )
else:
continue
- im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+ im = im[ymin : ymax + 1, xmin : xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys]
@@ -224,14 +235,16 @@ def crop_area(self,
return im, polys, tags, hv_tags, txts
- def fit_and_gather_tcl_points_v2(self,
- min_area_quad,
- poly,
- max_h,
- max_w,
- fixed_point_num=64,
- img_id=0,
- reference_height=3):
+ def fit_and_gather_tcl_points_v2(
+ self,
+ min_area_quad,
+ poly,
+ max_h,
+ max_w,
+ fixed_point_num=64,
+ img_id=0,
+ reference_height=3,
+ ):
"""
Find the center point of poly as key_points, then fit and gather.
"""
@@ -244,22 +257,21 @@ def fit_and_gather_tcl_points_v2(self,
tmp_image = np.zeros(
shape=(
max_h,
- max_w, ), dtype='float32')
- cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
- False, 1.0)
+ max_w,
+ ),
+ dtype="float32",
+ )
+ cv2.polylines(tmp_image, [np.array(key_point_xys).astype("int32")], False, 1.0)
ys, xs = np.where(tmp_image > 0)
- xy_text = np.array(list(zip(xs, ys)), dtype='float32')
+ xy_text = np.array(list(zip(xs, ys)), dtype="float32")
- left_center_pt = (
- (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
- right_center_pt = (
- (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
+ left_center_pt = ((min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
+ right_center_pt = ((min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / (
- np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
- proj_unit_vec_tile = np.tile(proj_unit_vec,
- (xy_text.shape[0], 1)) # (n, 2)
- left_center_pt_tile = np.tile(left_center_pt,
- (xy_text.shape[0], 1)) # (n, 2)
+ np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
+ )
+ proj_unit_vec_tile = np.tile(proj_unit_vec, (xy_text.shape[0], 1)) # (n, 2)
+ left_center_pt_tile = np.tile(left_center_pt, (xy_text.shape[0], 1)) # (n, 2)
xy_text_to_left_center = xy_text - left_center_pt_tile
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
xy_text = xy_text[np.argsort(proj_value)]
@@ -277,77 +289,86 @@ def fit_and_gather_tcl_points_v2(self,
keep = int(min(len(pos_info), fixed_point_num))
if np.random.rand() < 0.2 and reference_height >= 3:
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
- random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
- [keep, 1])
+ random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape([keep, 1])
pos_info += random_float
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
# padding to fixed length
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
- pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
+ pos_l[:, 0] = np.ones((self.tcl_len,)) * img_id
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
pos_m[:keep] = 1.0
return pos_l, pos_m
- def fit_and_gather_tcl_points_v3(self,
- min_area_quad,
- poly,
- max_h,
- max_w,
- fixed_point_num=64,
- img_id=0,
- reference_height=3):
+ def fit_and_gather_tcl_points_v3(
+ self,
+ min_area_quad,
+ poly,
+ max_h,
+ max_w,
+ fixed_point_num=64,
+ img_id=0,
+ reference_height=3,
+ ):
"""
Find the center point of poly as key_points, then fit and gather.
"""
- det_mask = np.zeros((int(max_h / self.ds_ratio),
- int(max_w / self.ds_ratio))).astype(np.float32)
+ det_mask = np.zeros(
+ (int(max_h / self.ds_ratio), int(max_w / self.ds_ratio))
+ ).astype(np.float32)
# score_big_map
- cv2.fillPoly(det_mask,
- np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
- det_mask = cv2.resize(
- det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
- det_mask = np.array(det_mask > 1e-3, dtype='float32')
+ cv2.fillPoly(det_mask, np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
+ det_mask = cv2.resize(det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
+ det_mask = np.array(det_mask > 1e-3, dtype="float32")
f_direction = self.f_direction
skeleton_map = thin(det_mask.astype(np.uint8))
instance_count, instance_label_map = cv2.connectedComponents(
- skeleton_map.astype(np.uint8), connectivity=8)
+ skeleton_map.astype(np.uint8), connectivity=8
+ )
ys, xs = np.where(instance_label_map == 1)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
return None
pos_list_sorted = sort_and_expand_with_direction_v2(
- pos_list, f_direction, det_mask)
+ pos_list, f_direction, det_mask
+ )
pos_list_sorted = np.array(pos_list_sorted)
length = len(pos_list_sorted) - 1
insert_num = 0
for index in range(length):
- stride_y = np.abs(pos_list_sorted[index + insert_num][0] -
- pos_list_sorted[index + 1 + insert_num][0])
- stride_x = np.abs(pos_list_sorted[index + insert_num][1] -
- pos_list_sorted[index + 1 + insert_num][1])
+ stride_y = np.abs(
+ pos_list_sorted[index + insert_num][0]
+ - pos_list_sorted[index + 1 + insert_num][0]
+ )
+ stride_x = np.abs(
+ pos_list_sorted[index + insert_num][1]
+ - pos_list_sorted[index + 1 + insert_num][1]
+ )
max_points = int(max(stride_x, stride_y))
- stride = (pos_list_sorted[index + insert_num] -
- pos_list_sorted[index + 1 + insert_num]) / (max_points)
+ stride = (
+ pos_list_sorted[index + insert_num]
+ - pos_list_sorted[index + 1 + insert_num]
+ ) / (max_points)
insert_num_temp = max_points - 1
for i in range(int(insert_num_temp)):
- insert_value = pos_list_sorted[index + insert_num] - (i + 1
- ) * stride
+ insert_value = pos_list_sorted[index + insert_num] - (i + 1) * stride
insert_index = index + i + 1 + insert_num
pos_list_sorted = np.insert(
- pos_list_sorted, insert_index, insert_value, axis=0)
+ pos_list_sorted, insert_index, insert_value, axis=0
+ )
insert_num += insert_num_temp
- pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype(
- np.float32) # xy-> yx
+ pos_info = (
+ np.array(pos_list_sorted).reshape(-1, 2).astype(np.float32)
+ ) # xy-> yx
point_num = len(pos_info)
if point_num > fixed_point_num:
@@ -358,14 +379,15 @@ def fit_and_gather_tcl_points_v3(self,
pos_info = pos_info[keep_ids, :]
keep = int(min(len(pos_info), fixed_point_num))
- reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) +
- np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2
+ reference_width = (
+ np.abs(poly[0, 0, 0] - poly[-1, 1, 0])
+ + np.abs(poly[0, 3, 0] - poly[-1, 2, 0])
+ ) // 2
if np.random.rand() < 1:
dh = (np.random.rand(keep) - 0.5) * reference_height
offset = np.random.rand() - 0.5
dw = np.array([[0, offset * reference_width * 0.2]])
- random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape(
- [keep, 1])
+ random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape([keep, 1])
random_float_w = dw.repeat(keep, axis=0)
pos_info += random_float_h
pos_info += random_float_w
@@ -374,61 +396,68 @@ def fit_and_gather_tcl_points_v3(self,
# padding to fixed length
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
- pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
+ pos_l[:, 0] = np.ones((self.tcl_len,)) * img_id
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
pos_m[:keep] = 1.0
return pos_l, pos_m
def generate_direction_map(self, poly_quads, n_char, direction_map):
- """
- """
+ """ """
width_list = []
height_list = []
for quad in poly_quads:
- quad_w = (np.linalg.norm(quad[0] - quad[1]) +
- np.linalg.norm(quad[2] - quad[3])) / 2.0
- quad_h = (np.linalg.norm(quad[0] - quad[3]) +
- np.linalg.norm(quad[2] - quad[1])) / 2.0
+ quad_w = (
+ np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
+ ) / 2.0
+ quad_h = (
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
+ ) / 2.0
width_list.append(quad_w)
height_list.append(quad_h)
norm_width = max(sum(width_list) / n_char, 1.0)
average_height = max(sum(height_list) / len(height_list), 1.0)
k = 1
for quad in poly_quads:
- direct_vector_full = (
- (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
- direct_vector = direct_vector_full / (
- np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
+ direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
+ direct_vector = (
+ direct_vector_full
+ / (np.linalg.norm(direct_vector_full) + 1e-6)
+ * norm_width
+ )
direction_label = tuple(
- map(float,
- [direct_vector[0], direct_vector[1], 1.0 / average_height]))
- cv2.fillPoly(direction_map,
- quad.round().astype(np.int32)[np.newaxis, :, :],
- direction_label)
+ map(float, [direct_vector[0], direct_vector[1], 1.0 / average_height])
+ )
+ cv2.fillPoly(
+ direction_map,
+ quad.round().astype(np.int32)[np.newaxis, :, :],
+ direction_label,
+ )
k += 1
return direction_map
def calculate_average_height(self, poly_quads):
- """
- """
+ """ """
height_list = []
for quad in poly_quads:
- quad_h = (np.linalg.norm(quad[0] - quad[3]) +
- np.linalg.norm(quad[2] - quad[1])) / 2.0
+ quad_h = (
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
+ ) / 2.0
height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height
- def generate_tcl_ctc_label(self,
- h,
- w,
- polys,
- tags,
- text_strs,
- ds_ratio,
- tcl_ratio=0.3,
- shrink_ratio_of_width=0.15):
+ def generate_tcl_ctc_label(
+ self,
+ h,
+ w,
+ polys,
+ tags,
+ text_strs,
+ ds_ratio,
+ tcl_ratio=0.3,
+ shrink_ratio_of_width=0.15,
+ ):
"""
Generate polygon.
"""
@@ -436,25 +465,38 @@ def generate_tcl_ctc_label(self,
score_map_big = np.zeros(
(
h,
- w, ), dtype=np.float32)
+ w,
+ ),
+ dtype=np.float32,
+ )
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
score_map = np.zeros(
(
h,
- w, ), dtype=np.float32)
+ w,
+ ),
+ dtype=np.float32,
+ )
score_label_map = np.zeros(
(
h,
- w, ), dtype=np.float32)
+ w,
+ ),
+ dtype=np.float32,
+ )
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones(
(
h,
- w, ), dtype=np.float32)
+ w,
+ ),
+ dtype=np.float32,
+ )
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
- [1, 1, 3]).astype(np.float32)
+ [1, 1, 3]
+ ).astype(np.float32)
label_idx = 0
score_label_map_text_label_list = []
@@ -466,26 +508,32 @@ def generate_tcl_ctc_label(self,
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
- np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
- np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3])
+ + np.linalg.norm(min_area_quad[1] - min_area_quad[2])
+ )
min_area_quad_w = 0.5 * (
- np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
- np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
-
- if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
- or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1])
+ + np.linalg.norm(min_area_quad[2] - min_area_quad[3])
+ )
+
+ if (
+ min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio
+ or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio
+ ):
continue
if tag:
- cv2.fillPoly(training_mask,
- poly.astype(np.int32)[np.newaxis, :, :], 0.15)
+ cv2.fillPoly(
+ training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15
+ )
else:
text_label = text_strs[poly_idx]
- text_label = self.prepare_text_label(text_label,
- self.Lexicon_Table)
- text_label_index_list = [[self.Lexicon_Table.index(c_)]
- for c_ in text_label
- if c_ in self.Lexicon_Table]
+ text_label = self.prepare_text_label(text_label, self.Lexicon_Table)
+ text_label_index_list = [
+ [self.Lexicon_Table.index(c_)]
+ for c_ in text_label
+ if c_ in self.Lexicon_Table
+ ]
if len(text_label_index_list) < 1:
continue
@@ -496,42 +544,48 @@ def generate_tcl_ctc_label(self,
stcl_quads, quad_index = self.shrink_poly_along_width(
tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width,
- expand_height_ratio=1.0 / tcl_ratio)
+ expand_height_ratio=1.0 / tcl_ratio,
+ )
- cv2.fillPoly(score_map,
- np.round(stcl_quads).astype(np.int32), 1.0)
- cv2.fillPoly(score_map_big,
- np.round(stcl_quads / ds_ratio).astype(np.int32),
- 1.0)
+ cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
+ cv2.fillPoly(
+ score_map_big, np.round(stcl_quads / ds_ratio).astype(np.int32), 1.0
+ )
for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(
quad_mask,
- np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
- tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
- quad_mask, tbo_map)
+ np.round(quad[np.newaxis, :, :]).astype(np.int32),
+ 1.0,
+ )
+ tbo_map = self.gen_quad_tbo(
+ poly_quads[quad_index[idx]], quad_mask, tbo_map
+ )
# score label map and score_label_map_text_label_list for refine
if label_idx == 0:
- text_pos_list_ = [[len(self.Lexicon_Table)], ]
+ text_pos_list_ = [
+ [len(self.Lexicon_Table)],
+ ]
score_label_map_text_label_list.append(text_pos_list_)
label_idx += 1
- cv2.fillPoly(score_label_map,
- np.round(poly_quads).astype(np.int32), label_idx)
+ cv2.fillPoly(
+ score_label_map, np.round(poly_quads).astype(np.int32), label_idx
+ )
score_label_map_text_label_list.append(text_label_index_list)
# direction info, fix-me
n_char = len(text_label_index_list)
- direction_map = self.generate_direction_map(poly_quads, n_char,
- direction_map)
+ direction_map = self.generate_direction_map(
+ poly_quads, n_char, direction_map
+ )
# pos info
- average_shrink_height = self.calculate_average_height(
- stcl_quads)
+ average_shrink_height = self.calculate_average_height(stcl_quads)
- if self.point_gather_mode == 'align':
+ if self.point_gather_mode == "align":
self.f_direction = direction_map[:, :, :-1].copy()
pos_res = self.fit_and_gather_tcl_points_v3(
min_area_quad,
@@ -540,7 +594,8 @@ def generate_tcl_ctc_label(self,
max_w=w,
fixed_point_num=64,
img_id=self.img_id,
- reference_height=average_shrink_height)
+ reference_height=average_shrink_height,
+ )
if pos_res is None:
continue
pos_l, pos_m = pos_res[0], pos_res[1]
@@ -553,7 +608,8 @@ def generate_tcl_ctc_label(self,
max_w=w,
fixed_point_num=64,
img_id=self.img_id,
- reference_height=average_shrink_height)
+ reference_height=average_shrink_height,
+ )
label_l = text_label_index_list
if len(text_label_index_list) < 2:
@@ -565,11 +621,21 @@ def generate_tcl_ctc_label(self,
# use big score_map for smooth tcl lines
score_map_big_resized = cv2.resize(
- score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
- score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
-
- return score_map, score_label_map, tbo_map, direction_map, training_mask, \
- pos_list, pos_mask, label_list, score_label_map_text_label_list
+ score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio
+ )
+ score_map = np.array(score_map_big_resized > 1e-3, dtype="float32")
+
+ return (
+ score_map,
+ score_label_map,
+ tbo_map,
+ direction_map,
+ training_mask,
+ pos_list,
+ pos_mask,
+ label_list,
+ score_label_map_text_label_list,
+ )
def adjust_point(self, poly):
"""
@@ -589,7 +655,8 @@ def adjust_point(self, poly):
vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (
- np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
+ np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6
+ )
theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi):
@@ -607,18 +674,21 @@ def gen_min_area_quad_from_poly(self, poly):
min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4
else:
- rect = cv2.minAreaRect(poly.astype(
- np.int32)) # (center (x,y), (width, height), angle of rotation)
+ rect = cv2.minAreaRect(
+ poly.astype(np.int32)
+ ) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
- dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
- np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
- np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
- np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ dist = (
+ np.linalg.norm(box[(i + 0) % 4] - poly[0])
+ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
+ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
+ + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ )
if dist < min_dist:
min_dist = dist
first_point_idx = i
@@ -628,23 +698,20 @@ def gen_min_area_quad_from_poly(self, poly):
return min_area_quad, center_point
- def shrink_quad_along_width(self,
- quad,
- begin_width_ratio=0.,
- end_width_ratio=1.):
+ def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
- [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
+ )
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
- def shrink_poly_along_width(self,
- quads,
- shrink_ratio_of_width,
- expand_height_ratio=1.0):
+ def shrink_poly_along_width(
+ self, quads, shrink_ratio_of_width, expand_height_ratio=1.0
+ ):
"""
shrink poly with given length.
"""
@@ -662,28 +729,30 @@ def get_cut_info(edge_len_list, cut_len):
upper_edge_list.append(upper_edge_len)
# length of left edge and right edge.
- left_length = np.linalg.norm(quads[0][0] - quads[0][
- 3]) * expand_height_ratio
- right_length = np.linalg.norm(quads[-1][1] - quads[-1][
- 2]) * expand_height_ratio
+ left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
+ right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
- shrink_length = min(left_length, right_length,
- sum(upper_edge_list)) * shrink_ratio_of_width
+ shrink_length = (
+ min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
+ )
# shrinking length
upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(
- quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
+ quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1
+ )
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(
- quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
+ quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio
+ )
out_quad_list = []
if left_idx == right_idx:
out_quad_list.append(
- [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
+ [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]
+ )
else:
out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx):
@@ -742,7 +811,7 @@ def line_cross_point(self, line1, line2):
d = a1 * b2 - a2 * b1
if d == 0:
- print('Cross point does not exist')
+ print("Cross point does not exist")
return np.array([0, 0], dtype=np.float32)
else:
x = (b1 * c2 - b2 * c1) / d
@@ -754,8 +823,7 @@ def quad2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point. (4, 2)
"""
- ratio_pair = np.array(
- [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
@@ -764,14 +832,14 @@ def poly2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point.
"""
- ratio_pair = np.array(
- [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0]
for idx in range(point_num // 2):
- point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
- ) * ratio_pair
+ point_pair = (
+ poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
+ )
tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly
@@ -784,10 +852,12 @@ def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2])
- quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
- np.linalg.norm(quad[1] - quad[2]))
- quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
- np.linalg.norm(quad[2] - quad[3]))
+ quad_h = 0.5 * (
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
+ )
+ quad_w = 0.5 * (
+ np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
+ )
# average angle of left and right line.
angle = self.average_angle(quad)
@@ -824,8 +894,9 @@ def poly2quads(self, poly):
quad_num = point_num // 2 - 1
for idx in range(quad_num):
# reshape and adjust to clock-wise
- quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
- ).reshape(4, 2)[[0, 2, 3, 1]])
+ quad_list.append(
+ (np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]
+ )
return np.array(quad_list)
@@ -852,23 +923,30 @@ def rotate_im_poly(self, im, text_polys):
poly = []
for j in range(4): # 16->4
sx, sy = wordBB[j][0], wordBB[j][1]
- dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
- sy - cy) + ncx
- dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
- sy - cy) + ncy
+ dx = (
+ math.cos(rot_angle) * (sx - cx)
+ - math.sin(rot_angle) * (sy - cy)
+ + ncx
+ )
+ dy = (
+ math.sin(rot_angle) * (sx - cx)
+ + math.cos(rot_angle) * (sy - cy)
+ + ncy
+ )
poly.append([dx, dy])
dst_polys.append(poly)
return dst_im, np.array(dst_polys, dtype=np.float32)
def __call__(self, data):
input_size = 512
- im = data['image']
- text_polys = data['polys']
- text_tags = data['ignore_tags']
- text_strs = data['texts']
+ im = data["image"]
+ text_polys = data["polys"]
+ text_tags = data["ignore_tags"]
+ text_strs = data["texts"]
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
- text_polys, text_tags, (h, w))
+ text_polys, text_tags, (h, w)
+ )
if text_polys.shape[0] <= 0:
return None
# set aspect ratio and keep area fix
@@ -909,12 +987,8 @@ def __call__(self, data):
# no background
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
- im,
- text_polys,
- text_tags,
- hv_tags,
- text_strs,
- crop_background=False)
+ im, text_polys, text_tags, hv_tags, text_strs, crop_background=False
+ )
if text_polys.shape[0] == 0:
return None
@@ -927,7 +1001,8 @@ def __call__(self, data):
# resize image
std_ratio = float(input_size) / max(new_w, new_h)
rand_scales = np.array(
- [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
+ [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]
+ )
rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale
@@ -966,16 +1041,23 @@ def __call__(self, data):
sw = int(np.random.rand() * del_w)
# Padding
- im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
+ im_padded[sh : sh + new_h, sw : sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh
- score_map, score_label_map, border_map, direction_map, training_mask, \
- pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
- input_size,
- text_polys,
- text_tags,
- text_strs, 0.25)
+ (
+ score_map,
+ score_label_map,
+ border_map,
+ direction_map,
+ training_mask,
+ pos_list,
+ pos_mask,
+ label_list,
+ score_label_map_text_label,
+ ) = self.generate_tcl_ctc_label(
+ input_size, input_size, text_polys, text_tags, text_strs, 0.25
+ )
if len(label_list) <= 0: # eliminate negative samples
return None
pos_list_temp = np.zeros([64, 3])
@@ -985,7 +1067,7 @@ def __call__(self, data):
for i, label in enumerate(label_list):
n = len(label)
if n > self.max_text_length:
- label_list[i] = label[:self.max_text_length]
+ label_list[i] = label[: self.max_text_length]
continue
while n < self.max_text_length:
label.append([self.pad_num])
@@ -1009,9 +1091,9 @@ def __call__(self, data):
im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255
- im_padded[:, :, 2] /= (255.0 * 0.229)
- im_padded[:, :, 1] /= (255.0 * 0.224)
- im_padded[:, :, 0] /= (255.0 * 0.225)
+ im_padded[:, :, 2] /= 255.0 * 0.229
+ im_padded[:, :, 1] /= 255.0 * 0.224
+ im_padded[:, :, 0] /= 255.0 * 0.225
im_padded = im_padded.transpose((2, 0, 1))
images = im_padded[::-1, :, :]
tcl_maps = score_map[np.newaxis, :, :]
@@ -1022,13 +1104,13 @@ def __call__(self, data):
pos_list = np.array(pos_list)
pos_mask = np.array(pos_mask)
label_list = np.array(label_list)
- data['images'] = images
- data['tcl_maps'] = tcl_maps
- data['tcl_label_maps'] = tcl_label_maps
- data['border_maps'] = border_maps
- data['direction_maps'] = direction_maps
- data['training_masks'] = training_masks
- data['label_list'] = label_list
- data['pos_list'] = pos_list
- data['pos_mask'] = pos_mask
+ data["images"] = images
+ data["tcl_maps"] = tcl_maps
+ data["tcl_label_maps"] = tcl_label_maps
+ data["border_maps"] = border_maps
+ data["direction_maps"] = direction_maps
+ data["training_masks"] = training_masks
+ data["label_list"] = label_list
+ data["pos_list"] = pos_list
+ data["pos_mask"] = pos_mask
return data
diff --git a/ppocr/data/imaug/randaugment.py b/ppocr/data/imaug/randaugment.py
index 56f114d2f6..5f29f325d5 100644
--- a/ppocr/data/imaug/randaugment.py
+++ b/ppocr/data/imaug/randaugment.py
@@ -24,11 +24,7 @@
class RawRandAugment(object):
- def __init__(self,
- num_layers=2,
- magnitude=5,
- fillcolor=(128, 128, 128),
- **kwargs):
+ def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128), **kwargs):
self.num_layers = num_layers
self.magnitude = magnitude
self.max_level = 10
@@ -48,16 +44,16 @@ def __init__(self,
"brightness": 0.9 * abso_level,
"autocontrast": 0,
"equalize": 0,
- "invert": 0
+ "invert": 0,
}
# from https://stackoverflow.com/questions/5252170/
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
- return Image.composite(rot,
- Image.new("RGBA", rot.size, (128, ) * 4),
- rot).convert(img.mode)
+ return Image.composite(
+ rot, Image.new("RGBA", rot.size, (128,) * 4), rot
+ ).convert(img.mode)
rnd_ch_op = random.choice
@@ -67,43 +63,45 @@ def rotate_with_fill(img, magnitude):
Image.AFFINE,
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
- fillcolor=fillcolor),
+ fillcolor=fillcolor,
+ ),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
Image.BICUBIC,
- fillcolor=fillcolor),
+ fillcolor=fillcolor,
+ ),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
- fillcolor=fillcolor),
+ fillcolor=fillcolor,
+ ),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
- fillcolor=fillcolor),
+ fillcolor=fillcolor,
+ ),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
- 1 + magnitude * rnd_ch_op([-1, 1])),
- "posterize": lambda img, magnitude:
- ImageOps.posterize(img, magnitude),
- "solarize": lambda img, magnitude:
- ImageOps.solarize(img, magnitude),
- "contrast": lambda img, magnitude:
- ImageEnhance.Contrast(img).enhance(
- 1 + magnitude * rnd_ch_op([-1, 1])),
- "sharpness": lambda img, magnitude:
- ImageEnhance.Sharpness(img).enhance(
- 1 + magnitude * rnd_ch_op([-1, 1])),
- "brightness": lambda img, magnitude:
- ImageEnhance.Brightness(img).enhance(
- 1 + magnitude * rnd_ch_op([-1, 1])),
- "autocontrast": lambda img, magnitude:
- ImageOps.autocontrast(img),
+ 1 + magnitude * rnd_ch_op([-1, 1])
+ ),
+ "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
+ "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
+ "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
+ 1 + magnitude * rnd_ch_op([-1, 1])
+ ),
+ "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
+ 1 + magnitude * rnd_ch_op([-1, 1])
+ ),
+ "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
+ 1 + magnitude * rnd_ch_op([-1, 1])
+ ),
+ "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
- "invert": lambda img, magnitude: ImageOps.invert(img)
+ "invert": lambda img, magnitude: ImageOps.invert(img),
}
def __call__(self, img):
@@ -115,7 +113,7 @@ def __call__(self, img):
class RandAugment(RawRandAugment):
- """ RandAugment wrapper to auto fit different img types """
+ """RandAugment wrapper to auto fit different img types"""
def __init__(self, prob=0.5, *args, **kwargs):
self.prob = prob
@@ -127,7 +125,7 @@ def __init__(self, prob=0.5, *args, **kwargs):
def __call__(self, data):
if np.random.rand() > self.prob:
return data
- img = data['image']
+ img = data["image"]
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
@@ -139,5 +137,5 @@ def __call__(self, data):
if isinstance(img, Image.Image):
img = np.asarray(img)
- data['image'] = img
+ data["image"] = img
return data
diff --git a/ppocr/data/imaug/random_crop_data.py b/ppocr/data/imaug/random_crop_data.py
index 64aa110de4..f3f2e3eeac 100644
--- a/ppocr/data/imaug/random_crop_data.py
+++ b/ppocr/data/imaug/random_crop_data.py
@@ -108,13 +108,15 @@ def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
else:
ymin, ymax = random_select(h_axis, h)
- if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
+ if (
+ xmax - xmin < min_crop_side_ratio * w
+ or ymax - ymin < min_crop_side_ratio * h
+ ):
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
- if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
- ymax - ymin):
+ if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
num_poly_in_rect += 1
break
@@ -125,28 +127,29 @@ def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
class EastRandomCropData(object):
- def __init__(self,
- size=(640, 640),
- max_tries=10,
- min_crop_side_ratio=0.1,
- keep_ratio=True,
- **kwargs):
+ def __init__(
+ self,
+ size=(640, 640),
+ max_tries=10,
+ min_crop_side_ratio=0.1,
+ keep_ratio=True,
+ **kwargs
+ ):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
self.keep_ratio = keep_ratio
def __call__(self, data):
- img = data['image']
- text_polys = data['polys']
- ignore_tags = data['ignore_tags']
- texts = data['texts']
- all_care_polys = [
- text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
- ]
+ img = data["image"]
+ text_polys = data["polys"]
+ ignore_tags = data["ignore_tags"]
+ texts = data["texts"]
+ all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(
- img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
+ img, all_care_polys, self.min_crop_side_ratio, self.max_tries
+ )
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
@@ -154,15 +157,16 @@ def __call__(self, data):
h = int(crop_h * scale)
w = int(crop_w * scale)
if self.keep_ratio:
- padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
- img.dtype)
+ padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), img.dtype)
padimg[:h, :w] = cv2.resize(
- img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+ img[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w], (w, h)
+ )
img = padimg
else:
img = cv2.resize(
- img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
- tuple(self.size))
+ img[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w],
+ tuple(self.size),
+ )
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
@@ -173,10 +177,10 @@ def __call__(self, data):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
- data['image'] = img
- data['polys'] = np.array(text_polys_crop)
- data['ignore_tags'] = ignore_tags_crop
- data['texts'] = texts_crop
+ data["image"] = img
+ data["polys"] = np.array(text_polys_crop)
+ data["ignore_tags"] = ignore_tags_crop
+ data["texts"] = texts_crop
return data
@@ -188,7 +192,7 @@ def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
self.p = p
def __call__(self, data):
- image = data['image']
+ image = data["image"]
h, w = image.shape[0:2]
th, tw = self.size
@@ -217,17 +221,17 @@ def __call__(self, data):
if k in self.crop_keys:
if len(data[k].shape) == 3:
if np.argmin(data[k].shape) == 0:
- img = data[k][:, i:i + th, j:j + tw]
+ img = data[k][:, i : i + th, j : j + tw]
if img.shape[1] != img.shape[2]:
a = 1
elif np.argmin(data[k].shape) == 2:
- img = data[k][i:i + th, j:j + tw, :]
+ img = data[k][i : i + th, j : j + tw, :]
if img.shape[1] != img.shape[0]:
a = 1
else:
img = data[k]
else:
- img = data[k][i:i + th, j:j + tw]
+ img = data[k][i : i + th, j : j + tw]
if img.shape[0] != img.shape[1]:
a = 1
data[k] = img
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 0bf15114d5..a79436021e 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -20,26 +20,36 @@
from PIL import Image
import PIL
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
-from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration, ParseQDeterioration
+from .abinet_aug import (
+ CVGeometry,
+ CVDeterioration,
+ CVColorJitter,
+ SVTRGeometry,
+ SVTRDeterioration,
+ ParseQDeterioration,
+)
from paddle.vision.transforms import Compose
class RecAug(object):
- def __init__(self,
- tia_prob=0.4,
- crop_prob=0.4,
- reverse_prob=0.4,
- noise_prob=0.4,
- jitter_prob=0.4,
- blur_prob=0.4,
- hsv_aug_prob=0.4,
- **kwargs):
+ def __init__(
+ self,
+ tia_prob=0.4,
+ crop_prob=0.4,
+ reverse_prob=0.4,
+ noise_prob=0.4,
+ jitter_prob=0.4,
+ blur_prob=0.4,
+ hsv_aug_prob=0.4,
+ **kwargs
+ ):
self.tia_prob = tia_prob
- self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob,
- jitter_prob, blur_prob, hsv_aug_prob)
+ self.bda = BaseDataAugmentation(
+ crop_prob, reverse_prob, noise_prob, jitter_prob, blur_prob, hsv_aug_prob
+ )
def __call__(self, data):
- img = data['image']
+ img = data["image"]
h, w, _ = img.shape
# tia
@@ -50,20 +60,22 @@ def __call__(self, data):
img = tia_perspective(img)
# bda
- data['image'] = img
+ data["image"] = img
data = self.bda(data)
return data
class BaseDataAugmentation(object):
- def __init__(self,
- crop_prob=0.4,
- reverse_prob=0.4,
- noise_prob=0.4,
- jitter_prob=0.4,
- blur_prob=0.4,
- hsv_aug_prob=0.4,
- **kwargs):
+ def __init__(
+ self,
+ crop_prob=0.4,
+ reverse_prob=0.4,
+ noise_prob=0.4,
+ jitter_prob=0.4,
+ blur_prob=0.4,
+ hsv_aug_prob=0.4,
+ **kwargs
+ ):
self.crop_prob = crop_prob
self.reverse_prob = reverse_prob
self.noise_prob = noise_prob
@@ -74,7 +86,7 @@ def __init__(self,
self.fil = cv2.getGaussianKernel(ksize=5, sigma=1, ktype=cv2.CV_32F)
def __call__(self, data):
- img = data['image']
+ img = data["image"]
h, w, _ = img.shape
if random.random() <= self.crop_prob and h >= 20 and w >= 20:
@@ -96,47 +108,51 @@ def __call__(self, data):
if random.random() <= self.reverse_prob:
img = 255 - img
- data['image'] = img
+ data["image"] = img
return data
class ABINetRecAug(object):
- def __init__(self,
- geometry_p=0.5,
- deterioration_p=0.25,
- colorjitter_p=0.25,
- **kwargs):
- self.transforms = Compose([
- CVGeometry(
- degrees=45,
- translate=(0.0, 0.0),
- scale=(0.5, 2.),
- shear=(45, 15),
- distortion=0.5,
- p=geometry_p), CVDeterioration(
- var=20, degrees=6, factor=4, p=deterioration_p),
- CVColorJitter(
- brightness=0.5,
- contrast=0.5,
- saturation=0.5,
- hue=0.1,
- p=colorjitter_p)
- ])
+ def __init__(
+ self, geometry_p=0.5, deterioration_p=0.25, colorjitter_p=0.25, **kwargs
+ ):
+ self.transforms = Compose(
+ [
+ CVGeometry(
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p,
+ ),
+ CVDeterioration(var=20, degrees=6, factor=4, p=deterioration_p),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p,
+ ),
+ ]
+ )
def __call__(self, data):
- img = data['image']
+ img = data["image"]
img = self.transforms(img)
- data['image'] = img
+ data["image"] = img
return data
class RecConAug(object):
- def __init__(self,
- prob=0.5,
- image_shape=(32, 320, 3),
- max_text_length=25,
- ext_data_num=1,
- **kwargs):
+ def __init__(
+ self,
+ prob=0.5,
+ image_shape=(32, 320, 3),
+ max_text_length=25,
+ ext_data_num=1,
+ **kwargs
+ ):
self.ext_data_num = ext_data_num
self.prob = prob
self.max_text_length = max_text_length
@@ -144,15 +160,17 @@ def __init__(self,
self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
def merge_ext_data(self, data, ext_data):
- ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
- self.image_shape[0])
- ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
- self.image_shape[0])
- data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
- ext_data['image'] = cv2.resize(ext_data['image'],
- (ext_w, self.image_shape[0]))
- data['image'] = np.concatenate(
- [data['image'], ext_data['image']], axis=1)
+ ori_w = round(
+ data["image"].shape[1] / data["image"].shape[0] * self.image_shape[0]
+ )
+ ext_w = round(
+ ext_data["image"].shape[1]
+ / ext_data["image"].shape[0]
+ * self.image_shape[0]
+ )
+ data["image"] = cv2.resize(data["image"], (ori_w, self.image_shape[0]))
+ ext_data["image"] = cv2.resize(ext_data["image"], (ext_w, self.image_shape[0]))
+ data["image"] = np.concatenate([data["image"], ext_data["image"]], axis=1)
data["label"] += ext_data["label"]
return data
@@ -161,11 +179,12 @@ def __call__(self, data):
if rnd_num > self.prob:
return data
for idx, ext_data in enumerate(data["ext_data"]):
- if len(data["label"]) + len(ext_data[
- "label"]) > self.max_text_length:
+ if len(data["label"]) + len(ext_data["label"]) > self.max_text_length:
break
- concat_ratio = data['image'].shape[1] / data['image'].shape[
- 0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
+ concat_ratio = (
+ data["image"].shape[1] / data["image"].shape[0]
+ + ext_data["image"].shape[1] / ext_data["image"].shape[0]
+ )
if concat_ratio > self.max_wh_ratio:
break
data = self.merge_ext_data(data, ext_data)
@@ -174,86 +193,104 @@ def __call__(self, data):
class SVTRRecAug(object):
- def __init__(self,
- aug_type=0,
- geometry_p=0.5,
- deterioration_p=0.25,
- colorjitter_p=0.25,
- **kwargs):
- self.transforms = Compose([
- SVTRGeometry(
- aug_type=aug_type,
- degrees=45,
- translate=(0.0, 0.0),
- scale=(0.5, 2.),
- shear=(45, 15),
- distortion=0.5,
- p=geometry_p), SVTRDeterioration(
- var=20, degrees=6, factor=4, p=deterioration_p),
- CVColorJitter(
- brightness=0.5,
- contrast=0.5,
- saturation=0.5,
- hue=0.1,
- p=colorjitter_p)
- ])
+ def __init__(
+ self,
+ aug_type=0,
+ geometry_p=0.5,
+ deterioration_p=0.25,
+ colorjitter_p=0.25,
+ **kwargs
+ ):
+ self.transforms = Compose(
+ [
+ SVTRGeometry(
+ aug_type=aug_type,
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p,
+ ),
+ SVTRDeterioration(var=20, degrees=6, factor=4, p=deterioration_p),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p,
+ ),
+ ]
+ )
def __call__(self, data):
- img = data['image']
+ img = data["image"]
img = self.transforms(img)
- data['image'] = img
+ data["image"] = img
return data
+
class ParseQRecAug(object):
- def __init__(self,
- aug_type=0,
- geometry_p=0.5,
- deterioration_p=0.25,
- colorjitter_p=0.25,
- **kwargs):
- self.transforms = Compose([
- SVTRGeometry(
- aug_type=aug_type,
- degrees=45,
- translate=(0.0, 0.0),
- scale=(0.5, 2.),
- shear=(45, 15),
- distortion=0.5,
- p=geometry_p), ParseQDeterioration(
- var=20, degrees=6, lam=20, radius=2.0, factor=4, p=deterioration_p),
- CVColorJitter(
- brightness=0.5,
- contrast=0.5,
- saturation=0.5,
- hue=0.1,
- p=colorjitter_p)
- ])
+ def __init__(
+ self,
+ aug_type=0,
+ geometry_p=0.5,
+ deterioration_p=0.25,
+ colorjitter_p=0.25,
+ **kwargs
+ ):
+ self.transforms = Compose(
+ [
+ SVTRGeometry(
+ aug_type=aug_type,
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.0),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p,
+ ),
+ ParseQDeterioration(
+ var=20, degrees=6, lam=20, radius=2.0, factor=4, p=deterioration_p
+ ),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p,
+ ),
+ ]
+ )
def __call__(self, data):
- img = data['image']
+ img = data["image"]
img = self.transforms(img)
- data['image'] = img
+ data["image"] = img
return data
+
class ClsResizeImg(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
- img = data['image']
+ img = data["image"]
norm_img, _ = resize_norm_img(img, self.image_shape)
- data['image'] = norm_img
+ data["image"] = norm_img
return data
class RecResizeImg(object):
- def __init__(self,
- image_shape,
- infer_mode=False,
- eval_mode=False,
- character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
- padding=True,
- **kwargs):
+ def __init__(
+ self,
+ image_shape,
+ infer_mode=False,
+ eval_mode=False,
+ character_dict_path="./ppocr/utils/ppocr_keys_v1.txt",
+ padding=True,
+ **kwargs
+ ):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.eval_mode = eval_mode
@@ -261,39 +298,37 @@ def __init__(self,
self.padding = padding
def __call__(self, data):
- img = data['image']
- if self.eval_mode or (self.infer_mode and
- self.character_dict_path is not None):
- norm_img, valid_ratio = resize_norm_img_chinese(img,
- self.image_shape)
+ img = data["image"]
+ if self.eval_mode or (self.infer_mode and self.character_dict_path is not None):
+ norm_img, valid_ratio = resize_norm_img_chinese(img, self.image_shape)
else:
- norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
- self.padding)
- data['image'] = norm_img
- data['valid_ratio'] = valid_ratio
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape, self.padding)
+ data["image"] = norm_img
+ data["valid_ratio"] = valid_ratio
return data
class VLRecResizeImg(object):
- def __init__(self,
- image_shape,
- infer_mode=False,
- character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
- padding=True,
- **kwargs):
+ def __init__(
+ self,
+ image_shape,
+ infer_mode=False,
+ character_dict_path="./ppocr/utils/ppocr_keys_v1.txt",
+ padding=True,
+ **kwargs
+ ):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
- img = data['image']
+ img = data["image"]
imgC, imgH, imgW = self.image_shape
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if self.image_shape[0] == 1:
resized_image = resized_image / 255
norm_img = resized_image[np.newaxis, :]
@@ -301,8 +336,8 @@ def __call__(self, data):
norm_img = resized_image.transpose((2, 0, 1)) / 255
valid_ratio = min(1.0, float(resized_w / imgW))
- data['image'] = norm_img
- data['valid_ratio'] = valid_ratio
+ data["image"] = norm_img
+ data["valid_ratio"] = valid_ratio
return data
@@ -324,12 +359,13 @@ def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
raise Exception("Unsupported interpolation type !!!")
def __call__(self, data):
- img = data['image']
+ img = data["image"]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
norm_img, valid_ratio = resize_norm_img(
- img, self.image_shape, self.padding, self.interpolation)
- data['image'] = norm_img
- data['valid_ratio'] = valid_ratio
+ img, self.image_shape, self.padding, self.interpolation
+ )
+ data["image"] = norm_img
+ data["valid_ratio"] = valid_ratio
return data
@@ -340,16 +376,20 @@ def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.max_text_length = max_text_length
def __call__(self, data):
- img = data['image']
+ img = data["image"]
norm_img = resize_norm_img_srn(img, self.image_shape)
- data['image'] = norm_img
- [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
- srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
-
- data['encoder_word_pos'] = encoder_word_pos
- data['gsrm_word_pos'] = gsrm_word_pos
- data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
- data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
+ data["image"] = norm_img
+ [
+ encoder_word_pos,
+ gsrm_word_pos,
+ gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2,
+ ] = srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
+
+ data["encoder_word_pos"] = encoder_word_pos
+ data["gsrm_word_pos"] = gsrm_word_pos
+ data["gsrm_slf_attn_bias1"] = gsrm_slf_attn_bias1
+ data["gsrm_slf_attn_bias2"] = gsrm_slf_attn_bias2
return data
@@ -359,42 +399,46 @@ def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
self.width_downsample_ratio = width_downsample_ratio
def __call__(self, data):
- img = data['image']
+ img = data["image"]
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
- img, self.image_shape, self.width_downsample_ratio)
- data['image'] = norm_img
- data['resized_shape'] = resize_shape
- data['pad_shape'] = pad_shape
- data['valid_ratio'] = valid_ratio
+ img, self.image_shape, self.width_downsample_ratio
+ )
+ data["image"] = norm_img
+ data["resized_shape"] = resize_shape
+ data["pad_shape"] = pad_shape
+ data["valid_ratio"] = valid_ratio
return data
class PRENResizeImg(object):
def __init__(self, image_shape, **kwargs):
"""
- Accroding to original paper's realization, it's a hard resize method here.
+ Accroding to original paper's realization, it's a hard resize method here.
So maybe you should optimize it to fit for your task better.
"""
self.dst_h, self.dst_w = image_shape
def __call__(self, data):
- img = data['image']
+ img = data["image"]
resized_img = cv2.resize(
- img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
+ img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR
+ )
resized_img = resized_img.transpose((2, 0, 1)) / 255
resized_img -= 0.5
resized_img /= 0.5
- data['image'] = resized_img.astype(np.float32)
+ data["image"] = resized_img.astype(np.float32)
return data
class SPINRecResizeImg(object):
- def __init__(self,
- image_shape,
- interpolation=2,
- mean=(127.5, 127.5, 127.5),
- std=(127.5, 127.5, 127.5),
- **kwargs):
+ def __init__(
+ self,
+ image_shape,
+ interpolation=2,
+ mean=(127.5, 127.5, 127.5),
+ std=(127.5, 127.5, 127.5),
+ **kwargs
+ ):
self.image_shape = image_shape
self.mean = np.array(mean, dtype=np.float32)
@@ -402,7 +446,7 @@ def __init__(self,
self.interpolation = interpolation
def __call__(self, data):
- img = data['image']
+ img = data["image"]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# different interpolation type corresponding the OpenCV
if self.interpolation == 0:
@@ -429,18 +473,20 @@ def __call__(self, data):
stdinv = 1 / np.float64(self.std.reshape(1, -1))
img -= mean
img *= stdinv
- data['image'] = img
+ data["image"] = img
return data
class GrayRecResizeImg(object):
- def __init__(self,
- image_shape,
- resize_type,
- inter_type="Image.Resampling.LANCZOS",
- scale=True,
- padding=False,
- **kwargs):
+ def __init__(
+ self,
+ image_shape,
+ resize_type,
+ inter_type="Image.Resampling.LANCZOS",
+ scale=True,
+ padding=False,
+ **kwargs
+ ):
self.image_shape = image_shape
self.resize_type = resize_type
self.padding = padding
@@ -448,7 +494,7 @@ def __init__(self,
self.scale = scale
def __call__(self, data):
- img = data['image']
+ img = data["image"]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image_shape = self.image_shape
if self.padding:
@@ -464,23 +510,23 @@ def __call__(self, data):
resized_image = cv2.resize(img, (resized_w, imgH))
norm_img = np.expand_dims(resized_image, -1)
norm_img = norm_img.transpose((2, 0, 1))
- resized_image = norm_img.astype(np.float32) / 128. - 1.
+ resized_image = norm_img.astype(np.float32) / 128.0 - 1.0
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
- data['image'] = padding_im
+ data["image"] = padding_im
return data
- if self.resize_type == 'PIL':
+ if self.resize_type == "PIL":
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, self.inter_type)
img = np.array(img)
- if self.resize_type == 'OpenCV':
+ if self.resize_type == "OpenCV":
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
if self.scale:
- data['image'] = norm_img.astype(np.float32) / 128. - 1.
+ data["image"] = norm_img.astype(np.float32) / 128.0 - 1.0
else:
- data['image'] = norm_img.astype(np.float32) / 255.
+ data["image"] = norm_img.astype(np.float32) / 255.0
return data
@@ -489,10 +535,10 @@ def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
- img = data['image']
+ img = data["image"]
norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
- data['image'] = norm_img
- data['valid_ratio'] = valid_ratio
+ data["image"] = norm_img
+ data["valid_ratio"] = valid_ratio
return data
@@ -502,35 +548,33 @@ def __init__(self, image_shape, padding=True, **kwargs):
self.padding = padding
def __call__(self, data):
- img = data['image']
+ img = data["image"]
- norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
- self.padding)
- data['image'] = norm_img
- data['valid_ratio'] = valid_ratio
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape, self.padding)
+ data["image"] = norm_img
+ data["valid_ratio"] = valid_ratio
return data
class RobustScannerRecResizeImg(object):
- def __init__(self,
- image_shape,
- max_text_length,
- width_downsample_ratio=0.25,
- **kwargs):
+ def __init__(
+ self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs
+ ):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
self.max_text_length = max_text_length
def __call__(self, data):
- img = data['image']
+ img = data["image"]
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
- img, self.image_shape, self.width_downsample_ratio)
- word_positons = np.array(range(0, self.max_text_length)).astype('int64')
- data['image'] = norm_img
- data['resized_shape'] = resize_shape
- data['pad_shape'] = pad_shape
- data['valid_ratio'] = valid_ratio
- data['word_positons'] = word_positons
+ img, self.image_shape, self.width_downsample_ratio
+ )
+ word_positons = np.array(range(0, self.max_text_length)).astype("int64")
+ data["image"] = norm_img
+ data["resized_shape"] = resize_shape
+ data["pad_shape"] = pad_shape
+ data["valid_ratio"] = valid_ratio
+ data["word_positons"] = word_positons
return data
@@ -552,8 +596,8 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
- resized_image = resized_image.astype('float32')
- # norm
+ resized_image = resized_image.astype("float32")
+ # norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -569,16 +613,12 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return padding_im, resize_shape, pad_shape, valid_ratio
-def resize_norm_img(img,
- image_shape,
- padding=True,
- interpolation=cv2.INTER_LINEAR):
+def resize_norm_img(img, image_shape, padding=True, interpolation=cv2.INTER_LINEAR):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=interpolation)
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=interpolation)
resized_w = imgW
else:
ratio = w / float(h)
@@ -587,7 +627,7 @@ def resize_norm_img(img,
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -614,7 +654,7 @@ def resize_norm_img_chinese(img, image_shape):
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -646,7 +686,7 @@ def resize_norm_img_srn(img, image_shape):
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
- img_black[:, 0:img_np.shape[1]] = img_np
+ img_black[:, 0 : img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
@@ -658,48 +698,46 @@ def resize_norm_img_srn(img, image_shape):
def resize_norm_img_abinet(img, image_shape):
imgC, imgH, imgW = image_shape
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
- resized_image = resized_image.astype('float32')
- resized_image = resized_image / 255.
+ resized_image = resized_image.astype("float32")
+ resized_image = resized_image / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
- resized_image = (
- resized_image - mean[None, None, ...]) / std[None, None, ...]
+ resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
valid_ratio = min(1.0, float(resized_w / imgW))
return resized_image, valid_ratio
def srn_other_inputs(image_shape, num_heads, max_text_length):
-
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
- encoder_word_pos = np.array(range(0, feature_dim)).reshape(
- (feature_dim, 1)).astype('int64')
- gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
- (max_text_length, 1)).astype('int64')
+ encoder_word_pos = (
+ np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64")
+ )
+ gsrm_word_pos = (
+ np.array(range(0, max_text_length))
+ .reshape((max_text_length, 1))
+ .astype("int64")
+ )
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
- [1, max_text_length, max_text_length])
- gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
- [num_heads, 1, 1]) * [-1e9]
+ [1, max_text_length, max_text_length]
+ )
+ gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
- [1, max_text_length, max_text_length])
- gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
- [num_heads, 1, 1]) * [-1e9]
+ [1, max_text_length, max_text_length]
+ )
+ gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [num_heads, 1, 1]) * [-1e9]
- return [
- encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
- gsrm_slf_attn_bias2
- ]
+ return [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
def flag():
@@ -741,7 +779,7 @@ def jitter(img):
s = int(random.random() * thres * 0.01)
src_img = img.copy()
for i in range(s):
- img[i:, i:, :] = src_img[:w - i, :h - i, :]
+ img[i:, i:, :] = src_img[: w - i, : h - i, :]
return img
else:
return img
@@ -773,7 +811,7 @@ def get_crop(image):
if ratio:
crop_img = crop_img[top_crop:h, :, :]
else:
- crop_img = crop_img[0:h - top_crop, :, :]
+ crop_img = crop_img[0 : h - top_crop, :, :]
return crop_img
@@ -788,30 +826,57 @@ def get_warpR(config):
"""
get_warpR
"""
- anglex, angley, anglez, fov, w, h, r = \
- config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
+ anglex, angley, anglez, fov, w, h, r = (
+ config.anglex,
+ config.angley,
+ config.anglez,
+ config.fov,
+ config.w,
+ config.h,
+ config.r,
+ )
if w > 69 and w < 112:
anglex = anglex * 1.5
z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
# Homogeneous coordinate transformation matrix
- rx = np.array([[1, 0, 0, 0],
- [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
- 0,
- -np.sin(rad(anglex)),
- np.cos(rad(anglex)),
- 0,
- ], [0, 0, 0, 1]], np.float32)
- ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
- [0, 1, 0, 0], [
- -np.sin(rad(angley)),
- 0,
- np.cos(rad(angley)),
- 0,
- ], [0, 0, 0, 1]], np.float32)
- rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
- [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
- [0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
+ rx = np.array(
+ [
+ [1, 0, 0, 0],
+ [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0],
+ [
+ 0,
+ -np.sin(rad(anglex)),
+ np.cos(rad(anglex)),
+ 0,
+ ],
+ [0, 0, 0, 1],
+ ],
+ np.float32,
+ )
+ ry = np.array(
+ [
+ [np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
+ [0, 1, 0, 0],
+ [
+ -np.sin(rad(angley)),
+ 0,
+ np.cos(rad(angley)),
+ 0,
+ ],
+ [0, 0, 0, 1],
+ ],
+ np.float32,
+ )
+ rz = np.array(
+ [
+ [np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
+ [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],
+ ],
+ np.float32,
+ )
r = rx.dot(ry).dot(rz)
# generate 4 points
pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
@@ -843,11 +908,11 @@ def get_warpR(config):
dx = -c1
dy = -r1
- T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
+ T1 = np.float32([[1.0, 0, dx], [0, 1.0, dy], [0, 0, 1.0 / ratio]])
ret = T1.dot(warpR)
except:
ratio = 1.0
- T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
+ T1 = np.float32([[1.0, 0, 0], [0, 1.0, 0], [0, 0, 1.0]])
ret = T1
return ret, (-r1, -c1), ratio, dst
@@ -857,6 +922,11 @@ def get_warpAffine(config):
get_warpAffine
"""
anglez = config.anglez
- rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
- [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
+ rz = np.array(
+ [
+ [np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
+ [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0],
+ ],
+ np.float32,
+ )
return rz
diff --git a/ppocr/data/imaug/sast_process.py b/ppocr/data/imaug/sast_process.py
index 08d03b194d..81e13930e6 100644
--- a/ppocr/data/imaug/sast_process.py
+++ b/ppocr/data/imaug/sast_process.py
@@ -1,16 +1,16 @@
-#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
This part code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
@@ -22,17 +22,19 @@
import sys
import os
-__all__ = ['SASTProcessTrain']
+__all__ = ["SASTProcessTrain"]
class SASTProcessTrain(object):
- def __init__(self,
- image_shape=[512, 512],
- min_crop_size=24,
- min_crop_side_ratio=0.3,
- min_text_size=10,
- max_text_size=512,
- **kwargs):
+ def __init__(
+ self,
+ image_shape=[512, 512],
+ min_crop_size=24,
+ min_crop_side_ratio=0.3,
+ min_text_size=10,
+ max_text_size=512,
+ **kwargs
+ ):
self.input_size = image_shape[1]
self.min_crop_size = min_crop_size
self.min_crop_side_ratio = min_crop_side_ratio
@@ -45,11 +47,13 @@ def quad_area(self, poly):
:param poly:
:return:
"""
- edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
- (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
- (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
- (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
- return np.sum(edge) / 2.
+ edge = [
+ (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
+ ]
+ return np.sum(edge) / 2.0
def gen_quad_from_poly(self, poly):
"""
@@ -58,18 +62,21 @@ def gen_quad_from_poly(self, poly):
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if True:
- rect = cv2.minAreaRect(poly.astype(
- np.int32)) # (center (x,y), (width, height), angle of rotation)
+ rect = cv2.minAreaRect(
+ poly.astype(np.int32)
+ ) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
- dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
- np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
- np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
- np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ dist = (
+ np.linalg.norm(box[(i + 0) % 4] - poly[0])
+ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
+ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
+ + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ )
if dist < min_dist:
min_dist = dist
first_point_idx = i
@@ -99,20 +106,21 @@ def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
quad = self.gen_quad_from_poly(poly)
p_area = self.quad_area(quad)
if abs(p_area) < 1:
- print('invalid poly')
+ print("invalid poly")
continue
if p_area > 0:
if tag == False:
- print('poly in wrong direction')
+ print("poly in wrong direction")
tag = True # reversed cases should be ignore
- poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
- 1), :]
+ poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
quad = quad[(0, 3, 2, 1), :]
- len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
- quad[2])
- len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
- quad[2])
+ len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(
+ quad[3] - quad[2]
+ )
+ len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(
+ quad[1] - quad[2]
+ )
hv_tag = 1
if len_w * 2.0 < len_h:
@@ -121,16 +129,9 @@ def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
validated_polys.append(poly)
validated_tags.append(tag)
hv_tags.append(hv_tag)
- return np.array(validated_polys), np.array(validated_tags), np.array(
- hv_tags)
+ return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
- def crop_area(self,
- im,
- polys,
- tags,
- hv_tags,
- crop_background=False,
- max_tries=25):
+ def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25):
"""
make random crop from the input image
:param im:
@@ -149,10 +150,10 @@ def crop_area(self,
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
- w_array[minx + pad_w:maxx + pad_w] = 1
+ w_array[minx + pad_w : maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
- h_array[miny + pad_h:maxy + pad_h] = 1
+ h_array[miny + pad_h : maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
@@ -171,25 +172,31 @@ def crop_area(self,
ymax = np.clip(ymax, 0, h - 1)
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
# ymax - ymin < ARGS.min_crop_side_ratio * h:
- if xmax - xmin < self.min_crop_size or \
- ymax - ymin < self.min_crop_size:
+ if xmax - xmin < self.min_crop_size or ymax - ymin < self.min_crop_size:
# area too small
continue
if polys.shape[0] != 0:
- poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
- & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
- selected_polys = np.where(
- np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ poly_axis_in_area = (
+ (polys[:, :, 0] >= xmin)
+ & (polys[:, :, 0] <= xmax)
+ & (polys[:, :, 1] >= ymin)
+ & (polys[:, :, 1] <= ymax)
+ )
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
- return im[ymin : ymax + 1, xmin : xmax + 1, :], \
- polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
+ return (
+ im[ymin : ymax + 1, xmin : xmax + 1, :],
+ polys[selected_polys],
+ tags[selected_polys],
+ hv_tags[selected_polys],
+ )
else:
continue
- im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+ im = im[ymin : ymax + 1, xmin : xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys]
@@ -200,53 +207,55 @@ def crop_area(self,
return im, polys, tags, hv_tags
def generate_direction_map(self, poly_quads, direction_map):
- """
- """
+ """ """
width_list = []
height_list = []
for quad in poly_quads:
- quad_w = (np.linalg.norm(quad[0] - quad[1]) +
- np.linalg.norm(quad[2] - quad[3])) / 2.0
- quad_h = (np.linalg.norm(quad[0] - quad[3]) +
- np.linalg.norm(quad[2] - quad[1])) / 2.0
+ quad_w = (
+ np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
+ ) / 2.0
+ quad_h = (
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
+ ) / 2.0
width_list.append(quad_w)
height_list.append(quad_h)
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
for quad in poly_quads:
- direct_vector_full = (
- (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
- direct_vector = direct_vector_full / (
- np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
+ direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
+ direct_vector = (
+ direct_vector_full
+ / (np.linalg.norm(direct_vector_full) + 1e-6)
+ * norm_width
+ )
direction_label = tuple(
- map(float, [
- direct_vector[0], direct_vector[1], 1.0 / (average_height +
- 1e-6)
- ]))
- cv2.fillPoly(direction_map,
- quad.round().astype(np.int32)[np.newaxis, :, :],
- direction_label)
+ map(
+ float,
+ [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)],
+ )
+ )
+ cv2.fillPoly(
+ direction_map,
+ quad.round().astype(np.int32)[np.newaxis, :, :],
+ direction_label,
+ )
return direction_map
def calculate_average_height(self, poly_quads):
- """
- """
+ """ """
height_list = []
for quad in poly_quads:
- quad_h = (np.linalg.norm(quad[0] - quad[3]) +
- np.linalg.norm(quad[2] - quad[1])) / 2.0
+ quad_h = (
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
+ ) / 2.0
height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height
- def generate_tcl_label(self,
- hw,
- polys,
- tags,
- ds_ratio,
- tcl_ratio=0.3,
- shrink_ratio_of_width=0.15):
+ def generate_tcl_label(
+ self, hw, polys, tags, ds_ratio, tcl_ratio=0.3, shrink_ratio_of_width=0.15
+ ):
"""
Generate polygon.
"""
@@ -257,14 +266,21 @@ def generate_tcl_label(self,
score_map = np.zeros(
(
h,
- w, ), dtype=np.float32)
+ w,
+ ),
+ dtype=np.float32,
+ )
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones(
(
h,
- w, ), dtype=np.float32)
+ w,
+ ),
+ dtype=np.float32,
+ )
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
- [1, 1, 3]).astype(np.float32)
+ [1, 1, 3]
+ ).astype(np.float32)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
@@ -273,20 +289,25 @@ def generate_tcl_label(self,
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
- np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
- np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3])
+ + np.linalg.norm(min_area_quad[1] - min_area_quad[2])
+ )
min_area_quad_w = 0.5 * (
- np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
- np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
-
- if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
- or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1])
+ + np.linalg.norm(min_area_quad[2] - min_area_quad[3])
+ )
+
+ if (
+ min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio
+ or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio
+ ):
continue
if tag:
# continue
- cv2.fillPoly(training_mask,
- poly.astype(np.int32)[np.newaxis, :, :], 0.15)
+ cv2.fillPoly(
+ training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15
+ )
else:
tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly)
@@ -295,27 +316,25 @@ def generate_tcl_label(self,
stcl_quads, quad_index = self.shrink_poly_along_width(
tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width,
- expand_height_ratio=1.0 / tcl_ratio)
+ expand_height_ratio=1.0 / tcl_ratio,
+ )
# generate tcl map
- cv2.fillPoly(score_map,
- np.round(stcl_quads).astype(np.int32), 1.0)
+ cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
# generate tbo map
for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(
quad_mask,
- np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
- tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
- quad_mask, tbo_map)
+ np.round(quad[np.newaxis, :, :]).astype(np.int32),
+ 1.0,
+ )
+ tbo_map = self.gen_quad_tbo(
+ poly_quads[quad_index[idx]], quad_mask, tbo_map
+ )
return score_map, tbo_map, training_mask
- def generate_tvo_and_tco(self,
- hw,
- polys,
- tags,
- tcl_ratio=0.3,
- ds_ratio=0.25):
+ def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
"""
Generate tcl map, tvo map and tbo map.
"""
@@ -338,7 +357,6 @@ def generate_tvo_and_tco(self,
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
for poly, poly_tag in zip(polys, tags):
-
if poly_tag == True:
continue
@@ -348,11 +366,13 @@ def generate_tvo_and_tco(self,
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
- np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
- np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3])
+ + np.linalg.norm(min_area_quad[1] - min_area_quad[2])
+ )
min_area_quad_w = 0.5 * (
- np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
- np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1])
+ + np.linalg.norm(min_area_quad[2] - min_area_quad[3])
+ )
# generate tcl map and text, 128 * 128
tcl_poly = self.poly2tcl(poly, tcl_ratio)
@@ -362,29 +382,33 @@ def generate_tvo_and_tco(self,
cv2.fillPoly(
poly_tv_xy_map[2 * idx],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
- float(min(max(min_area_quad[idx, 0], 0), w)))
+ float(min(max(min_area_quad[idx, 0], 0), w)),
+ )
cv2.fillPoly(
poly_tv_xy_map[2 * idx + 1],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
- float(min(max(min_area_quad[idx, 1], 0), h)))
+ float(min(max(min_area_quad[idx, 1], 0), h)),
+ )
# generate poly_tc_xy_map
for idx in range(2):
cv2.fillPoly(
poly_tc_xy_map[idx],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
- float(center_point[idx]))
+ float(center_point[idx]),
+ )
# generate poly_short_edge_map
cv2.fillPoly(
poly_short_edge_map,
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
- float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
+ float(max(min(min_area_quad_h, min_area_quad_w), 1.0)),
+ )
# generate poly_mask and training_mask
- cv2.fillPoly(poly_mask,
- np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
- 1)
+ cv2.fillPoly(
+ poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1
+ )
tvo_map *= poly_mask
tvo_map[:8] -= poly_tv_xy_map
@@ -416,7 +440,8 @@ def adjust_point(self, poly):
vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (
- np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
+ np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6
+ )
theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi):
@@ -434,18 +459,21 @@ def gen_min_area_quad_from_poly(self, poly):
min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4
else:
- rect = cv2.minAreaRect(poly.astype(
- np.int32)) # (center (x,y), (width, height), angle of rotation)
+ rect = cv2.minAreaRect(
+ poly.astype(np.int32)
+ ) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
- dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
- np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
- np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
- np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ dist = (
+ np.linalg.norm(box[(i + 0) % 4] - poly[0])
+ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
+ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
+ + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ )
if dist < min_dist:
min_dist = dist
first_point_idx = i
@@ -455,23 +483,20 @@ def gen_min_area_quad_from_poly(self, poly):
return min_area_quad, center_point
- def shrink_quad_along_width(self,
- quad,
- begin_width_ratio=0.,
- end_width_ratio=1.):
+ def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
- [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
+ )
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
- def shrink_poly_along_width(self,
- quads,
- shrink_ratio_of_width,
- expand_height_ratio=1.0):
+ def shrink_poly_along_width(
+ self, quads, shrink_ratio_of_width, expand_height_ratio=1.0
+ ):
"""
shrink poly with given length.
"""
@@ -489,28 +514,30 @@ def get_cut_info(edge_len_list, cut_len):
upper_edge_list.append(upper_edge_len)
# length of left edge and right edge.
- left_length = np.linalg.norm(quads[0][0] - quads[0][
- 3]) * expand_height_ratio
- right_length = np.linalg.norm(quads[-1][1] - quads[-1][
- 2]) * expand_height_ratio
+ left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
+ right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
- shrink_length = min(left_length, right_length,
- sum(upper_edge_list)) * shrink_ratio_of_width
+ shrink_length = (
+ min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
+ )
# shrinking length
upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(
- quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
+ quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1
+ )
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(
- quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
+ quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio
+ )
out_quad_list = []
if left_idx == right_idx:
out_quad_list.append(
- [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
+ [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]
+ )
else:
out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx):
@@ -560,9 +587,9 @@ def line_cross_point(self, line1, line2):
d = a1 * b2 - a2 * b1
if d == 0:
- #print("line1", line1)
- #print("line2", line2)
- print('Cross point does not exist')
+ # print("line1", line1)
+ # print("line2", line2)
+ print("Cross point does not exist")
return np.array([0, 0], dtype=np.float32)
else:
x = (b1 * c2 - b2 * c1) / d
@@ -574,8 +601,7 @@ def quad2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point. (4, 2)
"""
- ratio_pair = np.array(
- [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
@@ -584,14 +610,14 @@ def poly2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point.
"""
- ratio_pair = np.array(
- [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0]
for idx in range(point_num // 2):
- point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
- ) * ratio_pair
+ point_pair = (
+ poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
+ )
tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly
@@ -604,10 +630,12 @@ def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2])
- quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
- np.linalg.norm(quad[1] - quad[2]))
- quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
- np.linalg.norm(quad[2] - quad[3]))
+ quad_h = 0.5 * (
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
+ )
+ quad_w = 0.5 * (
+ np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
+ )
# average angle of left and right line.
angle = self.average_angle(quad)
@@ -644,15 +672,16 @@ def poly2quads(self, poly):
quad_num = point_num // 2 - 1
for idx in range(quad_num):
# reshape and adjust to clock-wise
- quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
- ).reshape(4, 2)[[0, 2, 3, 1]])
+ quad_list.append(
+ (np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]
+ )
return np.array(quad_list)
def __call__(self, data):
- im = data['image']
- text_polys = data['polys']
- text_tags = data['ignore_tags']
+ im = data["image"]
+ text_polys = data["polys"]
+ text_tags = data["ignore_tags"]
if im is None:
return None
if text_polys.shape[0] == 0:
@@ -660,12 +689,13 @@ def __call__(self, data):
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
- text_polys, text_tags, (h, w))
+ text_polys, text_tags, (h, w)
+ )
if text_polys.shape[0] == 0:
return None
- #set aspect ratio and keep area fix
+ # set aspect ratio and keep area fix
asp_scales = np.arange(1.0, 1.55, 0.1)
asp_scale = np.random.choice(asp_scales)
@@ -688,37 +718,39 @@ def __call__(self, data):
if min(h, w) < 16:
return None
- #no background
- im, text_polys, text_tags, hv_tags = self.crop_area(im, \
- text_polys, text_tags, hv_tags, crop_background=False)
+ # no background
+ im, text_polys, text_tags, hv_tags = self.crop_area(
+ im, text_polys, text_tags, hv_tags, crop_background=False
+ )
if text_polys.shape[0] == 0:
return None
- #continue for all ignore case
+ # continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
new_h, new_w, _ = im.shape
if (new_h is None) or (new_w is None):
return None
- #resize image
+ # resize image
std_ratio = float(self.input_size) / max(new_w, new_h)
rand_scales = np.array(
- [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
+ [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]
+ )
rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale
text_polys[:, :, 1] *= rz_scale
- #add gaussian blur
+ # add gaussian blur
if np.random.rand() < 0.1 * 0.5:
ks = np.random.permutation(5)[0] + 1
ks = int(ks / 2) * 2 + 1
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
- #add brighter
+ # add brighter
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 + np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
- #add darker
+ # add darker
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 - np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
@@ -728,8 +760,7 @@ def __call__(self, data):
if min(new_w, new_h) < self.input_size * 0.5:
return None
- im_padded = np.ones(
- (self.input_size, self.input_size, 3), dtype=np.float32)
+ im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
im_padded[:, :, 2] = 0.485 * 255
im_padded[:, :, 1] = 0.456 * 255
im_padded[:, :, 0] = 0.406 * 255
@@ -744,12 +775,13 @@ def __call__(self, data):
sw = int(np.random.rand() * del_w)
# Padding
- im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
+ im_padded[sh : sh + new_h, sw : sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh
score_map, border_map, training_mask = self.generate_tcl_label(
- (self.input_size, self.input_size), text_polys, text_tags, 0.25)
+ (self.input_size, self.input_size), text_polys, text_tags, 0.25
+ )
# SAST head
tvo_map, tco_map = self.generate_tvo_and_tco(
@@ -757,21 +789,22 @@ def __call__(self, data):
text_polys,
text_tags,
tcl_ratio=0.3,
- ds_ratio=0.25)
+ ds_ratio=0.25,
+ )
# print("test--------tvo_map shape:", tvo_map.shape)
im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255
- im_padded[:, :, 2] /= (255.0 * 0.229)
- im_padded[:, :, 1] /= (255.0 * 0.224)
- im_padded[:, :, 0] /= (255.0 * 0.225)
+ im_padded[:, :, 2] /= 255.0 * 0.229
+ im_padded[:, :, 1] /= 255.0 * 0.224
+ im_padded[:, :, 0] /= 255.0 * 0.225
im_padded = im_padded.transpose((2, 0, 1))
- data['image'] = im_padded[::-1, :, :]
- data['score_map'] = score_map[np.newaxis, :, :]
- data['border_map'] = border_map.transpose((2, 0, 1))
- data['training_mask'] = training_mask[np.newaxis, :, :]
- data['tvo_map'] = tvo_map.transpose((2, 0, 1))
- data['tco_map'] = tco_map.transpose((2, 0, 1))
+ data["image"] = im_padded[::-1, :, :]
+ data["score_map"] = score_map[np.newaxis, :, :]
+ data["border_map"] = border_map.transpose((2, 0, 1))
+ data["training_mask"] = training_mask[np.newaxis, :, :]
+ data["tvo_map"] = tvo_map.transpose((2, 0, 1))
+ data["tco_map"] = tco_map.transpose((2, 0, 1))
return data
diff --git a/ppocr/data/imaug/ssl_img_aug.py b/ppocr/data/imaug/ssl_img_aug.py
index f9ed6ac3e2..8162087a47 100644
--- a/ppocr/data/imaug/ssl_img_aug.py
+++ b/ppocr/data/imaug/ssl_img_aug.py
@@ -22,12 +22,9 @@
class SSLRotateResize(object):
- def __init__(self,
- image_shape,
- padding=False,
- select_all=True,
- mode="train",
- **kwargs):
+ def __init__(
+ self, image_shape, padding=False, select_all=True, mode="train", **kwargs
+ ):
self.image_shape = image_shape
self.padding = padding
self.select_all = select_all
@@ -37,18 +34,16 @@ def __call__(self, data):
img = data["image"]
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
- data["image_r180"] = cv2.rotate(data["image_r90"],
- cv2.ROTATE_90_CLOCKWISE)
- data["image_r270"] = cv2.rotate(data["image_r180"],
- cv2.ROTATE_90_CLOCKWISE)
+ data["image_r180"] = cv2.rotate(data["image_r90"], cv2.ROTATE_90_CLOCKWISE)
+ data["image_r270"] = cv2.rotate(data["image_r180"], cv2.ROTATE_90_CLOCKWISE)
images = []
for key in ["image", "image_r90", "image_r180", "image_r270"]:
images.append(
resize_norm_img(
- data.pop(key),
- image_shape=self.image_shape,
- padding=self.padding)[0])
+ data.pop(key), image_shape=self.image_shape, padding=self.padding
+ )[0]
+ )
data["image"] = np.stack(images, axis=0)
data["label"] = np.array(list(range(4)))
if not self.select_all:
diff --git a/ppocr/data/imaug/table_ops.py b/ppocr/data/imaug/table_ops.py
index c2c2fb2be6..ac21c30774 100644
--- a/ppocr/data/imaug/table_ops.py
+++ b/ppocr/data/imaug/table_ops.py
@@ -26,7 +26,7 @@
class GenTableMask(object):
- """ gen table mask """
+ """gen table mask"""
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
@@ -48,12 +48,10 @@ def projection(self, erosion, h, w, spilt_threshold=0):
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
- if in_text == False and project_val_array[
- i] > spilt_threshold: # 进入字符区了
+ if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
- elif project_val_array[
- i] <= spilt_threshold and in_text == True: # 进入空白区了
+ elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
@@ -72,8 +70,7 @@ def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
- ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
- cv2.THRESH_BINARY_INV)
+ ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
@@ -98,12 +95,10 @@ def projection_cx(self, box_img):
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
- if in_text == False and project_val_array[
- i] > spilt_threshold: # 进入字符区了
+ if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
- elif project_val_array[
- i] <= spilt_threshold and in_text == True: # 进入空白区了
+ elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
@@ -123,15 +118,16 @@ def projection_cx(self, box_img):
h_start = 0
if i == len(box_list):
h_end = h
- word_img = erosion[h_start:h_end + 1, :]
+ word_img = erosion[h_start : h_end + 1, :]
word_h, word_w = word_img.shape
- w_split_list, w_projection_map = self.projection(word_img.T,
- word_w, word_h)
+ w_split_list, w_projection_map = self.projection(
+ word_img.T, word_w, word_h
+ )
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
- word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
+ word_img = box_img[h_start : h_end + 1 :, w_start : w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
@@ -154,8 +150,8 @@ def shrink_bbox(self, bbox):
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
- img = data['image']
- cells = data['cells']
+ img = data["image"]
+ cells = data["cells"]
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
@@ -164,7 +160,7 @@ def __call__(self, data):
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
- bbox = cells[cno]['bbox']
+ bbox = cells[cno]["bbox"]
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
@@ -177,37 +173,37 @@ def __call__(self, data):
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox(
- [left, top, right, bottom])
+ [left, top, right, bottom]
+ )
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
- data['mask_img'] = mask_img
+ data["mask_img"] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
- data['image'] = mask_img
+ data["image"] = mask_img
return data
class ResizeTableImage(object):
- def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
- **kwargs):
+ def __init__(self, max_len, resize_bboxes=False, infer_mode=False, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
self.resize_bboxes = resize_bboxes
self.infer_mode = infer_mode
def __call__(self, data):
- img = data['image']
+ img = data["image"]
height, width = img.shape[0:2]
ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
resize_img = cv2.resize(img, (resize_w, resize_h))
if self.resize_bboxes and not self.infer_mode:
- data['bboxes'] = data['bboxes'] * ratio
- data['image'] = resize_img
- data['src_img'] = img
- data['shape'] = np.array([height, width, ratio, ratio])
- data['max_len'] = self.max_len
+ data["bboxes"] = data["bboxes"] * ratio
+ data["image"] = resize_img
+ data["src_img"] = img
+ data["shape"] = np.array([height, width, ratio, ratio])
+ data["max_len"] = self.max_len
return data
@@ -217,13 +213,13 @@ def __init__(self, size, **kwargs):
self.size = size
def __call__(self, data):
- img = data['image']
+ img = data["image"]
pad_h, pad_w = self.size
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
- data['image'] = padding_img
- shape = data['shape'].tolist()
+ data["image"] = padding_img
+ shape = data["shape"].tolist()
shape.extend([pad_h, pad_w])
- data['shape'] = np.array(shape)
+ data["shape"] = np.array(shape)
return data
diff --git a/ppocr/data/imaug/text_image_aug/__init__.py b/ppocr/data/imaug/text_image_aug/__init__.py
index bca262638e..16f179ff92 100644
--- a/ppocr/data/imaug/text_image_aug/__init__.py
+++ b/ppocr/data/imaug/text_image_aug/__init__.py
@@ -14,4 +14,4 @@
from .augment import tia_perspective, tia_distort, tia_stretch
-__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
+__all__ = ["tia_distort", "tia_stretch", "tia_perspective"]
diff --git a/ppocr/data/imaug/text_image_aug/augment.py b/ppocr/data/imaug/text_image_aug/augment.py
index 2d15dd5f35..1044abda6c 100644
--- a/ppocr/data/imaug/text_image_aug/augment.py
+++ b/ppocr/data/imaug/text_image_aug/augment.py
@@ -35,26 +35,29 @@ def tia_distort(src, segment=4):
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
+ dst_pts.append([img_w - np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
- [img_w - np.random.randint(thresh), np.random.randint(thresh)])
- dst_pts.append(
- [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
- dst_pts.append(
- [np.random.randint(thresh), img_h - np.random.randint(thresh)])
+ [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)]
+ )
+ dst_pts.append([np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
- dst_pts.append([
- cut * cut_idx + np.random.randint(thresh) - half_thresh,
- np.random.randint(thresh) - half_thresh
- ])
- dst_pts.append([
- cut * cut_idx + np.random.randint(thresh) - half_thresh,
- img_h + np.random.randint(thresh) - half_thresh
- ])
+ dst_pts.append(
+ [
+ cut * cut_idx + np.random.randint(thresh) - half_thresh,
+ np.random.randint(thresh) - half_thresh,
+ ]
+ )
+ dst_pts.append(
+ [
+ cut * cut_idx + np.random.randint(thresh) - half_thresh,
+ img_h + np.random.randint(thresh) - half_thresh,
+ ]
+ )
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
@@ -117,4 +120,4 @@ def tia_perspective(src):
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
- return dst
\ No newline at end of file
+ return dst
diff --git a/ppocr/data/imaug/text_image_aug/warp_mls.py b/ppocr/data/imaug/text_image_aug/warp_mls.py
index 75de11115c..2d349a628f 100644
--- a/ppocr/data/imaug/text_image_aug/warp_mls.py
+++ b/ppocr/data/imaug/text_image_aug/warp_mls.py
@@ -20,7 +20,7 @@
class WarpMLS:
- def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
+ def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.0):
self.src = src
self.src_pts = src_pts
self.dst_pts = dst_pts
@@ -34,8 +34,7 @@ def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
@staticmethod
def __bilinear_interp(x, y, v11, v12, v21, v22):
- return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
- (1 - y) + v22 * y) * x
+ return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 * (1 - y) + v22 * y) * x
def generate(self):
self.calc_delta()
@@ -72,9 +71,10 @@ def calc_delta(self):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
break
- w[k] = 1. / (
- (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
- (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
+ w[k] = 1.0 / (
+ (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0])
+ + (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1])
+ )
sw += w[k]
swp = swp + w[k] * np.array(self.dst_pts[k])
@@ -102,11 +102,15 @@ def calc_delta(self):
pt_j = np.array([-pt_i[1], pt_i[0]])
tmp_pt = np.zeros(2, dtype=np.float32)
- tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
- np.sum(pt_j * cur_pt) * self.src_pts[k][1]
- tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
- np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
- tmp_pt *= (w[k] / miu_s)
+ tmp_pt[0] = (
+ np.sum(pt_i * cur_pt) * self.src_pts[k][0]
+ - np.sum(pt_j * cur_pt) * self.src_pts[k][1]
+ )
+ tmp_pt[1] = (
+ -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0]
+ + np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
+ )
+ tmp_pt *= w[k] / miu_s
new_pt += tmp_pt
new_pt += qstar
@@ -138,11 +142,21 @@ def gen_img(self):
di = np.reshape(np.arange(h), (-1, 1))
dj = np.reshape(np.arange(w), (1, -1))
delta_x = self.__bilinear_interp(
- di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
- self.rdx[ni, j], self.rdx[ni, nj])
+ di / h,
+ dj / w,
+ self.rdx[i, j],
+ self.rdx[i, nj],
+ self.rdx[ni, j],
+ self.rdx[ni, nj],
+ )
delta_y = self.__bilinear_interp(
- di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
- self.rdy[ni, j], self.rdy[ni, nj])
+ di / h,
+ dj / w,
+ self.rdy[i, j],
+ self.rdy[i, nj],
+ self.rdy[ni, j],
+ self.rdy[ni, nj],
+ )
nx = j + dj + delta_x * self.trans_ratio
ny = i + di + delta_y * self.trans_ratio
nx = np.clip(nx, 0, src_w - 1)
@@ -158,9 +172,14 @@ def gen_img(self):
else:
x = ny - nyi
y = nx - nxi
- dst[i:i + h, j:j + w] = self.__bilinear_interp(
- x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
- self.src[nyi1, nxi], self.src[nyi1, nxi1])
+ dst[i : i + h, j : j + w] = self.__bilinear_interp(
+ x,
+ y,
+ self.src[nyi, nxi],
+ self.src[nyi, nxi1],
+ self.src[nyi1, nxi],
+ self.src[nyi1, nxi1],
+ )
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py
index 73f7dcdf71..812f20b541 100644
--- a/ppocr/data/imaug/vqa/__init__.py
+++ b/ppocr/data/imaug/vqa/__init__.py
@@ -12,9 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations
+from .token import (
+ VQATokenPad,
+ VQASerTokenChunk,
+ VQAReTokenChunk,
+ VQAReTokenRelation,
+ TensorizeEntitiesRelations,
+)
__all__ = [
- 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation',
- 'TensorizeEntitiesRelations'
+ "VQATokenPad",
+ "VQASerTokenChunk",
+ "VQAReTokenChunk",
+ "VQAReTokenRelation",
+ "TensorizeEntitiesRelations",
]
diff --git a/ppocr/data/imaug/vqa/augment.py b/ppocr/data/imaug/vqa/augment.py
index b95fcdf0f0..d4f4cf2b12 100644
--- a/ppocr/data/imaug/vqa/augment.py
+++ b/ppocr/data/imaug/vqa/augment.py
@@ -23,8 +23,9 @@ def order_by_tbyx(ocr_info):
res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
for i in range(len(res) - 1):
for j in range(i, 0, -1):
- if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
- (res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
+ if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and (
+ res[j + 1]["bbox"][0] < res[j]["bbox"][0]
+ ):
tmp = deepcopy(res[j])
res[j] = deepcopy(res[j + 1])
res[j + 1] = deepcopy(tmp)
diff --git a/ppocr/data/imaug/vqa/token/__init__.py b/ppocr/data/imaug/vqa/token/__init__.py
index 5fbaa43db9..e349dd7a03 100644
--- a/ppocr/data/imaug/vqa/token/__init__.py
+++ b/ppocr/data/imaug/vqa/token/__init__.py
@@ -15,4 +15,4 @@
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation
-from .vqa_re_convert import TensorizeEntitiesRelations
\ No newline at end of file
+from .vqa_re_convert import TensorizeEntitiesRelations
diff --git a/ppocr/data/imaug/vqa/token/vqa_re_convert.py b/ppocr/data/imaug/vqa/token/vqa_re_convert.py
index 86962f2590..fa149156c2 100644
--- a/ppocr/data/imaug/vqa/token/vqa_re_convert.py
+++ b/ppocr/data/imaug/vqa/token/vqa_re_convert.py
@@ -21,31 +21,29 @@ def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.infer_mode = infer_mode
def __call__(self, data):
- entities = data['entities']
- relations = data['relations']
+ entities = data["entities"]
+ relations = data["relations"]
entities_new = np.full(
- shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64')
- entities_new[0, 0] = len(entities['start'])
- entities_new[0, 1] = len(entities['end'])
- entities_new[0, 2] = len(entities['label'])
- entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[
- 'start'])
- entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end'])
- entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[
- 'label'])
+ shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype="int64"
+ )
+ entities_new[0, 0] = len(entities["start"])
+ entities_new[0, 1] = len(entities["end"])
+ entities_new[0, 2] = len(entities["label"])
+ entities_new[1 : len(entities["start"]) + 1, 0] = np.array(entities["start"])
+ entities_new[1 : len(entities["end"]) + 1, 1] = np.array(entities["end"])
+ entities_new[1 : len(entities["label"]) + 1, 2] = np.array(entities["label"])
relations_new = np.full(
shape=[self.max_seq_len * self.max_seq_len + 1, 2],
fill_value=-1,
- dtype='int64')
- relations_new[0, 0] = len(relations['head'])
- relations_new[0, 1] = len(relations['tail'])
- relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[
- 'head'])
- relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[
- 'tail'])
+ dtype="int64",
+ )
+ relations_new[0, 0] = len(relations["head"])
+ relations_new[0, 1] = len(relations["tail"])
+ relations_new[1 : len(relations["head"]) + 1, 0] = np.array(relations["head"])
+ relations_new[1 : len(relations["tail"]) + 1, 1] = np.array(relations["tail"])
- data['entities'] = entities_new
- data['relations'] = relations_new
+ data["entities"] = entities_new
+ data["relations"] = relations_new
return data
diff --git a/ppocr/data/imaug/vqa/token/vqa_token_chunk.py b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py
index 1fa949e688..d46a47d26c 100644
--- a/ppocr/data/imaug/vqa/token/vqa_token_chunk.py
+++ b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py
@@ -22,21 +22,24 @@ def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
def __call__(self, data):
encoded_inputs_all = []
- seq_len = len(data['input_ids'])
+ seq_len = len(data["input_ids"])
for index in range(0, seq_len, self.max_seq_len):
chunk_beg = index
chunk_end = min(index + self.max_seq_len, seq_len)
encoded_inputs_example = {}
for key in data:
if key in [
- 'label', 'input_ids', 'labels', 'token_type_ids',
- 'bbox', 'attention_mask'
+ "label",
+ "input_ids",
+ "labels",
+ "token_type_ids",
+ "bbox",
+ "attention_mask",
]:
- if self.infer_mode and key == 'labels':
+ if self.infer_mode and key == "labels":
encoded_inputs_example[key] = data[key]
else:
- encoded_inputs_example[key] = data[key][chunk_beg:
- chunk_end]
+ encoded_inputs_example[key] = data[key][chunk_beg:chunk_end]
else:
encoded_inputs_example[key] = data[key]
@@ -47,43 +50,47 @@ def __call__(self, data):
class VQAReTokenChunk(object):
- def __init__(self,
- max_seq_len=512,
- entities_labels=None,
- infer_mode=False,
- **kwargs):
+ def __init__(
+ self, max_seq_len=512, entities_labels=None, infer_mode=False, **kwargs
+ ):
self.max_seq_len = max_seq_len
- self.entities_labels = {
- 'HEADER': 0,
- 'QUESTION': 1,
- 'ANSWER': 2
- } if entities_labels is None else entities_labels
+ self.entities_labels = (
+ {"HEADER": 0, "QUESTION": 1, "ANSWER": 2}
+ if entities_labels is None
+ else entities_labels
+ )
self.infer_mode = infer_mode
def __call__(self, data):
# prepare data
- entities = data.pop('entities')
- relations = data.pop('relations')
+ entities = data.pop("entities")
+ relations = data.pop("relations")
encoded_inputs_all = []
for index in range(0, len(data["input_ids"]), self.max_seq_len):
item = {}
for key in data:
if key in [
- 'label', 'input_ids', 'labels', 'token_type_ids',
- 'bbox', 'attention_mask'
+ "label",
+ "input_ids",
+ "labels",
+ "token_type_ids",
+ "bbox",
+ "attention_mask",
]:
- if self.infer_mode and key == 'labels':
+ if self.infer_mode and key == "labels":
item[key] = data[key]
else:
- item[key] = data[key][index:index + self.max_seq_len]
+ item[key] = data[key][index : index + self.max_seq_len]
else:
item[key] = data[key]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
- if (index <= entity["start"] < index + self.max_seq_len and
- index <= entity["end"] < index + self.max_seq_len):
+ if (
+ index <= entity["start"] < index + self.max_seq_len
+ and index <= entity["end"] < index + self.max_seq_len
+ ):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
@@ -92,22 +99,27 @@ def __call__(self, data):
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
- if (index <= relation["start_index"] < index + self.max_seq_len
- and index <= relation["end_index"] <
- index + self.max_seq_len):
- relations_in_this_span.append({
- "head": global_to_local_map[relation["head"]],
- "tail": global_to_local_map[relation["tail"]],
- "start_index": relation["start_index"] - index,
- "end_index": relation["end_index"] - index,
- })
- item.update({
- "entities": self.reformat(entities_in_this_span),
- "relations": self.reformat(relations_in_this_span),
- })
- if len(item['entities']) > 0:
- item['entities']['label'] = [
- self.entities_labels[x] for x in item['entities']['label']
+ if (
+ index <= relation["start_index"] < index + self.max_seq_len
+ and index <= relation["end_index"] < index + self.max_seq_len
+ ):
+ relations_in_this_span.append(
+ {
+ "head": global_to_local_map[relation["head"]],
+ "tail": global_to_local_map[relation["tail"]],
+ "start_index": relation["start_index"] - index,
+ "end_index": relation["end_index"] - index,
+ }
+ )
+ item.update(
+ {
+ "entities": self.reformat(entities_in_this_span),
+ "relations": self.reformat(relations_in_this_span),
+ }
+ )
+ if len(item["entities"]) > 0:
+ item["entities"]["label"] = [
+ self.entities_labels[x] for x in item["entities"]["label"]
]
encoded_inputs_all.append(item)
if len(encoded_inputs_all) == 0:
diff --git a/ppocr/data/imaug/vqa/token/vqa_token_pad.py b/ppocr/data/imaug/vqa/token/vqa_token_pad.py
index 8e5a20f95f..5983cb75e8 100644
--- a/ppocr/data/imaug/vqa/token/vqa_token_pad.py
+++ b/ppocr/data/imaug/vqa/token/vqa_token_pad.py
@@ -16,16 +16,18 @@
class VQATokenPad(object):
- def __init__(self,
- max_seq_len=512,
- pad_to_max_seq_len=True,
- return_attention_mask=True,
- return_token_type_ids=True,
- truncation_strategy="longest_first",
- return_overflowing_tokens=False,
- return_special_tokens_mask=False,
- infer_mode=False,
- **kwargs):
+ def __init__(
+ self,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ truncation_strategy="longest_first",
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False,
+ infer_mode=False,
+ **kwargs
+ ):
self.max_seq_len = max_seq_len
self.pad_to_max_seq_len = max_seq_len
self.return_attention_mask = return_attention_mask
@@ -37,53 +39,61 @@ def __init__(self,
self.infer_mode = infer_mode
def __call__(self, data):
- needs_to_be_padded = self.pad_to_max_seq_len and len(data[
- "input_ids"]) < self.max_seq_len
+ needs_to_be_padded = (
+ self.pad_to_max_seq_len and len(data["input_ids"]) < self.max_seq_len
+ )
if needs_to_be_padded:
- if 'tokenizer_params' in data:
- tokenizer_params = data.pop('tokenizer_params')
+ if "tokenizer_params" in data:
+ tokenizer_params = data.pop("tokenizer_params")
else:
tokenizer_params = dict(
- padding_side='right', pad_token_type_id=0, pad_token_id=1)
+ padding_side="right", pad_token_type_id=0, pad_token_id=1
+ )
difference = self.max_seq_len - len(data["input_ids"])
- if tokenizer_params['padding_side'] == 'right':
+ if tokenizer_params["padding_side"] == "right":
if self.return_attention_mask:
- data["attention_mask"] = [1] * len(data[
- "input_ids"]) + [0] * difference
+ data["attention_mask"] = [1] * len(data["input_ids"]) + [
+ 0
+ ] * difference
if self.return_token_type_ids:
data["token_type_ids"] = (
- data["token_type_ids"] +
- [tokenizer_params['pad_token_type_id']] * difference)
+ data["token_type_ids"]
+ + [tokenizer_params["pad_token_type_id"]] * difference
+ )
if self.return_special_tokens_mask:
- data["special_tokens_mask"] = data[
- "special_tokens_mask"] + [1] * difference
- data["input_ids"] = data["input_ids"] + [
- tokenizer_params['pad_token_id']
- ] * difference
+ data["special_tokens_mask"] = (
+ data["special_tokens_mask"] + [1] * difference
+ )
+ data["input_ids"] = (
+ data["input_ids"] + [tokenizer_params["pad_token_id"]] * difference
+ )
if not self.infer_mode:
- data["labels"] = data[
- "labels"] + [self.pad_token_label_id] * difference
+ data["labels"] = (
+ data["labels"] + [self.pad_token_label_id] * difference
+ )
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
- elif tokenizer_params['padding_side'] == 'left':
+ elif tokenizer_params["padding_side"] == "left":
if self.return_attention_mask:
- data["attention_mask"] = [0] * difference + [
- 1
- ] * len(data["input_ids"])
+ data["attention_mask"] = [0] * difference + [1] * len(
+ data["input_ids"]
+ )
if self.return_token_type_ids:
- data["token_type_ids"] = (
- [tokenizer_params['pad_token_type_id']] * difference +
- data["token_type_ids"])
+ data["token_type_ids"] = [
+ tokenizer_params["pad_token_type_id"]
+ ] * difference + data["token_type_ids"]
if self.return_special_tokens_mask:
- data["special_tokens_mask"] = [
- 1
- ] * difference + data["special_tokens_mask"]
- data["input_ids"] = [tokenizer_params['pad_token_id']
- ] * difference + data["input_ids"]
+ data["special_tokens_mask"] = [1] * difference + data[
+ "special_tokens_mask"
+ ]
+ data["input_ids"] = [
+ tokenizer_params["pad_token_id"]
+ ] * difference + data["input_ids"]
if not self.infer_mode:
- data["labels"] = [self.pad_token_label_id
- ] * difference + data["labels"]
+ data["labels"] = [self.pad_token_label_id] * difference + data[
+ "labels"
+ ]
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
else:
if self.return_attention_mask:
@@ -91,14 +101,17 @@ def __call__(self, data):
for key in data:
if key in [
- 'input_ids', 'labels', 'token_type_ids', 'bbox',
- 'attention_mask'
+ "input_ids",
+ "labels",
+ "token_type_ids",
+ "bbox",
+ "attention_mask",
]:
if self.infer_mode:
- if key != 'labels':
+ if key != "labels":
length = min(len(data[key]), self.max_seq_len)
data[key] = data[key][:length]
else:
continue
- data[key] = np.array(data[key], dtype='int64')
+ data[key] = np.array(data[key], dtype="int64")
return data
diff --git a/ppocr/data/imaug/vqa/token/vqa_token_relation.py b/ppocr/data/imaug/vqa/token/vqa_token_relation.py
index 293988ff85..1946c58305 100644
--- a/ppocr/data/imaug/vqa/token/vqa_token_relation.py
+++ b/ppocr/data/imaug/vqa/token/vqa_token_relation.py
@@ -21,42 +21,51 @@ def __call__(self, data):
"""
build relations
"""
- entities = data['entities']
- relations = data['relations']
- id2label = data.pop('id2label')
- empty_entity = data.pop('empty_entity')
- entity_id_to_index_map = data.pop('entity_id_to_index_map')
+ entities = data["entities"]
+ relations = data["relations"]
+ id2label = data.pop("id2label")
+ empty_entity = data.pop("empty_entity")
+ entity_id_to_index_map = data.pop("entity_id_to_index_map")
relations = list(set(relations))
relations = [
- rel for rel in relations
+ rel
+ for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
- kv_relations.append({
- "head": entity_id_to_index_map[rel[0]],
- "tail": entity_id_to_index_map[rel[1]]
- })
+ kv_relations.append(
+ {
+ "head": entity_id_to_index_map[rel[0]],
+ "tail": entity_id_to_index_map[rel[1]],
+ }
+ )
elif pair == ["answer", "question"]:
- kv_relations.append({
- "head": entity_id_to_index_map[rel[1]],
- "tail": entity_id_to_index_map[rel[0]]
- })
+ kv_relations.append(
+ {
+ "head": entity_id_to_index_map[rel[1]],
+ "tail": entity_id_to_index_map[rel[0]],
+ }
+ )
else:
continue
relations = sorted(
- [{
- "head": rel["head"],
- "tail": rel["tail"],
- "start_index": self.get_relation_span(rel, entities)[0],
- "end_index": self.get_relation_span(rel, entities)[1],
- } for rel in kv_relations],
- key=lambda x: x["head"], )
+ [
+ {
+ "head": rel["head"],
+ "tail": rel["tail"],
+ "start_index": self.get_relation_span(rel, entities)[0],
+ "end_index": self.get_relation_span(rel, entities)[1],
+ }
+ for rel in kv_relations
+ ],
+ key=lambda x: x["head"],
+ )
- data['relations'] = relations
+ data["relations"] = relations
return data
def get_relation_span(self, rel, entities):
diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py
index f3efb60428..778adaa945 100644
--- a/ppocr/data/lmdb_dataset.py
+++ b/ppocr/data/lmdb_dataset.py
@@ -28,21 +28,20 @@ class LMDBDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(LMDBDataSet, self).__init__()
- global_config = config['Global']
- dataset_config = config[mode]['dataset']
- loader_config = config[mode]['loader']
- batch_size = loader_config['batch_size_per_card']
- data_dir = dataset_config['data_dir']
- self.do_shuffle = loader_config['shuffle']
+ global_config = config["Global"]
+ dataset_config = config[mode]["dataset"]
+ loader_config = config[mode]["loader"]
+ batch_size = loader_config["batch_size_per_card"]
+ data_dir = dataset_config["data_dir"]
+ self.do_shuffle = loader_config["shuffle"]
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
logger.info("Initialize indexs of datasets:%s" % data_dir)
self.data_idx_order_list = self.dataset_traversal()
if self.do_shuffle:
np.random.shuffle(self.data_idx_order_list)
- self.ops = create_operators(dataset_config['transforms'], global_config)
- self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
- 1)
+ self.ops = create_operators(dataset_config["transforms"], global_config)
+ self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 1)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
@@ -50,7 +49,7 @@ def __init__(self, config, mode, logger, seed=None):
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
- for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
+ for dirpath, dirnames, filenames in os.walk(data_dir + "/"):
if not dirnames:
env = lmdb.open(
dirpath,
@@ -58,11 +57,16 @@ def load_hierarchical_lmdb_dataset(self, data_dir):
readonly=True,
lock=False,
readahead=False,
- meminit=False)
+ meminit=False,
+ )
txn = env.begin(write=False)
- num_samples = int(txn.get('num-samples'.encode()))
- lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
- "txn":txn, "num_samples":num_samples}
+ num_samples = int(txn.get("num-samples".encode()))
+ lmdb_sets[dataset_idx] = {
+ "dirpath": dirpath,
+ "env": env,
+ "txn": txn,
+ "num_samples": num_samples,
+ }
dataset_idx += 1
return lmdb_sets
@@ -70,15 +74,14 @@ def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets)
total_sample_num = 0
for lno in range(lmdb_num):
- total_sample_num += self.lmdb_sets[lno]['num_samples']
+ total_sample_num += self.lmdb_sets[lno]["num_samples"]
data_idx_order_list = np.zeros((total_sample_num, 2))
beg_idx = 0
for lno in range(lmdb_num):
- tmp_sample_num = self.lmdb_sets[lno]['num_samples']
+ tmp_sample_num = self.lmdb_sets[lno]["num_samples"]
end_idx = beg_idx + tmp_sample_num
data_idx_order_list[beg_idx:end_idx, 0] = lno
- data_idx_order_list[beg_idx:end_idx, 1] \
- = list(range(tmp_sample_num))
+ data_idx_order_list[beg_idx:end_idx, 1] = list(range(tmp_sample_num))
data_idx_order_list[beg_idx:end_idx, 1] += 1
beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list
@@ -87,7 +90,7 @@ def get_img_data(self, value):
"""get_img_data"""
if not value:
return None
- imgdata = np.frombuffer(value, dtype='uint8')
+ imgdata = np.frombuffer(value, dtype="uint8")
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
@@ -98,23 +101,23 @@ def get_img_data(self, value):
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
- if hasattr(op, 'ext_data_num'):
- ext_data_num = getattr(op, 'ext_data_num')
+ if hasattr(op, "ext_data_num"):
+ ext_data_num = getattr(op, "ext_data_num")
break
- load_data_ops = self.ops[:self.ext_op_transform_idx]
+ load_data_ops = self.ops[: self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
- lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
- len(self))]
+ lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(len(self))]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(
- self.lmdb_sets[lmdb_idx]['txn'], file_idx)
+ self.lmdb_sets[lmdb_idx]["txn"], file_idx
+ )
if sample_info is None:
continue
img, label = sample_info
- data = {'image': img, 'label': label}
+ data = {"image": img, "label": label}
data = transform(data, load_data_ops)
if data is None:
continue
@@ -122,12 +125,12 @@ def get_ext_data(self):
return ext_data
def get_lmdb_sample_info(self, txn, index):
- label_key = 'label-%09d'.encode() % index
+ label_key = "label-%09d".encode() % index
label = txn.get(label_key)
if label is None:
return None
- label = label.decode('utf-8')
- img_key = 'image-%09d'.encode() % index
+ label = label.decode("utf-8")
+ img_key = "image-%09d".encode() % index
imgbuf = txn.get(img_key)
return imgbuf, label
@@ -135,13 +138,14 @@ def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
- sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
- file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]["txn"], file_idx
+ )
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info
- data = {'image': img, 'label': label}
- data['ext_data'] = self.get_ext_data()
+ data = {"image": img, "label": label}
+ data["ext_data"] = self.get_ext_data()
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
@@ -152,7 +156,7 @@ def __len__(self):
class LMDBDataSetSR(LMDBDataSet):
- def buf2PIL(self, txn, key, type='RGB'):
+ def buf2PIL(self, txn, key, type="RGB"):
imgbuf = txn.get(key)
buf = six.BytesIO()
buf.write(imgbuf)
@@ -162,29 +166,29 @@ def buf2PIL(self, txn, key, type='RGB'):
def str_filt(self, str_, voc_type):
alpha_dict = {
- 'digit': string.digits,
- 'lower': string.digits + string.ascii_lowercase,
- 'upper': string.digits + string.ascii_letters,
- 'all': string.digits + string.ascii_letters + string.punctuation
+ "digit": string.digits,
+ "lower": string.digits + string.ascii_lowercase,
+ "upper": string.digits + string.ascii_letters,
+ "all": string.digits + string.ascii_letters + string.punctuation,
}
- if voc_type == 'lower':
+ if voc_type == "lower":
str_ = str_.lower()
for char in str_:
if char not in alpha_dict[voc_type]:
- str_ = str_.replace(char, '')
+ str_ = str_.replace(char, "")
return str_
def get_lmdb_sample_info(self, txn, index):
- self.voc_type = 'upper'
+ self.voc_type = "upper"
self.max_len = 100
self.test = False
- label_key = b'label-%09d' % index
+ label_key = b"label-%09d" % index
word = str(txn.get(label_key).decode())
- img_HR_key = b'image_hr-%09d' % index # 128*32
- img_lr_key = b'image_lr-%09d' % index # 64*16
+ img_HR_key = b"image_hr-%09d" % index # 128*32
+ img_lr_key = b"image_lr-%09d" % index # 64*16
try:
- img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
- img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
+ img_HR = self.buf2PIL(txn, img_HR_key, "RGB")
+ img_lr = self.buf2PIL(txn, img_lr_key, "RGB")
except IOError or len(word) > self.max_len:
return self[index + 1]
label_str = self.str_filt(word, self.voc_type)
@@ -194,12 +198,13 @@ def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
- sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
- file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]["txn"], file_idx
+ )
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img_HR, img_lr, label_str = sample_info
- data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
+ data = {"image_hr": img_HR, "image_lr": img_lr, "label": label_str}
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
@@ -216,18 +221,23 @@ def load_hierarchical_lmdb_dataset(self, data_dir):
readonly=True,
lock=False,
readahead=False,
- meminit=False)
+ meminit=False,
+ )
txn = env.begin(write=False)
num_samples = int(pickle.loads(txn.get(b"__len__")))
- lmdb_sets[dataset_idx] = {"dirpath":data_dir, "env":env, \
- "txn":txn, "num_samples":num_samples}
+ lmdb_sets[dataset_idx] = {
+ "dirpath": data_dir,
+ "env": env,
+ "txn": txn,
+ "num_samples": num_samples,
+ }
return lmdb_sets
def get_img_data(self, value):
"""get_img_data"""
if not value:
return None
- imgdata = np.frombuffer(value, dtype='uint8')
+ imgdata = np.frombuffer(value, dtype="uint8")
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
@@ -243,7 +253,7 @@ def convert_bbox(bbox_str_list):
return bbox_list
try:
- data = pickle.loads(txn.get(str(index).encode('utf8')))
+ data = pickle.loads(txn.get(str(index).encode("utf8")))
except:
return None
@@ -252,33 +262,34 @@ def convert_bbox(bbox_str_list):
bytes = data[1]
info_lines = data[2] # raw data from TableMASTER annotation file.
# parse info_lines
- raw_data = info_lines.strip().split('\n')
- raw_name, text = raw_data[0], raw_data[
- 1] # don't filter the samples's length over max_seq_len.
- text = text.split(',')
+ raw_data = info_lines.strip().split("\n")
+ raw_name, text = (
+ raw_data[0],
+ raw_data[1],
+ ) # don't filter the samples's length over max_seq_len.
+ text = text.split(",")
bbox_str_list = raw_data[2:]
- bbox_split = ','
- bboxes = [{
- 'bbox': convert_bbox(bsl.strip().split(bbox_split)),
- 'tokens': ['1', '2']
- } for bsl in bbox_str_list]
+ bbox_split = ","
+ bboxes = [
+ {"bbox": convert_bbox(bsl.strip().split(bbox_split)), "tokens": ["1", "2"]}
+ for bsl in bbox_str_list
+ ]
# advance parse bbox
# import pdb;pdb.set_trace()
line_info = {}
- line_info['file_name'] = file_name
- line_info['structure'] = text
- line_info['cells'] = bboxes
- line_info['image'] = bytes
+ line_info["file_name"] = file_name
+ line_info["structure"] = text
+ line_info["cells"] = bboxes
+ line_info["image"] = bytes
return line_info
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
- data = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
- file_idx)
+ data = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]["txn"], file_idx)
if data is None:
return self.__getitem__(np.random.randint(self.__len__()))
outs = transform(data, self.ops)
diff --git a/ppocr/data/multi_scale_sampler.py b/ppocr/data/multi_scale_sampler.py
index 45793e2ba1..4ab38fc4e6 100644
--- a/ppocr/data/multi_scale_sampler.py
+++ b/ppocr/data/multi_scale_sampler.py
@@ -7,24 +7,26 @@
class MultiScaleSampler(Sampler):
- def __init__(self,
- data_source,
- scales,
- first_bs=128,
- fix_bs=True,
- divided_factor=[8, 16],
- is_training=True,
- ratio_wh=0.8,
- max_w=480.,
- seed=None):
+ def __init__(
+ self,
+ data_source,
+ scales,
+ first_bs=128,
+ fix_bs=True,
+ divided_factor=[8, 16],
+ is_training=True,
+ ratio_wh=0.8,
+ max_w=480.0,
+ seed=None,
+ ):
"""
- multi scale samper
- Args:
- data_source(dataset)
- scales(list): several scales for image resolution
- first_bs(int): batch size for the first scale in scales
- divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
- is_training(boolean): mode
+ multi scale samper
+ Args:
+ data_source(dataset)
+ scales(list): several scales for image resolution
+ first_bs(int): batch size for the first scale in scales
+ divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
+ is_training(boolean): mode
"""
# min. and max. spatial dimensions
self.data_source = data_source
@@ -62,17 +64,15 @@ def __init__(self,
# ImageNet models down-sample images by a factor of 32.
# Ensure that width and height dimensions are multiples are multiple of 32.
width_dims = [
- int((w // divided_factor[0]) * divided_factor[0])
- for w in width_dims
+ int((w // divided_factor[0]) * divided_factor[0]) for w in width_dims
]
height_dims = [
- int((h // divided_factor[1]) * divided_factor[1])
- for h in height_dims
+ int((h // divided_factor[1]) * divided_factor[1]) for h in height_dims
]
img_batch_pairs = list()
base_elements = base_im_w * base_im_h * base_batch_size
- for (h, w) in zip(height_dims, width_dims):
+ for h, w in zip(height_dims, width_dims):
if fix_bs:
batch_size = base_batch_size
else:
@@ -92,16 +92,14 @@ def __init__(self,
self.batch_list = []
self.current = 0
last_index = num_samples_per_replica * num_replicas
- indices_rank_i = self.img_indices[self.rank:last_index:
- self.num_replicas]
+ indices_rank_i = self.img_indices[self.rank : last_index : self.num_replicas]
while self.current < self.n_samples_per_replica:
for curr_w, curr_h, curr_bsz in self.img_batch_pairs:
- end_index = min(self.current + curr_bsz,
- self.n_samples_per_replica)
- batch_ids = indices_rank_i[self.current:end_index]
+ end_index = min(self.current + curr_bsz, self.n_samples_per_replica)
+ batch_ids = indices_rank_i[self.current : end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
- batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
+ batch_ids += indices_rank_i[: (curr_bsz - n_batch_samples)]
self.current += curr_bsz
if len(batch_ids) > 0:
@@ -110,9 +108,7 @@ def __init__(self,
random.shuffle(self.batch_list)
self.length = len(self.batch_list)
self.batchs_in_one_epoch = self.iter()
- self.batchs_in_one_epoch_id = [
- i for i in range(len(self.batchs_in_one_epoch))
- ]
+ self.batchs_in_one_epoch_id = [i for i in range(len(self.batchs_in_one_epoch))]
def __iter__(self):
if self.seed is None:
@@ -133,11 +129,13 @@ def iter(self):
if not self.ds_width:
random.shuffle(self.img_indices)
random.shuffle(self.img_batch_pairs)
- indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
- self.num_replicas]
+ indices_rank_i = self.img_indices[
+ self.rank : len(self.img_indices) : self.num_replicas
+ ]
else:
- indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
- self.num_replicas]
+ indices_rank_i = self.img_indices[
+ self.rank : len(self.img_indices) : self.num_replicas
+ ]
start_index = 0
batchs_in_one_epoch = []
@@ -147,19 +145,21 @@ def iter(self):
batch_ids = indices_rank_i[start_index:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
- batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
+ batch_ids += indices_rank_i[: (curr_bsz - n_batch_samples)]
start_index += curr_bsz
if len(batch_ids) > 0:
if self.ds_width:
- wh_ratio_current = self.wh_ratio[self.wh_ratio_sort[
- batch_ids]]
+ wh_ratio_current = self.wh_ratio[self.wh_ratio_sort[batch_ids]]
ratio_current = wh_ratio_current.mean()
- ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h
+ ratio_current = (
+ ratio_current
+ if ratio_current * curr_h < self.max_w
+ else self.max_w / curr_h
+ )
else:
ratio_current = None
- batch = [(curr_w, curr_h, b_id, ratio_current)
- for b_id in batch_ids]
+ batch = [(curr_w, curr_h, b_id, ratio_current) for b_id in batch_ids]
# yield batch
batchs_in_one_epoch.append(batch)
return batchs_in_one_epoch
diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py
index 6f80179c4e..839ac72254 100644
--- a/ppocr/data/pgnet_dataset.py
+++ b/ppocr/data/pgnet_dataset.py
@@ -25,21 +25,21 @@ def __init__(self, config, mode, logger, seed=None):
self.logger = logger
self.seed = seed
self.mode = mode
- global_config = config['Global']
- dataset_config = config[mode]['dataset']
- loader_config = config[mode]['loader']
+ global_config = config["Global"]
+ dataset_config = config[mode]["dataset"]
+ loader_config = config[mode]["loader"]
- self.delimiter = dataset_config.get('delimiter', '\t')
- label_file_list = dataset_config.pop('label_file_list')
+ self.delimiter = dataset_config.get("delimiter", "\t")
+ label_file_list = dataset_config.pop("label_file_list")
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
- assert len(
- ratio_list
- ) == data_source_num, "The length of ratio_list should be the same as the file_list."
- self.data_dir = dataset_config['data_dir']
- self.do_shuffle = loader_config['shuffle']
+ assert (
+ len(ratio_list) == data_source_num
+ ), "The length of ratio_list should be the same as the file_list."
+ self.data_dir = dataset_config["data_dir"]
+ self.do_shuffle = loader_config["shuffle"]
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
@@ -47,7 +47,7 @@ def __init__(self, config, mode, logger, seed=None):
if mode.lower() == "train":
self.shuffle_data_random()
- self.ops = create_operators(dataset_config['transforms'], global_config)
+ self.ops = create_operators(dataset_config["transforms"], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
@@ -66,8 +66,7 @@ def get_image_info_list(self, file_list, ratio_list):
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
- lines = random.sample(lines,
- round(len(lines) * ratio_list[idx]))
+ lines = random.sample(lines, round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
@@ -76,27 +75,29 @@ def __getitem__(self, idx):
data_line = self.data_lines[file_idx]
img_id = 0
try:
- data_line = data_line.decode('utf-8')
+ data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
- if self.mode.lower() == 'eval':
+ if self.mode.lower() == "eval":
try:
img_id = int(data_line.split(".")[0][7:])
except:
img_id = 0
- data = {'img_path': img_path, 'label': label, 'img_id': img_id}
+ data = {"img_path": img_path, "label": label, "img_id": img_id}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
- with open(data['img_path'], 'rb') as f:
+ with open(data["img_path"], "rb") as f:
img = f.read()
- data['image'] = img
+ data["image"] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- self.data_idx_order_list[idx], e))
+ self.data_idx_order_list[idx], e
+ )
+ )
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py
index 642d3eb196..2365b1318f 100644
--- a/ppocr/data/pubtab_dataset.py
+++ b/ppocr/data/pubtab_dataset.py
@@ -26,22 +26,22 @@ def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
self.logger = logger
- global_config = config['Global']
- dataset_config = config[mode]['dataset']
- loader_config = config[mode]['loader']
+ global_config = config["Global"]
+ dataset_config = config[mode]["dataset"]
+ loader_config = config[mode]["loader"]
- label_file_list = dataset_config.pop('label_file_list')
+ label_file_list = dataset_config.pop("label_file_list")
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
- assert len(
- ratio_list
- ) == data_source_num, "The length of ratio_list should be the same as the file_list."
+ assert (
+ len(ratio_list) == data_source_num
+ ), "The length of ratio_list should be the same as the file_list."
- self.data_dir = dataset_config['data_dir']
- self.do_shuffle = loader_config['shuffle']
+ self.data_dir = dataset_config["data_dir"]
+ self.do_shuffle = loader_config["shuffle"]
self.seed = seed
self.mode = mode.lower()
@@ -51,7 +51,7 @@ def __init__(self, config, mode, logger, seed=None):
if mode.lower() == "train" and self.do_shuffle:
self.shuffle_data_random()
- self.ops = create_operators(dataset_config['transforms'], global_config)
+ self.ops = create_operators(dataset_config["transforms"], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
@@ -63,19 +63,18 @@ def get_image_info_list(self, file_list, ratio_list):
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
- lines = random.sample(lines,
- round(len(lines) * ratio_list[idx]))
+ lines = random.sample(lines, round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def check(self, max_text_length):
data_lines = []
for line in self.data_lines:
- data_line = line.decode('utf-8').strip("\n")
+ data_line = line.decode("utf-8").strip("\n")
info = json.loads(data_line)
- file_name = info['filename']
- cells = info['html']['cells'].copy()
- structure = info['html']['structure']['tokens'].copy()
+ file_name = info["filename"]
+ cells = info["html"]["cells"].copy()
+ structure = info["html"]["structure"]["tokens"].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
@@ -96,36 +95,42 @@ def shuffle_data_random(self):
def __getitem__(self, idx):
try:
data_line = self.data_lines[idx]
- data_line = data_line.decode('utf-8').strip("\n")
+ data_line = data_line.decode("utf-8").strip("\n")
info = json.loads(data_line)
- file_name = info['filename']
- cells = info['html']['cells'].copy()
- structure = info['html']['structure']['tokens'].copy()
+ file_name = info["filename"]
+ cells = info["html"]["cells"].copy()
+ structure = info["html"]["structure"]["tokens"].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
data = {
- 'img_path': img_path,
- 'cells': cells,
- 'structure': structure,
- 'file_name': file_name
+ "img_path": img_path,
+ "cells": cells,
+ "structure": structure,
+ "file_name": file_name,
}
- with open(data['img_path'], 'rb') as f:
+ with open(data["img_path"], "rb") as f:
img = f.read()
- data['image'] = img
+ data["image"] = img
outs = transform(data, self.ops)
except:
import traceback
+
err = traceback.format_exc()
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- data_line, err))
+ data_line, err
+ )
+ )
outs = None
if outs is None:
- rnd_idx = np.random.randint(self.__len__(
- )) if self.mode == "train" else (idx + 1) % self.__len__()
+ rnd_idx = (
+ np.random.randint(self.__len__())
+ if self.mode == "train"
+ else (idx + 1) % self.__len__()
+ )
return self.__getitem__(rnd_idx)
return outs
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index f7c4c8f1a2..0a56fab2eb 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -28,22 +28,22 @@ def __init__(self, config, mode, logger, seed=None):
self.logger = logger
self.mode = mode.lower()
- global_config = config['Global']
- dataset_config = config[mode]['dataset']
- loader_config = config[mode]['loader']
+ global_config = config["Global"]
+ dataset_config = config[mode]["dataset"]
+ loader_config = config[mode]["loader"]
- self.delimiter = dataset_config.get('delimiter', '\t')
- label_file_list = dataset_config.pop('label_file_list')
+ self.delimiter = dataset_config.get("delimiter", "\t")
+ label_file_list = dataset_config.pop("label_file_list")
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", 1.0)
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
- assert len(
- ratio_list
- ) == data_source_num, "The length of ratio_list should be the same as the file_list."
- self.data_dir = dataset_config['data_dir']
- self.do_shuffle = loader_config['shuffle']
+ assert (
+ len(ratio_list) == data_source_num
+ ), "The length of ratio_list should be the same as the file_list."
+ self.data_dir = dataset_config["data_dir"]
+ self.do_shuffle = loader_config["shuffle"]
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
@@ -53,24 +53,29 @@ def __init__(self, config, mode, logger, seed=None):
self.set_epoch_as_seed(self.seed, dataset_config)
- self.ops = create_operators(dataset_config['transforms'], global_config)
- self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
- 2)
+ self.ops = create_operators(dataset_config["transforms"], global_config)
+ self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2)
self.need_reset = True in [x < 1 for x in ratio_list]
def set_epoch_as_seed(self, seed, dataset_config):
- if self.mode == 'train':
+ if self.mode == "train":
try:
- border_map_id = [index
- for index, dictionary in enumerate(dataset_config['transforms'])
- if 'MakeBorderMap' in dictionary][0]
- shrink_map_id = [index
- for index, dictionary in enumerate(dataset_config['transforms'])
- if 'MakeShrinkMap' in dictionary][0]
- dataset_config['transforms'][border_map_id]['MakeBorderMap'][
- 'epoch'] = seed if seed is not None else 0
- dataset_config['transforms'][shrink_map_id]['MakeShrinkMap'][
- 'epoch'] = seed if seed is not None else 0
+ border_map_id = [
+ index
+ for index, dictionary in enumerate(dataset_config["transforms"])
+ if "MakeBorderMap" in dictionary
+ ][0]
+ shrink_map_id = [
+ index
+ for index, dictionary in enumerate(dataset_config["transforms"])
+ if "MakeShrinkMap" in dictionary
+ ][0]
+ dataset_config["transforms"][border_map_id]["MakeBorderMap"][
+ "epoch"
+ ] = (seed if seed is not None else 0)
+ dataset_config["transforms"][shrink_map_id]["MakeShrinkMap"][
+ "epoch"
+ ] = (seed if seed is not None else 0)
except Exception as E:
print(E)
return
@@ -84,8 +89,7 @@ def get_image_info_list(self, file_list, ratio_list):
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
- lines = random.sample(lines,
- round(len(lines) * ratio_list[idx]))
+ lines = random.sample(lines, round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
@@ -107,34 +111,33 @@ def _try_parse_filename_list(self, file_name):
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
- if hasattr(op, 'ext_data_num'):
- ext_data_num = getattr(op, 'ext_data_num')
+ if hasattr(op, "ext_data_num"):
+ ext_data_num = getattr(op, "ext_data_num")
break
- load_data_ops = self.ops[:self.ext_op_transform_idx]
+ load_data_ops = self.ops[: self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
- file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
- ))]
+ file_idx = self.data_idx_order_list[np.random.randint(self.__len__())]
data_line = self.data_lines[file_idx]
- data_line = data_line.decode('utf-8')
+ data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
- data = {'img_path': img_path, 'label': label}
+ data = {"img_path": img_path, "label": label}
if not os.path.exists(img_path):
continue
- with open(data['img_path'], 'rb') as f:
+ with open(data["img_path"], "rb") as f:
img = f.read()
- data['image'] = img
+ data["image"] = img
data = transform(data, load_data_ops)
if data is None:
continue
- if 'polys' in data.keys():
- if data['polys'].shape[1] != 4:
+ if "polys" in data.keys():
+ if data["polys"].shape[1] != 4:
continue
ext_data.append(data)
return ext_data
@@ -143,29 +146,34 @@ def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
- data_line = data_line.decode('utf-8')
+ data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
- data = {'img_path': img_path, 'label': label}
+ data = {"img_path": img_path, "label": label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
- with open(data['img_path'], 'rb') as f:
+ with open(data["img_path"], "rb") as f:
img = f.read()
- data['image'] = img
- data['ext_data'] = self.get_ext_data()
+ data["image"] = img
+ data["ext_data"] = self.get_ext_data()
outs = transform(data, self.ops)
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- data_line, traceback.format_exc()))
+ data_line, traceback.format_exc()
+ )
+ )
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
- rnd_idx = np.random.randint(self.__len__(
- )) if self.mode == "train" else (idx + 1) % self.__len__()
+ rnd_idx = (
+ np.random.randint(self.__len__())
+ if self.mode == "train"
+ else (idx + 1) % self.__len__()
+ )
return self.__getitem__(rnd_idx)
return outs
@@ -176,7 +184,7 @@ def __len__(self):
class MultiScaleDataSet(SimpleDataSet):
def __init__(self, config, mode, logger, seed=None):
super(MultiScaleDataSet, self).__init__(config, mode, logger, seed)
- self.ds_width = config[mode]['dataset'].get('ds_width', False)
+ self.ds_width = config[mode]["dataset"].get("ds_width", False)
if self.ds_width:
self.wh_aware()
@@ -185,7 +193,7 @@ def wh_aware(self):
wh_ratio = []
for lins in self.data_lines:
data_line_new.append(lins)
- lins = lins.decode('utf-8')
+ lins = lins.decode("utf-8")
name, label, w, h = lins.strip("\n").split(self.delimiter)
wh_ratio.append(float(w) / float(h))
@@ -195,12 +203,13 @@ def wh_aware(self):
self.data_idx_order_list = list(range(len(self.data_lines)))
def resize_norm_img(self, data, imgW, imgH, padding=True):
- img = data['image']
+ img = data["image"]
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR
+ )
resized_w = imgW
else:
ratio = w / float(h)
@@ -209,7 +218,7 @@ def resize_norm_img(self, data, imgW, imgH, padding=True):
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
@@ -217,8 +226,8 @@ def resize_norm_img(self, data, imgW, imgH, padding=True):
padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
padding_im[:, :, :resized_w] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
- data['image'] = padding_im
- data['valid_ratio'] = valid_ratio
+ data["image"] = padding_im
+ data["valid_ratio"] = valid_ratio
return data
def __getitem__(self, properties):
@@ -227,8 +236,9 @@ def __getitem__(self, properties):
idx = properties[2]
if self.ds_width and properties[3] is not None:
wh_ratio = properties[3]
- img_width = img_height * (1 if int(round(wh_ratio)) == 0 else
- int(round(wh_ratio)))
+ img_width = img_height * (
+ 1 if int(round(wh_ratio)) == 0 else int(round(wh_ratio))
+ )
file_idx = self.wh_ratio_sort[idx]
else:
file_idx = self.data_idx_order_list[idx]
@@ -237,19 +247,19 @@ def __getitem__(self, properties):
data_line = self.data_lines[file_idx]
try:
- data_line = data_line.decode('utf-8')
+ data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
- data = {'img_path': img_path, 'label': label}
+ data = {"img_path": img_path, "label": label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
- with open(data['img_path'], 'rb') as f:
+ with open(data["img_path"], "rb") as f:
img = f.read()
- data['image'] = img
- data['ext_data'] = self.get_ext_data()
+ data["image"] = img
+ data["ext_data"] = self.get_ext_data()
outs = transform(data, self.ops[:-1])
if outs is not None:
outs = self.resize_norm_img(outs, img_width, img_height)
@@ -257,7 +267,9 @@ def __getitem__(self, properties):
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- data_line, traceback.format_exc()))
+ data_line, traceback.format_exc()
+ )
+ )
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
diff --git a/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu
index 17bd47dc08..b04766807c 100644
--- a/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu
+++ b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu
@@ -340,12 +340,12 @@ RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
PD_DISPATCH_FLOATING_TYPES(
input.type(), "roi_align_rotated_cuda_forward_kernel", ([&] {
- roi_align_rotated_cuda_forward_kernel<
- data_t><<>>(
- output_size, input.data(), rois.data(),
- static_cast(spatial_scale), sampling_ratio, aligned,
- clockwise, channels, height, width, aligned_height, aligned_width,
- output.data());
+ roi_align_rotated_cuda_forward_kernel
+ <<>>(
+ output_size, input.data(), rois.data(),
+ static_cast(spatial_scale), sampling_ratio, aligned,
+ clockwise, channels, height, width, aligned_height,
+ aligned_width, output.data());
}));
return {output};
@@ -370,11 +370,12 @@ std::vector RoIAlignRotatedCUDABackward(
PD_DISPATCH_FLOATING_TYPES(
grad_output.type(), "roi_align_rotated_backward_cuda_kernel", ([&] {
- roi_align_rotated_backward_cuda_kernel<
- data_t><<>>(
- output_size, grad_output.data(), rois.data(),
- spatial_scale, sampling_ratio, aligned, clockwise, channels, height,
- width, aligned_height, aligned_width, grad_input.data());
+ roi_align_rotated_backward_cuda_kernel
+ <<>>(
+ output_size, grad_output.data(), rois.data(),
+ spatial_scale, sampling_ratio, aligned, clockwise, channels,
+ height, width, aligned_height, aligned_width,
+ grad_input.data());
}));
return {grad_input};
}
\ No newline at end of file
diff --git a/ppocr/ext_op/roi_align_rotated/roi_align_rotated.py b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.py
index dcca285c75..4cc1bf6a31 100644
--- a/ppocr/ext_op/roi_align_rotated/roi_align_rotated.py
+++ b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.py
@@ -19,27 +19,24 @@
import paddle
import paddle.nn as nn
from paddle.utils.cpp_extension import load
+
custom_ops = load(
name="custom_jit_ops",
sources=[
"ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc",
- "ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu"
- ])
+ "ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu",
+ ],
+)
roi_align_rotated = custom_ops.roi_align_rotated
class RoIAlignRotated(nn.Layer):
- """RoI align pooling layer for rotated proposals.
-
- """
+ """RoI align pooling layer for rotated proposals."""
- def __init__(self,
- out_size,
- spatial_scale,
- sample_num=0,
- aligned=True,
- clockwise=False):
+ def __init__(
+ self, out_size, spatial_scale, sample_num=0, aligned=True, clockwise=False
+ ):
super(RoIAlignRotated, self).__init__()
if isinstance(out_size, int):
@@ -51,8 +48,7 @@ def __init__(self,
assert isinstance(out_size[1], int)
self.out_h, self.out_w = out_size
else:
- raise TypeError(
- '"out_size" must be an integer or tuple of integers')
+ raise TypeError('"out_size" must be an integer or tuple of integers')
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
@@ -60,7 +56,14 @@ def __init__(self,
self.clockwise = clockwise
def forward(self, feats, rois):
- output = roi_align_rotated(feats, rois, self.out_h, self.out_w,
- self.spatial_scale, self.sample_num,
- self.aligned, self.clockwise)
+ output = roi_align_rotated(
+ feats,
+ rois,
+ self.out_h,
+ self.out_w,
+ self.spatial_scale,
+ self.sample_num,
+ self.aligned,
+ self.clockwise,
+ )
return output
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 02f711a459..ed66e9837a 100644
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -62,6 +62,7 @@
# table loss
from .table_att_loss import TableAttentionLoss, SLALoss
from .table_master_loss import TableMasterLoss
+
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
@@ -72,17 +73,45 @@
def build_loss(config):
support_dict = [
- 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
- 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
- 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
- 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
- 'SATRNLoss', 'NRTRLoss', 'ParseQLoss', 'CPPDLoss'
+ "DBLoss",
+ "PSELoss",
+ "EASTLoss",
+ "SASTLoss",
+ "FCELoss",
+ "CTCLoss",
+ "ClsLoss",
+ "AttentionLoss",
+ "SRNLoss",
+ "PGLoss",
+ "CombinedLoss",
+ "CELoss",
+ "TableAttentionLoss",
+ "SARLoss",
+ "AsterLoss",
+ "SDMGRLoss",
+ "VQASerTokenLayoutLMLoss",
+ "LossFromOutput",
+ "PRENLoss",
+ "MultiLoss",
+ "TableMasterLoss",
+ "SPINAttentionLoss",
+ "VLLoss",
+ "StrokeFocusLoss",
+ "SLALoss",
+ "CTLoss",
+ "RFLLoss",
+ "DRRGLoss",
+ "CANLoss",
+ "TelescopeLoss",
+ "SATRNLoss",
+ "NRTRLoss",
+ "ParseQLoss",
+ "CPPDLoss",
]
config = copy.deepcopy(config)
- module_name = config.pop('name')
- assert module_name in support_dict, Exception('loss only support {}'.format(
- support_dict))
+ module_name = config.pop("name")
+ assert module_name in support_dict, Exception(
+ "loss only support {}".format(support_dict)
+ )
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/losses/ace_loss.py b/ppocr/losses/ace_loss.py
index 915b99e6ec..961528d7ab 100644
--- a/ppocr/losses/ace_loss.py
+++ b/ppocr/losses/ace_loss.py
@@ -26,18 +26,15 @@ class ACELoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_func = nn.CrossEntropyLoss(
- weight=None,
- ignore_index=0,
- reduction='none',
- soft_label=True,
- axis=-1)
+ weight=None, ignore_index=0, reduction="none", soft_label=True, axis=-1
+ )
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
- div = paddle.to_tensor([N]).astype('float32')
+ div = paddle.to_tensor([N]).astype("float32")
predicts = nn.functional.softmax(predicts, axis=-1)
aggregation_preds = paddle.sum(predicts, axis=1)
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index 9ad854cd12..8ecfd20af0 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -1,16 +1,16 @@
-#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import paddle
import paddle.nn as nn
@@ -55,28 +55,32 @@ def forward(self, x, label):
class KLJSLoss(object):
- def __init__(self, mode='kl'):
- assert mode in ['kl', 'js', 'KL', 'JS'
- ], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
+ def __init__(self, mode="kl"):
+ assert mode in [
+ "kl",
+ "js",
+ "KL",
+ "JS",
+ ], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean", eps=1e-5):
-
- if self.mode.lower() == 'kl':
- loss = paddle.multiply(p2,
- paddle.log((p2 + eps) / (p1 + eps) + eps))
- loss += paddle.multiply(p1,
- paddle.log((p1 + eps) / (p2 + eps) + eps))
+ if self.mode.lower() == "kl":
+ loss = paddle.multiply(p2, paddle.log((p2 + eps) / (p1 + eps) + eps))
+ loss += paddle.multiply(p1, paddle.log((p1 + eps) / (p2 + eps) + eps))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(
- p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps))
+ p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps)
+ )
loss += paddle.multiply(
- p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps))
+ p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps)
+ )
loss *= 0.5
else:
raise ValueError(
- "The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
+ "The mode.lower() if KLJSLoss should be one of ['kl', 'js']"
+ )
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
@@ -122,8 +126,7 @@ def forward(self, out1, out2):
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
- loss = (
- self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
+ loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
else:
# for detection distillation log is not needed
loss = self.jskl_loss(out1, out2)
@@ -151,7 +154,7 @@ def forward(self, x, y):
class LossFromOutput(nn.Layer):
- def __init__(self, key='loss', reduction='none'):
+ def __init__(self, key="loss", reduction="none"):
super().__init__()
self.key = key
self.reduction = reduction
@@ -160,11 +163,11 @@ def forward(self, predicts, batch):
loss = predicts
if self.key is not None and isinstance(predicts, dict):
loss = loss[self.key]
- if self.reduction == 'mean':
+ if self.reduction == "mean":
loss = paddle.mean(loss)
- elif self.reduction == 'sum':
+ elif self.reduction == "sum":
loss = paddle.sum(loss)
- return {'loss': loss}
+ return {"loss": loss}
class KLDivLoss(nn.Layer):
@@ -219,8 +222,7 @@ def _kl_div(self, x, label, mask=None):
return y
def forward(self, logits_student, logits_teacher, target, mask=None):
- gt_mask = F.one_hot(
- target.reshape([-1]), num_classes=logits_student.shape[-1])
+ gt_mask = F.one_hot(target.reshape([-1]), num_classes=logits_student.shape[-1])
other_mask = 1 - gt_mask
logits_student = logits_student.flatten(0, 1)
logits_teacher = logits_teacher.flatten(0, 1)
@@ -229,14 +231,18 @@ def forward(self, logits_student, logits_teacher, target, mask=None):
pred_student = self._cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask)
log_pred_student = paddle.log(pred_student)
- tckd_loss = self._kl_div(log_pred_student,
- pred_teacher) * (self.temperature**2)
+ tckd_loss = self._kl_div(log_pred_student, pred_teacher) * (
+ self.temperature**2
+ )
pred_teacher_part2 = F.softmax(
- logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1)
+ logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1
+ )
log_pred_student_part2 = F.log_softmax(
- logits_student / self.temperature - 1000.0 * gt_mask, axis=1)
- nckd_loss = self._kl_div(log_pred_student_part2,
- pred_teacher_part2) * (self.temperature**2)
+ logits_student / self.temperature - 1000.0 * gt_mask, axis=1
+ )
+ nckd_loss = self._kl_div(log_pred_student_part2, pred_teacher_part2) * (
+ self.temperature**2
+ )
loss = self.alpha * tckd_loss + self.beta * nckd_loss
diff --git a/ppocr/losses/center_loss.py b/ppocr/losses/center_loss.py
index f62b8af373..87d4d44fb7 100644
--- a/ppocr/losses/center_loss.py
+++ b/ppocr/losses/center_loss.py
@@ -1,16 +1,16 @@
-#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss
@@ -34,14 +34,15 @@ def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
- self.centers = paddle.randn(
- shape=[self.num_classes, self.feat_dim]).astype("float64")
+ self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype(
+ "float64"
+ )
if center_file_path is not None:
assert os.path.exists(
center_file_path
), f"center path({center_file_path}) must exist when it is not None."
- with open(center_file_path, 'rb') as f:
+ with open(center_file_path, "rb") as f:
char_dict = pickle.load(f)
for key in char_dict.keys():
self.centers[key] = paddle.to_tensor(char_dict[key])
@@ -50,39 +51,39 @@ def __call__(self, predicts, batch):
assert isinstance(predicts, (list, tuple))
features, predicts = predicts
- feats_reshape = paddle.reshape(
- features, [-1, features.shape[-1]]).astype("float64")
+ feats_reshape = paddle.reshape(features, [-1, features.shape[-1]]).astype(
+ "float64"
+ )
label = paddle.argmax(predicts, axis=2)
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
batch_size = feats_reshape.shape[0]
- #calc l2 distance between feats and centers
- square_feat = paddle.sum(paddle.square(feats_reshape),
- axis=1,
- keepdim=True)
+ # calc l2 distance between feats and centers
+ square_feat = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
- square_center = paddle.sum(paddle.square(self.centers),
- axis=1,
- keepdim=True)
+ square_center = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True)
square_center = paddle.expand(
- square_center, [self.num_classes, batch_size]).astype("float64")
+ square_center, [self.num_classes, batch_size]
+ ).astype("float64")
square_center = paddle.transpose(square_center, [1, 0])
distmat = paddle.add(square_feat, square_center)
- feat_dot_center = paddle.matmul(feats_reshape,
- paddle.transpose(self.centers, [1, 0]))
+ feat_dot_center = paddle.matmul(
+ feats_reshape, paddle.transpose(self.centers, [1, 0])
+ )
distmat = distmat - 2.0 * feat_dot_center
- #generate the mask
+ # generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
label = paddle.expand(
- paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
+ paddle.unsqueeze(label, 1), (batch_size, self.num_classes)
+ )
mask = paddle.equal(
- paddle.expand(classes, [batch_size, self.num_classes]),
- label).astype("float64")
+ paddle.expand(classes, [batch_size, self.num_classes]), label
+ ).astype("float64")
dist = paddle.multiply(distmat, mask)
- loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
- return {'loss_center': loss}
+ loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e12)) / batch_size
+ return {"loss_center": loss}
diff --git a/ppocr/losses/cls_loss.py b/ppocr/losses/cls_loss.py
index abc5e5b72c..eb3b17e668 100755
--- a/ppocr/losses/cls_loss.py
+++ b/ppocr/losses/cls_loss.py
@@ -22,9 +22,9 @@
class ClsLoss(nn.Layer):
def __init__(self, **kwargs):
super(ClsLoss, self).__init__()
- self.loss_func = nn.CrossEntropyLoss(reduction='mean')
+ self.loss_func = nn.CrossEntropyLoss(reduction="mean")
def forward(self, predicts, batch):
label = batch[1].astype("int64")
loss = self.loss_func(input=predicts, label=label)
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py
index a520f10ffb..10cdb013cd 100644
--- a/ppocr/losses/combined_loss.py
+++ b/ppocr/losses/combined_loss.py
@@ -22,9 +22,20 @@
from .distillation_loss import DistillationCTCLoss, DistillCTCLogits
from .distillation_loss import DistillationSARLoss, DistillationNRTRLoss
-from .distillation_loss import DistillationDMLLoss, DistillationKLDivLoss, DistillationDKDLoss
-from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
-from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
+from .distillation_loss import (
+ DistillationDMLLoss,
+ DistillationKLDivLoss,
+ DistillationDKDLoss,
+)
+from .distillation_loss import (
+ DistillationDistanceLoss,
+ DistillationDBLoss,
+ DistillationDilaDBLoss,
+)
+from .distillation_loss import (
+ DistillationVQASerTokenLayoutLMLoss,
+ DistillationSERDMLLoss,
+)
from .distillation_loss import DistillationLossFromOutput
from .distillation_loss import DistillationVQADistanceLoss
@@ -39,21 +50,22 @@ def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
- assert isinstance(loss_config_list, list), (
- 'operator config should be a list')
+ assert isinstance(loss_config_list, list), "operator config should be a list"
for config in loss_config_list:
- assert isinstance(config,
- dict) and len(config) == 1, "yaml format error"
+ assert isinstance(config, dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
- assert "weight" in param, "weight must be in param, but param just contains {}".format(
- param.keys())
+ assert (
+ "weight" in param
+ ), "weight must be in param, but param just contains {}".format(
+ param.keys()
+ )
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs):
loss_dict = {}
- loss_all = 0.
+ loss_all = 0.0
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
diff --git a/ppocr/losses/det_basic_loss.py b/ppocr/losses/det_basic_loss.py
index 61ea579b41..248c3bad53 100644
--- a/ppocr/losses/det_basic_loss.py
+++ b/ppocr/losses/det_basic_loss.py
@@ -27,23 +27,25 @@
class BalanceLoss(nn.Layer):
- def __init__(self,
- balance_loss=True,
- main_loss_type='DiceLoss',
- negative_ratio=3,
- return_origin=False,
- eps=1e-6,
- **kwargs):
+ def __init__(
+ self,
+ balance_loss=True,
+ main_loss_type="DiceLoss",
+ negative_ratio=3,
+ return_origin=False,
+ eps=1e-6,
+ **kwargs
+ ):
+ """
+ The BalanceLoss for Differentiable Binarization text detection
+ args:
+ balance_loss (bool): whether balance loss or not, default is True
+ main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
+ 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
+ negative_ratio (int|float): float, default is 3.
+ return_origin (bool): whether return unbalanced loss or not, default is False.
+ eps (float): default is 1e-6.
"""
- The BalanceLoss for Differentiable Binarization text detection
- args:
- balance_loss (bool): whether balance loss or not, default is True
- main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
- 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
- negative_ratio (int|float): float, default is 3.
- return_origin (bool): whether return unbalanced loss or not, default is False.
- eps (float): default is 1e-6.
- """
super(BalanceLoss, self).__init__()
self.balance_loss = balance_loss
self.main_loss_type = main_loss_type
@@ -58,16 +60,22 @@ def __init__(self,
elif self.main_loss_type == "DiceLoss":
self.loss = DiceLoss(self.eps)
elif self.main_loss_type == "BCELoss":
- self.loss = BCELoss(reduction='none')
+ self.loss = BCELoss(reduction="none")
elif self.main_loss_type == "MaskL1Loss":
self.loss = MaskL1Loss(self.eps)
else:
loss_type = [
- 'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
+ "CrossEntropy",
+ "DiceLoss",
+ "Euclidean",
+ "BCELoss",
+ "MaskL1Loss",
]
raise Exception(
"main_loss_type in BalanceLoss() can only be one of {}".format(
- loss_type))
+ loss_type
+ )
+ )
def forward(self, pred, gt, mask=None):
"""
@@ -82,8 +90,7 @@ def forward(self, pred, gt, mask=None):
negative = (1 - gt) * mask
positive_count = int(positive.sum())
- negative_count = int(
- min(negative.sum(), positive_count * self.negative_ratio))
+ negative_count = int(min(negative.sum(), positive_count * self.negative_ratio))
loss = self.loss(pred, gt, mask=mask)
if not self.balance_loss:
@@ -97,7 +104,8 @@ def forward(self, pred, gt, mask=None):
negative_loss = sort_loss[:negative_count]
# negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
- positive_count + negative_count + self.eps)
+ positive_count + negative_count + self.eps
+ )
else:
balance_loss = positive_loss.sum() / (positive_count + self.eps)
if self.return_origin:
@@ -144,7 +152,7 @@ def forward(self, pred, gt, mask):
class BCELoss(nn.Layer):
- def __init__(self, reduction='mean'):
+ def __init__(self, reduction="mean"):
super(BCELoss, self).__init__()
self.reduction = reduction
diff --git a/ppocr/losses/det_ct_loss.py b/ppocr/losses/det_ct_loss.py
index f48c95be4f..4655eff30c 100755
--- a/ppocr/losses/det_ct_loss.py
+++ b/ppocr/losses/det_ct_loss.py
@@ -30,14 +30,16 @@ def ohem_single(score, gt_text, training_mask):
# online hard example mining
pos_num = int(paddle.sum(gt_text > 0.5)) - int(
- paddle.sum((gt_text > 0.5) & (training_mask <= 0.5)))
+ paddle.sum((gt_text > 0.5) & (training_mask <= 0.5))
+ )
if pos_num == 0:
# selected_mask = gt_text.copy() * 0 # may be not good
selected_mask = training_mask
selected_mask = paddle.cast(
- selected_mask.reshape(
- (1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
+ selected_mask.reshape((1, selected_mask.shape[0], selected_mask.shape[1])),
+ "float32",
+ )
return selected_mask
neg_num = int(paddle.sum((gt_text <= 0.5) & (training_mask > 0.5)))
@@ -46,8 +48,9 @@ def ohem_single(score, gt_text, training_mask):
if neg_num == 0:
selected_mask = training_mask
selected_mask = paddle.cast(
- selected_mask.reshape(
- (1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
+ selected_mask.reshape((1, selected_mask.shape[0], selected_mask.shape[1])),
+ "float32",
+ )
return selected_mask
# hard example
@@ -55,11 +58,11 @@ def ohem_single(score, gt_text, training_mask):
neg_score_sorted = paddle.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
- selected_mask = ((score >= threshold) |
- (gt_text > 0.5)) & (training_mask > 0.5)
+ selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
selected_mask = paddle.cast(
- selected_mask.reshape(
- (1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
+ selected_mask.reshape((1, selected_mask.shape[0], selected_mask.shape[1])),
+ "float32",
+ )
return selected_mask
@@ -67,8 +70,8 @@ def ohem_batch(scores, gt_texts, training_masks):
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(
- ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
- i, :, :]))
+ ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])
+ )
selected_masks = paddle.cast(paddle.concat(selected_masks, 0), "float32")
return selected_masks
@@ -98,7 +101,7 @@ def iou(a, b, mask, n_class=2, reduce=True):
b = b.reshape((batch_size, -1))
mask = mask.reshape((batch_size, -1))
- iou = paddle.zeros((batch_size, ), dtype="float32")
+ iou = paddle.zeros((batch_size,), dtype="float32")
for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
@@ -153,15 +156,15 @@ def __init__(self, beta=1.0, loss_weight=1.0):
self.coord = self.create_parameter(
shape=[640 * 640, 2],
dtype="int32", # NOTE: not support "int64" before paddle 2.3.1
- default_initializer=nn.initializer.Assign(value=np_coord))
+ default_initializer=nn.initializer.Assign(value=np_coord),
+ )
self.coord.stop_gradient = True
def forward_single(self, input, target, mask, beta=1.0, eps=1e-6):
batch_size = input.shape[0]
diff = paddle.abs(input - target) * mask.unsqueeze(1)
- loss = paddle.where(diff < beta, 0.5 * diff * diff / beta,
- diff - 0.5 * beta)
+ loss = paddle.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta)
loss = paddle.cast(loss.reshape((batch_size, -1)), "float32")
mask = paddle.cast(mask.reshape((batch_size, -1)), "float32")
loss = paddle.sum(loss, axis=-1)
@@ -169,9 +172,7 @@ def forward_single(self, input, target, mask, beta=1.0, eps=1e-6):
return loss
- def select_single(self, distance, gt_instance, gt_kernel_instance,
- training_mask):
-
+ def select_single(self, distance, gt_instance, gt_kernel_instance, training_mask):
with paddle.no_grad():
# paddle 2.3.1, paddle.slice not support:
# distance[:, self.coord[:, 1], self.coord[:, 0]]
@@ -183,47 +184,56 @@ def select_single(self, distance, gt_instance, gt_kernel_instance,
select_distance = paddle.concat(select_distance_list, axis=0)
off_points = paddle.cast(
- self.coord, "float32") + 10 * select_distance.transpose((1, 0))
+ self.coord, "float32"
+ ) + 10 * select_distance.transpose((1, 0))
off_points = paddle.cast(off_points, "int64")
off_points = paddle.clip(off_points, 0, distance.shape[-1] - 1)
selected_mask = (
- gt_instance[self.coord[:, 1], self.coord[:, 0]] !=
- gt_kernel_instance[off_points[:, 1], off_points[:, 0]])
+ gt_instance[self.coord[:, 1], self.coord[:, 0]]
+ != gt_kernel_instance[off_points[:, 1], off_points[:, 0]]
+ )
selected_mask = paddle.cast(
- selected_mask.reshape((1, -1, distance.shape[-1])), "int64")
+ selected_mask.reshape((1, -1, distance.shape[-1])), "int64"
+ )
selected_training_mask = selected_mask * training_mask
return selected_training_mask
- def forward(self,
- distances,
- gt_instances,
- gt_kernel_instances,
- training_masks,
- gt_distances,
- reduce=True):
-
+ def forward(
+ self,
+ distances,
+ gt_instances,
+ gt_kernel_instances,
+ training_masks,
+ gt_distances,
+ reduce=True,
+ ):
selected_training_masks = []
for i in range(distances.shape[0]):
selected_training_masks.append(
- self.select_single(distances[i, :, :, :], gt_instances[i, :, :],
- gt_kernel_instances[i, :, :], training_masks[
- i, :, :]))
+ self.select_single(
+ distances[i, :, :, :],
+ gt_instances[i, :, :],
+ gt_kernel_instances[i, :, :],
+ training_masks[i, :, :],
+ )
+ )
selected_training_masks = paddle.cast(
- paddle.concat(selected_training_masks, 0), "float32")
+ paddle.concat(selected_training_masks, 0), "float32"
+ )
- loss = self.forward_single(distances, gt_distances,
- selected_training_masks, self.beta)
+ loss = self.forward_single(
+ distances, gt_distances, selected_training_masks, self.beta
+ )
loss = self.loss_weight * loss
with paddle.no_grad():
batch_size = distances.shape[0]
false_num = selected_training_masks.reshape((batch_size, -1))
false_num = false_num.sum(axis=-1)
- total_num = paddle.cast(
- training_masks.reshape((batch_size, -1)), "float32")
+ total_num = paddle.cast(training_masks.reshape((batch_size, -1)), "float32")
total_num = total_num.sum(axis=-1)
iou_text = (total_num - false_num) / (total_num + 1e-6)
@@ -241,9 +251,15 @@ def __init__(self):
def forward(self, preds, batch):
imgs = batch[0]
- out = preds['maps']
- gt_kernels, training_masks, gt_instances, gt_kernel_instances, training_mask_distances, gt_distances = batch[
- 1:]
+ out = preds["maps"]
+ (
+ gt_kernels,
+ training_masks,
+ gt_instances,
+ gt_kernel_instances,
+ training_mask_distances,
+ gt_distances,
+ ) = batch[1:]
kernels = out[:, 0, :, :]
distances = out[:, 1:, :, :]
@@ -252,13 +268,18 @@ def forward(self, preds, batch):
selected_masks = ohem_batch(kernels, gt_kernels, training_masks)
loss_kernel = self.kernel_loss(
- kernels, gt_kernels, selected_masks, reduce=False)
-
- iou_kernel = iou(paddle.cast((kernels > 0), "int64"),
- gt_kernels,
- training_masks,
- reduce=False)
- losses = dict(loss_kernels=loss_kernel, )
+ kernels, gt_kernels, selected_masks, reduce=False
+ )
+
+ iou_kernel = iou(
+ paddle.cast((kernels > 0), "int64"),
+ gt_kernels,
+ training_masks,
+ reduce=False,
+ )
+ losses = dict(
+ loss_kernels=loss_kernel,
+ )
# loc loss
loss_loc, iou_text = self.loc_loss(
@@ -267,10 +288,15 @@ def forward(self, preds, batch):
gt_kernel_instances,
training_mask_distances,
gt_distances,
- reduce=False)
- losses.update(dict(loss_loc=loss_loc, ))
+ reduce=False,
+ )
+ losses.update(
+ dict(
+ loss_loc=loss_loc,
+ )
+ )
loss_all = loss_kernel + loss_loc
- losses = {'loss': loss_all}
+ losses = {"loss": loss_all}
return losses
diff --git a/ppocr/losses/det_db_loss.py b/ppocr/losses/det_db_loss.py
index ce31ef1245..67201e0462 100755
--- a/ppocr/losses/det_db_loss.py
+++ b/ppocr/losses/det_db_loss.py
@@ -33,14 +33,16 @@ class DBLoss(nn.Layer):
param (dict): the super paramter for DB Loss
"""
- def __init__(self,
- balance_loss=True,
- main_loss_type='DiceLoss',
- alpha=5,
- beta=10,
- ohem_ratio=3,
- eps=1e-6,
- **kwargs):
+ def __init__(
+ self,
+ balance_loss=True,
+ main_loss_type="DiceLoss",
+ alpha=5,
+ beta=10,
+ ohem_ratio=3,
+ eps=1e-6,
+ **kwargs
+ ):
super(DBLoss, self).__init__()
self.alpha = alpha
self.beta = beta
@@ -49,39 +51,49 @@ def __init__(self,
self.bce_loss = BalanceLoss(
balance_loss=balance_loss,
main_loss_type=main_loss_type,
- negative_ratio=ohem_ratio)
+ negative_ratio=ohem_ratio,
+ )
def forward(self, predicts, labels):
- predict_maps = predicts['maps']
- label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
- 1:]
+ predict_maps = predicts["maps"]
+ (
+ label_threshold_map,
+ label_threshold_mask,
+ label_shrink_map,
+ label_shrink_mask,
+ ) = labels[1:]
shrink_maps = predict_maps[:, 0, :, :]
threshold_maps = predict_maps[:, 1, :, :]
binary_maps = predict_maps[:, 2, :, :]
- loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
- label_shrink_mask)
- loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
- label_threshold_mask)
- loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
- label_shrink_mask)
+ loss_shrink_maps = self.bce_loss(
+ shrink_maps, label_shrink_map, label_shrink_mask
+ )
+ loss_threshold_maps = self.l1_loss(
+ threshold_maps, label_threshold_map, label_threshold_mask
+ )
+ loss_binary_maps = self.dice_loss(
+ binary_maps, label_shrink_map, label_shrink_mask
+ )
loss_shrink_maps = self.alpha * loss_shrink_maps
loss_threshold_maps = self.beta * loss_threshold_maps
# CBN loss
- if 'distance_maps' in predicts.keys():
- distance_maps = predicts['distance_maps']
- cbn_maps = predicts['cbn_maps']
- cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map,
- label_shrink_mask)
+ if "distance_maps" in predicts.keys():
+ distance_maps = predicts["distance_maps"]
+ cbn_maps = predicts["cbn_maps"]
+ cbn_loss = self.bce_loss(
+ cbn_maps[:, 0, :, :], label_shrink_map, label_shrink_mask
+ )
else:
- dis_loss = paddle.to_tensor([0.])
- cbn_loss = paddle.to_tensor([0.])
+ dis_loss = paddle.to_tensor([0.0])
+ cbn_loss = paddle.to_tensor([0.0])
- loss_all = loss_shrink_maps + loss_threshold_maps \
- + loss_binary_maps
- losses = {'loss': loss_all+ cbn_loss, \
- "loss_shrink_maps": loss_shrink_maps, \
- "loss_threshold_maps": loss_threshold_maps, \
- "loss_binary_maps": loss_binary_maps, \
- "loss_cbn": cbn_loss}
+ loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
+ losses = {
+ "loss": loss_all + cbn_loss,
+ "loss_shrink_maps": loss_shrink_maps,
+ "loss_threshold_maps": loss_threshold_maps,
+ "loss_binary_maps": loss_binary_maps,
+ "loss_cbn": cbn_loss,
+ }
return losses
diff --git a/ppocr/losses/det_drrg_loss.py b/ppocr/losses/det_drrg_loss.py
index 89d4b521c7..6e990cebff 100644
--- a/ppocr/losses/det_drrg_loss.py
+++ b/ppocr/losses/det_drrg_loss.py
@@ -46,21 +46,22 @@ def balance_bce_loss(self, pred, gt, mask):
positive_count = int(positive.sum())
if positive_count > 0:
- loss = F.binary_cross_entropy(pred, gt, reduction='none')
+ loss = F.binary_cross_entropy(pred, gt, reduction="none")
positive_loss = paddle.sum(loss * positive)
negative_loss = loss * negative
negative_count = min(
- int(negative.sum()), int(positive_count * self.ohem_ratio))
+ int(negative.sum()), int(positive_count * self.ohem_ratio)
+ )
else:
positive_loss = paddle.to_tensor(0.0)
- loss = F.binary_cross_entropy(pred, gt, reduction='none')
+ loss = F.binary_cross_entropy(pred, gt, reduction="none")
negative_loss = loss * negative
negative_count = 100
- negative_loss, _ = paddle.topk(
- negative_loss.reshape([-1]), negative_count)
+ negative_loss, _ = paddle.topk(negative_loss.reshape([-1]), negative_count)
balance_loss = (positive_loss + paddle.sum(negative_loss)) / (
- float(positive_count + negative_count) + 1e-5)
+ float(positive_count + negative_count) + 1e-5
+ )
return balance_loss
@@ -105,7 +106,7 @@ def bitmasks2tensor(self, bitmasks, target_sz):
mask_sz = mask.shape
# left, right, top, bottom
pad = [0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]]
- mask = F.pad(mask, pad, mode='constant', value=0)
+ mask = F.pad(mask, pad, mode="constant", value=0)
kernel.append(mask)
kernel = paddle.stack(kernel)
results.append(kernel)
@@ -113,12 +114,18 @@ def bitmasks2tensor(self, bitmasks, target_sz):
return results
def forward(self, preds, labels):
- """Compute Drrg loss.
- """
+ """Compute Drrg loss."""
assert isinstance(preds, tuple)
- gt_text_mask, gt_center_region_mask, gt_mask, gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map = labels[
- 1:8]
+ (
+ gt_text_mask,
+ gt_center_region_mask,
+ gt_mask,
+ gt_top_height_map,
+ gt_bot_height_map,
+ gt_sin_map,
+ gt_cos_map,
+ ) = labels[1:8]
downsample_ratio = self.downsample_ratio
@@ -133,14 +140,13 @@ def forward(self, preds, labels):
# bitmask 2 tensor
mapping = {
- 'gt_text_mask': paddle.cast(gt_text_mask, 'float32'),
- 'gt_center_region_mask':
- paddle.cast(gt_center_region_mask, 'float32'),
- 'gt_mask': paddle.cast(gt_mask, 'float32'),
- 'gt_top_height_map': paddle.cast(gt_top_height_map, 'float32'),
- 'gt_bot_height_map': paddle.cast(gt_bot_height_map, 'float32'),
- 'gt_sin_map': paddle.cast(gt_sin_map, 'float32'),
- 'gt_cos_map': paddle.cast(gt_cos_map, 'float32')
+ "gt_text_mask": paddle.cast(gt_text_mask, "float32"),
+ "gt_center_region_mask": paddle.cast(gt_center_region_mask, "float32"),
+ "gt_mask": paddle.cast(gt_mask, "float32"),
+ "gt_top_height_map": paddle.cast(gt_top_height_map, "float32"),
+ "gt_bot_height_map": paddle.cast(gt_bot_height_map, "float32"),
+ "gt_sin_map": paddle.cast(gt_sin_map, "float32"),
+ "gt_cos_map": paddle.cast(gt_cos_map, "float32"),
}
gt = {}
for key, value in mapping.items():
@@ -150,7 +156,7 @@ def forward(self, preds, labels):
else:
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
- if key in ['gt_top_height_map', 'gt_bot_height_map']:
+ if key in ["gt_top_height_map", "gt_bot_height_map"]:
gt[key] = [item * downsample_ratio for item in gt[key]]
gt[key] = [item for item in gt[key]]
@@ -159,51 +165,54 @@ def forward(self, preds, labels):
pred_cos_map = pred_cos_map * scale
loss_text = self.balance_bce_loss(
- F.sigmoid(pred_text_region), gt['gt_text_mask'][0],
- gt['gt_mask'][0])
+ F.sigmoid(pred_text_region), gt["gt_text_mask"][0], gt["gt_mask"][0]
+ )
- text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0])
- negative_text_mask = ((1 - gt['gt_text_mask'][0]) * gt['gt_mask'][0])
+ text_mask = gt["gt_text_mask"][0] * gt["gt_mask"][0]
+ negative_text_mask = (1 - gt["gt_text_mask"][0]) * gt["gt_mask"][0]
loss_center_map = F.binary_cross_entropy(
F.sigmoid(pred_center_region),
- gt['gt_center_region_mask'][0],
- reduction='none')
+ gt["gt_center_region_mask"][0],
+ reduction="none",
+ )
if int(text_mask.sum()) > 0:
- loss_center_positive = paddle.sum(loss_center_map *
- text_mask) / paddle.sum(text_mask)
+ loss_center_positive = paddle.sum(loss_center_map * text_mask) / paddle.sum(
+ text_mask
+ )
else:
loss_center_positive = paddle.to_tensor(0.0)
loss_center_negative = paddle.sum(
- loss_center_map *
- negative_text_mask) / paddle.sum(negative_text_mask)
+ loss_center_map * negative_text_mask
+ ) / paddle.sum(negative_text_mask)
loss_center = loss_center_positive + 0.5 * loss_center_negative
- center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0])
+ center_mask = gt["gt_center_region_mask"][0] * gt["gt_mask"][0]
if int(center_mask.sum()) > 0:
map_sz = pred_top_height_map.shape
- ones = paddle.ones(map_sz, dtype='float32')
+ ones = paddle.ones(map_sz, dtype="float32")
loss_top = F.smooth_l1_loss(
- pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
+ pred_top_height_map / (gt["gt_top_height_map"][0] + 1e-2),
ones,
- reduction='none')
+ reduction="none",
+ )
loss_bot = F.smooth_l1_loss(
- pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
+ pred_bot_height_map / (gt["gt_bot_height_map"][0] + 1e-2),
ones,
- reduction='none')
- gt_height = (
- gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
+ reduction="none",
+ )
+ gt_height = gt["gt_top_height_map"][0] + gt["gt_bot_height_map"][0]
loss_height = paddle.sum(
- (paddle.log(gt_height + 1) *
- (loss_top + loss_bot)) * center_mask) / paddle.sum(center_mask)
+ (paddle.log(gt_height + 1) * (loss_top + loss_bot)) * center_mask
+ ) / paddle.sum(center_mask)
loss_sin = paddle.sum(
- F.smooth_l1_loss(
- pred_sin_map, gt['gt_sin_map'][0],
- reduction='none') * center_mask) / paddle.sum(center_mask)
+ F.smooth_l1_loss(pred_sin_map, gt["gt_sin_map"][0], reduction="none")
+ * center_mask
+ ) / paddle.sum(center_mask)
loss_cos = paddle.sum(
- F.smooth_l1_loss(
- pred_cos_map, gt['gt_cos_map'][0],
- reduction='none') * center_mask) / paddle.sum(center_mask)
+ F.smooth_l1_loss(pred_cos_map, gt["gt_cos_map"][0], reduction="none")
+ * center_mask
+ ) / paddle.sum(center_mask)
else:
loss_height = paddle.to_tensor(0.0)
loss_sin = paddle.to_tensor(0.0)
@@ -219,6 +228,7 @@ def forward(self, preds, labels):
loss_height=loss_height,
loss_sin=loss_sin,
loss_cos=loss_cos,
- loss_gcn=loss_gcn)
+ loss_gcn=loss_gcn,
+ )
return results
diff --git a/ppocr/losses/det_east_loss.py b/ppocr/losses/det_east_loss.py
index bcf5372b72..e5725d0d49 100644
--- a/ppocr/losses/det_east_loss.py
+++ b/ppocr/losses/det_east_loss.py
@@ -22,42 +22,41 @@
class EASTLoss(nn.Layer):
- """
- """
+ """ """
- def __init__(self,
- eps=1e-6,
- **kwargs):
+ def __init__(self, eps=1e-6, **kwargs):
super(EASTLoss, self).__init__()
self.dice_loss = DiceLoss(eps=eps)
def forward(self, predicts, labels):
l_score, l_geo, l_mask = labels[1:]
- f_score = predicts['f_score']
- f_geo = predicts['f_geo']
+ f_score = predicts["f_score"]
+ f_geo = predicts["f_geo"]
dice_loss = self.dice_loss(f_score, l_score, l_mask)
- #smoooth_l1_loss
+ # smoooth_l1_loss
channels = 8
- l_geo_split = paddle.split(
- l_geo, num_or_sections=channels + 1, axis=1)
+ l_geo_split = paddle.split(l_geo, num_or_sections=channels + 1, axis=1)
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
smooth_l1 = 0
for i in range(0, channels):
geo_diff = l_geo_split[i] - f_geo_split[i]
abs_geo_diff = paddle.abs(geo_diff)
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
- smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
- in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
- (abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
+ smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype="float32")
+ in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + (
+ abs_geo_diff - 0.5
+ ) * (1.0 - smooth_l1_sign)
out_loss = l_geo_split[-1] / channels * in_loss * l_score
smooth_l1 += out_loss
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
dice_loss = dice_loss * 0.01
total_loss = dice_loss + smooth_l1_loss
- losses = {"loss":total_loss, \
- "dice_loss":dice_loss,\
- "smooth_l1_loss":smooth_l1_loss}
+ losses = {
+ "loss": total_loss,
+ "dice_loss": dice_loss,
+ "smooth_l1_loss": smooth_l1_loss,
+ }
return losses
diff --git a/ppocr/losses/det_fce_loss.py b/ppocr/losses/det_fce_loss.py
index d7dfb5aa6c..86b82a8b24 100644
--- a/ppocr/losses/det_fce_loss.py
+++ b/ppocr/losses/det_fce_loss.py
@@ -43,7 +43,7 @@ class FCELoss(nn.Layer):
ohem_ratio (float): the negative/positive ratio in OHEM.
"""
- def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
+ def __init__(self, fourier_degree, num_sample, ohem_ratio=3.0):
super().__init__()
self.fourier_degree = fourier_degree
self.num_sample = num_sample
@@ -51,11 +51,12 @@ def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
def forward(self, preds, labels):
assert isinstance(preds, dict)
- preds = preds['levels']
+ preds = preds["levels"]
p3_maps, p4_maps, p5_maps = labels[1:]
- assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
- 'fourier degree not equal in FCEhead and FCEtarget'
+ assert (
+ p3_maps[0].shape[0] == 4 * self.fourier_degree + 5
+ ), "fourier degree not equal in FCEhead and FCEtarget"
# to tensor
gts = [p3_maps, p4_maps, p5_maps]
@@ -64,11 +65,11 @@ def forward(self, preds, labels):
losses = multi_apply(self.forward_single, preds, gts)
- loss_tr = paddle.to_tensor(0.).astype('float32')
- loss_tcl = paddle.to_tensor(0.).astype('float32')
- loss_reg_x = paddle.to_tensor(0.).astype('float32')
- loss_reg_y = paddle.to_tensor(0.).astype('float32')
- loss_all = paddle.to_tensor(0.).astype('float32')
+ loss_tr = paddle.to_tensor(0.0).astype("float32")
+ loss_tcl = paddle.to_tensor(0.0).astype("float32")
+ loss_reg_x = paddle.to_tensor(0.0).astype("float32")
+ loss_reg_y = paddle.to_tensor(0.0).astype("float32")
+ loss_all = paddle.to_tensor(0.0).astype("float32")
for idx, loss in enumerate(losses):
loss_all += sum(loss)
@@ -86,7 +87,8 @@ def forward(self, preds, labels):
loss_text=loss_tr,
loss_center=loss_tcl,
loss_reg_x=loss_reg_x,
- loss_reg_y=loss_reg_y, )
+ loss_reg_y=loss_reg_y,
+ )
return results
def forward_single(self, pred, gt):
@@ -98,40 +100,45 @@ def forward_single(self, pred, gt):
tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2))
tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2))
x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k))
- y_pred = paddle.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
+ y_pred = paddle.reshape(reg_pred[:, :, :, k : 2 * k], (-1, k))
tr_mask = gt[:, :, :, :1].reshape([-1])
tcl_mask = gt[:, :, :, 1:2].reshape([-1])
train_mask = gt[:, :, :, 2:3].reshape([-1])
- x_map = paddle.reshape(gt[:, :, :, 3:3 + k], (-1, k))
- y_map = paddle.reshape(gt[:, :, :, 3 + k:], (-1, k))
+ x_map = paddle.reshape(gt[:, :, :, 3 : 3 + k], (-1, k))
+ y_map = paddle.reshape(gt[:, :, :, 3 + k :], (-1, k))
- tr_train_mask = (train_mask * tr_mask).astype('bool')
+ tr_train_mask = (train_mask * tr_mask).astype("bool")
tr_train_mask2 = paddle.concat(
- [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1)
+ [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1
+ )
# tr loss
loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
# tcl loss
- loss_tcl = paddle.to_tensor(0.).astype('float32')
+ loss_tcl = paddle.to_tensor(0.0).astype("float32")
tr_neg_mask = tr_train_mask.logical_not()
tr_neg_mask2 = paddle.concat(
- [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1)
+ [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1
+ )
if tr_train_mask.sum().item() > 0:
loss_tcl_pos = F.cross_entropy(
tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
- tcl_mask.masked_select(tr_train_mask).astype('int64'))
+ tcl_mask.masked_select(tr_train_mask).astype("int64"),
+ )
loss_tcl_neg = F.cross_entropy(
tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
- tcl_mask.masked_select(tr_neg_mask).astype('int64'))
+ tcl_mask.masked_select(tr_neg_mask).astype("int64"),
+ )
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
# regression loss
- loss_reg_x = paddle.to_tensor(0.).astype('float32')
- loss_reg_y = paddle.to_tensor(0.).astype('float32')
+ loss_reg_x = paddle.to_tensor(0.0).astype("float32")
+ loss_reg_y = paddle.to_tensor(0.0).astype("float32")
if tr_train_mask.sum().item() > 0:
- weight = (tr_mask.masked_select(tr_train_mask.astype('bool'))
- .astype('float32') + tcl_mask.masked_select(
- tr_train_mask.astype('bool')).astype('float32')) / 2
+ weight = (
+ tr_mask.masked_select(tr_train_mask.astype("bool")).astype("float32")
+ + tcl_mask.masked_select(tr_train_mask.astype("bool")).astype("float32")
+ ) / 2
weight = weight.reshape([-1, 1])
ft_x, ft_y = self.fourier2poly(x_map, y_map)
@@ -140,52 +147,64 @@ def forward_single(self, pred, gt):
dim = ft_x.shape[1]
tr_train_mask3 = paddle.concat(
- [tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1)
-
- loss_reg_x = paddle.mean(weight * F.smooth_l1_loss(
- ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
- ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
- reduction='none'))
- loss_reg_y = paddle.mean(weight * F.smooth_l1_loss(
- ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
- ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
- reduction='none'))
+ [tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1
+ )
+
+ loss_reg_x = paddle.mean(
+ weight
+ * F.smooth_l1_loss(
+ ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
+ ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
+ reduction="none",
+ )
+ )
+ loss_reg_y = paddle.mean(
+ weight
+ * F.smooth_l1_loss(
+ ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
+ ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
+ reduction="none",
+ )
+ )
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
def ohem(self, predict, target, train_mask):
-
- pos = (target * train_mask).astype('bool')
- neg = ((1 - target) * train_mask).astype('bool')
+ pos = (target * train_mask).astype("bool")
+ neg = ((1 - target) * train_mask).astype("bool")
pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1)
neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1)
- n_pos = pos.astype('float32').sum()
+ n_pos = pos.astype("float32").sum()
if n_pos.item() > 0:
loss_pos = F.cross_entropy(
predict.masked_select(pos2).reshape([-1, 2]),
- target.masked_select(pos).astype('int64'),
- reduction='sum')
+ target.masked_select(pos).astype("int64"),
+ reduction="sum",
+ )
loss_neg = F.cross_entropy(
predict.masked_select(neg2).reshape([-1, 2]),
- target.masked_select(neg).astype('int64'),
- reduction='none')
+ target.masked_select(neg).astype("int64"),
+ reduction="none",
+ )
n_neg = min(
- int(neg.astype('float32').sum().item()),
- int(self.ohem_ratio * n_pos.astype('float32')))
+ int(neg.astype("float32").sum().item()),
+ int(self.ohem_ratio * n_pos.astype("float32")),
+ )
else:
- loss_pos = paddle.to_tensor(0.)
+ loss_pos = paddle.to_tensor(0.0)
loss_neg = F.cross_entropy(
predict.masked_select(neg2).reshape([-1, 2]),
- target.masked_select(neg).astype('int64'),
- reduction='none')
+ target.masked_select(neg).astype("int64"),
+ reduction="none",
+ )
n_neg = 100
if len(loss_neg) > n_neg:
loss_neg, _ = paddle.topk(loss_neg, n_neg)
- return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype('float32')
+ return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype("float32")
def fourier2poly(self, real_maps, imag_maps):
"""Transform Fourier coefficient maps to polygon maps.
@@ -204,22 +223,16 @@ def fourier2poly(self, real_maps, imag_maps):
"""
k_vect = paddle.arange(
- -self.fourier_degree, self.fourier_degree + 1,
- dtype='float32').reshape([-1, 1])
- i_vect = paddle.arange(
- 0, self.num_sample, dtype='float32').reshape([1, -1])
-
- transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect,
- i_vect)
-
- x1 = paddle.einsum('ak, kn-> an', real_maps,
- paddle.cos(transform_matrix))
- x2 = paddle.einsum('ak, kn-> an', imag_maps,
- paddle.sin(transform_matrix))
- y1 = paddle.einsum('ak, kn-> an', real_maps,
- paddle.sin(transform_matrix))
- y2 = paddle.einsum('ak, kn-> an', imag_maps,
- paddle.cos(transform_matrix))
+ -self.fourier_degree, self.fourier_degree + 1, dtype="float32"
+ ).reshape([-1, 1])
+ i_vect = paddle.arange(0, self.num_sample, dtype="float32").reshape([1, -1])
+
+ transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect, i_vect)
+
+ x1 = paddle.einsum("ak, kn-> an", real_maps, paddle.cos(transform_matrix))
+ x2 = paddle.einsum("ak, kn-> an", imag_maps, paddle.sin(transform_matrix))
+ y1 = paddle.einsum("ak, kn-> an", real_maps, paddle.sin(transform_matrix))
+ y2 = paddle.einsum("ak, kn-> an", imag_maps, paddle.cos(transform_matrix))
x_maps = x1 - x2
y_maps = y1 + y2
diff --git a/ppocr/losses/det_pse_loss.py b/ppocr/losses/det_pse_loss.py
index 6b31343ed4..8c4b2b4244 100644
--- a/ppocr/losses/det_pse_loss.py
+++ b/ppocr/losses/det_pse_loss.py
@@ -24,17 +24,18 @@
class PSELoss(nn.Layer):
- def __init__(self,
- alpha,
- ohem_ratio=3,
- kernel_sample_mask='pred',
- reduction='sum',
- eps=1e-6,
- **kwargs):
- """Implement PSE Loss.
- """
+ def __init__(
+ self,
+ alpha,
+ ohem_ratio=3,
+ kernel_sample_mask="pred",
+ reduction="sum",
+ eps=1e-6,
+ **kwargs
+ ):
+ """Implement PSE Loss."""
super(PSELoss, self).__init__()
- assert reduction in ['sum', 'mean', 'none']
+ assert reduction in ["sum", "mean", "none"]
self.alpha = alpha
self.ohem_ratio = ohem_ratio
self.kernel_sample_mask = kernel_sample_mask
@@ -42,7 +43,7 @@ def __init__(self,
self.eps = eps
def forward(self, outputs, labels):
- predicts = outputs['maps']
+ predicts = outputs["maps"]
predicts = F.interpolate(predicts, scale_factor=4)
texts = predicts[:, 0, :, :]
@@ -53,37 +54,36 @@ def forward(self, outputs, labels):
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
- iou_text = iou((texts > 0).astype('int64'),
- gt_texts,
- training_masks,
- reduce=False)
+ iou_text = iou(
+ (texts > 0).astype("int64"), gt_texts, training_masks, reduce=False
+ )
losses = dict(loss_text=loss_text, iou_text=iou_text)
# kernel loss
loss_kernels = []
- if self.kernel_sample_mask == 'gt':
+ if self.kernel_sample_mask == "gt":
selected_masks = gt_texts * training_masks
- elif self.kernel_sample_mask == 'pred':
- selected_masks = (
- F.sigmoid(texts) > 0.5).astype('float32') * training_masks
+ elif self.kernel_sample_mask == "pred":
+ selected_masks = (F.sigmoid(texts) > 0.5).astype("float32") * training_masks
for i in range(kernels.shape[1]):
kernel_i = kernels[:, i, :, :]
gt_kernel_i = gt_kernels[:, i, :, :]
- loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
- selected_masks)
+ loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
loss_kernels.append(loss_kernel_i)
loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
- iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
- gt_kernels[:, -1, :, :],
- training_masks * gt_texts,
- reduce=False)
+ iou_kernel = iou(
+ (kernels[:, -1, :, :] > 0).astype("int64"),
+ gt_kernels[:, -1, :, :],
+ training_masks * gt_texts,
+ reduce=False,
+ )
losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
- losses['loss'] = loss
- if self.reduction == 'sum':
+ losses["loss"] = loss
+ if self.reduction == "sum":
losses = {x: paddle.sum(v) for x, v in losses.items()}
- elif self.reduction == 'mean':
+ elif self.reduction == "mean":
losses = {x: paddle.mean(v) for x, v in losses.items()}
return losses
@@ -104,26 +104,29 @@ def dice_loss(self, input, target, mask):
return 1 - d
def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
- pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
+ pos_num = int(paddle.sum((gt_text > 0.5).astype("float32"))) - int(
paddle.sum(
- paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
- .astype('float32')))
+ paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5)).astype(
+ "float32"
+ )
+ )
+ )
if pos_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(
- [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
- 'float32')
+ [1, selected_mask.shape[0], selected_mask.shape[1]]
+ ).astype("float32")
return selected_mask
- neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
+ neg_num = int(paddle.sum((gt_text <= 0.5).astype("float32")))
neg_num = int(min(pos_num * ohem_ratio, neg_num))
if neg_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(
- [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
- 'float32')
+ [1, selected_mask.shape[0], selected_mask.shape[1]]
+ ).astype("float32")
return selected_mask
neg_score = paddle.masked_select(score, gt_text <= 0.5)
@@ -132,18 +135,24 @@ def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
selected_mask = paddle.logical_and(
paddle.logical_or((score >= threshold), (gt_text > 0.5)),
- (training_mask > 0.5))
+ (training_mask > 0.5),
+ )
selected_mask = selected_mask.reshape(
- [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
- 'float32')
+ [1, selected_mask.shape[0], selected_mask.shape[1]]
+ ).astype("float32")
return selected_mask
def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(
- self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
- training_masks[i, :, :], ohem_ratio))
-
- selected_masks = paddle.concat(selected_masks, 0).astype('float32')
+ self.ohem_single(
+ scores[i, :, :],
+ gt_texts[i, :, :],
+ training_masks[i, :, :],
+ ohem_ratio,
+ )
+ )
+
+ selected_masks = paddle.concat(selected_masks, 0).astype("float32")
return selected_masks
diff --git a/ppocr/losses/det_sast_loss.py b/ppocr/losses/det_sast_loss.py
index 2e0c756bd4..160eb06f4f 100644
--- a/ppocr/losses/det_sast_loss.py
+++ b/ppocr/losses/det_sast_loss.py
@@ -23,8 +23,7 @@
class SASTLoss(nn.Layer):
- """
- """
+ """ """
def __init__(self, eps=1e-6, **kwargs):
super(SASTLoss, self).__init__()
@@ -37,42 +36,43 @@ def forward(self, predicts, labels):
tcl_label: N x X list or LoDTensor
"""
- f_score = predicts['f_score']
- f_border = predicts['f_border']
- f_tvo = predicts['f_tvo']
- f_tco = predicts['f_tco']
+ f_score = predicts["f_score"]
+ f_border = predicts["f_border"]
+ f_tvo = predicts["f_tvo"]
+ f_tco = predicts["f_tco"]
l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
- #score_loss
+ # score_loss
intersection = paddle.sum(f_score * l_score * l_mask)
union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
- #border loss
+ # border loss
l_border_split, l_border_norm = paddle.split(
- l_border, num_or_sections=[4, 1], axis=1)
+ l_border, num_or_sections=[4, 1], axis=1
+ )
f_border_split = f_border
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
- l_border_norm_split = paddle.expand(
- x=l_border_norm, shape=border_ex_shape)
+ l_border_norm_split = paddle.expand(x=l_border_norm, shape=border_ex_shape)
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
- border_sign = paddle.cast(border_sign, dtype='float32')
+ border_sign = paddle.cast(border_sign, dtype="float32")
border_sign.stop_gradient = True
- border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
- (abs_border_diff - 0.5) * (1.0 - border_sign)
+ border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + (
+ abs_border_diff - 0.5
+ ) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
- border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
- (paddle.sum(l_border_score * l_border_mask) + 1e-5)
+ border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / (
+ paddle.sum(l_border_score * l_border_mask) + 1e-5
+ )
- #tvo_loss
- l_tvo_split, l_tvo_norm = paddle.split(
- l_tvo, num_or_sections=[8, 1], axis=1)
+ # tvo_loss
+ l_tvo_split, l_tvo_norm = paddle.split(l_tvo, num_or_sections=[8, 1], axis=1)
f_tvo_split = f_tvo
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
@@ -82,17 +82,18 @@ def forward(self, predicts, labels):
tvo_geo_diff = l_tvo_split - f_tvo_split
abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
tvo_sign = abs_tvo_geo_diff < 1.0
- tvo_sign = paddle.cast(tvo_sign, dtype='float32')
+ tvo_sign = paddle.cast(tvo_sign, dtype="float32")
tvo_sign.stop_gradient = True
- tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
- (abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
+ tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + (
+ abs_tvo_geo_diff - 0.5
+ ) * (1.0 - tvo_sign)
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
- tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
- (paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
+ tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / (
+ paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5
+ )
- #tco_loss
- l_tco_split, l_tco_norm = paddle.split(
- l_tco, num_or_sections=[2, 1], axis=1)
+ # tco_loss
+ l_tco_split, l_tco_norm = paddle.split(l_tco, num_or_sections=[2, 1], axis=1)
f_tco_split = f_tco
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
@@ -102,20 +103,31 @@ def forward(self, predicts, labels):
tco_geo_diff = l_tco_split - f_tco_split
abs_tco_geo_diff = paddle.abs(tco_geo_diff)
tco_sign = abs_tco_geo_diff < 1.0
- tco_sign = paddle.cast(tco_sign, dtype='float32')
+ tco_sign = paddle.cast(tco_sign, dtype="float32")
tco_sign.stop_gradient = True
- tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
- (abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
+ tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + (
+ abs_tco_geo_diff - 0.5
+ ) * (1.0 - tco_sign)
tco_out_loss = l_tco_norm_split * tco_in_loss
- tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
- (paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
+ tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / (
+ paddle.sum(l_tco_score * l_tco_mask) + 1e-5
+ )
# total loss
tvo_lw, tco_lw = 1.5, 1.5
score_lw, border_lw = 1.0, 1.0
- total_loss = score_loss * score_lw + border_loss * border_lw + \
- tvo_loss * tvo_lw + tco_loss * tco_lw
-
- losses = {'loss':total_loss, "score_loss":score_loss,\
- "border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
+ total_loss = (
+ score_loss * score_lw
+ + border_loss * border_lw
+ + tvo_loss * tvo_lw
+ + tco_loss * tco_lw
+ )
+
+ losses = {
+ "loss": total_loss,
+ "score_loss": score_loss,
+ "border_loss": border_loss,
+ "tvo_loss": tvo_loss,
+ "tco_loss": tco_loss,
+ }
return losses
diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py
index 97a3373300..4d0f751c5f 100644
--- a/ppocr/losses/distillation_loss.py
+++ b/ppocr/losses/distillation_loss.py
@@ -1,16 +1,16 @@
-#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import paddle
import paddle.nn as nn
@@ -33,7 +33,7 @@ def _sum_loss(loss_dict):
if "loss" in loss_dict.keys():
return loss_dict
else:
- loss_dict["loss"] = 0.
+ loss_dict["loss"] = 0.0
for k, value in loss_dict.items():
if k == "loss":
continue
@@ -43,18 +43,19 @@ def _sum_loss(loss_dict):
class DistillationDMLLoss(DMLLoss):
- """
- """
-
- def __init__(self,
- model_name_pairs=[],
- act=None,
- use_log=False,
- key=None,
- multi_head=False,
- dis_head='ctc',
- maps_name=None,
- name="dml"):
+ """ """
+
+ def __init__(
+ self,
+ model_name_pairs=[],
+ act=None,
+ use_log=False,
+ key=None,
+ multi_head=False,
+ dis_head="ctc",
+ maps_name=None,
+ name="dml",
+ ):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
@@ -68,7 +69,8 @@ def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
- model_name_pairs[0][0], str):
+ model_name_pairs[0][0], str
+ ):
return model_name_pairs
else:
return [model_name_pairs]
@@ -106,14 +108,14 @@ def forward(self, predicts, batch):
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
- loss = super().forward(out1[self.dis_head],
- out2[self.dis_head])
+ loss = super().forward(out1[self.dis_head], out2[self.dis_head])
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
- idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}".format(key, pair[0], pair[1], idx)
+ ] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
@@ -123,11 +125,15 @@ def forward(self, predicts, batch):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}_{}".format(key, pair[
- 0], pair[1], self.maps_name, idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}_{}".format(
+ key, pair[0], pair[1], self.maps_name, idx
+ )
+ ] = loss[key]
else:
- loss_dict["{}_{}_{}".format(self.name, self.maps_name[
- _c], idx)] = loss
+ loss_dict[
+ "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
+ ] = loss
loss_dict = _sum_loss(loss_dict)
@@ -135,16 +141,17 @@ def forward(self, predicts, batch):
class DistillationKLDivLoss(KLDivLoss):
- """
- """
-
- def __init__(self,
- model_name_pairs=[],
- key=None,
- multi_head=False,
- dis_head='ctc',
- maps_name=None,
- name="kl_div"):
+ """ """
+
+ def __init__(
+ self,
+ model_name_pairs=[],
+ key=None,
+ multi_head=False,
+ dis_head="ctc",
+ maps_name=None,
+ name="kl_div",
+ ):
super().__init__()
assert isinstance(model_name_pairs, list)
self.key = key
@@ -158,7 +165,8 @@ def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
- model_name_pairs[0][0], str):
+ model_name_pairs[0][0], str
+ ):
return model_name_pairs
else:
return [model_name_pairs]
@@ -198,19 +206,21 @@ def forward(self, predicts, batch):
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
- tgt = batch[2][:, 1:2 + max_len]
+ tgt = batch[2][:, 1 : 2 + max_len]
tgt = tgt.reshape([-1])
non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape, dtype=tgt.dtype))
- loss = super().forward(out1[self.dis_head],
- out2[self.dis_head], non_pad_mask)
+ tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
+ )
+ loss = super().forward(
+ out1[self.dis_head], out2[self.dis_head], non_pad_mask
+ )
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
- idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}".format(key, pair[0], pair[1], idx)
+ ] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
@@ -220,11 +230,15 @@ def forward(self, predicts, batch):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}_{}".format(key, pair[
- 0], pair[1], self.maps_name, idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}_{}".format(
+ key, pair[0], pair[1], self.maps_name, idx
+ )
+ ] = loss[key]
else:
- loss_dict["{}_{}_{}".format(self.name, self.maps_name[
- _c], idx)] = loss
+ loss_dict[
+ "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
+ ] = loss
loss_dict = _sum_loss(loss_dict)
@@ -232,19 +246,20 @@ def forward(self, predicts, batch):
class DistillationDKDLoss(DKDLoss):
- """
- """
-
- def __init__(self,
- model_name_pairs=[],
- key=None,
- multi_head=False,
- dis_head='ctc',
- maps_name=None,
- name="dkd",
- temperature=1.0,
- alpha=1.0,
- beta=1.0):
+ """ """
+
+ def __init__(
+ self,
+ model_name_pairs=[],
+ key=None,
+ multi_head=False,
+ dis_head="ctc",
+ maps_name=None,
+ name="dkd",
+ temperature=1.0,
+ alpha=1.0,
+ beta=1.0,
+ ):
super().__init__(temperature, alpha, beta)
assert isinstance(model_name_pairs, list)
self.key = key
@@ -258,7 +273,8 @@ def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
- model_name_pairs[0][0], str):
+ model_name_pairs[0][0], str
+ ):
return model_name_pairs
else:
return [model_name_pairs]
@@ -299,24 +315,23 @@ def forward(self, predicts, batch):
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
- tgt = batch[2][:, 1:2 +
- max_len] # [batch_size, max_len + 1]
+ tgt = batch[2][:, 1 : 2 + max_len] # [batch_size, max_len + 1]
tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape,
- dtype=tgt.dtype)) # batch_size * (max_len + 1)
+ tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
+ ) # batch_size * (max_len + 1)
loss = super().forward(
- out1[self.dis_head], out2[self.dis_head], tgt,
- non_pad_mask) # [batch_size, max_len + 1, num_char]
+ out1[self.dis_head], out2[self.dis_head], tgt, non_pad_mask
+ ) # [batch_size, max_len + 1, num_char]
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
- idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}".format(key, pair[0], pair[1], idx)
+ ] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
@@ -326,11 +341,15 @@ def forward(self, predicts, batch):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}_{}".format(key, pair[
- 0], pair[1], self.maps_name, idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}_{}".format(
+ key, pair[0], pair[1], self.maps_name, idx
+ )
+ ] = loss[key]
else:
- loss_dict["{}_{}_{}".format(self.name, self.maps_name[
- _c], idx)] = loss
+ loss_dict[
+ "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
+ ] = loss
loss_dict = _sum_loss(loss_dict)
@@ -338,8 +357,7 @@ def forward(self, predicts, batch):
class DistillationNRTRDMLLoss(DistillationDMLLoss):
- """
- """
+ """ """
def forward(self, predicts, batch):
loss_dict = dict()
@@ -353,19 +371,21 @@ def forward(self, predicts, batch):
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
- tgt = batch[2][:, 1:2 + max_len]
+ tgt = batch[2][:, 1 : 2 + max_len]
tgt = tgt.reshape([-1])
non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape, dtype=tgt.dtype))
- loss = super().forward(out1[self.dis_head], out2[self.dis_head],
- non_pad_mask)
+ tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
+ )
+ loss = super().forward(
+ out1[self.dis_head], out2[self.dis_head], non_pad_mask
+ )
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
- idx)] = loss[key]
+ loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], idx)] = loss[
+ key
+ ]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
@@ -375,16 +395,17 @@ def forward(self, predicts, batch):
class DistillationKLDivLoss(KLDivLoss):
- """
- """
-
- def __init__(self,
- model_name_pairs=[],
- key=None,
- multi_head=False,
- dis_head='ctc',
- maps_name=None,
- name="kl_div"):
+ """ """
+
+ def __init__(
+ self,
+ model_name_pairs=[],
+ key=None,
+ multi_head=False,
+ dis_head="ctc",
+ maps_name=None,
+ name="kl_div",
+ ):
super().__init__()
assert isinstance(model_name_pairs, list)
self.key = key
@@ -398,7 +419,8 @@ def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
- model_name_pairs[0][0], str):
+ model_name_pairs[0][0], str
+ ):
return model_name_pairs
else:
return [model_name_pairs]
@@ -438,19 +460,21 @@ def forward(self, predicts, batch):
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
- tgt = batch[2][:, 1:2 + max_len]
+ tgt = batch[2][:, 1 : 2 + max_len]
tgt = tgt.reshape([-1])
non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape, dtype=tgt.dtype))
- loss = super().forward(out1[self.dis_head],
- out2[self.dis_head], non_pad_mask)
+ tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
+ )
+ loss = super().forward(
+ out1[self.dis_head], out2[self.dis_head], non_pad_mask
+ )
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
- idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}".format(key, pair[0], pair[1], idx)
+ ] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
@@ -460,11 +484,15 @@ def forward(self, predicts, batch):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}_{}".format(key, pair[
- 0], pair[1], self.maps_name, idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}_{}".format(
+ key, pair[0], pair[1], self.maps_name, idx
+ )
+ ] = loss[key]
else:
- loss_dict["{}_{}_{}".format(self.name, self.maps_name[
- _c], idx)] = loss
+ loss_dict[
+ "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
+ ] = loss
loss_dict = _sum_loss(loss_dict)
@@ -472,19 +500,20 @@ def forward(self, predicts, batch):
class DistillationDKDLoss(DKDLoss):
- """
- """
-
- def __init__(self,
- model_name_pairs=[],
- key=None,
- multi_head=False,
- dis_head='ctc',
- maps_name=None,
- name="dkd",
- temperature=1.0,
- alpha=1.0,
- beta=1.0):
+ """ """
+
+ def __init__(
+ self,
+ model_name_pairs=[],
+ key=None,
+ multi_head=False,
+ dis_head="ctc",
+ maps_name=None,
+ name="dkd",
+ temperature=1.0,
+ alpha=1.0,
+ beta=1.0,
+ ):
super().__init__(temperature, alpha, beta)
assert isinstance(model_name_pairs, list)
self.key = key
@@ -498,7 +527,8 @@ def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
- model_name_pairs[0][0], str):
+ model_name_pairs[0][0], str
+ ):
return model_name_pairs
else:
return [model_name_pairs]
@@ -539,24 +569,23 @@ def forward(self, predicts, batch):
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
- tgt = batch[2][:, 1:2 +
- max_len] # [batch_size, max_len + 1]
+ tgt = batch[2][:, 1 : 2 + max_len] # [batch_size, max_len + 1]
tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape,
- dtype=tgt.dtype)) # batch_size * (max_len + 1)
+ tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
+ ) # batch_size * (max_len + 1)
loss = super().forward(
- out1[self.dis_head], out2[self.dis_head], tgt,
- non_pad_mask) # [batch_size, max_len + 1, num_char]
+ out1[self.dis_head], out2[self.dis_head], tgt, non_pad_mask
+ ) # [batch_size, max_len + 1, num_char]
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
- idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}".format(key, pair[0], pair[1], idx)
+ ] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
@@ -566,11 +595,15 @@ def forward(self, predicts, batch):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}_{}_{}".format(key, pair[
- 0], pair[1], self.maps_name, idx)] = loss[key]
+ loss_dict[
+ "{}_{}_{}_{}_{}".format(
+ key, pair[0], pair[1], self.maps_name, idx
+ )
+ ] = loss[key]
else:
- loss_dict["{}_{}_{}".format(self.name, self.maps_name[
- _c], idx)] = loss
+ loss_dict[
+ "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
+ ] = loss
loss_dict = _sum_loss(loss_dict)
@@ -578,11 +611,7 @@ def forward(self, predicts, batch):
class DistillationCTCLoss(CTCLoss):
- def __init__(self,
- model_name_list=[],
- key=None,
- multi_head=False,
- name="loss_ctc"):
+ def __init__(self, model_name_list=[], key=None, multi_head=False, name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
@@ -596,27 +625,23 @@ def forward(self, predicts, batch):
if self.key is not None:
out = out[self.key]
if self.multi_head:
- assert 'ctc' in out, 'multi head has multi out'
- loss = super().forward(out['ctc'], batch[:2] + batch[3:])
+ assert "ctc" in out, "multi head has multi out"
+ loss = super().forward(out["ctc"], batch[:2] + batch[3:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}".format(self.name, model_name,
- idx)] = loss[key]
+ loss_dict["{}_{}_{}".format(self.name, model_name, idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationSARLoss(SARLoss):
- def __init__(self,
- model_name_list=[],
- key=None,
- multi_head=False,
- name="loss_sar",
- **kwargs):
- ignore_index = kwargs.get('ignore_index', 92)
+ def __init__(
+ self, model_name_list=[], key=None, multi_head=False, name="loss_sar", **kwargs
+ ):
+ ignore_index = kwargs.get("ignore_index", 92)
super().__init__(ignore_index=ignore_index)
self.model_name_list = model_name_list
self.key = key
@@ -630,27 +655,28 @@ def forward(self, predicts, batch):
if self.key is not None:
out = out[self.key]
if self.multi_head:
- assert 'sar' in out, 'multi head has multi out'
- loss = super().forward(out['sar'], batch[:1] + batch[2:])
+ assert "sar" in out, "multi head has multi out"
+ loss = super().forward(out["sar"], batch[:1] + batch[2:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}".format(self.name, model_name,
- idx)] = loss[key]
+ loss_dict["{}_{}_{}".format(self.name, model_name, idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationNRTRLoss(CELoss):
- def __init__(self,
- model_name_list=[],
- key=None,
- multi_head=False,
- smoothing=True,
- name="loss_nrtr",
- **kwargs):
+ def __init__(
+ self,
+ model_name_list=[],
+ key=None,
+ multi_head=False,
+ smoothing=True,
+ name="loss_nrtr",
+ **kwargs
+ ):
super().__init__(smoothing=smoothing)
self.model_name_list = model_name_list
self.key = key
@@ -664,30 +690,31 @@ def forward(self, predicts, batch):
if self.key is not None:
out = out[self.key]
if self.multi_head:
- assert 'gtc' in out, 'multi head has multi out'
- loss = super().forward(out['gtc'], batch[:1] + batch[2:])
+ assert "gtc" in out, "multi head has multi out"
+ loss = super().forward(out["gtc"], batch[:1] + batch[2:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}".format(self.name, model_name,
- idx)] = loss[key]
+ loss_dict["{}_{}_{}".format(self.name, model_name, idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationDBLoss(DBLoss):
- def __init__(self,
- model_name_list=[],
- balance_loss=True,
- main_loss_type='DiceLoss',
- alpha=5,
- beta=10,
- ohem_ratio=3,
- eps=1e-6,
- name="db",
- **kwargs):
+ def __init__(
+ self,
+ model_name_list=[],
+ balance_loss=True,
+ main_loss_type="DiceLoss",
+ alpha=5,
+ beta=10,
+ ohem_ratio=3,
+ eps=1e-6,
+ name="db",
+ **kwargs
+ ):
super().__init__()
self.model_name_list = model_name_list
self.name = name
@@ -715,16 +742,18 @@ def forward(self, predicts, batch):
class DistillationDilaDBLoss(DBLoss):
- def __init__(self,
- model_name_pairs=[],
- key=None,
- balance_loss=True,
- main_loss_type='DiceLoss',
- alpha=5,
- beta=10,
- ohem_ratio=3,
- eps=1e-6,
- name="dila_dbloss"):
+ def __init__(
+ self,
+ model_name_pairs=[],
+ key=None,
+ balance_loss=True,
+ main_loss_type="DiceLoss",
+ alpha=5,
+ beta=10,
+ ohem_ratio=3,
+ eps=1e-6,
+ name="dila_dbloss",
+ ):
super().__init__()
self.model_name_pairs = model_name_pairs
self.name = name
@@ -747,21 +776,28 @@ def forward(self, predicts, batch):
th_shrink_maps = tch_preds[:, 0, :, :]
if hasattr(paddle.Tensor, "contiguous"):
th_shrink_maps = th_shrink_maps.contiguous()
- th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
+ th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
for i in range(th_shrink_maps.shape[0]):
dilate_maps[i] = cv2.dilate(
- th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
+ th_shrink_maps[i, :, :].astype(np.uint8), dilation_w
+ )
th_shrink_maps = paddle.to_tensor(dilate_maps)
- label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
- 1:]
+ (
+ label_threshold_map,
+ label_threshold_mask,
+ label_shrink_map,
+ label_shrink_mask,
+ ) = batch[1:]
# calculate the shrink map loss
bce_loss = self.alpha * self.bce_loss(
- stu_shrink_maps, th_shrink_maps, label_shrink_mask)
- loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
- label_shrink_mask)
+ stu_shrink_maps, th_shrink_maps, label_shrink_mask
+ )
+ loss_binary_maps = self.dice_loss(
+ stu_binary_maps, th_shrink_maps, label_shrink_mask
+ )
# k = f"{self.name}_{pair[0]}_{pair[1]}"
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
@@ -772,15 +808,11 @@ def forward(self, predicts, batch):
class DistillationDistanceLoss(DistanceLoss):
- """
- """
+ """ """
- def __init__(self,
- mode="l2",
- model_name_pairs=[],
- key=None,
- name="loss_distance",
- **kargs):
+ def __init__(
+ self, mode="l2", model_name_pairs=[], key=None, name="loss_distance", **kargs
+ ):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
@@ -798,20 +830,14 @@ def forward(self, predicts, batch):
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
- key]
+ loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[key]
else:
- loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
- idx)] = loss
+ loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)] = loss
return loss_dict
class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
- def __init__(self,
- num_classes,
- model_name_list=[],
- key=None,
- name="loss_ser"):
+ def __init__(self, num_classes, model_name_list=[], key=None, name="loss_ser"):
super().__init__(num_classes=num_classes)
self.model_name_list = model_name_list
self.key = key
@@ -829,12 +855,14 @@ def forward(self, predicts, batch):
class DistillationLossFromOutput(LossFromOutput):
- def __init__(self,
- reduction="none",
- model_name_list=[],
- dist_key=None,
- key="loss",
- name="loss_re"):
+ def __init__(
+ self,
+ reduction="none",
+ model_name_list=[],
+ dist_key=None,
+ key="loss",
+ name="loss_re",
+ ):
super().__init__(key=key, reduction=reduction)
self.model_name_list = model_name_list
self.name = name
@@ -852,16 +880,17 @@ def forward(self, predicts, batch):
class DistillationSERDMLLoss(DMLLoss):
- """
- """
-
- def __init__(self,
- act="softmax",
- use_log=True,
- num_classes=7,
- model_name_pairs=[],
- key=None,
- name="loss_dml_ser"):
+ """ """
+
+ def __init__(
+ self,
+ act="softmax",
+ use_log=True,
+ num_classes=7,
+ model_name_pairs=[],
+ key=None,
+ name="loss_dml_ser",
+ ):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
@@ -882,24 +911,32 @@ def forward(self, predicts, batch):
attention_mask = batch[2]
if attention_mask is not None:
- active_output = attention_mask.reshape([-1, ]) == 1
+ active_output = (
+ attention_mask.reshape(
+ [
+ -1,
+ ]
+ )
+ == 1
+ )
out1 = out1[active_output]
out2 = out2[active_output]
- loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1,
- out2)
+ loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, out2)
return loss_dict
class DistillationVQADistanceLoss(DistanceLoss):
- def __init__(self,
- mode="l2",
- model_name_pairs=[],
- key=None,
- index=None,
- name="loss_distance",
- **kargs):
+ def __init__(
+ self,
+ mode="l2",
+ model_name_pairs=[],
+ key=None,
+ index=None,
+ name="loss_distance",
+ **kargs
+ ):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
@@ -926,18 +963,23 @@ def forward(self, predicts, batch):
out1 = out1.reshape([-1, out1.shape[-1]])
out2 = out2.reshape([-1, out2.shape[-1]])
if attention_mask is not None:
- active_output = attention_mask.reshape([-1, ]) == 1
+ active_output = (
+ attention_mask.reshape(
+ [
+ -1,
+ ]
+ )
+ == 1
+ )
out1 = out1[active_output]
out2 = out2[active_output]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}nohu_{}".format(self.name, key,
- idx)] = loss[key]
+ loss_dict["{}_{}nohu_{}".format(self.name, key, idx)] = loss[key]
else:
- loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
- idx)] = loss
+ loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)] = loss
return loss_dict
@@ -958,7 +1000,8 @@ def __init__(self, temperature=0.5, alpha=1.0, beta=1.0):
def kl_loss(self, p1, p2): # predict, label
loss = paddle.multiply(
- p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps))
+ p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps)
+ )
bs = loss.shape[0]
loss = paddle.sum(loss) / bs
return loss
@@ -970,7 +1013,6 @@ def _cat_mask(self, t, mask1, mask2):
return rt
def multi_label_mask(self, targets):
-
targets = targets.astype("int32")
res = F.one_hot(targets, num_classes=11465)
mask = paddle.clip(paddle.sum(res, axis=1), 0, 1)
@@ -978,7 +1020,6 @@ def multi_label_mask(self, targets):
return mask
def forward(self, logits_student, logits_teacher, targets, mask=None):
-
gt_mask = self.multi_label_mask(targets)
other_mask = paddle.ones_like(gt_mask) - gt_mask
@@ -997,9 +1038,11 @@ def forward(self, logits_student, logits_teacher, targets, mask=None):
gt_mask_ex = paddle.expand_as(gt_mask.unsqueeze(axis=1), logits_teacher)
pred_teacher_part2 = F.softmax(
- logits_teacher / self.temperature - 1000.0 * gt_mask_ex, axis=-1)
+ logits_teacher / self.temperature - 1000.0 * gt_mask_ex, axis=-1
+ )
pred_student_part2 = F.softmax(
- logits_student / self.temperature - 1000.0 * gt_mask_ex, axis=-1)
+ logits_student / self.temperature - 1000.0 * gt_mask_ex, axis=-1
+ )
# differents with dkd
pred_teacher_part2 = paddle.mean(pred_teacher_part2, axis=1)
pred_student_part2 = paddle.mean(pred_student_part2, axis=1)
@@ -1011,7 +1054,7 @@ def forward(self, logits_student, logits_teacher, targets, mask=None):
class KLCTCLogits(nn.Layer):
- def __init__(self, weight=1.0, reduction='mean', mode="mean"):
+ def __init__(self, weight=1.0, reduction="mean", mode="mean"):
super().__init__()
self.weight = weight
self.reduction = reduction
@@ -1024,13 +1067,13 @@ def __init__(self, weight=1.0, reduction='mean', mode="mean"):
def kl_loss(self, p1, p2): # predict, label
loss = paddle.multiply(
- p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps))
+ p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps)
+ )
bs = loss.shape[0]
loss = paddle.sum(loss) / bs
return loss
def forward_meanmax(self, stu_out, tea_out):
-
stu_out = paddle.mean(F.softmax(stu_out / self.t, axis=-1), axis=1)
tea_out = paddle.mean(F.softmax(tea_out / self.t, axis=-1), axis=1)
loss = self.kl_loss(stu_out, tea_out)
@@ -1103,18 +1146,15 @@ def forward_log(self, out1, out2):
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
- loss = (
- self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
+ loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
return loss
class DistillCTCLogits(KLCTCLogits):
- def __init__(self,
- model_name_pairs=[],
- key=None,
- name="ctc_logits",
- reduction="mean"):
+ def __init__(
+ self, model_name_pairs=[], key=None, name="ctc_logits", reduction="mean"
+ ):
super().__init__(reduction=reduction)
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.key = key
@@ -1124,7 +1164,8 @@ def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
- model_name_pairs[0][0], str):
+ model_name_pairs[0][0], str
+ ):
return model_name_pairs
else:
return [model_name_pairs]
@@ -1136,15 +1177,14 @@ def forward(self, predicts, batch):
out2 = predicts[pair[1]]
if self.key is not None:
- out1 = out1[self.key]['ctc']
- out2 = out2[self.key]['ctc']
+ out1 = out1[self.key]["ctc"]
+ out2 = out2[self.key]["ctc"]
ctc_label = batch[1]
loss = super().forward(out1, out2, ctc_label)
if isinstance(loss, dict):
for key in loss:
- loss_dict["{}_{}_{}".format(self.name, model_name,
- idx)] = loss[key]
+ loss_dict["{}_{}_{}".format(self.name, model_name, idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict
diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py
index aff67b7ce3..8b667f0b2a 100644
--- a/ppocr/losses/e2e_pg_loss.py
+++ b/ppocr/losses/e2e_pg_loss.py
@@ -24,13 +24,9 @@
class PGLoss(nn.Layer):
- def __init__(self,
- tcl_bs,
- max_text_length,
- max_text_nums,
- pad_num,
- eps=1e-6,
- **kwargs):
+ def __init__(
+ self, tcl_bs, max_text_length, max_text_nums, pad_num, eps=1e-6, **kwargs
+ ):
super(PGLoss, self).__init__()
self.tcl_bs = tcl_bs
self.max_text_nums = max_text_nums
@@ -40,11 +36,11 @@ def __init__(self,
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
- l_border, num_or_sections=[4, 1], axis=1)
+ l_border, num_or_sections=[4, 1], axis=1
+ )
f_border_split = f_border
b, c, h, w = l_border_norm.shape
- l_border_norm_split = paddle.expand(
- x=l_border_norm, shape=[b, 4 * c, h, w])
+ l_border_norm_split = paddle.expand(x=l_border_norm, shape=[b, 4 * c, h, w])
b, c, h, w = l_score.shape
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
b, c, h, w = l_mask.shape
@@ -52,22 +48,26 @@ def border_loss(self, f_border, l_border, l_score, l_mask):
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
- border_sign = paddle.cast(border_sign, dtype='float32')
+ border_sign = paddle.cast(border_sign, dtype="float32")
border_sign.stop_gradient = True
- border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
- (abs_border_diff - 0.5) * (1.0 - border_sign)
+ border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + (
+ abs_border_diff - 0.5
+ ) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
- border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
- (paddle.sum(l_border_score * l_border_mask) + 1e-5)
+ border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / (
+ paddle.sum(l_border_score * l_border_mask) + 1e-5
+ )
return border_loss
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
l_direction_split, l_direction_norm = paddle.tensor.split(
- l_direction, num_or_sections=[2, 1], axis=1)
+ l_direction, num_or_sections=[2, 1], axis=1
+ )
f_direction_split = f_direction
b, c, h, w = l_direction_norm.shape
l_direction_norm_split = paddle.expand(
- x=l_direction_norm, shape=[b, 2 * c, h, w])
+ x=l_direction_norm, shape=[b, 2 * c, h, w]
+ )
b, c, h, w = l_score.shape
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
b, c, h, w = l_mask.shape
@@ -75,13 +75,16 @@ def direction_loss(self, f_direction, l_direction, l_score, l_mask):
direction_diff = l_direction_split - f_direction_split
abs_direction_diff = paddle.abs(direction_diff)
direction_sign = abs_direction_diff < 1.0
- direction_sign = paddle.cast(direction_sign, dtype='float32')
+ direction_sign = paddle.cast(direction_sign, dtype="float32")
direction_sign.stop_gradient = True
- direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
- (abs_direction_diff - 0.5) * (1.0 - direction_sign)
+ direction_in_loss = (
+ 0.5 * abs_direction_diff * abs_direction_diff * direction_sign
+ + (abs_direction_diff - 0.5) * (1.0 - direction_sign)
+ )
direction_out_loss = l_direction_norm_split * direction_in_loss
- direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
- (paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
+ direction_loss = paddle.sum(
+ direction_out_loss * l_direction_score * l_direction_mask
+ ) / (paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
return direction_loss
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
@@ -90,52 +93,73 @@ def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(
- f_tcl_char, [-1, 64, self.pad_num + 1]) # len(Lexicon_Table)+1
+ f_tcl_char, [-1, 64, self.pad_num + 1]
+ ) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(
- f_tcl_char, [self.pad_num, 1], axis=2)
+ f_tcl_char, [self.pad_num, 1], axis=2
+ )
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, self.pad_num * l])
tcl_mask_fg.stop_gradient = True
- f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
- -20.0)
+ f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (-20.0)
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
N, B, _ = f_tcl_char_ld.shape
- input_lengths = paddle.to_tensor([N] * B, dtype='int64')
+ input_lengths = paddle.to_tensor([N] * B, dtype="int64")
cost = paddle.nn.functional.ctc_loss(
log_probs=f_tcl_char_ld,
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
blank=self.pad_num,
- reduction='none')
+ reduction="none",
+ )
cost = cost.mean()
return cost
def forward(self, predicts, labels):
- images, tcl_maps, tcl_label_maps, border_maps \
- , direction_maps, training_masks, label_list, pos_list, pos_mask = labels
+ (
+ images,
+ tcl_maps,
+ tcl_label_maps,
+ border_maps,
+ direction_maps,
+ training_masks,
+ label_list,
+ pos_list,
+ pos_mask,
+ ) = labels
# for all the batch_size
pos_list, pos_mask, label_list, label_t = pre_process(
- label_list, pos_list, pos_mask, self.max_text_length,
- self.max_text_nums, self.pad_num, self.tcl_bs)
+ label_list,
+ pos_list,
+ pos_mask,
+ self.max_text_length,
+ self.max_text_nums,
+ self.pad_num,
+ self.tcl_bs,
+ )
- f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
- predicts['f_char']
+ f_score, f_border, f_direction, f_char = (
+ predicts["f_score"],
+ predicts["f_border"],
+ predicts["f_direction"],
+ predicts["f_char"],
+ )
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
- border_loss = self.border_loss(f_border, border_maps, tcl_maps,
- training_masks)
- direction_loss = self.direction_loss(f_direction, direction_maps,
- tcl_maps, training_masks)
+ border_loss = self.border_loss(f_border, border_maps, tcl_maps, training_masks)
+ direction_loss = self.direction_loss(
+ f_direction, direction_maps, tcl_maps, training_masks
+ )
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
losses = {
- 'loss': loss_all,
+ "loss": loss_all,
"score_loss": score_loss,
"border_loss": border_loss,
"direction_loss": direction_loss,
- "ctc_loss": ctc_loss
+ "ctc_loss": ctc_loss,
}
return losses
diff --git a/ppocr/losses/kie_sdmgr_loss.py b/ppocr/losses/kie_sdmgr_loss.py
index 745671f58d..685a1412cb 100644
--- a/ppocr/losses/kie_sdmgr_loss.py
+++ b/ppocr/losses/kie_sdmgr_loss.py
@@ -37,9 +37,7 @@ def pre_process(self, gts, tag):
batch = len(tag)
for i in range(batch):
num, recoder_len = tag[i][0], tag[i][1]
- temp_gts.append(
- paddle.to_tensor(
- gts[i, :num, :num + 1], dtype='int64'))
+ temp_gts.append(paddle.to_tensor(gts[i, :num, : num + 1], dtype="int64"))
return temp_gts
def accuracy(self, pred, target, topk=1, thresh=None):
@@ -63,28 +61,28 @@ def accuracy(self, pred, target, topk=1, thresh=None):
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
- topk = (topk, )
+ topk = (topk,)
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.shape[0] == 0:
- accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ accu = [pred.new_tensor(0.0) for i in range(len(topk))]
return accu[0] if return_single else accu
pred_value, pred_label = paddle.topk(pred, maxk, axis=1)
- pred_label = pred_label.transpose(
- [1, 0]) # transpose to shape (maxk, N)
- correct = paddle.equal(pred_label,
- (target.reshape([1, -1]).expand_as(pred_label)))
+ pred_label = pred_label.transpose([1, 0]) # transpose to shape (maxk, N)
+ correct = paddle.equal(
+ pred_label, (target.reshape([1, -1]).expand_as(pred_label))
+ )
res = []
for k in topk:
- correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'),
- axis=0,
- keepdim=True)
+ correct_k = paddle.sum(
+ correct[:k].reshape([-1]).astype("float32"), axis=0, keepdim=True
+ )
res.append(
- paddle.multiply(correct_k,
- paddle.to_tensor(100.0 / pred.shape[0])))
+ paddle.multiply(correct_k, paddle.to_tensor(100.0 / pred.shape[0]))
+ )
return res[0] if return_single else res
def forward(self, pred, batch):
@@ -109,7 +107,10 @@ def forward(self, pred, batch):
loss_edge=loss_edge,
acc_node=self.accuracy(
paddle.gather(node_preds, node_valids),
- paddle.gather(node_gts, node_valids)),
+ paddle.gather(node_gts, node_valids),
+ ),
acc_edge=self.accuracy(
paddle.gather(edge_preds, edge_valids),
- paddle.gather(edge_gts, edge_valids)))
+ paddle.gather(edge_gts, edge_valids),
+ ),
+ )
diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py
index 9b0a34eeac..eff53d2906 100644
--- a/ppocr/losses/rec_aster_loss.py
+++ b/ppocr/losses/rec_aster_loss.py
@@ -21,34 +21,37 @@
class CosineEmbeddingLoss(nn.Layer):
- def __init__(self, margin=0.):
+ def __init__(self, margin=0.0):
super(CosineEmbeddingLoss, self).__init__()
self.margin = margin
self.epsilon = 1e-12
def forward(self, x1, x2, target):
- similarity = paddle.sum(
- x1 * x2, axis=-1) / (paddle.norm(
- x1, axis=-1) * paddle.norm(
- x2, axis=-1) + self.epsilon)
+ similarity = paddle.sum(x1 * x2, axis=-1) / (
+ paddle.norm(x1, axis=-1) * paddle.norm(x2, axis=-1) + self.epsilon
+ )
one_list = paddle.full_like(target, fill_value=1)
out = paddle.mean(
paddle.where(
- paddle.equal(target, one_list), 1. - similarity,
- paddle.maximum(
- paddle.zeros_like(similarity), similarity - self.margin)))
+ paddle.equal(target, one_list),
+ 1.0 - similarity,
+ paddle.maximum(paddle.zeros_like(similarity), similarity - self.margin),
+ )
+ )
return out
class AsterLoss(nn.Layer):
- def __init__(self,
- weight=None,
- size_average=True,
- ignore_index=-100,
- sequence_normalize=False,
- sample_normalize=True,
- **kwargs):
+ def __init__(
+ self,
+ weight=None,
+ size_average=True,
+ ignore_index=-100,
+ sequence_normalize=False,
+ sample_normalize=True,
+ **kwargs
+ ):
super(AsterLoss, self).__init__()
self.weight = weight
self.size_average = size_average
@@ -57,28 +60,29 @@ def __init__(self,
self.sample_normalize = sample_normalize
self.loss_sem = CosineEmbeddingLoss()
self.is_cosin_loss = True
- self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
+ self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction="none")
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
- label_lengths = batch[2].astype('int64')
- sem_target = batch[3].astype('float32')
- embedding_vectors = predicts['embedding_vectors']
- rec_pred = predicts['rec_pred']
+ label_lengths = batch[2].astype("int64")
+ sem_target = batch[3].astype("float32")
+ embedding_vectors = predicts["embedding_vectors"]
+ rec_pred = predicts["rec_pred"]
if not self.is_cosin_loss:
sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
else:
label_target = paddle.ones([embedding_vectors.shape[0]])
sem_loss = paddle.sum(
- self.loss_sem(embedding_vectors, sem_target, label_target))
+ self.loss_sem(embedding_vectors, sem_target, label_target)
+ )
# rec loss
batch_size, def_max_length = targets.shape[0], targets.shape[1]
mask = paddle.zeros([batch_size, def_max_length])
for i in range(batch_size):
- mask[i, :label_lengths[i]] = 1
+ mask[i, : label_lengths[i]] = 1
mask = paddle.cast(mask, "float32")
max_length = max(label_lengths)
assert max_length == rec_pred.shape[1]
@@ -96,4 +100,4 @@ def forward(self, predicts, batch):
output = output / batch_size
loss = output + sem_loss * 0.1
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/rec_att_loss.py b/ppocr/losses/rec_att_loss.py
index 6e2f67483c..e0f65d94e2 100644
--- a/ppocr/losses/rec_att_loss.py
+++ b/ppocr/losses/rec_att_loss.py
@@ -23,17 +23,21 @@
class AttentionLoss(nn.Layer):
def __init__(self, **kwargs):
super(AttentionLoss, self).__init__()
- self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction="none")
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
- label_lengths = batch[2].astype('int64')
- batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
- 1], predicts.shape[2]
- assert len(targets.shape) == len(list(predicts.shape)) - 1, \
- "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+ label_lengths = batch[2].astype("int64")
+ batch_size, num_steps, num_classes = (
+ predicts.shape[0],
+ predicts.shape[1],
+ predicts.shape[2],
+ )
+ assert (
+ len(targets.shape) == len(list(predicts.shape)) - 1
+ ), "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1])
- return {'loss': paddle.sum(self.loss_func(inputs, targets))}
+ return {"loss": paddle.sum(self.loss_func(inputs, targets))}
diff --git a/ppocr/losses/rec_can_loss.py b/ppocr/losses/rec_can_loss.py
index 227e17f5e1..6ec0b794b3 100644
--- a/ppocr/losses/rec_can_loss.py
+++ b/ppocr/losses/rec_can_loss.py
@@ -22,20 +22,23 @@
class CANLoss(nn.Layer):
- '''
+ """
CANLoss is consist of two part:
word_average_loss: average accuracy of the symbol
counting_loss: counting loss of every symbol
- '''
+ """
def __init__(self):
super(CANLoss, self).__init__()
self.use_label_mask = False
self.out_channel = 111
- self.cross = nn.CrossEntropyLoss(
- reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
- self.counting_loss = nn.SmoothL1Loss(reduction='mean')
+ self.cross = (
+ nn.CrossEntropyLoss(reduction="none")
+ if self.use_label_mask
+ else nn.CrossEntropyLoss()
+ )
+ self.counting_loss = nn.SmoothL1Loss(reduction="mean")
self.ratio = 16
def forward(self, preds, batch):
@@ -46,18 +49,24 @@ def forward(self, preds, batch):
labels = batch[2]
labels_mask = batch[3]
counting_labels = gen_counting_label(labels, self.out_channel, True)
- counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
- + self.counting_loss(counting_preds, counting_labels)
+ counting_loss = (
+ self.counting_loss(counting_preds1, counting_labels)
+ + self.counting_loss(counting_preds2, counting_labels)
+ + self.counting_loss(counting_preds, counting_labels)
+ )
word_loss = self.cross(
paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
- paddle.reshape(labels, [-1]))
- word_average_loss = paddle.sum(
- paddle.reshape(word_loss * labels_mask, [-1])) / (
- paddle.sum(labels_mask) + 1e-10
- ) if self.use_label_mask else word_loss
+ paddle.reshape(labels, [-1]),
+ )
+ word_average_loss = (
+ paddle.sum(paddle.reshape(word_loss * labels_mask, [-1]))
+ / (paddle.sum(labels_mask) + 1e-10)
+ if self.use_label_mask
+ else word_loss
+ )
loss = word_average_loss + counting_loss
- return {'loss': loss}
+ return {"loss": loss}
def gen_counting_label(labels, channel, tag):
@@ -75,5 +84,5 @@ def gen_counting_label(labels, channel, tag):
continue
else:
counting_labels[i][k] += 1
- counting_labels = paddle.to_tensor(counting_labels, dtype='float32')
+ counting_labels = paddle.to_tensor(counting_labels, dtype="float32")
return counting_labels
diff --git a/ppocr/losses/rec_ce_loss.py b/ppocr/losses/rec_ce_loss.py
index 614384de86..45906fd195 100644
--- a/ppocr/losses/rec_ce_loss.py
+++ b/ppocr/losses/rec_ce_loss.py
@@ -4,22 +4,18 @@
class CELoss(nn.Layer):
- def __init__(self,
- smoothing=False,
- with_all=False,
- ignore_index=-1,
- **kwargs):
+ def __init__(self, smoothing=False, with_all=False, ignore_index=-1, **kwargs):
super(CELoss, self).__init__()
if ignore_index >= 0:
self.loss_func = nn.CrossEntropyLoss(
- reduction='mean', ignore_index=ignore_index)
+ reduction="mean", ignore_index=ignore_index
+ )
else:
- self.loss_func = nn.CrossEntropyLoss(reduction='mean')
+ self.loss_func = nn.CrossEntropyLoss(reduction="mean")
self.smoothing = smoothing
self.with_all = with_all
def forward(self, pred, batch):
-
if isinstance(pred, dict): # for ABINet
loss = {}
loss_sum = []
@@ -33,9 +29,9 @@ def forward(self, pred, batch):
else:
flt_logtis = logits.reshape([-1, logits.shape[2]])
flt_tgt = batch[1].reshape([-1])
- loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt)
- loss_sum.append(loss[name + '_loss'])
- loss['loss'] = sum(loss_sum)
+ loss[name + "_loss"] = self.loss_func(flt_logtis, flt_tgt)
+ loss_sum.append(loss[name + "_loss"])
+ loss["loss"] = sum(loss_sum)
return loss
else:
if self.with_all: # for ViTSTR
@@ -43,24 +39,23 @@ def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
tgt = tgt.reshape([-1])
loss = self.loss_func(pred, tgt)
- return {'loss': loss}
+ return {"loss": loss}
else: # for NRTR
max_len = batch[2].max()
- tgt = batch[1][:, 1:2 + max_len]
+ tgt = batch[1][:, 1 : 2 + max_len]
pred = pred.reshape([-1, pred.shape[2]])
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
- one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (
- n_class - 1)
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape, dtype=tgt.dtype))
+ tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
+ )
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/rec_cppd_loss.py b/ppocr/losses/rec_cppd_loss.py
index 35abfd48d6..1e884ce1b3 100755
--- a/ppocr/losses/rec_cppd_loss.py
+++ b/ppocr/losses/rec_cppd_loss.py
@@ -18,32 +18,28 @@
class CPPDLoss(nn.Layer):
- def __init__(self,
- smoothing=False,
- ignore_index=100,
- sideloss_weight=1.0,
- **kwargs):
+ def __init__(
+ self, smoothing=False, ignore_index=100, sideloss_weight=1.0, **kwargs
+ ):
super(CPPDLoss, self).__init__()
- self.edge_ce = nn.CrossEntropyLoss(
- reduction='mean', ignore_index=ignore_index)
- self.char_node_ce = nn.CrossEntropyLoss(reduction='mean')
- self.pos_node_ce = nn.BCEWithLogitsLoss(reduction='mean')
+ self.edge_ce = nn.CrossEntropyLoss(reduction="mean", ignore_index=ignore_index)
+ self.char_node_ce = nn.CrossEntropyLoss(reduction="mean")
+ self.pos_node_ce = nn.BCEWithLogitsLoss(reduction="mean")
self.smoothing = smoothing
self.ignore_index = ignore_index
self.sideloss_weight = sideloss_weight
def label_smoothing_ce(self, preds, targets):
-
non_pad_mask = paddle.not_equal(
targets,
- paddle.zeros(
- targets.shape, dtype=targets.dtype) + self.ignore_index)
+ paddle.zeros(targets.shape, dtype=targets.dtype) + self.ignore_index,
+ )
tgts = paddle.where(
- targets == (paddle.zeros(
- targets.shape, dtype=targets.dtype) + self.ignore_index),
- paddle.zeros(
- targets.shape, dtype=targets.dtype),
- targets)
+ targets
+ == (paddle.zeros(targets.shape, dtype=targets.dtype) + self.ignore_index),
+ paddle.zeros(targets.shape, dtype=targets.dtype),
+ targets,
+ )
eps = 0.1
n_class = preds.shape[1]
one_hot = F.one_hot(tgts, preds.shape[1])
@@ -58,10 +54,12 @@ def forward(self, pred, batch):
node_tgt = batch[2]
char_tgt = batch[1]
- loss_char_node = self.char_node_ce(node_feats[0].flatten(0, 1),
- node_tgt[:, :-26].flatten(0, 1))
- loss_pos_node = self.pos_node_ce(node_feats[1].flatten(
- 0, 1), node_tgt[:, -26:].flatten(0, 1).cast('float32'))
+ loss_char_node = self.char_node_ce(
+ node_feats[0].flatten(0, 1), node_tgt[:, :-26].flatten(0, 1)
+ )
+ loss_pos_node = self.pos_node_ce(
+ node_feats[1].flatten(0, 1), node_tgt[:, -26:].flatten(0, 1).cast("float32")
+ )
loss_node = loss_char_node + loss_pos_node
edge_feats = edge_feats.flatten(0, 1)
@@ -72,7 +70,7 @@ def forward(self, pred, batch):
loss_edge = self.edge_ce(edge_feats, char_tgt)
return {
- 'loss': self.sideloss_weight * loss_node + loss_edge,
- 'loss_node': self.sideloss_weight * loss_node,
- 'loss_edge': loss_edge
+ "loss": self.sideloss_weight * loss_node + loss_edge,
+ "loss_node": self.sideloss_weight * loss_node,
+ "loss_edge": loss_edge,
}
diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py
index 502fc8c522..c701fef298 100755
--- a/ppocr/losses/rec_ctc_loss.py
+++ b/ppocr/losses/rec_ctc_loss.py
@@ -23,7 +23,7 @@
class CTCLoss(nn.Layer):
def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
- self.loss_func = nn.CTCLoss(blank=0, reduction='none')
+ self.loss_func = nn.CTCLoss(blank=0, reduction="none")
self.use_focal_loss = use_focal_loss
def forward(self, predicts, batch):
@@ -32,9 +32,10 @@ def forward(self, predicts, batch):
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor(
- [N] * B, dtype='int64', place=paddle.CPUPlace())
+ [N] * B, dtype="int64", place=paddle.CPUPlace()
+ )
labels = batch[1].astype("int32")
- label_lengths = batch[2].astype('int64')
+ label_lengths = batch[2].astype("int64")
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
if self.use_focal_loss:
weight = paddle.exp(-loss)
@@ -42,4 +43,4 @@ def forward(self, predicts, batch):
weight = paddle.square(weight)
loss = paddle.multiply(loss, weight)
loss = loss.mean()
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/rec_enhanced_ctc_loss.py b/ppocr/losses/rec_enhanced_ctc_loss.py
index b57be6468e..ef88270392 100644
--- a/ppocr/losses/rec_enhanced_ctc_loss.py
+++ b/ppocr/losses/rec_enhanced_ctc_loss.py
@@ -24,17 +24,19 @@
class EnhancedCTCLoss(nn.Layer):
- def __init__(self,
- use_focal_loss=False,
- use_ace_loss=False,
- ace_loss_weight=0.1,
- use_center_loss=False,
- center_loss_weight=0.05,
- num_classes=6625,
- feat_dim=96,
- init_center=False,
- center_file_path=None,
- **kwargs):
+ def __init__(
+ self,
+ use_focal_loss=False,
+ use_ace_loss=False,
+ ace_loss_weight=0.1,
+ use_center_loss=False,
+ center_loss_weight=0.05,
+ num_classes=6625,
+ feat_dim=96,
+ init_center=False,
+ center_file_path=None,
+ **kwargs
+ ):
super(EnhancedCTCLoss, self).__init__()
self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
@@ -51,20 +53,24 @@ def __init__(self,
num_classes=num_classes,
feat_dim=feat_dim,
init_center=init_center,
- center_file_path=center_file_path)
+ center_file_path=center_file_path,
+ )
self.center_loss_weight = center_loss_weight
def __call__(self, predicts, batch):
loss = self.ctc_loss_func(predicts, batch)["loss"]
if self.use_center_loss:
- center_loss = self.center_loss_func(
- predicts, batch)["loss_center"] * self.center_loss_weight
+ center_loss = (
+ self.center_loss_func(predicts, batch)["loss_center"]
+ * self.center_loss_weight
+ )
loss = loss + center_loss
if self.use_ace_loss:
- ace_loss = self.ace_loss_func(
- predicts, batch)["loss_ace"] * self.ace_loss_weight
+ ace_loss = (
+ self.ace_loss_func(predicts, batch)["loss_ace"] * self.ace_loss_weight
+ )
loss = loss + ace_loss
- return {'enhanced_ctc_loss': loss}
+ return {"enhanced_ctc_loss": loss}
diff --git a/ppocr/losses/rec_multi_loss.py b/ppocr/losses/rec_multi_loss.py
index 4f9365750b..c19febe535 100644
--- a/ppocr/losses/rec_multi_loss.py
+++ b/ppocr/losses/rec_multi_loss.py
@@ -28,9 +28,9 @@ class MultiLoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_funcs = {}
- self.loss_list = kwargs.pop('loss_config_list')
- self.weight_1 = kwargs.get('weight_1', 1.0)
- self.weight_2 = kwargs.get('weight_2', 1.0)
+ self.loss_list = kwargs.pop("loss_config_list")
+ self.weight_1 = kwargs.get("weight_1", 1.0)
+ self.weight_2 = kwargs.get("weight_2", 1.0)
for loss_info in self.loss_list:
for name, param in loss_info.items():
if param is not None:
@@ -43,19 +43,26 @@ def forward(self, predicts, batch):
total_loss = 0.0
# batch [image, label_ctc, label_sar, length, valid_ratio]
for name, loss_func in self.loss_funcs.items():
- if name == 'CTCLoss':
- loss = loss_func(predicts['ctc'],
- batch[:2] + batch[3:])['loss'] * self.weight_1
- elif name == 'SARLoss':
- loss = loss_func(predicts['sar'],
- batch[:1] + batch[2:])['loss'] * self.weight_2
- elif name == 'NRTRLoss':
- loss = loss_func(predicts['nrtr'],
- batch[:1] + batch[2:])['loss'] * self.weight_2
+ if name == "CTCLoss":
+ loss = (
+ loss_func(predicts["ctc"], batch[:2] + batch[3:])["loss"]
+ * self.weight_1
+ )
+ elif name == "SARLoss":
+ loss = (
+ loss_func(predicts["sar"], batch[:1] + batch[2:])["loss"]
+ * self.weight_2
+ )
+ elif name == "NRTRLoss":
+ loss = (
+ loss_func(predicts["nrtr"], batch[:1] + batch[2:])["loss"]
+ * self.weight_2
+ )
else:
raise NotImplementedError(
- '{} is not supported in MultiLoss yet'.format(name))
+ "{} is not supported in MultiLoss yet".format(name)
+ )
self.total_loss[name] = loss
total_loss += loss
- self.total_loss['loss'] = total_loss
+ self.total_loss["loss"] = total_loss
return self.total_loss
diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py
index fbd397fbf0..a4f56ded79 100644
--- a/ppocr/losses/rec_nrtr_loss.py
+++ b/ppocr/losses/rec_nrtr_loss.py
@@ -8,12 +8,13 @@ def __init__(self, smoothing=True, ignore_index=0, **kwargs):
super(NRTRLoss, self).__init__()
if ignore_index >= 0 and not smoothing:
self.loss_func = nn.CrossEntropyLoss(
- reduction='mean', ignore_index=ignore_index)
+ reduction="mean", ignore_index=ignore_index
+ )
self.smoothing = smoothing
def forward(self, pred, batch):
max_len = batch[2].max()
- tgt = batch[1][:, 1:2 + max_len]
+ tgt = batch[1][:, 1 : 2 + max_len]
pred = pred.reshape([-1, pred.shape[2]])
tgt = tgt.reshape([-1])
if self.smoothing:
@@ -23,10 +24,10 @@ def forward(self, pred, batch):
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape, dtype=tgt.dtype))
+ tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
+ )
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/rec_parseq_loss.py b/ppocr/losses/rec_parseq_loss.py
index c2468b091a..a1731e8d02 100644
--- a/ppocr/losses/rec_parseq_loss.py
+++ b/ppocr/losses/rec_parseq_loss.py
@@ -30,9 +30,9 @@ def forward(self, predicts, targets):
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
tgt = label[:, :max_step]
- logits_list = predicts['logits_list']
- pad_id = predicts['pad_id']
- eos_id = predicts['eos_id']
+ logits_list = predicts["logits_list"]
+ pad_id = predicts["pad_id"]
+ eos_id = predicts["eos_id"]
tgt_out = tgt[:, 1:]
loss = 0
@@ -40,11 +40,13 @@ def forward(self, predicts, targets):
n = (tgt_out != pad_id).sum().item()
for i, logits in enumerate(logits_list):
- loss += n * paddle.nn.functional.cross_entropy(input=logits, label=tgt_out.flatten(), ignore_index=pad_id)
+ loss += n * paddle.nn.functional.cross_entropy(
+ input=logits, label=tgt_out.flatten(), ignore_index=pad_id
+ )
loss_numel += n
if i == 1:
tgt_out = paddle.where(condition=tgt_out == eos_id, x=pad_id, y=tgt_out)
n = (tgt_out != pad_id).sum().item()
loss /= loss_numel
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/rec_pren_loss.py b/ppocr/losses/rec_pren_loss.py
index 7bc53d29b2..d08b27a982 100644
--- a/ppocr/losses/rec_pren_loss.py
+++ b/ppocr/losses/rec_pren_loss.py
@@ -23,8 +23,8 @@ class PRENLoss(nn.Layer):
def __init__(self, **kwargs):
super(PRENLoss, self).__init__()
# note: 0 is padding idx
- self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
+ self.loss_func = nn.CrossEntropyLoss(reduction="mean", ignore_index=0)
def forward(self, predicts, batch):
- loss = self.loss_func(predicts, batch[1].astype('int64'))
- return {'loss': loss}
+ loss = self.loss_func(predicts, batch[1].astype("int64"))
+ return {"loss": loss}
diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py
index be0f06d903..3af677f35b 100644
--- a/ppocr/losses/rec_rfl_loss.py
+++ b/ppocr/losses/rec_rfl_loss.py
@@ -33,7 +33,6 @@ def __init__(self, ignore_index=-100, **kwargs):
self.seq_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
def forward(self, predicts, batch):
-
self.total_loss = {}
total_loss = 0.0
if isinstance(predicts, tuple) or isinstance(predicts, list):
@@ -42,18 +41,21 @@ def forward(self, predicts, batch):
cnt_outputs, seq_outputs = predicts, None
# batch [image, label, length, cnt_label]
if cnt_outputs is not None:
- cnt_loss = self.cnt_loss(cnt_outputs,
- paddle.cast(batch[3], paddle.float32))
- self.total_loss['cnt_loss'] = cnt_loss
+ cnt_loss = self.cnt_loss(cnt_outputs, paddle.cast(batch[3], paddle.float32))
+ self.total_loss["cnt_loss"] = cnt_loss
total_loss += cnt_loss
if seq_outputs is not None:
targets = batch[1].astype("int64")
- label_lengths = batch[2].astype('int64')
- batch_size, num_steps, num_classes = seq_outputs.shape[
- 0], seq_outputs.shape[1], seq_outputs.shape[2]
- assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \
- "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+ label_lengths = batch[2].astype("int64")
+ batch_size, num_steps, num_classes = (
+ seq_outputs.shape[0],
+ seq_outputs.shape[1],
+ seq_outputs.shape[2],
+ )
+ assert (
+ len(targets.shape) == len(list(seq_outputs.shape)) - 1
+ ), "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = seq_outputs[:, :-1, :]
targets = targets[:, 1:]
@@ -61,8 +63,8 @@ def forward(self, predicts, batch):
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
targets = paddle.reshape(targets, [-1])
seq_loss = self.seq_loss(inputs, targets)
- self.total_loss['seq_loss'] = seq_loss
+ self.total_loss["seq_loss"] = seq_loss
total_loss += seq_loss
- self.total_loss['loss'] = total_loss
+ self.total_loss["loss"] = total_loss
return self.total_loss
diff --git a/ppocr/losses/rec_sar_loss.py b/ppocr/losses/rec_sar_loss.py
index a4f83f03c0..a2cba690ff 100644
--- a/ppocr/losses/rec_sar_loss.py
+++ b/ppocr/losses/rec_sar_loss.py
@@ -9,21 +9,28 @@
class SARLoss(nn.Layer):
def __init__(self, **kwargs):
super(SARLoss, self).__init__()
- ignore_index = kwargs.get('ignore_index', 92) # 6626
+ ignore_index = kwargs.get("ignore_index", 92) # 6626
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
- reduction="mean", ignore_index=ignore_index)
+ reduction="mean", ignore_index=ignore_index
+ )
def forward(self, predicts, batch):
- predict = predicts[:, :
- -1, :] # ignore last index of outputs to be in same seq_len with targets
- label = batch[1].astype(
- "int64")[:, 1:] # ignore first index of target in loss calculation
- batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
- 1], predict.shape[2]
- assert len(label.shape) == len(list(predict.shape)) - 1, \
- "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+ predict = predicts[
+ :, :-1, :
+ ] # ignore last index of outputs to be in same seq_len with targets
+ label = batch[1].astype("int64")[
+ :, 1:
+ ] # ignore first index of target in loss calculation
+ batch_size, num_steps, num_classes = (
+ predict.shape[0],
+ predict.shape[1],
+ predict.shape[2],
+ )
+ assert (
+ len(label.shape) == len(list(predict.shape)) - 1
+ ), "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predict, [-1, num_classes])
targets = paddle.reshape(label, [-1])
loss = self.loss_func(inputs, targets)
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/rec_satrn_loss.py b/ppocr/losses/rec_satrn_loss.py
index fc7b517878..b198693a90 100644
--- a/ppocr/losses/rec_satrn_loss.py
+++ b/ppocr/losses/rec_satrn_loss.py
@@ -26,21 +26,28 @@
class SATRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SATRNLoss, self).__init__()
- ignore_index = kwargs.get('ignore_index', 92) # 6626
+ ignore_index = kwargs.get("ignore_index", 92) # 6626
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
- reduction="none", ignore_index=ignore_index)
+ reduction="none", ignore_index=ignore_index
+ )
def forward(self, predicts, batch):
- predict = predicts[:, :
- -1, :] # ignore last index of outputs to be in same seq_len with targets
- label = batch[1].astype(
- "int64")[:, 1:] # ignore first index of target in loss calculation
- batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
- 1], predict.shape[2]
- assert len(label.shape) == len(list(predict.shape)) - 1, \
- "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+ predict = predicts[
+ :, :-1, :
+ ] # ignore last index of outputs to be in same seq_len with targets
+ label = batch[1].astype("int64")[
+ :, 1:
+ ] # ignore first index of target in loss calculation
+ batch_size, num_steps, num_classes = (
+ predict.shape[0],
+ predict.shape[1],
+ predict.shape[2],
+ )
+ assert (
+ len(label.shape) == len(list(predict.shape)) - 1
+ ), "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predict, [-1, num_classes])
targets = paddle.reshape(label, [-1])
loss = self.loss_func(inputs, targets)
- return {'loss': loss.mean()}
+ return {"loss": loss.mean()}
diff --git a/ppocr/losses/rec_spin_att_loss.py b/ppocr/losses/rec_spin_att_loss.py
index 195780c7bf..591ae0853c 100644
--- a/ppocr/losses/rec_spin_att_loss.py
+++ b/ppocr/losses/rec_spin_att_loss.py
@@ -20,26 +20,33 @@
import paddle
from paddle import nn
-'''This code is refer from:
+"""This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR
-'''
+"""
+
class SPINAttentionLoss(nn.Layer):
- def __init__(self, reduction='mean', ignore_index=-100, **kwargs):
+ def __init__(self, reduction="mean", ignore_index=-100, **kwargs):
super(SPINAttentionLoss, self).__init__()
- self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index)
+ self.loss_func = nn.CrossEntropyLoss(
+ weight=None, reduction=reduction, ignore_index=ignore_index
+ )
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
- targets = targets[:, 1:] # remove [eos] in label
-
- label_lengths = batch[2].astype('int64')
- batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
- 1], predicts.shape[2]
- assert len(targets.shape) == len(list(predicts.shape)) - 1, \
- "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+ targets = targets[:, 1:] # remove [eos] in label
+
+ label_lengths = batch[2].astype("int64")
+ batch_size, num_steps, num_classes = (
+ predicts.shape[0],
+ predicts.shape[1],
+ predicts.shape[2],
+ )
+ assert (
+ len(targets.shape) == len(list(predicts.shape)) - 1
+ ), "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1])
- return {'loss': self.loss_func(inputs, targets)}
+ return {"loss": self.loss_func(inputs, targets)}
diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py
index 7d5b65ebaf..cb034f3ae8 100644
--- a/ppocr/losses/rec_srn_loss.py
+++ b/ppocr/losses/rec_srn_loss.py
@@ -26,12 +26,12 @@ def __init__(self, **kwargs):
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
def forward(self, predicts, batch):
- predict = predicts['predict']
- word_predict = predicts['word_out']
- gsrm_predict = predicts['gsrm_out']
+ predict = predicts["predict"]
+ word_predict = predicts["word_out"]
+ gsrm_predict = predicts["gsrm_out"]
label = batch[1]
- casted_label = paddle.cast(x=label, dtype='int64')
+ casted_label = paddle.cast(x=label, dtype="int64")
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
cost_word = self.loss_func(word_predict, label=casted_label)
@@ -44,4 +44,4 @@ def forward(self, predicts, batch):
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
- return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
+ return {"loss": sum_cost, "word_loss": cost_word, "img_loss": cost_vsfd}
diff --git a/ppocr/losses/rec_vl_loss.py b/ppocr/losses/rec_vl_loss.py
index 5cd87c709b..34c470e37f 100644
--- a/ppocr/losses/rec_vl_loss.py
+++ b/ppocr/losses/rec_vl_loss.py
@@ -25,10 +25,10 @@
class VLLoss(nn.Layer):
- def __init__(self, mode='LF_1', weight_res=0.5, weight_mas=0.5, **kwargs):
+ def __init__(self, mode="LF_1", weight_res=0.5, weight_mas=0.5, **kwargs):
super(VLLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean")
- assert mode in ['LF_1', 'LF_2', 'LA']
+ assert mode in ["LF_1", "LF_2", "LA"]
self.mode = mode
self.weight_res = weight_res
self.weight_mas = weight_mas
@@ -38,10 +38,10 @@ def flatten_label(self, target):
label_length = []
for i in range(0, target.shape[0]):
cur_label = target[i].tolist()
- label_flatten += cur_label[:cur_label.index(0) + 1]
+ label_flatten += cur_label[: cur_label.index(0) + 1]
label_length.append(cur_label.index(0) + 1)
- label_flatten = paddle.to_tensor(label_flatten, dtype='int64')
- label_length = paddle.to_tensor(label_length, dtype='int32')
+ label_flatten = paddle.to_tensor(label_flatten, dtype="int64")
+ label_length = paddle.to_tensor(label_length, dtype="int32")
return (label_flatten, label_length)
def _flatten(self, sources, lengths):
@@ -49,16 +49,16 @@ def _flatten(self, sources, lengths):
def forward(self, predicts, batch):
text_pre = predicts[0]
- target = batch[1].astype('int64')
+ target = batch[1].astype("int64")
label_flatten, length = self.flatten_label(target)
text_pre = self._flatten(text_pre, length)
- if self.mode == 'LF_1':
+ if self.mode == "LF_1":
loss = self.loss_func(text_pre, label_flatten)
else:
text_rem = predicts[1]
text_mas = predicts[2]
- target_res = batch[2].astype('int64')
- target_sub = batch[3].astype('int64')
+ target_res = batch[2].astype("int64")
+ target_sub = batch[3].astype("int64")
label_flatten_res, length_res = self.flatten_label(target_res)
label_flatten_sub, length_sub = self.flatten_label(target_sub)
text_rem = self._flatten(text_rem, length_res)
@@ -67,4 +67,4 @@ def forward(self, predicts, batch):
loss_res = self.loss_func(text_rem, label_flatten_res)
loss_mas = self.loss_func(text_mas, label_flatten_sub)
loss = loss_ori + loss_res * self.weight_res + loss_mas * self.weight_mas
- return {'loss': loss}
+ return {"loss": loss}
diff --git a/ppocr/losses/stroke_focus_loss.py b/ppocr/losses/stroke_focus_loss.py
index 002bbc3477..9b7850c6ba 100644
--- a/ppocr/losses/stroke_focus_loss.py
+++ b/ppocr/losses/stroke_focus_loss.py
@@ -31,13 +31,12 @@ def __init__(self, character_dict_path=None, **kwargs):
self.mse_loss = nn.MSELoss()
self.ce_loss = nn.CrossEntropyLoss()
self.l1_loss = nn.L1Loss()
- self.english_stroke_alphabet = '0123456789'
+ self.english_stroke_alphabet = "0123456789"
self.english_stroke_dict = {}
for index in range(len(self.english_stroke_alphabet)):
- self.english_stroke_dict[self.english_stroke_alphabet[
- index]] = index
+ self.english_stroke_dict[self.english_stroke_alphabet[index]] = index
- stroke_decompose_lines = open(character_dict_path, 'r').readlines()
+ stroke_decompose_lines = open(character_dict_path, "r").readlines()
self.dic = {}
for line in stroke_decompose_lines:
line = line.strip()
@@ -45,7 +44,6 @@ def __init__(self, character_dict_path=None, **kwargs):
self.dic[character] = sequence
def forward(self, pred, data):
-
sr_img = pred["sr_img"]
hr_img = pred["hr_img"]
@@ -56,13 +54,10 @@ def forward(self, pred, data):
hr_pred = pred["hr_pred"]
sr_pred = pred["sr_pred"]
- attention_loss = paddle.nn.functional.l1_loss(word_attention_map_gt,
- word_attention_map_pred)
+ attention_loss = paddle.nn.functional.l1_loss(
+ word_attention_map_gt, word_attention_map_pred
+ )
loss = (mse_loss + attention_loss * 50) * 100
- return {
- "mse_loss": mse_loss,
- "attention_loss": attention_loss,
- "loss": loss
- }
+ return {"mse_loss": mse_loss, "attention_loss": attention_loss, "loss": loss}
diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py
index f1771847b4..51160a25db 100644
--- a/ppocr/losses/table_att_loss.py
+++ b/ppocr/losses/table_att_loss.py
@@ -24,48 +24,50 @@
class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, **kwargs):
super(TableAttentionLoss, self).__init__()
- self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction="none")
self.structure_weight = structure_weight
self.loc_weight = loc_weight
def forward(self, predicts, batch):
- structure_probs = predicts['structure_probs']
+ structure_probs = predicts["structure_probs"]
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
- structure_probs = paddle.reshape(structure_probs,
- [-1, structure_probs.shape[-1]])
+ structure_probs = paddle.reshape(
+ structure_probs, [-1, structure_probs.shape[-1]]
+ )
structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets)
structure_loss = paddle.mean(structure_loss) * self.structure_weight
- loc_preds = predicts['loc_preds']
+ loc_preds = predicts["loc_preds"]
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
- loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
- loc_targets) * self.loc_weight
+ loc_loss = (
+ F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
+ )
total_loss = structure_loss + loc_loss
return {
- 'loss': total_loss,
+ "loss": total_loss,
"structure_loss": structure_loss,
- "loc_loss": loc_loss
+ "loc_loss": loc_loss,
}
class SLALoss(nn.Layer):
- def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
+ def __init__(self, structure_weight, loc_weight, loc_loss="mse", **kwargs):
super(SLALoss, self).__init__()
- self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction="mean")
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.loc_loss = loc_loss
self.eps = 1e-12
def forward(self, predicts, batch):
- structure_probs = predicts['structure_probs']
+ structure_probs = predicts["structure_probs"]
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
@@ -73,21 +75,25 @@ def forward(self, predicts, batch):
structure_loss = paddle.mean(structure_loss) * self.structure_weight
- loc_preds = predicts['loc_preds']
+ loc_preds = predicts["loc_preds"]
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
- loc_loss = F.smooth_l1_loss(
- loc_preds * loc_targets_mask,
- loc_targets * loc_targets_mask,
- reduction='sum') * self.loc_weight
+ loc_loss = (
+ F.smooth_l1_loss(
+ loc_preds * loc_targets_mask,
+ loc_targets * loc_targets_mask,
+ reduction="sum",
+ )
+ * self.loc_weight
+ )
loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps)
total_loss = structure_loss + loc_loss
return {
- 'loss': total_loss,
+ "loss": total_loss,
"structure_loss": structure_loss,
- "loc_loss": loc_loss
+ "loc_loss": loc_loss,
}
diff --git a/ppocr/losses/table_master_loss.py b/ppocr/losses/table_master_loss.py
index dca982dbd4..08b1c29c44 100644
--- a/ppocr/losses/table_master_loss.py
+++ b/ppocr/losses/table_master_loss.py
@@ -24,17 +24,17 @@ class TableMasterLoss(nn.Layer):
def __init__(self, ignore_index=-1):
super(TableMasterLoss, self).__init__()
self.structure_loss = nn.CrossEntropyLoss(
- ignore_index=ignore_index, reduction='mean')
- self.box_loss = nn.L1Loss(reduction='sum')
+ ignore_index=ignore_index, reduction="mean"
+ )
+ self.box_loss = nn.L1Loss(reduction="sum")
self.eps = 1e-12
def forward(self, predicts, batch):
# structure_loss
- structure_probs = predicts['structure_probs']
+ structure_probs = predicts["structure_probs"]
structure_targets = batch[1]
structure_targets = structure_targets[:, 1:]
- structure_probs = structure_probs.reshape(
- [-1, structure_probs.shape[-1]])
+ structure_probs = structure_probs.reshape([-1, structure_probs.shape[-1]])
structure_targets = structure_targets.reshape([-1])
structure_loss = self.structure_loss(structure_probs, structure_targets)
@@ -42,7 +42,7 @@ def forward(self, predicts, batch):
losses = dict(structure_loss=structure_loss)
# box loss
- bboxes_preds = predicts['loc_preds']
+ bboxes_preds = predicts["loc_preds"]
bboxes_targets = batch[2][:, 1:, :]
bbox_masks = batch[3][:, 1:]
# mask empty-bbox or non-bbox structure token's bbox.
@@ -51,20 +51,24 @@ def forward(self, predicts, batch):
masked_bboxes_targets = bboxes_targets * bbox_masks
# horizon loss (x and width)
- horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2],
- masked_bboxes_targets[:, :, 0::2])
+ horizon_sum_loss = self.box_loss(
+ masked_bboxes_preds[:, :, 0::2], masked_bboxes_targets[:, :, 0::2]
+ )
horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps)
# vertical loss (y and height)
- vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2],
- masked_bboxes_targets[:, :, 1::2])
+ vertical_sum_loss = self.box_loss(
+ masked_bboxes_preds[:, :, 1::2], masked_bboxes_targets[:, :, 1::2]
+ )
vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps)
horizon_loss = horizon_loss.mean()
vertical_loss = vertical_loss.mean()
all_loss = structure_loss + horizon_loss + vertical_loss
- losses.update({
- 'loss': all_loss,
- 'horizon_bbox_loss': horizon_loss,
- 'vertical_bbox_loss': vertical_loss
- })
+ losses.update(
+ {
+ "loss": all_loss,
+ "horizon_bbox_loss": horizon_loss,
+ "vertical_bbox_loss": vertical_loss,
+ }
+ )
return losses
diff --git a/ppocr/losses/text_focus_loss.py b/ppocr/losses/text_focus_loss.py
index b50628405b..310140b25c 100644
--- a/ppocr/losses/text_focus_loss.py
+++ b/ppocr/losses/text_focus_loss.py
@@ -21,14 +21,14 @@
import numpy as np
import pickle as pkl
-standard_alphebet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+standard_alphebet = "-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
standard_dict = {}
for index in range(len(standard_alphebet)):
standard_dict[standard_alphebet[index]] = index
def load_confuse_matrix(confuse_dict_path):
- f = open(confuse_dict_path, 'rb')
+ f = open(confuse_dict_path, "rb")
data = pkl.load(f)
f.close()
number = data[:10]
@@ -42,12 +42,14 @@ def load_confuse_matrix(confuse_dict_path):
rearrange_data[rearrange_data == np.inf] = 1
rearrange_data = paddle.to_tensor(rearrange_data)
- lower_alpha = 'abcdefghijklmnopqrstuvwxyz'
+ lower_alpha = "abcdefghijklmnopqrstuvwxyz"
# upper_alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
for i in range(63):
for j in range(63):
if i != j and standard_alphebet[j] in lower_alpha:
- rearrange_data[i][j] = max(rearrange_data[i][j], rearrange_data[i][j + 26])
+ rearrange_data[i][j] = max(
+ rearrange_data[i][j], rearrange_data[i][j + 26]
+ )
rearrange_data = rearrange_data[:37, :37]
return rearrange_data
@@ -60,7 +62,9 @@ def weight_cross_entropy(pred, gt, weight_table):
pred_exp_weight = weight * pred_exp
loss = 0
for i in range(len(gt)):
- loss -= paddle.log(pred_exp_weight[i][gt[i]] / paddle.sum(pred_exp_weight, 1)[i])
+ loss -= paddle.log(
+ pred_exp_weight[i][gt[i]] / paddle.sum(pred_exp_weight, 1)[i]
+ )
return loss / batch
@@ -84,8 +88,4 @@ def forward(self, pred, data):
attention_loss = self.l1_loss(word_attention_map_gt, word_attention_map_pred)
recognition_loss = weight_cross_entropy(sr_pred, text_gt, self.weight_table)
loss = mse_loss + attention_loss * 10 + recognition_loss * 0.0005
- return {
- "mse_loss": mse_loss,
- "attention_loss": attention_loss,
- "loss": loss
- }
+ return {"mse_loss": mse_loss, "attention_loss": attention_loss, "loss": loss}
diff --git a/ppocr/losses/vqa_token_layoutlm_loss.py b/ppocr/losses/vqa_token_layoutlm_loss.py
index 5d564c0e26..d01c091ac6 100755
--- a/ppocr/losses/vqa_token_layoutlm_loss.py
+++ b/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -34,13 +34,28 @@ def forward(self, predicts, batch):
labels = batch[5]
attention_mask = batch[2]
if attention_mask is not None:
- active_loss = attention_mask.reshape([-1, ]) == 1
- active_output = predicts.reshape(
- [-1, self.num_classes])[active_loss]
- active_label = labels.reshape([-1, ])[active_loss]
+ active_loss = (
+ attention_mask.reshape(
+ [
+ -1,
+ ]
+ )
+ == 1
+ )
+ active_output = predicts.reshape([-1, self.num_classes])[active_loss]
+ active_label = labels.reshape(
+ [
+ -1,
+ ]
+ )[active_loss]
loss = self.loss_class(active_output, active_label)
else:
loss = self.loss_class(
predicts.reshape([-1, self.num_classes]),
- labels.reshape([-1, ]))
- return {'loss': loss}
\ No newline at end of file
+ labels.reshape(
+ [
+ -1,
+ ]
+ ),
+ )
+ return {"loss": loss}
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 5e840a194a..9ab515fcb7 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -36,14 +36,26 @@
def build_metric(config):
support_dict = [
- "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
- "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
- 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric', 'CANMetric'
+ "DetMetric",
+ "DetFCEMetric",
+ "RecMetric",
+ "ClsMetric",
+ "E2EMetric",
+ "DistillationMetric",
+ "TableMetric",
+ "KIEMetric",
+ "VQASerTokenMetric",
+ "VQAReTokenMetric",
+ "SRMetric",
+ "CTMetric",
+ "CNTMetric",
+ "CANMetric",
]
config = copy.deepcopy(config)
module_name = config.pop("name")
assert module_name in support_dict, Exception(
- "metric only support {}".format(support_dict))
+ "metric only support {}".format(support_dict)
+ )
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/metrics/cls_metric.py b/ppocr/metrics/cls_metric.py
index 6c077518ce..820fe7574f 100644
--- a/ppocr/metrics/cls_metric.py
+++ b/ppocr/metrics/cls_metric.py
@@ -14,7 +14,7 @@
class ClsMetric(object):
- def __init__(self, main_indicator='acc', **kwargs):
+ def __init__(self, main_indicator="acc", **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.reset()
@@ -29,7 +29,9 @@ def __call__(self, pred_label, *args, **kwargs):
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
- return {'acc': correct_num / (all_num + self.eps), }
+ return {
+ "acc": correct_num / (all_num + self.eps),
+ }
def get_metric(self):
"""
@@ -39,7 +41,7 @@ def get_metric(self):
"""
acc = self.correct_num / (self.all_num + self.eps)
self.reset()
- return {'acc': acc}
+ return {"acc": acc}
def reset(self):
self.correct_num = 0
diff --git a/ppocr/metrics/ct_metric.py b/ppocr/metrics/ct_metric.py
index a7634230a2..5ec82bc406 100644
--- a/ppocr/metrics/ct_metric.py
+++ b/ppocr/metrics/ct_metric.py
@@ -24,7 +24,7 @@
class CTMetric(object):
- def __init__(self, main_indicator, delimiter='\t', **kwargs):
+ def __init__(self, main_indicator, delimiter="\t", **kwargs):
self.delimiter = delimiter
self.main_indicator = main_indicator
self.reset()
@@ -33,12 +33,11 @@ def reset(self):
self.results = [] # clear results
def __call__(self, preds, batch, **kwargs):
- # NOTE: only support bs=1 now, as the label length of different sample is Unequal
- assert len(
- preds) == 1, "CentripetalText test now only suuport batch_size=1."
+ # NOTE: only support bs=1 now, as the label length of different sample is Unequal
+ assert len(preds) == 1, "CentripetalText test now only suuport batch_size=1."
label = batch[2]
text = batch[3]
- pred = preds[0]['points']
+ pred = preds[0]["points"]
result = get_score_C(label, text, pred)
self.results.append(result)
diff --git a/ppocr/metrics/det_metric.py b/ppocr/metrics/det_metric.py
index dca94c0927..be95ec34dc 100644
--- a/ppocr/metrics/det_metric.py
+++ b/ppocr/metrics/det_metric.py
@@ -16,42 +16,41 @@
from __future__ import division
from __future__ import print_function
-__all__ = ['DetMetric', 'DetFCEMetric']
+__all__ = ["DetMetric", "DetFCEMetric"]
from .eval_det_iou import DetectionIoUEvaluator
class DetMetric(object):
- def __init__(self, main_indicator='hmean', **kwargs):
+ def __init__(self, main_indicator="hmean", **kwargs):
self.evaluator = DetectionIoUEvaluator()
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
- '''
- batch: a list produced by dataloaders.
- image: np.ndarray of shape (N, C, H, W).
- ratio_list: np.ndarray of shape(N,2)
- polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
- ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
- preds: a list of dict produced by post process
- points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
- '''
+ """
+ batch: a list produced by dataloaders.
+ image: np.ndarray of shape (N, C, H, W).
+ ratio_list: np.ndarray of shape(N,2)
+ polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
+ ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
+ preds: a list of dict produced by post process
+ points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
+ """
gt_polyons_batch = batch[2]
ignore_tags_batch = batch[3]
- for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
- ignore_tags_batch):
+ for pred, gt_polyons, ignore_tags in zip(
+ preds, gt_polyons_batch, ignore_tags_batch
+ ):
# prepare gt
- gt_info_list = [{
- 'points': gt_polyon,
- 'text': '',
- 'ignore': ignore_tag
- } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
+ gt_info_list = [
+ {"points": gt_polyon, "text": "", "ignore": ignore_tag}
+ for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)
+ ]
# prepare det
- det_info_list = [{
- 'points': det_polyon,
- 'text': ''
- } for det_polyon in pred['points']]
+ det_info_list = [
+ {"points": det_polyon, "text": ""} for det_polyon in pred["points"]
+ ]
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result)
@@ -73,46 +72,45 @@ def reset(self):
class DetFCEMetric(object):
- def __init__(self, main_indicator='hmean', **kwargs):
+ def __init__(self, main_indicator="hmean", **kwargs):
self.evaluator = DetectionIoUEvaluator()
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
- '''
- batch: a list produced by dataloaders.
- image: np.ndarray of shape (N, C, H, W).
- ratio_list: np.ndarray of shape(N,2)
- polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
- ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
- preds: a list of dict produced by post process
- points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
- '''
+ """
+ batch: a list produced by dataloaders.
+ image: np.ndarray of shape (N, C, H, W).
+ ratio_list: np.ndarray of shape(N,2)
+ polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
+ ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
+ preds: a list of dict produced by post process
+ points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
+ """
gt_polyons_batch = batch[2]
ignore_tags_batch = batch[3]
- for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
- ignore_tags_batch):
+ for pred, gt_polyons, ignore_tags in zip(
+ preds, gt_polyons_batch, ignore_tags_batch
+ ):
# prepare gt
- gt_info_list = [{
- 'points': gt_polyon,
- 'text': '',
- 'ignore': ignore_tag
- } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
+ gt_info_list = [
+ {"points": gt_polyon, "text": "", "ignore": ignore_tag}
+ for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)
+ ]
# prepare det
- det_info_list = [{
- 'points': det_polyon,
- 'text': '',
- 'score': score
- } for det_polyon, score in zip(pred['points'], pred['scores'])]
+ det_info_list = [
+ {"points": det_polyon, "text": "", "score": score}
+ for det_polyon, score in zip(pred["points"], pred["scores"])
+ ]
for score_thr in self.results.keys():
det_info_list_thr = [
- det_info for det_info in det_info_list
- if det_info['score'] >= score_thr
+ det_info
+ for det_info in det_info_list
+ if det_info["score"] >= score_thr
]
- result = self.evaluator.evaluate_image(gt_info_list,
- det_info_list_thr)
+ result = self.evaluator.evaluate_image(gt_info_list, det_info_list_thr)
self.results[score_thr].append(result)
def get_metric(self):
@@ -133,11 +131,12 @@ def get_metric(self):
metric = self.evaluator.combine_results(self.results[score_thr])
# for key, value in metric.items():
# metrics['{}_{}'.format(key, score_thr)] = value
- metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format(
- metric['precision'], metric['recall'], metric['hmean'])
- metrics['thr {}'.format(score_thr)] = metric_str
- hmean = max(hmean, metric['hmean'])
- metrics['hmean'] = hmean
+ metric_str = "precision:{:.5f} recall:{:.5f} hmean:{:.5f}".format(
+ metric["precision"], metric["recall"], metric["hmean"]
+ )
+ metrics["thr {}".format(score_thr)] = metric_str
+ hmean = max(hmean, metric["hmean"])
+ metrics["hmean"] = hmean
self.reset()
return metrics
@@ -150,5 +149,5 @@ def reset(self):
0.6: [],
0.7: [],
0.8: [],
- 0.9: []
+ 0.9: [],
} # clear results
diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py
index e2cbc4dc07..8e0bcf1e9d 100644
--- a/ppocr/metrics/distillation_metric.py
+++ b/ppocr/metrics/distillation_metric.py
@@ -24,11 +24,7 @@
class DistillationMetric(object):
- def __init__(self,
- key=None,
- base_metric_name=None,
- main_indicator=None,
- **kwargs):
+ def __init__(self, key=None, base_metric_name=None, main_indicator=None, **kwargs):
self.main_indicator = main_indicator
self.key = key
self.main_indicator = main_indicator
@@ -41,7 +37,8 @@ def _init_metrcis(self, preds):
mod = importlib.import_module(__name__)
for key in preds:
self.metrics[key] = getattr(mod, self.base_metric_name)(
- main_indicator=self.main_indicator, **self.kwargs)
+ main_indicator=self.main_indicator, **self.kwargs
+ )
self.metrics[key].reset()
def __call__(self, preds, batch, **kwargs):
diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py
index 2f8ba3b222..2cf563f13e 100644
--- a/ppocr/metrics/e2e_metric.py
+++ b/ppocr/metrics/e2e_metric.py
@@ -16,19 +16,21 @@
from __future__ import division
from __future__ import print_function
-__all__ = ['E2EMetric']
+__all__ = ["E2EMetric"]
from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
class E2EMetric(object):
- def __init__(self,
- mode,
- gt_mat_dir,
- character_dict_path,
- main_indicator='f_score_e2e',
- **kwargs):
+ def __init__(
+ self,
+ mode,
+ gt_mat_dir,
+ character_dict_path,
+ main_indicator="f_score_e2e",
+ **kwargs
+ ):
self.mode = mode
self.gt_mat_dir = gt_mat_dir
self.label_list = get_dict(character_dict_path)
@@ -37,7 +39,7 @@ def __init__(self,
self.reset()
def __call__(self, preds, batch, **kwargs):
- if self.mode == 'A':
+ if self.mode == "A":
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3][0]
ignore_tags_batch = batch[4]
@@ -51,29 +53,29 @@ def __call__(self, preds, batch, **kwargs):
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
- [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
+ [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch
+ ):
# prepare gt
- gt_info_list = [{
- 'points': gt_polyon,
- 'text': gt_str,
- 'ignore': ignore_tag
- } for gt_polyon, gt_str, ignore_tag in
- zip(gt_polyons, gt_strs, ignore_tags)]
+ gt_info_list = [
+ {"points": gt_polyon, "text": gt_str, "ignore": ignore_tag}
+ for gt_polyon, gt_str, ignore_tag in zip(
+ gt_polyons, gt_strs, ignore_tags
+ )
+ ]
# prepare det
- e2e_info_list = [{
- 'points': det_polyon,
- 'texts': pred_str
- } for det_polyon, pred_str in
- zip(pred['points'], pred['texts'])]
+ e2e_info_list = [
+ {"points": det_polyon, "texts": pred_str}
+ for det_polyon, pred_str in zip(pred["points"], pred["texts"])
+ ]
result = get_socre_A(gt_info_list, e2e_info_list)
self.results.append(result)
else:
img_id = batch[5][0]
- e2e_info_list = [{
- 'points': det_polyon,
- 'texts': pred_str
- } for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
+ e2e_info_list = [
+ {"points": det_polyon, "texts": pred_str}
+ for det_polyon, pred_str in zip(preds["points"], preds["texts"])
+ ]
result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result)
diff --git a/ppocr/metrics/eval_det_iou.py b/ppocr/metrics/eval_det_iou.py
index c144886b3f..21d5c18eb9 100644
--- a/ppocr/metrics/eval_det_iou.py
+++ b/ppocr/metrics/eval_det_iou.py
@@ -3,6 +3,7 @@
from collections import namedtuple
import numpy as np
from shapely.geometry import Polygon
+
"""
reference from :
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
@@ -48,7 +49,7 @@ def compute_ap(confList, matchList, numGtCare):
matchedSum = 0
- Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
+ Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax")
numGlobalCareGt = 0
numGlobalCareDet = 0
@@ -84,8 +85,8 @@ def compute_ap(confList, matchList, numGtCare):
evaluationLog = ""
for n in range(len(gt)):
- points = gt[n]['points']
- dontCare = gt[n]['ignore']
+ points = gt[n]["points"]
+ dontCare = gt[n]["ignore"]
if not Polygon(points).is_valid:
continue
@@ -95,12 +96,18 @@ def compute_ap(confList, matchList, numGtCare):
if dontCare:
gtDontCarePolsNum.append(len(gtPols) - 1)
- evaluationLog += "GT polygons: " + str(len(gtPols)) + (
- " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
- if len(gtDontCarePolsNum) > 0 else "\n")
+ evaluationLog += (
+ "GT polygons: "
+ + str(len(gtPols))
+ + (
+ " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
+ if len(gtDontCarePolsNum) > 0
+ else "\n"
+ )
+ )
for n in range(len(pred)):
- points = pred[n]['points']
+ points = pred[n]["points"]
if not Polygon(points).is_valid:
continue
@@ -112,14 +119,22 @@ def compute_ap(confList, matchList, numGtCare):
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol, detPol)
pdDimensions = Polygon(detPol).area
- precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
- if (precision > self.area_precision_constraint):
+ precision = (
+ 0 if pdDimensions == 0 else intersected_area / pdDimensions
+ )
+ if precision > self.area_precision_constraint:
detDontCarePolsNum.append(len(detPols) - 1)
break
- evaluationLog += "DET polygons: " + str(len(detPols)) + (
- " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
- if len(detDontCarePolsNum) > 0 else "\n")
+ evaluationLog += (
+ "DET polygons: "
+ + str(len(detPols))
+ + (
+ " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
+ if len(detDontCarePolsNum) > 0
+ else "\n"
+ )
+ )
if len(gtPols) > 0 and len(detPols) > 0:
# Calculate IoU and precision matrixs
@@ -135,19 +150,28 @@ def compute_ap(confList, matchList, numGtCare):
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
- if gtRectMat[gtNum] == 0 and detRectMat[
- detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
+ if (
+ gtRectMat[gtNum] == 0
+ and detRectMat[detNum] == 0
+ and gtNum not in gtDontCarePolsNum
+ and detNum not in detDontCarePolsNum
+ ):
if iouMat[gtNum, detNum] > self.iou_constraint:
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
detMatched += 1
- pairs.append({'gt': gtNum, 'det': detNum})
+ pairs.append({"gt": gtNum, "det": detNum})
detMatchedNums.append(detNum)
- evaluationLog += "Match GT #" + \
- str(gtNum) + " with Det #" + str(detNum) + "\n"
-
- numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
- numDetCare = (len(detPols) - len(detDontCarePolsNum))
+ evaluationLog += (
+ "Match GT #"
+ + str(gtNum)
+ + " with Det #"
+ + str(detNum)
+ + "\n"
+ )
+
+ numGtCare = len(gtPols) - len(gtDontCarePolsNum)
+ numDetCare = len(detPols) - len(detDontCarePolsNum)
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare > 0 else float(1)
@@ -155,17 +179,20 @@ def compute_ap(confList, matchList, numGtCare):
recall = float(detMatched) / numGtCare
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
- hmean = 0 if (precision + recall) == 0 else 2.0 * \
- precision * recall / (precision + recall)
+ hmean = (
+ 0
+ if (precision + recall) == 0
+ else 2.0 * precision * recall / (precision + recall)
+ )
matchedSum += detMatched
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
perSampleMetrics = {
- 'gtCare': numGtCare,
- 'detCare': numDetCare,
- 'detMatched': detMatched,
+ "gtCare": numGtCare,
+ "detCare": numDetCare,
+ "detMatched": detMatched,
}
return perSampleMetrics
@@ -174,42 +201,55 @@ def combine_results(self, results):
numGlobalCareDet = 0
matchedSum = 0
for result in results:
- numGlobalCareGt += result['gtCare']
- numGlobalCareDet += result['detCare']
- matchedSum += result['detMatched']
-
- methodRecall = 0 if numGlobalCareGt == 0 else float(
- matchedSum) / numGlobalCareGt
- methodPrecision = 0 if numGlobalCareDet == 0 else float(
- matchedSum) / numGlobalCareDet
- methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
- methodRecall * methodPrecision / (
- methodRecall + methodPrecision)
+ numGlobalCareGt += result["gtCare"]
+ numGlobalCareDet += result["detCare"]
+ matchedSum += result["detMatched"]
+
+ methodRecall = (
+ 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
+ )
+ methodPrecision = (
+ 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
+ )
+ methodHmean = (
+ 0
+ if methodRecall + methodPrecision == 0
+ else 2 * methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ )
methodMetrics = {
- 'precision': methodPrecision,
- 'recall': methodRecall,
- 'hmean': methodHmean
+ "precision": methodPrecision,
+ "recall": methodRecall,
+ "hmean": methodHmean,
}
return methodMetrics
-if __name__ == '__main__':
+if __name__ == "__main__":
evaluator = DetectionIoUEvaluator()
- gts = [[{
- 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
- 'text': 1234,
- 'ignore': False,
- }, {
- 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
- 'text': 5678,
- 'ignore': False,
- }]]
- preds = [[{
- 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
- 'text': 123,
- 'ignore': False,
- }]]
+ gts = [
+ [
+ {
+ "points": [(0, 0), (1, 0), (1, 1), (0, 1)],
+ "text": 1234,
+ "ignore": False,
+ },
+ {
+ "points": [(2, 2), (3, 2), (3, 3), (2, 3)],
+ "text": 5678,
+ "ignore": False,
+ },
+ ]
+ ]
+ preds = [
+ [
+ {
+ "points": [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+ "text": 123,
+ "ignore": False,
+ }
+ ]
+ ]
results = []
for gt, pred in zip(gts, preds):
results.append(evaluator.evaluate_image(gt, pred))
diff --git a/ppocr/metrics/kie_metric.py b/ppocr/metrics/kie_metric.py
index 28ab22b807..0c83756f8c 100644
--- a/ppocr/metrics/kie_metric.py
+++ b/ppocr/metrics/kie_metric.py
@@ -20,11 +20,11 @@
import numpy as np
import paddle
-__all__ = ['KIEMetric']
+__all__ = ["KIEMetric"]
class KIEMetric(object):
- def __init__(self, main_indicator='hmean', **kwargs):
+ def __init__(self, main_indicator="hmean", **kwargs):
self.main_indicator = main_indicator
self.reset()
self.node = []
@@ -33,7 +33,7 @@ def __init__(self, main_indicator='hmean', **kwargs):
def __call__(self, preds, batch, **kwargs):
nodes, _ = preds
gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
- gts = gts[:tag[0], :1].reshape([-1])
+ gts = gts[: tag[0], :1].reshape([-1])
self.node.append(nodes.numpy())
self.gt.append(gts)
# result = self.compute_f1_score(nodes, gts)
@@ -43,9 +43,11 @@ def compute_f1_score(self, preds, gts):
ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
C = preds.shape[1]
classes = np.array(sorted(set(range(C)) - set(ignores)))
- hist = np.bincount(
- (gts * C).astype('int64') + preds.argmax(1), minlength=C
- **2).reshape([C, C]).astype('float32')
+ hist = (
+ np.bincount((gts * C).astype("int64") + preds.argmax(1), minlength=C**2)
+ .reshape([C, C])
+ .astype("float32")
+ )
diag = np.diag(hist)
recalls = diag / hist.sum(1).clip(min=1)
precisions = diag / hist.sum(0).clip(min=1)
@@ -56,11 +58,10 @@ def combine_results(self, results):
node = np.concatenate(self.node, 0)
gts = np.concatenate(self.gt, 0)
results = self.compute_f1_score(node, gts)
- data = {'hmean': results.mean()}
+ data = {"hmean": results.mean()}
return data
def get_metric(self):
-
metrics = self.combine_results(self.results)
self.reset()
return metrics
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 305b913c72..e41dd36e09 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -20,11 +20,9 @@
class RecMetric(object):
- def __init__(self,
- main_indicator='acc',
- is_filter=False,
- ignore_space=True,
- **kwargs):
+ def __init__(
+ self, main_indicator="acc", is_filter=False, ignore_space=True, **kwargs
+ ):
self.main_indicator = main_indicator
self.is_filter = is_filter
self.ignore_space = ignore_space
@@ -32,8 +30,9 @@ def __init__(self,
self.reset()
def _normalize_text(self, text):
- text = ''.join(
- filter(lambda x: x in (string.digits + string.ascii_letters), text))
+ text = "".join(
+ filter(lambda x: x in (string.digits + string.ascii_letters), text)
+ )
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
@@ -56,8 +55,8 @@ def __call__(self, pred_label, *args, **kwargs):
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
return {
- 'acc': correct_num / (all_num + self.eps),
- 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
+ "acc": correct_num / (all_num + self.eps),
+ "norm_edit_dis": 1 - norm_edit_dis / (all_num + self.eps),
}
def get_metric(self):
@@ -70,7 +69,7 @@ def get_metric(self):
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset()
- return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
+ return {"acc": acc, "norm_edit_dis": norm_edit_dis}
def reset(self):
self.correct_num = 0
@@ -79,7 +78,7 @@ def reset(self):
class CNTMetric(object):
- def __init__(self, main_indicator='acc', **kwargs):
+ def __init__(self, main_indicator="acc", **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.reset()
@@ -94,7 +93,9 @@ def __call__(self, pred_label, *args, **kwargs):
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
- return {'acc': correct_num / (all_num + self.eps), }
+ return {
+ "acc": correct_num / (all_num + self.eps),
+ }
def get_metric(self):
"""
@@ -104,7 +105,7 @@ def get_metric(self):
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset()
- return {'acc': acc}
+ return {"acc": acc}
def reset(self):
self.correct_num = 0
@@ -112,7 +113,7 @@ def reset(self):
class CANMetric(object):
- def __init__(self, main_indicator='exp_rate', **kwargs):
+ def __init__(self, main_indicator="exp_rate", **kwargs):
self.main_indicator = main_indicator
self.word_right = []
self.exp_right = []
@@ -136,20 +137,19 @@ def __call__(self, preds, batch, **kwargs):
word_pred = word_pred.cpu().detach().numpy()
word_scores = [
SequenceMatcher(
- None,
- s1[:int(np.sum(s3))],
- s2[:int(np.sum(s3))],
- autojunk=False).ratio() * (
- len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) /
- len(s1[:int(np.sum(s3))]) / 2
+ None, s1[: int(np.sum(s3))], s2[: int(np.sum(s3))], autojunk=False
+ ).ratio()
+ * (len(s1[: int(np.sum(s3))]) + len(s2[: int(np.sum(s3))]))
+ / len(s1[: int(np.sum(s3))])
+ / 2
for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
]
batch_size = len(word_scores)
for i in range(batch_size):
if word_scores[i] == 1:
line_right += 1
- self.word_rate = np.mean(word_scores) #float
- self.exp_rate = line_right / batch_size #float
+ self.word_rate = np.mean(word_scores) # float
+ self.exp_rate = line_right / batch_size # float
exp_length, word_length = word_label.shape[:2]
self.word_right.append(self.word_rate * word_length)
self.exp_right.append(self.exp_rate * exp_length)
@@ -166,7 +166,7 @@ def get_metric(self):
cur_word_rate = sum(self.word_right) / self.word_total_length
cur_exp_rate = sum(self.exp_right) / self.exp_total_num
self.reset()
- return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate}
+ return {"word_rate": cur_word_rate, "exp_rate": cur_exp_rate}
def reset(self):
self.word_rate = 0
diff --git a/ppocr/metrics/sr_metric.py b/ppocr/metrics/sr_metric.py
index 51c3ad6656..ef9ef96939 100644
--- a/ppocr/metrics/sr_metric.py
+++ b/ppocr/metrics/sr_metric.py
@@ -32,10 +32,12 @@ def __init__(self, window_size=11, size_average=True):
self.window = self.create_window(window_size, self.channel)
def gaussian(self, window_size, sigma):
- gauss = paddle.to_tensor([
- exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
- for x in range(window_size)
- ])
+ gauss = paddle.to_tensor(
+ [
+ exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
+ for x in range(window_size)
+ ]
+ )
return gauss / gauss.sum()
def create_window(self, window_size, channel):
@@ -44,8 +46,7 @@ def create_window(self, window_size, channel):
window = _2D_window.expand([channel, 1, window_size, window_size])
return window
- def _ssim(self, img1, img2, window, window_size, channel,
- size_average=True):
+ def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
@@ -53,21 +54,25 @@ def _ssim(self, img1, img2, window, window_size, channel,
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
- sigma1_sq = F.conv2d(
- img1 * img1, window, padding=window_size // 2,
- groups=channel) - mu1_sq
- sigma2_sq = F.conv2d(
- img2 * img2, window, padding=window_size // 2,
- groups=channel) - mu2_sq
- sigma12 = F.conv2d(
- img1 * img2, window, padding=window_size // 2,
- groups=channel) - mu1_mu2
+ sigma1_sq = (
+ F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
+ - mu1_sq
+ )
+ sigma2_sq = (
+ F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
+ - mu2_sq
+ )
+ sigma12 = (
+ F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
+ - mu1_mu2
+ )
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
- (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
+ )
if size_average:
return ssim_map.mean()
@@ -78,8 +83,7 @@ def ssim(self, img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.shape
window = self.create_window(window_size, channel)
- return self._ssim(img1, img2, window, window_size, channel,
- size_average)
+ return self._ssim(img1, img2, window, window_size, channel, size_average)
def forward(self, img1, img2):
(_, channel, _, _) = img1.shape
@@ -92,12 +96,13 @@ def forward(self, img1, img2):
self.window = window
self.channel = channel
- return self._ssim(img1, img2, window, self.window_size, channel,
- self.size_average)
+ return self._ssim(
+ img1, img2, window, self.window_size, channel, self.size_average
+ )
class SRMetric(object):
- def __init__(self, main_indicator='all', **kwargs):
+ def __init__(self, main_indicator="all", **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.psnr_result = []
@@ -114,14 +119,15 @@ def reset(self):
def calculate_psnr(self, img1, img2):
# img1 and img2 have range [0, 1]
- mse = ((img1 * 255 - img2 * 255)**2).mean()
+ mse = ((img1 * 255 - img2 * 255) ** 2).mean()
if mse == 0:
- return float('inf')
+ return float("inf")
return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
def _normalize_text(self, text):
- text = ''.join(
- filter(lambda x: x in (string.digits + string.ascii_letters), text))
+ text = "".join(
+ filter(lambda x: x in (string.digits + string.ascii_letters), text)
+ )
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
@@ -149,7 +155,7 @@ def get_metric(self):
self.reset()
return {
- 'psnr_avg': self.psnr_avg,
+ "psnr_avg": self.psnr_avg,
"ssim_avg": self.ssim_avg,
- "all": self.all_avg
+ "all": self.all_avg,
}
diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
index c0b247efa6..3f319c8ca2 100644
--- a/ppocr/metrics/table_metric.py
+++ b/ppocr/metrics/table_metric.py
@@ -16,11 +16,7 @@
class TableStructureMetric(object):
- def __init__(self,
- main_indicator='acc',
- eps=1e-6,
- del_thead_tbody=False,
- **kwargs):
+ def __init__(self, main_indicator="acc", eps=1e-6, del_thead_tbody=False, **kwargs):
self.main_indicator = main_indicator
self.eps = eps
self.del_thead_tbody = del_thead_tbody
@@ -28,21 +24,28 @@ def __init__(self,
def __call__(self, pred_label, batch=None, *args, **kwargs):
preds, labels = pred_label
- pred_structure_batch_list = preds['structure_batch_list']
- gt_structure_batch_list = labels['structure_batch_list']
+ pred_structure_batch_list = preds["structure_batch_list"]
+ gt_structure_batch_list = labels["structure_batch_list"]
correct_num = 0
all_num = 0
- for (pred, pred_conf), target in zip(pred_structure_batch_list,
- gt_structure_batch_list):
- pred_str = ''.join(pred)
- target_str = ''.join(target)
+ for (pred, pred_conf), target in zip(
+ pred_structure_batch_list, gt_structure_batch_list
+ ):
+ pred_str = "".join(pred)
+ target_str = "".join(target)
if self.del_thead_tbody:
- pred_str = pred_str.replace('', '').replace(
- '', '').replace('', '').replace('',
- '')
- target_str = target_str.replace('', '').replace(
- '', '').replace('', '').replace('',
- '')
+ pred_str = (
+ pred_str.replace("", "")
+ .replace("", "")
+ .replace("", "")
+ .replace("", "")
+ )
+ target_str = (
+ target_str.replace("", "")
+ .replace("", "")
+ .replace("", "")
+ .replace("", "")
+ )
if pred_str == target_str:
correct_num += 1
all_num += 1
@@ -57,7 +60,7 @@ def get_metric(self):
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset()
- return {'acc': acc}
+ return {"acc": acc}
def reset(self):
self.correct_num = 0
@@ -68,20 +71,21 @@ def reset(self):
class TableMetric(object):
- def __init__(self,
- main_indicator='acc',
- compute_bbox_metric=False,
- box_format='xyxy',
- del_thead_tbody=False,
- **kwargs):
+ def __init__(
+ self,
+ main_indicator="acc",
+ compute_bbox_metric=False,
+ box_format="xyxy",
+ del_thead_tbody=False,
+ **kwargs
+ ):
"""
@param sub_metrics: configs of sub_metric
@param main_matric: main_matric for save best_model
@param kwargs:
"""
- self.structure_metric = TableStructureMetric(
- del_thead_tbody=del_thead_tbody)
+ self.structure_metric = TableStructureMetric(del_thead_tbody=del_thead_tbody)
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.box_format = box_format
@@ -98,19 +102,19 @@ def prepare_bbox_metric_input(self, pred_label):
gt_bbox_batch_list = []
preds, labels = pred_label
- batch_num = len(preds['bbox_batch_list'])
+ batch_num = len(preds["bbox_batch_list"])
for batch_idx in range(batch_num):
# pred
pred_bbox_list = [
self.format_box(pred_box)
- for pred_box in preds['bbox_batch_list'][batch_idx]
+ for pred_box in preds["bbox_batch_list"][batch_idx]
]
- pred_bbox_batch_list.append({'points': pred_bbox_list})
+ pred_bbox_batch_list.append({"points": pred_bbox_list})
# gt
gt_bbox_list = []
gt_ignore_tags_list = []
- for gt_box in labels['bbox_batch_list'][batch_idx]:
+ for gt_box in labels["bbox_batch_list"][batch_idx]:
gt_bbox_list.append(self.format_box(gt_box))
gt_ignore_tags_list.append(0)
gt_bbox_batch_list.append(gt_bbox_list)
@@ -118,7 +122,7 @@ def prepare_bbox_metric_input(self, pred_label):
return [
pred_bbox_batch_list,
- [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
+ [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list],
]
def get_metric(self):
@@ -129,8 +133,9 @@ def get_metric(self):
if self.main_indicator == self.bbox_metric.main_indicator:
output = bbox_metric
for sub_key in structure_metric:
- output["structure_metric_{}".format(
- sub_key)] = structure_metric[sub_key]
+ output["structure_metric_{}".format(sub_key)] = structure_metric[
+ sub_key
+ ]
else:
output = structure_metric
for sub_key in bbox_metric:
@@ -143,14 +148,14 @@ def reset(self):
self.bbox_metric.reset()
def format_box(self, box):
- if self.box_format == 'xyxy':
+ if self.box_format == "xyxy":
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
- elif self.box_format == 'xywh':
+ elif self.box_format == "xywh":
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
- elif self.box_format == 'xyxyxyxy':
+ elif self.box_format == "xyxyxyxy":
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
diff --git a/ppocr/metrics/vqa_token_re_metric.py b/ppocr/metrics/vqa_token_re_metric.py
index 0509984f7e..d39917000f 100644
--- a/ppocr/metrics/vqa_token_re_metric.py
+++ b/ppocr/metrics/vqa_token_re_metric.py
@@ -19,11 +19,11 @@
import numpy as np
import paddle
-__all__ = ['KIEMetric']
+__all__ = ["KIEMetric"]
class VQAReTokenMetric(object):
- def __init__(self, main_indicator='hmean', **kwargs):
+ def __init__(self, main_indicator="hmean", **kwargs):
self.main_indicator = main_indicator
self.reset()
@@ -41,27 +41,28 @@ def get_metric(self):
entitie_list = self.entities_list[b]
head_len = relation_list[0, 0]
if head_len > 0:
- entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0]
- entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1]
- entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2]
- for head, tail in zip(relation_list[1:head_len + 1, 0],
- relation_list[1:head_len + 1, 1]):
+ entitie_start_list = entitie_list[1 : entitie_list[0, 0] + 1, 0]
+ entitie_end_list = entitie_list[1 : entitie_list[0, 1] + 1, 1]
+ entitie_label_list = entitie_list[1 : entitie_list[0, 2] + 1, 2]
+ for head, tail in zip(
+ relation_list[1 : head_len + 1, 0],
+ relation_list[1 : head_len + 1, 1],
+ ):
rel = {}
rel["head_id"] = head
- rel["head"] = (entitie_start_list[head],
- entitie_end_list[head])
+ rel["head"] = (entitie_start_list[head], entitie_end_list[head])
rel["head_type"] = entitie_label_list[head]
rel["tail_id"] = tail
- rel["tail"] = (entitie_start_list[tail],
- entitie_end_list[tail])
+ rel["tail"] = (entitie_start_list[tail], entitie_end_list[tail])
rel["tail_type"] = entitie_label_list[tail]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.re_score(
- self.pred_relations_list, gt_relations, mode="boundaries")
+ self.pred_relations_list, gt_relations, mode="boundaries"
+ )
metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
@@ -94,14 +95,7 @@ def re_score(self, pred_relations, gt_relations, mode="strict"):
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
- scores = {
- rel: {
- "tp": 0,
- "fp": 0,
- "fn": 0
- }
- for rel in relation_types + ["ALL"]
- }
+ scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
@@ -113,21 +107,29 @@ def re_score(self, pred_relations, gt_relations, mode="strict"):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
- pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
- rel["tail_type"])
- for rel in pred_sent
- if rel["type"] == rel_type}
- gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
- rel["tail_type"])
- for rel in gt_sent if rel["type"] == rel_type}
+ pred_rels = {
+ (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"])
+ for rel in pred_sent
+ if rel["type"] == rel_type
+ }
+ gt_rels = {
+ (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"])
+ for rel in gt_sent
+ if rel["type"] == rel_type
+ }
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
- pred_rels = {(rel["head"], rel["tail"])
- for rel in pred_sent
- if rel["type"] == rel_type}
- gt_rels = {(rel["head"], rel["tail"])
- for rel in gt_sent if rel["type"] == rel_type}
+ pred_rels = {
+ (rel["head"], rel["tail"])
+ for rel in pred_sent
+ if rel["type"] == rel_type
+ }
+ gt_rels = {
+ (rel["head"], rel["tail"])
+ for rel in gt_sent
+ if rel["type"] == rel_type
+ }
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
@@ -137,16 +139,21 @@ def re_score(self, pred_relations, gt_relations, mode="strict"):
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
- scores[rel_type]["fp"] + scores[rel_type]["tp"])
+ scores[rel_type]["fp"] + scores[rel_type]["tp"]
+ )
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
- scores[rel_type]["fn"] + scores[rel_type]["tp"])
+ scores[rel_type]["fn"] + scores[rel_type]["tp"]
+ )
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
- 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
- (scores[rel_type]["p"] + scores[rel_type]["r"]))
+ 2
+ * scores[rel_type]["p"]
+ * scores[rel_type]["r"]
+ / (scores[rel_type]["p"] + scores[rel_type]["r"])
+ )
else:
scores[rel_type]["f1"] = 0
@@ -172,10 +179,13 @@ def re_score(self, pred_relations, gt_relations, mode="strict"):
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
- [scores[ent_type]["f1"] for ent_type in relation_types])
+ [scores[ent_type]["f1"] for ent_type in relation_types]
+ )
scores["ALL"]["Macro_p"] = np.mean(
- [scores[ent_type]["p"] for ent_type in relation_types])
+ [scores[ent_type]["p"] for ent_type in relation_types]
+ )
scores["ALL"]["Macro_r"] = np.mean(
- [scores[ent_type]["r"] for ent_type in relation_types])
+ [scores[ent_type]["r"] for ent_type in relation_types]
+ )
return scores
diff --git a/ppocr/metrics/vqa_token_ser_metric.py b/ppocr/metrics/vqa_token_ser_metric.py
index 286d8addaf..b6033c3ae5 100644
--- a/ppocr/metrics/vqa_token_ser_metric.py
+++ b/ppocr/metrics/vqa_token_ser_metric.py
@@ -19,11 +19,11 @@
import numpy as np
import paddle
-__all__ = ['KIEMetric']
+__all__ = ["KIEMetric"]
class VQASerTokenMetric(object):
- def __init__(self, main_indicator='hmean', **kwargs):
+ def __init__(self, main_indicator="hmean", **kwargs):
self.main_indicator = main_indicator
self.reset()
@@ -34,6 +34,7 @@ def __call__(self, preds, batch, **kwargs):
def get_metric(self):
from seqeval.metrics import f1_score, precision_score, recall_score
+
metrics = {
"precision": precision_score(self.gt_list, self.pred_list),
"recall": recall_score(self.gt_list, self.pred_list),
diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py
index 00220d28de..560b6440aa 100755
--- a/ppocr/modeling/architectures/__init__.py
+++ b/ppocr/modeling/architectures/__init__.py
@@ -38,81 +38,77 @@ def build_model(config):
def apply_to_static(model, config, logger):
if config["Global"].get("to_static", False) is not True:
return model
- assert "d2s_train_image_shape" in config[
- "Global"], "d2s_train_image_shape must be assigned for static training mode..."
- supported_list = [
- "DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet", "SVTR"
- ]
+ assert (
+ "d2s_train_image_shape" in config["Global"]
+ ), "d2s_train_image_shape must be assigned for static training mode..."
+ supported_list = ["DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet", "SVTR"]
if config["Architecture"]["algorithm"] in ["Distillation"]:
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
else:
algo = config["Architecture"]["algorithm"]
- assert algo in supported_list, f"algorithms that supports static training must in in {supported_list} but got {algo}"
+ assert (
+ algo in supported_list
+ ), f"algorithms that supports static training must in in {supported_list} but got {algo}"
specs = [
- InputSpec(
- [None] + config["Global"]["d2s_train_image_shape"], dtype='float32')
+ InputSpec([None] + config["Global"]["d2s_train_image_shape"], dtype="float32")
]
if algo == "SVTR_LCNet":
- specs.append([
- InputSpec(
- [None, config["Global"]["max_text_length"]],
- dtype='int64'), InputSpec(
- [None, config["Global"]["max_text_length"]], dtype='int64'),
- InputSpec(
- [None], dtype='int64'), InputSpec(
- [None], dtype='float64')
- ])
+ specs.append(
+ [
+ InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
+ InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
+ InputSpec([None], dtype="int64"),
+ InputSpec([None], dtype="float64"),
+ ]
+ )
elif algo == "TableMaster":
specs.append(
[
+ InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec(
- [None, config["Global"]["max_text_length"]], dtype='int64'),
- InputSpec(
- [None, config["Global"]["max_text_length"], 4],
- dtype='float32'),
+ [None, config["Global"]["max_text_length"], 4], dtype="float32"
+ ),
InputSpec(
- [None, config["Global"]["max_text_length"], 1],
- dtype='float32'),
- InputSpec(
- [None, 6], dtype='float32'),
- ])
+ [None, config["Global"]["max_text_length"], 1], dtype="float32"
+ ),
+ InputSpec([None, 6], dtype="float32"),
+ ]
+ )
elif algo == "LayoutXLM":
- specs = [[
- InputSpec(
- shape=[None, 512], dtype="int64"), # input_ids
- InputSpec(
- shape=[None, 512, 4], dtype="int64"), # bbox
- InputSpec(
- shape=[None, 512], dtype="int64"), # attention_mask
- InputSpec(
- shape=[None, 512], dtype="int64"), # token_type_ids
- InputSpec(
- shape=[None, 3, 224, 224], dtype="float32"), # image
- InputSpec(
- shape=[None, 512], dtype="int64"), # label
- ]]
+ specs = [
+ [
+ InputSpec(shape=[None, 512], dtype="int64"), # input_ids
+ InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
+ InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
+ InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
+ InputSpec(shape=[None, 3, 224, 224], dtype="float32"), # image
+ InputSpec(shape=[None, 512], dtype="int64"), # label
+ ]
+ ]
elif algo == "SLANet":
- specs.append([
- InputSpec(
- [None, config["Global"]["max_text_length"] + 2], dtype='int64'),
- InputSpec(
- [None, config["Global"]["max_text_length"] + 2, 4],
- dtype='float32'),
- InputSpec(
- [None, config["Global"]["max_text_length"] + 2, 1],
- dtype='float32'),
- InputSpec(
- [None, 6], dtype='float64'),
- ])
+ specs.append(
+ [
+ InputSpec(
+ [None, config["Global"]["max_text_length"] + 2], dtype="int64"
+ ),
+ InputSpec(
+ [None, config["Global"]["max_text_length"] + 2, 4], dtype="float32"
+ ),
+ InputSpec(
+ [None, config["Global"]["max_text_length"] + 2, 1], dtype="float32"
+ ),
+ InputSpec([None, 6], dtype="float64"),
+ ]
+ )
elif algo == "SVTR":
- specs.append([
- InputSpec(
- [None, config["Global"]["max_text_length"]], dtype='int64'),
- InputSpec(
- [None], dtype='int64')
- ])
+ specs.append(
+ [
+ InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
+ InputSpec([None], dtype="int64"),
+ ]
+ )
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index 5612d366ea..841312a813 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -21,7 +21,7 @@
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
-__all__ = ['BaseModel']
+__all__ = ["BaseModel"]
class BaseModel(nn.Layer):
@@ -32,26 +32,26 @@ def __init__(self, config):
config (dict): the super parameters for module.
"""
super(BaseModel, self).__init__()
- in_channels = config.get('in_channels', 3)
- model_type = config['model_type']
+ in_channels = config.get("in_channels", 3)
+ model_type = config["model_type"]
# build transfrom,
# for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None,
# if you make model differently, you can use transfrom in det and cls
- if 'Transform' not in config or config['Transform'] is None:
+ if "Transform" not in config or config["Transform"] is None:
self.use_transform = False
else:
self.use_transform = True
- config['Transform']['in_channels'] = in_channels
- self.transform = build_transform(config['Transform'])
+ config["Transform"]["in_channels"] = in_channels
+ self.transform = build_transform(config["Transform"])
in_channels = self.transform.out_channels
# build backbone, backbone is need for del, rec and cls
- if 'Backbone' not in config or config['Backbone'] is None:
+ if "Backbone" not in config or config["Backbone"] is None:
self.use_backbone = False
else:
self.use_backbone = True
- config["Backbone"]['in_channels'] = in_channels
+ config["Backbone"]["in_channels"] = in_channels
self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels
@@ -59,26 +59,25 @@ def __init__(self, config):
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
# for cls, neck should be none
- if 'Neck' not in config or config['Neck'] is None:
+ if "Neck" not in config or config["Neck"] is None:
self.use_neck = False
else:
self.use_neck = True
- config['Neck']['in_channels'] = in_channels
- self.neck = build_neck(config['Neck'])
+ config["Neck"]["in_channels"] = in_channels
+ self.neck = build_neck(config["Neck"])
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
- if 'Head' not in config or config['Head'] is None:
+ if "Head" not in config or config["Head"] is None:
self.use_head = False
else:
self.use_head = True
- config["Head"]['in_channels'] = in_channels
+ config["Head"]["in_channels"] = in_channels
self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
-
y = dict()
if self.use_transform:
x = self.transform(x)
@@ -99,7 +98,7 @@ def forward(self, x, data=None):
if self.use_head:
x = self.head(x, targets=data)
# for multi head, save ctc neck out for udml
- if isinstance(x, dict) and 'ctc_neck' in x.keys():
+ if isinstance(x, dict) and "ctc_neck" in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
@@ -115,4 +114,4 @@ def forward(self, x, data=None):
else:
return {final_name: x}
else:
- return x
\ No newline at end of file
+ return x
diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py
index cce8fd311d..2309d2eac5 100644
--- a/ppocr/modeling/architectures/distillation_model.py
+++ b/ppocr/modeling/architectures/distillation_model.py
@@ -23,7 +23,7 @@
from .base_model import BaseModel
from ppocr.utils.save_load import load_pretrained_params
-__all__ = ['DistillationModel']
+__all__ = ["DistillationModel"]
class DistillationModel(nn.Layer):
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 10839b82b7..ce80afd109 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -25,13 +25,20 @@ def build_backbone(config, model_type):
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit import ViT
+
support_dict = [
- "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet",
- "PPLCNetV3", "PPHGNet_small"
+ "MobileNetV3",
+ "ResNet",
+ "ResNet_vd",
+ "ResNet_SAST",
+ "PPLCNet",
+ "PPLCNetV3",
+ "PPHGNet_small",
]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
- support_dict.append('TableResNetExtra')
+
+ support_dict.append("TableResNetExtra")
elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
@@ -52,32 +59,64 @@ def build_backbone(config, model_type):
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit_parseq import ViTParseQ
+
support_dict = [
- 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
- 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
- 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
- 'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ', 'ViT'
+ "MobileNetV1Enhance",
+ "MobileNetV3",
+ "ResNet",
+ "ResNetFPN",
+ "MTB",
+ "ResNet31",
+ "ResNet45",
+ "ResNet_ASTER",
+ "MicroNet",
+ "EfficientNetb3_PREN",
+ "SVTRNet",
+ "ViTSTR",
+ "ResNet32",
+ "ResNetRFL",
+ "DenseNet",
+ "ShallowCNN",
+ "PPLCNetV3",
+ "PPHGNet_small",
+ "ViTParseQ",
+ "ViT",
]
- elif model_type == 'e2e':
+ elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
- support_dict = ['ResNet']
- elif model_type == 'kie':
+
+ support_dict = ["ResNet"]
+ elif model_type == "kie":
from .kie_unet_sdmgr import Kie_backbone
- from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
+ from .vqa_layoutlm import (
+ LayoutLMForSer,
+ LayoutLMv2ForSer,
+ LayoutLMv2ForRe,
+ LayoutXLMForSer,
+ LayoutXLMForRe,
+ )
+
support_dict = [
- 'Kie_backbone', 'LayoutLMForSer', 'LayoutLMv2ForSer',
- 'LayoutLMv2ForRe', 'LayoutXLMForSer', 'LayoutXLMForRe'
+ "Kie_backbone",
+ "LayoutLMForSer",
+ "LayoutLMv2ForSer",
+ "LayoutLMv2ForRe",
+ "LayoutXLMForSer",
+ "LayoutXLMForRe",
]
- elif model_type == 'table':
+ elif model_type == "table":
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
- support_dict = ['ResNet', 'MobileNetV3']
+
+ support_dict = ["ResNet", "MobileNetV3"]
else:
raise NotImplementedError
- module_name = config.pop('name')
+ module_name = config.pop("name")
assert module_name in support_dict, Exception(
- "when model typs is {}, backbone only support {}".format(model_type,
- support_dict))
+ "when model typs is {}, backbone only support {}".format(
+ model_type, support_dict
+ )
+ )
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py
index 05113ea841..98db44b691 100755
--- a/ppocr/modeling/backbones/det_mobilenet_v3.py
+++ b/ppocr/modeling/backbones/det_mobilenet_v3.py
@@ -21,7 +21,7 @@
import paddle.nn.functional as F
from paddle import ParamAttr
-__all__ = ['MobileNetV3']
+__all__ = ["MobileNetV3"]
def make_divisible(v, divisor=8, min_value=None):
@@ -34,12 +34,9 @@ def make_divisible(v, divisor=8, min_value=None):
class MobileNetV3(nn.Layer):
- def __init__(self,
- in_channels=3,
- model_name='large',
- scale=0.5,
- disable_se=False,
- **kwargs):
+ def __init__(
+ self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs
+ ):
"""
the MobilenetV3 backbone network for detection module.
Args:
@@ -52,46 +49,48 @@ def __init__(self,
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
- [3, 16, 16, False, 'relu', 1],
- [3, 64, 24, False, 'relu', 2],
- [3, 72, 24, False, 'relu', 1],
- [5, 72, 40, True, 'relu', 2],
- [5, 120, 40, True, 'relu', 1],
- [5, 120, 40, True, 'relu', 1],
- [3, 240, 80, False, 'hardswish', 2],
- [3, 200, 80, False, 'hardswish', 1],
- [3, 184, 80, False, 'hardswish', 1],
- [3, 184, 80, False, 'hardswish', 1],
- [3, 480, 112, True, 'hardswish', 1],
- [3, 672, 112, True, 'hardswish', 1],
- [5, 672, 160, True, 'hardswish', 2],
- [5, 960, 160, True, 'hardswish', 1],
- [5, 960, 160, True, 'hardswish', 1],
+ [3, 16, 16, False, "relu", 1],
+ [3, 64, 24, False, "relu", 2],
+ [3, 72, 24, False, "relu", 1],
+ [5, 72, 40, True, "relu", 2],
+ [5, 120, 40, True, "relu", 1],
+ [5, 120, 40, True, "relu", 1],
+ [3, 240, 80, False, "hardswish", 2],
+ [3, 200, 80, False, "hardswish", 1],
+ [3, 184, 80, False, "hardswish", 1],
+ [3, 184, 80, False, "hardswish", 1],
+ [3, 480, 112, True, "hardswish", 1],
+ [3, 672, 112, True, "hardswish", 1],
+ [5, 672, 160, True, "hardswish", 2],
+ [5, 960, 160, True, "hardswish", 1],
+ [5, 960, 160, True, "hardswish", 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
- [3, 16, 16, True, 'relu', 2],
- [3, 72, 24, False, 'relu', 2],
- [3, 88, 24, False, 'relu', 1],
- [5, 96, 40, True, 'hardswish', 2],
- [5, 240, 40, True, 'hardswish', 1],
- [5, 240, 40, True, 'hardswish', 1],
- [5, 120, 48, True, 'hardswish', 1],
- [5, 144, 48, True, 'hardswish', 1],
- [5, 288, 96, True, 'hardswish', 2],
- [5, 576, 96, True, 'hardswish', 1],
- [5, 576, 96, True, 'hardswish', 1],
+ [3, 16, 16, True, "relu", 2],
+ [3, 72, 24, False, "relu", 2],
+ [3, 88, 24, False, "relu", 1],
+ [5, 96, 40, True, "hardswish", 2],
+ [5, 240, 40, True, "hardswish", 1],
+ [5, 240, 40, True, "hardswish", 1],
+ [5, 120, 48, True, "hardswish", 1],
+ [5, 144, 48, True, "hardswish", 1],
+ [5, 288, 96, True, "hardswish", 2],
+ [5, 576, 96, True, "hardswish", 1],
+ [5, 576, 96, True, "hardswish", 1],
]
cls_ch_squeeze = 576
else:
- raise NotImplementedError("mode[" + model_name +
- "_model] is not implemented!")
+ raise NotImplementedError(
+ "mode[" + model_name + "_model] is not implemented!"
+ )
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
- assert scale in supported_scale, \
- "supported scale are {} but input scale is {}".format(supported_scale, scale)
+ assert (
+ scale in supported_scale
+ ), "supported scale are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv = ConvBNLayer(
@@ -102,16 +101,17 @@ def __init__(self,
padding=1,
groups=1,
if_act=True,
- act='hardswish')
+ act="hardswish",
+ )
self.stages = []
self.out_channels = []
block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
- for (k, exp, c, se, nl, s) in cfg:
+ for k, exp, c, se, nl, s in cfg:
se = se and not self.disable_se
- start_idx = 2 if model_name == 'large' else 0
+ start_idx = 2 if model_name == "large" else 0
if s == 2 and i > start_idx:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
@@ -124,7 +124,9 @@ def __init__(self,
kernel_size=k,
stride=s,
use_se=se,
- act=nl))
+ act=nl,
+ )
+ )
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
@@ -136,7 +138,9 @@ def __init__(self,
padding=0,
groups=1,
if_act=True,
- act='hardswish'))
+ act="hardswish",
+ )
+ )
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
@@ -152,15 +156,17 @@ def forward(self, x):
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups=1,
- if_act=True,
- act=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ ):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -171,7 +177,8 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
@@ -184,21 +191,26 @@ def forward(self, x):
elif self.act == "hardswish":
x = F.hardswish(x)
else:
- print("The activation function({}) is selected incorrectly.".
- format(self.act))
+ print(
+ "The activation function({}) is selected incorrectly.".format(
+ self.act
+ )
+ )
exit()
return x
class ResidualUnit(nn.Layer):
- def __init__(self,
- in_channels,
- mid_channels,
- out_channels,
- kernel_size,
- stride,
- use_se,
- act=None):
+ def __init__(
+ self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ use_se,
+ act=None,
+ ):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
@@ -210,7 +222,8 @@ def __init__(self,
stride=1,
padding=0,
if_act=True,
- act=act)
+ act=act,
+ )
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
@@ -219,7 +232,8 @@ def __init__(self,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
- act=act)
+ act=act,
+ )
if self.if_se:
self.mid_se = SEModule(mid_channels)
self.linear_conv = ConvBNLayer(
@@ -229,7 +243,8 @@ def __init__(self,
stride=1,
padding=0,
if_act=False,
- act=None)
+ act=None,
+ )
def forward(self, inputs):
x = self.expand_conv(inputs)
@@ -251,13 +266,15 @@ def __init__(self, in_channels, reduction=4):
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
- padding=0)
+ padding=0,
+ )
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
- padding=0)
+ padding=0,
+ )
def forward(self, inputs):
outputs = self.avg_pool(inputs)
diff --git a/ppocr/modeling/backbones/det_pp_lcnet.py b/ppocr/modeling/backbones/det_pp_lcnet.py
index 3f719e92bc..bf557a480d 100644
--- a/ppocr/modeling/backbones/det_pp_lcnet.py
+++ b/ppocr/modeling/backbones/det_pp_lcnet.py
@@ -24,22 +24,14 @@
from paddle.utils.download import get_path_from_url
MODEL_URLS = {
- "PPLCNet_x0.25":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams",
- "PPLCNet_x0.35":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams",
- "PPLCNet_x0.5":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams",
- "PPLCNet_x0.75":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams",
- "PPLCNet_x1.0":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams",
- "PPLCNet_x1.5":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams",
- "PPLCNet_x2.0":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams",
- "PPLCNet_x2.5":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams"
+ "PPLCNet_x0.25": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams",
+ "PPLCNet_x0.35": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams",
+ "PPLCNet_x0.5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams",
+ "PPLCNet_x0.75": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams",
+ "PPLCNet_x1.0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams",
+ "PPLCNet_x1.5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams",
+ "PPLCNet_x2.0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams",
+ "PPLCNet_x2.5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams",
}
MODEL_STAGES_PATTERN = {
@@ -61,10 +53,15 @@
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
- "blocks5":
- [[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
- [5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
- "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
+ "blocks5": [
+ [3, 128, 256, 2, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ ],
+ "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]],
}
@@ -78,12 +75,7 @@ def make_divisible(v, divisor=8, min_value=None):
class ConvBNLayer(nn.Layer):
- def __init__(self,
- num_channels,
- filter_size,
- num_filters,
- stride,
- num_groups=1):
+ def __init__(self, num_channels, filter_size, num_filters, stride, num_groups=1):
super().__init__()
self.conv = Conv2D(
@@ -94,12 +86,14 @@ def __init__(self,
padding=(filter_size - 1) // 2,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = BatchNorm(
num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
- bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ )
self.hardswish = nn.Hardswish()
def forward(self, x):
@@ -110,12 +104,7 @@ def forward(self, x):
class DepthwiseSeparable(nn.Layer):
- def __init__(self,
- num_channels,
- num_filters,
- stride,
- dw_size=3,
- use_se=False):
+ def __init__(self, num_channels, num_filters, stride, dw_size=3, use_se=False):
super().__init__()
self.use_se = use_se
self.dw_conv = ConvBNLayer(
@@ -123,14 +112,13 @@ def __init__(self,
num_filters=num_channels,
filter_size=dw_size,
stride=stride,
- num_groups=num_channels)
+ num_groups=num_channels,
+ )
if use_se:
self.se = SEModule(num_channels)
self.pw_conv = ConvBNLayer(
- num_channels=num_channels,
- filter_size=1,
- num_filters=num_filters,
- stride=1)
+ num_channels=num_channels, filter_size=1, num_filters=num_filters, stride=1
+ )
def forward(self, x):
x = self.dw_conv(x)
@@ -149,14 +137,16 @@ def __init__(self, channel, reduction=4):
out_channels=channel // reduction,
kernel_size=1,
stride=1,
- padding=0)
+ padding=0,
+ )
self.relu = nn.ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
- padding=0)
+ padding=0,
+ )
self.hardsigmoid = nn.Hardsigmoid()
def forward(self, x):
@@ -171,17 +161,13 @@ def forward(self, x):
class PPLCNet(nn.Layer):
- def __init__(self,
- in_channels=3,
- scale=1.0,
- pretrained=False,
- use_ssld=False):
+ def __init__(self, in_channels=3, scale=1.0, pretrained=False, use_ssld=False):
super().__init__()
self.out_channels = [
int(NET_CONFIG["blocks3"][-1][2] * scale),
int(NET_CONFIG["blocks4"][-1][2] * scale),
int(NET_CONFIG["blocks5"][-1][2] * scale),
- int(NET_CONFIG["blocks6"][-1][2] * scale)
+ int(NET_CONFIG["blocks6"][-1][2] * scale),
]
self.scale = scale
@@ -189,61 +175,78 @@ def __init__(self,
num_channels=in_channels,
filter_size=3,
num_filters=make_divisible(16 * scale),
- stride=2)
+ stride=2,
+ )
- self.blocks2 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
- ])
+ self.blocks2 = nn.Sequential(
+ *[
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
+ ]
+ )
- self.blocks3 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
- ])
+ self.blocks3 = nn.Sequential(
+ *[
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
+ ]
+ )
- self.blocks4 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
- ])
+ self.blocks4 = nn.Sequential(
+ *[
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
+ ]
+ )
- self.blocks5 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
- ])
+ self.blocks5 = nn.Sequential(
+ *[
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
+ ]
+ )
- self.blocks6 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
- ])
+ self.blocks6 = nn.Sequential(
+ *[
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
+ ]
+ )
if pretrained:
self._load_pretrained(
- MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld)
+ MODEL_URLS["PPLCNet_x{}".format(scale)], use_ssld=use_ssld
+ )
def forward(self, x):
outs = []
@@ -261,11 +264,11 @@ def forward(self, x):
def _load_pretrained(self, pretrained_url, use_ssld=False):
if use_ssld:
- pretrained_url = pretrained_url.replace("_pretrained",
- "_ssld_pretrained")
+ pretrained_url = pretrained_url.replace("_pretrained", "_ssld_pretrained")
print(pretrained_url)
local_weight_path = get_path_from_url(
- pretrained_url, os.path.expanduser("~/.paddleclas/weights"))
+ pretrained_url, os.path.expanduser("~/.paddleclas/weights")
+ )
param_state_dict = paddle.load(local_weight_path)
self.set_dict(param_state_dict)
return
diff --git a/ppocr/modeling/backbones/det_resnet.py b/ppocr/modeling/backbones/det_resnet.py
index 87eef11cf0..ff059610cd 100644
--- a/ppocr/modeling/backbones/det_resnet.py
+++ b/ppocr/modeling/backbones/det_resnet.py
@@ -34,19 +34,15 @@
class BottleneckBlock(nn.Layer):
- def __init__(self,
- num_channels,
- num_filters,
- stride,
- shortcut=True,
- is_dcn=False):
+ def __init__(self, num_channels, num_filters, stride, shortcut=True, is_dcn=False):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
- act="relu", )
+ act="relu",
+ )
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
@@ -54,19 +50,22 @@ def __init__(self,
stride=stride,
act="relu",
is_dcn=is_dcn,
- dcn_groups=1, )
+ dcn_groups=1,
+ )
self.conv2 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters * 4,
kernel_size=1,
- act=None, )
+ act=None,
+ )
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters * 4,
kernel_size=1,
- stride=stride, )
+ stride=stride,
+ )
self.shortcut = shortcut
@@ -88,12 +87,7 @@ def forward(self, inputs):
class BasicBlock(nn.Layer):
- def __init__(self,
- num_channels,
- num_filters,
- stride,
- shortcut=True,
- name=None):
+ def __init__(self, num_channels, num_filters, stride, shortcut=True, name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -101,19 +95,19 @@ def __init__(self,
out_channels=num_filters,
kernel_size=3,
stride=stride,
- act="relu")
+ act="relu",
+ )
self.conv1 = ConvBNLayer(
- in_channels=num_filters,
- out_channels=num_filters,
- kernel_size=3,
- act=None)
+ in_channels=num_filters, out_channels=num_filters, kernel_size=3, act=None
+ )
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
- stride=stride)
+ stride=stride,
+ )
self.shortcut = shortcut
@@ -131,20 +125,18 @@ def forward(self, inputs):
class ResNet(nn.Layer):
- def __init__(self,
- in_channels=3,
- layers=50,
- out_indices=None,
- dcn_stage=None):
+ def __init__(self, in_channels=3, layers=50, out_indices=None, dcn_stage=None):
super(ResNet, self).__init__()
self.layers = layers
self.input_image_channel = in_channels
supported_layers = [18, 34, 50, 101, 152]
- assert layers in supported_layers, \
- "supported layers are {} but input layer is {}".format(
- supported_layers, layers)
+ assert (
+ layers in supported_layers
+ ), "supported layers are {} but input layer is {}".format(
+ supported_layers, layers
+ )
if layers == 18:
depth = [2, 2, 2, 2]
@@ -154,27 +146,26 @@ def __init__(self,
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
- num_channels = [64, 256, 512,
- 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_channels = [64, 256, 512, 1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
- self.dcn_stage = dcn_stage if dcn_stage is not None else [
- False, False, False, False
- ]
- self.out_indices = out_indices if out_indices is not None else [
- 0, 1, 2, 3
- ]
+ self.dcn_stage = (
+ dcn_stage if dcn_stage is not None else [False, False, False, False]
+ )
+ self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
self.conv = ConvBNLayer(
in_channels=self.input_image_channel,
out_channels=64,
kernel_size=7,
stride=2,
- act="relu", )
+ act="relu",
+ )
self.pool2d_max = MaxPool2D(
kernel_size=3,
stride=2,
- padding=1, )
+ padding=1,
+ )
self.stages = []
self.out_channels = []
@@ -195,11 +186,14 @@ def __init__(self,
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
- if i == 0 else num_filters[block] * 4,
+ if i == 0
+ else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
- is_dcn=is_dcn))
+ is_dcn=is_dcn,
+ ),
+ )
block_list.append(bottleneck_block)
shortcut = True
if block in self.out_indices:
@@ -215,10 +209,13 @@ def __init__(self,
conv_name,
BasicBlock(
num_channels=num_channels[block]
- if i == 0 else num_filters[block],
+ if i == 0
+ else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
- shortcut=shortcut))
+ shortcut=shortcut,
+ ),
+ )
block_list.append(basic_block)
shortcut = True
if block in self.out_indices:
diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py
index a421da0ab4..1d26d5789d 100644
--- a/ppocr/modeling/backbones/det_resnet_vd.py
+++ b/ppocr/modeling/backbones/det_resnet_vd.py
@@ -29,21 +29,23 @@
class DeformableConvV2(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- weight_attr=None,
- bias_attr=None,
- lr_scale=1,
- regularizer=None,
- skip_quant=False,
- dcn_bias_regularizer=L2Decay(0.),
- dcn_bias_lr_scale=2.):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ weight_attr=None,
+ bias_attr=None,
+ lr_scale=1,
+ regularizer=None,
+ skip_quant=False,
+ dcn_bias_regularizer=L2Decay(0.0),
+ dcn_bias_lr_scale=2.0,
+ ):
super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2 * groups
self.mask_channel = kernel_size**2 * groups
@@ -53,7 +55,8 @@ def __init__(self,
dcn_bias_attr = ParamAttr(
initializer=Constant(value=0),
regularizer=dcn_bias_regularizer,
- learning_rate=dcn_bias_lr_scale)
+ learning_rate=dcn_bias_lr_scale,
+ )
else:
# in ResNet backbone, do not need bias
dcn_bias_attr = False
@@ -66,15 +69,17 @@ def __init__(self,
dilation=dilation,
deformable_groups=groups,
weight_attr=weight_attr,
- bias_attr=dcn_bias_attr)
+ bias_attr=dcn_bias_attr,
+ )
if lr_scale == 1 and regularizer is None:
- offset_bias_attr = ParamAttr(initializer=Constant(0.))
+ offset_bias_attr = ParamAttr(initializer=Constant(0.0))
else:
offset_bias_attr = ParamAttr(
- initializer=Constant(0.),
+ initializer=Constant(0.0),
learning_rate=lr_scale,
- regularizer=regularizer)
+ regularizer=regularizer,
+ )
self.conv_offset = nn.Conv2D(
in_channels,
groups * 3 * kernel_size**2,
@@ -82,7 +87,8 @@ def __init__(self,
stride=stride,
padding=(kernel_size - 1) // 2,
weight_attr=ParamAttr(initializer=Constant(0.0)),
- bias_attr=offset_bias_attr)
+ bias_attr=offset_bias_attr,
+ )
if skip_quant:
self.conv_offset.skip_quant = True
@@ -91,28 +97,32 @@ def forward(self, x):
offset, mask = paddle.split(
offset_mask,
num_or_sections=[self.offset_channel, self.mask_channel],
- axis=1)
+ axis=1,
+ )
mask = F.sigmoid(mask)
y = self.conv_dcn(x, offset, mask=mask)
return y
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- dcn_groups=1,
- is_vd_mode=False,
- act=None,
- is_dcn=False):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ dcn_groups=1,
+ is_vd_mode=False,
+ act=None,
+ is_dcn=False,
+ ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ kernel_size=2, stride=2, padding=0, ceil_mode=True
+ )
if not is_dcn:
self._conv = nn.Conv2D(
in_channels=in_channels,
@@ -121,7 +131,8 @@ def __init__(self,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- bias_attr=False)
+ bias_attr=False,
+ )
else:
self._conv = DeformableConvV2(
in_channels=in_channels,
@@ -129,8 +140,9 @@ def __init__(self,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
- groups=dcn_groups, #groups,
- bias_attr=False)
+ groups=dcn_groups, # groups,
+ bias_attr=False,
+ )
self._batch_norm = nn.BatchNorm(out_channels, act=act)
def forward(self, inputs):
@@ -143,33 +155,37 @@ def forward(self, inputs):
class BottleneckBlock(nn.Layer):
def __init__(
- self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- is_dcn=False, ):
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ is_dcn=False,
+ ):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu')
+ act="relu",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
+ act="relu",
is_dcn=is_dcn,
- dcn_groups=2)
+ dcn_groups=2,
+ )
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
- act=None)
+ act=None,
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -177,7 +193,8 @@ def __init__(
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True)
+ is_vd_mode=False if if_first else True,
+ )
self.shortcut = shortcut
@@ -197,12 +214,13 @@ def forward(self, inputs):
class BasicBlock(nn.Layer):
def __init__(
- self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False, ):
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -210,12 +228,11 @@ def __init__(
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu')
+ act="relu",
+ )
self.conv1 = ConvBNLayer(
- in_channels=out_channels,
- out_channels=out_channels,
- kernel_size=3,
- act=None)
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, act=None
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -223,7 +240,8 @@ def __init__(
out_channels=out_channels,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True)
+ is_vd_mode=False if if_first else True,
+ )
self.shortcut = shortcut
@@ -241,19 +259,18 @@ def forward(self, inputs):
class ResNet_vd(nn.Layer):
- def __init__(self,
- in_channels=3,
- layers=50,
- dcn_stage=None,
- out_indices=None,
- **kwargs):
+ def __init__(
+ self, in_channels=3, layers=50, dcn_stage=None, out_indices=None, **kwargs
+ ):
super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
- assert layers in supported_layers, \
- "supported layers are {} but input layer is {}".format(
- supported_layers, layers)
+ assert (
+ layers in supported_layers
+ ), "supported layers are {} but input layer is {}".format(
+ supported_layers, layers
+ )
if layers == 18:
depth = [2, 2, 2, 2]
@@ -265,35 +282,27 @@ def __init__(self,
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
- num_channels = [64, 256, 512,
- 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_channels = [64, 256, 512, 1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
- self.dcn_stage = dcn_stage if dcn_stage is not None else [
- False, False, False, False
- ]
- self.out_indices = out_indices if out_indices is not None else [
- 0, 1, 2, 3
- ]
+ self.dcn_stage = (
+ dcn_stage if dcn_stage is not None else [False, False, False, False]
+ )
+ self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
- act='relu')
+ act="relu",
+ )
self.conv1_2 = ConvBNLayer(
- in_channels=32,
- out_channels=32,
- kernel_size=3,
- stride=1,
- act='relu')
+ in_channels=32, out_channels=32, kernel_size=3, stride=1, act="relu"
+ )
self.conv1_3 = ConvBNLayer(
- in_channels=32,
- out_channels=64,
- kernel_size=3,
- stride=1,
- act='relu')
+ in_channels=32, out_channels=64, kernel_size=3, stride=1, act="relu"
+ )
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
@@ -305,15 +314,18 @@ def __init__(self,
is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
bottleneck_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block] * 4,
+ if i == 0
+ else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
- is_dcn=is_dcn))
+ is_dcn=is_dcn,
+ ),
+ )
shortcut = True
block_list.append(bottleneck_block)
if block in self.out_indices:
@@ -325,14 +337,17 @@ def __init__(self,
shortcut = False
for i in range(depth[block]):
basic_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BasicBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block],
+ if i == 0
+ else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
- if_first=block == i == 0))
+ if_first=block == i == 0,
+ ),
+ )
shortcut = True
block_list.append(basic_block)
if block in self.out_indices:
diff --git a/ppocr/modeling/backbones/det_resnet_vd_sast.py b/ppocr/modeling/backbones/det_resnet_vd_sast.py
index c9376a8d56..7cb349afc7 100644
--- a/ppocr/modeling/backbones/det_resnet_vd_sast.py
+++ b/ppocr/modeling/backbones/det_resnet_vd_sast.py
@@ -26,20 +26,22 @@
class ConvBNLayer(nn.Layer):
def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None, ):
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ kernel_size=2, stride=2, padding=0, ceil_mode=True
+ )
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
@@ -48,7 +50,8 @@ def __init__(
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False)
+ bias_attr=False,
+ )
if name == "conv1":
bn_name = "bn_" + name
else:
@@ -56,10 +59,11 @@ def __init__(
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
+ param_attr=ParamAttr(name=bn_name + "_scale"),
+ bias_attr=ParamAttr(bn_name + "_offset"),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance",
+ )
def forward(self, inputs):
if self.is_vd_mode:
@@ -70,34 +74,39 @@ def forward(self, inputs):
class BottleneckBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu',
- name=name + "_branch2a")
+ act="relu",
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2b")
+ act="relu",
+ name=name + "_branch2b",
+ )
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
- name=name + "_branch2c")
+ name=name + "_branch2c",
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -106,7 +115,8 @@ def __init__(self,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.shortcut = shortcut
@@ -125,13 +135,15 @@ def forward(self, inputs):
class BasicBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -139,14 +151,16 @@ def __init__(self,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2a")
+ act="relu",
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
- name=name + "_branch2b")
+ name=name + "_branch2b",
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -155,7 +169,8 @@ def __init__(self,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.shortcut = shortcut
@@ -178,9 +193,11 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
- assert layers in supported_layers, \
- "supported layers are {} but input layer is {}".format(
- supported_layers, layers)
+ assert (
+ layers in supported_layers
+ ), "supported layers are {} but input layer is {}".format(
+ supported_layers, layers
+ )
if layers == 18:
depth = [2, 2, 2, 2]
@@ -196,8 +213,9 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
# num_channels = [64, 256, 512,
# 1024] if layers >= 50 else [64, 64, 128, 256]
# num_filters = [64, 128, 256, 512]
- num_channels = [64, 256, 512,
- 1024, 2048] if layers >= 50 else [64, 64, 128, 256]
+ num_channels = (
+ [64, 256, 512, 1024, 2048] if layers >= 50 else [64, 64, 128, 256]
+ )
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
@@ -205,22 +223,25 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
out_channels=32,
kernel_size=3,
stride=2,
- act='relu',
- name="conv1_1")
+ act="relu",
+ name="conv1_1",
+ )
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_2")
+ act="relu",
+ name="conv1_2",
+ )
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_3")
+ act="relu",
+ name="conv1_3",
+ )
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
@@ -238,15 +259,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block] * 4,
+ if i == 0
+ else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
@@ -258,15 +282,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BasicBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block],
+ if i == 0
+ else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
@@ -282,4 +309,4 @@ def forward(self, inputs):
for block in self.stages:
y = block(y)
out.append(y)
- return out
\ No newline at end of file
+ return out
diff --git a/ppocr/modeling/backbones/e2e_resnet_vd_pg.py b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py
index 97afd3460d..16defc7719 100644
--- a/ppocr/modeling/backbones/e2e_resnet_vd_pg.py
+++ b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py
@@ -26,20 +26,22 @@
class ConvBNLayer(nn.Layer):
def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None, ):
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ kernel_size=2, stride=2, padding=0, ceil_mode=True
+ )
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
@@ -48,7 +50,8 @@ def __init__(
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False)
+ bias_attr=False,
+ )
if name == "conv1":
bn_name = "bn_" + name
else:
@@ -56,10 +59,11 @@ def __init__(
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
+ param_attr=ParamAttr(name=bn_name + "_scale"),
+ bias_attr=ParamAttr(bn_name + "_offset"),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance",
+ )
def forward(self, inputs):
y = self._conv(inputs)
@@ -68,34 +72,39 @@ def forward(self, inputs):
class BottleneckBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu',
- name=name + "_branch2a")
+ act="relu",
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2b")
+ act="relu",
+ name=name + "_branch2b",
+ )
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
- name=name + "_branch2c")
+ name=name + "_branch2c",
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -104,7 +113,8 @@ def __init__(self,
kernel_size=1,
stride=stride,
is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.shortcut = shortcut
@@ -123,13 +133,15 @@ def forward(self, inputs):
class BasicBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -137,14 +149,16 @@ def __init__(self,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2a")
+ act="relu",
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
- name=name + "_branch2b")
+ name=name + "_branch2b",
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -153,7 +167,8 @@ def __init__(self,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.shortcut = shortcut
@@ -176,9 +191,11 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
- assert layers in supported_layers, \
- "supported layers are {} but input layer is {}".format(
- supported_layers, layers)
+ assert (
+ layers in supported_layers
+ ), "supported layers are {} but input layer is {}".format(
+ supported_layers, layers
+ )
if layers == 18:
depth = [2, 2, 2, 2]
@@ -191,8 +208,9 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
- num_channels = [64, 256, 512, 1024,
- 2048] if layers >= 50 else [64, 64, 128, 256]
+ num_channels = (
+ [64, 256, 512, 1024, 2048] if layers >= 50 else [64, 64, 128, 256]
+ )
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
@@ -200,8 +218,9 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
out_channels=64,
kernel_size=7,
stride=2,
- act='relu',
- name="conv1_1")
+ act="relu",
+ name="conv1_1",
+ )
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
@@ -220,15 +239,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block] * 4,
+ if i == 0
+ else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
@@ -240,15 +262,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BasicBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block],
+ if i == 0
+ else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py
index 4b1bd80300..7b0c2b5bb7 100644
--- a/ppocr/modeling/backbones/kie_unet_sdmgr.py
+++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py
@@ -33,8 +33,9 @@ def __init__(self, num_channels, num_filters):
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
- self.bn1 = nn.BatchNorm(num_filters, act='relu')
+ bias_attr=False,
+ )
+ self.bn1 = nn.BatchNorm(num_filters, act="relu")
self.conv2 = nn.Conv2D(
num_filters,
@@ -42,8 +43,9 @@ def __init__(self, num_channels, num_filters):
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
- self.bn2 = nn.BatchNorm(num_filters, act='relu')
+ bias_attr=False,
+ )
+ self.bn2 = nn.BatchNorm(num_filters, act="relu")
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
@@ -66,8 +68,9 @@ def __init__(self, num_channels, num_filters):
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
- self.bn1 = nn.BatchNorm(num_filters, act='relu')
+ bias_attr=False,
+ )
+ self.bn1 = nn.BatchNorm(num_filters, act="relu")
self.conv2 = nn.Conv2D(
num_filters,
@@ -75,8 +78,9 @@ def __init__(self, num_channels, num_filters):
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
- self.bn2 = nn.BatchNorm(num_filters, act='relu')
+ bias_attr=False,
+ )
+ self.bn2 = nn.BatchNorm(num_filters, act="relu")
self.conv0 = nn.Conv2D(
num_channels,
@@ -84,14 +88,16 @@ def __init__(self, num_channels, num_filters):
kernel_size=1,
stride=1,
padding=0,
- bias_attr=False)
- self.bn0 = nn.BatchNorm(num_filters, act='relu')
+ bias_attr=False,
+ )
+ self.bn0 = nn.BatchNorm(num_filters, act="relu")
def forward(self, inputs_prev, inputs):
x = self.conv0(inputs)
x = self.bn0(x)
x = paddle.nn.functional.interpolate(
- x, scale_factor=2, mode='bilinear', align_corners=False)
+ x, scale_factor=2, mode="bilinear", align_corners=False
+ )
x = paddle.concat([inputs_prev, x], axis=1)
x = self.conv1(x)
x = self.bn1(x)
@@ -143,13 +149,18 @@ def bbox2roi(self, bbox_list):
rois_num.append(bboxes.shape[0])
rois_list.append(bboxes)
rois = paddle.concat(rois_list, 0)
- rois_num = paddle.to_tensor(rois_num, dtype='int32')
+ rois_num = paddle.to_tensor(rois_num, dtype="int32")
return rois, rois_num
def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
- img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
- ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
- ).tolist(), img_size.numpy()
+ img, relations, texts, gt_bboxes, tag, img_size = (
+ img.numpy(),
+ relations.numpy(),
+ texts.numpy(),
+ gt_bboxes.numpy(),
+ tag.numpy().tolist(),
+ img_size.numpy(),
+ )
temp_relations, temp_texts, temp_gt_bboxes = [], [], []
h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
img = paddle.to_tensor(img[:, :, :h, :w])
@@ -157,25 +168,32 @@ def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
for i in range(batch):
num, recoder_len = tag[i][0], tag[i][1]
temp_relations.append(
- paddle.to_tensor(
- relations[i, :num, :num, :], dtype='float32'))
+ paddle.to_tensor(relations[i, :num, :num, :], dtype="float32")
+ )
temp_texts.append(
- paddle.to_tensor(
- texts[i, :num, :recoder_len], dtype='float32'))
+ paddle.to_tensor(texts[i, :num, :recoder_len], dtype="float32")
+ )
temp_gt_bboxes.append(
- paddle.to_tensor(
- gt_bboxes[i, :num, ...], dtype='float32'))
+ paddle.to_tensor(gt_bboxes[i, :num, ...], dtype="float32")
+ )
return img, temp_relations, temp_texts, temp_gt_bboxes
def forward(self, inputs):
img = inputs[0]
- relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
- 2], inputs[3], inputs[5], inputs[-1]
+ relations, texts, gt_bboxes, tag, img_size = (
+ inputs[1],
+ inputs[2],
+ inputs[3],
+ inputs[5],
+ inputs[-1],
+ )
img, relations, texts, gt_bboxes = self.pre_process(
- img, relations, texts, gt_bboxes, tag, img_size)
+ img, relations, texts, gt_bboxes, tag, img_size
+ )
x = self.img_feat(img)
boxes, rois_num = self.bbox2roi(gt_bboxes)
feats = paddle.vision.ops.roi_align(
- x, boxes, spatial_scale=1.0, output_size=7, boxes_num=rois_num)
+ x, boxes, spatial_scale=1.0, output_size=7, boxes_num=rois_num
+ )
feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
return [relations, texts, feats]
diff --git a/ppocr/modeling/backbones/rec_densenet.py b/ppocr/modeling/backbones/rec_densenet.py
index 65c5fa4f24..ad6e9e5c67 100644
--- a/ppocr/modeling/backbones/rec_densenet.py
+++ b/ppocr/modeling/backbones/rec_densenet.py
@@ -33,12 +33,12 @@ def __init__(self, nChannels, growthRate, use_dropout):
interChannels = 4 * growthRate
self.bn1 = nn.BatchNorm2D(interChannels)
self.conv1 = nn.Conv2D(
- nChannels, interChannels, kernel_size=1,
- bias_attr=None) # Xavier initialization
+ nChannels, interChannels, kernel_size=1, bias_attr=None
+ ) # Xavier initialization
self.bn2 = nn.BatchNorm2D(growthRate)
self.conv2 = nn.Conv2D(
- interChannels, growthRate, kernel_size=3, padding=1,
- bias_attr=None) # Xavier initialization
+ interChannels, growthRate, kernel_size=3, padding=1, bias_attr=None
+ ) # Xavier initialization
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
@@ -58,7 +58,8 @@ def __init__(self, nChannels, growthRate, use_dropout):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2D(nChannels)
self.conv1 = nn.Conv2D(
- nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False)
+ nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False
+ )
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
@@ -76,8 +77,7 @@ class Transition(nn.Layer):
def __init__(self, nChannels, out_channels, use_dropout):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2D(out_channels)
- self.conv1 = nn.Conv2D(
- nChannels, out_channels, kernel_size=1, bias_attr=False)
+ self.conv1 = nn.Conv2D(nChannels, out_channels, kernel_size=1, bias_attr=False)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
@@ -90,8 +90,9 @@ def forward(self, x):
class DenseNet(nn.Layer):
- def __init__(self, growthRate, reduction, bottleneck, use_dropout,
- input_channel, **kwargs):
+ def __init__(
+ self, growthRate, reduction, bottleneck, use_dropout, input_channel, **kwargs
+ ):
super(DenseNet, self).__init__()
nDenseBlocks = 16
@@ -103,27 +104,30 @@ def __init__(self, growthRate, reduction, bottleneck, use_dropout,
kernel_size=7,
padding=3,
stride=2,
- bias_attr=False)
- self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
- bottleneck, use_dropout)
+ bias_attr=False,
+ )
+ self.dense1 = self._make_dense(
+ nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout
+ )
nChannels += nDenseBlocks * growthRate
out_channels = int(math.floor(nChannels * reduction))
self.trans1 = Transition(nChannels, out_channels, use_dropout)
nChannels = out_channels
- self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
- bottleneck, use_dropout)
+ self.dense2 = self._make_dense(
+ nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout
+ )
nChannels += nDenseBlocks * growthRate
out_channels = int(math.floor(nChannels * reduction))
self.trans2 = Transition(nChannels, out_channels, use_dropout)
nChannels = out_channels
- self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
- bottleneck, use_dropout)
+ self.dense3 = self._make_dense(
+ nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout
+ )
self.out_channels = out_channels
- def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,
- use_dropout):
+ def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
diff --git a/ppocr/modeling/backbones/rec_efficientb3_pren.py b/ppocr/modeling/backbones/rec_efficientb3_pren.py
index 701e436c1e..d153ad6d87 100644
--- a/ppocr/modeling/backbones/rec_efficientb3_pren.py
+++ b/ppocr/modeling/backbones/rec_efficientb3_pren.py
@@ -27,18 +27,37 @@
import paddle.nn as nn
import paddle.nn.functional as F
-__all__ = ['EfficientNetb3']
-
-GlobalParams = collections.namedtuple('GlobalParams', [
- 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes',
- 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth',
- 'drop_connect_rate', 'image_size'
-])
-
-BlockArgs = collections.namedtuple('BlockArgs', [
- 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
- 'expand_ratio', 'id_skip', 'stride', 'se_ratio'
-])
+__all__ = ["EfficientNetb3"]
+
+GlobalParams = collections.namedtuple(
+ "GlobalParams",
+ [
+ "batch_norm_momentum",
+ "batch_norm_epsilon",
+ "dropout_rate",
+ "num_classes",
+ "width_coefficient",
+ "depth_coefficient",
+ "depth_divisor",
+ "min_depth",
+ "drop_connect_rate",
+ "image_size",
+ ],
+)
+
+BlockArgs = collections.namedtuple(
+ "BlockArgs",
+ [
+ "kernel_size",
+ "num_repeat",
+ "input_filters",
+ "output_filters",
+ "expand_ratio",
+ "id_skip",
+ "stride",
+ "se_ratio",
+ ],
+)
class BlockDecoder:
@@ -46,26 +65,28 @@ class BlockDecoder:
def _decode_block_string(block_string):
assert isinstance(block_string, str)
- ops = block_string.split('_')
+ ops = block_string.split("_")
options = {}
for op in ops:
- splits = re.split(r'(\d.*)', op)
+ splits = re.split(r"(\d.*)", op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
- assert (('s' in options and len(options['s']) == 1) or
- (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
+ assert ("s" in options and len(options["s"]) == 1) or (
+ len(options["s"]) == 2 and options["s"][0] == options["s"][1]
+ )
return BlockArgs(
- kernel_size=int(options['k']),
- num_repeat=int(options['r']),
- input_filters=int(options['i']),
- output_filters=int(options['o']),
- expand_ratio=int(options['e']),
- id_skip=('noskip' not in block_string),
- se_ratio=float(options['se']) if 'se' in options else None,
- stride=[int(options['s'][0])])
+ kernel_size=int(options["k"]),
+ num_repeat=int(options["r"]),
+ input_filters=int(options["i"]),
+ output_filters=int(options["o"]),
+ expand_ratio=int(options["e"]),
+ id_skip=("noskip" not in block_string),
+ se_ratio=float(options["se"]) if "se" in options else None,
+ stride=[int(options["s"][0])],
+ )
@staticmethod
def decode(string_list):
@@ -76,20 +97,22 @@ def decode(string_list):
return blocks_args
-def efficientnet(width_coefficient=None,
- depth_coefficient=None,
- dropout_rate=0.2,
- drop_connect_rate=0.2,
- image_size=None,
- num_classes=1000):
+def efficientnet(
+ width_coefficient=None,
+ depth_coefficient=None,
+ dropout_rate=0.2,
+ drop_connect_rate=0.2,
+ image_size=None,
+ num_classes=1000,
+):
blocks_args = [
- 'r1_k3_s11_e1_i32_o16_se0.25',
- 'r2_k3_s22_e6_i16_o24_se0.25',
- 'r2_k5_s22_e6_i24_o40_se0.25',
- 'r3_k3_s22_e6_i40_o80_se0.25',
- 'r3_k5_s11_e6_i80_o112_se0.25',
- 'r4_k5_s22_e6_i112_o192_se0.25',
- 'r1_k3_s11_e6_i192_o320_se0.25',
+ "r1_k3_s11_e1_i32_o16_se0.25",
+ "r2_k3_s22_e6_i16_o24_se0.25",
+ "r2_k5_s22_e6_i24_o40_se0.25",
+ "r3_k3_s22_e6_i40_o80_se0.25",
+ "r3_k5_s11_e6_i80_o112_se0.25",
+ "r4_k5_s22_e6_i112_o192_se0.25",
+ "r1_k3_s11_e6_i192_o320_se0.25",
]
blocks_args = BlockDecoder.decode(blocks_args)
@@ -103,14 +126,15 @@ def efficientnet(width_coefficient=None,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
- image_size=image_size, )
+ image_size=image_size,
+ )
return blocks_args, global_params
class EffUtils:
@staticmethod
def round_filters(filters, global_params):
- """ Calculate and round number of filters based on depth multiplier. """
+ """Calculate and round number of filters based on depth multiplier."""
multiplier = global_params.width_coefficient
if not multiplier:
return filters
@@ -118,15 +142,14 @@ def round_filters(filters, global_params):
min_depth = global_params.min_depth
filters *= multiplier
min_depth = min_depth or divisor
- new_filters = max(min_depth,
- int(filters + divisor / 2) // divisor * divisor)
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters:
new_filters += divisor
return int(new_filters)
@staticmethod
def round_repeats(repeats, global_params):
- """ Round number of filters based on depth multiplier. """
+ """Round number of filters based on depth multiplier."""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
@@ -137,8 +160,9 @@ class MbConvBlock(nn.Layer):
def __init__(self, block_args):
super(MbConvBlock, self).__init__()
self._block_args = block_args
- self.has_se = (self._block_args.se_ratio is not None) and \
- (0 < self._block_args.se_ratio <= 1)
+ self.has_se = (self._block_args.se_ratio is not None) and (
+ 0 < self._block_args.se_ratio <= 1
+ )
self.id_skip = block_args.id_skip
# expansion phase
@@ -159,15 +183,16 @@ def __init__(self, block_args):
groups=oup,
kernel_size=k,
stride=s,
- padding='same',
- bias_attr=False)
+ padding="same",
+ bias_attr=False,
+ )
self._bn1 = nn.BatchNorm(oup)
# squeeze and excitation layer, if desired
if self.has_se:
- num_squeezed_channels = max(1,
- int(self._block_args.input_filters *
- self._block_args.se_ratio))
+ num_squeezed_channels = max(
+ 1, int(self._block_args.input_filters * self._block_args.se_ratio)
+ )
self._se_reduce = nn.Conv2D(oup, num_squeezed_channels, 1)
self._se_expand = nn.Conv2D(num_squeezed_channels, oup, 1)
@@ -199,17 +224,14 @@ def forward(self, inputs, drop_connect_rate=None):
# squeeze and excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
- x_squeezed = self._se_expand(
- self._swish(self._se_reduce(x_squeezed)))
+ x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
x = F.sigmoid(x_squeezed) * x
x = self._bn2(self._project_conv(x))
# skip conntection and drop connect
- if self.id_skip and self._block_args.stride == 1 and \
- self.inp == self.final_oup:
+ if self.id_skip and self._block_args.stride == 1 and self.inp == self.final_oup:
if drop_connect_rate:
- x = self._drop_connect(
- x, p=drop_connect_rate, training=self.training)
+ x = self._drop_connect(x, p=drop_connect_rate, training=self.training)
x = x + inputs
return x
@@ -225,15 +247,14 @@ def __init__(self, in_channels):
"""
w, d, s, p = 1.2, 1.4, 64, 0.3
self._blocks_args, self._global_params = efficientnet(
- width_coefficient=w,
- depth_coefficient=d,
- dropout_rate=p,
- image_size=s)
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s
+ )
self.out_channels = []
# stem
out_channels = EffUtils.round_filters(32, self._global_params)
self._conv_stem = nn.Conv2D(
- in_channels, out_channels, 3, 2, padding='same', bias_attr=False)
+ in_channels, out_channels, 3, 2, padding="same", bias_attr=False
+ )
self._bn0 = nn.BatchNorm(out_channels)
# build blocks
@@ -243,23 +264,28 @@ def __init__(self, in_channels):
_concerned_idx = 0
for i, block_args in enumerate(self._blocks_args):
block_args = block_args._replace(
- input_filters=EffUtils.round_filters(block_args.input_filters,
- self._global_params),
- output_filters=EffUtils.round_filters(block_args.output_filters,
- self._global_params),
- num_repeat=EffUtils.round_repeats(block_args.num_repeat,
- self._global_params))
- self._blocks.append(
- self.add_sublayer(f"{i}-0", MbConvBlock(block_args)))
+ input_filters=EffUtils.round_filters(
+ block_args.input_filters, self._global_params
+ ),
+ output_filters=EffUtils.round_filters(
+ block_args.output_filters, self._global_params
+ ),
+ num_repeat=EffUtils.round_repeats(
+ block_args.num_repeat, self._global_params
+ ),
+ )
+ self._blocks.append(self.add_sublayer(f"{i}-0", MbConvBlock(block_args)))
_concerned_idx += 1
if _concerned_idx in self._concerned_block_idxes:
self.out_channels.append(block_args.output_filters)
if block_args.num_repeat > 1:
block_args = block_args._replace(
- input_filters=block_args.output_filters, stride=1)
+ input_filters=block_args.output_filters, stride=1
+ )
for j in range(block_args.num_repeat - 1):
self._blocks.append(
- self.add_sublayer(f'{i}-{j+1}', MbConvBlock(block_args)))
+ self.add_sublayer(f"{i}-{j+1}", MbConvBlock(block_args))
+ )
_concerned_idx += 1
if _concerned_idx in self._concerned_block_idxes:
self.out_channels.append(block_args.output_filters)
diff --git a/ppocr/modeling/backbones/rec_hgnet.py b/ppocr/modeling/backbones/rec_hgnet.py
index d990453308..1d59f2501c 100644
--- a/ppocr/modeling/backbones/rec_hgnet.py
+++ b/ppocr/modeling/backbones/rec_hgnet.py
@@ -21,18 +21,14 @@
from paddle import ParamAttr
kaiming_normal_ = KaimingNormal()
-zeros_ = Constant(value=0.)
-ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
class ConvBNAct(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- groups=1,
- use_act=True):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True
+ ):
super().__init__()
self.use_act = use_act
self.conv = Conv2D(
@@ -42,11 +38,13 @@ def __init__(self,
stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = BatchNorm2D(
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
- bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ )
if self.use_act:
self.act = ReLU()
@@ -67,7 +65,8 @@ def __init__(self, channels):
out_channels=channels,
kernel_size=1,
stride=1,
- padding=0)
+ padding=0,
+ )
self.sigmoid = nn.Sigmoid()
def forward(self, x):
@@ -80,12 +79,13 @@ def forward(self, x):
class HG_Block(nn.Layer):
def __init__(
- self,
- in_channels,
- mid_channels,
- out_channels,
- layer_num,
- identity=False, ):
+ self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ layer_num,
+ identity=False,
+ ):
super().__init__()
self.identity = identity
@@ -95,14 +95,18 @@ def __init__(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=3,
- stride=1))
+ stride=1,
+ )
+ )
for _ in range(layer_num - 1):
self.layers.append(
ConvBNAct(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
- stride=1))
+ stride=1,
+ )
+ )
# feature aggregation
total_channels = in_channels + layer_num * mid_channels
@@ -110,7 +114,8 @@ def __init__(
in_channels=total_channels,
out_channels=out_channels,
kernel_size=1,
- stride=1)
+ stride=1,
+ )
self.att = ESEModule(out_channels)
def forward(self, x):
@@ -129,14 +134,16 @@ def forward(self, x):
class HG_Stage(nn.Layer):
- def __init__(self,
- in_channels,
- mid_channels,
- out_channels,
- block_num,
- layer_num,
- downsample=True,
- stride=[2, 1]):
+ def __init__(
+ self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ block_num,
+ layer_num,
+ downsample=True,
+ stride=[2, 1],
+ ):
super().__init__()
self.downsample = downsample
if downsample:
@@ -146,24 +153,19 @@ def __init__(self,
kernel_size=3,
stride=stride,
groups=in_channels,
- use_act=False)
+ use_act=False,
+ )
blocks_list = []
blocks_list.append(
- HG_Block(
- in_channels,
- mid_channels,
- out_channels,
- layer_num,
- identity=False))
+ HG_Block(in_channels, mid_channels, out_channels, layer_num, identity=False)
+ )
for _ in range(block_num - 1):
blocks_list.append(
HG_Block(
- out_channels,
- mid_channels,
- out_channels,
- layer_num,
- identity=True))
+ out_channels, mid_channels, out_channels, layer_num, identity=True
+ )
+ )
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
@@ -189,29 +191,31 @@ class PPHGNet(nn.Layer):
"""
def __init__(
- self,
- stem_channels,
- stage_config,
- layer_num,
- in_channels=3,
- det=False,
- out_indices=None, ):
+ self,
+ stem_channels,
+ stage_config,
+ layer_num,
+ in_channels=3,
+ det=False,
+ out_indices=None,
+ ):
super().__init__()
self.det = det
- self.out_indices = out_indices if out_indices is not None else [
- 0, 1, 2, 3
- ]
+ self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
# stem
stem_channels.insert(0, in_channels)
- self.stem = nn.Sequential(* [
- ConvBNAct(
- in_channels=stem_channels[i],
- out_channels=stem_channels[i + 1],
- kernel_size=3,
- stride=2 if i == 0 else 1) for i in range(
- len(stem_channels) - 1)
- ])
+ self.stem = nn.Sequential(
+ *[
+ ConvBNAct(
+ in_channels=stem_channels[i],
+ out_channels=stem_channels[i + 1],
+ kernel_size=3,
+ stride=2 if i == 0 else 1,
+ )
+ for i in range(len(stem_channels) - 1)
+ ]
+ )
if self.det:
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
@@ -219,11 +223,25 @@ def __init__(
self.stages = nn.LayerList()
self.out_channels = []
for block_id, k in enumerate(stage_config):
- in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[
- k]
+ (
+ in_channels,
+ mid_channels,
+ out_channels,
+ block_num,
+ downsample,
+ stride,
+ ) = stage_config[k]
self.stages.append(
- HG_Stage(in_channels, mid_channels, out_channels, block_num,
- layer_num, downsample, stride))
+ HG_Stage(
+ in_channels,
+ mid_channels,
+ out_channels,
+ block_num,
+ layer_num,
+ downsample,
+ stride,
+ )
+ )
if block_id in self.out_indices:
self.out_channels.append(out_channels)
@@ -281,10 +299,8 @@ def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
}
model = PPHGNet(
- stem_channels=[48, 48, 96],
- stage_config=stage_config,
- layer_num=5,
- **kwargs)
+ stem_channels=[48, 48, 96], stage_config=stage_config, layer_num=5, **kwargs
+ )
return model
@@ -319,7 +335,8 @@ def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
stage_config=stage_config_det if det else stage_config_rec,
layer_num=6,
det=det,
- **kwargs)
+ **kwargs
+ )
return model
@@ -346,5 +363,6 @@ def PPHGNet_base(pretrained=False, use_ssld=True, **kwargs):
stage_config=stage_config,
layer_num=7,
dropout_prob=0.2,
- **kwargs)
+ **kwargs
+ )
return model
diff --git a/ppocr/modeling/backbones/rec_lcnetv3.py b/ppocr/modeling/backbones/rec_lcnetv3.py
index ab0951761d..48a43f8bb5 100644
--- a/ppocr/modeling/backbones/rec_lcnetv3.py
+++ b/ppocr/modeling/backbones/rec_lcnetv3.py
@@ -21,33 +21,59 @@
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Constant, KaimingNormal
-from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Hardsigmoid, Hardswish, Identity, Linear, ReLU
+from paddle.nn import (
+ AdaptiveAvgPool2D,
+ BatchNorm2D,
+ Conv2D,
+ Dropout,
+ Hardsigmoid,
+ Hardswish,
+ Identity,
+ Linear,
+ ReLU,
+)
from paddle.regularizer import L2Decay
NET_CONFIG_det = {
"blocks2":
- #k, in_c, out_c, s, use_se
+ # k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
- "blocks5":
- [[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
- [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
- "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True],
- [5, 512, 512, 1, False], [5, 512, 512, 1, False]]
+ "blocks5": [
+ [3, 128, 256, 2, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ ],
+ "blocks6": [
+ [5, 256, 512, 2, True],
+ [5, 512, 512, 1, True],
+ [5, 512, 512, 1, False],
+ [5, 512, 512, 1, False],
+ ],
}
NET_CONFIG_rec = {
"blocks2":
- #k, in_c, out_c, s, use_se
+ # k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
- "blocks5":
- [[3, 128, 256, (1, 2), False], [5, 256, 256, 1, False],
- [5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
- "blocks6": [[5, 256, 512, (2, 1), True], [5, 512, 512, 1, True],
- [5, 512, 512, (2, 1), False], [5, 512, 512, 1, False]]
+ "blocks5": [
+ [3, 128, 256, (1, 2), False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False],
+ ],
+ "blocks6": [
+ [5, 256, 512, (2, 1), True],
+ [5, 512, 512, 1, True],
+ [5, 512, 512, (2, 1), False],
+ [5, 512, 512, 1, False],
+ ],
}
@@ -61,18 +87,23 @@ def make_divisible(v, divisor=16, min_value=None):
class LearnableAffineBlock(nn.Layer):
- def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0,
- lab_lr=0.1):
+ def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.1):
super().__init__()
self.scale = self.create_parameter(
- shape=[1, ],
+ shape=[
+ 1,
+ ],
default_initializer=Constant(value=scale_value),
- attr=ParamAttr(learning_rate=lr_mult * lab_lr))
+ attr=ParamAttr(learning_rate=lr_mult * lab_lr),
+ )
self.add_parameter("scale", self.scale)
self.bias = self.create_parameter(
- shape=[1, ],
+ shape=[
+ 1,
+ ],
default_initializer=Constant(value=bias_value),
- attr=ParamAttr(learning_rate=lr_mult * lab_lr))
+ attr=ParamAttr(learning_rate=lr_mult * lab_lr),
+ )
self.add_parameter("bias", self.bias)
def forward(self, x):
@@ -80,13 +111,9 @@ def forward(self, x):
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- groups=1,
- lr_mult=1.0):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride, groups=1, lr_mult=1.0
+ ):
super().__init__()
self.conv = Conv2D(
in_channels=in_channels,
@@ -95,16 +122,15 @@ def __init__(self,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(
- initializer=KaimingNormal(), learning_rate=lr_mult),
- bias_attr=False)
+ weight_attr=ParamAttr(initializer=KaimingNormal(), learning_rate=lr_mult),
+ bias_attr=False,
+ )
self.bn = BatchNorm2D(
out_channels,
- weight_attr=ParamAttr(
- regularizer=L2Decay(0.0), learning_rate=lr_mult),
- bias_attr=ParamAttr(
- regularizer=L2Decay(0.0), learning_rate=lr_mult))
+ weight_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult),
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult),
+ )
def forward(self, x):
x = self.conv(x)
@@ -127,15 +153,17 @@ def forward(self, x):
class LearnableRepLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- num_conv_branches=1,
- lr_mult=1.0,
- lab_lr=0.1):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ num_conv_branches=1,
+ lr_mult=1.0,
+ lab_lr=0.1,
+ ):
super().__init__()
self.is_repped = False
self.groups = groups
@@ -146,29 +174,37 @@ def __init__(self,
self.num_conv_branches = num_conv_branches
self.padding = (kernel_size - 1) // 2
- self.identity = BatchNorm2D(
- num_features=in_channels,
- weight_attr=ParamAttr(learning_rate=lr_mult),
- bias_attr=ParamAttr(learning_rate=lr_mult)
- ) if out_channels == in_channels and stride == 1 else None
+ self.identity = (
+ BatchNorm2D(
+ num_features=in_channels,
+ weight_attr=ParamAttr(learning_rate=lr_mult),
+ bias_attr=ParamAttr(learning_rate=lr_mult),
+ )
+ if out_channels == in_channels and stride == 1
+ else None
+ )
+
+ self.conv_kxk = nn.LayerList(
+ [
+ ConvBNLayer(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=groups,
+ lr_mult=lr_mult,
+ )
+ for _ in range(self.num_conv_branches)
+ ]
+ )
- self.conv_kxk = nn.LayerList([
+ self.conv_1x1 = (
ConvBNLayer(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- groups=groups,
- lr_mult=lr_mult) for _ in range(self.num_conv_branches)
- ])
-
- self.conv_1x1 = ConvBNLayer(
- in_channels,
- out_channels,
- 1,
- stride,
- groups=groups,
- lr_mult=lr_mult) if kernel_size > 1 else None
+ in_channels, out_channels, 1, stride, groups=groups, lr_mult=lr_mult
+ )
+ if kernel_size > 1
+ else None
+ )
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
@@ -206,7 +242,8 @@ def rep(self):
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
- groups=self.groups)
+ groups=self.groups,
+ )
self.reparam_conv.weight.set_value(kernel)
self.reparam_conv.bias.set_value(bias)
self.is_repped = True
@@ -219,8 +256,9 @@ def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
def _get_kernel_bias(self):
kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
- kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(kernel_conv_1x1,
- self.kernel_size // 2)
+ kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(
+ kernel_conv_1x1, self.kernel_size // 2
+ )
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
@@ -247,15 +285,16 @@ def _fuse_bn_tensor(self, branch):
eps = branch.bn._epsilon
else:
assert isinstance(branch, BatchNorm2D)
- if not hasattr(self, 'id_tensor'):
+ if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_value = paddle.zeros(
- (self.in_channels, input_dim, self.kernel_size,
- self.kernel_size),
- dtype=branch.weight.dtype)
+ (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
+ dtype=branch.weight.dtype,
+ )
for i in range(self.in_channels):
- kernel_value[i, i % input_dim, self.kernel_size // 2,
- self.kernel_size // 2] = 1
+ kernel_value[
+ i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
+ ] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch._mean
@@ -279,7 +318,8 @@ def __init__(self, channel, reduction=4, lr_mult=1.0):
stride=1,
padding=0,
weight_attr=ParamAttr(learning_rate=lr_mult),
- bias_attr=ParamAttr(learning_rate=lr_mult))
+ bias_attr=ParamAttr(learning_rate=lr_mult),
+ )
self.relu = ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
@@ -288,7 +328,8 @@ def __init__(self, channel, reduction=4, lr_mult=1.0):
stride=1,
padding=0,
weight_attr=ParamAttr(learning_rate=lr_mult),
- bias_attr=ParamAttr(learning_rate=lr_mult))
+ bias_attr=ParamAttr(learning_rate=lr_mult),
+ )
self.hardsigmoid = Hardsigmoid()
def forward(self, x):
@@ -303,15 +344,17 @@ def forward(self, x):
class LCNetV3Block(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- dw_size,
- use_se=False,
- conv_kxk_num=4,
- lr_mult=1.0,
- lab_lr=0.1):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ dw_size,
+ use_se=False,
+ conv_kxk_num=4,
+ lr_mult=1.0,
+ lab_lr=0.1,
+ ):
super().__init__()
self.use_se = use_se
self.dw_conv = LearnableRepLayer(
@@ -322,7 +365,8 @@ def __init__(self,
groups=in_channels,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
- lab_lr=lab_lr)
+ lab_lr=lab_lr,
+ )
if use_se:
self.se = SELayer(in_channels, lr_mult=lr_mult)
self.pw_conv = LearnableRepLayer(
@@ -332,7 +376,8 @@ def __init__(self,
stride=1,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
- lab_lr=lab_lr)
+ lab_lr=lab_lr,
+ )
def forward(self, x):
x = self.dw_conv(x)
@@ -343,13 +388,15 @@ def forward(self, x):
class PPLCNetV3(nn.Layer):
- def __init__(self,
- scale=1.0,
- conv_kxk_num=4,
- lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
- lab_lr=0.1,
- det=False,
- **kwargs):
+ def __init__(
+ self,
+ scale=1.0,
+ conv_kxk_num=4,
+ lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ lab_lr=0.1,
+ det=False,
+ **kwargs
+ ):
super().__init__()
self.scale = scale
self.lr_mult_list = lr_mult_list
@@ -357,90 +404,102 @@ def __init__(self,
self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
- assert isinstance(self.lr_mult_list, (
- list, tuple
- )), "lr_mult_list should be in (list, tuple) but got {}".format(
- type(self.lr_mult_list))
- assert len(self.lr_mult_list
- ) == 6, "lr_mult_list length should be 6 but got {}".format(
- len(self.lr_mult_list))
+ assert isinstance(
+ self.lr_mult_list, (list, tuple)
+ ), "lr_mult_list should be in (list, tuple) but got {}".format(
+ type(self.lr_mult_list)
+ )
+ assert (
+ len(self.lr_mult_list) == 6
+ ), "lr_mult_list length should be 6 but got {}".format(len(self.lr_mult_list))
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=make_divisible(16 * scale),
kernel_size=3,
stride=2,
- lr_mult=self.lr_mult_list[0])
-
- self.blocks2 = nn.Sequential(*[
- LCNetV3Block(
- in_channels=make_divisible(in_c * scale),
- out_channels=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se,
- conv_kxk_num=conv_kxk_num,
- lr_mult=self.lr_mult_list[1],
- lab_lr=lab_lr)
- for i, (k, in_c, out_c, s, se
- ) in enumerate(self.net_config["blocks2"])
- ])
-
- self.blocks3 = nn.Sequential(*[
- LCNetV3Block(
- in_channels=make_divisible(in_c * scale),
- out_channels=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se,
- conv_kxk_num=conv_kxk_num,
- lr_mult=self.lr_mult_list[2],
- lab_lr=lab_lr)
- for i, (k, in_c, out_c, s, se
- ) in enumerate(self.net_config["blocks3"])
- ])
-
- self.blocks4 = nn.Sequential(*[
- LCNetV3Block(
- in_channels=make_divisible(in_c * scale),
- out_channels=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se,
- conv_kxk_num=conv_kxk_num,
- lr_mult=self.lr_mult_list[3],
- lab_lr=lab_lr)
- for i, (k, in_c, out_c, s, se
- ) in enumerate(self.net_config["blocks4"])
- ])
-
- self.blocks5 = nn.Sequential(*[
- LCNetV3Block(
- in_channels=make_divisible(in_c * scale),
- out_channels=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se,
- conv_kxk_num=conv_kxk_num,
- lr_mult=self.lr_mult_list[4],
- lab_lr=lab_lr)
- for i, (k, in_c, out_c, s, se
- ) in enumerate(self.net_config["blocks5"])
- ])
-
- self.blocks6 = nn.Sequential(*[
- LCNetV3Block(
- in_channels=make_divisible(in_c * scale),
- out_channels=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se,
- conv_kxk_num=conv_kxk_num,
- lr_mult=self.lr_mult_list[5],
- lab_lr=lab_lr)
- for i, (k, in_c, out_c, s, se
- ) in enumerate(self.net_config["blocks6"])
- ])
+ lr_mult=self.lr_mult_list[0],
+ )
+
+ self.blocks2 = nn.Sequential(
+ *[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[1],
+ lab_lr=lab_lr,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks2"])
+ ]
+ )
+
+ self.blocks3 = nn.Sequential(
+ *[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[2],
+ lab_lr=lab_lr,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks3"])
+ ]
+ )
+
+ self.blocks4 = nn.Sequential(
+ *[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[3],
+ lab_lr=lab_lr,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks4"])
+ ]
+ )
+
+ self.blocks5 = nn.Sequential(
+ *[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[4],
+ lab_lr=lab_lr,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks5"])
+ ]
+ )
+
+ self.blocks6 = nn.Sequential(
+ *[
+ LCNetV3Block(
+ in_channels=make_divisible(in_c * scale),
+ out_channels=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se,
+ conv_kxk_num=conv_kxk_num,
+ lr_mult=self.lr_mult_list[5],
+ lab_lr=lab_lr,
+ )
+ for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks6"])
+ ]
+ )
self.out_channels = make_divisible(512 * scale)
if self.det:
@@ -452,15 +511,19 @@ def __init__(self,
make_divisible(self.net_config["blocks6"][-1][2] * scale),
]
- self.layer_list = nn.LayerList([
- nn.Conv2D(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
- nn.Conv2D(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
- nn.Conv2D(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
- nn.Conv2D(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0)
- ])
+ self.layer_list = nn.LayerList(
+ [
+ nn.Conv2D(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
+ nn.Conv2D(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
+ nn.Conv2D(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
+ nn.Conv2D(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0),
+ ]
+ )
self.out_channels = [
- int(mv_c[0] * scale), int(mv_c[1] * scale),
- int(mv_c[2] * scale), int(mv_c[3] * scale)
+ int(mv_c[0] * scale),
+ int(mv_c[1] * scale),
+ int(mv_c[2] * scale),
+ int(mv_c[3] * scale),
]
def forward(self, x):
diff --git a/ppocr/modeling/backbones/rec_micronet.py b/ppocr/modeling/backbones/rec_micronet.py
index b0ae5a14c3..6550c92018 100644
--- a/ppocr/modeling/backbones/rec_micronet.py
+++ b/ppocr/modeling/backbones/rec_micronet.py
@@ -76,7 +76,7 @@
def get_micronet_config(mode):
- return eval(mode + '_cfgs')
+ return eval(mode + "_cfgs")
class MaxGroupPooling(nn.Layer):
@@ -104,18 +104,26 @@ def __init__(self, inp, oups, kernel_size, stride):
self.conv = nn.Sequential(
nn.Conv2D(
inp,
- oup1, (kernel_size, 1), (stride, 1), (kernel_size // 2, 0),
+ oup1,
+ (kernel_size, 1),
+ (stride, 1),
+ (kernel_size // 2, 0),
bias_attr=False,
- groups=1),
+ groups=1,
+ ),
nn.BatchNorm2D(oup1),
nn.Conv2D(
oup1,
- oup1 * oup2, (1, kernel_size), (1, stride),
+ oup1 * oup2,
+ (1, kernel_size),
+ (1, stride),
(0, kernel_size // 2),
bias_attr=False,
- groups=oup1),
+ groups=oup1,
+ ),
nn.BatchNorm2D(oup1 * oup2),
- ChannelShuffle(oup1), )
+ ChannelShuffle(oup1),
+ )
def forward(self, x):
out = self.conv(x)
@@ -148,7 +156,8 @@ def __init__(self, inp, oup, stride, groups=(4, 4)):
g1, g2 = groups
self.stem = nn.Sequential(
SpatialSepConvSF(inp, groups, 3, stride),
- MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6())
+ MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6(),
+ )
def forward(self, x):
out = self.stem(x)
@@ -167,18 +176,25 @@ def __init__(self, inp, expand, kernel_size, stride):
self.conv = nn.Sequential(
nn.Conv2D(
inp,
- inp * exp1, (kernel_size, 1), (stride, 1),
+ inp * exp1,
+ (kernel_size, 1),
+ (stride, 1),
(kernel_size // 2, 0),
bias_attr=False,
- groups=inp),
+ groups=inp,
+ ),
nn.BatchNorm2D(inp * exp1),
nn.Conv2D(
hidden_dim,
- oup, (1, kernel_size),
- 1, (0, kernel_size // 2),
+ oup,
+ (1, kernel_size),
+ 1,
+ (0, kernel_size // 2),
bias_attr=False,
- groups=hidden_dim),
- nn.BatchNorm2D(oup))
+ groups=hidden_dim,
+ ),
+ nn.BatchNorm2D(oup),
+ )
def forward(self, x):
x = self.conv(x)
@@ -192,9 +208,9 @@ def __init__(self, inp, oup, groups=2):
self.oup = oup
self.groups = groups
self.conv = nn.Sequential(
- nn.Conv2D(
- inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]),
- nn.BatchNorm2D(oup))
+ nn.Conv2D(inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]),
+ nn.BatchNorm2D(oup),
+ )
def forward(self, x):
x = self.conv(x)
@@ -212,8 +228,10 @@ def __init__(self, inp, oup, kernel_size, stride):
stride,
kernel_size // 2,
bias_attr=False,
- groups=inp),
- nn.BatchNorm2D(oup))
+ groups=inp,
+ ),
+ nn.BatchNorm2D(oup),
+ )
def forward(self, x):
out = self.conv(x)
@@ -221,23 +239,27 @@ def forward(self, x):
class DYShiftMax(nn.Layer):
- def __init__(self,
- inp,
- oup,
- reduction=4,
- act_max=1.0,
- act_relu=True,
- init_a=[0.0, 0.0],
- init_b=[0.0, 0.0],
- relu_before_pool=False,
- g=None,
- expansion=False):
+ def __init__(
+ self,
+ inp,
+ oup,
+ reduction=4,
+ act_max=1.0,
+ act_relu=True,
+ init_a=[0.0, 0.0],
+ init_b=[0.0, 0.0],
+ relu_before_pool=False,
+ g=None,
+ expansion=False,
+ ):
super(DYShiftMax, self).__init__()
self.oup = oup
self.act_max = act_max * 2
self.act_relu = act_relu
- self.avg_pool = nn.Sequential(nn.ReLU() if relu_before_pool == True else
- nn.Sequential(), nn.AdaptiveAvgPool2D(1))
+ self.avg_pool = nn.Sequential(
+ nn.ReLU() if relu_before_pool == True else nn.Sequential(),
+ nn.AdaptiveAvgPool2D(1),
+ )
self.exp = 4 if act_relu else 2
self.init_a = init_a
@@ -250,7 +272,10 @@ def __init__(self,
self.fc = nn.Sequential(
nn.Linear(inp, squeeze),
- nn.ReLU(), nn.Linear(squeeze, oup * self.exp), nn.Hardsigmoid())
+ nn.ReLU(),
+ nn.Linear(squeeze, oup * self.exp),
+ nn.Hardsigmoid(),
+ )
if g is None:
g = 1
@@ -309,25 +334,27 @@ def forward(self, x):
class DYMicroBlock(nn.Layer):
- def __init__(self,
- inp,
- oup,
- kernel_size=3,
- stride=1,
- ch_exp=(2, 2),
- ch_per_group=4,
- groups_1x1=(1, 1),
- depthsep=True,
- shuffle=False,
- activation_cfg=None):
+ def __init__(
+ self,
+ inp,
+ oup,
+ kernel_size=3,
+ stride=1,
+ ch_exp=(2, 2),
+ ch_per_group=4,
+ groups_1x1=(1, 1),
+ depthsep=True,
+ shuffle=False,
+ activation_cfg=None,
+ ):
super(DYMicroBlock, self).__init__()
self.identity = stride == 1 and inp == oup
- y1, y2, y3 = activation_cfg['dy']
- act_reduction = 8 * activation_cfg['ratio']
- init_a = activation_cfg['init_a']
- init_b = activation_cfg['init_b']
+ y1, y2, y3 = activation_cfg["dy"]
+ act_reduction = 8 * activation_cfg["ratio"]
+ init_a = activation_cfg["init_a"]
+ init_b = activation_cfg["init_b"]
t1 = ch_exp
gs1 = ch_per_group
@@ -346,10 +373,14 @@ def __init__(self,
reduction=act_reduction,
init_b=init_b,
g=gs1,
- expansion=False) if y2 > 0 else nn.ReLU6(),
+ expansion=False,
+ )
+ if y2 > 0
+ else nn.ReLU6(),
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
ChannelShuffle(hidden_dim2 // 2)
- if shuffle and y2 != 0 else nn.Sequential(),
+ if shuffle and y2 != 0
+ else nn.Sequential(),
GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax(
oup,
@@ -360,10 +391,15 @@ def __init__(self,
reduction=act_reduction // 2,
init_b=[0.0, 0.0],
g=(g1, g2),
- expansion=False) if y3 > 0 else nn.Sequential(),
+ expansion=False,
+ )
+ if y3 > 0
+ else nn.Sequential(),
ChannelShuffle(g2) if shuffle else nn.Sequential(),
ChannelShuffle(oup // 2)
- if shuffle and oup % 2 == 0 and y3 != 0 else nn.Sequential(), )
+ if shuffle and oup % 2 == 0 and y3 != 0
+ else nn.Sequential(),
+ )
elif g2 == 0:
self.layers = nn.Sequential(
GroupConv(inp, hidden_dim2, gs1),
@@ -376,7 +412,11 @@ def __init__(self,
reduction=act_reduction,
init_b=[0.0, 0.0],
g=gs1,
- expansion=False) if y3 > 0 else nn.Sequential(), )
+ expansion=False,
+ )
+ if y3 > 0
+ else nn.Sequential(),
+ )
else:
self.layers = nn.Sequential(
GroupConv(inp, hidden_dim2, gs1),
@@ -389,11 +429,14 @@ def __init__(self,
reduction=act_reduction,
init_b=init_b,
g=gs1,
- expansion=False) if y1 > 0 else nn.ReLU6(),
+ expansion=False,
+ )
+ if y1 > 0
+ else nn.ReLU6(),
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride)
- if depthsep else
- DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride),
+ if depthsep
+ else DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride),
nn.Sequential(),
DYShiftMax(
hidden_dim2,
@@ -404,10 +447,15 @@ def __init__(self,
reduction=act_reduction,
init_b=init_b,
g=gs1,
- expansion=True) if y2 > 0 else nn.ReLU6(),
+ expansion=True,
+ )
+ if y2 > 0
+ else nn.ReLU6(),
ChannelShuffle(hidden_dim2 // 4)
- if shuffle and y1 != 0 and y2 != 0 else nn.Sequential()
- if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2),
+ if shuffle and y1 != 0 and y2 != 0
+ else nn.Sequential()
+ if y1 == 0 and y2 == 0
+ else ChannelShuffle(hidden_dim2 // 2),
GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax(
oup,
@@ -416,13 +464,17 @@ def __init__(self,
act_relu=False,
init_a=[1.0, 0.0],
reduction=act_reduction // 2
- if oup < hidden_dim2 else act_reduction,
+ if oup < hidden_dim2
+ else act_reduction,
init_b=[0.0, 0.0],
g=(g1, g2),
- expansion=False) if y3 > 0 else nn.Sequential(),
+ expansion=False,
+ )
+ if y3 > 0
+ else nn.Sequential(),
ChannelShuffle(g2) if shuffle else nn.Sequential(),
- ChannelShuffle(oup // 2)
- if shuffle and y3 != 0 else nn.Sequential(), )
+ ChannelShuffle(oup // 2) if shuffle and y3 != 0 else nn.Sequential(),
+ )
def forward(self, x):
identity = x
@@ -436,46 +488,45 @@ def forward(self, x):
class MicroNet(nn.Layer):
"""
- the MicroNet backbone network for recognition module.
- Args:
- mode(str): {'M0', 'M1', 'M2', 'M3'}
- Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
- Default: 'M3'.
+ the MicroNet backbone network for recognition module.
+ Args:
+ mode(str): {'M0', 'M1', 'M2', 'M3'}
+ Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
+ Default: 'M3'.
"""
- def __init__(self, mode='M3', **kwargs):
+ def __init__(self, mode="M3", **kwargs):
super(MicroNet, self).__init__()
self.cfgs = get_micronet_config(mode)
activation_cfg = {}
- if mode == 'M0':
+ if mode == "M0":
input_channel = 4
stem_groups = 2, 2
out_ch = 384
- activation_cfg['init_a'] = 1.0, 1.0
- activation_cfg['init_b'] = 0.0, 0.0
- elif mode == 'M1':
+ activation_cfg["init_a"] = 1.0, 1.0
+ activation_cfg["init_b"] = 0.0, 0.0
+ elif mode == "M1":
input_channel = 6
stem_groups = 3, 2
out_ch = 576
- activation_cfg['init_a'] = 1.0, 1.0
- activation_cfg['init_b'] = 0.0, 0.0
- elif mode == 'M2':
+ activation_cfg["init_a"] = 1.0, 1.0
+ activation_cfg["init_b"] = 0.0, 0.0
+ elif mode == "M2":
input_channel = 8
stem_groups = 4, 2
out_ch = 768
- activation_cfg['init_a'] = 1.0, 1.0
- activation_cfg['init_b'] = 0.0, 0.0
- elif mode == 'M3':
+ activation_cfg["init_a"] = 1.0, 1.0
+ activation_cfg["init_b"] = 0.0, 0.0
+ elif mode == "M3":
input_channel = 12
stem_groups = 4, 3
out_ch = 432
- activation_cfg['init_a'] = 1.0, 0.5
- activation_cfg['init_b'] = 0.0, 0.5
+ activation_cfg["init_a"] = 1.0, 0.5
+ activation_cfg["init_b"] = 0.0, 0.5
else:
- raise NotImplementedError("mode[" + mode +
- "_model] is not implemented!")
+ raise NotImplementedError("mode[" + mode + "_model] is not implemented!")
layers = [StemLayer(3, input_channel, stride=2, groups=stem_groups)]
@@ -485,8 +536,8 @@ def __init__(self, mode='M3', **kwargs):
t1 = (c1, c2)
gs1 = (g1, g2)
gs2 = (c3, g3, g4)
- activation_cfg['dy'] = [y1, y2, y3]
- activation_cfg['ratio'] = r
+ activation_cfg["dy"] = [y1, y2, y3]
+ activation_cfg["ratio"] = r
output_channel = c
layers.append(
@@ -500,7 +551,9 @@ def __init__(self, mode='M3', **kwargs):
groups_1x1=gs2,
depthsep=True,
shuffle=True,
- activation_cfg=activation_cfg, ))
+ activation_cfg=activation_cfg,
+ )
+ )
input_channel = output_channel
for i in range(1, n):
layers.append(
@@ -514,7 +567,9 @@ def __init__(self, mode='M3', **kwargs):
groups_1x1=gs2,
depthsep=True,
shuffle=True,
- activation_cfg=activation_cfg, ))
+ activation_cfg=activation_cfg,
+ )
+ )
input_channel = output_channel
self.features = nn.Sequential(*layers)
diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py
index 917e000d94..12831dcad8 100644
--- a/ppocr/modeling/backbones/rec_mobilenet_v3.py
+++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py
@@ -14,20 +14,26 @@
from paddle import nn
-from ppocr.modeling.backbones.det_mobilenet_v3 import ResidualUnit, ConvBNLayer, make_divisible
+from ppocr.modeling.backbones.det_mobilenet_v3 import (
+ ResidualUnit,
+ ConvBNLayer,
+ make_divisible,
+)
-__all__ = ['MobileNetV3']
+__all__ = ["MobileNetV3"]
class MobileNetV3(nn.Layer):
- def __init__(self,
- in_channels=3,
- model_name='small',
- scale=0.5,
- large_stride=None,
- small_stride=None,
- disable_se=False,
- **kwargs):
+ def __init__(
+ self,
+ in_channels=3,
+ model_name="small",
+ scale=0.5,
+ large_stride=None,
+ small_stride=None,
+ disable_se=False,
+ **kwargs
+ ):
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if small_stride is None:
@@ -35,58 +41,66 @@ def __init__(self,
if large_stride is None:
large_stride = [1, 2, 2, 2]
- assert isinstance(large_stride, list), "large_stride type must " \
- "be list but got {}".format(type(large_stride))
- assert isinstance(small_stride, list), "small_stride type must " \
- "be list but got {}".format(type(small_stride))
- assert len(large_stride) == 4, "large_stride length must be " \
- "4 but got {}".format(len(large_stride))
- assert len(small_stride) == 4, "small_stride length must be " \
- "4 but got {}".format(len(small_stride))
+ assert isinstance(
+ large_stride, list
+ ), "large_stride type must " "be list but got {}".format(type(large_stride))
+ assert isinstance(
+ small_stride, list
+ ), "small_stride type must " "be list but got {}".format(type(small_stride))
+ assert (
+ len(large_stride) == 4
+ ), "large_stride length must be " "4 but got {}".format(len(large_stride))
+ assert (
+ len(small_stride) == 4
+ ), "small_stride length must be " "4 but got {}".format(len(small_stride))
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
- [3, 16, 16, False, 'relu', large_stride[0]],
- [3, 64, 24, False, 'relu', (large_stride[1], 1)],
- [3, 72, 24, False, 'relu', 1],
- [5, 72, 40, True, 'relu', (large_stride[2], 1)],
- [5, 120, 40, True, 'relu', 1],
- [5, 120, 40, True, 'relu', 1],
- [3, 240, 80, False, 'hardswish', 1],
- [3, 200, 80, False, 'hardswish', 1],
- [3, 184, 80, False, 'hardswish', 1],
- [3, 184, 80, False, 'hardswish', 1],
- [3, 480, 112, True, 'hardswish', 1],
- [3, 672, 112, True, 'hardswish', 1],
- [5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
- [5, 960, 160, True, 'hardswish', 1],
- [5, 960, 160, True, 'hardswish', 1],
+ [3, 16, 16, False, "relu", large_stride[0]],
+ [3, 64, 24, False, "relu", (large_stride[1], 1)],
+ [3, 72, 24, False, "relu", 1],
+ [5, 72, 40, True, "relu", (large_stride[2], 1)],
+ [5, 120, 40, True, "relu", 1],
+ [5, 120, 40, True, "relu", 1],
+ [3, 240, 80, False, "hardswish", 1],
+ [3, 200, 80, False, "hardswish", 1],
+ [3, 184, 80, False, "hardswish", 1],
+ [3, 184, 80, False, "hardswish", 1],
+ [3, 480, 112, True, "hardswish", 1],
+ [3, 672, 112, True, "hardswish", 1],
+ [5, 672, 160, True, "hardswish", (large_stride[3], 1)],
+ [5, 960, 160, True, "hardswish", 1],
+ [5, 960, 160, True, "hardswish", 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
- [3, 16, 16, True, 'relu', (small_stride[0], 1)],
- [3, 72, 24, False, 'relu', (small_stride[1], 1)],
- [3, 88, 24, False, 'relu', 1],
- [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
- [5, 240, 40, True, 'hardswish', 1],
- [5, 240, 40, True, 'hardswish', 1],
- [5, 120, 48, True, 'hardswish', 1],
- [5, 144, 48, True, 'hardswish', 1],
- [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
- [5, 576, 96, True, 'hardswish', 1],
- [5, 576, 96, True, 'hardswish', 1],
+ [3, 16, 16, True, "relu", (small_stride[0], 1)],
+ [3, 72, 24, False, "relu", (small_stride[1], 1)],
+ [3, 88, 24, False, "relu", 1],
+ [5, 96, 40, True, "hardswish", (small_stride[2], 1)],
+ [5, 240, 40, True, "hardswish", 1],
+ [5, 240, 40, True, "hardswish", 1],
+ [5, 120, 48, True, "hardswish", 1],
+ [5, 144, 48, True, "hardswish", 1],
+ [5, 288, 96, True, "hardswish", (small_stride[3], 1)],
+ [5, 576, 96, True, "hardswish", 1],
+ [5, 576, 96, True, "hardswish", 1],
]
cls_ch_squeeze = 576
else:
- raise NotImplementedError("mode[" + model_name +
- "_model] is not implemented!")
+ raise NotImplementedError(
+ "mode[" + model_name + "_model] is not implemented!"
+ )
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
- assert scale in supported_scale, \
- "supported scales are {} but input scale is {}".format(supported_scale, scale)
+ assert (
+ scale in supported_scale
+ ), "supported scales are {} but input scale is {}".format(
+ supported_scale, scale
+ )
inplanes = 16
# conv1
@@ -98,11 +112,12 @@ def __init__(self,
padding=1,
groups=1,
if_act=True,
- act='hardswish')
+ act="hardswish",
+ )
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
- for (k, exp, c, se, nl, s) in cfg:
+ for k, exp, c, se, nl, s in cfg:
se = se and not self.disable_se
block_list.append(
ResidualUnit(
@@ -112,7 +127,9 @@ def __init__(self,
kernel_size=k,
stride=s,
use_se=se,
- act=nl))
+ act=nl,
+ )
+ )
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
@@ -125,7 +142,8 @@ def __init__(self,
padding=0,
groups=1,
if_act=True,
- act='hardswish')
+ act="hardswish",
+ )
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze)
diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py
index 2d4efe7209..13c9de8698 100644
--- a/ppocr/modeling/backbones/rec_mv1_enhance.py
+++ b/ppocr/modeling/backbones/rec_mv1_enhance.py
@@ -32,15 +32,17 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- num_channels,
- filter_size,
- num_filters,
- stride,
- padding,
- channels=None,
- num_groups=1,
- act='hard_swish'):
+ def __init__(
+ self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ channels=None,
+ num_groups=1,
+ act="hard_swish",
+ ):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
@@ -51,13 +53,15 @@ def __init__(self,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False)
+ bias_attr=False,
+ )
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
- bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ )
def forward(self, inputs):
y = self._conv(inputs)
@@ -66,16 +70,18 @@ def forward(self, inputs):
class DepthwiseSeparable(nn.Layer):
- def __init__(self,
- num_channels,
- num_filters1,
- num_filters2,
- num_groups,
- stride,
- scale,
- dw_size=3,
- padding=1,
- use_se=False):
+ def __init__(
+ self,
+ num_channels,
+ num_filters1,
+ num_filters2,
+ num_groups,
+ stride,
+ scale,
+ dw_size=3,
+ padding=1,
+ use_se=False,
+ ):
super(DepthwiseSeparable, self).__init__()
self.use_se = use_se
self._depthwise_conv = ConvBNLayer(
@@ -84,7 +90,8 @@ def __init__(self,
filter_size=dw_size,
stride=stride,
padding=padding,
- num_groups=int(num_groups * scale))
+ num_groups=int(num_groups * scale),
+ )
if use_se:
self._se = SEModule(int(num_filters1 * scale))
self._pointwise_conv = ConvBNLayer(
@@ -92,7 +99,8 @@ def __init__(self,
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
- padding=0)
+ padding=0,
+ )
def forward(self, inputs):
y = self._depthwise_conv(inputs)
@@ -103,13 +111,15 @@ def forward(self, inputs):
class MobileNetV1Enhance(nn.Layer):
- def __init__(self,
- in_channels=3,
- scale=0.5,
- last_conv_stride=1,
- last_pool_type='max',
- last_pool_kernel_size=[3, 2],
- **kwargs):
+ def __init__(
+ self,
+ in_channels=3,
+ scale=0.5,
+ last_conv_stride=1,
+ last_pool_type="max",
+ last_pool_kernel_size=[3, 2],
+ **kwargs
+ ):
super().__init__()
self.scale = scale
self.block_list = []
@@ -120,7 +130,8 @@ def __init__(self,
channels=3,
num_filters=int(32 * scale),
stride=2,
- padding=1)
+ padding=1,
+ )
conv2_1 = DepthwiseSeparable(
num_channels=int(32 * scale),
@@ -128,7 +139,8 @@ def __init__(self,
num_filters2=64,
num_groups=32,
stride=1,
- scale=scale)
+ scale=scale,
+ )
self.block_list.append(conv2_1)
conv2_2 = DepthwiseSeparable(
@@ -137,7 +149,8 @@ def __init__(self,
num_filters2=128,
num_groups=64,
stride=1,
- scale=scale)
+ scale=scale,
+ )
self.block_list.append(conv2_2)
conv3_1 = DepthwiseSeparable(
@@ -146,7 +159,8 @@ def __init__(self,
num_filters2=128,
num_groups=128,
stride=1,
- scale=scale)
+ scale=scale,
+ )
self.block_list.append(conv3_1)
conv3_2 = DepthwiseSeparable(
@@ -155,7 +169,8 @@ def __init__(self,
num_filters2=256,
num_groups=128,
stride=(2, 1),
- scale=scale)
+ scale=scale,
+ )
self.block_list.append(conv3_2)
conv4_1 = DepthwiseSeparable(
@@ -164,7 +179,8 @@ def __init__(self,
num_filters2=256,
num_groups=256,
stride=1,
- scale=scale)
+ scale=scale,
+ )
self.block_list.append(conv4_1)
conv4_2 = DepthwiseSeparable(
@@ -173,7 +189,8 @@ def __init__(self,
num_filters2=512,
num_groups=256,
stride=(2, 1),
- scale=scale)
+ scale=scale,
+ )
self.block_list.append(conv4_2)
for _ in range(5):
@@ -186,7 +203,8 @@ def __init__(self,
dw_size=5,
padding=2,
scale=scale,
- use_se=False)
+ use_se=False,
+ )
self.block_list.append(conv5)
conv5_6 = DepthwiseSeparable(
@@ -198,7 +216,8 @@ def __init__(self,
dw_size=5,
padding=2,
scale=scale,
- use_se=True)
+ use_se=True,
+ )
self.block_list.append(conv5_6)
conv6 = DepthwiseSeparable(
@@ -210,15 +229,17 @@ def __init__(self,
dw_size=5,
padding=2,
use_se=True,
- scale=scale)
+ scale=scale,
+ )
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
- if last_pool_type == 'avg':
+ if last_pool_type == "avg":
self.pool = nn.AvgPool2D(
kernel_size=last_pool_kernel_size,
stride=last_pool_kernel_size,
- padding=0)
+ padding=0,
+ )
else:
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
@@ -241,7 +262,8 @@ def __init__(self, channel, reduction=4):
stride=1,
padding=0,
weight_attr=ParamAttr(),
- bias_attr=ParamAttr())
+ bias_attr=ParamAttr(),
+ )
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
@@ -249,7 +271,8 @@ def __init__(self, channel, reduction=4):
stride=1,
padding=0,
weight_attr=ParamAttr(),
- bias_attr=ParamAttr())
+ bias_attr=ParamAttr(),
+ )
def forward(self, inputs):
outputs = self.avg_pool(inputs)
diff --git a/ppocr/modeling/backbones/rec_nrtr_mtb.py b/ppocr/modeling/backbones/rec_nrtr_mtb.py
index 9315a9456d..c486838492 100644
--- a/ppocr/modeling/backbones/rec_nrtr_mtb.py
+++ b/ppocr/modeling/backbones/rec_nrtr_mtb.py
@@ -25,17 +25,19 @@ def __init__(self, cnn_num, in_channels):
if self.cnn_num == 2:
for i in range(self.cnn_num):
self.block.add_sublayer(
- 'conv_{}'.format(i),
+ "conv_{}".format(i),
nn.Conv2D(
- in_channels=in_channels
- if i == 0 else 32 * (2**(i - 1)),
+ in_channels=in_channels if i == 0 else 32 * (2 ** (i - 1)),
out_channels=32 * (2**i),
kernel_size=3,
stride=2,
- padding=1))
- self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
- self.block.add_sublayer('bn_{}'.format(i),
- nn.BatchNorm2D(32 * (2**i)))
+ padding=1,
+ ),
+ )
+ self.block.add_sublayer("relu_{}".format(i), nn.ReLU())
+ self.block.add_sublayer(
+ "bn_{}".format(i), nn.BatchNorm2D(32 * (2**i))
+ )
def forward(self, images):
x = self.block(images)
@@ -43,6 +45,5 @@ def forward(self, images):
# (b, w, h, c)
x = paddle.transpose(x, [0, 3, 2, 1])
x_shape = x.shape
- x = paddle.reshape(
- x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
+ x = paddle.reshape(x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
return x
diff --git a/ppocr/modeling/backbones/rec_resnet_31.py b/ppocr/modeling/backbones/rec_resnet_31.py
index 46dc374008..2db2549eea 100644
--- a/ppocr/modeling/backbones/rec_resnet_31.py
+++ b/ppocr/modeling/backbones/rec_resnet_31.py
@@ -29,6 +29,7 @@
__all__ = ["ResNet31"]
+
def conv3x3(in_channel, out_channel, stride=1, conv_weight_attr=None):
return nn.Conv2D(
in_channel,
@@ -37,20 +38,29 @@ def conv3x3(in_channel, out_channel, stride=1, conv_weight_attr=None):
stride=stride,
padding=1,
weight_attr=conv_weight_attr,
- bias_attr=False)
+ bias_attr=False,
+ )
class BasicBlock(nn.Layer):
expansion = 1
- def __init__(self, in_channels, channels, stride=1, downsample=False, conv_weight_attr=None, bn_weight_attr=None):
+ def __init__(
+ self,
+ in_channels,
+ channels,
+ stride=1,
+ downsample=False,
+ conv_weight_attr=None,
+ bn_weight_attr=None,
+ ):
super().__init__()
- self.conv1 = conv3x3(in_channels, channels, stride,
- conv_weight_attr=conv_weight_attr)
+ self.conv1 = conv3x3(
+ in_channels, channels, stride, conv_weight_attr=conv_weight_attr
+ )
self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.relu = nn.ReLU()
- self.conv2 = conv3x3(channels, channels,
- conv_weight_attr=conv_weight_attr)
+ self.conv2 = conv3x3(channels, channels, conv_weight_attr=conv_weight_attr)
self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.downsample = downsample
if downsample:
@@ -61,8 +71,10 @@ def __init__(self, in_channels, channels, stride=1, downsample=False, conv_weigh
1,
stride,
weight_attr=conv_weight_attr,
- bias_attr=False),
- nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
+ bias_attr=False,
+ ),
+ nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr),
+ )
else:
self.downsample = nn.Sequential()
self.stride = stride
@@ -87,7 +99,7 @@ def forward(self, x):
class ResNet31(nn.Layer):
- '''
+ """
Args:
in_channels (int): Number of channels of input image tensor.
layers (list[int]): List of BasicBlock number for each stage.
@@ -95,15 +107,17 @@ class ResNet31(nn.Layer):
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
init_type (None | str): the config to control the initialization.
- '''
-
- def __init__(self,
- in_channels=3,
- layers=[1, 2, 5, 3],
- channels=[64, 128, 256, 256, 512, 512, 512],
- out_indices=None,
- last_stage_pool=False,
- init_type=None):
+ """
+
+ def __init__(
+ self,
+ in_channels=3,
+ layers=[1, 2, 5, 3],
+ channels=[64, 128, 256, 256, 512, 512, 512],
+ out_indices=None,
+ last_stage_pool=False,
+ init_type=None,
+ ):
super(ResNet31, self).__init__()
assert isinstance(in_channels, int)
assert isinstance(last_stage_pool, bool)
@@ -113,52 +127,99 @@ def __init__(self,
conv_weight_attr = None
bn_weight_attr = None
-
+
if init_type is not None:
- support_dict = ['KaimingNormal']
+ support_dict = ["KaimingNormal"]
assert init_type in support_dict, Exception(
- "resnet31 only support {}".format(support_dict))
+ "resnet31 only support {}".format(support_dict)
+ )
conv_weight_attr = nn.initializer.KaimingNormal()
- bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
+ bn_weight_attr = ParamAttr(
+ initializer=nn.initializer.Uniform(), learning_rate=1
+ )
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
- in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ in_channels,
+ channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ )
self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
- channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ channels[0],
+ channels[1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ )
self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
- self.pool2 = nn.MaxPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
- self.block2 = self._make_layer(channels[1], channels[2], layers[0],
- conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
+ self.pool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block2 = self._make_layer(
+ channels[1],
+ channels[2],
+ layers[0],
+ conv_weight_attr=conv_weight_attr,
+ bn_weight_attr=bn_weight_attr,
+ )
self.conv2 = nn.Conv2D(
- channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ channels[2],
+ channels[2],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ )
self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
- self.pool3 = nn.MaxPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
- self.block3 = self._make_layer(channels[2], channels[3], layers[1],
- conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
+ self.pool3 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block3 = self._make_layer(
+ channels[2],
+ channels[3],
+ layers[1],
+ conv_weight_attr=conv_weight_attr,
+ bn_weight_attr=bn_weight_attr,
+ )
self.conv3 = nn.Conv2D(
- channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ channels[3],
+ channels[3],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ )
self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2D(
- kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
- self.block4 = self._make_layer(channels[3], channels[4], layers[2],
- conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
+ kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True
+ )
+ self.block4 = self._make_layer(
+ channels[3],
+ channels[4],
+ layers[2],
+ conv_weight_attr=conv_weight_attr,
+ bn_weight_attr=bn_weight_attr,
+ )
self.conv4 = nn.Conv2D(
- channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ channels[4],
+ channels[4],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ )
self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
self.relu4 = nn.ReLU()
@@ -166,17 +227,36 @@ def __init__(self,
self.pool5 = None
if self.last_stage_pool:
self.pool5 = nn.MaxPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
- self.block5 = self._make_layer(channels[4], channels[5], layers[3],
- conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
+ kernel_size=2, stride=2, padding=0, ceil_mode=True
+ )
+ self.block5 = self._make_layer(
+ channels[4],
+ channels[5],
+ layers[3],
+ conv_weight_attr=conv_weight_attr,
+ bn_weight_attr=bn_weight_attr,
+ )
self.conv5 = nn.Conv2D(
- channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
+ channels[5],
+ channels[5],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ )
self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
- def _make_layer(self, input_channels, output_channels, blocks, conv_weight_attr=None, bn_weight_attr=None):
+ def _make_layer(
+ self,
+ input_channels,
+ output_channels,
+ blocks,
+ conv_weight_attr=None,
+ bn_weight_attr=None,
+ ):
layers = []
for _ in range(blocks):
downsample = None
@@ -188,13 +268,20 @@ def _make_layer(self, input_channels, output_channels, blocks, conv_weight_attr=
kernel_size=1,
stride=1,
weight_attr=conv_weight_attr,
- bias_attr=False),
- nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
+ bias_attr=False,
+ ),
+ nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr),
+ )
layers.append(
BasicBlock(
- input_channels, output_channels, downsample=downsample,
- conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr))
+ input_channels,
+ output_channels,
+ downsample=downsample,
+ conv_weight_attr=conv_weight_attr,
+ bn_weight_attr=bn_weight_attr,
+ )
+ )
input_channels = output_channels
return nn.Sequential(*layers)
@@ -210,11 +297,11 @@ def forward(self, x):
outs = []
for i in range(4):
layer_index = i + 2
- pool_layer = getattr(self, f'pool{layer_index}')
- block_layer = getattr(self, f'block{layer_index}')
- conv_layer = getattr(self, f'conv{layer_index}')
- bn_layer = getattr(self, f'bn{layer_index}')
- relu_layer = getattr(self, f'relu{layer_index}')
+ pool_layer = getattr(self, f"pool{layer_index}")
+ block_layer = getattr(self, f"block{layer_index}")
+ conv_layer = getattr(self, f"conv{layer_index}")
+ bn_layer = getattr(self, f"bn{layer_index}")
+ relu_layer = getattr(self, f"relu{layer_index}")
if pool_layer is not None:
x = pool_layer(x)
diff --git a/ppocr/modeling/backbones/rec_resnet_32.py b/ppocr/modeling/backbones/rec_resnet_32.py
index cbd19251a3..51059ef12d 100644
--- a/ppocr/modeling/backbones/rec_resnet_32.py
+++ b/ppocr/modeling/backbones/rec_resnet_32.py
@@ -26,6 +26,7 @@
conv_weight_attr = nn.initializer.KaimingNormal()
+
class ResNet32(nn.Layer):
"""
Feature Extractor is proposed in FAN Ref [1]
@@ -55,13 +56,15 @@ def forward(self, inputs):
"""
return self.ConvNet(inputs)
+
class BasicBlock(nn.Layer):
"""Res-net Basic Block"""
+
expansion = 1
- def __init__(self, inplanes, planes,
- stride=1, downsample=None,
- norm_type='BN', **kwargs):
+ def __init__(
+ self, inplanes, planes, stride=1, downsample=None, norm_type="BN", **kwargs
+ ):
"""
Args:
inplanes (int): input channel
@@ -92,10 +95,15 @@ def _conv3x3(self, in_planes, out_planes, stride=1):
"""
- return nn.Conv2D(in_planes, out_planes,
- kernel_size=3, stride=stride,
- padding=1, weight_attr=conv_weight_attr,
- bias_attr=False)
+ return nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
def forward(self, x):
residual = x
@@ -114,10 +122,11 @@ def forward(self, x):
return out
+
class ResNet(nn.Layer):
"""Res-Net network structure"""
- def __init__(self, input_channel,
- output_channel, block, layers):
+
+ def __init__(self, input_channel, output_channel, block, layers):
"""
Args:
@@ -128,78 +137,101 @@ def __init__(self, input_channel,
"""
super(ResNet, self).__init__()
- self.output_channel_block = [int(output_channel / 4),
- int(output_channel / 2),
- output_channel,
- output_channel]
+ self.output_channel_block = [
+ int(output_channel / 4),
+ int(output_channel / 2),
+ output_channel,
+ output_channel,
+ ]
self.inplanes = int(output_channel / 8)
- self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16),
- kernel_size=3, stride=1,
- padding=1,
- weight_attr=conv_weight_attr,
- bias_attr=False)
+ self.conv0_1 = nn.Conv2D(
+ input_channel,
+ int(output_channel / 16),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16))
- self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes,
- kernel_size=3, stride=1,
- padding=1,
- weight_attr=conv_weight_attr,
- bias_attr=False)
+ self.conv0_2 = nn.Conv2D(
+ int(output_channel / 16),
+ self.inplanes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
self.bn0_2 = nn.BatchNorm2D(self.inplanes)
self.relu = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
- self.layer1 = self._make_layer(block,
- self.output_channel_block[0],
- layers[0])
- self.conv1 = nn.Conv2D(self.output_channel_block[0],
- self.output_channel_block[0],
- kernel_size=3, stride=1,
- padding=1,
- weight_attr=conv_weight_attr,
- bias_attr=False)
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
+ self.conv1 = nn.Conv2D(
+ self.output_channel_block[0],
+ self.output_channel_block[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
self.bn1 = nn.BatchNorm2D(self.output_channel_block[0])
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
- self.layer2 = self._make_layer(block,
- self.output_channel_block[1],
- layers[1], stride=1)
- self.conv2 = nn.Conv2D(self.output_channel_block[1],
- self.output_channel_block[1],
- kernel_size=3, stride=1,
- padding=1,
- weight_attr=conv_weight_attr,
- bias_attr=False,)
+ self.layer2 = self._make_layer(
+ block, self.output_channel_block[1], layers[1], stride=1
+ )
+ self.conv2 = nn.Conv2D(
+ self.output_channel_block[1],
+ self.output_channel_block[1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
self.bn2 = nn.BatchNorm2D(self.output_channel_block[1])
- self.maxpool3 = nn.MaxPool2D(kernel_size=2,
- stride=(2, 1),
- padding=(0, 1))
- self.layer3 = self._make_layer(block, self.output_channel_block[2],
- layers[2], stride=1)
- self.conv3 = nn.Conv2D(self.output_channel_block[2],
- self.output_channel_block[2],
- kernel_size=3, stride=1,
- padding=1,
- weight_attr=conv_weight_attr,
- bias_attr=False)
+ self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=(2, 1), padding=(0, 1))
+ self.layer3 = self._make_layer(
+ block, self.output_channel_block[2], layers[2], stride=1
+ )
+ self.conv3 = nn.Conv2D(
+ self.output_channel_block[2],
+ self.output_channel_block[2],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
self.bn3 = nn.BatchNorm2D(self.output_channel_block[2])
- self.layer4 = self._make_layer(block, self.output_channel_block[3],
- layers[3], stride=1)
- self.conv4_1 = nn.Conv2D(self.output_channel_block[3],
- self.output_channel_block[3],
- kernel_size=2, stride=(2, 1),
- padding=(0, 1),
- weight_attr=conv_weight_attr,
- bias_attr=False)
+ self.layer4 = self._make_layer(
+ block, self.output_channel_block[3], layers[3], stride=1
+ )
+ self.conv4_1 = nn.Conv2D(
+ self.output_channel_block[3],
+ self.output_channel_block[3],
+ kernel_size=2,
+ stride=(2, 1),
+ padding=(0, 1),
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3])
- self.conv4_2 = nn.Conv2D(self.output_channel_block[3],
- self.output_channel_block[3],
- kernel_size=2, stride=1,
- padding=0,
- weight_attr=conv_weight_attr,
- bias_attr=False)
+ self.conv4_2 = nn.Conv2D(
+ self.output_channel_block[3],
+ self.output_channel_block[3],
+ kernel_size=2,
+ stride=1,
+ padding=0,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ )
self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3])
def _make_layer(self, block, planes, blocks, stride=1):
@@ -218,10 +250,14 @@ def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
- nn.Conv2D(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride,
- weight_attr=conv_weight_attr,
- bias_attr=False),
+ nn.Conv2D(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,
+ ),
nn.BatchNorm2D(planes * block.expansion),
)
diff --git a/ppocr/modeling/backbones/rec_resnet_45.py b/ppocr/modeling/backbones/rec_resnet_45.py
index 083eb7f488..634c9a248c 100644
--- a/ppocr/modeling/backbones/rec_resnet_45.py
+++ b/ppocr/modeling/backbones/rec_resnet_45.py
@@ -38,7 +38,8 @@ def conv1x1(in_planes, out_planes, stride=1):
kernel_size=1,
stride=1,
weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False)
+ bias_attr=False,
+ )
def conv3x3(in_channel, out_channel, stride=1):
@@ -49,7 +50,8 @@ def conv3x3(in_channel, out_channel, stride=1):
stride=stride,
padding=1,
weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False)
+ bias_attr=False,
+ )
class BasicBlock(nn.Layer):
@@ -84,11 +86,13 @@ def forward(self, x):
class ResNet45(nn.Layer):
- def __init__(self,
- in_channels=3,
- block=BasicBlock,
- layers=[3, 4, 6, 6, 3],
- strides=[2, 1, 2, 1, 1]):
+ def __init__(
+ self,
+ in_channels=3,
+ block=BasicBlock,
+ layers=[3, 4, 6, 6, 3],
+ strides=[2, 1, 2, 1, 1],
+ ):
self.inplanes = 32
super(ResNet45, self).__init__()
self.conv1 = nn.Conv2D(
@@ -98,7 +102,8 @@ def __init__(self,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
@@ -120,8 +125,10 @@ def _make_layer(self, block, planes, blocks, stride=1):
kernel_size=1,
stride=stride,
weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False),
- nn.BatchNorm2D(planes * block.expansion), )
+ bias_attr=False,
+ ),
+ nn.BatchNorm2D(planes * block.expansion),
+ )
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
@@ -132,7 +139,6 @@ def _make_layer(self, block, planes, blocks, stride=1):
return nn.Sequential(*layers)
def forward(self, x):
-
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py
index 782dc393ea..9b5a15ec67 100644
--- a/ppocr/modeling/backbones/rec_resnet_aster.py
+++ b/ppocr/modeling/backbones/rec_resnet_aster.py
@@ -25,18 +25,15 @@
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2D(
- in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias_attr=False)
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias_attr=False
+ )
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2D(
- in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
+ in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False
+ )
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
@@ -46,9 +43,7 @@ def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
dim_range = paddle.arange(0, feat_dim)
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
# [n_position, feat_dim]
- angles = paddle.unsqueeze(
- positions, axis=1) / paddle.unsqueeze(
- dim_range, axis=0)
+ angles = paddle.unsqueeze(positions, axis=1) / paddle.unsqueeze(dim_range, axis=0)
angles = paddle.cast(angles, "float32")
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
@@ -96,9 +91,11 @@ def __init__(self, with_lstm=True, n_group=1, in_channels=3):
kernel_size=(3, 3),
stride=1,
padding=1,
- bias_attr=False),
+ bias_attr=False,
+ ),
nn.BatchNorm2D(32),
- nn.ReLU())
+ nn.ReLU(),
+ )
self.inplanes = 32
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
@@ -117,7 +114,8 @@ def _make_layer(self, planes, blocks, stride):
downsample = None
if stride != [1, 1] or self.inplanes != planes:
downsample = nn.Sequential(
- conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
+ conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes)
+ )
layers = []
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
@@ -140,4 +138,4 @@ def forward(self, x):
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
- return cnn_feat
\ No newline at end of file
+ return cnn_feat
diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py
index 79efd6e41e..d259f1d7ea 100644
--- a/ppocr/modeling/backbones/rec_resnet_fpn.py
+++ b/ppocr/modeling/backbones/rec_resnet_fpn.py
@@ -1,16 +1,16 @@
-#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from __future__ import absolute_import
from __future__ import division
@@ -28,30 +28,15 @@ class ResNetFPN(nn.Layer):
def __init__(self, in_channels=1, layers=50, **kwargs):
super(ResNetFPN, self).__init__()
supported_layers = {
- 18: {
- 'depth': [2, 2, 2, 2],
- 'block_class': BasicBlock
- },
- 34: {
- 'depth': [3, 4, 6, 3],
- 'block_class': BasicBlock
- },
- 50: {
- 'depth': [3, 4, 6, 3],
- 'block_class': BottleneckBlock
- },
- 101: {
- 'depth': [3, 4, 23, 3],
- 'block_class': BottleneckBlock
- },
- 152: {
- 'depth': [3, 8, 36, 3],
- 'block_class': BottleneckBlock
- }
+ 18: {"depth": [2, 2, 2, 2], "block_class": BasicBlock},
+ 34: {"depth": [3, 4, 6, 3], "block_class": BasicBlock},
+ 50: {"depth": [3, 4, 6, 3], "block_class": BottleneckBlock},
+ 101: {"depth": [3, 4, 23, 3], "block_class": BottleneckBlock},
+ 152: {"depth": [3, 8, 36, 3], "block_class": BottleneckBlock},
}
stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
num_filters = [64, 128, 256, 512]
- self.depth = supported_layers[layers]['depth']
+ self.depth = supported_layers[layers]["depth"]
self.F = []
self.conv = ConvBNLayer(
in_channels=in_channels,
@@ -59,7 +44,8 @@ def __init__(self, in_channels=1, layers=50, **kwargs):
kernel_size=7,
stride=2,
act="relu",
- name="conv1")
+ name="conv1",
+ )
self.block_list = []
in_ch = 64
if layers >= 50:
@@ -78,7 +64,9 @@ def __init__(self, in_channels=1, layers=50, **kwargs):
in_channels=in_ch,
out_channels=num_filters[block],
stride=stride_list[block] if i == 0 else 1,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
in_ch = num_filters[block] * 4
self.block_list.append(block_list)
self.F.append(block_list)
@@ -97,7 +85,9 @@ def __init__(self, in_channels=1, layers=50, **kwargs):
out_channels=num_filters[block],
stride=stride_list[block] if i == 0 else 1,
is_first=block == i == 0,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
in_ch = basic_block.out_channels
self.block_list.append(basic_block)
out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
@@ -115,7 +105,10 @@ def __init__(self, in_channels=1, layers=50, **kwargs):
out_channels=out_ch_list[i],
kernel_size=1,
weight_attr=ParamAttr(trainable=True),
- bias_attr=ParamAttr(trainable=True))))
+ bias_attr=ParamAttr(trainable=True),
+ ),
+ )
+ )
self.base_block.append(
self.add_sublayer(
"F_{}_base_block_1".format(i),
@@ -125,7 +118,10 @@ def __init__(self, in_channels=1, layers=50, **kwargs):
kernel_size=3,
padding=1,
weight_attr=ParamAttr(trainable=True),
- bias_attr=ParamAttr(trainable=True))))
+ bias_attr=ParamAttr(trainable=True),
+ ),
+ )
+ )
self.base_block.append(
self.add_sublayer(
"F_{}_base_block_2".format(i),
@@ -133,7 +129,10 @@ def __init__(self, in_channels=1, layers=50, **kwargs):
num_channels=out_ch_list[i],
act="relu",
param_attr=ParamAttr(trainable=True),
- bias_attr=ParamAttr(trainable=True))))
+ bias_attr=ParamAttr(trainable=True),
+ ),
+ )
+ )
self.base_block.append(
self.add_sublayer(
"F_{}_base_block_3".format(i),
@@ -142,7 +141,10 @@ def __init__(self, in_channels=1, layers=50, **kwargs):
out_channels=512,
kernel_size=1,
bias_attr=ParamAttr(trainable=True),
- weight_attr=ParamAttr(trainable=True))))
+ weight_attr=ParamAttr(trainable=True),
+ ),
+ )
+ )
self.out_channels = 512
def __call__(self, x):
@@ -150,7 +152,7 @@ def __call__(self, x):
fpn_list = []
F = []
for i in range(len(self.depth)):
- fpn_list.append(np.sum(self.depth[:i + 1]))
+ fpn_list.append(np.sum(self.depth[: i + 1]))
for i, block in enumerate(self.block_list):
x = block(x)
@@ -175,14 +177,16 @@ def __call__(self, x):
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
in_channels=in_channels,
@@ -192,8 +196,9 @@ def __init__(self,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
- bias_attr=False, )
+ weight_attr=ParamAttr(name=name + ".conv2d.output.1.w_0"),
+ bias_attr=False,
+ )
if name == "conv1":
bn_name = "bn_" + name
@@ -202,10 +207,11 @@ def __init__(self,
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
- param_attr=ParamAttr(name=name + '.output.1.w_0'),
- bias_attr=ParamAttr(name=name + '.output.1.b_0'),
+ param_attr=ParamAttr(name=name + ".output.1.w_0"),
+ bias_attr=ParamAttr(name=name + ".output.1.b_0"),
moving_mean_name=bn_name + "_mean",
- moving_variance_name=bn_name + "_variance")
+ moving_variance_name=bn_name + "_variance",
+ )
def __call__(self, x):
x = self.conv(x)
@@ -220,11 +226,9 @@ def __init__(self, in_channels, out_channels, stride, name, is_first=False):
if in_channels != out_channels or stride != 1 or is_first == True:
if stride == (1, 1):
- self.conv = ConvBNLayer(
- in_channels, out_channels, 1, 1, name=name)
+ self.conv = ConvBNLayer(in_channels, out_channels, 1, 1, name=name)
else: # stride==(2,2)
- self.conv = ConvBNLayer(
- in_channels, out_channels, 1, stride, name=name)
+ self.conv = ConvBNLayer(in_channels, out_channels, 1, stride, name=name)
else:
self.use_conv = False
@@ -241,29 +245,33 @@ def __init__(self, in_channels, out_channels, stride, name):
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu',
- name=name + "_branch2a")
+ act="relu",
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2b")
+ act="relu",
+ name=name + "_branch2b",
+ )
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
- name=name + "_branch2c")
+ name=name + "_branch2c",
+ )
self.short = ShortCut(
in_channels=in_channels,
out_channels=out_channels * 4,
stride=stride,
is_first=False,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.out_channels = out_channels * 4
def forward(self, x):
@@ -282,21 +290,24 @@ def __init__(self, in_channels, out_channels, stride, name, is_first):
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
- act='relu',
+ act="relu",
stride=stride,
- name=name + "_branch2a")
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
- name=name + "_branch2b")
+ name=name + "_branch2b",
+ )
self.short = ShortCut(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
is_first=is_first,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.out_channels = out_channels
def forward(self, x):
diff --git a/ppocr/modeling/backbones/rec_resnet_rfl.py b/ppocr/modeling/backbones/rec_resnet_rfl.py
index fd317c6ea6..2b4e5e0b3c 100644
--- a/ppocr/modeling/backbones/rec_resnet_rfl.py
+++ b/ppocr/modeling/backbones/rec_resnet_rfl.py
@@ -26,21 +26,18 @@
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
kaiming_init_ = KaimingNormal()
-zeros_ = Constant(value=0.)
-ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
class BasicBlock(nn.Layer):
"""Res-net Basic Block"""
+
expansion = 1
- def __init__(self,
- inplanes,
- planes,
- stride=1,
- downsample=None,
- norm_type='BN',
- **kwargs):
+ def __init__(
+ self, inplanes, planes, stride=1, downsample=None, norm_type="BN", **kwargs
+ ):
"""
Args:
inplanes (int): input channel
@@ -60,14 +57,14 @@ def __init__(self,
self.stride = stride
def _conv3x3(self, in_planes, out_planes, stride=1):
-
return nn.Conv2D(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
- bias_attr=False)
+ bias_attr=False,
+ )
def forward(self, x):
residual = x
@@ -88,11 +85,7 @@ def forward(self, x):
class ResNetRFL(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels=512,
- use_cnt=True,
- use_seq=True):
+ def __init__(self, in_channels, out_channels=512, use_cnt=True, use_seq=True):
"""
Args:
@@ -106,8 +99,10 @@ def __init__(self,
self.out_channels = out_channels
self.out_channels_block = [
- int(self.out_channels / 4), int(self.out_channels / 2),
- self.out_channels, self.out_channels
+ int(self.out_channels / 4),
+ int(self.out_channels / 2),
+ self.out_channels,
+ self.out_channels,
]
block = BasicBlock
layers = [1, 2, 5, 3]
@@ -115,28 +110,31 @@ def __init__(self,
self.relu = nn.ReLU()
if self.use_seq:
- self.maxpool3 = nn.MaxPool2D(
- kernel_size=2, stride=(2, 1), padding=(0, 1))
+ self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=(2, 1), padding=(0, 1))
self.layer3 = self._make_layer(
- block, self.out_channels_block[2], layers[2], stride=1)
+ block, self.out_channels_block[2], layers[2], stride=1
+ )
self.conv3 = nn.Conv2D(
self.out_channels_block[2],
self.out_channels_block[2],
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn3 = nn.BatchNorm(self.out_channels_block[2])
self.layer4 = self._make_layer(
- block, self.out_channels_block[3], layers[3], stride=1)
+ block, self.out_channels_block[3], layers[3], stride=1
+ )
self.conv4_1 = nn.Conv2D(
self.out_channels_block[3],
self.out_channels_block[3],
kernel_size=2,
stride=(2, 1),
padding=(0, 1),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
self.conv4_2 = nn.Conv2D(
self.out_channels_block[3],
@@ -144,33 +142,37 @@ def __init__(self,
kernel_size=2,
stride=1,
padding=0,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
if self.use_cnt:
self.inplanes = int(self.out_channels // 2)
- self.v_maxpool3 = nn.MaxPool2D(
- kernel_size=2, stride=(2, 1), padding=(0, 1))
+ self.v_maxpool3 = nn.MaxPool2D(kernel_size=2, stride=(2, 1), padding=(0, 1))
self.v_layer3 = self._make_layer(
- block, self.out_channels_block[2], layers[2], stride=1)
+ block, self.out_channels_block[2], layers[2], stride=1
+ )
self.v_conv3 = nn.Conv2D(
self.out_channels_block[2],
self.out_channels_block[2],
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
self.v_layer4 = self._make_layer(
- block, self.out_channels_block[3], layers[3], stride=1)
+ block, self.out_channels_block[3], layers[3], stride=1
+ )
self.v_conv4_1 = nn.Conv2D(
self.out_channels_block[3],
self.out_channels_block[3],
kernel_size=2,
stride=(2, 1),
padding=(0, 1),
- bias_attr=False)
+ bias_attr=False,
+ )
self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
self.v_conv4_2 = nn.Conv2D(
self.out_channels_block[3],
@@ -178,11 +180,11 @@ def __init__(self,
kernel_size=2,
stride=1,
padding=0,
- bias_attr=False)
+ bias_attr=False,
+ )
self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
def _make_layer(self, block, planes, blocks, stride=1):
-
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
@@ -191,8 +193,10 @@ def _make_layer(self, block, planes, blocks, stride=1):
planes * block.expansion,
kernel_size=1,
stride=stride,
- bias_attr=False),
- nn.BatchNorm(planes * block.expansion), )
+ bias_attr=False,
+ ),
+ nn.BatchNorm(planes * block.expansion),
+ )
layers = list()
layers.append(block(self.inplanes, planes, stride, downsample))
@@ -246,8 +250,10 @@ def __init__(self, in_channels, out_channels, block, layers):
super(ResNetBase, self).__init__()
self.out_channels_block = [
- int(out_channels / 4), int(out_channels / 2), out_channels,
- out_channels
+ int(out_channels / 4),
+ int(out_channels / 2),
+ out_channels,
+ out_channels,
]
self.inplanes = int(out_channels / 8)
@@ -257,7 +263,8 @@ def __init__(self, in_channels, out_channels, block, layers):
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
self.conv0_2 = nn.Conv2D(
int(out_channels / 16),
@@ -265,32 +272,35 @@ def __init__(self, in_channels, out_channels, block, layers):
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn0_2 = nn.BatchNorm(self.inplanes)
self.relu = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
- self.layer1 = self._make_layer(block, self.out_channels_block[0],
- layers[0])
+ self.layer1 = self._make_layer(block, self.out_channels_block[0], layers[0])
self.conv1 = nn.Conv2D(
self.out_channels_block[0],
self.out_channels_block[0],
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn1 = nn.BatchNorm(self.out_channels_block[0])
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.layer2 = self._make_layer(
- block, self.out_channels_block[1], layers[1], stride=1)
+ block, self.out_channels_block[1], layers[1], stride=1
+ )
self.conv2 = nn.Conv2D(
self.out_channels_block[1],
self.out_channels_block[1],
kernel_size=3,
stride=1,
padding=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn2 = nn.BatchNorm(self.out_channels_block[1])
def _make_layer(self, block, planes, blocks, stride=1):
@@ -302,8 +312,10 @@ def _make_layer(self, block, planes, blocks, stride=1):
planes * block.expansion,
kernel_size=1,
stride=stride,
- bias_attr=False),
- nn.BatchNorm(planes * block.expansion), )
+ bias_attr=False,
+ ),
+ nn.BatchNorm(planes * block.expansion),
+ )
layers = list()
layers.append(block(self.inplanes, planes, stride, downsample))
@@ -337,12 +349,11 @@ def forward(self, x):
class RFLBase(nn.Layer):
- """ Reciprocal feature learning share backbone network"""
+ """Reciprocal feature learning share backbone network"""
def __init__(self, in_channels, out_channels=512):
super(RFLBase, self).__init__()
- self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock,
- [1, 2, 5, 3])
+ self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
def forward(self, inputs):
return self.ConvNet(inputs)
diff --git a/ppocr/modeling/backbones/rec_resnet_vd.py b/ppocr/modeling/backbones/rec_resnet_vd.py
index 0187deb96f..3dad51fba9 100644
--- a/ppocr/modeling/backbones/rec_resnet_vd.py
+++ b/ppocr/modeling/backbones/rec_resnet_vd.py
@@ -26,20 +26,22 @@
class ConvBNLayer(nn.Layer):
def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None, ):
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
- kernel_size=stride, stride=stride, padding=0, ceil_mode=True)
+ kernel_size=stride, stride=stride, padding=0, ceil_mode=True
+ )
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
@@ -48,7 +50,8 @@ def __init__(
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False)
+ bias_attr=False,
+ )
if name == "conv1":
bn_name = "bn_" + name
else:
@@ -56,10 +59,11 @@ def __init__(
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
+ param_attr=ParamAttr(name=bn_name + "_scale"),
+ bias_attr=ParamAttr(bn_name + "_offset"),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance",
+ )
def forward(self, inputs):
if self.is_vd_mode:
@@ -70,34 +74,39 @@ def forward(self, inputs):
class BottleneckBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu',
- name=name + "_branch2a")
+ act="relu",
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2b")
+ act="relu",
+ name=name + "_branch2b",
+ )
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
- name=name + "_branch2c")
+ name=name + "_branch2c",
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -106,7 +115,8 @@ def __init__(self,
kernel_size=1,
stride=stride,
is_vd_mode=not if_first and stride[0] != 1,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.shortcut = shortcut
@@ -126,13 +136,15 @@ def forward(self, inputs):
class BasicBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None,
+ ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -140,14 +152,16 @@ def __init__(self,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2a")
+ act="relu",
+ name=name + "_branch2a",
+ )
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
- name=name + "_branch2b")
+ name=name + "_branch2b",
+ )
if not shortcut:
self.short = ConvBNLayer(
@@ -156,7 +170,8 @@ def __init__(self,
kernel_size=1,
stride=stride,
is_vd_mode=not if_first and stride[0] != 1,
- name=name + "_branch1")
+ name=name + "_branch1",
+ )
self.shortcut = shortcut
@@ -179,9 +194,11 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
- assert layers in supported_layers, \
- "supported layers are {} but input layer is {}".format(
- supported_layers, layers)
+ assert (
+ layers in supported_layers
+ ), "supported layers are {} but input layer is {}".format(
+ supported_layers, layers
+ )
if layers == 18:
depth = [2, 2, 2, 2]
@@ -193,8 +210,7 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
- num_channels = [64, 256, 512,
- 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_channels = [64, 256, 512, 1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.conv1_1 = ConvBNLayer(
@@ -202,22 +218,25 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
out_channels=32,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_1")
+ act="relu",
+ name="conv1_1",
+ )
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_2")
+ act="relu",
+ name="conv1_2",
+ )
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_3")
+ act="relu",
+ name="conv1_3",
+ )
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.block_list = []
@@ -238,15 +257,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
else:
stride = (1, 1)
bottleneck_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block] * 4,
+ if i == 0
+ else num_filters[block] * 4,
out_channels=num_filters[block],
stride=stride,
shortcut=shortcut,
if_first=block == i == 0,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
shortcut = True
self.block_list.append(bottleneck_block)
self.out_channels = num_filters[block] * 4
@@ -261,15 +283,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
stride = (1, 1)
basic_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
+ "bb_%d_%d" % (block, i),
BasicBlock(
in_channels=num_channels[block]
- if i == 0 else num_filters[block],
+ if i == 0
+ else num_filters[block],
out_channels=num_filters[block],
stride=stride,
shortcut=shortcut,
if_first=block == i == 0,
- name=conv_name))
+ name=conv_name,
+ ),
+ )
shortcut = True
self.block_list.append(basic_block)
self.out_channels = num_filters[block]
diff --git a/ppocr/modeling/backbones/rec_shallow_cnn.py b/ppocr/modeling/backbones/rec_shallow_cnn.py
index 544f108d26..85c043d1f5 100644
--- a/ppocr/modeling/backbones/rec_shallow_cnn.py
+++ b/ppocr/modeling/backbones/rec_shallow_cnn.py
@@ -31,13 +31,9 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- num_channels,
- filter_size,
- num_filters,
- stride,
- padding,
- num_groups=1):
+ def __init__(
+ self, num_channels, filter_size, num_filters, stride, padding, num_groups=1
+ ):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
@@ -48,12 +44,14 @@ def __init__(self,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm2D(
num_filters,
weight_attr=ParamAttr(initializer=Uniform(0, 1)),
- bias_attr=ParamAttr(initializer=Constant(0)))
+ bias_attr=ParamAttr(initializer=Constant(0)),
+ )
self.relu = nn.ReLU()
def forward(self, inputs):
@@ -69,15 +67,12 @@ def __init__(self, in_channels=1, hidden_dim=512):
assert isinstance(in_channels, int)
assert isinstance(hidden_dim, int)
- self.conv1 = ConvBNLayer(
- in_channels, 3, hidden_dim // 2, stride=1, padding=1)
- self.conv2 = ConvBNLayer(
- hidden_dim // 2, 3, hidden_dim, stride=1, padding=1)
+ self.conv1 = ConvBNLayer(in_channels, 3, hidden_dim // 2, stride=1, padding=1)
+ self.conv2 = ConvBNLayer(hidden_dim // 2, 3, hidden_dim, stride=1, padding=1)
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = hidden_dim
def forward(self, x):
-
x = self.conv1(x)
x = self.pool(x)
diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py
index daddfeac43..427c87b324 100644
--- a/ppocr/modeling/backbones/rec_svtrnet.py
+++ b/ppocr/modeling/backbones/rec_svtrnet.py
@@ -19,21 +19,21 @@
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
-trunc_normal_ = TruncatedNormal(std=.02)
+trunc_normal_ = TruncatedNormal(std=0.02)
normal_ = Normal
-zeros_ = Constant(value=0.)
-ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
-def drop_path(x, drop_prob=0., training=False):
+def drop_path(x, drop_prob=0.0, training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
- if drop_prob == 0. or not training:
+ if drop_prob == 0.0 or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
- shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
@@ -41,15 +41,17 @@ def drop_path(x, drop_prob=0., training=False):
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=0,
- bias_attr=False,
- groups=1,
- act=nn.GELU):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias_attr=False,
+ groups=1,
+ act=nn.GELU,
+ ):
super().__init__()
self.conv = nn.Conv2D(
in_channels=in_channels,
@@ -58,9 +60,9 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- weight_attr=paddle.ParamAttr(
- initializer=nn.initializer.KaimingUniform()),
- bias_attr=bias_attr)
+ weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+ bias_attr=bias_attr,
+ )
self.norm = nn.BatchNorm2D(out_channels)
self.act = act()
@@ -72,8 +74,7 @@ def forward(self, inputs):
class DropPath(nn.Layer):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
@@ -92,12 +93,14 @@ def forward(self, input):
class Mlp(nn.Layer):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@@ -117,11 +120,12 @@ def forward(self, x):
class ConvMixer(nn.Layer):
def __init__(
- self,
- dim,
- num_heads=8,
- HW=[8, 25],
- local_k=[3, 3], ):
+ self,
+ dim,
+ num_heads=8,
+ HW=[8, 25],
+ local_k=[3, 3],
+ ):
super().__init__()
self.HW = HW
self.dim = dim
@@ -129,9 +133,11 @@ def __init__(
dim,
dim,
local_k,
- 1, [local_k[0] // 2, local_k[1] // 2],
+ 1,
+ [local_k[0] // 2, local_k[1] // 2],
groups=num_heads,
- weight_attr=ParamAttr(initializer=KaimingNormal()))
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
def forward(self, x):
h = self.HW[0]
@@ -143,16 +149,18 @@ def forward(self, x):
class Attention(nn.Layer):
- def __init__(self,
- dim,
- num_heads=8,
- mixer='Global',
- HW=None,
- local_k=[7, 11],
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ mixer="Global",
+ HW=None,
+ local_k=[7, 11],
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
super().__init__()
self.num_heads = num_heads
self.dim = dim
@@ -169,28 +177,31 @@ def __init__(self,
W = HW[1]
self.N = H * W
self.C = dim
- if mixer == 'Local' and HW is not None:
+ if mixer == "Local" and HW is not None:
hk = local_k[0]
wk = local_k[1]
- mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype='float32')
+ mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype="float32")
for h in range(0, H):
for w in range(0, W):
- mask[h * W + w, h:h + hk, w:w + wk] = 0.
- mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
- 2].flatten(1)
- mask_inf = paddle.full([H * W, H * W], '-inf', dtype='float32')
+ mask[h * W + w, h : h + hk, w : w + wk] = 0.0
+ mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
+ 1
+ )
+ mask_inf = paddle.full([H * W, H * W], "-inf", dtype="float32")
mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf)
self.mask = mask.unsqueeze([0, 1])
self.mixer = mixer
def forward(self, x):
- qkv = self.qkv(x).reshape(
- (0, -1, 3, self.num_heads, self.head_dim)).transpose(
- (2, 0, 3, 1, 4))
+ qkv = (
+ self.qkv(x)
+ .reshape((0, -1, 3, self.num_heads, self.head_dim))
+ .transpose((2, 0, 3, 1, 4))
+ )
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
- attn = (q.matmul(k.transpose((0, 1, 3, 2))))
- if self.mixer == 'Local':
+ attn = q.matmul(k.transpose((0, 1, 3, 2)))
+ if self.mixer == "Local":
attn += self.mask
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
@@ -202,28 +213,30 @@ def forward(self, x):
class Block(nn.Layer):
- def __init__(self,
- dim,
- num_heads,
- mixer='Global',
- local_mixer=[7, 11],
- HW=None,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer='nn.LayerNorm',
- epsilon=1e-6,
- prenorm=True):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer="Global",
+ local_mixer=[7, 11],
+ HW=None,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-6,
+ prenorm=True,
+ ):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
else:
self.norm1 = norm_layer(dim)
- if mixer == 'Global' or mixer == 'Local':
+ if mixer == "Global" or mixer == "Local":
self.mixer = Attention(
dim,
num_heads=num_heads,
@@ -233,24 +246,26 @@ def __init__(self,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
- proj_drop=drop)
- elif mixer == 'Conv':
- self.mixer = ConvMixer(
- dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
+ proj_drop=drop,
+ )
+ elif mixer == "Conv":
+ self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
else:
raise TypeError("The mixer must be one of [Global, Local, Conv]")
- self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
- self.mlp = Mlp(in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
self.prenorm = prenorm
def forward(self, x):
@@ -264,24 +279,24 @@ def forward(self, x):
class PatchEmbed(nn.Layer):
- """ Image to Patch Embedding
- """
+ """Image to Patch Embedding"""
- def __init__(self,
- img_size=[32, 100],
- in_channels=3,
- embed_dim=768,
- sub_num=2,
- patch_size=[4, 4],
- mode='pope'):
+ def __init__(
+ self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=768,
+ sub_num=2,
+ patch_size=[4, 4],
+ mode="pope",
+ ):
super().__init__()
- num_patches = (img_size[1] // (2 ** sub_num)) * \
- (img_size[0] // (2 ** sub_num))
+ num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
self.img_size = img_size
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
- if mode == 'pope':
+ if mode == "pope":
if sub_num == 2:
self.proj = nn.Sequential(
ConvBNLayer(
@@ -291,7 +306,8 @@ def __init__(self,
stride=2,
padding=1,
act=nn.GELU,
- bias_attr=None),
+ bias_attr=None,
+ ),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
@@ -299,7 +315,9 @@ def __init__(self,
stride=2,
padding=1,
act=nn.GELU,
- bias_attr=None))
+ bias_attr=None,
+ ),
+ )
if sub_num == 3:
self.proj = nn.Sequential(
ConvBNLayer(
@@ -309,7 +327,8 @@ def __init__(self,
stride=2,
padding=1,
act=nn.GELU,
- bias_attr=None),
+ bias_attr=None,
+ ),
ConvBNLayer(
in_channels=embed_dim // 4,
out_channels=embed_dim // 2,
@@ -317,7 +336,8 @@ def __init__(self,
stride=2,
padding=1,
act=nn.GELU,
- bias_attr=None),
+ bias_attr=None,
+ ),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
@@ -325,36 +345,45 @@ def __init__(self,
stride=2,
padding=1,
act=nn.GELU,
- bias_attr=None))
- elif mode == 'linear':
+ bias_attr=None,
+ ),
+ )
+ elif mode == "linear":
self.proj = nn.Conv2D(
- 1, embed_dim, kernel_size=patch_size, stride=patch_size)
- self.num_patches = img_size[0] // patch_size[0] * img_size[
- 1] // patch_size[1]
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+ self.num_patches = (
+ img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
+ )
def forward(self, x):
B, C, H, W = x.shape
- assert H == self.img_size[0] and W == self.img_size[1], \
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose((0, 2, 1))
return x
class SubSample(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- types='Pool',
- stride=[2, 1],
- sub_norm='nn.LayerNorm',
- act=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ types="Pool",
+ stride=[2, 1],
+ sub_norm="nn.LayerNorm",
+ act=None,
+ ):
super().__init__()
self.types = types
- if types == 'Pool':
+ if types == "Pool":
self.avgpool = nn.AvgPool2D(
- kernel_size=[3, 5], stride=stride, padding=[1, 2])
+ kernel_size=[3, 5], stride=stride, padding=[1, 2]
+ )
self.maxpool = nn.MaxPool2D(
- kernel_size=[3, 5], stride=stride, padding=[1, 2])
+ kernel_size=[3, 5], stride=stride, padding=[1, 2]
+ )
self.proj = nn.Linear(in_channels, out_channels)
else:
self.conv = nn.Conv2D(
@@ -363,7 +392,8 @@ def __init__(self,
kernel_size=3,
stride=stride,
padding=1,
- weight_attr=ParamAttr(initializer=KaimingNormal()))
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
self.norm = eval(sub_norm)(out_channels)
if act is not None:
self.act = act()
@@ -371,8 +401,7 @@ def __init__(self,
self.act = None
def forward(self, x):
-
- if self.types == 'Pool':
+ if self.types == "Pool":
x1 = self.avgpool(x)
x2 = self.maxpool(x)
x = (x1 + x2) * 0.5
@@ -389,130 +418,150 @@ def forward(self, x):
class SVTRNet(nn.Layer):
def __init__(
- self,
- img_size=[32, 100],
- in_channels=3,
- embed_dim=[64, 128, 256],
- depth=[3, 6, 3],
- num_heads=[2, 4, 8],
- mixer=['Local'] * 6 + ['Global'] *
- 6, # Local atten, Global atten, Conv
- local_mixer=[[7, 11], [7, 11], [7, 11]],
- patch_merging='Conv', # Conv, Pool, None
- mlp_ratio=4,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- last_drop=0.1,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- norm_layer='nn.LayerNorm',
- sub_norm='nn.LayerNorm',
- epsilon=1e-6,
- out_channels=192,
- out_char_num=25,
- block_unit='Block',
- act='nn.GELU',
- last_stage=True,
- sub_num=2,
- prenorm=True,
- use_lenhead=False,
- **kwargs):
+ self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
+ patch_merging="Conv", # Conv, Pool, None
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer="nn.LayerNorm",
+ sub_norm="nn.LayerNorm",
+ epsilon=1e-6,
+ out_channels=192,
+ out_char_num=25,
+ block_unit="Block",
+ act="nn.GELU",
+ last_stage=True,
+ sub_num=2,
+ prenorm=True,
+ use_lenhead=False,
+ **kwargs,
+ ):
super().__init__()
self.img_size = img_size
self.embed_dim = embed_dim
self.out_channels = out_channels
self.prenorm = prenorm
- patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
+ patch_merging = (
+ None
+ if patch_merging != "Conv" and patch_merging != "Pool"
+ else patch_merging
+ )
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim[0],
- sub_num=sub_num)
+ sub_num=sub_num,
+ )
num_patches = self.patch_embed.num_patches
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
self.pos_embed = self.create_parameter(
- shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
+ shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_
+ )
self.add_parameter("pos_embed", self.pos_embed)
self.pos_drop = nn.Dropout(p=drop_rate)
Block_unit = eval(block_unit)
dpr = np.linspace(0, drop_path_rate, sum(depth))
- self.blocks1 = nn.LayerList([
- Block_unit(
- dim=embed_dim[0],
- num_heads=num_heads[0],
- mixer=mixer[0:depth[0]][i],
- HW=self.HW,
- local_mixer=local_mixer[0],
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- act_layer=eval(act),
- attn_drop=attn_drop_rate,
- drop_path=dpr[0:depth[0]][i],
- norm_layer=norm_layer,
- epsilon=epsilon,
- prenorm=prenorm) for i in range(depth[0])
- ])
+ self.blocks1 = nn.LayerList(
+ [
+ Block_unit(
+ dim=embed_dim[0],
+ num_heads=num_heads[0],
+ mixer=mixer[0 : depth[0]][i],
+ HW=self.HW,
+ local_mixer=local_mixer[0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[0 : depth[0]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm,
+ )
+ for i in range(depth[0])
+ ]
+ )
if patch_merging is not None:
self.sub_sample1 = SubSample(
embed_dim[0],
embed_dim[1],
sub_norm=sub_norm,
stride=[2, 1],
- types=patch_merging)
+ types=patch_merging,
+ )
HW = [self.HW[0] // 2, self.HW[1]]
else:
HW = self.HW
self.patch_merging = patch_merging
- self.blocks2 = nn.LayerList([
- Block_unit(
- dim=embed_dim[1],
- num_heads=num_heads[1],
- mixer=mixer[depth[0]:depth[0] + depth[1]][i],
- HW=HW,
- local_mixer=local_mixer[1],
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- act_layer=eval(act),
- attn_drop=attn_drop_rate,
- drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
- norm_layer=norm_layer,
- epsilon=epsilon,
- prenorm=prenorm) for i in range(depth[1])
- ])
+ self.blocks2 = nn.LayerList(
+ [
+ Block_unit(
+ dim=embed_dim[1],
+ num_heads=num_heads[1],
+ mixer=mixer[depth[0] : depth[0] + depth[1]][i],
+ HW=HW,
+ local_mixer=local_mixer[1],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm,
+ )
+ for i in range(depth[1])
+ ]
+ )
if patch_merging is not None:
self.sub_sample2 = SubSample(
embed_dim[1],
embed_dim[2],
sub_norm=sub_norm,
stride=[2, 1],
- types=patch_merging)
+ types=patch_merging,
+ )
HW = [self.HW[0] // 4, self.HW[1]]
else:
HW = self.HW
- self.blocks3 = nn.LayerList([
- Block_unit(
- dim=embed_dim[2],
- num_heads=num_heads[2],
- mixer=mixer[depth[0] + depth[1]:][i],
- HW=HW,
- local_mixer=local_mixer[2],
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- act_layer=eval(act),
- attn_drop=attn_drop_rate,
- drop_path=dpr[depth[0] + depth[1]:][i],
- norm_layer=norm_layer,
- epsilon=epsilon,
- prenorm=prenorm) for i in range(depth[2])
- ])
+ self.blocks3 = nn.LayerList(
+ [
+ Block_unit(
+ dim=embed_dim[2],
+ num_heads=num_heads[2],
+ mixer=mixer[depth[0] + depth[1] :][i],
+ HW=HW,
+ local_mixer=local_mixer[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] + depth[1] :][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm,
+ )
+ for i in range(depth[2])
+ ]
+ )
self.last_stage = last_stage
if last_stage:
self.avg_pool = nn.AdaptiveAvgPool2D([1, out_char_num])
@@ -522,7 +571,8 @@ def __init__(
kernel_size=1,
stride=1,
padding=0,
- bias_attr=False)
+ bias_attr=False,
+ )
self.hardswish = nn.Hardswish()
self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
if not prenorm:
@@ -531,8 +581,7 @@ def __init__(
if use_lenhead:
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
self.hardswish_len = nn.Hardswish()
- self.dropout_len = nn.Dropout(
- p=last_drop, mode="downscale_in_infer")
+ self.dropout_len = nn.Dropout(p=last_drop, mode="downscale_in_infer")
trunc_normal_(self.pos_embed)
self.apply(self._init_weights)
@@ -555,13 +604,17 @@ def forward_features(self, x):
if self.patch_merging is not None:
x = self.sub_sample1(
x.transpose([0, 2, 1]).reshape(
- [0, self.embed_dim[0], self.HW[0], self.HW[1]]))
+ [0, self.embed_dim[0], self.HW[0], self.HW[1]]
+ )
+ )
for blk in self.blocks2:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample2(
x.transpose([0, 2, 1]).reshape(
- [0, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
+ [0, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]
+ )
+ )
for blk in self.blocks3:
x = blk(x)
if not self.prenorm:
@@ -579,8 +632,8 @@ def forward(self, x):
else:
h = self.HW[0]
x = self.avg_pool(
- x.transpose([0, 2, 1]).reshape(
- [0, self.embed_dim[2], h, self.HW[1]]))
+ x.transpose([0, 2, 1]).reshape([0, self.embed_dim[2], h, self.HW[1]])
+ )
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
diff --git a/ppocr/modeling/backbones/rec_vit.py b/ppocr/modeling/backbones/rec_vit.py
index ff31377478..d492eea426 100644
--- a/ppocr/modeling/backbones/rec_vit.py
+++ b/ppocr/modeling/backbones/rec_vit.py
@@ -19,21 +19,21 @@
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
-trunc_normal_ = TruncatedNormal(std=.02)
+trunc_normal_ = TruncatedNormal(std=0.02)
normal_ = Normal
-zeros_ = Constant(value=0.)
-ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
-def drop_path(x, drop_prob=0., training=False):
+def drop_path(x, drop_prob=0.0, training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
- if drop_prob == 0. or not training:
+ if drop_prob == 0.0 or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
- shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
@@ -41,8 +41,7 @@ def drop_path(x, drop_prob=0., training=False):
class DropPath(nn.Layer):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
@@ -61,12 +60,14 @@ def forward(self, input):
class Mlp(nn.Layer):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@@ -85,13 +86,15 @@ def forward(self, x):
class Attention(nn.Layer):
- def __init__(self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
super().__init__()
self.num_heads = num_heads
self.dim = dim
@@ -102,15 +105,14 @@ def __init__(self,
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
-
def forward(self, x):
-
- qkv = paddle.reshape(self.qkv(x), (0, -1, 3, self.num_heads, self.dim //
- self.num_heads)).transpose((2, 0, 3, 1, 4))
+ qkv = paddle.reshape(
+ self.qkv(x), (0, -1, 3, self.num_heads, self.dim // self.num_heads)
+ ).transpose((2, 0, 3, 1, 4))
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
- attn = (q.matmul(k.transpose((0, 1, 3, 2))))
+ attn = q.matmul(k.transpose((0, 1, 3, 2)))
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
@@ -121,43 +123,48 @@ def forward(self, x):
class Block(nn.Layer):
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer='nn.LayerNorm',
- epsilon=1e-6,
- prenorm=True):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-6,
+ prenorm=True,
+ ):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
else:
self.norm1 = norm_layer(dim)
self.mixer = Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop)
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
- self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
- self.mlp = Mlp(in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
self.prenorm = prenorm
def forward(self, x):
@@ -172,60 +179,69 @@ def forward(self, x):
class ViT(nn.Layer):
def __init__(
- self,
- img_size=[32, 128],
- patch_size=[4,4],
- in_channels=3,
- embed_dim=384,
- depth=12,
- num_heads=6,
- mlp_ratio=4,
- qkv_bias=False,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- norm_layer='nn.LayerNorm',
- epsilon=1e-6,
- act='nn.GELU',
- prenorm=False,
- **kwargs):
+ self,
+ img_size=[32, 128],
+ patch_size=[4, 4],
+ in_channels=3,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-6,
+ act="nn.GELU",
+ prenorm=False,
+ **kwargs
+ ):
super().__init__()
self.embed_dim = embed_dim
self.out_channels = embed_dim
self.prenorm = prenorm
- self.patch_embed = nn.Conv2D(in_channels, embed_dim, patch_size, patch_size, padding=(0, 0))
+ self.patch_embed = nn.Conv2D(
+ in_channels, embed_dim, patch_size, patch_size, padding=(0, 0)
+ )
self.pos_embed = self.create_parameter(
- shape=[1, 257, embed_dim], default_initializer=zeros_)
+ shape=[1, 257, embed_dim], default_initializer=zeros_
+ )
self.add_parameter("pos_embed", self.pos_embed)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
- self.blocks1 = nn.LayerList([
- Block(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- act_layer=eval(act),
- attn_drop=attn_drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- epsilon=epsilon,
- prenorm=prenorm) for i in range(depth)
- ])
+ self.blocks1 = nn.LayerList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm,
+ )
+ for i in range(depth)
+ ]
+ )
if not prenorm:
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
-
+
self.avg_pool = nn.AdaptiveAvgPool2D([1, 25])
self.last_conv = nn.Conv2D(
- in_channels=embed_dim,
- out_channels=self.out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- bias_attr=False)
+ in_channels=embed_dim,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias_attr=False,
+ )
self.hardswish = nn.Hardswish()
self.dropout = nn.Dropout(p=0.1, mode="downscale_in_infer")
@@ -243,15 +259,14 @@ def _init_weights(self, m):
def forward(self, x):
x = self.patch_embed(x).flatten(2).transpose((0, 2, 1))
- x = x + self.pos_embed[:, 1:, :] #[:, :x.shape[1], :]
+ x = x + self.pos_embed[:, 1:, :] # [:, :x.shape[1], :]
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
if not self.prenorm:
x = self.norm(x)
-
- x = self.avg_pool(x.transpose([0, 2, 1]).reshape(
- [0, self.embed_dim, -1, 25]))
+
+ x = self.avg_pool(x.transpose([0, 2, 1]).reshape([0, self.embed_dim, -1, 25]))
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
diff --git a/ppocr/modeling/backbones/rec_vit_parseq.py b/ppocr/modeling/backbones/rec_vit_parseq.py
index 2bb7592a7c..1ede97b179 100644
--- a/ppocr/modeling/backbones/rec_vit_parseq.py
+++ b/ppocr/modeling/backbones/rec_vit_parseq.py
@@ -25,25 +25,25 @@
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
-trunc_normal_ = TruncatedNormal(std=.02)
+trunc_normal_ = TruncatedNormal(std=0.02)
normal_ = Normal
-zeros_ = Constant(value=0.)
-ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
def to_2tuple(x):
return tuple([x] * 2)
-def drop_path(x, drop_prob=0., training=False):
+def drop_path(x, drop_prob=0.0, training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
- if drop_prob == 0. or not training:
+ if drop_prob == 0.0 or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
- shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
@@ -51,8 +51,7 @@ def drop_path(x, drop_prob=0., training=False):
class DropPath(nn.Layer):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
@@ -71,12 +70,14 @@ def forward(self, input):
class Mlp(nn.Layer):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@@ -95,13 +96,15 @@ def forward(self, x):
class Attention(nn.Layer):
- def __init__(self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
@@ -115,8 +118,11 @@ def __init__(self,
def forward(self, x):
# B= x.shape[0]
N, C = x.shape[1:]
- qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
- self.num_heads)).transpose((2, 0, 3, 1, 4))
+ qkv = (
+ self.qkv(x)
+ .reshape((-1, N, 3, self.num_heads, C // self.num_heads))
+ .transpose((2, 0, 3, 1, 4))
+ )
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
@@ -128,48 +134,52 @@ def forward(self, x):
x = self.proj_drop(x)
return x
+
class Block(nn.Layer):
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer='nn.LayerNorm',
- epsilon=1e-5):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-5,
+ ):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm1 = norm_layer(dim)
else:
- raise TypeError(
- "The norm_layer must be str or paddle.nn.layer.Layer class")
+ raise TypeError("The norm_layer must be str or paddle.nn.layer.Layer class")
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
- proj_drop=drop)
+ proj_drop=drop,
+ )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm2 = norm_layer(dim)
else:
- raise TypeError(
- "The norm_layer must be str or paddle.nn.layer.Layer class")
+ raise TypeError("The norm_layer must be str or paddle.nn.layer.Layer class")
mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
@@ -178,8 +188,7 @@ def forward(self, x):
class PatchEmbed(nn.Layer):
- """ Image to Patch Embedding
- """
+ """Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
@@ -193,38 +202,41 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
self.num_patches = num_patches
self.proj = nn.Conv2D(
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
def forward(self, x):
B, C, H, W = x.shape
- assert H == self.img_size[0] and W == self.img_size[1], \
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose((0, 2, 1))
return x
class VisionTransformer(nn.Layer):
- """ Vision Transformer with support for patch input
- """
-
- def __init__(self,
- img_size=224,
- patch_size=16,
- in_channels=3,
- class_num=1000,
- embed_dim=768,
- depth=12,
- num_heads=12,
- mlp_ratio=4,
- qkv_bias=False,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- norm_layer='nn.LayerNorm',
- epsilon=1e-5,
- **kwargs):
+ """Vision Transformer with support for patch input"""
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ class_num=1000,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-5,
+ **kwargs,
+ ):
super().__init__()
self.class_num = class_num
@@ -234,37 +246,44 @@ def __init__(self,
img_size=img_size,
patch_size=patch_size,
in_chans=in_channels,
- embed_dim=embed_dim)
+ embed_dim=embed_dim,
+ )
num_patches = self.patch_embed.num_patches
- self.pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim), default_initializer=zeros_)
+ self.pos_embed = self.create_parameter(
+ shape=(1, num_patches, embed_dim), default_initializer=zeros_
+ )
self.add_parameter("pos_embed", self.pos_embed)
self.cls_token = self.create_parameter(
- shape=(1, 1, embed_dim), default_initializer=zeros_)
+ shape=(1, 1, embed_dim), default_initializer=zeros_
+ )
self.add_parameter("cls_token", self.cls_token)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
- self.blocks = nn.LayerList([
- Block(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- epsilon=epsilon) for i in range(depth)
- ])
+ self.blocks = nn.LayerList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ )
+ for i in range(depth)
+ ]
+ )
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
# Classifier head
- self.head = nn.Linear(embed_dim,
- class_num) if class_num > 0 else Identity()
+ self.head = nn.Linear(embed_dim, class_num) if class_num > 0 else Identity()
trunc_normal_(self.pos_embed)
self.out_channels = embed_dim
@@ -296,9 +315,34 @@ def forward(self, x):
class ViTParseQ(VisionTransformer):
- def __init__(self, img_size=[224, 224], patch_size=[16, 16], in_channels=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0):
- super().__init__(img_size, patch_size, in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate,
- attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, class_num=0)
+ def __init__(
+ self,
+ img_size=[224, 224],
+ patch_size=[16, 16],
+ in_channels=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ ):
+ super().__init__(
+ img_size,
+ patch_size,
+ in_channels,
+ embed_dim=embed_dim,
+ depth=depth,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate,
+ class_num=0,
+ )
def forward(self, x):
return self.forward_features(x)
diff --git a/ppocr/modeling/backbones/rec_vitstr.py b/ppocr/modeling/backbones/rec_vitstr.py
index d5d7d5148a..564e579d26 100644
--- a/ppocr/modeling/backbones/rec_vitstr.py
+++ b/ppocr/modeling/backbones/rec_vitstr.py
@@ -19,72 +19,85 @@
import numpy as np
import paddle
import paddle.nn as nn
-from ppocr.modeling.backbones.rec_svtrnet import Block, PatchEmbed, zeros_, trunc_normal_, ones_
+from ppocr.modeling.backbones.rec_svtrnet import (
+ Block,
+ PatchEmbed,
+ zeros_,
+ trunc_normal_,
+ ones_,
+)
-scale_dim_heads = {'tiny': [192, 3], 'small': [384, 6], 'base': [768, 12]}
+scale_dim_heads = {"tiny": [192, 3], "small": [384, 6], "base": [768, 12]}
class ViTSTR(nn.Layer):
- def __init__(self,
- img_size=[224, 224],
- in_channels=1,
- scale='tiny',
- seqlen=27,
- patch_size=[16, 16],
- embed_dim=None,
- depth=12,
- num_heads=None,
- mlp_ratio=4,
- qkv_bias=True,
- qk_scale=None,
- drop_path_rate=0.,
- drop_rate=0.,
- attn_drop_rate=0.,
- norm_layer='nn.LayerNorm',
- act_layer='nn.GELU',
- epsilon=1e-6,
- out_channels=None,
- **kwargs):
+ def __init__(
+ self,
+ img_size=[224, 224],
+ in_channels=1,
+ scale="tiny",
+ seqlen=27,
+ patch_size=[16, 16],
+ embed_dim=None,
+ depth=12,
+ num_heads=None,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_path_rate=0.0,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ norm_layer="nn.LayerNorm",
+ act_layer="nn.GELU",
+ epsilon=1e-6,
+ out_channels=None,
+ **kwargs
+ ):
super().__init__()
self.seqlen = seqlen
- embed_dim = embed_dim if embed_dim is not None else scale_dim_heads[
- scale][0]
- num_heads = num_heads if num_heads is not None else scale_dim_heads[
- scale][1]
+ embed_dim = embed_dim if embed_dim is not None else scale_dim_heads[scale][0]
+ num_heads = num_heads if num_heads is not None else scale_dim_heads[scale][1]
out_channels = out_channels if out_channels is not None else embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim,
patch_size=patch_size,
- mode='linear')
+ mode="linear",
+ )
num_patches = self.patch_embed.num_patches
self.pos_embed = self.create_parameter(
- shape=[1, num_patches + 1, embed_dim], default_initializer=zeros_)
+ shape=[1, num_patches + 1, embed_dim], default_initializer=zeros_
+ )
self.add_parameter("pos_embed", self.pos_embed)
self.cls_token = self.create_parameter(
- shape=[1, 1, embed_dim], default_initializer=zeros_)
+ shape=[1, 1, embed_dim], default_initializer=zeros_
+ )
self.add_parameter("cls_token", self.cls_token)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
- self.blocks = nn.LayerList([
- Block(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- act_layer=eval(act_layer),
- epsilon=epsilon,
- prenorm=False) for i in range(depth)
- ])
+ self.blocks = nn.LayerList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=eval(act_layer),
+ epsilon=epsilon,
+ prenorm=False,
+ )
+ for i in range(depth)
+ ]
+ )
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
self.out_channels = out_channels
@@ -116,5 +129,5 @@ def forward_features(self, x):
def forward(self, x):
x = self.forward_features(x)
- x = x[:, :self.seqlen]
+ x = x[:, : self.seqlen]
return x.transpose([0, 2, 1]).unsqueeze(2)
diff --git a/ppocr/modeling/backbones/table_master_resnet.py b/ppocr/modeling/backbones/table_master_resnet.py
index dacf5ed26e..6880ae317e 100644
--- a/ppocr/modeling/backbones/table_master_resnet.py
+++ b/ppocr/modeling/backbones/table_master_resnet.py
@@ -24,40 +24,33 @@
class BasicBlock(nn.Layer):
expansion = 1
- def __init__(self,
- inplanes,
- planes,
- stride=1,
- downsample=None,
- gcb_config=None):
+ def __init__(self, inplanes, planes, stride=1, downsample=None, gcb_config=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
- inplanes,
- planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias_attr=False)
+ inplanes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False
+ )
self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
- planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False
+ )
self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
self.downsample = downsample
self.stride = stride
self.gcb_config = gcb_config
if self.gcb_config is not None:
- gcb_ratio = gcb_config['ratio']
- gcb_headers = gcb_config['headers']
- att_scale = gcb_config['att_scale']
- fusion_type = gcb_config['fusion_type']
+ gcb_ratio = gcb_config["ratio"]
+ gcb_headers = gcb_config["headers"]
+ att_scale = gcb_config["att_scale"]
+ fusion_type = gcb_config["fusion_type"]
self.context_block = MultiAspectGCAttention(
inplanes=planes,
ratio=gcb_ratio,
headers=gcb_headers,
att_scale=att_scale,
- fusion_type=fusion_type)
+ fusion_type=fusion_type,
+ )
def forward(self, x):
residual = x
@@ -82,7 +75,7 @@ def forward(self, x):
def get_gcb_config(gcb_config, layer):
- if gcb_config is None or not gcb_config['layers'][layer]:
+ if gcb_config is None or not gcb_config["layers"][layer]:
return None
else:
return gcb_config
@@ -95,17 +88,14 @@ def __init__(self, layers, in_channels=3, gcb_config=None):
super(TableResNetExtra, self).__init__()
self.inplanes = 128
self.conv1 = nn.Conv2D(
- in_channels,
- 64,
- kernel_size=3,
- stride=1,
- padding=1,
- bias_attr=False)
+ in_channels, 64, kernel_size=3, stride=1, padding=1, bias_attr=False
+ )
self.bn1 = nn.BatchNorm2D(64)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
- 64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ 64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False
+ )
self.bn2 = nn.BatchNorm2D(128)
self.relu2 = nn.ReLU()
@@ -116,10 +106,12 @@ def __init__(self, layers, in_channels=3, gcb_config=None):
256,
layers[0],
stride=1,
- gcb_config=get_gcb_config(gcb_config, 0))
+ gcb_config=get_gcb_config(gcb_config, 0),
+ )
self.conv3 = nn.Conv2D(
- 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False
+ )
self.bn3 = nn.BatchNorm2D(256)
self.relu3 = nn.ReLU()
@@ -130,10 +122,12 @@ def __init__(self, layers, in_channels=3, gcb_config=None):
256,
layers[1],
stride=1,
- gcb_config=get_gcb_config(gcb_config, 1))
+ gcb_config=get_gcb_config(gcb_config, 1),
+ )
self.conv4 = nn.Conv2D(
- 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False
+ )
self.bn4 = nn.BatchNorm2D(256)
self.relu4 = nn.ReLU()
@@ -144,10 +138,12 @@ def __init__(self, layers, in_channels=3, gcb_config=None):
512,
layers[2],
stride=1,
- gcb_config=get_gcb_config(gcb_config, 2))
+ gcb_config=get_gcb_config(gcb_config, 2),
+ )
self.conv5 = nn.Conv2D(
- 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False
+ )
self.bn5 = nn.BatchNorm2D(512)
self.relu5 = nn.ReLU()
@@ -156,10 +152,12 @@ def __init__(self, layers, in_channels=3, gcb_config=None):
512,
layers[3],
stride=1,
- gcb_config=get_gcb_config(gcb_config, 3))
+ gcb_config=get_gcb_config(gcb_config, 3),
+ )
self.conv6 = nn.Conv2D(
- 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False
+ )
self.bn6 = nn.BatchNorm2D(512)
self.relu6 = nn.ReLU()
@@ -174,17 +172,15 @@ def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
planes * block.expansion,
kernel_size=1,
stride=stride,
- bias_attr=False),
- nn.BatchNorm2D(planes * block.expansion), )
+ bias_attr=False,
+ ),
+ nn.BatchNorm2D(planes * block.expansion),
+ )
layers = []
layers.append(
- block(
- self.inplanes,
- planes,
- stride,
- downsample,
- gcb_config=gcb_config))
+ block(self.inplanes, planes, stride, downsample, gcb_config=gcb_config)
+ )
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
@@ -234,18 +230,22 @@ def forward(self, x):
class MultiAspectGCAttention(nn.Layer):
- def __init__(self,
- inplanes,
- ratio,
- headers,
- pooling_type='att',
- att_scale=False,
- fusion_type='channel_add'):
+ def __init__(
+ self,
+ inplanes,
+ ratio,
+ headers,
+ pooling_type="att",
+ att_scale=False,
+ fusion_type="channel_add",
+ ):
super(MultiAspectGCAttention, self).__init__()
- assert pooling_type in ['avg', 'att']
+ assert pooling_type in ["avg", "att"]
- assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
- assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
+ assert fusion_type in ["channel_add", "channel_mul", "channel_concat"]
+ assert (
+ inplanes % headers == 0 and inplanes >= 8
+ ) # inplanes must be divided by headers evenly
self.headers = headers
self.inplanes = inplanes
@@ -257,56 +257,50 @@ def __init__(self,
self.single_header_inplanes = int(inplanes / headers)
- if pooling_type == 'att':
- self.conv_mask = nn.Conv2D(
- self.single_header_inplanes, 1, kernel_size=1)
+ if pooling_type == "att":
+ self.conv_mask = nn.Conv2D(self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(axis=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2D(1)
- if fusion_type == 'channel_add':
+ if fusion_type == "channel_add":
self.channel_add_conv = nn.Sequential(
- nn.Conv2D(
- self.inplanes, self.planes, kernel_size=1),
+ nn.Conv2D(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
- nn.Conv2D(
- self.planes, self.inplanes, kernel_size=1))
- elif fusion_type == 'channel_concat':
+ nn.Conv2D(self.planes, self.inplanes, kernel_size=1),
+ )
+ elif fusion_type == "channel_concat":
self.channel_concat_conv = nn.Sequential(
- nn.Conv2D(
- self.inplanes, self.planes, kernel_size=1),
+ nn.Conv2D(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
- nn.Conv2D(
- self.planes, self.inplanes, kernel_size=1))
+ nn.Conv2D(self.planes, self.inplanes, kernel_size=1),
+ )
# for concat
- self.cat_conv = nn.Conv2D(
- 2 * self.inplanes, self.inplanes, kernel_size=1)
- elif fusion_type == 'channel_mul':
+ self.cat_conv = nn.Conv2D(2 * self.inplanes, self.inplanes, kernel_size=1)
+ elif fusion_type == "channel_mul":
self.channel_mul_conv = nn.Sequential(
- nn.Conv2D(
- self.inplanes, self.planes, kernel_size=1),
+ nn.Conv2D(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
- nn.Conv2D(
- self.planes, self.inplanes, kernel_size=1))
+ nn.Conv2D(self.planes, self.inplanes, kernel_size=1),
+ )
def spatial_pool(self, x):
batch, channel, height, width = x.shape
- if self.pooling_type == 'att':
+ if self.pooling_type == "att":
# [N*headers, C', H , W] C = headers * C'
- x = x.reshape([
- batch * self.headers, self.single_header_inplanes, height, width
- ])
+ x = x.reshape(
+ [batch * self.headers, self.single_header_inplanes, height, width]
+ )
input_x = x
# [N*headers, C', H * W] C = headers * C'
# input_x = input_x.view(batch, channel, height * width)
- input_x = input_x.reshape([
- batch * self.headers, self.single_header_inplanes,
- height * width
- ])
+ input_x = input_x.reshape(
+ [batch * self.headers, self.single_header_inplanes, height * width]
+ )
# [N*headers, 1, C', H * W]
input_x = input_x.unsqueeze(1)
@@ -314,12 +308,12 @@ def spatial_pool(self, x):
context_mask = self.conv_mask(x)
# [N*headers, 1, H * W]
context_mask = context_mask.reshape(
- [batch * self.headers, 1, height * width])
+ [batch * self.headers, 1, height * width]
+ )
# scale variance
if self.att_scale and self.headers > 1:
- context_mask = context_mask / paddle.sqrt(
- self.single_header_inplanes)
+ context_mask = context_mask / paddle.sqrt(self.single_header_inplanes)
# [N*headers, 1, H * W]
context_mask = self.softmax(context_mask)
@@ -331,7 +325,8 @@ def spatial_pool(self, x):
# [N, headers * C', 1, 1]
context = context.reshape(
- [batch, self.headers * self.single_header_inplanes, 1, 1])
+ [batch, self.headers * self.single_header_inplanes, 1, 1]
+ )
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
@@ -344,11 +339,11 @@ def forward(self, x):
out = x
- if self.fusion_type == 'channel_mul':
+ if self.fusion_type == "channel_mul":
# [N, C, 1, 1]
channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
- elif self.fusion_type == 'channel_add':
+ elif self.fusion_type == "channel_add":
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
@@ -361,7 +356,8 @@ def forward(self, x):
N, C2, H, W = out.shape
out = paddle.concat(
- [out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
+ [out, channel_concat_term.expand([-1, -1, H, W])], axis=1
+ )
out = self.cat_conv(out)
out = F.layer_norm(out, [self.inplanes, H, W])
out = F.relu(out)
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
index 4357b56645..cbcf804a71 100644
--- a/ppocr/modeling/backbones/vqa_layoutlm.py
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -19,9 +19,17 @@
import os
from paddle import nn
-from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
+from paddlenlp.transformers import (
+ LayoutXLMModel,
+ LayoutXLMForTokenClassification,
+ LayoutXLMForRelationExtraction,
+)
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
-from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
+from paddlenlp.transformers import (
+ LayoutLMv2Model,
+ LayoutLMv2ForTokenClassification,
+ LayoutLMv2ForRelationExtraction,
+)
from paddlenlp.transformers import AutoModel
__all__ = ["LayoutXLMForSer", "LayoutLMForSer"]
@@ -42,34 +50,37 @@
class NLPBaseModel(nn.Layer):
- def __init__(self,
- base_model_class,
- model_class,
- mode="base",
- type="ser",
- pretrained=True,
- checkpoints=None,
- **kwargs):
+ def __init__(
+ self,
+ base_model_class,
+ model_class,
+ mode="base",
+ type="ser",
+ pretrained=True,
+ checkpoints=None,
+ **kwargs
+ ):
super(NLPBaseModel, self).__init__()
if checkpoints is not None: # load the trained model
self.model = model_class.from_pretrained(checkpoints)
else: # load the pretrained-model
pretrained_model_name = pretrained_model_dict[base_model_class][mode]
if type == "ser":
- self.model = model_class.from_pretrained(pretrained_model_name, num_classes=kwargs["num_classes"], dropout=0)
+ self.model = model_class.from_pretrained(
+ pretrained_model_name, num_classes=kwargs["num_classes"], dropout=0
+ )
else:
- self.model = model_class.from_pretrained(pretrained_model_name, dropout=0)
+ self.model = model_class.from_pretrained(
+ pretrained_model_name, dropout=0
+ )
self.out_channels = 1
self.use_visual_backbone = True
class LayoutLMForSer(NLPBaseModel):
- def __init__(self,
- num_classes,
- pretrained=True,
- checkpoints=None,
- mode="base",
- **kwargs):
+ def __init__(
+ self, num_classes, pretrained=True, checkpoints=None, mode="base", **kwargs
+ ):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
@@ -77,7 +88,8 @@ def __init__(self,
"ser",
pretrained,
checkpoints,
- num_classes=num_classes, )
+ num_classes=num_classes,
+ )
self.use_visual_backbone = False
def forward(self, x):
@@ -87,17 +99,15 @@ def forward(self, x):
attention_mask=x[2],
token_type_ids=x[3],
position_ids=None,
- output_hidden_states=False)
+ output_hidden_states=False,
+ )
return x
class LayoutLMv2ForSer(NLPBaseModel):
- def __init__(self,
- num_classes,
- pretrained=True,
- checkpoints=None,
- mode="base",
- **kwargs):
+ def __init__(
+ self, num_classes, pretrained=True, checkpoints=None, mode="base", **kwargs
+ ):
super(LayoutLMv2ForSer, self).__init__(
LayoutLMv2Model,
LayoutLMv2ForTokenClassification,
@@ -105,9 +115,12 @@ def __init__(self,
"ser",
pretrained,
checkpoints,
- num_classes=num_classes)
- if hasattr(self.model.layoutlmv2, "use_visual_backbone"
- ) and self.model.layoutlmv2.use_visual_backbone is False:
+ num_classes=num_classes,
+ )
+ if (
+ hasattr(self.model.layoutlmv2, "use_visual_backbone")
+ and self.model.layoutlmv2.use_visual_backbone is False
+ ):
self.use_visual_backbone = False
def forward(self, x):
@@ -123,7 +136,8 @@ def forward(self, x):
image=image,
position_ids=None,
head_mask=None,
- labels=None)
+ labels=None,
+ )
if self.training:
res = {"backbone_out": x[0]}
res.update(x[1])
@@ -133,12 +147,9 @@ def forward(self, x):
class LayoutXLMForSer(NLPBaseModel):
- def __init__(self,
- num_classes,
- pretrained=True,
- checkpoints=None,
- mode="base",
- **kwargs):
+ def __init__(
+ self, num_classes, pretrained=True, checkpoints=None, mode="base", **kwargs
+ ):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
@@ -146,9 +157,12 @@ def __init__(self,
"ser",
pretrained,
checkpoints,
- num_classes=num_classes)
- if hasattr(self.model.layoutxlm, "use_visual_backbone"
- ) and self.model.layoutxlm.use_visual_backbone is False:
+ num_classes=num_classes,
+ )
+ if (
+ hasattr(self.model.layoutxlm, "use_visual_backbone")
+ and self.model.layoutxlm.use_visual_backbone is False
+ ):
self.use_visual_backbone = False
def forward(self, x):
@@ -164,7 +178,8 @@ def forward(self, x):
image=image,
position_ids=None,
head_mask=None,
- labels=None)
+ labels=None,
+ )
if self.training:
res = {"backbone_out": x[0]}
res.update(x[1])
@@ -174,13 +189,19 @@ def forward(self, x):
class LayoutLMv2ForRe(NLPBaseModel):
- def __init__(self, pretrained=True, checkpoints=None, mode="base",
- **kwargs):
+ def __init__(self, pretrained=True, checkpoints=None, mode="base", **kwargs):
super(LayoutLMv2ForRe, self).__init__(
- LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re",
- pretrained, checkpoints)
- if hasattr(self.model.layoutlmv2, "use_visual_backbone"
- ) and self.model.layoutlmv2.use_visual_backbone is False:
+ LayoutLMv2Model,
+ LayoutLMv2ForRelationExtraction,
+ mode,
+ "re",
+ pretrained,
+ checkpoints,
+ )
+ if (
+ hasattr(self.model.layoutlmv2, "use_visual_backbone")
+ and self.model.layoutlmv2.use_visual_backbone is False
+ ):
self.use_visual_backbone = False
def forward(self, x):
@@ -194,18 +215,25 @@ def forward(self, x):
head_mask=None,
labels=None,
entities=x[5],
- relations=x[6])
+ relations=x[6],
+ )
return x
class LayoutXLMForRe(NLPBaseModel):
- def __init__(self, pretrained=True, checkpoints=None, mode="base",
- **kwargs):
+ def __init__(self, pretrained=True, checkpoints=None, mode="base", **kwargs):
super(LayoutXLMForRe, self).__init__(
- LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re",
- pretrained, checkpoints)
- if hasattr(self.model.layoutxlm, "use_visual_backbone"
- ) and self.model.layoutxlm.use_visual_backbone is False:
+ LayoutXLMModel,
+ LayoutXLMForRelationExtraction,
+ mode,
+ "re",
+ pretrained,
+ checkpoints,
+ )
+ if (
+ hasattr(self.model.layoutxlm, "use_visual_backbone")
+ and self.model.layoutxlm.use_visual_backbone is False
+ ):
self.use_visual_backbone = False
def forward(self, x):
@@ -227,5 +255,6 @@ def forward(self, x):
head_mask=None,
labels=None,
entities=entities,
- relations=relations)
+ relations=relations,
+ )
return x
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 1ff6040bf5..f9a9528eb0 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['build_head']
+__all__ = ["build_head"]
def build_head(config):
@@ -24,6 +24,7 @@ def build_head(config):
from .det_fce_head import FCEHead
from .e2e_pg_head import PGHead
from .det_ct_head import CT_Head
+
# rec head
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
@@ -46,30 +47,56 @@ def build_head(config):
# cls head
from .cls_head import ClsHead
- #kie head
+ # kie head
from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead, SLAHead
from .table_master_head import TableMasterHead
support_dict = [
- 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
- 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
- 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
- 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
- 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
- 'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal', 'ParseQHead',
- 'CPPDHead'
+ "DBHead",
+ "PSEHead",
+ "FCEHead",
+ "EASTHead",
+ "SASTHead",
+ "CTCHead",
+ "ClsHead",
+ "AttentionHead",
+ "SRNHead",
+ "PGHead",
+ "Transformer",
+ "TableAttentionHead",
+ "SARHead",
+ "AsterHead",
+ "SDMGRHead",
+ "PRENHead",
+ "MultiHead",
+ "ABINetHead",
+ "TableMasterHead",
+ "SPINAttentionHead",
+ "VLHead",
+ "SLAHead",
+ "RobustScannerHead",
+ "CT_Head",
+ "RFLHead",
+ "DRRGHead",
+ "CANHead",
+ "SATRNHead",
+ "PFHeadLocal",
+ "ParseQHead",
+ "CPPDHead",
]
- if config['name'] == 'DRRGHead':
+ if config["name"] == "DRRGHead":
from .det_drrg_head import DRRGHead
- support_dict.append('DRRGHead')
- #table head
+ support_dict.append("DRRGHead")
+
+ # table head
- module_name = config.pop('name')
- assert module_name in support_dict, Exception('head only support {}'.format(
- support_dict))
+ module_name = config.pop("name")
+ assert module_name in support_dict, Exception(
+ "head only support {}".format(support_dict)
+ )
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py
index 91bfa615a8..867e960182 100644
--- a/ppocr/modeling/heads/cls_head.py
+++ b/ppocr/modeling/heads/cls_head.py
@@ -39,9 +39,10 @@ def __init__(self, in_channels, class_dim, **kwargs):
in_channels,
class_dim,
weight_attr=ParamAttr(
- name="fc_0.w_0",
- initializer=nn.initializer.Uniform(-stdv, stdv)),
- bias_attr=ParamAttr(name="fc_0.b_0"), )
+ name="fc_0.w_0", initializer=nn.initializer.Uniform(-stdv, stdv)
+ ),
+ bias_attr=ParamAttr(name="fc_0.b_0"),
+ )
def forward(self, x, targets=None):
x = self.pool(x)
diff --git a/ppocr/modeling/heads/det_ct_head.py b/ppocr/modeling/heads/det_ct_head.py
index 08e6719e8f..cd050fc601 100644
--- a/ppocr/modeling/heads/det_ct_head.py
+++ b/ppocr/modeling/heads/det_ct_head.py
@@ -24,37 +24,37 @@
import math
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
-ones_ = Constant(value=1.)
-zeros_ = Constant(value=0.)
+
+ones_ = Constant(value=1.0)
+zeros_ = Constant(value=0.0)
class CT_Head(nn.Layer):
- def __init__(self,
- in_channels,
- hidden_dim,
- num_classes,
- loss_kernel=None,
- loss_loc=None):
+ def __init__(
+ self, in_channels, hidden_dim, num_classes, loss_kernel=None, loss_loc=None
+ ):
super(CT_Head, self).__init__()
self.conv1 = nn.Conv2D(
- in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
+ in_channels, hidden_dim, kernel_size=3, stride=1, padding=1
+ )
self.bn1 = nn.BatchNorm2D(hidden_dim)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
- hidden_dim, num_classes, kernel_size=1, stride=1, padding=0)
+ hidden_dim, num_classes, kernel_size=1, stride=1, padding=0
+ )
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
+ normal_ = Normal(mean=0.0, std=math.sqrt(2.0 / n))
normal_(m.weight)
elif isinstance(m, nn.BatchNorm2D):
zeros_(m.bias)
ones_(m.weight)
def _upsample(self, x, scale=1):
- return F.upsample(x, scale_factor=scale, mode='bilinear')
+ return F.upsample(x, scale_factor=scale, mode="bilinear")
def forward(self, f, targets=None):
out = self.conv1(f)
@@ -63,7 +63,7 @@ def forward(self, f, targets=None):
if self.training:
out = self._upsample(out, scale=4)
- return {'maps': out}
+ return {"maps": out}
else:
score = F.sigmoid(out[:, 0, :, :])
- return {'maps': out, 'score': score}
+ return {"maps": out, "score": score}
diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py
index 8db14d7f6f..8f41a25b01 100644
--- a/ppocr/modeling/heads/det_db_head.py
+++ b/ppocr/modeling/heads/det_db_head.py
@@ -41,38 +41,37 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
kernel_size=kernel_list[0],
padding=int(kernel_list[0] // 2),
weight_attr=ParamAttr(),
- bias_attr=False)
+ bias_attr=False,
+ )
self.conv_bn1 = nn.BatchNorm(
num_channels=in_channels // 4,
- param_attr=ParamAttr(
- initializer=paddle.nn.initializer.Constant(value=1.0)),
- bias_attr=ParamAttr(
- initializer=paddle.nn.initializer.Constant(value=1e-4)),
- act='relu')
+ param_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1.0)),
+ bias_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1e-4)),
+ act="relu",
+ )
self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4,
out_channels=in_channels // 4,
kernel_size=kernel_list[1],
stride=2,
- weight_attr=ParamAttr(
- initializer=paddle.nn.initializer.KaimingUniform()),
- bias_attr=get_bias_attr(in_channels // 4))
+ weight_attr=ParamAttr(initializer=paddle.nn.initializer.KaimingUniform()),
+ bias_attr=get_bias_attr(in_channels // 4),
+ )
self.conv_bn2 = nn.BatchNorm(
num_channels=in_channels // 4,
- param_attr=ParamAttr(
- initializer=paddle.nn.initializer.Constant(value=1.0)),
- bias_attr=ParamAttr(
- initializer=paddle.nn.initializer.Constant(value=1e-4)),
- act="relu")
+ param_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1.0)),
+ bias_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1e-4)),
+ act="relu",
+ )
self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4,
out_channels=1,
kernel_size=kernel_list[2],
stride=2,
- weight_attr=ParamAttr(
- initializer=paddle.nn.initializer.KaimingUniform()),
- bias_attr=get_bias_attr(in_channels // 4), )
+ weight_attr=ParamAttr(initializer=paddle.nn.initializer.KaimingUniform()),
+ bias_attr=get_bias_attr(in_channels // 4),
+ )
def forward(self, x, return_f=False):
x = self.conv1(x)
@@ -108,18 +107,18 @@ def step_function(self, x, y):
def forward(self, x, targets=None):
shrink_maps = self.binarize(x)
if not self.training:
- return {'maps': shrink_maps}
+ return {"maps": shrink_maps}
threshold_maps = self.thresh(x)
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
- return {'maps': y}
+ return {"maps": y}
class LocalModule(nn.Layer):
def __init__(self, in_c, mid_c, use_distance=True):
super(self.__class__, self).__init__()
- self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
+ self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act="relu")
self.last_1 = nn.Conv2D(mid_c, 1, 1, 1, 0)
def forward(self, x, init_map, distance_map):
@@ -130,14 +129,14 @@ def forward(self, x, init_map, distance_map):
class PFHeadLocal(DBHead):
- def __init__(self, in_channels, k=50, mode='small', **kwargs):
+ def __init__(self, in_channels, k=50, mode="small", **kwargs):
super(PFHeadLocal, self).__init__(in_channels, k, **kwargs)
self.mode = mode
self.up_conv = nn.Upsample(scale_factor=2, mode="nearest", align_mode=1)
- if self.mode == 'large':
+ if self.mode == "large":
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
- elif self.mode == 'small':
+ elif self.mode == "small":
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)
def forward(self, x, targets=None):
@@ -146,9 +145,9 @@ def forward(self, x, targets=None):
cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
cbn_maps = F.sigmoid(cbn_maps)
if not self.training:
- return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}
+ return {"maps": 0.5 * (base_maps + cbn_maps), "cbn_maps": cbn_maps}
threshold_maps = self.thresh(x)
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat([cbn_maps, threshold_maps, binary_maps], axis=1)
- return {'maps': y, 'distance_maps': cbn_maps, 'cbn_maps': binary_maps}
+ return {"maps": y, "distance_maps": cbn_maps, "cbn_maps": binary_maps}
diff --git a/ppocr/modeling/heads/det_drrg_head.py b/ppocr/modeling/heads/det_drrg_head.py
index 3aee1f8cb7..24dc12fef5 100644
--- a/ppocr/modeling/heads/det_drrg_head.py
+++ b/ppocr/modeling/heads/det_drrg_head.py
@@ -32,24 +32,26 @@
class DRRGHead(nn.Layer):
- def __init__(self,
- in_channels,
- k_at_hops=(8, 4),
- num_adjacent_linkages=3,
- node_geo_feat_len=120,
- pooling_scale=1.0,
- pooling_output_size=(4, 3),
- nms_thr=0.3,
- min_width=8.0,
- max_width=24.0,
- comp_shrink_ratio=1.03,
- comp_ratio=0.4,
- comp_score_thr=0.3,
- text_region_thr=0.2,
- center_region_thr=0.2,
- center_region_area_thr=50,
- local_graph_thr=0.7,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ k_at_hops=(8, 4),
+ num_adjacent_linkages=3,
+ node_geo_feat_len=120,
+ pooling_scale=1.0,
+ pooling_output_size=(4, 3),
+ nms_thr=0.3,
+ min_width=8.0,
+ max_width=24.0,
+ comp_shrink_ratio=1.03,
+ comp_ratio=0.4,
+ comp_score_thr=0.3,
+ text_region_thr=0.2,
+ center_region_thr=0.2,
+ center_region_area_thr=50,
+ local_graph_thr=0.7,
+ **kwargs
+ ):
super().__init__()
assert isinstance(in_channels, int)
@@ -93,22 +95,39 @@ def __init__(self,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
- padding=0)
+ padding=0,
+ )
self.graph_train = LocalGraphs(
- self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
- self.pooling_scale, self.pooling_output_size, self.local_graph_thr)
+ self.k_at_hops,
+ self.num_adjacent_linkages,
+ self.node_geo_feat_len,
+ self.pooling_scale,
+ self.pooling_output_size,
+ self.local_graph_thr,
+ )
self.graph_test = ProposalLocalGraphs(
- self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
- self.pooling_scale, self.pooling_output_size, self.nms_thr,
- self.min_width, self.max_width, self.comp_shrink_ratio,
- self.comp_ratio, self.comp_score_thr, self.text_region_thr,
- self.center_region_thr, self.center_region_area_thr)
+ self.k_at_hops,
+ self.num_adjacent_linkages,
+ self.node_geo_feat_len,
+ self.pooling_scale,
+ self.pooling_output_size,
+ self.nms_thr,
+ self.min_width,
+ self.max_width,
+ self.comp_shrink_ratio,
+ self.comp_ratio,
+ self.comp_score_thr,
+ self.text_region_thr,
+ self.center_region_thr,
+ self.center_region_area_thr,
+ )
pool_w, pool_h = self.pooling_output_size
node_feat_len = (pool_w * pool_h) * (
- self.in_channels + self.out_channels) + self.node_geo_feat_len
+ self.in_channels + self.out_channels
+ ) + self.node_geo_feat_len
self.gcn = GCN(node_feat_len)
def forward(self, inputs, targets=None):
@@ -134,7 +153,8 @@ def forward(self, inputs, targets=None):
pred_maps = self.out_conv(inputs)
feat_maps = paddle.concat([inputs, pred_maps], axis=1)
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
- feat_maps, np.stack(gt_comp_attribs))
+ feat_maps, np.stack(gt_comp_attribs)
+ )
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
@@ -164,13 +184,17 @@ def single_test(self, feat_maps):
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
- (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
- pivot_local_graphs, text_comps) = graph_data
+ (
+ local_graphs_node_feat,
+ adjacent_matrices,
+ pivots_knn_inds,
+ pivot_local_graphs,
+ text_comps,
+ ) = graph_data
if none_flag:
return None, None, None
- gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices,
- pivots_knn_inds)
+ gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds)
pred_labels = F.softmax(gcn_pred, axis=1)
edges = []
@@ -182,8 +206,9 @@ def single_test(self, feat_maps):
for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
neighbor = pivot_local_graph[neighbor_ind.item()]
edges.append([pivot, neighbor])
- scores.append(pred_labels[pivot_ind * pivots_knn_inds.shape[1] +
- k_ind, 1].item())
+ scores.append(
+ pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind, 1].item()
+ )
edges = np.asarray(edges)
scores = np.asarray(scores)
diff --git a/ppocr/modeling/heads/det_east_head.py b/ppocr/modeling/heads/det_east_head.py
index 004eb5d7bb..c0ad6e8c3d 100644
--- a/ppocr/modeling/heads/det_east_head.py
+++ b/ppocr/modeling/heads/det_east_head.py
@@ -24,16 +24,18 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -44,8 +46,9 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(
num_channels=out_channels,
@@ -53,7 +56,8 @@ def __init__(self,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
- moving_variance_name="bn_" + name + "_variance")
+ moving_variance_name="bn_" + name + "_variance",
+ )
def forward(self, x):
x = self.conv(x)
@@ -62,8 +66,8 @@ def forward(self, x):
class EASTHead(nn.Layer):
- """
- """
+ """ """
+
def __init__(self, in_channels, model_name, **kwargs):
super(EASTHead, self).__init__()
self.model_name = model_name
@@ -79,8 +83,9 @@ def __init__(self, in_channels, model_name, **kwargs):
stride=1,
padding=1,
if_act=True,
- act='relu',
- name="det_head1")
+ act="relu",
+ name="det_head1",
+ )
self.det_conv2 = ConvBNLayer(
in_channels=num_outputs[0],
out_channels=num_outputs[1],
@@ -88,8 +93,9 @@ def __init__(self, in_channels, model_name, **kwargs):
stride=1,
padding=1,
if_act=True,
- act='relu',
- name="det_head2")
+ act="relu",
+ name="det_head2",
+ )
self.score_conv = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[2],
@@ -98,7 +104,8 @@ def __init__(self, in_channels, model_name, **kwargs):
padding=0,
if_act=False,
act=None,
- name="f_score")
+ name="f_score",
+ )
self.geo_conv = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[3],
@@ -107,7 +114,8 @@ def __init__(self, in_channels, model_name, **kwargs):
padding=0,
if_act=False,
act=None,
- name="f_geo")
+ name="f_geo",
+ )
def forward(self, x, targets=None):
f_det = self.det_conv1(x)
@@ -117,5 +125,5 @@ def forward(self, x, targets=None):
f_geo = self.geo_conv(f_det)
f_geo = (F.sigmoid(f_geo) - 0.5) * 2 * 800
- pred = {'f_score': f_score, 'f_geo': f_geo}
+ pred = {"f_score": f_score, "f_geo": f_geo}
return pred
diff --git a/ppocr/modeling/heads/det_fce_head.py b/ppocr/modeling/heads/det_fce_head.py
index 9503989f58..1a90a9a6f5 100644
--- a/ppocr/modeling/heads/det_fce_head.py
+++ b/ppocr/modeling/heads/det_fce_head.py
@@ -61,10 +61,10 @@ def __init__(self, in_channels, fourier_degree=5):
padding=1,
groups=1,
weight_attr=ParamAttr(
- name='cls_weights',
- initializer=Normal(
- mean=0., std=0.01)),
- bias_attr=True)
+ name="cls_weights", initializer=Normal(mean=0.0, std=0.01)
+ ),
+ bias_attr=True,
+ )
self.out_conv_reg = nn.Conv2D(
in_channels=self.in_channels,
out_channels=self.out_channels_reg,
@@ -73,10 +73,10 @@ def __init__(self, in_channels, fourier_degree=5):
padding=1,
groups=1,
weight_attr=ParamAttr(
- name='reg_weights',
- initializer=Normal(
- mean=0., std=0.01)),
- bias_attr=True)
+ name="reg_weights", initializer=Normal(mean=0.0, std=0.01)
+ ),
+ bias_attr=True,
+ )
def forward(self, feats, targets=None):
cls_res, reg_res = multi_apply(self.forward_single, feats)
@@ -86,11 +86,12 @@ def forward(self, feats, targets=None):
for i in range(level_num):
tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], axis=1)
tcl_pred = F.softmax(cls_res[i][:, 2:, :, :], axis=1)
- outs['level_{}'.format(i)] = paddle.concat(
- [tr_pred, tcl_pred, reg_res[i]], axis=1)
+ outs["level_{}".format(i)] = paddle.concat(
+ [tr_pred, tcl_pred, reg_res[i]], axis=1
+ )
else:
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
- outs['levels'] = preds
+ outs["levels"] = preds
return outs
def forward_single(self, x):
diff --git a/ppocr/modeling/heads/det_pse_head.py b/ppocr/modeling/heads/det_pse_head.py
index 32a5b48e19..2f51621d20 100644
--- a/ppocr/modeling/heads/det_pse_head.py
+++ b/ppocr/modeling/heads/det_pse_head.py
@@ -23,15 +23,17 @@ class PSEHead(nn.Layer):
def __init__(self, in_channels, hidden_dim=256, out_channels=7, **kwargs):
super(PSEHead, self).__init__()
self.conv1 = nn.Conv2D(
- in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
+ in_channels, hidden_dim, kernel_size=3, stride=1, padding=1
+ )
self.bn1 = nn.BatchNorm2D(hidden_dim)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
- hidden_dim, out_channels, kernel_size=1, stride=1, padding=0)
+ hidden_dim, out_channels, kernel_size=1, stride=1, padding=0
+ )
def forward(self, x, **kwargs):
out = self.conv1(x)
out = self.relu1(self.bn1(out))
out = self.conv2(out)
- return {'maps': out}
+ return {"maps": out}
diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py
index 7a88a2db6c..9246355a8b 100644
--- a/ppocr/modeling/heads/det_sast_head.py
+++ b/ppocr/modeling/heads/det_sast_head.py
@@ -24,15 +24,17 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -43,8 +45,9 @@ def __init__(self,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(
num_channels=out_channels,
@@ -52,7 +55,8 @@ def __init__(self,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
- moving_variance_name="bn_" + name + "_variance")
+ moving_variance_name="bn_" + name + "_variance",
+ )
def forward(self, x):
x = self.conv(x)
@@ -65,16 +69,28 @@ def __init__(self, in_channels, **kwargs):
super(SAST_Header1, self).__init__()
out_channels = [64, 64, 128]
self.score_conv = nn.Sequential(
- ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_score1'),
- ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_score2'),
- ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_score3'),
- ConvBNLayer(out_channels[2], 1, 3, 1, act=None, name='f_score4')
+ ConvBNLayer(
+ in_channels, out_channels[0], 1, 1, act="relu", name="f_score1"
+ ),
+ ConvBNLayer(
+ out_channels[0], out_channels[1], 3, 1, act="relu", name="f_score2"
+ ),
+ ConvBNLayer(
+ out_channels[1], out_channels[2], 1, 1, act="relu", name="f_score3"
+ ),
+ ConvBNLayer(out_channels[2], 1, 3, 1, act=None, name="f_score4"),
)
self.border_conv = nn.Sequential(
- ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_border1'),
- ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_border2'),
- ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_border3'),
- ConvBNLayer(out_channels[2], 4, 3, 1, act=None, name='f_border4')
+ ConvBNLayer(
+ in_channels, out_channels[0], 1, 1, act="relu", name="f_border1"
+ ),
+ ConvBNLayer(
+ out_channels[0], out_channels[1], 3, 1, act="relu", name="f_border2"
+ ),
+ ConvBNLayer(
+ out_channels[1], out_channels[2], 1, 1, act="relu", name="f_border3"
+ ),
+ ConvBNLayer(out_channels[2], 4, 3, 1, act=None, name="f_border4"),
)
def forward(self, x):
@@ -89,16 +105,24 @@ def __init__(self, in_channels, **kwargs):
super(SAST_Header2, self).__init__()
out_channels = [64, 64, 128]
self.tvo_conv = nn.Sequential(
- ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tvo1'),
- ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tvo2'),
- ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tvo3'),
- ConvBNLayer(out_channels[2], 8, 3, 1, act=None, name='f_tvo4')
+ ConvBNLayer(in_channels, out_channels[0], 1, 1, act="relu", name="f_tvo1"),
+ ConvBNLayer(
+ out_channels[0], out_channels[1], 3, 1, act="relu", name="f_tvo2"
+ ),
+ ConvBNLayer(
+ out_channels[1], out_channels[2], 1, 1, act="relu", name="f_tvo3"
+ ),
+ ConvBNLayer(out_channels[2], 8, 3, 1, act=None, name="f_tvo4"),
)
self.tco_conv = nn.Sequential(
- ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tco1'),
- ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tco2'),
- ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tco3'),
- ConvBNLayer(out_channels[2], 2, 3, 1, act=None, name='f_tco4')
+ ConvBNLayer(in_channels, out_channels[0], 1, 1, act="relu", name="f_tco1"),
+ ConvBNLayer(
+ out_channels[0], out_channels[1], 3, 1, act="relu", name="f_tco2"
+ ),
+ ConvBNLayer(
+ out_channels[1], out_channels[2], 1, 1, act="relu", name="f_tco3"
+ ),
+ ConvBNLayer(out_channels[2], 2, 3, 1, act=None, name="f_tco4"),
)
def forward(self, x):
@@ -108,8 +132,8 @@ def forward(self, x):
class SASTHead(nn.Layer):
- """
- """
+ """ """
+
def __init__(self, in_channels, **kwargs):
super(SASTHead, self).__init__()
@@ -121,8 +145,8 @@ def forward(self, x, targets=None):
f_tvo, f_tco = self.head2(x)
predicts = {}
- predicts['f_score'] = f_score
- predicts['f_border'] = f_border
- predicts['f_tvo'] = f_tvo
- predicts['f_tco'] = f_tco
- return predicts
\ No newline at end of file
+ predicts["f_score"] = f_score
+ predicts["f_border"] = f_border
+ predicts["f_tvo"] = f_tvo
+ predicts["f_tco"] = f_tco
+ return predicts
diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py
index 514962ef97..27b04727f8 100644
--- a/ppocr/modeling/heads/e2e_pg_head.py
+++ b/ppocr/modeling/heads/e2e_pg_head.py
@@ -24,16 +24,18 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -44,8 +46,9 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(
num_channels=out_channels,
@@ -54,7 +57,8 @@ def __init__(self,
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
- use_global_stats=False)
+ use_global_stats=False,
+ )
def forward(self, x):
x = self.conv(x)
@@ -63,13 +67,11 @@ def forward(self, x):
class PGHead(nn.Layer):
- """
- """
+ """ """
- def __init__(self,
- in_channels,
- character_dict_path='ppocr/utils/ic15_dict.txt',
- **kwargs):
+ def __init__(
+ self, in_channels, character_dict_path="ppocr/utils/ic15_dict.txt", **kwargs
+ ):
super(PGHead, self).__init__()
# get character_length
@@ -83,24 +85,27 @@ def __init__(self,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_score{}".format(1))
+ act="relu",
+ name="conv_f_score{}".format(1),
+ )
self.conv_f_score2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
- act='relu',
- name="conv_f_score{}".format(2))
+ act="relu",
+ name="conv_f_score{}".format(2),
+ )
self.conv_f_score3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_score{}".format(3))
+ act="relu",
+ name="conv_f_score{}".format(3),
+ )
self.conv1 = nn.Conv2D(
in_channels=128,
@@ -110,7 +115,8 @@ def __init__(self,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
- bias_attr=False)
+ bias_attr=False,
+ )
self.conv_f_boder1 = ConvBNLayer(
in_channels=in_channels,
@@ -118,24 +124,27 @@ def __init__(self,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_boder{}".format(1))
+ act="relu",
+ name="conv_f_boder{}".format(1),
+ )
self.conv_f_boder2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
- act='relu',
- name="conv_f_boder{}".format(2))
+ act="relu",
+ name="conv_f_boder{}".format(2),
+ )
self.conv_f_boder3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_boder{}".format(3))
+ act="relu",
+ name="conv_f_boder{}".format(3),
+ )
self.conv2 = nn.Conv2D(
in_channels=128,
out_channels=4,
@@ -144,47 +153,53 @@ def __init__(self,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
- bias_attr=False)
+ bias_attr=False,
+ )
self.conv_f_char1 = ConvBNLayer(
in_channels=in_channels,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_char{}".format(1))
+ act="relu",
+ name="conv_f_char{}".format(1),
+ )
self.conv_f_char2 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
- act='relu',
- name="conv_f_char{}".format(2))
+ act="relu",
+ name="conv_f_char{}".format(2),
+ )
self.conv_f_char3 = ConvBNLayer(
in_channels=128,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_char{}".format(3))
+ act="relu",
+ name="conv_f_char{}".format(3),
+ )
self.conv_f_char4 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
- act='relu',
- name="conv_f_char{}".format(4))
+ act="relu",
+ name="conv_f_char{}".format(4),
+ )
self.conv_f_char5 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_char{}".format(5))
+ act="relu",
+ name="conv_f_char{}".format(5),
+ )
self.conv3 = nn.Conv2D(
in_channels=256,
out_channels=character_length,
@@ -193,7 +208,8 @@ def __init__(self,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
- bias_attr=False)
+ bias_attr=False,
+ )
self.conv_f_direc1 = ConvBNLayer(
in_channels=in_channels,
@@ -201,24 +217,27 @@ def __init__(self,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_direc{}".format(1))
+ act="relu",
+ name="conv_f_direc{}".format(1),
+ )
self.conv_f_direc2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
- act='relu',
- name="conv_f_direc{}".format(2))
+ act="relu",
+ name="conv_f_direc{}".format(2),
+ )
self.conv_f_direc3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
- act='relu',
- name="conv_f_direc{}".format(3))
+ act="relu",
+ name="conv_f_direc{}".format(3),
+ )
self.conv4 = nn.Conv2D(
in_channels=128,
out_channels=2,
@@ -227,7 +246,8 @@ def __init__(self,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
- bias_attr=False)
+ bias_attr=False,
+ )
def forward(self, x, targets=None):
f_score = self.conv_f_score1(x)
@@ -255,8 +275,8 @@ def forward(self, x, targets=None):
f_direction = self.conv4(f_direction)
predicts = {}
- predicts['f_score'] = f_score
- predicts['f_border'] = f_border
- predicts['f_char'] = f_char
- predicts['f_direction'] = f_direction
+ predicts["f_score"] = f_score
+ predicts["f_border"] = f_border
+ predicts["f_char"] = f_char
+ predicts["f_direction"] = f_direction
return predicts
diff --git a/ppocr/modeling/heads/gcn.py b/ppocr/modeling/heads/gcn.py
index d123f067cb..6e6e2ea5f8 100644
--- a/ppocr/modeling/heads/gcn.py
+++ b/ppocr/modeling/heads/gcn.py
@@ -26,12 +26,14 @@
class BatchNorm1D(nn.BatchNorm1D):
- def __init__(self,
- num_features,
- eps=1e-05,
- momentum=0.1,
- affine=True,
- track_running_stats=True):
+ def __init__(
+ self,
+ num_features,
+ eps=1e-05,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ ):
momentum = 1 - momentum
weight_attr = None
bias_attr = None
@@ -44,7 +46,8 @@ def __init__(self,
epsilon=eps,
weight_attr=weight_attr,
bias_attr=bias_attr,
- use_global_stats=track_running_stats)
+ use_global_stats=track_running_stats,
+ )
class MeanAggregator(nn.Layer):
@@ -59,12 +62,13 @@ def __init__(self, in_dim, out_dim):
self.in_dim = in_dim
self.out_dim = out_dim
self.weight = self.create_parameter(
- [in_dim * 2, out_dim],
- default_initializer=nn.initializer.XavierUniform())
+ [in_dim * 2, out_dim], default_initializer=nn.initializer.XavierUniform()
+ )
self.bias = self.create_parameter(
[out_dim],
is_bias=True,
- default_initializer=nn.initializer.Assign([0] * out_dim))
+ default_initializer=nn.initializer.Assign([0] * out_dim),
+ )
self.aggregator = MeanAggregator()
@@ -73,7 +77,7 @@ def forward(self, features, A):
assert d == self.in_dim
agg_feats = self.aggregator(features, A)
cat_feats = paddle.concat([features, agg_feats], axis=2)
- out = paddle.einsum('bnd,df->bnf', cat_feats, self.weight)
+ out = paddle.einsum("bnd,df->bnf", cat_feats, self.weight)
out = F.relu(out + self.bias)
return out
@@ -87,10 +91,10 @@ def __init__(self, feat_len):
self.conv3 = GraphConv(256, 128)
self.conv4 = GraphConv(128, 64)
self.classifier = nn.Sequential(
- nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2))
+ nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2)
+ )
def forward(self, x, A, knn_inds):
-
num_local_graphs, num_max_nodes, feat_len = x.shape
x = x.reshape([-1, feat_len])
@@ -105,8 +109,9 @@ def forward(self, x, A, knn_inds):
mid_feat_len = x.shape[-1]
edge_feat = paddle.zeros([num_local_graphs, k, mid_feat_len])
for graph_ind in range(num_local_graphs):
- edge_feat[graph_ind, :, :] = x[graph_ind][paddle.to_tensor(knn_inds[
- graph_ind])]
+ edge_feat[graph_ind, :, :] = x[graph_ind][
+ paddle.to_tensor(knn_inds[graph_ind])
+ ]
edge_feat = edge_feat.reshape([-1, mid_feat_len])
pred = self.classifier(edge_feat)
diff --git a/ppocr/modeling/heads/kie_sdmgr_head.py b/ppocr/modeling/heads/kie_sdmgr_head.py
index ac5f73fa7e..bc019ec802 100644
--- a/ppocr/modeling/heads/kie_sdmgr_head.py
+++ b/ppocr/modeling/heads/kie_sdmgr_head.py
@@ -25,28 +25,30 @@
class SDMGRHead(nn.Layer):
- def __init__(self,
- in_channels,
- num_chars=92,
- visual_dim=16,
- fusion_dim=1024,
- node_input=32,
- node_embed=256,
- edge_input=5,
- edge_embed=256,
- num_gnn=2,
- num_classes=26,
- bidirectional=False):
+ def __init__(
+ self,
+ in_channels,
+ num_chars=92,
+ visual_dim=16,
+ fusion_dim=1024,
+ node_input=32,
+ node_embed=256,
+ edge_input=5,
+ edge_embed=256,
+ num_gnn=2,
+ num_classes=26,
+ bidirectional=False,
+ ):
super().__init__()
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
self.node_embed = nn.Embedding(num_chars, node_input, 0)
hidden = node_embed // 2 if bidirectional else node_embed
- self.rnn = nn.LSTM(
- input_size=node_input, hidden_size=hidden, num_layers=1)
+ self.rnn = nn.LSTM(input_size=node_input, hidden_size=hidden, num_layers=1)
self.edge_embed = nn.Linear(edge_input, edge_embed)
self.gnn_layers = nn.LayerList(
- [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
+ [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]
+ )
self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2)
@@ -58,12 +60,14 @@ def forward(self, input, targets):
char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
max_num = max([char_num.max() for char_num in char_nums])
- all_nodes = paddle.concat([
- paddle.concat(
- [text, paddle.zeros(
- (text.shape[0], max_num - text.shape[1]))], -1)
- for text in texts
- ])
+ all_nodes = paddle.concat(
+ [
+ paddle.concat(
+ [text, paddle.zeros((text.shape[0], max_num - text.shape[1]))], -1
+ )
+ for text in texts
+ ]
+ )
temp = paddle.clip(all_nodes, min=0).astype(int)
embed_nodes = self.node_embed(temp)
rnn_nodes, _ = self.rnn(embed_nodes)
@@ -72,17 +76,17 @@ def forward(self, input, targets):
nodes = paddle.zeros([b, w])
all_nums = paddle.concat(char_nums)
valid = paddle.nonzero((all_nums > 0).astype(int))
- temp_all_nums = (
- paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
- temp_all_nums = paddle.expand(temp_all_nums, [
- temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
- ])
+ temp_all_nums = (paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
+ temp_all_nums = paddle.expand(
+ temp_all_nums,
+ [temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]],
+ )
temp_all_nodes = paddle.gather(rnn_nodes, valid)
N, C, A = temp_all_nodes.shape
- one_hot = F.one_hot(
- temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
- one_hot = paddle.multiply(
- temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True)
+ one_hot = F.one_hot(temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
+ one_hot = paddle.multiply(temp_all_nodes, one_hot.astype("float32")).sum(
+ axis=1, keepdim=True
+ )
t = one_hot.expand([N, 1, A]).squeeze(1)
nodes = paddle.scatter(nodes, valid.squeeze(1), t)
@@ -90,8 +94,9 @@ def forward(self, input, targets):
nodes = self.fusion([x, nodes])
all_edges = paddle.concat(
- [rel.reshape([-1, rel.shape[-1]]) for rel in relations])
- embed_edges = self.edge_embed(all_edges.astype('float32'))
+ [rel.reshape([-1, rel.shape[-1]]) for rel in relations]
+ )
+ embed_edges = self.edge_embed(all_edges.astype("float32"))
embed_edges = F.normalize(embed_edges)
for gnn_layer in self.gnn_layers:
@@ -112,12 +117,16 @@ def __init__(self, node_dim=256, edge_dim=256):
def forward(self, nodes, edges, nums):
start, cat_nodes = 0, []
for num in nums:
- sample_nodes = nodes[start:start + num]
+ sample_nodes = nodes[start : start + num]
cat_nodes.append(
- paddle.concat([
- paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
- paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1])
- ], -1).reshape([num**2, -1]))
+ paddle.concat(
+ [
+ paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
+ paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1]),
+ ],
+ -1,
+ ).reshape([num**2, -1])
+ )
start += num
cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
cat_nodes = self.relu(self.in_fc(cat_nodes))
@@ -126,10 +135,16 @@ def forward(self, nodes, edges, nums):
start, residuals = 0, []
for num in nums:
residual = F.softmax(
- -paddle.eye(num).unsqueeze(-1) * 1e9 +
- coefs[start:start + num**2].reshape([num, num, -1]), 1)
- residuals.append((residual * cat_nodes[start:start + num**2]
- .reshape([num, num, -1])).sum(1))
+ -paddle.eye(num).unsqueeze(-1) * 1e9
+ + coefs[start : start + num**2].reshape([num, num, -1]),
+ 1,
+ )
+ residuals.append(
+ (
+ residual
+ * cat_nodes[start : start + num**2].reshape([num, num, -1])
+ ).sum(1)
+ )
start += num**2
nodes += self.relu(self.out_fc(paddle.concat(residuals)))
@@ -137,28 +152,29 @@ def forward(self, nodes, edges, nums):
class Block(nn.Layer):
- def __init__(self,
- input_dims,
- output_dim,
- mm_dim=1600,
- chunks=20,
- rank=15,
- shared=False,
- dropout_input=0.,
- dropout_pre_lin=0.,
- dropout_output=0.,
- pos_norm='before_cat'):
+ def __init__(
+ self,
+ input_dims,
+ output_dim,
+ mm_dim=1600,
+ chunks=20,
+ rank=15,
+ shared=False,
+ dropout_input=0.0,
+ dropout_pre_lin=0.0,
+ dropout_output=0.0,
+ pos_norm="before_cat",
+ ):
super().__init__()
self.rank = rank
self.dropout_input = dropout_input
self.dropout_pre_lin = dropout_pre_lin
self.dropout_output = dropout_output
- assert (pos_norm in ['before_cat', 'after_cat'])
+ assert pos_norm in ["before_cat", "after_cat"]
self.pos_norm = pos_norm
# Modules
self.linear0 = nn.Linear(input_dims[0], mm_dim)
- self.linear1 = (self.linear0
- if shared else nn.Linear(input_dims[1], mm_dim))
+ self.linear1 = self.linear0 if shared else nn.Linear(input_dims[1], mm_dim)
self.merge_linears0 = nn.LayerList()
self.merge_linears1 = nn.LayerList()
self.chunks = self.chunk_sizes(mm_dim, chunks)
@@ -179,17 +195,18 @@ def forward(self, x):
x0_chunks = paddle.split(x0, self.chunks, -1)
x1_chunks = paddle.split(x1, self.chunks, -1)
zs = []
- for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0,
- self.merge_linears1):
+ for x0_c, x1_c, m0, m1 in zip(
+ x0_chunks, x1_chunks, self.merge_linears0, self.merge_linears1
+ ):
m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
m = m.reshape([bs, self.rank, -1])
z = paddle.sum(m, 1)
- if self.pos_norm == 'before_cat':
+ if self.pos_norm == "before_cat":
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
z = F.normalize(z)
zs.append(z)
z = paddle.concat(zs, 1)
- if self.pos_norm == 'after_cat':
+ if self.pos_norm == "after_cat":
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
z = F.normalize(z)
diff --git a/ppocr/modeling/heads/local_graph.py b/ppocr/modeling/heads/local_graph.py
index 50fe6d7223..f65c9aa02b 100644
--- a/ppocr/modeling/heads/local_graph.py
+++ b/ppocr/modeling/heads/local_graph.py
@@ -89,45 +89,61 @@ def feature_embedding(input_feats, out_feat_len):
residue_dim = out_feat_len % feat_dim
if residue_dim > 0:
- embed_wave = np.array([
- np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
- for j in range(feat_repeat_times + 1)
- ]).reshape((feat_repeat_times + 1, 1, 1))
+ embed_wave = np.array(
+ [
+ np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
+ for j in range(feat_repeat_times + 1)
+ ]
+ ).reshape((feat_repeat_times + 1, 1, 1))
repeat_feats = np.repeat(
- np.expand_dims(
- input_feats, axis=0), feat_repeat_times, axis=0)
- residue_feats = np.hstack([
- input_feats[:, 0:residue_dim], np.zeros(
- (num_nodes, feat_dim - residue_dim))
- ])
+ np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0
+ )
+ residue_feats = np.hstack(
+ [
+ input_feats[:, 0:residue_dim],
+ np.zeros((num_nodes, feat_dim - residue_dim)),
+ ]
+ )
residue_feats = np.expand_dims(residue_feats, axis=0)
repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
- (num_nodes, -1))[:, 0:out_feat_len]
+ (num_nodes, -1)
+ )[:, 0:out_feat_len]
else:
- embed_wave = np.array([
- np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
- for j in range(feat_repeat_times)
- ]).reshape((feat_repeat_times, 1, 1))
+ embed_wave = np.array(
+ [
+ np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
+ for j in range(feat_repeat_times)
+ ]
+ ).reshape((feat_repeat_times, 1, 1))
repeat_feats = np.repeat(
- np.expand_dims(
- input_feats, axis=0), feat_repeat_times, axis=0)
+ np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0
+ )
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
- embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
- (num_nodes, -1)).astype(np.float32)
+ embedded_feats = (
+ np.transpose(embedded_feats, (1, 0, 2))
+ .reshape((num_nodes, -1))
+ .astype(np.float32)
+ )
return embedded_feats
class LocalGraphs:
- def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
- pooling_scale, pooling_output_size, local_graph_thr):
-
+ def __init__(
+ self,
+ k_at_hops,
+ num_adjacent_linkages,
+ node_geo_feat_len,
+ pooling_scale,
+ pooling_output_size,
+ local_graph_thr,
+ ):
assert len(k_at_hops) == 2
assert all(isinstance(n, int) for n in k_at_hops)
assert isinstance(num_adjacent_linkages, int)
@@ -160,20 +176,22 @@ def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
"""
assert sorted_dist_inds.ndim == 2
- assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
- gt_comp_labels.shape[0])
+ assert (
+ sorted_dist_inds.shape[0]
+ == sorted_dist_inds.shape[1]
+ == gt_comp_labels.shape[0]
+ )
- knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
+ knn_graph = sorted_dist_inds[:, 1 : self.k_at_hops[0] + 1]
pivot_local_graphs = []
pivot_knns = []
for pivot_ind, knn in enumerate(knn_graph):
-
local_graph_neighbors = set(knn)
for neighbor_ind in knn:
local_graph_neighbors.update(
- set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
- 1]))
+ set(sorted_dist_inds[neighbor_ind, 1 : self.k_at_hops[1] + 1])
+ )
local_graph_neighbors.discard(pivot_ind)
pivot_local_graph = list(local_graph_neighbors)
@@ -190,18 +208,21 @@ def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
added_local_graph = pivot_local_graphs[graph_ind]
union = len(
- set(pivot_local_graph[1:]).union(
- set(added_local_graph[1:])))
+ set(pivot_local_graph[1:]).union(set(added_local_graph[1:]))
+ )
intersect = len(
set(pivot_local_graph[1:]).intersection(
- set(added_local_graph[1:])))
+ set(added_local_graph[1:])
+ )
+ )
local_graph_iou = intersect / (union + 1e-8)
- if (local_graph_iou > self.local_graph_thr and
- pivot_ind in added_knn and
- gt_comp_labels[added_pivot_ind] ==
- gt_comp_labels[pivot_ind] and
- gt_comp_labels[pivot_ind] != 0):
+ if (
+ local_graph_iou > self.local_graph_thr
+ and pivot_ind in added_knn
+ and gt_comp_labels[added_pivot_ind] == gt_comp_labels[pivot_ind]
+ and gt_comp_labels[pivot_ind] != 0
+ ):
add_flag = False
break
if add_flag:
@@ -210,8 +231,14 @@ def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
return pivot_local_graphs, pivot_knns
- def generate_gcn_input(self, node_feat_batch, node_label_batch,
- local_graph_batch, knn_batch, sorted_dist_ind_batch):
+ def generate_gcn_input(
+ self,
+ node_feat_batch,
+ node_label_batch,
+ local_graph_batch,
+ knn_batch,
+ sorted_dist_ind_batch,
+ ):
"""Generate graph convolution network input data.
Args:
@@ -239,11 +266,13 @@ def generate_gcn_input(self, node_feat_batch, node_label_batch,
assert isinstance(knn_batch, list)
assert isinstance(sorted_dist_ind_batch, list)
- num_max_nodes = max([
- len(pivot_local_graph)
- for pivot_local_graphs in local_graph_batch
- for pivot_local_graph in pivot_local_graphs
- ])
+ num_max_nodes = max(
+ [
+ len(pivot_local_graph)
+ for pivot_local_graphs in local_graph_batch
+ for pivot_local_graph in pivot_local_graphs
+ ]
+ )
local_graphs_node_feat = []
adjacent_matrices = []
@@ -262,42 +291,47 @@ def generate_gcn_input(self, node_feat_batch, node_label_batch,
pivot_ind = pivot_local_graph[0]
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
- knn_inds = paddle.to_tensor(
- [node2ind_map[i] for i in pivot_knn[1:]])
+ knn_inds = paddle.to_tensor([node2ind_map[i] for i in pivot_knn[1:]])
pivot_feats = node_feats[pivot_ind]
- normalized_feats = node_feats[paddle.to_tensor(
- pivot_local_graph)] - pivot_feats
+ normalized_feats = (
+ node_feats[paddle.to_tensor(pivot_local_graph)] - pivot_feats
+ )
- adjacent_matrix = np.zeros(
- (num_nodes, num_nodes), dtype=np.float32)
+ adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
for node in pivot_local_graph:
- neighbors = sorted_dist_inds[node, 1:
- self.num_adjacent_linkages + 1]
+ neighbors = sorted_dist_inds[
+ node, 1 : self.num_adjacent_linkages + 1
+ ]
for neighbor in neighbors:
if neighbor in pivot_local_graph:
-
- adjacent_matrix[node2ind_map[node], node2ind_map[
- neighbor]] = 1
- adjacent_matrix[node2ind_map[neighbor],
- node2ind_map[node]] = 1
+ adjacent_matrix[
+ node2ind_map[node], node2ind_map[neighbor]
+ ] = 1
+ adjacent_matrix[
+ node2ind_map[neighbor], node2ind_map[node]
+ ] = 1
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
- pad_adjacent_matrix = paddle.zeros(
- (num_max_nodes, num_max_nodes))
+ pad_adjacent_matrix = paddle.zeros((num_max_nodes, num_max_nodes))
pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
- paddle.to_tensor(adjacent_matrix), 'float32')
+ paddle.to_tensor(adjacent_matrix), "float32"
+ )
pad_normalized_feats = paddle.concat(
[
- normalized_feats, paddle.zeros(
- (num_max_nodes - num_nodes,
- normalized_feats.shape[1]))
+ normalized_feats,
+ paddle.zeros(
+ (num_max_nodes - num_nodes, normalized_feats.shape[1])
+ ),
],
- axis=0)
+ axis=0,
+ )
local_graph_labels = node_labels[pivot_local_graph]
knn_labels = local_graph_labels[knn_inds.numpy()]
- link_labels = ((node_labels[pivot_ind] == knn_labels) &
- (node_labels[pivot_ind] > 0)).astype(np.int64)
+ link_labels = (
+ (node_labels[pivot_ind] == knn_labels)
+ & (node_labels[pivot_ind] > 0)
+ ).astype(np.int64)
link_labels = paddle.to_tensor(link_labels)
local_graphs_node_feat.append(pad_normalized_feats)
@@ -310,8 +344,12 @@ def generate_gcn_input(self, node_feat_batch, node_label_batch,
pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
pivots_gt_linkage = paddle.stack(pivots_gt_linkage, 0)
- return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
- pivots_gt_linkage)
+ return (
+ local_graphs_node_feat,
+ adjacent_matrices,
+ pivots_knn_inds,
+ pivots_gt_linkage,
+ )
def __call__(self, feat_maps, comp_attribs):
"""Generate local graphs as GCN input.
@@ -343,34 +381,32 @@ def __call__(self, feat_maps, comp_attribs):
for batch_ind in range(comp_attribs.shape[0]):
num_comps = int(comp_attribs[batch_ind, 0, 0])
comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
- node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(
- np.int32)
+ node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(np.int32)
comp_centers = comp_geo_attribs[:, 0:2]
- distance_matrix = euclidean_distance_matrix(comp_centers,
- comp_centers)
+ distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
- batch_id = np.zeros(
- (comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
+ batch_id = (
+ np.zeros((comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
+ )
comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
- comp_geo_attribs[:, -1])
+ comp_geo_attribs[:, -1]
+ )
angle = angle.reshape((-1, 1))
- rotated_rois = np.hstack(
- [batch_id, comp_geo_attribs[:, :-2], angle])
+ rotated_rois = np.hstack([batch_id, comp_geo_attribs[:, :-2], angle])
rois = paddle.to_tensor(rotated_rois)
- content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0),
- rois)
+ content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0), rois)
content_feats = content_feats.reshape([content_feats.shape[0], -1])
- geo_feats = feature_embedding(comp_geo_attribs,
- self.node_geo_feat_dim)
+ geo_feats = feature_embedding(comp_geo_attribs, self.node_geo_feat_dim)
geo_feats = paddle.to_tensor(geo_feats)
node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
pivot_local_graphs, pivot_knns = self.generate_local_graphs(
- sorted_dist_inds, node_labels)
+ sorted_dist_inds, node_labels
+ )
node_feat_batch.append(node_feats)
node_label_batch.append(node_labels)
@@ -378,11 +414,12 @@ def __call__(self, feat_maps, comp_attribs):
knn_batch.append(pivot_knns)
sorted_dist_inds_batch.append(sorted_dist_inds)
- (node_feats, adjacent_matrices, knn_inds, gt_linkage) = \
- self.generate_gcn_input(node_feat_batch,
- node_label_batch,
- local_graph_batch,
- knn_batch,
- sorted_dist_inds_batch)
+ (node_feats, adjacent_matrices, knn_inds, gt_linkage) = self.generate_gcn_input(
+ node_feat_batch,
+ node_label_batch,
+ local_graph_batch,
+ knn_batch,
+ sorted_dist_inds_batch,
+ )
return node_feats, adjacent_matrices, knn_inds, gt_linkage
diff --git a/ppocr/modeling/heads/proposal_local_graph.py b/ppocr/modeling/heads/proposal_local_graph.py
index a48656135b..7bf0765dda 100644
--- a/ppocr/modeling/heads/proposal_local_graph.py
+++ b/ppocr/modeling/heads/proposal_local_graph.py
@@ -28,29 +28,44 @@
from lanms import merge_quadrangle_n9 as la_nms
from ppocr.ext_op import RoIAlignRotated
-from .local_graph import (euclidean_distance_matrix, feature_embedding,
- normalize_adjacent_matrix)
+from .local_graph import (
+ euclidean_distance_matrix,
+ feature_embedding,
+ normalize_adjacent_matrix,
+)
def fill_hole(input_mask):
h, w = input_mask.shape
canvas = np.zeros((h + 2, w + 2), np.uint8)
- canvas[1:h + 1, 1:w + 1] = input_mask.copy()
+ canvas[1 : h + 1, 1 : w + 1] = input_mask.copy()
mask = np.zeros((h + 4, w + 4), np.uint8)
cv2.floodFill(canvas, mask, (0, 0), 1)
- canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool_)
+ canvas = canvas[1 : h + 1, 1 : w + 1].astype(np.bool_)
return ~canvas | input_mask
class ProposalLocalGraphs:
- def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
- pooling_scale, pooling_output_size, nms_thr, min_width,
- max_width, comp_shrink_ratio, comp_w_h_ratio, comp_score_thr,
- text_region_thr, center_region_thr, center_region_area_thr):
-
+ def __init__(
+ self,
+ k_at_hops,
+ num_adjacent_linkages,
+ node_geo_feat_len,
+ pooling_scale,
+ pooling_output_size,
+ nms_thr,
+ min_width,
+ max_width,
+ comp_shrink_ratio,
+ comp_w_h_ratio,
+ comp_score_thr,
+ text_region_thr,
+ center_region_thr,
+ center_region_area_thr,
+ ):
assert len(k_at_hops) == 2
assert isinstance(k_at_hops, tuple)
assert isinstance(num_adjacent_linkages, int)
@@ -82,9 +97,19 @@ def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
- def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
- cos_map, comp_score_thr, min_width, max_width,
- comp_shrink_ratio, comp_w_h_ratio):
+ def propose_comps(
+ self,
+ score_map,
+ top_height_map,
+ bot_height_map,
+ sin_map,
+ cos_map,
+ comp_score_thr,
+ min_width,
+ max_width,
+ comp_shrink_ratio,
+ comp_w_h_ratio,
+ ):
"""Propose text components.
Args:
@@ -116,10 +141,8 @@ def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
- top_mid_pts = comp_centers + np.hstack(
- [top_height * sin, top_height * cos])
- bot_mid_pts = comp_centers - np.hstack(
- [bot_height * sin, bot_height * cos])
+ top_mid_pts = comp_centers + np.hstack([top_height * sin, top_height * cos])
+ bot_mid_pts = comp_centers - np.hstack([bot_height * sin, bot_height * cos])
width = (top_height + bot_height) * comp_w_h_ratio
width = np.clip(width, min_width, max_width)
@@ -136,9 +159,15 @@ def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
return text_comps
- def propose_comps_and_attribs(self, text_region_map, center_region_map,
- top_height_map, bot_height_map, sin_map,
- cos_map):
+ def propose_comps_and_attribs(
+ self,
+ text_region_map,
+ center_region_map,
+ top_height_map,
+ bot_height_map,
+ sin_map,
+ cos_map,
+ ):
"""Generate text components and attributes.
Args:
@@ -158,20 +187,24 @@ def propose_comps_and_attribs(self, text_region_map, center_region_map,
text_comps (ndarray): The text components.
"""
- assert (text_region_map.shape == center_region_map.shape ==
- top_height_map.shape == bot_height_map.shape == sin_map.shape ==
- cos_map.shape)
+ assert (
+ text_region_map.shape
+ == center_region_map.shape
+ == top_height_map.shape
+ == bot_height_map.shape
+ == sin_map.shape
+ == cos_map.shape
+ )
text_mask = text_region_map > self.text_region_thr
- center_region_mask = (
- center_region_map > self.center_region_thr) * text_mask
+ center_region_mask = (center_region_map > self.center_region_thr) * text_mask
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8))
sin_map, cos_map = sin_map * scale, cos_map * scale
center_region_mask = fill_hole(center_region_mask)
center_region_contours, _ = cv2.findContours(
- center_region_mask.astype(np.uint8), cv2.RETR_TREE,
- cv2.CHAIN_APPROX_SIMPLE)
+ center_region_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
+ )
mask_sz = center_region_map.shape
comp_list = []
@@ -183,14 +216,21 @@ def propose_comps_and_attribs(self, text_region_map, center_region_map,
score_map = text_region_map * current_center_mask
text_comps = self.propose_comps(
- score_map, top_height_map, bot_height_map, sin_map, cos_map,
- self.comp_score_thr, self.min_width, self.max_width,
- self.comp_shrink_ratio, self.comp_w_h_ratio)
+ score_map,
+ top_height_map,
+ bot_height_map,
+ sin_map,
+ cos_map,
+ self.comp_score_thr,
+ self.min_width,
+ self.max_width,
+ self.comp_shrink_ratio,
+ self.comp_w_h_ratio,
+ )
text_comps = la_nms(text_comps, self.nms_thr)
text_comp_mask = np.zeros(mask_sz)
- text_comp_boxes = text_comps[:, :8].reshape(
- (-1, 4, 2)).astype(np.int32)
+ text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2)).astype(np.int32)
cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1)
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
@@ -209,25 +249,25 @@ def propose_comps_and_attribs(self, text_region_map, center_region_map,
scores = []
for text_comp_box in text_comp_boxes:
- text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0,
- mask_sz[1] - 1)
- text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0,
- mask_sz[0] - 1)
+ text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0, mask_sz[1] - 1)
+ text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0, mask_sz[0] - 1)
min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
text_comp_box = text_comp_box - min_coord
- box_sz = (max_coord - min_coord + 1)
+ box_sz = max_coord - min_coord + 1
temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
- temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] + 1),
- min_coord[0]:(max_coord[0] + 1)]
+ temp_region_patch = text_region_map[
+ min_coord[1] : (max_coord[1] + 1), min_coord[0] : (max_coord[0] + 1)
+ ]
score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
scores.append(score)
scores = np.array(scores).reshape((-1, 1))
text_comps = np.hstack([text_comps[:, :-1], scores])
- h = top_height_map[y, x].reshape(
- (-1, 1)) + bot_height_map[y, x].reshape((-1, 1))
+ h = top_height_map[y, x].reshape((-1, 1)) + bot_height_map[y, x].reshape(
+ (-1, 1)
+ )
w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
@@ -257,21 +297,23 @@ def generate_local_graphs(self, sorted_dist_inds, node_feats):
"""
assert sorted_dist_inds.ndim == 2
- assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
- node_feats.shape[0])
+ assert (
+ sorted_dist_inds.shape[0]
+ == sorted_dist_inds.shape[1]
+ == node_feats.shape[0]
+ )
- knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
+ knn_graph = sorted_dist_inds[:, 1 : self.k_at_hops[0] + 1]
pivot_local_graphs = []
pivot_knns = []
for pivot_ind, knn in enumerate(knn_graph):
-
local_graph_neighbors = set(knn)
for neighbor_ind in knn:
local_graph_neighbors.update(
- set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
- 1]))
+ set(sorted_dist_inds[neighbor_ind, 1 : self.k_at_hops[1] + 1])
+ )
local_graph_neighbors.discard(pivot_ind)
pivot_local_graph = list(local_graph_neighbors)
@@ -281,9 +323,9 @@ def generate_local_graphs(self, sorted_dist_inds, node_feats):
pivot_local_graphs.append(pivot_local_graph)
pivot_knns.append(pivot_knn)
- num_max_nodes = max([
- len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
- ])
+ num_max_nodes = max(
+ [len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs]
+ )
local_graphs_node_feat = []
adjacent_matrices = []
@@ -297,42 +339,47 @@ def generate_local_graphs(self, sorted_dist_inds, node_feats):
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
knn_inds = paddle.cast(
- paddle.to_tensor([node2ind_map[i]
- for i in pivot_knn[1:]]), 'int64')
+ paddle.to_tensor([node2ind_map[i] for i in pivot_knn[1:]]), "int64"
+ )
pivot_feats = node_feats[pivot_ind]
- normalized_feats = node_feats[paddle.to_tensor(
- pivot_local_graph)] - pivot_feats
+ normalized_feats = (
+ node_feats[paddle.to_tensor(pivot_local_graph)] - pivot_feats
+ )
adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
for node in pivot_local_graph:
- neighbors = sorted_dist_inds[node, 1:self.active_connection + 1]
+ neighbors = sorted_dist_inds[node, 1 : self.active_connection + 1]
for neighbor in neighbors:
if neighbor in pivot_local_graph:
- adjacent_matrix[node2ind_map[node], node2ind_map[
- neighbor]] = 1
- adjacent_matrix[node2ind_map[neighbor], node2ind_map[
- node]] = 1
+ adjacent_matrix[node2ind_map[node], node2ind_map[neighbor]] = 1
+ adjacent_matrix[node2ind_map[neighbor], node2ind_map[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
- pad_adjacent_matrix = paddle.zeros((num_max_nodes, num_max_nodes), )
+ pad_adjacent_matrix = paddle.zeros(
+ (num_max_nodes, num_max_nodes),
+ )
pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
- paddle.to_tensor(adjacent_matrix), 'float32')
+ paddle.to_tensor(adjacent_matrix), "float32"
+ )
pad_normalized_feats = paddle.concat(
[
- normalized_feats, paddle.zeros(
+ normalized_feats,
+ paddle.zeros(
(num_max_nodes - num_nodes, normalized_feats.shape[1]),
- )
+ ),
],
- axis=0)
+ axis=0,
+ )
local_graph_nodes = paddle.to_tensor(pivot_local_graph)
local_graph_nodes = paddle.concat(
[
- local_graph_nodes, paddle.zeros(
- [num_max_nodes - num_nodes], dtype='int64')
+ local_graph_nodes,
+ paddle.zeros([num_max_nodes - num_nodes], dtype="int64"),
],
- axis=-1)
+ axis=-1,
+ )
local_graphs_node_feat.append(pad_normalized_feats)
adjacent_matrices.append(pad_adjacent_matrix)
@@ -344,8 +391,12 @@ def generate_local_graphs(self, sorted_dist_inds, node_feats):
pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
pivots_local_graphs = paddle.stack(pivots_local_graphs, 0)
- return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
- pivots_local_graphs)
+ return (
+ local_graphs_node_feat,
+ adjacent_matrices,
+ pivots_knn_inds,
+ pivots_local_graphs,
+ )
def __call__(self, preds, feat_maps):
"""Generate local graphs and graph convolutional network input data.
@@ -378,8 +429,13 @@ def __call__(self, preds, feat_maps):
pred_bot_height_map = preds[5].numpy()
comp_attribs, text_comps = self.propose_comps_and_attribs(
- pred_text_region, pred_center_region, pred_top_height_map,
- pred_bot_height_map, pred_sin_map, pred_cos_map)
+ pred_text_region,
+ pred_center_region,
+ pred_top_height_map,
+ pred_bot_height_map,
+ pred_sin_map,
+ pred_cos_map,
+ )
if comp_attribs is None or len(comp_attribs) < 2:
none_flag = True
@@ -403,10 +459,18 @@ def __call__(self, preds, feat_maps):
node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
- (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
- pivots_local_graphs) = self.generate_local_graphs(sorted_dist_inds,
- node_feats)
+ (
+ local_graphs_node_feat,
+ adjacent_matrices,
+ pivots_knn_inds,
+ pivots_local_graphs,
+ ) = self.generate_local_graphs(sorted_dist_inds, node_feats)
none_flag = False
- return none_flag, (local_graphs_node_feat, adjacent_matrices,
- pivots_knn_inds, pivots_local_graphs, text_comps)
+ return none_flag, (
+ local_graphs_node_feat,
+ adjacent_matrices,
+ pivots_knn_inds,
+ pivots_local_graphs,
+ text_comps,
+ )
diff --git a/ppocr/modeling/heads/rec_abinet_head.py b/ppocr/modeling/heads/rec_abinet_head.py
index a95f2f1164..d666846c1c 100644
--- a/ppocr/modeling/heads/rec_abinet_head.py
+++ b/ppocr/modeling/heads/rec_abinet_head.py
@@ -25,15 +25,17 @@
class BCNLanguage(nn.Layer):
- def __init__(self,
- d_model=512,
- nhead=8,
- num_layers=4,
- dim_feedforward=2048,
- dropout=0.,
- max_length=25,
- detach=True,
- num_classes=37):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_layers=4,
+ dim_feedforward=2048,
+ dropout=0.0,
+ max_length=25,
+ detach=True,
+ num_classes=37,
+ ):
super().__init__()
self.d_model = d_model
@@ -41,20 +43,26 @@ def __init__(self,
self.max_length = max_length + 1 # additional stop token
self.proj = nn.Linear(num_classes, d_model, bias_attr=False)
self.token_encoder = PositionalEncoding(
- dropout=0.1, dim=d_model, max_len=self.max_length)
+ dropout=0.1, dim=d_model, max_len=self.max_length
+ )
self.pos_encoder = PositionalEncoding(
- dropout=0, dim=d_model, max_len=self.max_length)
+ dropout=0, dim=d_model, max_len=self.max_length
+ )
- self.decoder = nn.LayerList([
- TransformerBlock(
- d_model=d_model,
- nhead=nhead,
- dim_feedforward=dim_feedforward,
- attention_dropout_rate=dropout,
- residual_dropout_rate=dropout,
- with_self_attn=False,
- with_cross_attn=True) for i in range(num_layers)
- ])
+ self.decoder = nn.LayerList(
+ [
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=False,
+ with_cross_attn=True,
+ )
+ for i in range(num_layers)
+ ]
+ )
self.cls = nn.Linear(d_model, num_classes)
@@ -64,7 +72,8 @@ def forward(self, tokens, lengths):
tokens: (B, N, C) where N is length, B is batch size and C is classes number
lengths: (B,)
"""
- if self.detach: tokens = tokens.detach()
+ if self.detach:
+ tokens = tokens.detach()
embed = self.proj(tokens) # (B, N, C)
embed = self.token_encoder(embed) # (B, N, C)
padding_mask = _get_mask(lengths, self.max_length)
@@ -81,61 +90,53 @@ def forward(self, tokens, lengths):
def encoder_layer(in_c, out_c, k=3, s=2, p=1):
return nn.Sequential(
- nn.Conv2D(in_c, out_c, k, s, p), nn.BatchNorm2D(out_c), nn.ReLU())
+ nn.Conv2D(in_c, out_c, k, s, p), nn.BatchNorm2D(out_c), nn.ReLU()
+ )
-def decoder_layer(in_c,
- out_c,
- k=3,
- s=1,
- p=1,
- mode='nearest',
- scale_factor=None,
- size=None):
- align_corners = False if mode == 'nearest' else True
+def decoder_layer(
+ in_c, out_c, k=3, s=1, p=1, mode="nearest", scale_factor=None, size=None
+):
+ align_corners = False if mode == "nearest" else True
return nn.Sequential(
nn.Upsample(
- size=size,
- scale_factor=scale_factor,
- mode=mode,
- align_corners=align_corners),
+ size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners
+ ),
nn.Conv2D(in_c, out_c, k, s, p),
nn.BatchNorm2D(out_c),
- nn.ReLU())
+ nn.ReLU(),
+ )
class PositionAttention(nn.Layer):
- def __init__(self,
- max_length,
- in_channels=512,
- num_channels=64,
- h=8,
- w=32,
- mode='nearest',
- **kwargs):
+ def __init__(
+ self,
+ max_length,
+ in_channels=512,
+ num_channels=64,
+ h=8,
+ w=32,
+ mode="nearest",
+ **kwargs
+ ):
super().__init__()
self.max_length = max_length
self.k_encoder = nn.Sequential(
- encoder_layer(
- in_channels, num_channels, s=(1, 2)),
- encoder_layer(
- num_channels, num_channels, s=(2, 2)),
- encoder_layer(
- num_channels, num_channels, s=(2, 2)),
- encoder_layer(
- num_channels, num_channels, s=(2, 2)))
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ )
self.k_decoder = nn.Sequential(
- decoder_layer(
- num_channels, num_channels, scale_factor=2, mode=mode),
- decoder_layer(
- num_channels, num_channels, scale_factor=2, mode=mode),
- decoder_layer(
- num_channels, num_channels, scale_factor=2, mode=mode),
- decoder_layer(
- num_channels, in_channels, size=(h, w), mode=mode))
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode),
+ )
self.pos_encoder = PositionalEncoding(
- dropout=0, dim=in_channels, max_len=max_length)
+ dropout=0, dim=in_channels, max_len=max_length
+ )
self.project = nn.Linear(in_channels, in_channels)
def forward(self, x):
@@ -155,53 +156,57 @@ def forward(self, x):
# calculate query vector
# TODO q=f(q,k)
- zeros = paddle.zeros(
- (B, self.max_length, C), dtype=x.dtype) # (B, N, C)
+ zeros = paddle.zeros((B, self.max_length, C), dtype=x.dtype) # (B, N, C)
q = self.pos_encoder(zeros) # (B, N, C)
q = self.project(q) # (B, N, C)
# calculate attention
- attn_scores = q @k.flatten(2) # (B, N, (H*W))
+ attn_scores = q @ k.flatten(2) # (B, N, (H*W))
attn_scores = attn_scores / (C**0.5)
attn_scores = F.softmax(attn_scores, axis=-1)
v = v.flatten(2).transpose([0, 2, 1]) # (B, (H*W), C)
- attn_vecs = attn_scores @v # (B, N, C)
+ attn_vecs = attn_scores @ v # (B, N, C)
return attn_vecs, attn_scores.reshape([0, self.max_length, H, W])
class ABINetHead(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- d_model=512,
- nhead=8,
- num_layers=3,
- dim_feedforward=2048,
- dropout=0.1,
- max_length=25,
- use_lang=False,
- iter_size=1,
- image_size=(32, 128)):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ d_model=512,
+ nhead=8,
+ num_layers=3,
+ dim_feedforward=2048,
+ dropout=0.1,
+ max_length=25,
+ use_lang=False,
+ iter_size=1,
+ image_size=(32, 128),
+ ):
super().__init__()
self.max_length = max_length + 1
h, w = image_size[0] // 4, image_size[1] // 4
- self.pos_encoder = PositionalEncoding(
- dropout=0.1, dim=d_model, max_len=h * w)
- self.encoder = nn.LayerList([
- TransformerBlock(
- d_model=d_model,
- nhead=nhead,
- dim_feedforward=dim_feedforward,
- attention_dropout_rate=dropout,
- residual_dropout_rate=dropout,
- with_self_attn=True,
- with_cross_attn=False) for i in range(num_layers)
- ])
+ self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model, max_len=h * w)
+ self.encoder = nn.LayerList(
+ [
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=True,
+ with_cross_attn=False,
+ )
+ for i in range(num_layers)
+ ]
+ )
self.decoder = PositionAttention(
- max_length=max_length + 1, # additional stop token
- mode='nearest', h=h, w=w)
+ max_length=max_length + 1, mode="nearest", h=h, w=w # additional stop token
+ )
self.out_channels = out_channels
self.cls = nn.Linear(d_model, self.out_channels)
self.use_lang = use_lang
@@ -214,7 +219,8 @@ def __init__(self,
dim_feedforward=dim_feedforward,
dropout=dropout,
max_length=max_length,
- num_classes=self.out_channels)
+ num_classes=self.out_channels,
+ )
# alignment
self.w_att_align = nn.Linear(2 * d_model, d_model)
self.cls_align = nn.Linear(d_model, self.out_channels)
@@ -227,8 +233,7 @@ def forward(self, x, targets=None):
for encoder_layer in self.encoder:
feature = encoder_layer(feature)
feature = feature.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
- v_feature, attn_scores = self.decoder(
- feature) # (B, N, C), (B, C, H, W)
+ v_feature, attn_scores = self.decoder(feature) # (B, N, C), (B, C, H, W)
vis_logits = self.cls(v_feature) # (B, N, C)
logits = vis_logits
vis_lengths = _get_length(vis_logits)
@@ -240,7 +245,8 @@ def forward(self, x, targets=None):
tokens = F.softmax(align_logits, axis=-1)
lengths = align_lengths
lengths = paddle.clip(
- lengths, 2, self.max_length) # TODO:move to langauge model
+ lengths, 2, self.max_length
+ ) # TODO:move to langauge model
l_feature, l_logits = self.language(tokens, lengths)
# alignment
@@ -253,11 +259,7 @@ def forward(self, x, targets=None):
align_lengths = _get_length(align_logits)
all_a_res.append(align_logits)
if self.training:
- return {
- 'align': all_a_res,
- 'lang': all_l_res,
- 'vision': vis_logits
- }
+ return {"align": all_a_res, "lang": all_l_res, "vision": vis_logits}
else:
logits = align_logits
if self.training:
@@ -267,12 +269,12 @@ def forward(self, x, targets=None):
def _get_length(logit):
- """ Greed decoder to obtain length from logit"""
- out = (logit.argmax(-1) == 0)
+ """Greed decoder to obtain length from logit"""
+ out = logit.argmax(-1) == 0
abn = out.any(-1)
- out_int = out.cast('int32')
+ out_int = out.cast("int32")
out = (out_int.cumsum(-1) == 1) & out
- out = out.cast('int32')
+ out = out.cast("int32")
out = out.argmax(-1)
out = out + 1
len_seq = paddle.zeros_like(out) + logit.shape[1]
@@ -282,18 +284,16 @@ def _get_length(logit):
def _get_mask(length, max_length):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
- Unmasked positions are filled with float(0.0).
+ Unmasked positions are filled with float(0.0).
"""
length = length.unsqueeze(-1)
B = length.shape[0]
grid = paddle.arange(0, max_length).unsqueeze(0).tile([B, 1])
- zero_mask = paddle.zeros([B, max_length], dtype='float32')
- inf_mask = paddle.full([B, max_length], '-inf', dtype='float32')
+ zero_mask = paddle.zeros([B, max_length], dtype="float32")
+ inf_mask = paddle.full([B, max_length], "-inf", dtype="float32")
diag_mask = paddle.diag(
- paddle.full(
- [max_length], '-inf', dtype=paddle.float32),
- offset=0,
- name=None)
+ paddle.full([max_length], "-inf", dtype=paddle.float32), offset=0, name=None
+ )
mask = paddle.where(grid >= length, inf_mask, zero_mask)
mask = mask.unsqueeze(1) + diag_mask
return mask.unsqueeze(1)
diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py
index 4e36e26ef3..ba0acaeebe 100644
--- a/ppocr/modeling/heads/rec_aster_head.py
+++ b/ppocr/modeling/heads/rec_aster_head.py
@@ -27,23 +27,26 @@
class AsterHead(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- sDim,
- attDim,
- max_len_labels,
- time_step=25,
- beam_width=5,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ sDim,
+ attDim,
+ max_len_labels,
+ time_step=25,
+ beam_width=5,
+ **kwargs
+ ):
super(AsterHead, self).__init__()
self.num_classes = out_channels
self.in_planes = in_channels
self.sDim = sDim
self.attDim = attDim
self.max_len_labels = max_len_labels
- self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
- attDim, max_len_labels)
+ self.decoder = AttentionRecognitionHead(
+ in_channels, out_channels, sDim, attDim, max_len_labels
+ )
self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width
@@ -55,16 +58,16 @@ def forward(self, x, targets=None, embed=None):
if self.training:
rec_targets, rec_lengths, _ = targets
- rec_pred = self.decoder([x, rec_targets, rec_lengths],
- embedding_vectors)
- return_dict['rec_pred'] = rec_pred
- return_dict['embedding_vectors'] = embedding_vectors
+ rec_pred = self.decoder([x, rec_targets, rec_lengths], embedding_vectors)
+ return_dict["rec_pred"] = rec_pred
+ return_dict["embedding_vectors"] = embedding_vectors
else:
rec_pred, rec_pred_scores = self.decoder.beam_search(
- x, self.beam_width, self.eos, embedding_vectors)
- return_dict['rec_pred'] = rec_pred
- return_dict['rec_pred_scores'] = rec_pred_scores
- return_dict['embedding_vectors'] = embedding_vectors
+ x, self.beam_width, self.eos, embedding_vectors
+ )
+ return_dict["rec_pred"] = rec_pred
+ return_dict["rec_pred_scores"] = rec_pred_scores
+ return_dict["embedding_vectors"] = embedding_vectors
return return_dict
@@ -77,8 +80,8 @@ def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
self.embed_dim = embed_dim
self.mid_dim = mid_dim
self.eEmbed = nn.Linear(
- in_timestep * in_planes,
- self.embed_dim) # Embed encoder output to a word-embedding like
+ in_timestep * in_planes, self.embed_dim
+ ) # Embed encoder output to a word-embedding like
def forward(self, x):
x = paddle.reshape(x, [x.shape[0], -1])
@@ -88,20 +91,23 @@ def forward(self, x):
class AttentionRecognitionHead(nn.Layer):
"""
- input: [b x 16 x 64 x in_planes]
- output: probability sequence: [b x T x num_classes]
- """
+ input: [b x 16 x 64 x in_planes]
+ output: probability sequence: [b x T x num_classes]
+ """
def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
super(AttentionRecognitionHead, self).__init__()
- self.num_classes = out_channels # this is the output classes. So it includes the .
+ self.num_classes = (
+ out_channels # this is the output classes. So it includes the .
+ )
self.in_planes = in_channels
self.sDim = sDim
self.attDim = attDim
self.max_len_labels = max_len_labels
self.decoder = DecoderUnit(
- sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
+ sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim
+ )
def forward(self, x, embed):
x, targets, lengths = x
@@ -111,8 +117,7 @@ def forward(self, x, embed):
outputs = []
for i in range(max(lengths)):
if i == 0:
- y_prev = paddle.full(
- shape=[batch_size], fill_value=self.num_classes)
+ y_prev = paddle.full(shape=[batch_size], fill_value=self.num_classes)
else:
y_prev = targets[:, i - 1]
output, state = self.decoder(x, state, y_prev)
@@ -130,8 +135,7 @@ def sample(self, x):
predicted_ids, predicted_scores = [], []
for i in range(self.max_len_labels):
if i == 0:
- y_prev = paddle.full(
- shape=[batch_size], fill_value=self.num_classes)
+ y_prev = paddle.full(shape=[batch_size], fill_value=self.num_classes)
else:
y_prev = predicted
@@ -155,27 +159,30 @@ def _inflate(tensor, times, dim):
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
batch_size, l, d = x.shape
x = paddle.tile(
- paddle.transpose(
- x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
+ paddle.transpose(x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1]
+ )
inflated_encoder_feats = paddle.reshape(
- paddle.transpose(
- x, perm=[1, 0, 2, 3]), [-1, l, d])
+ paddle.transpose(x, perm=[1, 0, 2, 3]), [-1, l, d]
+ )
# Initialize the decoder
state = self.decoder.get_initial_state(embed, tile_times=beam_width)
pos_index = paddle.reshape(
- paddle.arange(batch_size) * beam_width, shape=[-1, 1])
+ paddle.arange(batch_size) * beam_width, shape=[-1, 1]
+ )
# Initialize the scores
sequence_scores = paddle.full(
- shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
+ shape=[batch_size * beam_width, 1], fill_value=-float("Inf")
+ )
index = [i * beam_width for i in range(0, batch_size)]
sequence_scores[index] = 0.0
# Initialize the input vector
y_prev = paddle.full(
- shape=[batch_size * beam_width], fill_value=self.num_classes)
+ shape=[batch_size * beam_width], fill_value=self.num_classes
+ )
# Store decisions for backtracking
stored_scores = list()
@@ -185,30 +192,29 @@ def _inflate(tensor, times, dim):
for i in range(self.max_len_labels):
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
state = paddle.unsqueeze(state, axis=0)
- log_softmax_output = paddle.nn.functional.log_softmax(
- output, axis=1)
+ log_softmax_output = paddle.nn.functional.log_softmax(output, axis=1)
sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
sequence_scores += log_softmax_output
scores, candidates = paddle.topk(
- paddle.reshape(sequence_scores, [batch_size, -1]),
- beam_width,
- axis=1)
+ paddle.reshape(sequence_scores, [batch_size, -1]), beam_width, axis=1
+ )
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
y_prev = paddle.reshape(
- candidates % self.num_classes, shape=[batch_size * beam_width])
- sequence_scores = paddle.reshape(
- scores, shape=[batch_size * beam_width, 1])
+ candidates % self.num_classes, shape=[batch_size * beam_width]
+ )
+ sequence_scores = paddle.reshape(scores, shape=[batch_size * beam_width, 1])
# Update fields for next timestep
pos_index = paddle.expand_as(pos_index, candidates)
predecessors = paddle.cast(
- candidates / self.num_classes + pos_index, dtype='int64')
+ candidates / self.num_classes + pos_index, dtype="int64"
+ )
predecessors = paddle.reshape(
- predecessors, shape=[batch_size * beam_width, 1])
- state = paddle.index_select(
- state, index=predecessors.squeeze(), axis=1)
+ predecessors, shape=[batch_size * beam_width, 1]
+ )
+ state = paddle.index_select(state, index=predecessors.squeeze(), axis=1)
# Update sequence socres and erase scores for symbol so that they aren't expanded
stored_scores.append(sequence_scores.clone())
@@ -219,7 +225,7 @@ def _inflate(tensor, times, dim):
if mask.dim() > 0:
sequence_scores = sequence_scores.numpy()
mask = mask.numpy()
- sequence_scores[mask] = -float('inf')
+ sequence_scores[mask] = -float("inf")
sequence_scores = paddle.to_tensor(sequence_scores)
# Cache results for backtracking
@@ -228,18 +234,19 @@ def _inflate(tensor, times, dim):
stored_emitted_symbols.append(y_prev)
# Do backtracking to return the optimal values
- #====== backtrak ======#
+ # ====== backtrak ======#
# Initialize return variables given different types
p = list()
- l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
- ] # Placeholder for lengths of top-k sequences
+ l = [
+ [self.max_len_labels] * beam_width for _ in range(batch_size)
+ ] # Placeholder for lengths of top-k sequences
# the last step output of the beams are not sorted
# thus they are sorted here
sorted_score, sorted_idx = paddle.topk(
- paddle.reshape(
- stored_scores[-1], shape=[batch_size, beam_width]),
- beam_width)
+ paddle.reshape(stored_scores[-1], shape=[batch_size, beam_width]),
+ beam_width,
+ )
# initialize the sequence scores with the sorted last step beam scores
s = sorted_score.clone()
@@ -251,13 +258,16 @@ def _inflate(tensor, times, dim):
# add pos_index for indexing variable with b*k as the first dimension.
t_predecessors = paddle.reshape(
sorted_idx + pos_index.expand_as(sorted_idx),
- shape=[batch_size * beam_width])
+ shape=[batch_size * beam_width],
+ )
while t >= 0:
# Re-order the variables with the back pointer
current_symbol = paddle.index_select(
- stored_emitted_symbols[t], index=t_predecessors, axis=0)
+ stored_emitted_symbols[t], index=t_predecessors, axis=0
+ )
t_predecessors = paddle.index_select(
- stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
+ stored_predecessors[t].squeeze(), index=t_predecessors, axis=0
+ )
eos_indices = stored_emitted_symbols[t] == eos
eos_indices = paddle.nonzero(eos_indices)
@@ -270,8 +280,7 @@ def _inflate(tensor, times, dim):
b_idx = int(idx[0] / beam_width)
# The indices of the replacing position
# according to the replacement strategy noted above
- res_k_idx = beam_width - (batch_eos_found[b_idx] %
- beam_width) - 1
+ res_k_idx = beam_width - (batch_eos_found[b_idx] % beam_width) - 1
batch_eos_found[b_idx] += 1
res_idx = b_idx * beam_width + res_k_idx
@@ -290,20 +299,21 @@ def _inflate(tensor, times, dim):
# the order (very unlikely)
s, re_sorted_idx = s.topk(beam_width)
for b_idx in range(batch_size):
- l[b_idx] = [
- l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
- ]
+ l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]]
re_sorted_idx = paddle.reshape(
re_sorted_idx + pos_index.expand_as(re_sorted_idx),
- [batch_size * beam_width])
+ [batch_size * beam_width],
+ )
# Reverse the sequences and re-order at the same time
# It is reversed because the backtracking happens in reverse time order
p = [
paddle.reshape(
paddle.index_select(step, re_sorted_idx, 0),
- shape=[batch_size, beam_width, -1]) for step in reversed(p)
+ shape=[batch_size, beam_width, -1],
+ )
+ for step in reversed(p)
]
p = paddle.concat(p, -1)[:, 0, :]
return p, paddle.ones_like(p)
@@ -330,8 +340,7 @@ def forward(self, x, sPrev):
sPrev = sPrev.squeeze(0)
sProj = self.sEmbed(sPrev) # [b x attDim]
sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
- sProj = paddle.expand(sProj,
- [batch_size, T, self.attDim]) # [b x T x attDim]
+ sProj = paddle.expand(sProj, [batch_size, T, self.attDim]) # [b x T x attDim]
sumTanh = paddle.tanh(sProj + xProj)
sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
@@ -339,7 +348,8 @@ def forward(self, x, sPrev):
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
vProj = paddle.reshape(vProj, [batch_size, T])
alpha = F.softmax(
- vProj, axis=1) # attention weights for each sample in the minibatch
+ vProj, axis=1
+ ) # attention weights for each sample in the minibatch
return alpha
@@ -354,14 +364,15 @@ def __init__(self, sDim, xDim, yDim, attDim):
self.attention_unit = AttentionUnit(sDim, xDim, attDim)
self.tgt_embedding = nn.Embedding(
- yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
- std=0.01)) # the last is used for
+ yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(std=0.01)
+ ) # the last is used for
self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
self.fc = nn.Linear(
sDim,
yDim,
weight_attr=nn.initializer.Normal(std=0.01),
- bias_attr=nn.initializer.Constant(value=0))
+ bias_attr=nn.initializer.Constant(value=0),
+ )
self.embed_fc = nn.Linear(300, self.sDim)
def get_initial_state(self, embed, tile_times=1):
@@ -390,4 +401,4 @@ def forward(self, x, sPrev, yPrev):
output, state = self.gru(concat_context, sPrev)
output = paddle.squeeze(output, axis=1)
output = self.fc(output)
- return output, state
\ No newline at end of file
+ return output, state
diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py
index 7bd5bcc039..2c952cea49 100644
--- a/ppocr/modeling/heads/rec_att_head.py
+++ b/ppocr/modeling/heads/rec_att_head.py
@@ -30,7 +30,8 @@ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
self.num_classes = out_channels
self.attention_cell = AttentionGRUCell(
- in_channels, hidden_size, out_channels, use_gru=False)
+ in_channels, hidden_size, out_channels, use_gru=False
+ )
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
@@ -47,9 +48,11 @@ def forward(self, inputs, targets=None, batch_max_length=25):
if targets is not None:
for i in range(num_steps):
char_onehots = self._char_to_onehot(
- targets[:, i], onehot_dim=self.num_classes)
- (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
- char_onehots)
+ targets[:, i], onehot_dim=self.num_classes
+ )
+ (outputs, hidden), alpha = self.attention_cell(
+ hidden, inputs, char_onehots
+ )
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
@@ -62,16 +65,18 @@ def forward(self, inputs, targets=None, batch_max_length=25):
for i in range(num_steps):
char_onehots = self._char_to_onehot(
- targets, onehot_dim=self.num_classes)
- (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
- char_onehots)
+ targets, onehot_dim=self.num_classes
+ )
+ (outputs, hidden), alpha = self.attention_cell(
+ hidden, inputs, char_onehots
+ )
probs_step = self.generator(outputs)
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
- [probs, paddle.unsqueeze(
- probs_step, axis=1)], axis=1)
+ [probs, paddle.unsqueeze(probs_step, axis=1)], axis=1
+ )
next_input = probs_step.argmax(axis=1)
targets = next_input
if not self.training:
@@ -87,12 +92,12 @@ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ input_size=input_size + num_embeddings, hidden_size=hidden_size
+ )
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
-
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
@@ -118,7 +123,8 @@ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
- in_channels, hidden_size, out_channels, use_gru=False)
+ in_channels, hidden_size, out_channels, use_gru=False
+ )
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
@@ -129,17 +135,19 @@ def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
- hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
- (batch_size, self.hidden_size)))
+ hidden = (
+ paddle.zeros((batch_size, self.hidden_size)),
+ paddle.zeros((batch_size, self.hidden_size)),
+ )
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
- targets[:, i], onehot_dim=self.num_classes)
- hidden, alpha = self.attention_cell(hidden, inputs,
- char_onehots)
+ targets[:, i], onehot_dim=self.num_classes
+ )
+ hidden, alpha = self.attention_cell(hidden, inputs, char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
@@ -154,17 +162,17 @@ def forward(self, inputs, targets=None, batch_max_length=25):
for i in range(num_steps):
char_onehots = self._char_to_onehot(
- targets, onehot_dim=self.num_classes)
- hidden, alpha = self.attention_cell(hidden, inputs,
- char_onehots)
+ targets, onehot_dim=self.num_classes
+ )
+ hidden, alpha = self.attention_cell(hidden, inputs, char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
- [probs, paddle.unsqueeze(
- probs_step, axis=1)], axis=1)
+ [probs, paddle.unsqueeze(probs_step, axis=1)], axis=1
+ )
next_input = probs_step.argmax(axis=1)
@@ -182,10 +190,12 @@ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ input_size=input_size + num_embeddings, hidden_size=hidden_size
+ )
else:
self.rnn = nn.GRUCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ input_size=input_size + num_embeddings, hidden_size=hidden_size
+ )
self.hidden_size = hidden_size
diff --git a/ppocr/modeling/heads/rec_can_head.py b/ppocr/modeling/heads/rec_can_head.py
index 732dbfe2db..921b8e4a8b 100644
--- a/ppocr/modeling/heads/rec_can_head.py
+++ b/ppocr/modeling/heads/rec_can_head.py
@@ -27,9 +27,10 @@
import paddle.nn as nn
import paddle
import math
-'''
+
+"""
Counting Module
-'''
+"""
class ChannelAtt(nn.Layer):
@@ -39,7 +40,10 @@ def __init__(self, channel, reduction):
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
- nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid())
+ nn.ReLU(),
+ nn.Linear(channel // reduction, channel),
+ nn.Sigmoid(),
+ )
def forward(self, x):
b, c, _, _ = x.shape
@@ -60,15 +64,17 @@ def __init__(self, in_channel, out_channel, kernel_size):
512,
kernel_size=kernel_size,
padding=kernel_size // 2,
- bias_attr=False),
- nn.BatchNorm2D(512))
+ bias_attr=False,
+ ),
+ nn.BatchNorm2D(512),
+ )
self.channel_att = ChannelAtt(512, 16)
self.pred_layer = nn.Sequential(
- nn.Conv2D(
- 512, self.out_channel, kernel_size=1, bias_attr=False),
- nn.Sigmoid())
+ nn.Conv2D(512, self.out_channel, kernel_size=1, bias_attr=False),
+ nn.Sigmoid(),
+ )
def forward(self, x, mask):
b, _, h, w = x.shape
@@ -84,17 +90,15 @@ def forward(self, x, mask):
return x1, paddle.reshape(x, [b, self.out_channel, h, w])
-'''
+"""
Attention Decoder
-'''
+"""
class PositionEmbeddingSine(nn.Layer):
- def __init__(self,
- num_pos_feats=64,
- temperature=10000,
- normalize=False,
- scale=None):
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
@@ -106,54 +110,61 @@ def __init__(self,
self.scale = scale
def forward(self, x, mask):
- y_embed = paddle.cumsum(mask, 1, dtype='float32')
- x_embed = paddle.cumsum(mask, 2, dtype='float32')
+ y_embed = paddle.cumsum(mask, 1, dtype="float32")
+ x_embed = paddle.cumsum(mask, 2, dtype="float32")
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
- dim_t = paddle.arange(self.num_pos_feats, dtype='float32')
+ dim_t = paddle.arange(self.num_pos_feats, dtype="float32")
dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
- dim_t = self.temperature**(2 * (dim_t / dim_d).astype('int64') /
- self.num_pos_feats)
+ dim_t = self.temperature ** (
+ 2 * (dim_t / dim_d).astype("int64") / self.num_pos_feats
+ )
pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t
pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t
pos_x = paddle.flatten(
paddle.stack(
- [
- paddle.sin(pos_x[:, :, :, 0::2]),
- paddle.cos(pos_x[:, :, :, 1::2])
- ],
- axis=4),
- 3)
+ [paddle.sin(pos_x[:, :, :, 0::2]), paddle.cos(pos_x[:, :, :, 1::2])],
+ axis=4,
+ ),
+ 3,
+ )
pos_y = paddle.flatten(
paddle.stack(
- [
- paddle.sin(pos_y[:, :, :, 0::2]),
- paddle.cos(pos_y[:, :, :, 1::2])
- ],
- axis=4),
- 3)
+ [paddle.sin(pos_y[:, :, :, 0::2]), paddle.cos(pos_y[:, :, :, 1::2])],
+ axis=4,
+ ),
+ 3,
+ )
- pos = paddle.transpose(
- paddle.concat(
- [pos_y, pos_x], axis=3), [0, 3, 1, 2])
+ pos = paddle.transpose(paddle.concat([pos_y, pos_x], axis=3), [0, 3, 1, 2])
return pos
class AttDecoder(nn.Layer):
- def __init__(self, ratio, is_train, input_size, hidden_size,
- encoder_out_channel, dropout, dropout_ratio, word_num,
- counting_decoder_out_channel, attention):
+ def __init__(
+ self,
+ ratio,
+ is_train,
+ input_size,
+ hidden_size,
+ encoder_out_channel,
+ dropout,
+ dropout_ratio,
+ word_num,
+ counting_decoder_out_channel,
+ attention,
+ ):
super(AttDecoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.out_channel = encoder_out_channel
- self.attention_dim = attention['attention_dim']
+ self.attention_dim = attention["attention_dim"]
self.dropout_prob = dropout
self.ratio = ratio
self.word_num = word_num
@@ -164,20 +175,19 @@ def __init__(self, ratio, is_train, input_size, hidden_size,
self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
self.embedding = nn.Embedding(self.word_num, self.input_size)
self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
- self.word_attention = Attention(hidden_size, attention['attention_dim'])
+ self.word_attention = Attention(hidden_size, attention["attention_dim"])
self.encoder_feature_conv = nn.Conv2D(
self.out_channel,
self.attention_dim,
- kernel_size=attention['word_conv_kernel'],
- padding=attention['word_conv_kernel'] // 2)
+ kernel_size=attention["word_conv_kernel"],
+ padding=attention["word_conv_kernel"] // 2,
+ )
self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
- self.word_embedding_weight = nn.Linear(self.input_size,
- self.hidden_size)
+ self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size)
self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
- self.counting_context_weight = nn.Linear(self.counting_num,
- self.hidden_size)
+ self.counting_context_weight = nn.Linear(self.counting_num, self.hidden_size)
self.word_convert = nn.Linear(self.hidden_size, self.word_num)
if dropout:
@@ -190,7 +200,7 @@ def forward(self, cnn_features, labels, counting_preds, images_mask):
num_steps = 36
batch_size, _, height, width = cnn_features.shape
- images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+ images_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
word_probs = paddle.zeros((batch_size, num_steps, self.word_num))
word_alpha_sum = paddle.zeros((batch_size, 1, height, width))
@@ -204,14 +214,14 @@ def forward(self, cnn_features, labels, counting_preds, images_mask):
cnn_features_trans = cnn_features_trans + pos
- word = paddle.ones([batch_size, 1], dtype='int64') # init word as sos
+ word = paddle.ones([batch_size, 1], dtype="int64") # init word as sos
word = word.squeeze(axis=1)
for i in range(num_steps):
word_embedding = self.embedding(word)
_, hidden = self.word_input_gru(word_embedding, hidden)
word_context_vec, _, word_alpha_sum = self.word_attention(
- cnn_features, cnn_features_trans, hidden, word_alpha_sum,
- images_mask)
+ cnn_features, cnn_features_trans, hidden, word_alpha_sum, images_mask
+ )
current_state = self.word_state_weight(hidden)
word_weighted_embedding = self.word_embedding_weight(word_embedding)
@@ -219,10 +229,18 @@ def forward(self, cnn_features, labels, counting_preds, images_mask):
if self.dropout_prob:
word_out_state = self.dropout(
- current_state + word_weighted_embedding +
- word_context_weighted + counting_context_weighted)
+ current_state
+ + word_weighted_embedding
+ + word_context_weighted
+ + counting_context_weighted
+ )
else:
- word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
+ word_out_state = (
+ current_state
+ + word_weighted_embedding
+ + word_context_weighted
+ + counting_context_weighted
+ )
word_prob = self.word_convert(word_out_state)
word_probs[:, i] = word_prob
@@ -238,16 +256,16 @@ def forward(self, cnn_features, labels, counting_preds, images_mask):
return word_probs
def init_hidden(self, features, feature_mask):
- average = paddle.sum(paddle.sum(features * feature_mask, axis=-1),
- axis=-1) / paddle.sum(
- (paddle.sum(feature_mask, axis=-1)), axis=-1)
+ average = paddle.sum(
+ paddle.sum(features * feature_mask, axis=-1), axis=-1
+ ) / paddle.sum((paddle.sum(feature_mask, axis=-1)), axis=-1)
average = self.init_weight(average)
return paddle.tanh(average)
-'''
+"""
Attention Module
-'''
+"""
class Attention(nn.Layer):
@@ -257,35 +275,37 @@ def __init__(self, hidden_size, attention_dim):
self.attention_dim = attention_dim
self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
self.attention_conv = nn.Conv2D(
- 1, 512, kernel_size=11, padding=5, bias_attr=False)
- self.attention_weight = nn.Linear(
- 512, self.attention_dim, bias_attr=False)
+ 1, 512, kernel_size=11, padding=5, bias_attr=False
+ )
+ self.attention_weight = nn.Linear(512, self.attention_dim, bias_attr=False)
self.alpha_convert = nn.Linear(self.attention_dim, 1)
- def forward(self,
- cnn_features,
- cnn_features_trans,
- hidden,
- alpha_sum,
- image_mask=None):
+ def forward(
+ self, cnn_features, cnn_features_trans, hidden, alpha_sum, image_mask=None
+ ):
query = self.hidden_weight(hidden)
alpha_sum_trans = self.attention_conv(alpha_sum)
coverage_alpha = self.attention_weight(
- paddle.transpose(alpha_sum_trans, [0, 2, 3, 1]))
+ paddle.transpose(alpha_sum_trans, [0, 2, 3, 1])
+ )
alpha_score = paddle.tanh(
- paddle.unsqueeze(query, [1, 2]) + coverage_alpha + paddle.transpose(
- cnn_features_trans, [0, 2, 3, 1]))
+ paddle.unsqueeze(query, [1, 2])
+ + coverage_alpha
+ + paddle.transpose(cnn_features_trans, [0, 2, 3, 1])
+ )
energy = self.alpha_convert(alpha_score)
energy = energy - energy.max()
energy_exp = paddle.exp(paddle.squeeze(energy, -1))
if image_mask is not None:
energy_exp = energy_exp * paddle.squeeze(image_mask, 1)
- alpha = energy_exp / (paddle.unsqueeze(
- paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10)
+ alpha = energy_exp / (
+ paddle.unsqueeze(paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10
+ )
alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum
context_vector = paddle.sum(
- paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1)
+ paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1
+ )
return context_vector, alpha, alpha_sum
@@ -297,10 +317,10 @@ def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
self.in_channel = in_channel
self.out_channel = out_channel
- self.counting_decoder1 = CountingDecoder(self.in_channel,
- self.out_channel, 3) # mscm
- self.counting_decoder2 = CountingDecoder(self.in_channel,
- self.out_channel, 5)
+ self.counting_decoder1 = CountingDecoder(
+ self.in_channel, self.out_channel, 3
+ ) # mscm
+ self.counting_decoder2 = CountingDecoder(self.in_channel, self.out_channel, 5)
self.decoder = AttDecoder(ratio, **attdecoder)
@@ -309,11 +329,10 @@ def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
def forward(self, inputs, targets=None):
cnn_features, images_mask, labels = inputs
- counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+ counting_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
counting_preds = (counting_preds1 + counting_preds2) / 2
- word_probs = self.decoder(cnn_features, labels, counting_preds,
- images_mask)
+ word_probs = self.decoder(cnn_features, labels, counting_preds, images_mask)
return word_probs, counting_preds, counting_preds1, counting_preds2
diff --git a/ppocr/modeling/heads/rec_cppd_head.py b/ppocr/modeling/heads/rec_cppd_head.py
index dc3ba4e12a..87c4037798 100644
--- a/ppocr/modeling/heads/rec_cppd_head.py
+++ b/ppocr/modeling/heads/rec_cppd_head.py
@@ -26,17 +26,26 @@
from paddle import nn
from paddle.nn import functional as F
from ppocr.modeling.heads.rec_nrtr_head import Embeddings
-from ppocr.modeling.backbones.rec_svtrnet import DropPath, Identity, trunc_normal_, zeros_, ones_, Mlp
+from ppocr.modeling.backbones.rec_svtrnet import (
+ DropPath,
+ Identity,
+ trunc_normal_,
+ zeros_,
+ ones_,
+ Mlp,
+)
class Attention(nn.Layer):
- def __init__(self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
@@ -51,12 +60,16 @@ def __init__(self,
def forward(self, q, kv):
N, C = kv.shape[1:]
QN = q.shape[1]
- q = self.q(q).reshape(
- [-1, QN, self.num_heads, C // self.num_heads]).transpose(
- [0, 2, 1, 3])
- k, v = self.kv(kv).reshape(
- [-1, N, 2, self.num_heads, C // self.num_heads]).transpose(
- (2, 0, 3, 1, 4))
+ q = (
+ self.q(q)
+ .reshape([-1, QN, self.num_heads, C // self.num_heads])
+ .transpose([0, 2, 1, 3])
+ )
+ k, v = (
+ self.kv(kv)
+ .reshape([-1, N, 2, self.num_heads, C // self.num_heads])
+ .transpose((2, 0, 3, 1, 4))
+ )
attn = q.matmul(k.transpose((0, 1, 3, 2))) * self.scale
attn = F.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
@@ -67,26 +80,27 @@ def forward(self, q, kv):
class EdgeDecoderLayer(nn.Layer):
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=[0., 0.],
- act_layer=nn.GELU,
- norm_layer='nn.LayerNorm',
- epsilon=1e-6):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=[0.0, 0.0],
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-6,
+ ):
super().__init__()
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path1 = DropPath(drop_path[0]) if drop_path[
- 0] > 0. else Identity()
+ self.drop_path1 = DropPath(drop_path[0]) if drop_path[0] > 0.0 else Identity()
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
@@ -99,29 +113,36 @@ def __init__(self,
self.p_proj = nn.Linear(dim, dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
- self.mlp = Mlp(in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
def forward(self, p, cv, pv):
-
pN = p.shape[1]
vN = cv.shape[1]
p_shortcut = p
- p1 = self.p(p).reshape(
- [-1, pN, self.num_heads, self.dim // self.num_heads]).transpose(
- [0, 2, 1, 3])
- cv1 = self.cv(cv).reshape(
- [-1, vN, self.num_heads, self.dim // self.num_heads]).transpose(
- [0, 2, 1, 3])
- pv1 = self.pv(pv).reshape(
- [-1, vN, self.num_heads, self.dim // self.num_heads]).transpose(
- [0, 2, 1, 3])
+ p1 = (
+ self.p(p)
+ .reshape([-1, pN, self.num_heads, self.dim // self.num_heads])
+ .transpose([0, 2, 1, 3])
+ )
+ cv1 = (
+ self.cv(cv)
+ .reshape([-1, vN, self.num_heads, self.dim // self.num_heads])
+ .transpose([0, 2, 1, 3])
+ )
+ pv1 = (
+ self.pv(pv)
+ .reshape([-1, vN, self.num_heads, self.dim // self.num_heads])
+ .transpose([0, 2, 1, 3])
+ )
edge = F.softmax(p1.matmul(pv1.transpose((0, 1, 3, 2))), -1) # B h N N
- p_c = (edge @cv1).transpose((0, 2, 1, 3)).reshape((-1, pN, self.dim))
+ p_c = (edge @ cv1).transpose((0, 2, 1, 3)).reshape((-1, pN, self.dim))
x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
@@ -130,18 +151,20 @@ def forward(self, p, cv, pv):
class DecoderLayer(nn.Layer):
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer='nn.LayerNorm',
- epsilon=1e-6):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-6,
+ ):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
@@ -150,31 +173,32 @@ def __init__(self,
self.norm1 = norm_layer(dim)
self.normkv = norm_layer(dim)
else:
- raise TypeError(
- "The norm_layer must be str or paddle.nn.LayerNorm class")
+ raise TypeError("The norm_layer must be str or paddle.nn.LayerNorm class")
self.mixer = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
- proj_drop=drop)
+ proj_drop=drop,
+ )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm2 = norm_layer(dim)
else:
- raise TypeError(
- "The norm_layer must be str or paddle.nn.layer.Layer class")
+ raise TypeError("The norm_layer must be str or paddle.nn.layer.Layer class")
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
- self.mlp = Mlp(in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
def forward(self, q, kv):
x1 = self.norm1(q + self.drop_path(self.mixer(q, kv)))
@@ -183,16 +207,18 @@ def forward(self, q, kv):
class CPPDHead(nn.Layer):
- def __init__(self,
- in_channels,
- dim,
- out_channels,
- num_layer=2,
- drop_path_rate=0.1,
- max_len=25,
- vis_seq=50,
- ch=False,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ dim,
+ out_channels,
+ num_layer=2,
+ drop_path_rate=0.1,
+ max_len=25,
+ vis_seq=50,
+ ch=False,
+ **kwargs
+ ):
super(CPPDHead, self).__init__()
self.out_channels = out_channels # none + 26 + 10
@@ -200,40 +226,53 @@ def __init__(self,
self.ch = ch
self.max_len = max_len + 1 # max_len + eos
self.char_node_embed = Embeddings(
- d_model=dim, vocab=self.out_channels, scale_embedding=True)
+ d_model=dim, vocab=self.out_channels, scale_embedding=True
+ )
self.pos_node_embed = Embeddings(
- d_model=dim, vocab=self.max_len, scale_embedding=True)
+ d_model=dim, vocab=self.max_len, scale_embedding=True
+ )
dpr = np.linspace(0, drop_path_rate, num_layer + 1)
- self.char_node_decoder = nn.LayerList([
- DecoderLayer(
- dim=dim,
- num_heads=dim // 32,
- mlp_ratio=4.0,
- qkv_bias=True,
- drop_path=dpr[i]) for i in range(num_layer)
- ])
- self.pos_node_decoder = nn.LayerList([
- DecoderLayer(
- dim=dim,
- num_heads=dim // 32,
- mlp_ratio=4.0,
- qkv_bias=True,
- drop_path=dpr[i]) for i in range(num_layer)
- ])
+ self.char_node_decoder = nn.LayerList(
+ [
+ DecoderLayer(
+ dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i],
+ )
+ for i in range(num_layer)
+ ]
+ )
+ self.pos_node_decoder = nn.LayerList(
+ [
+ DecoderLayer(
+ dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i],
+ )
+ for i in range(num_layer)
+ ]
+ )
self.edge_decoder = EdgeDecoderLayer(
dim=dim,
num_heads=dim // 32,
mlp_ratio=4.0,
qkv_bias=True,
- drop_path=dpr[num_layer:num_layer + 1])
+ drop_path=dpr[num_layer : num_layer + 1],
+ )
self.char_pos_embed = self.create_parameter(
- shape=[1, self.max_len, dim], default_initializer=zeros_)
+ shape=[1, self.max_len, dim], default_initializer=zeros_
+ )
self.add_parameter("char_pos_embed", self.char_pos_embed)
self.vis_pos_embed = self.create_parameter(
- shape=[1, vis_seq, dim], default_initializer=zeros_)
+ shape=[1, vis_seq, dim], default_initializer=zeros_
+ )
self.add_parameter("vis_pos_embed", self.vis_pos_embed)
self.char_node_fc1 = nn.Linear(dim, max_len)
@@ -262,23 +301,29 @@ def forward(self, x, targets=None, epoch=0):
def forward_test(self, x):
visual_feats = x + self.vis_pos_embed
bs = visual_feats.shape[0]
- pos_node_embed = self.pos_node_embed(paddle.arange(
- self.max_len)).unsqueeze(0) + self.char_pos_embed
+ pos_node_embed = (
+ self.pos_node_embed(paddle.arange(self.max_len)).unsqueeze(0)
+ + self.char_pos_embed
+ )
pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1])
char_vis_node_query = visual_feats
pos_vis_node_query = paddle.concat([pos_node_embed, visual_feats], 1)
- for char_decoder_layer, pos_decoder_layer in zip(self.char_node_decoder,
- self.pos_node_decoder):
- char_vis_node_query = char_decoder_layer(char_vis_node_query,
- char_vis_node_query)
+ for char_decoder_layer, pos_decoder_layer in zip(
+ self.char_node_decoder, self.pos_node_decoder
+ ):
+ char_vis_node_query = char_decoder_layer(
+ char_vis_node_query, char_vis_node_query
+ )
pos_vis_node_query = pos_decoder_layer(
- pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
- pos_node_query = pos_vis_node_query[:, :self.max_len, :]
+ pos_vis_node_query, pos_vis_node_query[:, self.max_len :, :]
+ )
+ pos_node_query = pos_vis_node_query[:, : self.max_len, :]
char_vis_feats = char_vis_node_query
- pos_node_feats = self.edge_decoder(pos_node_query, char_vis_feats,
- char_vis_feats) # B, 26, dim
+ pos_node_feats = self.edge_decoder(
+ pos_node_query, char_vis_feats, char_vis_feats
+ ) # B, 26, dim
edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
edge_logits = F.softmax(edge_feats, -1)
@@ -292,11 +337,14 @@ def forward_train(self, x, targets=None, epoch=0):
char_node_embed = self.char_node_embed(targets[-2])
else:
char_node_embed = self.char_node_embed(
- paddle.arange(self.out_channels)).unsqueeze(0)
+ paddle.arange(self.out_channels)
+ ).unsqueeze(0)
char_node_embed = paddle.tile(char_node_embed, [bs, 1, 1])
counting_char_num = char_node_embed.shape[1]
- pos_node_embed = self.pos_node_embed(paddle.arange(
- self.max_len)).unsqueeze(0) + self.char_pos_embed
+ pos_node_embed = (
+ self.pos_node_embed(paddle.arange(self.max_len)).unsqueeze(0)
+ + self.char_pos_embed
+ )
pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1])
node_feats = []
@@ -304,30 +352,36 @@ def forward_train(self, x, targets=None, epoch=0):
char_vis_node_query = paddle.concat([char_node_embed, visual_feats], 1)
pos_vis_node_query = paddle.concat([pos_node_embed, visual_feats], 1)
- for char_decoder_layer, pos_decoder_layer in zip(self.char_node_decoder,
- self.pos_node_decoder):
+ for char_decoder_layer, pos_decoder_layer in zip(
+ self.char_node_decoder, self.pos_node_decoder
+ ):
char_vis_node_query = char_decoder_layer(
- char_vis_node_query,
- char_vis_node_query[:, counting_char_num:, :])
+ char_vis_node_query, char_vis_node_query[:, counting_char_num:, :]
+ )
pos_vis_node_query = pos_decoder_layer(
- pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
+ pos_vis_node_query, pos_vis_node_query[:, self.max_len :, :]
+ )
char_node_query = char_vis_node_query[:, :counting_char_num, :]
- pos_node_query = pos_vis_node_query[:, :self.max_len, :]
+ pos_node_query = pos_vis_node_query[:, : self.max_len, :]
char_vis_feats = char_vis_node_query[:, counting_char_num:, :]
char_node_feats1 = self.char_node_fc1(char_node_query)
pos_node_feats1 = self.pos_node_fc1(pos_node_query)
- diag_mask = paddle.eye(pos_node_feats1.shape[1]).unsqueeze(0).tile(
- [pos_node_feats1.shape[0], 1, 1])
+ diag_mask = (
+ paddle.eye(pos_node_feats1.shape[1])
+ .unsqueeze(0)
+ .tile([pos_node_feats1.shape[0], 1, 1])
+ )
pos_node_feats1 = (pos_node_feats1 * diag_mask).sum(-1)
node_feats.append(char_node_feats1)
node_feats.append(pos_node_feats1)
- pos_node_feats = self.edge_decoder(pos_node_query, char_vis_feats,
- char_vis_feats) # B, 26, dim
+ pos_node_feats = self.edge_decoder(
+ pos_node_query, char_vis_feats, char_vis_feats
+ ) # B, 26, dim
edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
return node_feats, edge_feats
diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py
index 6c1cf06596..6c008bbeda 100755
--- a/ppocr/modeling/heads/rec_ctc_head.py
+++ b/ppocr/modeling/heads/rec_ctc_head.py
@@ -33,38 +33,43 @@ def get_para_bias_attr(l2_decay, k):
class CTCHead(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- fc_decay=0.0004,
- mid_channels=None,
- return_feats=False,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ fc_decay=0.0004,
+ mid_channels=None,
+ return_feats=False,
+ **kwargs
+ ):
super(CTCHead, self).__init__()
if mid_channels is None:
weight_attr, bias_attr = get_para_bias_attr(
- l2_decay=fc_decay, k=in_channels)
+ l2_decay=fc_decay, k=in_channels
+ )
self.fc = nn.Linear(
- in_channels,
- out_channels,
- weight_attr=weight_attr,
- bias_attr=bias_attr)
+ in_channels, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
+ )
else:
weight_attr1, bias_attr1 = get_para_bias_attr(
- l2_decay=fc_decay, k=in_channels)
+ l2_decay=fc_decay, k=in_channels
+ )
self.fc1 = nn.Linear(
in_channels,
mid_channels,
weight_attr=weight_attr1,
- bias_attr=bias_attr1)
+ bias_attr=bias_attr1,
+ )
weight_attr2, bias_attr2 = get_para_bias_attr(
- l2_decay=fc_decay, k=mid_channels)
+ l2_decay=fc_decay, k=mid_channels
+ )
self.fc2 = nn.Linear(
mid_channels,
out_channels,
weight_attr=weight_attr2,
- bias_attr=bias_attr2)
+ bias_attr=bias_attr2,
+ )
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
diff --git a/ppocr/modeling/heads/rec_multi_head.py b/ppocr/modeling/heads/rec_multi_head.py
index bca76511e0..c7005c108e 100644
--- a/ppocr/modeling/heads/rec_multi_head.py
+++ b/ppocr/modeling/heads/rec_multi_head.py
@@ -22,7 +22,15 @@
import paddle.nn as nn
import paddle.nn.functional as F
-from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR, trunc_normal_, zeros_
+from ppocr.modeling.necks.rnn import (
+ Im2Seq,
+ EncoderWithRNN,
+ EncoderWithFC,
+ SequenceEncoder,
+ EncoderWithSVTR,
+ trunc_normal_,
+ zeros_,
+)
from .rec_ctc_head import CTCHead
from .rec_sar_head import SARHead
from .rec_nrtr_head import Transformer
@@ -41,49 +49,58 @@ def forward(self, x):
else:
return self.fc(x.transpose([0, 2, 1]))
+
class AddPos(nn.Layer):
- def __init__(self, dim, w):
+ def __init__(self, dim, w):
super().__init__()
self.dec_pos_embed = self.create_parameter(
- shape=[1, w, dim], default_initializer=zeros_)
+ shape=[1, w, dim], default_initializer=zeros_
+ )
self.add_parameter("dec_pos_embed", self.dec_pos_embed)
trunc_normal_(self.dec_pos_embed)
-
- def forward(self,x):
- x = x + self.dec_pos_embed[:, :x.shape[1], :]
+
+ def forward(self, x):
+ x = x + self.dec_pos_embed[:, : x.shape[1], :]
return x
class MultiHead(nn.Layer):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
- self.head_list = kwargs.pop('head_list')
- self.use_pool = kwargs.get('use_pool', False)
- self.use_pos = kwargs.get('use_pos', False)
+ self.head_list = kwargs.pop("head_list")
+ self.use_pool = kwargs.get("use_pool", False)
+ self.use_pos = kwargs.get("use_pos", False)
self.in_channels = in_channels
if self.use_pool:
self.pool = nn.AvgPool2D(kernel_size=[3, 2], stride=[3, 2], padding=0)
- self.gtc_head = 'sar'
+ self.gtc_head = "sar"
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
name = list(head_name)[0]
- if name == 'SARHead':
+ if name == "SARHead":
# sar head
sar_args = self.head_list[idx][name]
- self.sar_head = eval(name)(in_channels=in_channels, \
- out_channels=out_channels_list['SARLabelDecode'], **sar_args)
- elif name == 'NRTRHead':
+ self.sar_head = eval(name)(
+ in_channels=in_channels,
+ out_channels=out_channels_list["SARLabelDecode"],
+ **sar_args
+ )
+ elif name == "NRTRHead":
gtc_args = self.head_list[idx][name]
- max_text_length = gtc_args.get('max_text_length', 25)
- nrtr_dim = gtc_args.get('nrtr_dim', 256)
- num_decoder_layers = gtc_args.get('num_decoder_layers', 4)
+ max_text_length = gtc_args.get("max_text_length", 25)
+ nrtr_dim = gtc_args.get("nrtr_dim", 256)
+ num_decoder_layers = gtc_args.get("num_decoder_layers", 4)
if self.use_pos:
self.before_gtc = nn.Sequential(
- nn.Flatten(2), FCTranspose(in_channels, nrtr_dim), AddPos(nrtr_dim, 80))
+ nn.Flatten(2),
+ FCTranspose(in_channels, nrtr_dim),
+ AddPos(nrtr_dim, 80),
+ )
else:
self.before_gtc = nn.Sequential(
- nn.Flatten(2), FCTranspose(in_channels, nrtr_dim))
-
+ nn.Flatten(2), FCTranspose(in_channels, nrtr_dim)
+ )
+
self.gtc_head = Transformer(
d_model=nrtr_dim,
nhead=nrtr_dim // 32,
@@ -92,37 +109,45 @@ def __init__(self, in_channels, out_channels_list, **kwargs):
num_decoder_layers=num_decoder_layers,
max_len=max_text_length,
dim_feedforward=nrtr_dim * 4,
- out_channels=out_channels_list['NRTRLabelDecode'])
- elif name == 'CTCHead':
+ out_channels=out_channels_list["NRTRLabelDecode"],
+ )
+ elif name == "CTCHead":
# ctc neck
self.encoder_reshape = Im2Seq(in_channels)
- neck_args = self.head_list[idx][name]['Neck']
- encoder_type = neck_args.pop('name')
- self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \
- encoder_type=encoder_type, **neck_args)
+ neck_args = self.head_list[idx][name]["Neck"]
+ encoder_type = neck_args.pop("name")
+ self.ctc_encoder = SequenceEncoder(
+ in_channels=in_channels, encoder_type=encoder_type, **neck_args
+ )
# ctc head
- head_args = self.head_list[idx][name]['Head']
- self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels, \
- out_channels=out_channels_list['CTCLabelDecode'], **head_args)
+ head_args = self.head_list[idx][name]["Head"]
+ self.ctc_head = eval(name)(
+ in_channels=self.ctc_encoder.out_channels,
+ out_channels=out_channels_list["CTCLabelDecode"],
+ **head_args
+ )
else:
raise NotImplementedError(
- '{} is not supported in MultiHead yet'.format(name))
+ "{} is not supported in MultiHead yet".format(name)
+ )
def forward(self, x, targets=None):
if self.use_pool:
- x = self.pool(x.reshape([0, 3, -1, self.in_channels]).transpose([0, 3, 1, 2]))
+ x = self.pool(
+ x.reshape([0, 3, -1, self.in_channels]).transpose([0, 3, 1, 2])
+ )
ctc_encoder = self.ctc_encoder(x)
ctc_out = self.ctc_head(ctc_encoder, targets)
head_out = dict()
- head_out['ctc'] = ctc_out
- head_out['ctc_neck'] = ctc_encoder
+ head_out["ctc"] = ctc_out
+ head_out["ctc_neck"] = ctc_encoder
# eval mode
if not self.training:
return ctc_out
- if self.gtc_head == 'sar':
+ if self.gtc_head == "sar":
sar_out = self.sar_head(x, targets[1:])
- head_out['sar'] = sar_out
+ head_out["sar"] = sar_out
else:
gtc_out = self.gtc_head(self.before_gtc(x), targets[1:])
- head_out['nrtr'] = gtc_out
+ head_out["nrtr"] = gtc_out
return head_out
diff --git a/ppocr/modeling/heads/rec_nrtr_head.py b/ppocr/modeling/heads/rec_nrtr_head.py
index 46de11428e..ad01438bee 100644
--- a/ppocr/modeling/heads/rec_nrtr_head.py
+++ b/ppocr/modeling/heads/rec_nrtr_head.py
@@ -40,19 +40,21 @@ class Transformer(nn.Layer):
custom_decoder: custom decoder (default=None).
"""
- def __init__(self,
- d_model=512,
- nhead=8,
- num_encoder_layers=6,
- beam_size=0,
- num_decoder_layers=6,
- max_len=25,
- dim_feedforward=1024,
- attention_dropout_rate=0.0,
- residual_dropout_rate=0.1,
- in_channels=0,
- out_channels=0,
- scale_embedding=True):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ beam_size=0,
+ num_decoder_layers=6,
+ max_len=25,
+ dim_feedforward=1024,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ in_channels=0,
+ out_channels=0,
+ scale_embedding=True,
+ ):
super(Transformer, self).__init__()
self.out_channels = out_channels + 1
self.max_len = max_len
@@ -60,12 +62,32 @@ def __init__(self,
d_model=d_model,
vocab=self.out_channels,
padding_idx=0,
- scale_embedding=scale_embedding)
+ scale_embedding=scale_embedding,
+ )
self.positional_encoding = PositionalEncoding(
- dropout=residual_dropout_rate, dim=d_model)
+ dropout=residual_dropout_rate, dim=d_model
+ )
if num_encoder_layers > 0:
- self.encoder = nn.LayerList([
+ self.encoder = nn.LayerList(
+ [
+ TransformerBlock(
+ d_model,
+ nhead,
+ dim_feedforward,
+ attention_dropout_rate,
+ residual_dropout_rate,
+ with_self_attn=True,
+ with_cross_attn=False,
+ )
+ for i in range(num_encoder_layers)
+ ]
+ )
+ else:
+ self.encoder = None
+
+ self.decoder = nn.LayerList(
+ [
TransformerBlock(
d_model,
nhead,
@@ -73,34 +95,23 @@ def __init__(self,
attention_dropout_rate,
residual_dropout_rate,
with_self_attn=True,
- with_cross_attn=False) for i in range(num_encoder_layers)
- ])
- else:
- self.encoder = None
-
- self.decoder = nn.LayerList([
- TransformerBlock(
- d_model,
- nhead,
- dim_feedforward,
- attention_dropout_rate,
- residual_dropout_rate,
- with_self_attn=True,
- with_cross_attn=True) for i in range(num_decoder_layers)
- ])
+ with_cross_attn=True,
+ )
+ for i in range(num_decoder_layers)
+ ]
+ )
self.beam_size = beam_size
self.d_model = d_model
self.nhead = nhead
- self.tgt_word_prj = nn.Linear(
- d_model, self.out_channels, bias_attr=False)
- w0 = np.random.normal(0.0, d_model**-0.5,
- (d_model, self.out_channels)).astype(np.float32)
+ self.tgt_word_prj = nn.Linear(d_model, self.out_channels, bias_attr=False)
+ w0 = np.random.normal(
+ 0.0, d_model**-0.5, (d_model, self.out_channels)
+ ).astype(np.float32)
self.tgt_word_prj.weight.set_value(w0)
self.apply(self._init_weights)
def _init_weights(self, m):
-
if isinstance(m, nn.Linear):
xavier_normal_(m.weight)
if m.bias is not None:
@@ -140,7 +151,7 @@ def forward(self, src, targets=None):
if self.training:
max_len = targets[1].max()
- tgt = targets[0][:, :2 + max_len]
+ tgt = targets[0][:, : 2 + max_len]
return self.forward_train(src, tgt)
else:
if self.beam_size > 0:
@@ -149,7 +160,6 @@ def forward(self, src, targets=None):
return self.forward_test(src)
def forward_test(self, src):
-
bs = src.shape[0]
if self.encoder is not None:
src = self.positional_encoding(src)
@@ -159,12 +169,11 @@ def forward_test(self, src):
else:
memory = src
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
- dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
+ dec_prob = paddle.full((bs, 1), 1.0, dtype=paddle.float32)
for len_dec_seq in range(1, paddle.to_tensor(self.max_len)):
dec_seq_embed = self.embedding(dec_seq)
dec_seq_embed = self.positional_encoding(dec_seq_embed)
- tgt_mask = self.generate_square_subsequent_mask(
- dec_seq_embed.shape[1])
+ tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[1])
tgt = dec_seq_embed
for decoder_layer in self.decoder:
tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
@@ -173,45 +182,50 @@ def forward_test(self, src):
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=-1)
preds_idx = paddle.argmax(word_prob, axis=-1)
if paddle.equal_all(
- preds_idx,
- paddle.full(
- preds_idx.shape, 3, dtype='int64')):
+ preds_idx, paddle.full(preds_idx.shape, 3, dtype="int64")
+ ):
break
preds_prob = paddle.max(word_prob, axis=-1)
dec_seq = paddle.concat(
- [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
+ [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1
+ )
dec_prob = paddle.concat(
- [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
+ [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1
+ )
return [dec_seq, dec_prob]
def forward_beam(self, images):
- """ Translation work in one batch """
+ """Translation work in one batch"""
def get_inst_idx_to_tensor_position_map(inst_idx_list):
- """ Indicate the position of an instance in a tensor. """
+ """Indicate the position of an instance in a tensor."""
return {
inst_idx: tensor_position
for tensor_position, inst_idx in enumerate(inst_idx_list)
}
- def collect_active_part(beamed_tensor, curr_active_inst_idx,
- n_prev_active_inst, n_bm):
- """ Collect tensor parts associated to active instances. """
+ def collect_active_part(
+ beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm
+ ):
+ """Collect tensor parts associated to active instances."""
beamed_tensor_shape = beamed_tensor.shape
n_curr_active_inst = len(curr_active_inst_idx)
- new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
- beamed_tensor_shape[2])
+ new_shape = (
+ n_curr_active_inst * n_bm,
+ beamed_tensor_shape[1],
+ beamed_tensor_shape[2],
+ )
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
- beamed_tensor = beamed_tensor.index_select(
- curr_active_inst_idx, axis=0)
+ beamed_tensor = beamed_tensor.index_select(curr_active_inst_idx, axis=0)
beamed_tensor = beamed_tensor.reshape(new_shape)
return beamed_tensor
- def collate_active_info(src_enc, inst_idx_to_position_map,
- active_inst_idx_list):
+ def collate_active_info(
+ src_enc, inst_idx_to_position_map, active_inst_idx_list
+ ):
# Sentences which are still active are collected,
# so the decoder will not run on completed sentences.
@@ -219,17 +233,19 @@ def collate_active_info(src_enc, inst_idx_to_position_map,
active_inst_idx = [
inst_idx_to_position_map[k] for k in active_inst_idx_list
]
- active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
+ active_inst_idx = paddle.to_tensor(active_inst_idx, dtype="int64")
active_src_enc = collect_active_part(
- src_enc.transpose([1, 0, 2]), active_inst_idx,
- n_prev_active_inst, n_bm).transpose([1, 0, 2])
+ src_enc.transpose([1, 0, 2]), active_inst_idx, n_prev_active_inst, n_bm
+ ).transpose([1, 0, 2])
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
- active_inst_idx_list)
+ active_inst_idx_list
+ )
return active_src_enc, active_inst_idx_to_position_map
- def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
- inst_idx_to_position_map, n_bm):
- """ Decode and update beam status, and then return active beam idx """
+ def beam_decode_step(
+ inst_dec_beams, len_dec_seq, enc_output, inst_idx_to_position_map, n_bm
+ ):
+ """Decode and update beam status, and then return active beam idx"""
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
dec_partial_seq = [
@@ -242,24 +258,24 @@ def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
def predict_word(dec_seq, enc_output, n_active_inst, n_bm):
dec_seq = self.embedding(dec_seq)
dec_seq = self.positional_encoding(dec_seq)
- tgt_mask = self.generate_square_subsequent_mask(
- dec_seq.shape[1])
+ tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[1])
tgt = dec_seq
for decoder_layer in self.decoder:
tgt = decoder_layer(tgt, enc_output, self_mask=tgt_mask)
dec_output = tgt
- dec_output = dec_output[:,
- -1, :] # Pick the last step: (bh * bm) * d_h
+ dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
return word_prob
- def collect_active_inst_idx_list(inst_beams, word_prob,
- inst_idx_to_position_map):
+ def collect_active_inst_idx_list(
+ inst_beams, word_prob, inst_idx_to_position_map
+ ):
active_inst_idx_list = []
for inst_idx, inst_position in inst_idx_to_position_map.items():
- is_inst_complete = inst_beams[inst_idx].advance(word_prob[
- inst_position])
+ is_inst_complete = inst_beams[inst_idx].advance(
+ word_prob[inst_position]
+ )
if not is_inst_complete:
active_inst_idx_list += [inst_idx]
@@ -270,7 +286,8 @@ def collect_active_inst_idx_list(inst_beams, word_prob,
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm)
# Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list = collect_active_inst_idx_list(
- inst_dec_beams, word_prob, inst_idx_to_position_map)
+ inst_dec_beams, word_prob, inst_idx_to_position_map
+ )
return active_inst_idx_list
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
@@ -286,7 +303,7 @@ def collect_hypothesis_and_scores(inst_dec_beams, n_best):
return all_hyp, all_scores
with paddle.no_grad():
- #-- Encode
+ # -- Encode
if self.encoder is not None:
src = self.positional_encoding(images)
src_enc = self.encoder(src)
@@ -300,20 +317,24 @@ def collect_hypothesis_and_scores(inst_dec_beams, n_best):
# Repeat data for beam search
src_enc = paddle.tile(src_enc, [1, n_bm, 1])
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
- active_inst_idx_list)
+ active_inst_idx_list
+ )
# Decode
for len_dec_seq in range(1, paddle.to_tensor(self.max_len)):
src_enc_copy = src_enc.clone()
active_inst_idx_list = beam_decode_step(
- inst_dec_beams, len_dec_seq, src_enc_copy,
- inst_idx_to_position_map, n_bm)
+ inst_dec_beams,
+ len_dec_seq,
+ src_enc_copy,
+ inst_idx_to_position_map,
+ n_bm,
+ )
if not active_inst_idx_list:
break # all instances have finished their path to
src_enc, inst_idx_to_position_map = collate_active_info(
- src_enc_copy, inst_idx_to_position_map,
- active_inst_idx_list)
- batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
- 1)
+ src_enc_copy, inst_idx_to_position_map, active_inst_idx_list
+ )
+ batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, 1)
result_hyp = []
hyp_scores = []
for bs_hyp, score in zip(batch_hyp, batch_scores):
@@ -324,20 +345,18 @@ def collect_hypothesis_and_scores(inst_dec_beams, n_best):
hyp_score = [score for _ in range(25)]
hyp_scores.append(hyp_score)
return [
- paddle.to_tensor(
- np.array(result_hyp), dtype=paddle.int64),
- paddle.to_tensor(hyp_scores)
+ paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64),
+ paddle.to_tensor(hyp_scores),
]
def generate_square_subsequent_mask(self, sz):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
- Unmasked positions are filled with float(0.0).
+ Unmasked positions are filled with float(0.0).
"""
- mask = paddle.zeros([sz, sz], dtype='float32')
+ mask = paddle.zeros([sz, sz], dtype="float32")
mask_inf = paddle.triu(
- paddle.full(
- shape=[sz, sz], dtype='float32', fill_value='-inf'),
- diagonal=1)
+ paddle.full(shape=[sz, sz], dtype="float32", fill_value="-inf"), diagonal=1
+ )
mask = mask + mask_inf
return mask.unsqueeze([0, 1])
@@ -357,13 +376,15 @@ class MultiheadAttention(nn.Layer):
"""
- def __init__(self, embed_dim, num_heads, dropout=0., self_attn=False):
+ def __init__(self, embed_dim, num_heads, dropout=0.0, self_attn=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
# self.dropout = dropout
self.head_dim = embed_dim // num_heads
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
self.scale = self.head_dim**-0.5
self.self_attn = self_attn
if self_attn:
@@ -375,21 +396,27 @@ def __init__(self, embed_dim, num_heads, dropout=0., self_attn=False):
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key=None, attn_mask=None):
-
qN = query.shape[1]
if self.self_attn:
- qkv = self.qkv(query).reshape(
- (0, qN, 3, self.num_heads, self.head_dim)).transpose(
- (2, 0, 3, 1, 4))
+ qkv = (
+ self.qkv(query)
+ .reshape((0, qN, 3, self.num_heads, self.head_dim))
+ .transpose((2, 0, 3, 1, 4))
+ )
q, k, v = qkv[0], qkv[1], qkv[2]
else:
kN = key.shape[1]
- q = self.q(query).reshape(
- [0, qN, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
- kv = self.kv(key).reshape(
- (0, kN, 2, self.num_heads, self.head_dim)).transpose(
- (2, 0, 3, 1, 4))
+ q = (
+ self.q(query)
+ .reshape([0, qN, self.num_heads, self.head_dim])
+ .transpose([0, 2, 1, 3])
+ )
+ kv = (
+ self.kv(key)
+ .reshape((0, kN, 2, self.num_heads, self.head_dim))
+ .transpose((2, 0, 3, 1, 4))
+ )
k, v = kv[0], kv[1]
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
@@ -400,46 +427,48 @@ def forward(self, query, key=None, attn_mask=None):
attn = F.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
- x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape(
- (0, qN, self.embed_dim))
+ x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, qN, self.embed_dim))
x = self.out_proj(x)
return x
class TransformerBlock(nn.Layer):
- def __init__(self,
- d_model,
- nhead,
- dim_feedforward=2048,
- attention_dropout_rate=0.0,
- residual_dropout_rate=0.1,
- with_self_attn=True,
- with_cross_attn=False,
- epsilon=1e-5):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ with_self_attn=True,
+ with_cross_attn=False,
+ epsilon=1e-5,
+ ):
super(TransformerBlock, self).__init__()
self.with_self_attn = with_self_attn
if with_self_attn:
self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- dropout=attention_dropout_rate,
- self_attn=with_self_attn)
+ d_model, nhead, dropout=attention_dropout_rate, self_attn=with_self_attn
+ )
self.norm1 = LayerNorm(d_model, epsilon=epsilon)
self.dropout1 = Dropout(residual_dropout_rate)
self.with_cross_attn = with_cross_attn
if with_cross_attn:
- self.cross_attn = MultiheadAttention( #for self_attn of encoder or cross_attn of decoder
- d_model,
- nhead,
- dropout=attention_dropout_rate)
+ self.cross_attn = (
+ MultiheadAttention( # for self_attn of encoder or cross_attn of decoder
+ d_model, nhead, dropout=attention_dropout_rate
+ )
+ )
self.norm2 = LayerNorm(d_model, epsilon=epsilon)
self.dropout2 = Dropout(residual_dropout_rate)
- self.mlp = Mlp(in_features=d_model,
- hidden_features=dim_feedforward,
- act_layer=nn.ReLU,
- drop=residual_dropout_rate)
+ self.mlp = Mlp(
+ in_features=d_model,
+ hidden_features=dim_feedforward,
+ act_layer=nn.ReLU,
+ drop=residual_dropout_rate,
+ )
self.norm3 = LayerNorm(d_model, epsilon=epsilon)
@@ -481,13 +510,13 @@ def __init__(self, dropout, dim, max_len=5000):
pe = paddle.zeros([max_len, dim])
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
- paddle.arange(0, dim, 2).astype('float32') *
- (-math.log(10000.0) / dim))
+ paddle.arange(0, dim, 2).astype("float32") * (-math.log(10000.0) / dim)
+ )
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = paddle.unsqueeze(pe, 0)
pe = paddle.transpose(pe, [1, 0, 2])
- self.register_buffer('pe', pe)
+ self.register_buffer("pe", pe)
def forward(self, x):
"""Inputs of forward function
@@ -500,7 +529,7 @@ def forward(self, x):
>>> output = pos_encoder(x)
"""
x = x.transpose([1, 0, 2])
- x = x + self.pe[:x.shape[0], :]
+ x = x + self.pe[: x.shape[0], :]
return self.dropout(x).transpose([1, 0, 2])
@@ -528,19 +557,19 @@ def __init__(self, dropout, dim, max_len=5000):
pe = paddle.zeros([max_len, dim])
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
- paddle.arange(0, dim, 2).astype('float32') *
- (-math.log(10000.0) / dim))
+ paddle.arange(0, dim, 2).astype("float32") * (-math.log(10000.0) / dim)
+ )
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
- self.register_buffer('pe', pe)
+ self.register_buffer("pe", pe)
self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
self.linear1 = nn.Linear(dim, dim)
- self.linear1.weight.data.fill_(1.)
+ self.linear1.weight.data.fill_(1.0)
self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
self.linear2 = nn.Linear(dim, dim)
- self.linear2.weight.data.fill_(1.)
+ self.linear2.weight.data.fill_(1.0)
def forward(self, x):
"""Inputs of forward function
@@ -552,13 +581,13 @@ def forward(self, x):
Examples:
>>> output = pos_encoder(x)
"""
- w_pe = self.pe[:x.shape[-1], :]
+ w_pe = self.pe[: x.shape[-1], :]
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
w_pe = w_pe * w1
w_pe = paddle.transpose(w_pe, [1, 2, 0])
w_pe = paddle.unsqueeze(w_pe, 2)
- h_pe = self.pe[:x.shape.shape[-2], :]
+ h_pe = self.pe[: x.shape.shape[-2], :]
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
h_pe = h_pe * w2
h_pe = paddle.transpose(h_pe, [1, 2, 0])
@@ -566,9 +595,9 @@ def forward(self, x):
x = x + w_pe + h_pe
x = paddle.transpose(
- paddle.reshape(x,
- [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
- [2, 0, 1])
+ paddle.reshape(x, [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
+ [2, 0, 1],
+ )
return self.dropout(x)
@@ -577,8 +606,7 @@ class Embeddings(nn.Layer):
def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
super(Embeddings, self).__init__()
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
- w0 = np.random.normal(0.0, d_model**-0.5,
- (vocab, d_model)).astype(np.float32)
+ w0 = np.random.normal(0.0, d_model**-0.5, (vocab, d_model)).astype(np.float32)
self.embedding.weight.set_value(w0)
self.d_model = d_model
self.scale_embedding = scale_embedding
@@ -590,20 +618,19 @@ def forward(self, x):
return self.embedding(x)
-class Beam():
- """ Beam search """
+class Beam:
+ """Beam search"""
def __init__(self, size, device=False):
-
self.size = size
self._done = False
# The score for each translation on the beam.
- self.scores = paddle.zeros((size, ), dtype=paddle.float32)
+ self.scores = paddle.zeros((size,), dtype=paddle.float32)
self.all_scores = []
# The backpointers at each time-step.
self.prev_ks = []
# The outputs at each time-step.
- self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
+ self.next_ys = [paddle.full((size,), 0, dtype=paddle.int64)]
self.next_ys[0][0] = 2
def get_current_state(self):
@@ -629,8 +656,9 @@ def advance(self, word_prob):
beam_lk = word_prob[0]
flat_beam_lk = beam_lk.reshape([-1])
- best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
- True) # 1st sort
+ best_scores, best_scores_id = flat_beam_lk.topk(
+ self.size, 0, True, True
+ ) # 1st sort
self.all_scores.append(self.scores)
self.scores = best_scores
# bestScoresId is flattened as a (beam x word) array,
@@ -648,7 +676,8 @@ def advance(self, word_prob):
def sort_scores(self):
"Sort the scores."
return self.scores, paddle.to_tensor(
- [i for i in range(int(self.scores.shape[0]))], dtype='int32')
+ [i for i in range(int(self.scores.shape[0]))], dtype="int32"
+ )
def get_the_best_score_and_idx(self):
"Get the score of the best in the beam."
@@ -663,11 +692,11 @@ def get_tentative_hypothesis(self):
_, keys = self.sort_scores()
hyps = [self.get_hypothesis(k) for k in keys]
hyps = [[2] + h for h in hyps]
- dec_seq = paddle.to_tensor(hyps, dtype='int64')
+ dec_seq = paddle.to_tensor(hyps, dtype="int64")
return dec_seq
def get_hypothesis(self, k):
- """ Walk back to construct the full hypothesis. """
+ """Walk back to construct the full hypothesis."""
hyp = []
for j in range(len(self.prev_ks) - 1, -1, -1):
hyp.append(self.next_ys[j + 1][k])
diff --git a/ppocr/modeling/heads/rec_parseq_head.py b/ppocr/modeling/heads/rec_parseq_head.py
index a06e1fbb5d..de27d1145b 100644
--- a/ppocr/modeling/heads/rec_parseq_head.py
+++ b/ppocr/modeling/heads/rec_parseq_head.py
@@ -34,55 +34,112 @@
class DecoderLayer(paddle.nn.Layer):
"""A Transformer decoder layer supporting two-stream attention (XLNet)
- This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
-
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-05):
+ This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
+
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="gelu",
+ layer_norm_eps=1e-05,
+ ):
super().__init__()
- self.self_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True) # paddle.nn.MultiHeadAttention默认为batch_first模式
- self.cross_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True)
- self.linear1 = paddle.nn.Linear(in_features=d_model, out_features=dim_feedforward)
+ self.self_attn = paddle.nn.MultiHeadAttention(
+ d_model, nhead, dropout=dropout, need_weights=True
+ ) # paddle.nn.MultiHeadAttention默认为batch_first模式
+ self.cross_attn = paddle.nn.MultiHeadAttention(
+ d_model, nhead, dropout=dropout, need_weights=True
+ )
+ self.linear1 = paddle.nn.Linear(
+ in_features=d_model, out_features=dim_feedforward
+ )
self.dropout = paddle.nn.Dropout(p=dropout)
- self.linear2 = paddle.nn.Linear(in_features=dim_feedforward, out_features=d_model)
- self.norm1 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
- self.norm2 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
- self.norm_q = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
- self.norm_c = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
+ self.linear2 = paddle.nn.Linear(
+ in_features=dim_feedforward, out_features=d_model
+ )
+ self.norm1 = paddle.nn.LayerNorm(
+ normalized_shape=d_model, epsilon=layer_norm_eps
+ )
+ self.norm2 = paddle.nn.LayerNorm(
+ normalized_shape=d_model, epsilon=layer_norm_eps
+ )
+ self.norm_q = paddle.nn.LayerNorm(
+ normalized_shape=d_model, epsilon=layer_norm_eps
+ )
+ self.norm_c = paddle.nn.LayerNorm(
+ normalized_shape=d_model, epsilon=layer_norm_eps
+ )
self.dropout1 = paddle.nn.Dropout(p=dropout)
self.dropout2 = paddle.nn.Dropout(p=dropout)
self.dropout3 = paddle.nn.Dropout(p=dropout)
- if activation == 'gelu':
+ if activation == "gelu":
self.activation = paddle.nn.GELU()
def __setstate__(self, state):
- if 'activation' not in state:
- state['activation'] = paddle.nn.functional.gelu
+ if "activation" not in state:
+ state["activation"] = paddle.nn.functional.gelu
super().__setstate__(state)
- def forward_stream(self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask):
+ def forward_stream(
+ self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask
+ ):
"""Forward pass for a single stream (i.e. content or query)
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
Both tgt_kv and memory are expected to be LayerNorm'd too.
memory is LayerNorm'd by ViT.
"""
if tgt_key_padding_mask is not None:
- tgt_mask1 = (tgt_mask!=float('-inf'))[None,None,:,:] & (tgt_key_padding_mask[:,None,None,:]==False)
- tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1)
+ tgt_mask1 = (tgt_mask != float("-inf"))[None, None, :, :] & (
+ tgt_key_padding_mask[:, None, None, :] == False
+ )
+ tgt2, sa_weights = self.self_attn(
+ tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1
+ )
else:
- tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask)
+ tgt2, sa_weights = self.self_attn(
+ tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask
+ )
tgt = tgt + self.dropout1(tgt2)
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
tgt = tgt + self.dropout2(tgt2)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
+ tgt2 = self.linear2(
+ self.dropout(self.activation(self.linear1(self.norm2(tgt))))
+ )
tgt = tgt + self.dropout3(tgt2)
return tgt, sa_weights, ca_weights
- def forward(self, query, content, memory, query_mask=None, content_mask=None, content_key_padding_mask=None, update_content=True):
+ def forward(
+ self,
+ query,
+ content,
+ memory,
+ query_mask=None,
+ content_mask=None,
+ content_key_padding_mask=None,
+ update_content=True,
+ ):
query_norm = self.norm_q(query)
content_norm = self.norm_c(content)
- query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
+ query = self.forward_stream(
+ query,
+ query_norm,
+ content_norm,
+ memory,
+ query_mask,
+ content_key_padding_mask,
+ )[0]
if update_content:
- content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, content_key_padding_mask)[0]
+ content = self.forward_stream(
+ content,
+ content_norm,
+ content_norm,
+ memory,
+ content_mask,
+ content_key_padding_mask,
+ )[0]
return query, content
@@ -91,7 +148,7 @@ def get_clones(module, N):
class Decoder(paddle.nn.Layer):
- __constants__ = ['norm']
+ __constants__ = ["norm"]
def __init__(self, decoder_layer, num_layers, norm):
super().__init__()
@@ -99,19 +156,36 @@ def __init__(self, decoder_layer, num_layers, norm):
self.num_layers = num_layers
self.norm = norm
- def forward(self, query, content, memory, query_mask: Optional[paddle.Tensor]=None, content_mask: Optional[paddle.Tensor]=None, content_key_padding_mask: Optional[paddle.Tensor]=None):
+ def forward(
+ self,
+ query,
+ content,
+ memory,
+ query_mask: Optional[paddle.Tensor] = None,
+ content_mask: Optional[paddle.Tensor] = None,
+ content_key_padding_mask: Optional[paddle.Tensor] = None,
+ ):
for i, mod in enumerate(self.layers):
last = i == len(self.layers) - 1
- query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last)
+ query, content = mod(
+ query,
+ content,
+ memory,
+ query_mask,
+ content_mask,
+ content_key_padding_mask,
+ update_content=not last,
+ )
query = self.norm(query)
return query
-
-class TokenEmbedding(paddle.nn.Layer):
+class TokenEmbedding(paddle.nn.Layer):
def __init__(self, charset_size: int, embed_dim: int):
super().__init__()
- self.embedding = paddle.nn.Embedding(num_embeddings=charset_size, embedding_dim=embed_dim)
+ self.embedding = paddle.nn.Embedding(
+ num_embeddings=charset_size, embedding_dim=embed_dim
+ )
self.embed_dim = embed_dim
def forward(self, tokens: paddle.Tensor):
@@ -134,25 +208,54 @@ def kaiming_normal_init(param, **kwargs):
class ParseQHead(nn.Layer):
- def __init__(self, out_channels, max_text_length, embed_dim, dec_num_heads, dec_mlp_ratio, dec_depth, perm_num, perm_forward, perm_mirrored, decode_ar, refine_iters, dropout, **kwargs):
+ def __init__(
+ self,
+ out_channels,
+ max_text_length,
+ embed_dim,
+ dec_num_heads,
+ dec_mlp_ratio,
+ dec_depth,
+ perm_num,
+ perm_forward,
+ perm_mirrored,
+ decode_ar,
+ refine_iters,
+ dropout,
+ **kwargs
+ ):
super().__init__()
self.bos_id = out_channels - 2
self.eos_id = 0
self.pad_id = out_channels - 1
-
+
self.max_label_length = max_text_length
self.decode_ar = decode_ar
self.refine_iters = refine_iters
- decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
- self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=paddle.nn.LayerNorm(normalized_shape=embed_dim))
+ decoder_layer = DecoderLayer(
+ embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout
+ )
+ self.decoder = Decoder(
+ decoder_layer,
+ num_layers=dec_depth,
+ norm=paddle.nn.LayerNorm(normalized_shape=embed_dim),
+ )
self.rng = np.random.default_rng()
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
self.perm_forward = perm_forward
self.perm_mirrored = perm_mirrored
- self.head = paddle.nn.Linear(in_features=embed_dim, out_features=out_channels - 2)
+ self.head = paddle.nn.Linear(
+ in_features=embed_dim, out_features=out_channels - 2
+ )
self.text_embed = TokenEmbedding(out_channels, embed_dim)
- self.pos_queries = paddle.create_parameter(shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape, dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype, default_initializer=paddle.nn.initializer.Assign(paddle.empty(shape=[1, max_text_length + 1, embed_dim])))
+ self.pos_queries = paddle.create_parameter(
+ shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape,
+ dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype,
+ default_initializer=paddle.nn.initializer.Assign(
+ paddle.empty(shape=[1, max_text_length + 1, embed_dim])
+ ),
+ )
self.pos_queries.stop_gradient = not True
self.dropout = paddle.nn.Dropout(p=dropout)
self._device = self.parameters()[0].place
@@ -169,76 +272,120 @@ def _init_weights(self, m):
if m._padding_idx is not None:
m.weight.data[m._padding_idx].zero_()
elif isinstance(m, paddle.nn.Conv2D):
- kaiming_normal_init(m.weight, fan_in=None, nonlinearity='relu')
+ kaiming_normal_init(m.weight, fan_in=None, nonlinearity="relu")
if m.bias is not None:
constant_init(m.bias, value=0.0)
- elif isinstance(m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)):
- constant_init(m.weight, value=1.0)
- constant_init(m.bias, value=0.0)
+ elif isinstance(
+ m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)
+ ):
+ constant_init(m.weight, value=1.0)
+ constant_init(m.bias, value=0.0)
def no_weight_decay(self):
- param_names = {'text_embed.embedding.weight', 'pos_queries'}
- enc_param_names = {('encoder.' + n) for n in self.encoder.
- no_weight_decay()}
+ param_names = {"text_embed.embedding.weight", "pos_queries"}
+ enc_param_names = {("encoder." + n) for n in self.encoder.no_weight_decay()}
return param_names.union(enc_param_names)
def encode(self, img):
return self.encoder(img)
- def decode(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, tgt_query=None, tgt_query_mask=None):
+ def decode(
+ self,
+ tgt,
+ memory,
+ tgt_mask=None,
+ tgt_padding_mask=None,
+ tgt_query=None,
+ tgt_query_mask=None,
+ ):
N, L = tgt.shape
null_ctx = self.text_embed(tgt[:, :1])
if L != 1:
- tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:])
+ tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
else:
tgt_emb = self.dropout(null_ctx)
if tgt_query is None:
tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
tgt_query = self.dropout(tgt_query)
- return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
+ return self.decoder(
+ tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask
+ )
def forward_test(self, memory, max_length=None):
testing = max_length is None
- max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
+ max_length = (
+ self.max_label_length
+ if max_length is None
+ else min(max_length, self.max_label_length)
+ )
bs = memory.shape[0]
num_steps = max_length + 1
pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
- tgt_mask = query_mask = paddle.triu(x=paddle.full(shape=(num_steps, num_steps), fill_value=float('-inf')), diagonal=1)
+ tgt_mask = query_mask = paddle.triu(
+ x=paddle.full(shape=(num_steps, num_steps), fill_value=float("-inf")),
+ diagonal=1,
+ )
if self.decode_ar:
- tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype('int64')
+ tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype(
+ "int64"
+ )
tgt_in[:, (0)] = self.bos_id
logits = []
for i in range(paddle.to_tensor(num_steps)):
j = i + 1
- tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j])
+ tgt_out = self.decode(
+ tgt_in[:, :j],
+ memory,
+ tgt_mask[:j, :j],
+ tgt_query=pos_queries[:, i:j],
+ tgt_query_mask=query_mask[i:j, :j],
+ )
p_i = self.head(tgt_out)
logits.append(p_i)
if j < num_steps:
tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
- if testing and (tgt_in == self.eos_id).astype('bool').any(axis=-1).astype('bool').all():
+ if (
+ testing
+ and (tgt_in == self.eos_id)
+ .astype("bool")
+ .any(axis=-1)
+ .astype("bool")
+ .all()
+ ):
break
logits = paddle.concat(x=logits, axis=1)
else:
- tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
+ tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
logits = self.head(tgt_out)
if self.refine_iters:
- temp = paddle.triu(x=paddle.ones(shape=[num_steps,num_steps], dtype='bool'), diagonal=2)
- posi = np.where(temp.cpu().numpy()==True)
+ temp = paddle.triu(
+ x=paddle.ones(shape=[num_steps, num_steps], dtype="bool"), diagonal=2
+ )
+ posi = np.where(temp.cpu().numpy() == True)
query_mask[posi] = 0
- bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
+ bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
for i in range(self.refine_iters):
tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
- tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype='int32')
+ tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype="int32")
tgt_padding_mask = tgt_padding_mask.cpu()
tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
- tgt_padding_mask = tgt_padding_mask.cuda().astype(dtype='float32')==1.0
- tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]])
+ tgt_padding_mask = (
+ tgt_padding_mask.cuda().astype(dtype="float32") == 1.0
+ )
+ tgt_out = self.decode(
+ tgt_in,
+ memory,
+ tgt_mask,
+ tgt_padding_mask,
+ tgt_query=pos_queries,
+ tgt_query_mask=query_mask[:, : tgt_in.shape[1]],
+ )
logits = self.head(tgt_out)
-
+
# transfer to probility
logits = F.softmax(logits, axis=-1)
@@ -248,8 +395,8 @@ def forward_test(self, memory, max_length=None):
def gen_tgt_perms(self, tgt):
"""Generate shared permutations for the whole batch.
- This works because the same attention mask can be used for the shorter sequences
- because of the padding mask.
+ This works because the same attention mask can be used for the shorter sequences
+ because of the padding mask.
"""
max_num_chars = tgt.shape[1] - 2
if max_num_chars == 1:
@@ -264,16 +411,25 @@ def gen_tgt_perms(self, tgt):
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
else:
selector = list(range(max_perms))
- perm_pool = paddle.to_tensor(data=list(permutations(range(max_num_chars), max_num_chars)), place=self._device)[selector]
+ perm_pool = paddle.to_tensor(
+ data=list(permutations(range(max_num_chars), max_num_chars)),
+ place=self._device,
+ )[selector]
if self.perm_forward:
perm_pool = perm_pool[1:]
perms = paddle.stack(x=perms)
if len(perm_pool):
- i = self.rng.choice(len(perm_pool), size=num_gen_perms -
- len(perms), replace=False)
+ i = self.rng.choice(
+ len(perm_pool), size=num_gen_perms - len(perms), replace=False
+ )
perms = paddle.concat(x=[perms, perm_pool[i]])
else:
- perms.extend([paddle.randperm(n=max_num_chars) for _ in range(num_gen_perms - len(perms))])
+ perms.extend(
+ [
+ paddle.randperm(n=max_num_chars)
+ for _ in range(num_gen_perms - len(perms))
+ ]
+ )
perms = paddle.stack(x=perms)
if self.perm_mirrored:
comp = perms.flip(axis=-1)
@@ -283,8 +439,9 @@ def gen_tgt_perms(self, tgt):
perm_2[1] = 0
perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
- eos_idx = paddle.full(shape=(len(perms), 1), fill_value=
- max_num_chars + 1, dtype=perms.dtype)
+ eos_idx = paddle.full(
+ shape=(len(perms), 1), fill_value=max_num_chars + 1, dtype=perms.dtype
+ )
perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
if len(perms) > 1:
perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
@@ -299,12 +456,12 @@ def generate_attn_masks(self, perm):
mask = paddle.zeros(shape=(sz, sz))
for i in range(sz):
query_idx = perm[i].cpu().numpy().tolist()
- masked_keys = perm[i + 1:].cpu().numpy().tolist()
+ masked_keys = perm[i + 1 :].cpu().numpy().tolist()
if len(masked_keys) == 0:
break
- mask[query_idx, masked_keys] = float('-inf')
+ mask[query_idx, masked_keys] = float("-inf")
content_mask = mask[:-1, :-1].clone()
- mask[paddle.eye(num_rows=sz).astype('bool')] = float('-inf')
+ mask[paddle.eye(num_rows=sz).astype("bool")] = float("-inf")
query_mask = mask[1:, :-1]
return content_mask, query_mask
@@ -316,16 +473,18 @@ def forward_train(self, memory, tgt):
final_out = {}
for i, perm in enumerate(tgt_perms):
tgt_mask, query_mask = self.generate_attn_masks(perm)
- out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
+ out = self.decode(
+ tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask
+ )
logits = self.head(out)
if i == 0:
- final_out['predict'] = logits
+ final_out["predict"] = logits
logits = logits.flatten(stop_axis=1)
logits_list.append(logits)
- final_out['logits_list'] = logits_list
- final_out['pad_id'] = self.pad_id
- final_out['eos_id'] = self.eos_id
+ final_out["logits_list"] = logits_list
+ final_out["pad_id"] = self.pad_id
+ final_out["eos_id"] = self.eos_id
return final_out
diff --git a/ppocr/modeling/heads/rec_rfl_head.py b/ppocr/modeling/heads/rec_rfl_head.py
index 1ded8cde93..9eadfb0dad 100644
--- a/ppocr/modeling/heads/rec_rfl_head.py
+++ b/ppocr/modeling/heads/rec_rfl_head.py
@@ -22,29 +22,24 @@
from .rec_att_head import AttentionLSTM
kaiming_init_ = KaimingNormal()
-zeros_ = Constant(value=0.)
-ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
class CNTHead(nn.Layer):
- def __init__(self,
- embed_size=512,
- encode_length=26,
- out_channels=38,
- **kwargs):
+ def __init__(self, embed_size=512, encode_length=26, out_channels=38, **kwargs):
super(CNTHead, self).__init__()
self.out_channels = out_channels
self.Wv_fusion = nn.Linear(embed_size, embed_size, bias_attr=False)
- self.Prediction_visual = nn.Linear(encode_length * embed_size,
- self.out_channels)
+ self.Prediction_visual = nn.Linear(
+ encode_length * embed_size, self.out_channels
+ )
def forward(self, visual_feature):
-
b, c, h, w = visual_feature.shape
- visual_feature = visual_feature.reshape([b, c, h * w]).transpose(
- [0, 2, 1])
+ visual_feature = visual_feature.reshape([b, c, h * w]).transpose([0, 2, 1])
visual_feature_num = self.Wv_fusion(visual_feature) # batch * 26 * 512
b, n, c = visual_feature_num.shape
# using visual feature directly calculate the text length
@@ -55,15 +50,16 @@ def forward(self, visual_feature):
class RFLHead(nn.Layer):
- def __init__(self,
- in_channels=512,
- hidden_size=256,
- batch_max_legnth=25,
- out_channels=38,
- use_cnt=True,
- use_seq=True,
- **kwargs):
-
+ def __init__(
+ self,
+ in_channels=512,
+ hidden_size=256,
+ batch_max_legnth=25,
+ out_channels=38,
+ use_cnt=True,
+ use_seq=True,
+ **kwargs
+ ):
super(RFLHead, self).__init__()
assert use_cnt or use_seq
self.use_cnt = use_cnt
@@ -73,13 +69,15 @@ def __init__(self,
embed_size=in_channels,
encode_length=batch_max_legnth + 1,
out_channels=out_channels,
- **kwargs)
+ **kwargs
+ )
if self.use_seq:
self.seq_head = AttentionLSTM(
in_channels=in_channels,
out_channels=out_channels,
hidden_size=hidden_size,
- **kwargs)
+ **kwargs
+ )
self.batch_max_legnth = batch_max_legnth
self.num_class = out_channels
self.apply(self.init_weights)
@@ -98,11 +96,11 @@ def forward(self, x, targets=None):
cnt_outputs = None
if self.use_seq:
if self.training:
- seq_outputs = self.seq_head(seq_inputs, targets[0],
- self.batch_max_legnth)
+ seq_outputs = self.seq_head(
+ seq_inputs, targets[0], self.batch_max_legnth
+ )
else:
- seq_outputs = self.seq_head(seq_inputs, None,
- self.batch_max_legnth)
+ seq_outputs = self.seq_head(seq_inputs, None, self.batch_max_legnth)
return cnt_outputs, seq_outputs
else:
return cnt_outputs
diff --git a/ppocr/modeling/heads/rec_robustscanner_head.py b/ppocr/modeling/heads/rec_robustscanner_head.py
index 6ca60ff109..ba8d07753a 100644
--- a/ppocr/modeling/heads/rec_robustscanner_head.py
+++ b/ppocr/modeling/heads/rec_robustscanner_head.py
@@ -28,6 +28,7 @@
import paddle.nn as nn
import paddle.nn.functional as F
+
class BaseDecoder(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
@@ -38,19 +39,24 @@ def forward_train(self, feat, out_enc, targets, img_metas):
def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError
- def forward(self,
- feat,
- out_enc,
- label=None,
- valid_ratios=None,
- word_positions=None,
- train_mode=True):
+ def forward(
+ self,
+ feat,
+ out_enc,
+ label=None,
+ valid_ratios=None,
+ word_positions=None,
+ train_mode=True,
+ ):
self.train_mode = train_mode
if train_mode:
- return self.forward_train(feat, out_enc, label, valid_ratios, word_positions)
+ return self.forward_train(
+ feat, out_enc, label, valid_ratios, word_positions
+ )
return self.forward_test(feat, out_enc, valid_ratios, word_positions)
+
class ChannelReductionEncoder(nn.Layer):
"""Change the channel number with a one by one convoluational layer.
@@ -59,14 +65,17 @@ class ChannelReductionEncoder(nn.Layer):
out_channels (int): Number of output channels.
"""
- def __init__(self,
- in_channels,
- out_channels,
- **kwargs):
+ def __init__(self, in_channels, out_channels, **kwargs):
super(ChannelReductionEncoder, self).__init__()
self.layer = nn.Conv2D(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0, weight_attr=nn.initializer.XavierNormal())
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=nn.initializer.XavierNormal(),
+ )
def forward(self, feat):
"""
@@ -84,12 +93,12 @@ def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
-class DotProductAttentionLayer(nn.Layer):
+class DotProductAttentionLayer(nn.Layer):
def __init__(self, dim_model=None):
super().__init__()
- self.scale = dim_model**-0.5 if dim_model is not None else 1.
+ self.scale = dim_model**-0.5 if dim_model is not None else 1.0
def forward(self, query, key, value, h, w, valid_ratios=None):
query = paddle.transpose(query, (0, 2, 1))
@@ -103,7 +112,7 @@ def forward(self, query, key, value, h, w, valid_ratios=None):
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, int(w * valid_ratio + 0.5))
if valid_width < w:
- logits[i, :, :, valid_width:] = float('-inf')
+ logits[i, :, :, valid_width:] = float("-inf")
# reshape to (n, c, h, w)
logits = paddle.reshape(logits, [n, c, t])
@@ -113,6 +122,7 @@ def forward(self, query, key, value, h, w, valid_ratios=None):
glimpse = paddle.transpose(glimpse, (0, 2, 1))
return glimpse
+
class SequenceAttentionDecoder(BaseDecoder):
"""Sequence attention decoder for RobustScanner.
@@ -143,18 +153,20 @@ class SequenceAttentionDecoder(BaseDecoder):
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
"""
- def __init__(self,
- num_classes=None,
- rnn_layers=2,
- dim_input=512,
- dim_model=128,
- max_seq_len=40,
- start_idx=0,
- mask=True,
- padding_idx=None,
- dropout=0,
- return_feature=False,
- encode_value=False):
+ def __init__(
+ self,
+ num_classes=None,
+ rnn_layers=2,
+ dim_input=512,
+ dim_model=128,
+ max_seq_len=40,
+ start_idx=0,
+ mask=True,
+ padding_idx=None,
+ dropout=0,
+ return_feature=False,
+ encode_value=False,
+ ):
super().__init__()
self.num_classes = num_classes
@@ -167,14 +179,16 @@ def __init__(self,
self.mask = mask
self.embedding = nn.Embedding(
- self.num_classes, self.dim_model, padding_idx=padding_idx)
+ self.num_classes, self.dim_model, padding_idx=padding_idx
+ )
self.sequence_layer = nn.LSTM(
input_size=dim_model,
hidden_size=dim_model,
num_layers=rnn_layers,
time_major=False,
- dropout=dropout)
+ dropout=dropout,
+ )
self.attention_layer = DotProductAttentionLayer()
@@ -182,7 +196,8 @@ def __init__(self,
if not self.return_feature:
pred_num_classes = num_classes - 1
self.prediction = nn.Linear(
- dim_model if encode_value else dim_input, pred_num_classes)
+ dim_model if encode_value else dim_input, pred_num_classes
+ )
def forward_train(self, feat, out_enc, targets, valid_ratios):
"""
@@ -243,12 +258,15 @@ def forward_test(self, feat, out_enc, valid_ratios):
seq_len = self.max_seq_len
batch_size = feat.shape[0]
- decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
+ decode_sequence = (
+ paddle.ones((batch_size, seq_len), dtype="int64") * self.start_idx
+ )
outputs = []
for i in range(seq_len):
- step_out = self.forward_test_step(feat, out_enc, decode_sequence,
- i, valid_ratios)
+ step_out = self.forward_test_step(
+ feat, out_enc, decode_sequence, i, valid_ratios
+ )
outputs.append(step_out)
max_idx = paddle.argmax(step_out, axis=1, keepdim=False)
if i < seq_len - 1:
@@ -258,8 +276,9 @@ def forward_test(self, feat, out_enc, valid_ratios):
return outputs
- def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
- valid_ratios):
+ def forward_test_step(
+ self, feat, out_enc, decode_sequence, current_step, valid_ratios
+ ):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
@@ -274,7 +293,7 @@ def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
tokens at current time step.
"""
-
+
embed = self.embedding(decode_sequence)
n, c_enc, h, w = out_enc.shape
@@ -306,7 +325,6 @@ def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
class PositionAwareLayer(nn.Layer):
-
def __init__(self, dim_model, rnn_layers=2):
super().__init__()
@@ -316,14 +334,14 @@ def __init__(self, dim_model, rnn_layers=2):
input_size=dim_model,
hidden_size=dim_model,
num_layers=rnn_layers,
- time_major=False)
+ time_major=False,
+ )
self.mixer = nn.Sequential(
- nn.Conv2D(
- dim_model, dim_model, kernel_size=3, stride=1, padding=1),
+ nn.Conv2D(dim_model, dim_model, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
- nn.Conv2D(
- dim_model, dim_model, kernel_size=3, stride=1, padding=1))
+ nn.Conv2D(dim_model, dim_model, kernel_size=3, stride=1, padding=1),
+ )
def forward(self, img_feature):
n, c, h, w = img_feature.shape
@@ -360,18 +378,20 @@ class PositionAttentionDecoder(BaseDecoder):
This decoder will not predict the final class which is assumed to be
``. Therefore, its output size is always :math:`C - 1`. ``
is also ignored by loss
-
+
"""
- def __init__(self,
- num_classes=None,
- rnn_layers=2,
- dim_input=512,
- dim_model=128,
- max_seq_len=40,
- mask=True,
- return_feature=False,
- encode_value=False):
+ def __init__(
+ self,
+ num_classes=None,
+ rnn_layers=2,
+ dim_input=512,
+ dim_model=128,
+ max_seq_len=40,
+ mask=True,
+ return_feature=False,
+ encode_value=False,
+ ):
super().__init__()
self.num_classes = num_classes
@@ -384,8 +404,7 @@ def __init__(self,
self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
- self.position_aware_module = PositionAwareLayer(
- self.dim_model, rnn_layers)
+ self.position_aware_module = PositionAwareLayer(self.dim_model, rnn_layers)
self.attention_layer = DotProductAttentionLayer()
@@ -393,12 +412,13 @@ def __init__(self,
if not self.return_feature:
pred_num_classes = num_classes - 1
self.prediction = nn.Linear(
- dim_model if encode_value else dim_input, pred_num_classes)
+ dim_model if encode_value else dim_input, pred_num_classes
+ )
def _get_position_index(self, length, batch_size):
position_index_list = []
for i in range(batch_size):
- position_index = paddle.arange(0, end=length, step=1, dtype='int64')
+ position_index = paddle.arange(0, end=length, step=1, dtype="int64")
position_index_list.append(position_index)
batch_position_index = paddle.stack(position_index_list, axis=0)
return batch_position_index
@@ -427,16 +447,16 @@ def forward_train(self, feat, out_enc, targets, valid_ratios, position_index):
assert c_feat == self.dim_input
_, len_q = targets.shape
assert len_q <= self.max_seq_len
-
+
position_out_enc = self.position_aware_module(out_enc)
query = self.embedding(position_index)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
if self.encode_value:
- value = paddle.reshape(out_enc,(n, c_enc, h * w))
+ value = paddle.reshape(out_enc, (n, c_enc, h * w))
else:
- value = paddle.reshape(feat,(n, c_feat, h * w))
+ value = paddle.reshape(feat, (n, c_feat, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
@@ -467,14 +487,14 @@ def forward_test(self, feat, out_enc, valid_ratios, position_index):
assert c_feat == self.dim_input
position_out_enc = self.position_aware_module(out_enc)
-
+
query = self.embedding(position_index)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
if self.encode_value:
- value = paddle.reshape(out_enc,(n, c_enc, h * w))
+ value = paddle.reshape(out_enc, (n, c_enc, h * w))
else:
- value = paddle.reshape(feat,(n, c_feat, h * w))
+ value = paddle.reshape(feat, (n, c_feat, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
@@ -484,8 +504,8 @@ def forward_test(self, feat, out_enc, valid_ratios, position_index):
return self.prediction(attn_out)
-class RobustScannerFusionLayer(nn.Layer):
+class RobustScannerFusionLayer(nn.Layer):
def __init__(self, dim_model, dim=-1):
super(RobustScannerFusionLayer, self).__init__()
@@ -500,6 +520,7 @@ def forward(self, x0, x1):
output = F.glu(output, self.dim)
return output
+
class RobustScannerDecoder(BaseDecoder):
"""Decoder for RobustScanner.
@@ -527,18 +548,20 @@ class RobustScannerDecoder(BaseDecoder):
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
"""
- def __init__(self,
- num_classes=None,
- dim_input=512,
- dim_model=128,
- hybrid_decoder_rnn_layers=2,
- hybrid_decoder_dropout=0,
- position_decoder_rnn_layers=2,
- max_seq_len=40,
- start_idx=0,
- mask=True,
- padding_idx=None,
- encode_value=False):
+ def __init__(
+ self,
+ num_classes=None,
+ dim_input=512,
+ dim_model=128,
+ hybrid_decoder_rnn_layers=2,
+ hybrid_decoder_dropout=0,
+ position_decoder_rnn_layers=2,
+ max_seq_len=40,
+ start_idx=0,
+ mask=True,
+ padding_idx=None,
+ encode_value=False,
+ ):
super().__init__()
self.num_classes = num_classes
self.dim_input = dim_input
@@ -561,7 +584,7 @@ def __init__(self,
padding_idx=padding_idx,
dropout=hybrid_decoder_dropout,
encode_value=encode_value,
- return_feature=True
+ return_feature=True,
)
# init position decoder
@@ -573,16 +596,17 @@ def __init__(self,
max_seq_len=max_seq_len,
mask=mask,
encode_value=encode_value,
- return_feature=True
+ return_feature=True,
)
-
self.fusion_module = RobustScannerFusionLayer(
- self.dim_model if encode_value else dim_input)
+ self.dim_model if encode_value else dim_input
+ )
pred_num_classes = num_classes - 1
- self.prediction = nn.Linear(dim_model if encode_value else dim_input,
- pred_num_classes)
+ self.prediction = nn.Linear(
+ dim_model if encode_value else dim_input, pred_num_classes
+ )
def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
"""
@@ -593,16 +617,18 @@ def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
target (dict): A dict with the key ``padded_targets``, a
tensor of shape :math:`(N, T)`. Each element is the index of a
character.
- valid_ratios (Tensor):
+ valid_ratios (Tensor):
word_positions (Tensor): The position of each word.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
"""
hybrid_glimpse = self.hybrid_decoder.forward_train(
- feat, out_enc, target, valid_ratios)
+ feat, out_enc, target, valid_ratios
+ )
position_glimpse = self.position_decoder.forward_train(
- feat, out_enc, target, valid_ratios, word_positions)
+ feat, out_enc, target, valid_ratios, word_positions
+ )
fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
@@ -616,7 +642,7 @@ def forward_test(self, feat, out_enc, valid_ratios, word_positions):
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
- valid_ratios (Tensor):
+ valid_ratios (Tensor):
word_positions (Tensor): The position of each word.
Returns:
Tensor: The output logit sequence tensor of shape
@@ -625,18 +651,23 @@ def forward_test(self, feat, out_enc, valid_ratios, word_positions):
seq_len = self.max_seq_len
batch_size = feat.shape[0]
- decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
+ decode_sequence = (
+ paddle.ones((batch_size, seq_len), dtype="int64") * self.start_idx
+ )
position_glimpse = self.position_decoder.forward_test(
- feat, out_enc, valid_ratios, word_positions)
+ feat, out_enc, valid_ratios, word_positions
+ )
outputs = []
for i in range(seq_len):
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
- feat, out_enc, decode_sequence, i, valid_ratios)
+ feat, out_enc, decode_sequence, i, valid_ratios
+ )
- fusion_out = self.fusion_module(hybrid_glimpse_step,
- position_glimpse[:, i, :])
+ fusion_out = self.fusion_module(
+ hybrid_glimpse_step, position_glimpse[:, i, :]
+ )
char_out = self.prediction(fusion_out)
char_out = F.softmax(char_out, -1)
@@ -649,28 +680,32 @@ def forward_test(self, feat, out_enc, valid_ratios, word_positions):
return outputs
+
class RobustScannerHead(nn.Layer):
- def __init__(self,
- out_channels, # 90 + unknown + start + padding
- in_channels,
- enc_outchannles=128,
- hybrid_dec_rnn_layers=2,
- hybrid_dec_dropout=0,
- position_dec_rnn_layers=2,
- start_idx=0,
- max_text_length=40,
- mask=True,
- padding_idx=None,
- encode_value=False,
- **kwargs):
+ def __init__(
+ self,
+ out_channels, # 90 + unknown + start + padding
+ in_channels,
+ enc_outchannles=128,
+ hybrid_dec_rnn_layers=2,
+ hybrid_dec_dropout=0,
+ position_dec_rnn_layers=2,
+ start_idx=0,
+ max_text_length=40,
+ mask=True,
+ padding_idx=None,
+ encode_value=False,
+ **kwargs
+ ):
super(RobustScannerHead, self).__init__()
# encoder module
self.encoder = ChannelReductionEncoder(
- in_channels=in_channels, out_channels=enc_outchannles)
+ in_channels=in_channels, out_channels=enc_outchannles
+ )
# decoder module
- self.decoder =RobustScannerDecoder(
+ self.decoder = RobustScannerDecoder(
num_classes=out_channels,
dim_input=in_channels,
dim_model=enc_outchannles,
@@ -681,30 +716,33 @@ def __init__(self,
start_idx=start_idx,
mask=mask,
padding_idx=padding_idx,
- encode_value=encode_value)
+ encode_value=encode_value,
+ )
def forward(self, inputs, targets=None):
- '''
+ """
targets: [label, valid_ratio, word_positions]
- '''
+ """
out_enc = self.encoder(inputs)
valid_ratios = None
word_positions = targets[-1]
if len(targets) > 1:
valid_ratios = targets[-2]
-
+
if self.training:
label = targets[0] # label
- label = paddle.to_tensor(label, dtype='int64')
+ label = paddle.to_tensor(label, dtype="int64")
final_out = self.decoder(
- inputs, out_enc, label, valid_ratios, word_positions)
+ inputs, out_enc, label, valid_ratios, word_positions
+ )
if not self.training:
final_out = self.decoder(
inputs,
out_enc,
label=None,
- valid_ratios=valid_ratios,
+ valid_ratios=valid_ratios,
word_positions=word_positions,
- train_mode=False)
+ train_mode=False,
+ )
return final_out
diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py
index b301787adb..906cff0d8c 100644
--- a/ppocr/modeling/heads/rec_sar_head.py
+++ b/ppocr/modeling/heads/rec_sar_head.py
@@ -39,14 +39,16 @@ class SAREncoder(nn.Layer):
mask (bool): If True, mask padding in RNN sequence.
"""
- def __init__(self,
- enc_bi_rnn=False,
- enc_drop_rnn=0.1,
- enc_gru=False,
- d_model=512,
- d_enc=512,
- mask=True,
- **kwargs):
+ def __init__(
+ self,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ enc_gru=False,
+ d_model=512,
+ d_enc=512,
+ mask=True,
+ **kwargs
+ ):
super().__init__()
assert isinstance(enc_bi_rnn, bool)
assert isinstance(enc_drop_rnn, (int, float))
@@ -62,16 +64,17 @@ def __init__(self,
# LSTM Encoder
if enc_bi_rnn:
- direction = 'bidirectional'
+ direction = "bidirectional"
else:
- direction = 'forward'
+ direction = "forward"
kwargs = dict(
input_size=d_model,
hidden_size=d_enc,
num_layers=2,
time_major=False,
dropout=enc_drop_rnn,
- direction=direction)
+ direction=direction,
+ )
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
@@ -90,8 +93,7 @@ def forward(self, feat, img_metas=None):
valid_ratios = img_metas[-1]
h_feat = feat.shape[2] # bsz c h w
- feat_v = F.max_pool2d(
- feat, kernel_size=(h_feat, 1), stride=1, padding=0)
+ feat_v = F.max_pool2d(feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
@@ -100,8 +102,10 @@ def forward(self, feat, img_metas=None):
valid_hf = []
T = paddle.shape(holistic_feat)[1]
for i in range(valid_ratios.shape[0]):
- valid_step = paddle.minimum(
- T, paddle.ceil(valid_ratios[i] * T).astype(T.dtype)) - 1
+ valid_step = (
+ paddle.minimum(T, paddle.ceil(valid_ratios[i] * T).astype(T.dtype))
+ - 1
+ )
valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = paddle.stack(valid_hf, axis=0)
else:
@@ -121,12 +125,7 @@ def forward_train(self, feat, out_enc, targets, img_metas):
def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError
- def forward(self,
- feat,
- out_enc,
- label=None,
- img_metas=None,
- train_mode=True):
+ def forward(self, feat, out_enc, label=None, img_metas=None, train_mode=True):
self.train_mode = train_mode
if train_mode:
@@ -155,20 +154,21 @@ class ParallelSARDecoder(BaseDecoder):
"""
def __init__(
- self,
- out_channels, # 90 + unknown + start + padding
- enc_bi_rnn=False,
- dec_bi_rnn=False,
- dec_drop_rnn=0.0,
- dec_gru=False,
- d_model=512,
- d_enc=512,
- d_k=64,
- pred_dropout=0.1,
- max_text_length=30,
- mask=True,
- pred_concat=True,
- **kwargs):
+ self,
+ out_channels, # 90 + unknown + start + padding
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_drop_rnn=0.0,
+ dec_gru=False,
+ d_model=512,
+ d_enc=512,
+ d_k=64,
+ pred_dropout=0.1,
+ max_text_length=30,
+ mask=True,
+ pred_concat=True,
+ **kwargs
+ ):
super().__init__()
self.num_classes = out_channels
@@ -185,15 +185,14 @@ def __init__(
# 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
- self.conv3x3_1 = nn.Conv2D(
- d_model, d_k, kernel_size=3, stride=1, padding=1)
+ self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer
if dec_bi_rnn:
- direction = 'bidirectional'
+ direction = "bidirectional"
else:
- direction = 'forward'
+ direction = "forward"
kwargs = dict(
input_size=encoder_rnn_out_size,
@@ -201,7 +200,8 @@ def __init__(
num_layers=2,
time_major=False,
dropout=dec_drop_rnn,
- direction=direction)
+ direction=direction,
+ )
if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs)
else:
@@ -209,9 +209,8 @@ def __init__(
# Decoder input embedding
self.embedding = nn.Embedding(
- self.num_classes,
- encoder_rnn_out_size,
- padding_idx=self.padding_idx)
+ self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx
+ )
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
@@ -222,12 +221,7 @@ def __init__(
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
- def _2d_attention(self,
- decoder_input,
- feat,
- holistic_feat,
- valid_ratios=None):
-
+ def _2d_attention(self, decoder_input, feat, holistic_feat, valid_ratios=None):
y = self.rnn_decoder(decoder_input)[0]
# y: bsz * (seq_len + 1) * hidden_size
@@ -255,9 +249,10 @@ def _2d_attention(self,
# cal mask of attention weight
for i in range(valid_ratios.shape[0]):
valid_width = paddle.minimum(
- w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
+ w, paddle.ceil(valid_ratios[i] * w).astype("int32")
+ )
if valid_width < w:
- attn_weight[i, :, :, valid_width:, :] = float('-inf')
+ attn_weight[i, :, :, valid_width:, :] = float("-inf")
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
attn_weight = F.softmax(attn_weight, axis=-1)
@@ -266,19 +261,20 @@ def _2d_attention(self,
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
# attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w
- attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
- (3, 4),
- keepdim=False)
+ attn_feat = paddle.sum(
+ paddle.multiply(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False
+ )
# bsz * (seq_len + 1) * C
# Linear transformation
if self.pred_concat:
hf_c = holistic_feat.shape[-1]
- holistic_feat = paddle.expand(
- holistic_feat, shape=[bsz, seq_len, hf_c])
+ holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c])
y = self.prediction(
- paddle.concat((y, attn_feat.astype(y.dtype),
- holistic_feat.astype(y.dtype)), 2))
+ paddle.concat(
+ (y, attn_feat.astype(y.dtype), holistic_feat.astype(y.dtype)), 2
+ )
+ )
else:
y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes
@@ -288,9 +284,9 @@ def _2d_attention(self,
return y
def forward_train(self, feat, out_enc, label, img_metas):
- '''
+ """
img_metas: [label, valid_ratio]
- '''
+ """
if img_metas is not None:
assert img_metas[0].shape[0] == feat.shape[0]
@@ -304,8 +300,7 @@ def forward_train(self, feat, out_enc, label, img_metas):
# bsz * 1 * emb_dim
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
# bsz * (seq_len + 1) * C
- out_dec = self._2d_attention(
- in_dec, feat, out_enc, valid_ratios=valid_ratios)
+ out_dec = self._2d_attention(in_dec, feat, out_enc, valid_ratios=valid_ratios)
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
@@ -319,8 +314,7 @@ def forward_test(self, feat, out_enc, img_metas):
seq_len = self.max_seq_len
bsz = feat.shape[0]
- start_token = paddle.full(
- (bsz, ), fill_value=self.start_idx, dtype='int64')
+ start_token = paddle.full((bsz,), fill_value=self.start_idx, dtype="int64")
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
@@ -336,7 +330,8 @@ def forward_test(self, feat, out_enc, img_metas):
outputs = []
for i in range(1, seq_len + 1):
decoder_output = self._2d_attention(
- decoder_input, feat, out_enc, valid_ratios=valid_ratios)
+ decoder_input, feat, out_enc, valid_ratios=valid_ratios
+ )
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
@@ -351,21 +346,23 @@ def forward_test(self, feat, out_enc, img_metas):
class SARHead(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- enc_dim=512,
- max_text_length=30,
- enc_bi_rnn=False,
- enc_drop_rnn=0.1,
- enc_gru=False,
- dec_bi_rnn=False,
- dec_drop_rnn=0.0,
- dec_gru=False,
- d_k=512,
- pred_dropout=0.1,
- pred_concat=True,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ enc_dim=512,
+ max_text_length=30,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ enc_gru=False,
+ dec_bi_rnn=False,
+ dec_drop_rnn=0.0,
+ dec_gru=False,
+ d_k=512,
+ pred_dropout=0.1,
+ pred_concat=True,
+ **kwargs
+ ):
super(SARHead, self).__init__()
# encoder module
@@ -374,7 +371,8 @@ def __init__(self,
enc_drop_rnn=enc_drop_rnn,
enc_gru=enc_gru,
d_model=in_channels,
- d_enc=enc_dim)
+ d_enc=enc_dim,
+ )
# decoder module
self.decoder = ParallelSARDecoder(
@@ -388,25 +386,22 @@ def __init__(self,
d_k=d_k,
pred_dropout=pred_dropout,
max_text_length=max_text_length,
- pred_concat=pred_concat)
+ pred_concat=pred_concat,
+ )
def forward(self, feat, targets=None):
- '''
+ """
img_metas: [label, valid_ratio]
- '''
+ """
holistic_feat = self.encoder(feat, targets) # bsz c
if self.training:
label = targets[0] # label
- final_out = self.decoder(
- feat, holistic_feat, label, img_metas=targets)
+ final_out = self.decoder(feat, holistic_feat, label, img_metas=targets)
else:
final_out = self.decoder(
- feat,
- holistic_feat,
- label=None,
- img_metas=targets,
- train_mode=False)
+ feat, holistic_feat, label=None, img_metas=targets, train_mode=False
+ )
# (bsz, seq_len, num_classes)
return final_out
diff --git a/ppocr/modeling/heads/rec_satrn_head.py b/ppocr/modeling/heads/rec_satrn_head.py
index ce7ae05450..6367fc51a8 100644
--- a/ppocr/modeling/heads/rec_satrn_head.py
+++ b/ppocr/modeling/heads/rec_satrn_head.py
@@ -29,13 +29,9 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- num_channels,
- filter_size,
- num_filters,
- stride,
- padding,
- num_groups=1):
+ def __init__(
+ self, num_channels, filter_size, num_filters, stride, padding, num_groups=1
+ ):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
@@ -45,12 +41,14 @@ def __init__(self,
stride=stride,
padding=padding,
groups=num_groups,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm2D(
num_filters,
weight_attr=ParamAttr(initializer=Constant(1)),
- bias_attr=ParamAttr(initializer=Constant(0)))
+ bias_attr=ParamAttr(initializer=Constant(0)),
+ )
self.relu = nn.ReLU()
def forward(self, inputs):
@@ -61,21 +59,23 @@ def forward(self, inputs):
class SATRNEncoderLayer(nn.Layer):
- def __init__(self,
- d_model=512,
- d_inner=512,
- n_head=8,
- d_k=64,
- d_v=64,
- dropout=0.1,
- qkv_bias=False):
+ def __init__(
+ self,
+ d_model=512,
+ d_inner=512,
+ n_head=8,
+ d_k=64,
+ d_v=64,
+ dropout=0.1,
+ qkv_bias=False,
+ ):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
+ n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout
+ )
self.norm2 = nn.LayerNorm(d_model)
- self.feed_forward = LocalityAwareFeedforward(
- d_model, d_inner, dropout=dropout)
+ self.feed_forward = LocalityAwareFeedforward(d_model, d_inner, dropout=dropout)
def forward(self, x, h, w, mask=None):
n, hw, c = x.shape
@@ -93,15 +93,17 @@ def forward(self, x, h, w, mask=None):
class LocalityAwareFeedforward(nn.Layer):
def __init__(
- self,
- d_in,
- d_hid,
- dropout=0.1, ):
+ self,
+ d_in,
+ d_hid,
+ dropout=0.1,
+ ):
super().__init__()
self.conv1 = ConvBNLayer(d_in, 1, d_hid, stride=1, padding=0)
self.depthwise_conv = ConvBNLayer(
- d_hid, 3, d_hid, stride=1, padding=1, num_groups=d_hid)
+ d_hid, 3, d_hid, stride=1, padding=1, num_groups=d_hid
+ )
self.conv2 = ConvBNLayer(d_hid, 1, d_in, stride=1, padding=0)
@@ -125,8 +127,8 @@ def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1):
w_position_encoder = w_position_encoder.transpose([1, 0])
w_position_encoder = w_position_encoder.reshape([1, d_hid, 1, n_width])
- self.register_buffer('h_position_encoder', h_position_encoder)
- self.register_buffer('w_position_encoder', w_position_encoder)
+ self.register_buffer("h_position_encoder", h_position_encoder)
+ self.register_buffer("w_position_encoder", w_position_encoder)
self.h_scale = self.scale_factor_generate(d_hid)
self.w_scale = self.scale_factor_generate(d_hid)
@@ -135,13 +137,11 @@ def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1):
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
- denominator = paddle.to_tensor([
- 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
- for hid_j in range(d_hid)
- ])
+ denominator = paddle.to_tensor(
+ [1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+ )
denominator = denominator.reshape([1, -1])
- pos_tensor = paddle.cast(
- paddle.arange(n_position).unsqueeze(-1), 'float32')
+ pos_tensor = paddle.cast(paddle.arange(n_position).unsqueeze(-1), "float32")
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
@@ -151,7 +151,10 @@ def _get_sinusoid_encoding_table(self, n_position, d_hid):
def scale_factor_generate(self, d_hid):
scale_factor = nn.Sequential(
nn.Conv2D(d_hid, d_hid, 1),
- nn.ReLU(), nn.Conv2D(d_hid, d_hid, 1), nn.Sigmoid())
+ nn.ReLU(),
+ nn.Conv2D(d_hid, d_hid, 1),
+ nn.Sigmoid(),
+ )
return scale_factor
@@ -160,10 +163,8 @@ def forward(self, x):
avg_pool = self.pool(x)
- h_pos_encoding = \
- self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
- w_pos_encoding = \
- self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
+ h_pos_encoding = self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
+ w_pos_encoding = self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
out = x + h_pos_encoding + w_pos_encoding
@@ -196,13 +197,9 @@ def masked_fill(x, mask, value):
class MultiHeadAttention(nn.Layer):
- def __init__(self,
- n_head=8,
- d_model=512,
- d_k=64,
- d_v=64,
- dropout=0.1,
- qkv_bias=False):
+ def __init__(
+ self, n_head=8, d_model=512, d_k=64, d_v=64, dropout=0.1, qkv_bias=False
+ ):
super().__init__()
self.n_head = n_head
self.d_k = d_k
@@ -228,8 +225,11 @@ def forward(self, q, k, v, mask=None):
k = self.linear_k(k).reshape([batch_size, len_k, self.n_head, self.d_k])
v = self.linear_v(v).reshape([batch_size, len_k, self.n_head, self.d_v])
- q, k, v = q.transpose([0, 2, 1, 3]), k.transpose(
- [0, 2, 1, 3]), v.transpose([0, 2, 1, 3])
+ q, k, v = (
+ q.transpose([0, 2, 1, 3]),
+ k.transpose([0, 2, 1, 3]),
+ v.transpose([0, 2, 1, 3]),
+ )
if mask is not None:
if mask.dim() == 3:
@@ -240,7 +240,8 @@ def forward(self, q, k, v, mask=None):
attn_out, _ = self.attention(q, k, v, mask=mask)
attn_out = attn_out.transpose([0, 2, 1, 3]).reshape(
- [batch_size, len_q, self.dim_v])
+ [batch_size, len_q, self.dim_v]
+ )
attn_out = self.fc(attn_out)
attn_out = self.proj_drop(attn_out)
@@ -249,27 +250,28 @@ def forward(self, q, k, v, mask=None):
class SATRNEncoder(nn.Layer):
- def __init__(self,
- n_layers=12,
- n_head=8,
- d_k=64,
- d_v=64,
- d_model=512,
- n_position=100,
- d_inner=256,
- dropout=0.1):
+ def __init__(
+ self,
+ n_layers=12,
+ n_head=8,
+ d_k=64,
+ d_v=64,
+ d_model=512,
+ n_position=100,
+ d_inner=256,
+ dropout=0.1,
+ ):
super().__init__()
self.d_model = d_model
self.position_enc = Adaptive2DPositionalEncoding(
- d_hid=d_model,
- n_height=n_position,
- n_width=n_position,
- dropout=dropout)
- self.layer_stack = nn.LayerList([
- SATRNEncoderLayer(
- d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
- for _ in range(n_layers)
- ])
+ d_hid=d_model, n_height=n_position, n_width=n_position, dropout=dropout
+ )
+ self.layer_stack = nn.LayerList(
+ [
+ SATRNEncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ]
+ )
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, feat, valid_ratios=None):
@@ -284,7 +286,7 @@ def forward(self, feat, valid_ratios=None):
"""
if valid_ratios is None:
bs = feat.shape[0]
- valid_ratios = paddle.full((bs, 1), 1., dtype=paddle.float32)
+ valid_ratios = paddle.full((bs, 1), 1.0, dtype=paddle.float32)
feat = self.position_enc(feat)
n, c, h, w = feat.shape
@@ -330,18 +332,16 @@ def __init__(self, d_hid=512, n_position=200, dropout=0):
# Not a parameter
# Position table of shape (1, n_position, d_hid)
self.register_buffer(
- 'position_table',
- self._get_sinusoid_encoding_table(n_position, d_hid))
+ "position_table", self._get_sinusoid_encoding_table(n_position, d_hid)
+ )
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
- denominator = paddle.to_tensor([
- 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
- for hid_j in range(d_hid)
- ])
+ denominator = paddle.to_tensor(
+ [1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+ )
denominator = denominator.reshape([1, -1])
- pos_tensor = paddle.cast(
- paddle.arange(n_position).unsqueeze(-1), 'float32')
+ pos_tensor = paddle.cast(paddle.arange(n_position).unsqueeze(-1), "float32")
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
@@ -349,20 +349,22 @@ def _get_sinusoid_encoding_table(self, n_position, d_hid):
return sinusoid_table.unsqueeze(0)
def forward(self, x):
- x = x + self.position_table[:, :x.shape[1]].clone().detach()
+ x = x + self.position_table[:, : x.shape[1]].clone().detach()
return self.dropout(x)
class TFDecoderLayer(nn.Layer):
- def __init__(self,
- d_model=512,
- d_inner=256,
- n_head=8,
- d_k=64,
- d_v=64,
- dropout=0.1,
- qkv_bias=False,
- operation_order=None):
+ def __init__(
+ self,
+ d_model=512,
+ d_inner=256,
+ n_head=8,
+ d_k=64,
+ d_v=64,
+ dropout=0.1,
+ qkv_bias=False,
+ operation_order=None,
+ ):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
@@ -370,52 +372,74 @@ def __init__(self,
self.norm3 = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
+ n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias
+ )
self.enc_attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
+ n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias
+ )
self.mlp = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
self.operation_order = operation_order
if self.operation_order is None:
- self.operation_order = ('norm', 'self_attn', 'norm', 'enc_dec_attn',
- 'norm', 'ffn')
+ self.operation_order = (
+ "norm",
+ "self_attn",
+ "norm",
+ "enc_dec_attn",
+ "norm",
+ "ffn",
+ )
assert self.operation_order in [
- ('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'),
- ('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm')
+ ("norm", "self_attn", "norm", "enc_dec_attn", "norm", "ffn"),
+ ("self_attn", "norm", "enc_dec_attn", "norm", "ffn", "norm"),
]
- def forward(self,
- dec_input,
- enc_output,
- self_attn_mask=None,
- dec_enc_attn_mask=None):
- if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn', 'norm',
- 'ffn', 'norm'):
- dec_attn_out = self.self_attn(dec_input, dec_input, dec_input,
- self_attn_mask)
+ def forward(
+ self, dec_input, enc_output, self_attn_mask=None, dec_enc_attn_mask=None
+ ):
+ if self.operation_order == (
+ "self_attn",
+ "norm",
+ "enc_dec_attn",
+ "norm",
+ "ffn",
+ "norm",
+ ):
+ dec_attn_out = self.self_attn(
+ dec_input, dec_input, dec_input, self_attn_mask
+ )
dec_attn_out += dec_input
dec_attn_out = self.norm1(dec_attn_out)
- enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output,
- enc_output, dec_enc_attn_mask)
+ enc_dec_attn_out = self.enc_attn(
+ dec_attn_out, enc_output, enc_output, dec_enc_attn_mask
+ )
enc_dec_attn_out += dec_attn_out
enc_dec_attn_out = self.norm2(enc_dec_attn_out)
mlp_out = self.mlp(enc_dec_attn_out)
mlp_out += enc_dec_attn_out
mlp_out = self.norm3(mlp_out)
- elif self.operation_order == ('norm', 'self_attn', 'norm',
- 'enc_dec_attn', 'norm', 'ffn'):
+ elif self.operation_order == (
+ "norm",
+ "self_attn",
+ "norm",
+ "enc_dec_attn",
+ "norm",
+ "ffn",
+ ):
dec_input_norm = self.norm1(dec_input)
- dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
- dec_input_norm, self_attn_mask)
+ dec_attn_out = self.self_attn(
+ dec_input_norm, dec_input_norm, dec_input_norm, self_attn_mask
+ )
dec_attn_out += dec_input
enc_dec_attn_in = self.norm2(dec_attn_out)
- enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
- enc_output, dec_enc_attn_mask)
+ enc_dec_attn_out = self.enc_attn(
+ enc_dec_attn_in, enc_output, enc_output, dec_enc_attn_mask
+ )
enc_dec_attn_out += dec_attn_out
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
@@ -425,20 +449,22 @@ def forward(self,
class SATRNDecoder(nn.Layer):
- def __init__(self,
- n_layers=6,
- d_embedding=512,
- n_head=8,
- d_k=64,
- d_v=64,
- d_model=512,
- d_inner=256,
- n_position=200,
- dropout=0.1,
- num_classes=93,
- max_seq_len=40,
- start_idx=1,
- padding_idx=92):
+ def __init__(
+ self,
+ n_layers=6,
+ d_embedding=512,
+ n_head=8,
+ d_k=64,
+ d_v=64,
+ d_model=512,
+ d_inner=256,
+ n_position=200,
+ dropout=0.1,
+ num_classes=93,
+ max_seq_len=40,
+ start_idx=1,
+ padding_idx=92,
+ ):
super().__init__()
self.padding_idx = padding_idx
@@ -446,17 +472,18 @@ def __init__(self,
self.max_seq_len = max_seq_len
self.trg_word_emb = nn.Embedding(
- num_classes, d_embedding, padding_idx=padding_idx)
+ num_classes, d_embedding, padding_idx=padding_idx
+ )
- self.position_enc = PositionalEncoding(
- d_embedding, n_position=n_position)
+ self.position_enc = PositionalEncoding(d_embedding, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
- self.layer_stack = nn.LayerList([
- TFDecoderLayer(
- d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
- for _ in range(n_layers)
- ])
+ self.layer_stack = nn.LayerList(
+ [
+ TFDecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ]
+ )
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
pred_num_class = num_classes - 1 # ignore padding_idx
@@ -464,16 +491,14 @@ def __init__(self,
@staticmethod
def get_pad_mask(seq, pad_idx):
-
return (seq != pad_idx).unsqueeze(-2)
@staticmethod
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.shape[1]
- subsequent_mask = 1 - paddle.triu(
- paddle.ones((len_s, len_s)), diagonal=1)
- subsequent_mask = paddle.cast(subsequent_mask.unsqueeze(0), 'bool')
+ subsequent_mask = 1 - paddle.triu(paddle.ones((len_s, len_s)), diagonal=1)
+ subsequent_mask = paddle.cast(subsequent_mask.unsqueeze(0), "bool")
return subsequent_mask
@@ -483,15 +508,13 @@ def _attention(self, trg_seq, src, src_mask=None):
tgt = self.dropout(trg_pos_encoded)
trg_mask = self.get_pad_mask(
- trg_seq,
- pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
+ trg_seq, pad_idx=self.padding_idx
+ ) & self.get_subsequent_mask(trg_seq)
output = tgt
for dec_layer in self.layer_stack:
output = dec_layer(
- output,
- src,
- self_attn_mask=trg_mask,
- dec_enc_attn_mask=src_mask)
+ output, src, self_attn_mask=trg_mask, dec_enc_attn_mask=src_mask
+ )
output = self.layer_norm(output)
return output
@@ -518,17 +541,20 @@ def forward_test(self, feat, out_enc, valid_ratio):
src_mask = self._get_mask(out_enc, valid_ratio)
N = out_enc.shape[0]
init_target_seq = paddle.full(
- (N, self.max_seq_len + 1), self.padding_idx, dtype='int64')
+ (N, self.max_seq_len + 1), self.padding_idx, dtype="int64"
+ )
# bsz * seq_len
init_target_seq[:, 0] = self.start_idx
outputs = []
for step in range(0, paddle.to_tensor(self.max_seq_len)):
decoder_output = self._attention(
- init_target_seq, out_enc, src_mask=src_mask)
+ init_target_seq, out_enc, src_mask=src_mask
+ )
# bsz * seq_len * C
step_result = F.softmax(
- self.classifier(decoder_output[:, step, :]), axis=-1)
+ self.classifier(decoder_output[:, step, :]), axis=-1
+ )
# bsz * num_classes
outputs.append(step_result)
step_max_index = paddle.argmax(step_result, axis=-1)
diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py
index d4d364aad9..87f3dadde1 100644
--- a/ppocr/modeling/heads/rec_spin_att_head.py
+++ b/ppocr/modeling/heads/rec_spin_att_head.py
@@ -34,7 +34,8 @@ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
- in_channels, hidden_size, out_channels, use_gru=False)
+ in_channels, hidden_size, out_channels, use_gru=False
+ )
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
@@ -43,21 +44,25 @@ def _char_to_onehot(self, input_char, onehot_dim):
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
- num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
+ num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
- hidden = (paddle.zeros((batch_size, self.hidden_size)),
- paddle.zeros((batch_size, self.hidden_size)))
+ hidden = (
+ paddle.zeros((batch_size, self.hidden_size)),
+ paddle.zeros((batch_size, self.hidden_size)),
+ )
output_hiddens = []
- if self.training: # for train
+ if self.training: # for train
targets = targets[0]
for i in range(num_steps):
char_onehots = self._char_to_onehot(
- targets[:, i], onehot_dim=self.num_classes)
- (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
- char_onehots)
+ targets[:, i], onehot_dim=self.num_classes
+ )
+ (outputs, hidden), alpha = self.attention_cell(
+ hidden, inputs, char_onehots
+ )
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
- probs = self.generator(output)
+ probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
@@ -67,16 +72,18 @@ def forward(self, inputs, targets=None, batch_max_length=25):
for i in range(num_steps):
char_onehots = self._char_to_onehot(
- targets, onehot_dim=self.num_classes)
- (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
- char_onehots)
+ targets, onehot_dim=self.num_classes
+ )
+ (outputs, hidden), alpha = self.attention_cell(
+ hidden, inputs, char_onehots
+ )
probs_step = self.generator(outputs)
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
- [probs, paddle.unsqueeze(
- probs_step, axis=1)], axis=1)
+ [probs, paddle.unsqueeze(probs_step, axis=1)], axis=1
+ )
next_input = probs_step.argmax(axis=1)
targets = next_input
if not self.training:
@@ -92,10 +99,12 @@ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ input_size=input_size + num_embeddings, hidden_size=hidden_size
+ )
else:
self.rnn = nn.GRUCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ input_size=input_size + num_embeddings, hidden_size=hidden_size
+ )
self.hidden_size = hidden_size
diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py
index 1070d8cd64..a56038ddf4 100644
--- a/ppocr/modeling/heads/rec_srn_head.py
+++ b/ppocr/modeling/heads/rec_srn_head.py
@@ -27,12 +27,20 @@
from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
from collections import OrderedDict
+
gradient_clip = 10
class PVAM(nn.Layer):
- def __init__(self, in_channels, char_num, max_text_length, num_heads,
- num_encoder_tus, hidden_dims):
+ def __init__(
+ self,
+ in_channels,
+ char_num,
+ max_text_length,
+ num_heads,
+ num_encoder_tus,
+ hidden_dims,
+ ):
super(PVAM, self).__init__()
self.char_num = char_num
self.max_length = max_text_length
@@ -56,18 +64,22 @@ def __init__(self, in_channels, char_num, max_text_length, num_heads,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
- weight_sharing=True)
+ weight_sharing=True,
+ )
# PVAM
self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
self.fc0 = paddle.nn.Linear(
in_features=in_channels,
- out_features=in_channels, )
+ out_features=in_channels,
+ )
self.emb = paddle.nn.Embedding(
- num_embeddings=self.max_length, embedding_dim=in_channels)
+ num_embeddings=self.max_length, embedding_dim=in_channels
+ )
self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
self.fc1 = paddle.nn.Linear(
- in_features=in_channels, out_features=1, bias_attr=False)
+ in_features=in_channels, out_features=1, bias_attr=False
+ )
def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
b, c, h, w = inputs.shape
@@ -85,23 +97,34 @@ def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
word_pos_feature = self.emb(gsrm_word_pos)
- word_pos_feature_ = paddle.reshape(word_pos_feature,
- [-1, self.max_length, 1, c])
+ word_pos_feature_ = paddle.reshape(
+ word_pos_feature, [-1, self.max_length, 1, c]
+ )
word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
y = word_pos_feature_ + word_features_
y = F.tanh(y)
attention_weight = self.fc1(y)
attention_weight = paddle.reshape(
- attention_weight, shape=[-1, self.max_length, t])
+ attention_weight, shape=[-1, self.max_length, t]
+ )
attention_weight = F.softmax(attention_weight, axis=-1)
- pvam_features = paddle.matmul(attention_weight,
- word_features) #[b, max_length, c]
+ pvam_features = paddle.matmul(
+ attention_weight, word_features
+ ) # [b, max_length, c]
return pvam_features
class GSRM(nn.Layer):
- def __init__(self, in_channels, char_num, max_text_length, num_heads,
- num_encoder_tus, num_decoder_tus, hidden_dims):
+ def __init__(
+ self,
+ in_channels,
+ char_num,
+ max_text_length,
+ num_heads,
+ num_encoder_tus,
+ num_decoder_tus,
+ hidden_dims,
+ ):
super(GSRM, self).__init__()
self.char_num = char_num
self.max_length = max_text_length
@@ -110,8 +133,7 @@ def __init__(self, in_channels, char_num, max_text_length, num_heads,
self.num_decoder_TUs = num_decoder_tus
self.hidden_dims = hidden_dims
- self.fc0 = paddle.nn.Linear(
- in_features=in_channels, out_features=self.char_num)
+ self.fc0 = paddle.nn.Linear(in_features=in_channels, out_features=self.char_num)
self.wrap_encoder0 = WrapEncoder(
src_vocab_size=self.char_num + 1,
max_length=self.max_length,
@@ -126,7 +148,8 @@ def __init__(self, in_channels, char_num, max_text_length, num_heads,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
- weight_sharing=True)
+ weight_sharing=True,
+ )
self.wrap_encoder1 = WrapEncoder(
src_vocab_size=self.char_num + 1,
@@ -142,14 +165,14 @@ def __init__(self, in_channels, char_num, max_text_length, num_heads,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
- weight_sharing=True)
+ weight_sharing=True,
+ )
- self.mul = lambda x: paddle.matmul(x=x,
- y=self.wrap_encoder0.prepare_decoder.emb0.weight,
- transpose_y=True)
+ self.mul = lambda x: paddle.matmul(
+ x=x, y=self.wrap_encoder0.prepare_decoder.emb0.weight, transpose_y=True
+ )
- def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
- gsrm_slf_attn_bias2):
+ def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2):
# ===== GSRM Visual-to-semantic embedding block =====
b, t, c = inputs.shape
pvam_features = paddle.reshape(inputs, [-1, c])
@@ -157,7 +180,7 @@ def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
word_ids = paddle.argmax(F.softmax(word_out), axis=1)
word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
- #===== GSRM Semantic reasoning block =====
+ # ===== GSRM Semantic reasoning block =====
"""
This module is achieved through bi-transformers,
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
@@ -176,10 +199,11 @@ def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
- gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
- value=0.,
- data_format="NLC")
- gsrm_feature2 = gsrm_feature2[:, 1:, ]
+ gsrm_feature2 = F.pad(gsrm_feature2, [0, 1], value=0.0, data_format="NLC")
+ gsrm_feature2 = gsrm_feature2[
+ :,
+ 1:,
+ ]
gsrm_features = gsrm_feature1 + gsrm_feature2
gsrm_out = self.mul(gsrm_features)
@@ -194,23 +218,21 @@ class VSFD(nn.Layer):
def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
super(VSFD, self).__init__()
self.char_num = char_num
- self.fc0 = paddle.nn.Linear(
- in_features=in_channels * 2, out_features=pvam_ch)
- self.fc1 = paddle.nn.Linear(
- in_features=pvam_ch, out_features=self.char_num)
+ self.fc0 = paddle.nn.Linear(in_features=in_channels * 2, out_features=pvam_ch)
+ self.fc1 = paddle.nn.Linear(in_features=pvam_ch, out_features=self.char_num)
def forward(self, pvam_feature, gsrm_feature):
b, t, c1 = pvam_feature.shape
b, t, c2 = gsrm_feature.shape
combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
- img_comb_feature_ = paddle.reshape(
- combine_feature_, shape=[-1, c1 + c2])
+ img_comb_feature_ = paddle.reshape(combine_feature_, shape=[-1, c1 + c2])
img_comb_feature_map = self.fc0(img_comb_feature_)
img_comb_feature_map = F.sigmoid(img_comb_feature_map)
- img_comb_feature_map = paddle.reshape(
- img_comb_feature_map, shape=[-1, t, c1])
- combine_feature = img_comb_feature_map * pvam_feature + (
- 1.0 - img_comb_feature_map) * gsrm_feature
+ img_comb_feature_map = paddle.reshape(img_comb_feature_map, shape=[-1, t, c1])
+ combine_feature = (
+ img_comb_feature_map * pvam_feature
+ + (1.0 - img_comb_feature_map) * gsrm_feature
+ )
img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
out = self.fc1(img_comb_feature)
@@ -218,8 +240,17 @@ def forward(self, pvam_feature, gsrm_feature):
class SRNHead(nn.Layer):
- def __init__(self, in_channels, out_channels, max_text_length, num_heads,
- num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ max_text_length,
+ num_heads,
+ num_encoder_TUs,
+ num_decoder_TUs,
+ hidden_dims,
+ **kwargs
+ ):
super(SRNHead, self).__init__()
self.char_num = out_channels
self.max_length = max_text_length
@@ -234,7 +265,8 @@ def __init__(self, in_channels, out_channels, max_text_length, num_heads,
max_text_length=self.max_length,
num_heads=self.num_heads,
num_encoder_tus=self.num_encoder_TUs,
- hidden_dims=self.hidden_dims)
+ hidden_dims=self.hidden_dims,
+ )
self.gsrm = GSRM(
in_channels=in_channels,
@@ -243,10 +275,13 @@ def __init__(self, in_channels, out_channels, max_text_length, num_heads,
num_heads=self.num_heads,
num_encoder_tus=self.num_encoder_TUs,
num_decoder_tus=self.num_decoder_TUs,
- hidden_dims=self.hidden_dims)
+ hidden_dims=self.hidden_dims,
+ )
self.vsfd = VSFD(in_channels=in_channels, char_num=self.char_num)
- self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
+ self.gsrm.wrap_encoder1.prepare_decoder.emb0 = (
+ self.gsrm.wrap_encoder0.prepare_decoder.emb0
+ )
def forward(self, inputs, targets=None):
others = targets[-4:]
@@ -258,8 +293,8 @@ def forward(self, inputs, targets=None):
pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
gsrm_feature, word_out, gsrm_out = self.gsrm(
- pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
- gsrm_slf_attn_bias2)
+ pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2
+ )
final_out = self.vsfd(pvam_feature, gsrm_feature)
if not self.training:
@@ -267,12 +302,14 @@ def forward(self, inputs, targets=None):
_, decoded_out = paddle.topk(final_out, k=1)
- predicts = OrderedDict([
- ('predict', final_out),
- ('pvam_feature', pvam_feature),
- ('decoded_out', decoded_out),
- ('word_out', word_out),
- ('gsrm_out', gsrm_out),
- ])
+ predicts = OrderedDict(
+ [
+ ("predict", final_out),
+ ("pvam_feature", pvam_feature),
+ ("decoded_out", decoded_out),
+ ("word_out", word_out),
+ ("gsrm_out", gsrm_out),
+ ]
+ )
return predicts
diff --git a/ppocr/modeling/heads/rec_visionlan_head.py b/ppocr/modeling/heads/rec_visionlan_head.py
index 86054d9bbb..21b401721d 100644
--- a/ppocr/modeling/heads/rec_visionlan_head.py
+++ b/ppocr/modeling/heads/rec_visionlan_head.py
@@ -32,10 +32,11 @@ class PositionalEncoding(nn.Layer):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
self.register_buffer(
- 'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
+ "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
+ )
def _get_sinusoid_encoding_table(self, n_position, d_hid):
- ''' Sinusoid position encoding table '''
+ """Sinusoid position encoding table"""
def get_position_angle_vec(position):
return [
@@ -44,15 +45,16 @@ def get_position_angle_vec(position):
]
sinusoid_table = np.array(
- [get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
+ )
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
- sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
+ sinusoid_table = paddle.to_tensor(sinusoid_table, dtype="float32")
sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
return sinusoid_table
def forward(self, x):
- return x + self.pos_table[:, :x.shape[1]].clone().detach()
+ return x + self.pos_table[:, : x.shape[1]].clone().detach()
class ScaledDotProductAttention(nn.Layer):
@@ -76,7 +78,8 @@ def forward(self, q, k, v, mask=None):
mask = paddle.unsqueeze(mask, axis=1)
mask = paddle.unsqueeze(mask, axis=1)
repeat_times = [
- attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
+ attn.shape[1] // mask.shape[1],
+ attn.shape[2] // mask.shape[2],
]
mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
attn[mask == 0] = -1e9
@@ -87,7 +90,7 @@ def forward(self, q, k, v, mask=None):
class MultiHeadAttention(nn.Layer):
- " Multi-Head Attention module"
+ "Multi-Head Attention module"
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super(MultiHeadAttention, self).__init__()
@@ -97,26 +100,30 @@ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
self.w_qs = nn.Linear(
d_model,
n_head * d_k,
- weight_attr=ParamAttr(initializer=Normal(
- mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+ weight_attr=ParamAttr(
+ initializer=Normal(mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
+ ),
+ )
self.w_ks = nn.Linear(
d_model,
n_head * d_k,
- weight_attr=ParamAttr(initializer=Normal(
- mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+ weight_attr=ParamAttr(
+ initializer=Normal(mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
+ ),
+ )
self.w_vs = nn.Linear(
d_model,
n_head * d_v,
- weight_attr=ParamAttr(initializer=Normal(
- mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
+ weight_attr=ParamAttr(
+ initializer=Normal(mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
+ ),
+ )
- self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
- 0.5))
+ self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(
- n_head * d_v,
- d_model,
- weight_attr=ParamAttr(initializer=XavierNormal()))
+ n_head * d_v, d_model, weight_attr=ParamAttr(initializer=XavierNormal())
+ )
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
@@ -127,8 +134,7 @@ def forward(self, q, k, v, mask=None):
residual = q
q = self.w_qs(q)
- q = paddle.reshape(
- q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
+ q = paddle.reshape(q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
k = self.w_ks(k)
k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
v = self.w_vs(v)
@@ -141,14 +147,15 @@ def forward(self, q, k, v, mask=None):
v = paddle.transpose(v, perm=[2, 0, 1, 3])
v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
- mask = paddle.tile(
- mask,
- [n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
+ mask = (
+ paddle.tile(mask, [n_head, 1, 1]) if mask is not None else None
+ ) # (n*b) x .. x ..
output = self.attention(q, k, v, mask=mask)
output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
output = paddle.transpose(output, perm=[1, 2, 0, 3])
output = paddle.reshape(
- output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
+ output, shape=[-1, len_q, n_head * d_v]
+ ) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output
@@ -173,47 +180,45 @@ def forward(self, x):
class EncoderLayer(nn.Layer):
- ''' Compose with two layers '''
+ """Compose with two layers"""
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
- self.slf_attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, dropout=dropout)
- self.pos_ffn = PositionwiseFeedForward(
- d_model, d_inner, dropout=dropout)
+ self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
+ self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
- enc_output = self.slf_attn(
- enc_input, enc_input, enc_input, mask=slf_attn_mask)
+ enc_output = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output
class Transformer_Encoder(nn.Layer):
- def __init__(self,
- n_layers=2,
- n_head=8,
- d_word_vec=512,
- d_k=64,
- d_v=64,
- d_model=512,
- d_inner=2048,
- dropout=0.1,
- n_position=256):
+ def __init__(
+ self,
+ n_layers=2,
+ n_head=8,
+ d_word_vec=512,
+ d_k=64,
+ d_v=64,
+ d_model=512,
+ d_inner=2048,
+ dropout=0.1,
+ n_position=256,
+ ):
super(Transformer_Encoder, self).__init__()
- self.position_enc = PositionalEncoding(
- d_word_vec, n_position=n_position)
+ self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
- self.layer_stack = nn.LayerList([
- EncoderLayer(
- d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
- for _ in range(n_layers)
- ])
+ self.layer_stack = nn.LayerList(
+ [
+ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ]
+ )
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
def forward(self, enc_output, src_mask, return_attns=False):
- enc_output = self.dropout(
- self.position_enc(enc_output)) # position embeding
+ enc_output = self.dropout(self.position_enc(enc_output)) # position embeding
for enc_layer in self.layer_stack:
enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
enc_output = self.layer_norm(enc_output)
@@ -222,7 +227,6 @@ def forward(self, enc_output, src_mask, return_attns=False):
class PP_layer(nn.Layer):
def __init__(self, n_dim=512, N_max_character=25, n_position=256):
-
super(PP_layer, self).__init__()
self.character_len = N_max_character
self.f0_embedding = nn.Embedding(N_max_character, n_dim)
@@ -234,17 +238,18 @@ def __init__(self, n_dim=512, N_max_character=25, n_position=256):
def forward(self, enc_output):
# enc_output: b,256,512
- reading_order = paddle.arange(self.character_len, dtype='int64')
+ reading_order = paddle.arange(self.character_len, dtype="int64")
reading_order = reading_order.unsqueeze(0).expand(
- [enc_output.shape[0], self.character_len]) # (S,) -> (B, S)
+ [enc_output.shape[0], self.character_len]
+ ) # (S,) -> (B, S)
reading_order = self.f0_embedding(reading_order) # b,25,512
# calculate attention
reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
t = self.w0(reading_order) # b,512,256
t = self.active(
- paddle.transpose(
- t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
+ paddle.transpose(t, perm=[0, 2, 1]) + self.wv(enc_output)
+ ) # b,256,512
t = self.we(t) # b,256,25
t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
g_output = paddle.bmm(t, enc_output) # b,25,512
@@ -252,22 +257,19 @@ def forward(self, enc_output):
class Prediction(nn.Layer):
- def __init__(self,
- n_dim=512,
- n_position=256,
- N_max_character=25,
- n_class=37):
+ def __init__(self, n_dim=512, n_position=256, N_max_character=25, n_class=37):
super(Prediction, self).__init__()
self.pp = PP_layer(
- n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+ n_dim=n_dim, N_max_character=N_max_character, n_position=n_position
+ )
self.pp_share = PP_layer(
- n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+ n_dim=n_dim, N_max_character=N_max_character, n_position=n_position
+ )
self.w_vrm = nn.Linear(n_dim, n_class) # output layer
self.w_share = nn.Linear(n_dim, n_class) # output layer
self.nclass = n_class
- def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
- use_mlm=True):
+ def forward(self, cnn_feature, f_res, f_sub, train_mode=False, use_mlm=True):
if train_mode:
if not use_mlm:
g_output = self.pp(cnn_feature) # b,25,512
@@ -294,9 +296,11 @@ class MLM(nn.Layer):
def __init__(self, n_dim=512, n_position=256, max_text_length=25):
super(MLM, self).__init__()
self.MLM_SequenceModeling_mask = Transformer_Encoder(
- n_layers=2, n_position=n_position)
+ n_layers=2, n_position=n_position
+ )
self.MLM_SequenceModeling_WCL = Transformer_Encoder(
- n_layers=1, n_position=n_position)
+ n_layers=1, n_position=n_position
+ )
self.pos_embedding = nn.Embedding(max_text_length, n_dim)
self.w0_linear = nn.Linear(1, n_position)
self.wv = nn.Linear(n_dim, n_dim)
@@ -308,7 +312,7 @@ def forward(self, x, label_pos):
# transformer unit for generating mask_c
feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
# position embedding layer
- label_pos = paddle.to_tensor(label_pos, dtype='int64')
+ label_pos = paddle.to_tensor(label_pos, dtype="int64")
pos_emb = self.pos_embedding(label_pos)
pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
@@ -351,24 +355,23 @@ class MLM_VRM(nn.Layer):
mask_c_show: visualization of Mask_c
"""
- def __init__(self,
- n_layers=3,
- n_position=256,
- n_dim=512,
- max_text_length=25,
- nclass=37):
+ def __init__(
+ self, n_layers=3, n_position=256, n_dim=512, max_text_length=25, nclass=37
+ ):
super(MLM_VRM, self).__init__()
- self.MLM = MLM(n_dim=n_dim,
- n_position=n_position,
- max_text_length=max_text_length)
+ self.MLM = MLM(
+ n_dim=n_dim, n_position=n_position, max_text_length=max_text_length
+ )
self.SequenceModeling = Transformer_Encoder(
- n_layers=n_layers, n_position=n_position)
+ n_layers=n_layers, n_position=n_position
+ )
self.Prediction = Prediction(
n_dim=n_dim,
n_position=n_position,
- N_max_character=max_text_length +
- 1, # N_max_character = 1 eos + 25 characters
- n_class=nclass)
+ N_max_character=max_text_length
+ + 1, # N_max_character = 1 eos + 25 characters
+ n_class=nclass,
+ )
self.nclass = nclass
self.max_text_length = max_text_length
@@ -379,22 +382,24 @@ def forward(self, x, label_pos, training_step, train_mode=False):
x = paddle.reshape(x, [-1, c, h * w])
x = paddle.transpose(x, perm=[0, 2, 1])
if train_mode:
- if training_step == 'LF_1':
+ if training_step == "LF_1":
f_res = 0
f_sub = 0
x = self.SequenceModeling(x, src_mask=None)
text_pre, test_rem, text_mas = self.Prediction(
- x, f_res, f_sub, train_mode=True, use_mlm=False)
+ x, f_res, f_sub, train_mode=True, use_mlm=False
+ )
return text_pre, text_pre, text_pre, text_pre
- elif training_step == 'LF_2':
+ elif training_step == "LF_2":
# MLM
f_res, f_sub, mask_c = self.MLM(x, label_pos)
x = self.SequenceModeling(x, src_mask=None)
text_pre, test_rem, text_mas = self.Prediction(
- x, f_res, f_sub, train_mode=True)
+ x, f_res, f_sub, train_mode=True
+ )
mask_c_show = trans_1d_2d(mask_c)
return text_pre, test_rem, text_mas, mask_c_show
- elif training_step == 'LA':
+ elif training_step == "LA":
# MLM
f_res, f_sub, mask_c = self.MLM(x, label_pos)
## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
@@ -413,7 +418,8 @@ def forward(self, x, label_pos, training_step, train_mode=False):
x = self.SequenceModeling(x, src_mask=None)
## prediction layer for MLM and VSR
text_pre, test_rem, text_mas = self.Prediction(
- x, f_res, f_sub, train_mode=True)
+ x, f_res, f_sub, train_mode=True
+ )
mask_c_show = trans_1d_2d(mask_c)
return text_pre, test_rem, text_mas, mask_c_show
else:
@@ -423,13 +429,9 @@ def forward(self, x, label_pos, training_step, train_mode=False):
f_sub = 0
contextual_feature = self.SequenceModeling(x, src_mask=None)
text_pre = self.Prediction(
- contextual_feature,
- f_res,
- f_sub,
- train_mode=False,
- use_mlm=False)
- text_pre = paddle.transpose(
- text_pre, perm=[1, 0, 2]) # (26, b, 37))
+ contextual_feature, f_res, f_sub, train_mode=False, use_mlm=False
+ )
+ text_pre = paddle.transpose(text_pre, perm=[1, 0, 2]) # (26, b, 37))
return text_pre, x
@@ -438,31 +440,35 @@ class VLHead(nn.Layer):
Architecture of VisionLAN
"""
- def __init__(self,
- in_channels,
- out_channels=36,
- n_layers=3,
- n_position=256,
- n_dim=512,
- max_text_length=25,
- training_step='LA'):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=36,
+ n_layers=3,
+ n_position=256,
+ n_dim=512,
+ max_text_length=25,
+ training_step="LA",
+ ):
super(VLHead, self).__init__()
self.MLM_VRM = MLM_VRM(
n_layers=n_layers,
n_position=n_position,
n_dim=n_dim,
max_text_length=max_text_length,
- nclass=out_channels + 1)
+ nclass=out_channels + 1,
+ )
self.training_step = training_step
def forward(self, feat, targets=None):
-
if self.training:
label_pos = targets[-2]
text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
- feat, label_pos, self.training_step, train_mode=True)
+ feat, label_pos, self.training_step, train_mode=True
+ )
return text_pre, test_rem, text_mas, mask_map
else:
text_pre, x = self.MLM_VRM(
- feat, targets, self.training_step, train_mode=False)
+ feat, targets, self.training_step, train_mode=False
+ )
return text_pre, x
diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py
index 85417dd913..59aff21b03 100644
--- a/ppocr/modeling/heads/self_attention.py
+++ b/ppocr/modeling/heads/self_attention.py
@@ -23,26 +23,29 @@
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
+
gradient_clip = 10
class WrapEncoderForFeature(nn.Layer):
- def __init__(self,
- src_vocab_size,
- max_length,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd,
- postprocess_cmd,
- weight_sharing,
- bos_idx=0):
+ def __init__(
+ self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0,
+ ):
super(WrapEncoderForFeature, self).__init__()
self.prepare_encoder = PrepareEncoder(
@@ -51,11 +54,21 @@ def __init__(self,
max_length,
prepostprocess_dropout,
bos_idx=bos_idx,
- word_emb_param_name="src_word_emb_table")
- self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
- d_inner_hid, prepostprocess_dropout,
- attention_dropout, relu_dropout, preprocess_cmd,
- postprocess_cmd)
+ word_emb_param_name="src_word_emb_table",
+ )
+ self.encoder = Encoder(
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ )
def forward(self, enc_inputs):
conv_features, src_pos, src_slf_attn_bias = enc_inputs
@@ -69,34 +82,42 @@ class WrapEncoder(nn.Layer):
embedder + encoder
"""
- def __init__(self,
- src_vocab_size,
- max_length,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd,
- postprocess_cmd,
- weight_sharing,
- bos_idx=0):
+ def __init__(
+ self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0,
+ ):
super(WrapEncoder, self).__init__()
self.prepare_decoder = PrepareDecoder(
- src_vocab_size,
+ src_vocab_size, d_model, max_length, prepostprocess_dropout, bos_idx=bos_idx
+ )
+ self.encoder = Encoder(
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
d_model,
- max_length,
+ d_inner_hid,
prepostprocess_dropout,
- bos_idx=bos_idx)
- self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
- d_inner_hid, prepostprocess_dropout,
- attention_dropout, relu_dropout, preprocess_cmd,
- postprocess_cmd)
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ )
def forward(self, enc_inputs):
src_word, src_pos, src_slf_attn_bias = enc_inputs
@@ -110,19 +131,20 @@ class Encoder(nn.Layer):
encoder
"""
- def __init__(self,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd="n",
- postprocess_cmd="da"):
-
+ def __init__(
+ self,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ ):
super(Encoder, self).__init__()
self.encoder_layers = list()
@@ -130,12 +152,23 @@ def __init__(self,
self.encoder_layers.append(
self.add_sublayer(
"layer_%d" % i,
- EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
- prepostprocess_dropout, attention_dropout,
- relu_dropout, preprocess_cmd,
- postprocess_cmd)))
- self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
- prepostprocess_dropout)
+ EncoderLayer(
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ ),
+ )
+ )
+ self.processer = PrePostProcessLayer(
+ preprocess_cmd, d_model, prepostprocess_dropout
+ )
def forward(self, enc_input, attn_bias):
for encoder_layer in self.encoder_layers:
@@ -150,35 +183,42 @@ class EncoderLayer(nn.Layer):
EncoderLayer
"""
- def __init__(self,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd="n",
- postprocess_cmd="da"):
-
+ def __init__(
+ self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ ):
super(EncoderLayer, self).__init__()
- self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
- prepostprocess_dropout)
- self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
- attention_dropout)
- self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
- prepostprocess_dropout)
-
- self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
- prepostprocess_dropout)
+ self.preprocesser1 = PrePostProcessLayer(
+ preprocess_cmd, d_model, prepostprocess_dropout
+ )
+ self.self_attn = MultiHeadAttention(
+ d_key, d_value, d_model, n_head, attention_dropout
+ )
+ self.postprocesser1 = PrePostProcessLayer(
+ postprocess_cmd, d_model, prepostprocess_dropout
+ )
+
+ self.preprocesser2 = PrePostProcessLayer(
+ preprocess_cmd, d_model, prepostprocess_dropout
+ )
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
- self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
- prepostprocess_dropout)
+ self.postprocesser2 = PrePostProcessLayer(
+ postprocess_cmd, d_model, prepostprocess_dropout
+ )
def forward(self, enc_input, attn_bias):
attn_output = self.self_attn(
- self.preprocesser1(enc_input), None, None, attn_bias)
+ self.preprocesser1(enc_input), None, None, attn_bias
+ )
attn_output = self.postprocesser1(attn_output, enc_input)
ffn_output = self.ffn(self.preprocesser2(attn_output))
ffn_output = self.postprocesser2(ffn_output, attn_output)
@@ -190,7 +230,7 @@ class MultiHeadAttention(nn.Layer):
Multi-Head Attention
"""
- def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
+ def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.0):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_key = d_key
@@ -198,13 +238,17 @@ def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
self.d_model = d_model
self.dropout_rate = dropout_rate
self.q_fc = paddle.nn.Linear(
- in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False
+ )
self.k_fc = paddle.nn.Linear(
- in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False
+ )
self.v_fc = paddle.nn.Linear(
- in_features=d_model, out_features=d_value * n_head, bias_attr=False)
+ in_features=d_model, out_features=d_value * n_head, bias_attr=False
+ )
self.proj_fc = paddle.nn.Linear(
- in_features=d_value * n_head, out_features=d_model, bias_attr=False)
+ in_features=d_value * n_head, out_features=d_model, bias_attr=False
+ )
def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention
@@ -255,8 +299,7 @@ def forward(self, queries, keys, values, attn_bias, cache=None):
product += attn_bias.astype(product.dtype)
weights = F.softmax(product)
if self.dropout_rate:
- weights = F.dropout(
- weights, p=self.dropout_rate, mode="downscale_in_infer")
+ weights = F.dropout(weights, p=self.dropout_rate, mode="downscale_in_infer")
out = paddle.matmul(weights, v)
# combine heads
@@ -288,13 +331,20 @@ def __init__(self, process_cmd, d_model, dropout_rate):
paddle.nn.LayerNorm(
normalized_shape=d_model,
weight_attr=paddle.ParamAttr(
- initializer=paddle.nn.initializer.Constant(1.)),
+ initializer=paddle.nn.initializer.Constant(1.0)
+ ),
bias_attr=paddle.ParamAttr(
- initializer=paddle.nn.initializer.Constant(0.)))))
+ initializer=paddle.nn.initializer.Constant(0.0)
+ ),
+ ),
+ )
+ )
elif cmd == "d": # add dropout
- self.functors.append(lambda x: F.dropout(
- x, p=dropout_rate, mode="downscale_in_infer")
- if dropout_rate else x)
+ self.functors.append(
+ lambda x: F.dropout(x, p=dropout_rate, mode="downscale_in_infer")
+ if dropout_rate
+ else x
+ )
def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd):
@@ -306,46 +356,50 @@ def forward(self, x, residual=None):
class PrepareEncoder(nn.Layer):
- def __init__(self,
- src_vocab_size,
- src_emb_dim,
- src_max_len,
- dropout_rate=0,
- bos_idx=0,
- word_emb_param_name=None,
- pos_enc_param_name=None):
+ def __init__(
+ self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None,
+ ):
super(PrepareEncoder, self).__init__()
self.src_emb_dim = src_emb_dim
self.src_max_len = src_max_len
self.emb = paddle.nn.Embedding(
- num_embeddings=self.src_max_len, embedding_dim=self.src_emb_dim)
+ num_embeddings=self.src_max_len, embedding_dim=self.src_emb_dim
+ )
self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos):
src_word_emb = src_word
- src_word_emb = paddle.cast(src_word_emb, 'float32')
+ src_word_emb = paddle.cast(src_word_emb, "float32")
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
src_pos = paddle.squeeze(src_pos, axis=-1)
src_pos_enc = self.emb(src_pos)
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
if self.dropout_rate:
- out = F.dropout(
- x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ out = F.dropout(x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
else:
out = enc_input
return out
class PrepareDecoder(nn.Layer):
- def __init__(self,
- src_vocab_size,
- src_emb_dim,
- src_max_len,
- dropout_rate=0,
- bos_idx=0,
- word_emb_param_name=None,
- pos_enc_param_name=None):
+ def __init__(
+ self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None,
+ ):
super(PrepareDecoder, self).__init__()
self.src_emb_dim = src_emb_dim
"""
@@ -358,15 +412,18 @@ def __init__(self,
padding_idx=bos_idx,
weight_attr=paddle.ParamAttr(
name=word_emb_param_name,
- initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
+ initializer=nn.initializer.Normal(0.0, src_emb_dim**-0.5),
+ ),
+ )
self.emb1 = paddle.nn.Embedding(
num_embeddings=src_max_len,
embedding_dim=self.src_emb_dim,
- weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
+ weight_attr=paddle.ParamAttr(name=pos_enc_param_name),
+ )
self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos):
- src_word = paddle.cast(src_word, 'int64')
+ src_word = paddle.cast(src_word, "int64")
src_word = paddle.squeeze(src_word, axis=-1)
src_word_emb = self.emb0(src_word)
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
@@ -375,8 +432,7 @@ def forward(self, src_word, src_pos):
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
if self.dropout_rate:
- out = F.dropout(
- x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ out = F.dropout(x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
else:
out = enc_input
return out
@@ -390,16 +446,13 @@ class FFN(nn.Layer):
def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__()
self.dropout_rate = dropout_rate
- self.fc1 = paddle.nn.Linear(
- in_features=d_model, out_features=d_inner_hid)
- self.fc2 = paddle.nn.Linear(
- in_features=d_inner_hid, out_features=d_model)
+ self.fc1 = paddle.nn.Linear(in_features=d_model, out_features=d_inner_hid)
+ self.fc2 = paddle.nn.Linear(in_features=d_inner_hid, out_features=d_model)
def forward(self, x):
hidden = self.fc1(x)
hidden = F.relu(hidden)
if self.dropout_rate:
- hidden = F.dropout(
- hidden, p=self.dropout_rate, mode="downscale_in_infer")
+ hidden = F.dropout(hidden, p=self.dropout_rate, mode="downscale_in_infer")
out = self.fc2(hidden)
return out
diff --git a/ppocr/modeling/heads/sr_rensnet_transformer.py b/ppocr/modeling/heads/sr_rensnet_transformer.py
index 8f2705678e..dcb8bfb77f 100644
--- a/ppocr/modeling/heads/sr_rensnet_transformer.py
+++ b/ppocr/modeling/heads/sr_rensnet_transformer.py
@@ -25,13 +25,13 @@
def subsequent_mask(size):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
- Unmasked positions are filled with float(0.0).
+ Unmasked positions are filled with float(0.0).
"""
- mask = paddle.ones([1, size, size], dtype='float32')
+ mask = paddle.ones([1, size, size], dtype="float32")
mask_inf = paddle.triu(
- paddle.full(
- shape=[1, size, size], dtype='float32', fill_value='-inf'),
- diagonal=1)
+ paddle.full(shape=[1, size, size], dtype="float32", fill_value="-inf"),
+ diagonal=1,
+ )
mask = mask + mask_inf
padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
return padding_mask
@@ -48,11 +48,10 @@ def masked_fill(x, mask, value):
def attention(query, key, value, mask=None, dropout=None, attention_map=None):
d_k = query.shape[-1]
- scores = paddle.matmul(query,
- paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
+ scores = paddle.matmul(query, paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
if mask is not None:
- scores = masked_fill(scores, mask == 0, float('-inf'))
+ scores = masked_fill(scores, mask == 0, float("-inf"))
else:
pass
@@ -80,9 +79,12 @@ def forward(self, query, key, value, mask=None, attention_map=None):
mask = mask.unsqueeze(1)
nbatches = query.shape[0]
- query, key, value = \
- [paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
- for l, x in zip(self.linears, (query, key, value))]
+ query, key, value = [
+ paddle.transpose(
+ l(x).reshape([nbatches, -1, self.h, self.d_k]), [0, 2, 1, 3]
+ )
+ for l, x in zip(self.linears, (query, key, value))
+ ]
x, attention_map = attention(
query,
@@ -90,11 +92,12 @@ def forward(self, query, key, value, mask=None, attention_map=None):
value,
mask=mask,
dropout=self.dropout,
- attention_map=attention_map)
+ attention_map=attention_map,
+ )
x = paddle.reshape(
- paddle.transpose(x, [0, 2, 1, 3]),
- [nbatches, -1, self.h * self.d_k])
+ paddle.transpose(x, [0, 2, 1, 3]), [nbatches, -1, self.h * self.d_k]
+ )
return self.linears[-1](x), attention_map
@@ -137,12 +140,11 @@ def __init__(self, num_in, block, layers):
self.layer4_conv2_relu = nn.ReLU()
def _make_layer(self, block, inplanes, planes, blocks):
-
if inplanes != planes:
downsample = nn.Sequential(
nn.Conv2D(inplanes, planes, 3, 1, 1),
- nn.BatchNorm2D(
- planes, use_global_stats=True), )
+ nn.BatchNorm2D(planes, use_global_stats=True),
+ )
else:
downsample = None
layers = []
@@ -222,15 +224,15 @@ def __init__(self, dropout, dim, max_len=5000):
pe = paddle.zeros([max_len, dim])
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
- paddle.arange(0, dim, 2).astype('float32') *
- (-math.log(10000.0) / dim))
+ paddle.arange(0, dim, 2).astype("float32") * (-math.log(10000.0) / dim)
+ )
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = paddle.unsqueeze(pe, 0)
- self.register_buffer('pe', pe)
+ self.register_buffer("pe", pe)
def forward(self, x):
- x = x + self.pe[:, :x.shape[1]]
+ x = x + self.pe[:, : x.shape[1]]
return self.dropout(x)
@@ -277,11 +279,11 @@ class LayerNorm(nn.Layer):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = self.create_parameter(
- shape=[features],
- default_initializer=paddle.nn.initializer.Constant(1.0))
+ shape=[features], default_initializer=paddle.nn.initializer.Constant(1.0)
+ )
self.b_2 = self.create_parameter(
- shape=[features],
- default_initializer=paddle.nn.initializer.Constant(0.0))
+ shape=[features], default_initializer=paddle.nn.initializer.Constant(0.0)
+ )
self.eps = eps
def forward(self, x):
@@ -294,8 +296,7 @@ class Decoder(nn.Layer):
def __init__(self):
super(Decoder, self).__init__()
- self.mask_multihead = MultiHeadedAttention(
- h=16, d_model=1024, dropout=0.1)
+ self.mask_multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
self.mul_layernorm1 = LayerNorm(1024)
self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
@@ -308,17 +309,14 @@ def forward(self, text, conv_feature, attention_map=None):
text_max_length = text.shape[1]
mask = subsequent_mask(text_max_length)
result = text
- result = self.mul_layernorm1(result + self.mask_multihead(
- text, text, text, mask=mask)[0])
+ result = self.mul_layernorm1(
+ result + self.mask_multihead(text, text, text, mask=mask)[0]
+ )
b, c, h, w = conv_feature.shape
- conv_feature = paddle.transpose(
- conv_feature.reshape([b, c, h * w]), [0, 2, 1])
+ conv_feature = paddle.transpose(conv_feature.reshape([b, c, h * w]), [0, 2, 1])
word_image_align, attention_map = self.multihead(
- result,
- conv_feature,
- conv_feature,
- mask=None,
- attention_map=attention_map)
+ result, conv_feature, conv_feature, mask=None, attention_map=attention_map
+ )
result = self.mul_layernorm2(result + word_image_align)
result = self.mul_layernorm3(result + self.pff(result))
@@ -328,12 +326,10 @@ def forward(self, text, conv_feature, attention_map=None):
class BasicBlock(nn.Layer):
def __init__(self, inplanes, planes, downsample):
super(BasicBlock, self).__init__()
- self.conv1 = nn.Conv2D(
- inplanes, planes, kernel_size=3, stride=1, padding=1)
+ self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
self.relu = nn.ReLU()
- self.conv2 = nn.Conv2D(
- planes, planes, kernel_size=3, stride=1, padding=1)
+ self.conv2 = nn.Conv2D(planes, planes, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
self.downsample = downsample
@@ -367,7 +363,7 @@ def forward(self, input):
class Transformer(nn.Layer):
- def __init__(self, in_channels=1, alphabet='0123456789'):
+ def __init__(self, in_channels=1, alphabet="0123456789"):
super(Transformer, self).__init__()
self.alphabet = alphabet
word_n_class = self.get_alphabet_len()
@@ -397,18 +393,21 @@ def forward(self, image, text_length, text_input, attention_map=None):
text_input = text_input[:, :max_length]
text_embedding = self.embedding_word_with_upperword(
- text_input) # batch, text_max_length, 512
+ text_input
+ ) # batch, text_max_length, 512
postion_embedding = self.pe(
- paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512
- text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
- 2) # batch, text_max_length, 1024
+ paddle.zeros(text_embedding.shape)
+ ) # batch, text_max_length, 512
+ text_input_with_pe = paddle.concat(
+ [text_embedding, postion_embedding], 2
+ ) # batch, text_max_length, 1024
batch, seq_len, _ = text_input_with_pe.shape
text_input_with_pe, word_attention_map = self.decoder(
- text_input_with_pe, conv_feature)
+ text_input_with_pe, conv_feature
+ )
- word_decoder_result = self.generator_word_with_upperword(
- text_input_with_pe)
+ word_decoder_result = self.generator_word_with_upperword(text_input_with_pe)
if self.training:
total_length = paddle.sum(text_length)
@@ -417,8 +416,9 @@ def forward(self, image, text_length, text_input, attention_map=None):
for index, length in enumerate(text_length):
length = int(length.numpy())
- probs_res[start:start + length, :] = word_decoder_result[
- index, 0:0 + length, :]
+ probs_res[start : start + length, :] = word_decoder_result[
+ index, 0 : 0 + length, :
+ ]
start = start + length
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
index e3fc8436e7..d204d5a433 100644
--- a/ppocr/modeling/heads/table_att_head.py
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -40,14 +40,16 @@ def get_para_bias_attr(l2_decay, k):
class TableAttentionHead(nn.Layer):
- def __init__(self,
- in_channels,
- hidden_size,
- in_max_len=488,
- max_text_length=800,
- out_channels=30,
- loc_reg_num=4,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ in_max_len=488,
+ max_text_length=800,
+ out_channels=30,
+ loc_reg_num=4,
+ **kwargs
+ ):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
@@ -55,7 +57,8 @@ def __init__(self,
self.max_text_length = max_text_length
self.structure_attention_cell = AttentionGRUCell(
- self.input_size, hidden_size, self.out_channels, use_gru=False)
+ self.input_size, hidden_size, self.out_channels, use_gru=False
+ )
self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.in_max_len = in_max_len
@@ -65,8 +68,7 @@ def __init__(self,
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
- self.loc_generator = nn.Linear(self.input_size + hidden_size,
- loc_reg_num)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size, loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
@@ -83,14 +85,17 @@ def forward(self, inputs, targets=None):
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = paddle.zeros(
- (batch_size, self.max_text_length + 1, self.hidden_size))
+ (batch_size, self.max_text_length + 1, self.hidden_size)
+ )
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot(
- structure[:, i], onehot_dim=self.out_channels)
+ structure[:, i], onehot_dim=self.out_channels
+ )
(outputs, hidden), alpha = self.structure_attention_cell(
- hidden, fea, elem_onehots)
+ hidden, fea, elem_onehots
+ )
output_hiddens[:, i, :] = outputs
structure_probs = self.structure_generator(output_hiddens)
loc_fea = fea.transpose([0, 2, 1])
@@ -109,9 +114,11 @@ def forward(self, inputs, targets=None):
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
elem_onehots = self._char_to_onehot(
- temp_elem, onehot_dim=self.out_channels)
+ temp_elem, onehot_dim=self.out_channels
+ )
(outputs, hidden), alpha = self.structure_attention_cell(
- hidden, fea, elem_onehots)
+ hidden, fea, elem_onehots
+ )
output_hiddens[:, i, :] = outputs
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
@@ -124,18 +131,20 @@ def forward(self, inputs, targets=None):
loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
- return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
+ return {"structure_probs": structure_probs, "loc_preds": loc_preds}
class SLAHead(nn.Layer):
- def __init__(self,
- in_channels,
- hidden_size,
- out_channels=30,
- max_text_length=500,
- loc_reg_num=4,
- fc_decay=0.0,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ out_channels=30,
+ max_text_length=500,
+ loc_reg_num=4,
+ fc_decay=0.0,
+ **kwargs
+ ):
"""
@param in_channels: input shape
@param hidden_size: hidden_size for RNN and Embedding
@@ -152,41 +161,48 @@ def __init__(self,
# structure
self.structure_attention_cell = AttentionGRUCell(
- in_channels, hidden_size, self.num_embeddings)
- weight_attr, bias_attr = get_para_bias_attr(
- l2_decay=fc_decay, k=hidden_size)
+ in_channels, hidden_size, self.num_embeddings
+ )
+ weight_attr, bias_attr = get_para_bias_attr(l2_decay=fc_decay, k=hidden_size)
weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
- l2_decay=fc_decay, k=hidden_size)
+ l2_decay=fc_decay, k=hidden_size
+ )
weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
- l2_decay=fc_decay, k=hidden_size)
+ l2_decay=fc_decay, k=hidden_size
+ )
self.structure_generator = nn.Sequential(
nn.Linear(
self.hidden_size,
self.hidden_size,
weight_attr=weight_attr1_2,
- bias_attr=bias_attr1_2),
+ bias_attr=bias_attr1_2,
+ ),
nn.Linear(
- hidden_size,
- out_channels,
- weight_attr=weight_attr,
- bias_attr=bias_attr))
+ hidden_size, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
+ ),
+ )
# loc
weight_attr1, bias_attr1 = get_para_bias_attr(
- l2_decay=fc_decay, k=self.hidden_size)
+ l2_decay=fc_decay, k=self.hidden_size
+ )
weight_attr2, bias_attr2 = get_para_bias_attr(
- l2_decay=fc_decay, k=self.hidden_size)
+ l2_decay=fc_decay, k=self.hidden_size
+ )
self.loc_generator = nn.Sequential(
nn.Linear(
self.hidden_size,
self.hidden_size,
weight_attr=weight_attr1,
- bias_attr=bias_attr1),
+ bias_attr=bias_attr1,
+ ),
nn.Linear(
self.hidden_size,
loc_reg_num,
weight_attr=weight_attr2,
- bias_attr=bias_attr2),
- nn.Sigmoid())
+ bias_attr=bias_attr2,
+ ),
+ nn.Sigmoid(),
+ )
def forward(self, inputs, targets=None):
fea = inputs[-1]
@@ -197,16 +213,19 @@ def forward(self, inputs, targets=None):
hidden = paddle.zeros((batch_size, self.hidden_size))
structure_preds = paddle.zeros(
- (batch_size, self.max_text_length + 1, self.num_embeddings))
+ (batch_size, self.max_text_length + 1, self.num_embeddings)
+ )
loc_preds = paddle.zeros(
- (batch_size, self.max_text_length + 1, self.loc_reg_num))
+ (batch_size, self.max_text_length + 1, self.loc_reg_num)
+ )
structure_preds.stop_gradient = True
loc_preds.stop_gradient = True
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_text_length + 1):
- hidden, structure_step, loc_step = self._decode(structure[:, i],
- fea, hidden)
+ hidden, structure_step, loc_step = self._decode(
+ structure[:, i], fea, hidden
+ )
structure_preds[:, i, :] = structure_step
loc_preds[:, i, :] = loc_step
else:
@@ -215,14 +234,13 @@ def forward(self, inputs, targets=None):
# for export
loc_step, structure_step = None, None
for i in range(max_text_length + 1):
- hidden, structure_step, loc_step = self._decode(pre_chars, fea,
- hidden)
+ hidden, structure_step, loc_step = self._decode(pre_chars, fea, hidden)
pre_chars = structure_step.argmax(axis=1, dtype="int32")
structure_preds[:, i, :] = structure_step
loc_preds[:, i, :] = loc_step
if not self.training:
structure_preds = F.softmax(structure_preds)
- return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
+ return {"structure_probs": structure_preds, "loc_preds": loc_preds}
def _decode(self, pre_chars, features, hidden):
"""
@@ -235,7 +253,8 @@ def _decode(self, pre_chars, features, hidden):
emb_feature = self.emb(pre_chars)
# output shape is b * self.hidden_size
(output, hidden), alpha = self.structure_attention_cell(
- hidden, features, emb_feature)
+ hidden, features, emb_feature
+ )
# structure
structure_step = self.structure_generator(output)
diff --git a/ppocr/modeling/heads/table_master_head.py b/ppocr/modeling/heads/table_master_head.py
index 8ae7d52c45..2785a2b0ac 100644
--- a/ppocr/modeling/heads/table_master_head.py
+++ b/ppocr/modeling/heads/table_master_head.py
@@ -30,28 +30,28 @@ class TableMasterHead(nn.Layer):
Bbox_layer is used to regress bbox coord.
"""
- def __init__(self,
- in_channels,
- out_channels=30,
- headers=8,
- d_ff=2048,
- dropout=0,
- max_text_length=500,
- loc_reg_num=4,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=30,
+ headers=8,
+ d_ff=2048,
+ dropout=0,
+ max_text_length=500,
+ loc_reg_num=4,
+ **kwargs
+ ):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
- self.layers = clones(
- DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
- self.cls_layer = clones(
- DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
- self.bbox_layer = clones(
- DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
+ self.layers = clones(DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
+ self.cls_layer = clones(DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
+ self.bbox_layer = clones(DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, loc_reg_num),
- nn.Sigmoid())
+ nn.Sigmoid(),
+ )
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
self.positional_encoding = PositionalEncoding(d_model=hidden_size)
@@ -73,11 +73,10 @@ def make_mask(self, tgt):
tgt_len = tgt.shape[1]
trg_sub_mask = paddle.tril(
- paddle.ones(
- ([tgt_len, tgt_len]), dtype=paddle.float32))
+ paddle.ones(([tgt_len, tgt_len]), dtype=paddle.float32)
+ )
- tgt_mask = paddle.logical_and(
- trg_pad_mask.astype(paddle.float32), trg_sub_mask)
+ tgt_mask = paddle.logical_and(trg_pad_mask.astype(paddle.float32), trg_sub_mask)
return tgt_mask.astype(paddle.float32)
def decode(self, input, feature, src_mask, tgt_mask):
@@ -105,18 +104,18 @@ def decode(self, input, feature, src_mask, tgt_mask):
def greedy_forward(self, SOS, feature):
input = SOS
output = paddle.zeros(
- [input.shape[0], self.max_text_length + 1, self.out_channels])
+ [input.shape[0], self.max_text_length + 1, self.out_channels]
+ )
bbox_output = paddle.zeros(
- [input.shape[0], self.max_text_length + 1, self.loc_reg_num])
+ [input.shape[0], self.max_text_length + 1, self.loc_reg_num]
+ )
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
- out_step, bbox_output_step = self.decode(input, feature, None,
- target_mask)
+ out_step, bbox_output_step = self.decode(input, feature, None, target_mask)
prob = F.softmax(out_step, axis=-1)
next_word = prob.argmax(axis=2, dtype="int64")
- input = paddle.concat(
- [input, next_word[:, -1].unsqueeze(-1)], axis=1)
+ input = paddle.concat([input, next_word[:, -1].unsqueeze(-1)], axis=1)
if i == self.max_text_length:
output = out_step
bbox_output = bbox_output_step
@@ -129,16 +128,17 @@ def forward_train(self, out_enc, targets):
padded_targets = targets[0]
src_mask = None
tgt_mask = self.make_mask(padded_targets[:, :-1])
- output, bbox_output = self.decode(padded_targets[:, :-1], out_enc,
- src_mask, tgt_mask)
- return {'structure_probs': output, 'loc_preds': bbox_output}
+ output, bbox_output = self.decode(
+ padded_targets[:, :-1], out_enc, src_mask, tgt_mask
+ )
+ return {"structure_probs": output, "loc_preds": bbox_output}
def forward_test(self, out_enc):
batch_size = out_enc.shape[0]
- SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS
+ SOS = paddle.zeros([batch_size, 1], dtype="int64") + self.SOS
output, bbox_output = self.greedy_forward(SOS, out_enc)
output = F.softmax(output)
- return {'structure_probs': output, 'loc_preds': bbox_output}
+ return {"structure_probs": output, "loc_preds": bbox_output}
def forward(self, feat, targets=None):
feat = feat[-1]
@@ -166,8 +166,7 @@ def __init__(self, headers, d_model, dropout, d_ff):
def forward(self, x, feature, src_mask, tgt_mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
- x = self.sublayer[1](
- x, lambda x: self.src_attn(x, feature, feature, src_mask))
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, feature, feature, src_mask))
return self.sublayer[2](x, self.feed_forward)
@@ -186,12 +185,14 @@ def forward(self, query, key, value, mask=None):
B = query.shape[0]
# 1) Do all the linear projections in batch from d_model => h x d_k
- query, key, value = \
- [l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
- for l, x in zip(self.linears, (query, key, value))]
+ query, key, value = [
+ l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
+ for l, x in zip(self.linears, (query, key, value))
+ ]
# 2) Apply attention on all the projected vectors in batch
x, self.attn = self_attention(
- query, key, value, mask=mask, dropout=self.dropout)
+ query, key, value, mask=mask, dropout=self.dropout
+ )
x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
return self.linears[-1](x)
@@ -246,7 +247,7 @@ def self_attention(query, key, value, mask=None, dropout=None):
def clones(module, N):
- """ Produce N identical layers """
+ """Produce N identical layers"""
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
@@ -262,22 +263,23 @@ def forward(self, *input):
class PositionalEncoding(nn.Layer):
- """ Implement the PE function. """
+ """Implement the PE function."""
- def __init__(self, d_model, dropout=0., max_len=5000):
+ def __init__(self, d_model, dropout=0.0, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = paddle.zeros([max_len, d_model])
- position = paddle.arange(0, max_len).unsqueeze(1).astype('float32')
+ position = paddle.arange(0, max_len).unsqueeze(1).astype("float32")
div_term = paddle.exp(
- paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
+ paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model
+ )
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
- self.register_buffer('pe', pe)
+ self.register_buffer("pe", pe)
def forward(self, feat, **kwargs):
- feat = feat + self.pe[:, :feat.shape[1]] # pe 1*5000*512
+ feat = feat + self.pe[:, : feat.shape[1]] # pe 1*5000*512
return self.dropout(feat)
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index f5e89a5b80..47e4c7ec58 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['build_neck']
+__all__ = ["build_neck"]
def build_neck(config):
@@ -29,15 +29,29 @@ def build_neck(config):
from .ct_fpn import CTFPN
from .fpn_unet import FPN_UNet
from .rf_adaptor import RFAdaptor
+
support_dict = [
- 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
- 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
- 'RFAdaptor', 'FPN_UNet'
+ "FPN",
+ "FCEFPN",
+ "LKPAN",
+ "DBFPN",
+ "RSEFPN",
+ "EASTFPN",
+ "SASTFPN",
+ "SequenceEncoder",
+ "PGFPN",
+ "TableFPN",
+ "PRENFPN",
+ "CSPPAN",
+ "CTFPN",
+ "RFAdaptor",
+ "FPN_UNet",
]
- module_name = config.pop('name')
- assert module_name in support_dict, Exception('neck only support {}'.format(
- support_dict))
+ module_name = config.pop("name")
+ assert module_name in support_dict, Exception(
+ "neck only support {}".format(support_dict)
+ )
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/modeling/necks/csp_pan.py b/ppocr/modeling/necks/csp_pan.py
index 1819619bfb..5e8464d5e4 100755
--- a/ppocr/modeling/necks/csp_pan.py
+++ b/ppocr/modeling/necks/csp_pan.py
@@ -20,21 +20,23 @@
import paddle.nn.functional as F
from paddle import ParamAttr
-__all__ = ['CSPPAN']
+__all__ = ["CSPPAN"]
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channel=96,
- out_channel=96,
- kernel_size=3,
- stride=1,
- groups=1,
- act='leaky_relu'):
+ def __init__(
+ self,
+ in_channel=96,
+ out_channel=96,
+ kernel_size=3,
+ stride=1,
+ groups=1,
+ act="leaky_relu",
+ ):
super(ConvBNLayer, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
- assert self.act in ['leaky_relu', "hard_swish"]
+ assert self.act in ["leaky_relu", "hard_swish"]
self.conv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
@@ -43,7 +45,8 @@ def __init__(self,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm2D(out_channel)
def forward(self, x):
@@ -67,12 +70,9 @@ class DPModule(nn.Layer):
Now support `leaky_relu` and `hard_swish`.
"""
- def __init__(self,
- in_channel=96,
- out_channel=96,
- kernel_size=3,
- stride=1,
- act='leaky_relu'):
+ def __init__(
+ self, in_channel=96, out_channel=96, kernel_size=3, stride=1, act="leaky_relu"
+ ):
super(DPModule, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
@@ -84,7 +84,8 @@ def __init__(self,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn1 = nn.BatchNorm2D(out_channel)
self.pwconv = nn.Conv2D(
in_channels=out_channel,
@@ -93,7 +94,8 @@ def __init__(self,
groups=1,
padding=0,
weight_attr=ParamAttr(initializer=initializer),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn2 = nn.BatchNorm2D(out_channel)
def act_func(self, x):
@@ -125,30 +127,30 @@ class DarknetBottleneck(nn.Layer):
Default: False
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- expansion=0.5,
- add_identity=True,
- use_depthwise=False,
- act="leaky_relu"):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ expansion=0.5,
+ add_identity=True,
+ use_depthwise=False,
+ act="leaky_relu",
+ ):
super(DarknetBottleneck, self).__init__()
hidden_channels = int(out_channels * expansion)
conv_func = DPModule if use_depthwise else ConvBNLayer
self.conv1 = ConvBNLayer(
- in_channel=in_channels,
- out_channel=hidden_channels,
- kernel_size=1,
- act=act)
+ in_channel=in_channels, out_channel=hidden_channels, kernel_size=1, act=act
+ )
self.conv2 = conv_func(
in_channel=hidden_channels,
out_channel=out_channels,
kernel_size=kernel_size,
stride=1,
- act=act)
- self.add_identity = \
- add_identity and in_channels == out_channels
+ act=act,
+ )
+ self.add_identity = add_identity and in_channels == out_channels
def forward(self, x):
identity = x
@@ -175,32 +177,37 @@ class CSPLayer(nn.Layer):
blocks. Default: False
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- expand_ratio=0.5,
- num_blocks=1,
- add_identity=True,
- use_depthwise=False,
- act="leaky_relu"):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ expand_ratio=0.5,
+ num_blocks=1,
+ add_identity=True,
+ use_depthwise=False,
+ act="leaky_relu",
+ ):
super().__init__()
mid_channels = int(out_channels * expand_ratio)
self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
- self.final_conv = ConvBNLayer(
- 2 * mid_channels, out_channels, 1, act=act)
-
- self.blocks = nn.Sequential(* [
- DarknetBottleneck(
- mid_channels,
- mid_channels,
- kernel_size,
- 1.0,
- add_identity,
- use_depthwise,
- act=act) for _ in range(num_blocks)
- ])
+ self.final_conv = ConvBNLayer(2 * mid_channels, out_channels, 1, act=act)
+
+ self.blocks = nn.Sequential(
+ *[
+ DarknetBottleneck(
+ mid_channels,
+ mid_channels,
+ kernel_size,
+ 1.0,
+ add_identity,
+ use_depthwise,
+ act=act,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
def forward(self, x):
x_short = self.short_conv(x)
@@ -213,16 +220,11 @@ def forward(self, x):
class Channel_T(nn.Layer):
- def __init__(self,
- in_channels=[116, 232, 464],
- out_channels=96,
- act="leaky_relu"):
+ def __init__(self, in_channels=[116, 232, 464], out_channels=96, act="leaky_relu"):
super(Channel_T, self).__init__()
self.convs = nn.LayerList()
for i in range(len(in_channels)):
- self.convs.append(
- ConvBNLayer(
- in_channels[i], out_channels, 1, act=act))
+ self.convs.append(ConvBNLayer(in_channels[i], out_channels, 1, act=act))
def forward(self, x):
outs = [self.convs[i](x[i]) for i in range(len(x))]
@@ -240,13 +242,15 @@ class CSPPAN(nn.Layer):
blocks. Default: True
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=5,
- num_csp_blocks=1,
- use_depthwise=True,
- act='hard_swish'):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=5,
+ num_csp_blocks=1,
+ use_depthwise=True,
+ act="hard_swish",
+ ):
super(CSPPAN, self).__init__()
self.in_channels = in_channels
self.out_channels = [out_channels] * len(in_channels)
@@ -255,7 +259,7 @@ def __init__(self,
self.conv_t = Channel_T(in_channels, out_channels, act=act)
# build top-down blocks
- self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
self.top_down_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1, 0, -1):
self.top_down_blocks.append(
@@ -266,7 +270,9 @@ def __init__(self,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
- act=act))
+ act=act,
+ )
+ )
# build bottom-up blocks
self.downsamples = nn.LayerList()
@@ -278,7 +284,9 @@ def __init__(self,
out_channels,
kernel_size=kernel_size,
stride=2,
- act=act))
+ act=act,
+ )
+ )
self.bottom_up_blocks.append(
CSPLayer(
out_channels * 2,
@@ -287,7 +295,9 @@ def __init__(self,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
- act=act))
+ act=act,
+ )
+ )
def forward(self, inputs):
"""
@@ -305,10 +315,12 @@ def forward(self, inputs):
feat_heigh = inner_outs[0]
feat_low = inputs[idx - 1]
upsample_feat = F.upsample(
- feat_heigh, size=feat_low.shape[2:4], mode="nearest")
+ feat_heigh, size=feat_low.shape[2:4], mode="nearest"
+ )
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
- paddle.concat([upsample_feat, feat_low], 1))
+ paddle.concat([upsample_feat, feat_low], 1)
+ )
inner_outs.insert(0, inner_out)
# bottom-up path
@@ -317,8 +329,9 @@ def forward(self, inputs):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsamples[idx](feat_low)
- out = self.bottom_up_blocks[idx](paddle.concat(
- [downsample_feat, feat_height], 1))
+ out = self.bottom_up_blocks[idx](
+ paddle.concat([downsample_feat, feat_height], 1)
+ )
outs.append(out)
return tuple(outs)
diff --git a/ppocr/modeling/necks/ct_fpn.py b/ppocr/modeling/necks/ct_fpn.py
index ee4d25e901..c7d5773146 100644
--- a/ppocr/modeling/necks/ct_fpn.py
+++ b/ppocr/modeling/necks/ct_fpn.py
@@ -25,21 +25,17 @@
import math
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
-ones_ = Constant(value=1.)
-zeros_ = Constant(value=0.)
+
+ones_ = Constant(value=1.0)
+zeros_ = Constant(value=0.0)
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../..")))
class Conv_BN_ReLU(nn.Layer):
- def __init__(self,
- in_planes,
- out_planes,
- kernel_size=1,
- stride=1,
- padding=0):
+ def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0):
super(Conv_BN_ReLU, self).__init__()
self.conv = nn.Conv2D(
in_planes,
@@ -47,14 +43,15 @@ def __init__(self,
kernel_size=kernel_size,
stride=stride,
padding=padding,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm2D(out_planes)
self.relu = nn.ReLU()
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
+ normal_ = Normal(mean=0.0, std=math.sqrt(2.0 / n))
normal_(m.weight)
elif isinstance(m, nn.BatchNorm2D):
zeros_(m.bias)
@@ -75,7 +72,8 @@ def __init__(self, in_channels, out_channels):
stride=1,
padding=1,
groups=planes,
- bias_attr=False)
+ bias_attr=False,
+ )
self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes)
self.dwconv2_1 = nn.Conv2D(
@@ -85,7 +83,8 @@ def __init__(self, in_channels, out_channels):
stride=1,
padding=1,
groups=planes,
- bias_attr=False)
+ bias_attr=False,
+ )
self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes)
self.dwconv1_1 = nn.Conv2D(
@@ -95,7 +94,8 @@ def __init__(self, in_channels, out_channels):
stride=1,
padding=1,
groups=planes,
- bias_attr=False)
+ bias_attr=False,
+ )
self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes)
self.dwconv2_2 = nn.Conv2D(
@@ -105,7 +105,8 @@ def __init__(self, in_channels, out_channels):
stride=2,
padding=1,
groups=planes,
- bias_attr=False)
+ bias_attr=False,
+ )
self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes)
self.dwconv3_2 = nn.Conv2D(
@@ -115,7 +116,8 @@ def __init__(self, in_channels, out_channels):
stride=2,
padding=1,
groups=planes,
- bias_attr=False)
+ bias_attr=False,
+ )
self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes)
self.dwconv4_2 = nn.Conv2D(
@@ -125,11 +127,12 @@ def __init__(self, in_channels, out_channels):
stride=2,
padding=1,
groups=planes,
- bias_attr=False)
+ bias_attr=False,
+ )
self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes)
def _upsample_add(self, x, y):
- return F.upsample(x, scale_factor=2, mode='bilinear') + y
+ return F.upsample(x, scale_factor=2, mode="bilinear") + y
def forward(self, f1, f2, f3, f4):
# up-down
@@ -159,7 +162,7 @@ def __init__(self, in_channels, out_channel=128):
self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
def _upsample(self, x, scale=1):
- return F.upsample(x, scale_factor=scale, mode='bilinear')
+ return F.upsample(x, scale_factor=scale, mode="bilinear")
def forward(self, f):
# # reduce channel
diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py
index 0f5b826bfb..4c74f36bc0 100644
--- a/ppocr/modeling/necks/db_fpn.py
+++ b/ppocr/modeling/necks/db_fpn.py
@@ -26,22 +26,24 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../..")))
from ppocr.modeling.backbones.det_mobilenet_v3 import SEModule
class DSConv(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- padding,
- stride=1,
- groups=None,
- if_act=True,
- act="relu",
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding,
+ stride=1,
+ groups=None,
+ if_act=True,
+ act="relu",
+ **kwargs
+ ):
super(DSConv, self).__init__()
if groups == None:
groups = in_channels
@@ -54,7 +56,8 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None)
@@ -63,7 +66,8 @@ def __init__(self,
out_channels=int(in_channels * 4),
kernel_size=1,
stride=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None)
@@ -72,7 +76,8 @@ def __init__(self,
out_channels=out_channels,
kernel_size=1,
stride=1,
- bias_attr=False)
+ bias_attr=False,
+ )
self._c = [in_channels, out_channels]
if in_channels != out_channels:
self.conv_end = nn.Conv2D(
@@ -80,10 +85,10 @@ def __init__(self,
out_channels=out_channels,
kernel_size=1,
stride=1,
- bias_attr=False)
+ bias_attr=False,
+ )
def forward(self, inputs):
-
x = self.conv1(inputs)
x = self.bn1(x)
@@ -95,8 +100,11 @@ def forward(self, inputs):
elif self.act == "hardswish":
x = F.hardswish(x)
else:
- print("The activation function({}) is selected incorrectly.".
- format(self.act))
+ print(
+ "The activation function({}) is selected incorrectly.".format(
+ self.act
+ )
+ )
exit()
x = self.conv3(x)
@@ -117,53 +125,61 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
if self.use_asf is True:
self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
@@ -177,11 +193,14 @@ def forward(self, x):
in2 = self.in2_conv(c2)
out4 = in4 + F.upsample(
- in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
+ in5, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/16
out3 = in3 + F.upsample(
- out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
+ out4, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/8
out2 = in2 + F.upsample(
- out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
+ out3, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/4
p5 = self.p5_conv(in5)
p4 = self.p4_conv(out4)
@@ -210,7 +229,8 @@ def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
kernel_size=kernel_size,
padding=int(kernel_size // 2),
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.se_block = SEModule(self.out_channels)
self.shortcut = shortcut
@@ -230,8 +250,8 @@ def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
self.ins_conv = nn.LayerList()
self.inp_conv = nn.LayerList()
self.intracl = False
- if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
- self.intracl = kwargs['intracl']
+ if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
+ self.intracl = kwargs["intracl"]
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
@@ -239,17 +259,13 @@ def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
for i in range(len(in_channels)):
self.ins_conv.append(
- RSELayer(
- in_channels[i],
- out_channels,
- kernel_size=1,
- shortcut=shortcut))
+ RSELayer(in_channels[i], out_channels, kernel_size=1, shortcut=shortcut)
+ )
self.inp_conv.append(
RSELayer(
- out_channels,
- out_channels // 4,
- kernel_size=3,
- shortcut=shortcut))
+ out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut
+ )
+ )
def forward(self, x):
c2, c3, c4, c5 = x
@@ -260,11 +276,14 @@ def forward(self, x):
in2 = self.ins_conv[0](c2)
out4 = in4 + F.upsample(
- in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
+ in5, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/16
out3 = in3 + F.upsample(
- out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
+ out4, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/8
out2 = in2 + F.upsample(
- out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
+ out3, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/4
p5 = self.inp_conv[3](in5)
p4 = self.inp_conv[2](out4)
@@ -286,7 +305,7 @@ def forward(self, x):
class LKPAN(nn.Layer):
- def __init__(self, in_channels, out_channels, mode='large', **kwargs):
+ def __init__(self, in_channels, out_channels, mode="large", **kwargs):
super(LKPAN, self).__init__()
self.out_channels = out_channels
weight_attr = paddle.nn.initializer.KaimingUniform()
@@ -297,14 +316,16 @@ def __init__(self, in_channels, out_channels, mode='large', **kwargs):
self.pan_head_conv = nn.LayerList()
self.pan_lat_conv = nn.LayerList()
- if mode.lower() == 'lite':
+ if mode.lower() == "lite":
p_layer = DSConv
- elif mode.lower() == 'large':
+ elif mode.lower() == "large":
p_layer = nn.Conv2D
else:
raise ValueError(
- "mode can only be one of ['lite', 'large'], but received {}".
- format(mode))
+ "mode can only be one of ['lite', 'large'], but received {}".format(
+ mode
+ )
+ )
for i in range(len(in_channels)):
self.ins_conv.append(
@@ -313,7 +334,9 @@ def __init__(self, in_channels, out_channels, mode='large', **kwargs):
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False))
+ bias_attr=False,
+ )
+ )
self.inp_conv.append(
p_layer(
@@ -322,7 +345,9 @@ def __init__(self, in_channels, out_channels, mode='large', **kwargs):
kernel_size=9,
padding=4,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False))
+ bias_attr=False,
+ )
+ )
if i > 0:
self.pan_head_conv.append(
@@ -333,7 +358,9 @@ def __init__(self, in_channels, out_channels, mode='large', **kwargs):
padding=1,
stride=2,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False))
+ bias_attr=False,
+ )
+ )
self.pan_lat_conv.append(
p_layer(
in_channels=self.out_channels // 4,
@@ -341,11 +368,13 @@ def __init__(self, in_channels, out_channels, mode='large', **kwargs):
kernel_size=9,
padding=4,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False))
+ bias_attr=False,
+ )
+ )
self.intracl = False
- if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
- self.intracl = kwargs['intracl']
+ if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
+ self.intracl = kwargs["intracl"]
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
@@ -360,11 +389,14 @@ def forward(self, x):
in2 = self.ins_conv[0](c2)
out4 = in4 + F.upsample(
- in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
+ in5, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/16
out3 = in3 + F.upsample(
- out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
+ out4, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/8
out2 = in2 + F.upsample(
- out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
+ out3, scale_factor=2, mode="nearest", align_mode=1
+ ) # 1/4
f5 = self.inp_conv[3](in5)
f4 = self.inp_conv[2](out4)
@@ -416,22 +448,25 @@ def __init__(self, in_channels, inter_channels, out_features_num=4):
self.conv = nn.Conv2D(in_channels, inter_channels, 3, padding=1)
self.spatial_scale = nn.Sequential(
- #Nx1xHxW
+ # Nx1xHxW
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=3,
bias_attr=False,
padding=1,
- weight_attr=ParamAttr(initializer=weight_attr)),
+ weight_attr=ParamAttr(initializer=weight_attr),
+ ),
nn.ReLU(),
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=1,
bias_attr=False,
- weight_attr=ParamAttr(initializer=weight_attr)),
- nn.Sigmoid())
+ weight_attr=ParamAttr(initializer=weight_attr),
+ ),
+ nn.Sigmoid(),
+ )
self.channel_scale = nn.Sequential(
nn.Conv2D(
@@ -439,8 +474,10 @@ def __init__(self, in_channels, inter_channels, out_features_num=4):
out_channels=out_features_num,
kernel_size=1,
bias_attr=False,
- weight_attr=ParamAttr(initializer=weight_attr)),
- nn.Sigmoid())
+ weight_attr=ParamAttr(initializer=weight_attr),
+ ),
+ nn.Sigmoid(),
+ )
def forward(self, fuse_features, features_list):
fuse_features = self.conv(fuse_features)
@@ -451,5 +488,5 @@ def forward(self, fuse_features, features_list):
out_list = []
for i in range(self.out_features_num):
- out_list.append(attention_scores[:, i:i + 1] * features_list[i])
- return paddle.concat(out_list, axis=1)
\ No newline at end of file
+ out_list.append(attention_scores[:, i : i + 1] * features_list[i])
+ return paddle.concat(out_list, axis=1)
diff --git a/ppocr/modeling/necks/east_fpn.py b/ppocr/modeling/necks/east_fpn.py
index 120ff156cb..1b7f50db61 100644
--- a/ppocr/modeling/necks/east_fpn.py
+++ b/ppocr/modeling/necks/east_fpn.py
@@ -23,16 +23,18 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -43,8 +45,9 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(
num_channels=out_channels,
@@ -52,7 +55,8 @@ def __init__(self,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
- moving_variance_name="bn_" + name + "_variance")
+ moving_variance_name="bn_" + name + "_variance",
+ )
def forward(self, x):
x = self.conv(x)
@@ -61,16 +65,18 @@ def forward(self, x):
class DeConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -81,15 +87,17 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
- moving_variance_name="bn_" + name + "_variance")
+ moving_variance_name="bn_" + name + "_variance",
+ )
def forward(self, x):
x = self.deconv(x)
@@ -107,32 +115,35 @@ def __init__(self, in_channels, model_name, **kwargs):
self.out_channels = 64
self.in_channels = in_channels[::-1]
self.h1_conv = ConvBNLayer(
- in_channels=self.out_channels+self.in_channels[1],
+ in_channels=self.out_channels + self.in_channels[1],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
- act='relu',
- name="unet_h_1")
+ act="relu",
+ name="unet_h_1",
+ )
self.h2_conv = ConvBNLayer(
- in_channels=self.out_channels+self.in_channels[2],
+ in_channels=self.out_channels + self.in_channels[2],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
- act='relu',
- name="unet_h_2")
+ act="relu",
+ name="unet_h_2",
+ )
self.h3_conv = ConvBNLayer(
- in_channels=self.out_channels+self.in_channels[3],
+ in_channels=self.out_channels + self.in_channels[3],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
- act='relu',
- name="unet_h_3")
+ act="relu",
+ name="unet_h_3",
+ )
self.g0_deconv = DeConvBNLayer(
in_channels=self.in_channels[0],
out_channels=self.out_channels,
@@ -140,8 +151,9 @@ def __init__(self, in_channels, model_name, **kwargs):
stride=2,
padding=1,
if_act=True,
- act='relu',
- name="unet_g_0")
+ act="relu",
+ name="unet_g_0",
+ )
self.g1_deconv = DeConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
@@ -149,8 +161,9 @@ def __init__(self, in_channels, model_name, **kwargs):
stride=2,
padding=1,
if_act=True,
- act='relu',
- name="unet_g_1")
+ act="relu",
+ name="unet_g_1",
+ )
self.g2_deconv = DeConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
@@ -158,8 +171,9 @@ def __init__(self, in_channels, model_name, **kwargs):
stride=2,
padding=1,
if_act=True,
- act='relu',
- name="unet_g_2")
+ act="relu",
+ name="unet_g_2",
+ )
self.g3_conv = ConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
@@ -167,8 +181,9 @@ def __init__(self, in_channels, model_name, **kwargs):
stride=1,
padding=1,
if_act=True,
- act='relu',
- name="unet_g_3")
+ act="relu",
+ name="unet_g_3",
+ )
def forward(self, x):
f = x[::-1]
@@ -185,4 +200,4 @@ def forward(self, x):
h = self.h3_conv(h)
g = self.g3_conv(h)
- return g
\ No newline at end of file
+ return g
diff --git a/ppocr/modeling/necks/fce_fpn.py b/ppocr/modeling/necks/fce_fpn.py
index 954e964e97..a456fd16e6 100644
--- a/ppocr/modeling/necks/fce_fpn.py
+++ b/ppocr/modeling/necks/fce_fpn.py
@@ -23,25 +23,26 @@
from paddle.nn.initializer import Normal
from paddle.regularizer import L2Decay
-__all__ = ['FCEFPN']
+__all__ = ["FCEFPN"]
class ConvNormLayer(nn.Layer):
- def __init__(self,
- ch_in,
- ch_out,
- filter_size,
- stride,
- groups=1,
- norm_type='bn',
- norm_decay=0.,
- norm_groups=32,
- lr_scale=1.,
- freeze_norm=False,
- initializer=Normal(
- mean=0., std=0.01)):
+ def __init__(
+ self,
+ ch_in,
+ ch_out,
+ filter_size,
+ stride,
+ groups=1,
+ norm_type="bn",
+ norm_decay=0.0,
+ norm_groups=32,
+ lr_scale=1.0,
+ freeze_norm=False,
+ initializer=Normal(mean=0.0, std=0.01),
+ ):
super(ConvNormLayer, self).__init__()
- assert norm_type in ['bn', 'sync_bn', 'gn']
+ assert norm_type in ["bn", "sync_bn", "gn"]
bias_attr = False
@@ -52,29 +53,34 @@ def __init__(self,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(
- initializer=initializer, learning_rate=1.),
- bias_attr=bias_attr)
+ weight_attr=ParamAttr(initializer=initializer, learning_rate=1.0),
+ bias_attr=bias_attr,
+ )
- norm_lr = 0. if freeze_norm else 1.
+ norm_lr = 0.0 if freeze_norm else 1.0
param_attr = ParamAttr(
learning_rate=norm_lr,
- regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
+ regularizer=L2Decay(norm_decay) if norm_decay is not None else None,
+ )
bias_attr = ParamAttr(
learning_rate=norm_lr,
- regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
- if norm_type == 'bn':
+ regularizer=L2Decay(norm_decay) if norm_decay is not None else None,
+ )
+ if norm_type == "bn":
self.norm = nn.BatchNorm2D(
- ch_out, weight_attr=param_attr, bias_attr=bias_attr)
- elif norm_type == 'sync_bn':
+ ch_out, weight_attr=param_attr, bias_attr=bias_attr
+ )
+ elif norm_type == "sync_bn":
self.norm = nn.SyncBatchNorm(
- ch_out, weight_attr=param_attr, bias_attr=bias_attr)
- elif norm_type == 'gn':
+ ch_out, weight_attr=param_attr, bias_attr=bias_attr
+ )
+ elif norm_type == "gn":
self.norm = nn.GroupNorm(
num_groups=norm_groups,
num_channels=ch_out,
weight_attr=param_attr,
- bias_attr=bias_attr)
+ bias_attr=bias_attr,
+ )
def forward(self, inputs):
out = self.conv(inputs)
@@ -86,45 +92,47 @@ class FCEFPN(nn.Layer):
"""
Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
Args:
- in_channels (list[int]): input channels of each level which can be
+ in_channels (list[int]): input channels of each level which can be
derived from the output shape of backbone by from_config
out_channels (list[int]): output channel of each level
spatial_scales (list[float]): the spatial scales between input feature
- maps and original input image which can be derived from the output
+ maps and original input image which can be derived from the output
shape of backbone by from_config
has_extra_convs (bool): whether to add extra conv to the last level.
default False
extra_stage (int): the number of extra stages added to the last level.
default 1
- use_c5 (bool): Whether to use c5 as the input of extra stage,
+ use_c5 (bool): Whether to use c5 as the input of extra stage,
otherwise p5 is used. default True
- norm_type (string|None): The normalization type in FPN module. If
- norm_type is None, norm will not be used after conv and if
+ norm_type (string|None): The normalization type in FPN module. If
+ norm_type is None, norm will not be used after conv and if
norm_type is string, bn, gn, sync_bn are available. default None
norm_decay (float): weight decay for normalization layer weights.
default 0.
- freeze_norm (bool): whether to freeze normalization layer.
+ freeze_norm (bool): whether to freeze normalization layer.
default False
relu_before_extra_convs (bool): whether to add relu before extra convs.
default False
-
+
"""
- def __init__(self,
- in_channels,
- out_channels,
- spatial_scales=[0.25, 0.125, 0.0625, 0.03125],
- has_extra_convs=False,
- extra_stage=1,
- use_c5=True,
- norm_type=None,
- norm_decay=0.,
- freeze_norm=False,
- relu_before_extra_convs=True):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ spatial_scales=[0.25, 0.125, 0.0625, 0.03125],
+ has_extra_convs=False,
+ extra_stage=1,
+ use_c5=True,
+ norm_type=None,
+ norm_decay=0.0,
+ freeze_norm=False,
+ relu_before_extra_convs=True,
+ ):
super(FCEFPN, self).__init__()
self.out_channels = out_channels
for s in range(extra_stage):
- spatial_scales = spatial_scales + [spatial_scales[-1] / 2.]
+ spatial_scales = spatial_scales + [spatial_scales[-1] / 2.0]
self.spatial_scales = spatial_scales
self.has_extra_convs = has_extra_convs
self.extra_stage = extra_stage
@@ -144,9 +152,9 @@ def __init__(self,
ed_stage = st_stage + len(in_channels) - 1
for i in range(st_stage, ed_stage + 1):
if i == 3:
- lateral_name = 'fpn_inner_res5_sum'
+ lateral_name = "fpn_inner_res5_sum"
else:
- lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2)
+ lateral_name = "fpn_inner_res{}_sum_lateral".format(i + 2)
in_c = in_channels[i - st_stage]
if self.norm_type is not None:
lateral = self.add_sublayer(
@@ -159,7 +167,9 @@ def __init__(self,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
- initializer=XavierUniform(fan_out=in_c)))
+ initializer=XavierUniform(fan_out=in_c),
+ ),
+ )
else:
lateral = self.add_sublayer(
lateral_name,
@@ -167,12 +177,13 @@ def __init__(self,
in_channels=in_c,
out_channels=out_channels,
kernel_size=1,
- weight_attr=ParamAttr(
- initializer=XavierUniform(fan_out=in_c))))
+ weight_attr=ParamAttr(initializer=XavierUniform(fan_out=in_c)),
+ ),
+ )
self.lateral_convs.append(lateral)
for i in range(st_stage, ed_stage + 1):
- fpn_name = 'fpn_res{}_sum'.format(i + 2)
+ fpn_name = "fpn_res{}_sum".format(i + 2)
if self.norm_type is not None:
fpn_conv = self.add_sublayer(
fpn_name,
@@ -184,7 +195,9 @@ def __init__(self,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
- initializer=XavierUniform(fan_out=fan)))
+ initializer=XavierUniform(fan_out=fan),
+ ),
+ )
else:
fpn_conv = self.add_sublayer(
fpn_name,
@@ -193,8 +206,9 @@ def __init__(self,
out_channels=out_channels,
kernel_size=3,
padding=1,
- weight_attr=ParamAttr(
- initializer=XavierUniform(fan_out=fan))))
+ weight_attr=ParamAttr(initializer=XavierUniform(fan_out=fan)),
+ ),
+ )
self.fpn_convs.append(fpn_conv)
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
@@ -205,7 +219,7 @@ def __init__(self,
in_c = in_channels[-1]
else:
in_c = out_channels
- extra_fpn_name = 'fpn_{}'.format(lvl + 2)
+ extra_fpn_name = "fpn_{}".format(lvl + 2)
if self.norm_type is not None:
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
@@ -217,7 +231,9 @@ def __init__(self,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
- initializer=XavierUniform(fan_out=fan)))
+ initializer=XavierUniform(fan_out=fan),
+ ),
+ )
else:
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
@@ -228,14 +244,17 @@ def __init__(self,
stride=2,
padding=1,
weight_attr=ParamAttr(
- initializer=XavierUniform(fan_out=fan))))
+ initializer=XavierUniform(fan_out=fan)
+ ),
+ ),
+ )
self.fpn_convs.append(extra_fpn_conv)
@classmethod
def from_config(cls, cfg, input_shape):
return {
- 'in_channels': [i.channels for i in input_shape],
- 'spatial_scales': [1.0 / i.stride for i in input_shape],
+ "in_channels": [i.channels for i in input_shape],
+ "spatial_scales": [1.0 / i.stride for i in input_shape],
}
def forward(self, body_feats):
@@ -249,8 +268,9 @@ def forward(self, body_feats):
lvl = num_levels - i
upsample = F.interpolate(
laterals[lvl],
- scale_factor=2.,
- mode='nearest', )
+ scale_factor=2.0,
+ mode="nearest",
+ )
laterals[lvl - 1] += upsample
fpn_output = []
@@ -260,7 +280,9 @@ def forward(self, body_feats):
if self.extra_stage > 0:
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
if not self.has_extra_convs:
- assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs'
+ assert (
+ self.extra_stage == 1
+ ), "extra_stage should be 1 if FPN has not extra convs"
fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2))
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
else:
@@ -272,9 +294,11 @@ def forward(self, body_feats):
for i in range(1, self.extra_stage):
if self.relu_before_extra_convs:
- fpn_output.append(self.fpn_convs[num_levels + i](F.relu(
- fpn_output[-1])))
+ fpn_output.append(
+ self.fpn_convs[num_levels + i](F.relu(fpn_output[-1]))
+ )
else:
- fpn_output.append(self.fpn_convs[num_levels + i](
- fpn_output[-1]))
+ fpn_output.append(
+ self.fpn_convs[num_levels + i](fpn_output[-1])
+ )
return fpn_output
diff --git a/ppocr/modeling/necks/fpn.py b/ppocr/modeling/necks/fpn.py
index 48c85b1e53..ea5253c059 100644
--- a/ppocr/modeling/necks/fpn.py
+++ b/ppocr/modeling/necks/fpn.py
@@ -23,12 +23,7 @@
class Conv_BN_ReLU(nn.Layer):
- def __init__(self,
- in_planes,
- out_planes,
- kernel_size=1,
- stride=1,
- padding=0):
+ def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0):
super(Conv_BN_ReLU, self).__init__()
self.conv = nn.Conv2D(
in_planes,
@@ -36,7 +31,8 @@ def __init__(self,
kernel_size=kernel_size,
stride=stride,
padding=padding,
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm2D(out_planes, momentum=0.1)
self.relu = nn.ReLU()
@@ -45,18 +41,22 @@ def __init__(self,
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(
shape=m.weight.shape,
- dtype='float32',
+ dtype="float32",
default_initializer=paddle.nn.initializer.Normal(
- 0, math.sqrt(2. / n)))
+ 0, math.sqrt(2.0 / n)
+ ),
+ )
elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(
shape=m.weight.shape,
- dtype='float32',
- default_initializer=paddle.nn.initializer.Constant(1.0))
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(1.0),
+ )
m.bias = paddle.create_parameter(
shape=m.bias.shape,
- dtype='float32',
- default_initializer=paddle.nn.initializer.Constant(0.0))
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(0.0),
+ )
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
@@ -68,26 +68,33 @@ def __init__(self, in_channels, out_channels):
# Top layer
self.toplayer_ = Conv_BN_ReLU(
- in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
+ in_channels[3], out_channels, kernel_size=1, stride=1, padding=0
+ )
# Lateral layers
self.latlayer1_ = Conv_BN_ReLU(
- in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
+ in_channels[2], out_channels, kernel_size=1, stride=1, padding=0
+ )
self.latlayer2_ = Conv_BN_ReLU(
- in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
+ in_channels[1], out_channels, kernel_size=1, stride=1, padding=0
+ )
self.latlayer3_ = Conv_BN_ReLU(
- in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
+ in_channels[0], out_channels, kernel_size=1, stride=1, padding=0
+ )
# Smooth layers
self.smooth1_ = Conv_BN_ReLU(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
self.smooth2_ = Conv_BN_ReLU(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
self.smooth3_ = Conv_BN_ReLU(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
self.out_channels = out_channels * 4
for m in self.sublayers():
@@ -95,24 +102,28 @@ def __init__(self, in_channels, out_channels):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(
shape=m.weight.shape,
- dtype='float32',
+ dtype="float32",
default_initializer=paddle.nn.initializer.Normal(
- 0, math.sqrt(2. / n)))
+ 0, math.sqrt(2.0 / n)
+ ),
+ )
elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(
shape=m.weight.shape,
- dtype='float32',
- default_initializer=paddle.nn.initializer.Constant(1.0))
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(1.0),
+ )
m.bias = paddle.create_parameter(
shape=m.bias.shape,
- dtype='float32',
- default_initializer=paddle.nn.initializer.Constant(0.0))
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(0.0),
+ )
def _upsample(self, x, scale=1):
- return F.upsample(x, scale_factor=scale, mode='bilinear')
+ return F.upsample(x, scale_factor=scale, mode="bilinear")
def _upsample_add(self, x, y, scale=1):
- return F.upsample(x, scale_factor=scale, mode='bilinear') + y
+ return F.upsample(x, scale_factor=scale, mode="bilinear") + y
def forward(self, x):
f2, f3, f4, f5 = x
diff --git a/ppocr/modeling/necks/fpn_unet.py b/ppocr/modeling/necks/fpn_unet.py
index 34e94a8b50..e560cdb69c 100644
--- a/ppocr/modeling/necks/fpn_unet.py
+++ b/ppocr/modeling/necks/fpn_unet.py
@@ -29,11 +29,14 @@ def __init__(self, in_channels, out_channels):
assert isinstance(out_channels, int)
self.conv1x1 = nn.Conv2D(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
self.conv3x3 = nn.Conv2D(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
self.deconv = nn.Conv2DTranspose(
- out_channels, out_channels, kernel_size=4, stride=2, padding=1)
+ out_channels, out_channels, kernel_size=4, stride=2, padding=1
+ )
def forward(self, x):
x = F.relu(self.conv1x1(x))
@@ -53,16 +56,19 @@ def __init__(self, in_channels, out_channels):
blocks_out_channels = [out_channels] + [
min(out_channels * 2**i, 256) for i in range(4)
]
- blocks_in_channels = [blocks_out_channels[1]] + [
- in_channels[i] + blocks_out_channels[i + 2] for i in range(3)
- ] + [in_channels[3]]
+ blocks_in_channels = (
+ [blocks_out_channels[1]]
+ + [in_channels[i] + blocks_out_channels[i + 2] for i in range(3)]
+ + [in_channels[3]]
+ )
self.up4 = nn.Conv2DTranspose(
blocks_in_channels[4],
blocks_out_channels[4],
kernel_size=4,
stride=2,
- padding=1)
+ padding=1,
+ )
self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3])
self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2])
self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1])
diff --git a/ppocr/modeling/necks/intracl.py b/ppocr/modeling/necks/intracl.py
index 205b52e35f..2c4809cb12 100644
--- a/ppocr/modeling/necks/intracl.py
+++ b/ppocr/modeling/necks/intracl.py
@@ -11,55 +11,55 @@ def __init__(self, in_channels=96, reduce_factor=4):
self.rf = reduce_factor
weight_attr = paddle.nn.initializer.KaimingUniform()
self.conv1x1_reduce_channel = nn.Conv2D(
- self.channels,
- self.channels // self.rf,
- kernel_size=1,
- stride=1,
- padding=0)
+ self.channels, self.channels // self.rf, kernel_size=1, stride=1, padding=0
+ )
self.conv1x1_return_channel = nn.Conv2D(
- self.channels // self.rf,
- self.channels,
- kernel_size=1,
- stride=1,
- padding=0)
+ self.channels // self.rf, self.channels, kernel_size=1, stride=1, padding=0
+ )
self.v_layer_7x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 1),
stride=(1, 1),
- padding=(3, 0))
+ padding=(3, 0),
+ )
self.v_layer_5x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 1),
stride=(1, 1),
- padding=(2, 0))
+ padding=(2, 0),
+ )
self.v_layer_3x1 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 1),
stride=(1, 1),
- padding=(1, 0))
+ padding=(1, 0),
+ )
self.q_layer_1x7 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 7),
stride=(1, 1),
- padding=(0, 3))
+ padding=(0, 3),
+ )
self.q_layer_1x5 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 5),
stride=(1, 1),
- padding=(0, 2))
+ padding=(0, 2),
+ )
self.q_layer_1x3 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 3),
stride=(1, 1),
- padding=(0, 1))
+ padding=(0, 1),
+ )
# base
self.c_layer_7x7 = nn.Conv2D(
@@ -67,19 +67,22 @@ def __init__(self, in_channels=96, reduce_factor=4):
self.channels // self.rf,
kernel_size=(7, 7),
stride=(1, 1),
- padding=(3, 3))
+ padding=(3, 3),
+ )
self.c_layer_5x5 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 5),
stride=(1, 1),
- padding=(2, 2))
+ padding=(2, 2),
+ )
self.c_layer_3x3 = nn.Conv2D(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 3),
stride=(1, 1),
- padding=(1, 1))
+ padding=(1, 1),
+ )
self.bn = nn.BatchNorm2D(self.channels)
self.relu = nn.ReLU()
@@ -115,4 +118,4 @@ def build_intraclblock_list(num_block):
for i in range(num_block):
IntraCLBlock_list.append(IntraCLBlock())
- return IntraCLBlock_list
\ No newline at end of file
+ return IntraCLBlock_list
diff --git a/ppocr/modeling/necks/pg_fpn.py b/ppocr/modeling/necks/pg_fpn.py
index 3f64539f79..cae4b22408 100644
--- a/ppocr/modeling/necks/pg_fpn.py
+++ b/ppocr/modeling/necks/pg_fpn.py
@@ -23,20 +23,23 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
- kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ kernel_size=2, stride=2, padding=0, ceil_mode=True
+ )
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
@@ -45,7 +48,8 @@ def __init__(self,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False)
+ bias_attr=False,
+ )
if name == "conv1":
bn_name = "bn_" + name
else:
@@ -53,11 +57,12 @@ def __init__(self,
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance',
- use_global_stats=False)
+ param_attr=ParamAttr(name=bn_name + "_scale"),
+ bias_attr=ParamAttr(bn_name + "_offset"),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance",
+ use_global_stats=False,
+ )
def forward(self, inputs):
y = self._conv(inputs)
@@ -66,16 +71,18 @@ def forward(self, inputs):
class DeConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
@@ -87,8 +94,9 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
@@ -96,7 +104,8 @@ def __init__(self,
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
- use_global_stats=False)
+ use_global_stats=False,
+ )
def forward(self, x):
x = self.deconv(x)
@@ -116,56 +125,64 @@ def __init__(self, in_channels, **kwargs):
kernel_size=3,
stride=1,
act=None,
- name='FPN_d1')
+ name="FPN_d1",
+ )
self.conv_bn_layer_2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act=None,
- name='FPN_d2')
+ name="FPN_d2",
+ )
self.conv_bn_layer_3 = ConvBNLayer(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=1,
act=None,
- name='FPN_d3')
+ name="FPN_d3",
+ )
self.conv_bn_layer_4 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=2,
act=None,
- name='FPN_d4')
+ name="FPN_d4",
+ )
self.conv_bn_layer_5 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
- act='relu',
- name='FPN_d5')
+ act="relu",
+ name="FPN_d5",
+ )
self.conv_bn_layer_6 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
act=None,
- name='FPN_d6')
+ name="FPN_d6",
+ )
self.conv_bn_layer_7 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
- act='relu',
- name='FPN_d7')
+ act="relu",
+ name="FPN_d7",
+ )
self.conv_bn_layer_8 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=1,
stride=1,
act=None,
- name='FPN_d8')
+ name="FPN_d8",
+ )
self.conv_h0 = ConvBNLayer(
in_channels=num_inputs[0],
@@ -173,90 +190,104 @@ def __init__(self, in_channels, **kwargs):
kernel_size=1,
stride=1,
act=None,
- name="conv_h{}".format(0))
+ name="conv_h{}".format(0),
+ )
self.conv_h1 = ConvBNLayer(
in_channels=num_inputs[1],
out_channels=num_outputs[1],
kernel_size=1,
stride=1,
act=None,
- name="conv_h{}".format(1))
+ name="conv_h{}".format(1),
+ )
self.conv_h2 = ConvBNLayer(
in_channels=num_inputs[2],
out_channels=num_outputs[2],
kernel_size=1,
stride=1,
act=None,
- name="conv_h{}".format(2))
+ name="conv_h{}".format(2),
+ )
self.conv_h3 = ConvBNLayer(
in_channels=num_inputs[3],
out_channels=num_outputs[3],
kernel_size=1,
stride=1,
act=None,
- name="conv_h{}".format(3))
+ name="conv_h{}".format(3),
+ )
self.conv_h4 = ConvBNLayer(
in_channels=num_inputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
- name="conv_h{}".format(4))
+ name="conv_h{}".format(4),
+ )
self.dconv0 = DeConvBNLayer(
in_channels=num_outputs[0],
out_channels=num_outputs[0 + 1],
- name="dconv_{}".format(0))
+ name="dconv_{}".format(0),
+ )
self.dconv1 = DeConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[1 + 1],
act=None,
- name="dconv_{}".format(1))
+ name="dconv_{}".format(1),
+ )
self.dconv2 = DeConvBNLayer(
in_channels=num_outputs[2],
out_channels=num_outputs[2 + 1],
act=None,
- name="dconv_{}".format(2))
+ name="dconv_{}".format(2),
+ )
self.dconv3 = DeConvBNLayer(
in_channels=num_outputs[3],
out_channels=num_outputs[3 + 1],
act=None,
- name="dconv_{}".format(3))
+ name="dconv_{}".format(3),
+ )
self.conv_g1 = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[1],
kernel_size=3,
stride=1,
- act='relu',
- name="conv_g{}".format(1))
+ act="relu",
+ name="conv_g{}".format(1),
+ )
self.conv_g2 = ConvBNLayer(
in_channels=num_outputs[2],
out_channels=num_outputs[2],
kernel_size=3,
stride=1,
- act='relu',
- name="conv_g{}".format(2))
+ act="relu",
+ name="conv_g{}".format(2),
+ )
self.conv_g3 = ConvBNLayer(
in_channels=num_outputs[3],
out_channels=num_outputs[3],
kernel_size=3,
stride=1,
- act='relu',
- name="conv_g{}".format(3))
+ act="relu",
+ name="conv_g{}".format(3),
+ )
self.conv_g4 = ConvBNLayer(
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=3,
stride=1,
- act='relu',
- name="conv_g{}".format(4))
+ act="relu",
+ name="conv_g{}".format(4),
+ )
self.convf = ConvBNLayer(
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
- name="conv_f{}".format(4))
+ name="conv_f{}".format(4),
+ )
def forward(self, x):
c0, c1, c2, c3, c4, c5, c6 = x
diff --git a/ppocr/modeling/necks/pren_fpn.py b/ppocr/modeling/necks/pren_fpn.py
index afbdcea83d..29c98e9817 100644
--- a/ppocr/modeling/necks/pren_fpn.py
+++ b/ppocr/modeling/necks/pren_fpn.py
@@ -46,14 +46,26 @@ def _build_aggs(self):
for i in range(self.n_r):
aggs.append(
self.add_sublayer(
- '{}'.format(i),
+ "{}".format(i),
nn.Sequential(
- ('conv1', nn.Conv2D(
- self.d_in, self.d_middle, 3, 2, 1, bias_attr=False)
- ), ('bn1', nn.BatchNorm(self.d_middle)),
- ('act', self.act), ('conv2', nn.Conv2D(
- self.d_middle, self.d_out, 3, 2, 1, bias_attr=False
- )), ('bn2', nn.BatchNorm(self.d_out)))))
+ (
+ "conv1",
+ nn.Conv2D(
+ self.d_in, self.d_middle, 3, 2, 1, bias_attr=False
+ ),
+ ),
+ ("bn1", nn.BatchNorm(self.d_middle)),
+ ("act", self.act),
+ (
+ "conv2",
+ nn.Conv2D(
+ self.d_middle, self.d_out, 3, 2, 1, bias_attr=False
+ ),
+ ),
+ ("bn2", nn.BatchNorm(self.d_out)),
+ ),
+ )
+ )
return aggs
def forward(self, x):
@@ -80,19 +92,20 @@ def __init__(self, n_r, d_in, d_middle=None, d_out=None):
self.act = nn.Swish()
self.conv_n = nn.Sequential(
- ('conv1', nn.Conv2D(
- d_in, d_in, 3, 1, 1,
- bias_attr=False)), ('bn1', nn.BatchNorm(d_in)),
- ('act1', self.act), ('conv2', nn.Conv2D(
- d_in, n_r, 1, bias_attr=False)), ('bn2', nn.BatchNorm(n_r)),
- ('act2', nn.Sigmoid()))
+ ("conv1", nn.Conv2D(d_in, d_in, 3, 1, 1, bias_attr=False)),
+ ("bn1", nn.BatchNorm(d_in)),
+ ("act1", self.act),
+ ("conv2", nn.Conv2D(d_in, n_r, 1, bias_attr=False)),
+ ("bn2", nn.BatchNorm(n_r)),
+ ("act2", nn.Sigmoid()),
+ )
self.conv_d = nn.Sequential(
- ('conv1', nn.Conv2D(
- d_in, d_middle, 3, 1, 1,
- bias_attr=False)), ('bn1', nn.BatchNorm(d_middle)),
- ('act1', self.act), ('conv2', nn.Conv2D(
- d_middle, d_out, 1,
- bias_attr=False)), ('bn2', nn.BatchNorm(d_out)))
+ ("conv1", nn.Conv2D(d_in, d_middle, 3, 1, 1, bias_attr=False)),
+ ("bn1", nn.BatchNorm(d_middle)),
+ ("act1", self.act),
+ ("conv2", nn.Conv2D(d_middle, d_out, 1, bias_attr=False)),
+ ("bn2", nn.BatchNorm(d_out)),
+ )
def forward(self, x):
b, _, h, w = x.shape
@@ -101,7 +114,8 @@ def forward(self, x):
fmaps = self.conv_d(x)
r = paddle.bmm(
hmaps.reshape((b, self.n_r, h * w)),
- fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1)))
+ fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1)),
+ )
return r
diff --git a/ppocr/modeling/necks/rf_adaptor.py b/ppocr/modeling/necks/rf_adaptor.py
index 94590127b0..3c30fe3b35 100644
--- a/ppocr/modeling/necks/rf_adaptor.py
+++ b/ppocr/modeling/necks/rf_adaptor.py
@@ -21,12 +21,12 @@
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
kaiming_init_ = KaimingNormal()
-zeros_ = Constant(value=0.)
-ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
class S2VAdaptor(nn.Layer):
- """ Semantic to Visual adaptation module"""
+ """Semantic to Visual adaptation module"""
def __init__(self, in_channels=512):
super(S2VAdaptor, self).__init__()
@@ -35,7 +35,8 @@ def __init__(self, in_channels=512):
# feature strengthen module, channel attention
self.channel_inter = nn.Linear(
- self.in_channels, self.in_channels, bias_attr=False)
+ self.in_channels, self.in_channels, bias_attr=False
+ )
self.channel_bn = nn.BatchNorm1D(self.in_channels)
self.channel_act = nn.ReLU()
self.apply(self.init_weights)
@@ -53,8 +54,7 @@ def forward(self, semantic):
semantic_source = semantic # batch, channel, height, width
# feature transformation
- semantic = semantic.squeeze(2).transpose(
- [0, 2, 1]) # batch, width, channel
+ semantic = semantic.squeeze(2).transpose([0, 2, 1]) # batch, width, channel
channel_att = self.channel_inter(semantic) # batch, width, channel
channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
channel_bn = self.channel_bn(channel_att) # batch, channel, width
@@ -62,13 +62,14 @@ def forward(self, semantic):
# Feature enhancement
channel_output = semantic_source * channel_att.unsqueeze(
- -2) # batch, channel, 1, width
+ -2
+ ) # batch, channel, 1, width
return channel_output
class V2SAdaptor(nn.Layer):
- """ Visual to Semantic adaptation module"""
+ """Visual to Semantic adaptation module"""
def __init__(self, in_channels=512, return_mask=False):
super(V2SAdaptor, self).__init__()
@@ -79,7 +80,8 @@ def __init__(self, in_channels=512, return_mask=False):
# output transformation
self.channel_inter = nn.Linear(
- self.in_channels, self.in_channels, bias_attr=False)
+ self.in_channels, self.in_channels, bias_attr=False
+ )
self.channel_bn = nn.BatchNorm1D(self.in_channels)
self.channel_act = nn.ReLU()
@@ -115,9 +117,15 @@ def __init__(self, in_channels=512, use_v2s=True, use_s2v=True, **kwargs):
def forward(self, x):
visual_feature, rcg_feature = x
if visual_feature is not None:
- batch, source_channels, v_source_height, v_source_width = visual_feature.shape
+ (
+ batch,
+ source_channels,
+ v_source_height,
+ v_source_width,
+ ) = visual_feature.shape
visual_feature = visual_feature.reshape(
- [batch, source_channels, 1, v_source_height * v_source_width])
+ [batch, source_channels, 1, v_source_height * v_source_width]
+ )
if self.neck_v2s is not None:
v_rcg_feature = rcg_feature * self.neck_v2s(visual_feature)
@@ -131,7 +139,8 @@ def forward(self, x):
if v_rcg_feature is not None:
batch, source_channels, source_height, source_width = v_rcg_feature.shape
v_rcg_feature = v_rcg_feature.reshape(
- [batch, source_channels, 1, source_height * source_width])
+ [batch, source_channels, 1, source_height * source_width]
+ )
v_rcg_feature = v_rcg_feature.squeeze(2).transpose([0, 2, 1])
return v_visual_feature, v_rcg_feature
diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py
index a195a6217a..fa7b8a1f1a 100644
--- a/ppocr/modeling/necks/rnn.py
+++ b/ppocr/modeling/necks/rnn.py
@@ -20,7 +20,13 @@
from paddle import nn
from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
-from ppocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer, trunc_normal_, zeros_, ones_
+from ppocr.modeling.backbones.rec_svtrnet import (
+ Block,
+ ConvBNLayer,
+ trunc_normal_,
+ zeros_,
+ ones_,
+)
class Im2Seq(nn.Layer):
@@ -41,7 +47,8 @@ def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN, self).__init__()
self.out_channels = hidden_size * 2
self.lstm = nn.LSTM(
- in_channels, hidden_size, direction='bidirectional', num_layers=2)
+ in_channels, hidden_size, direction="bidirectional", num_layers=2
+ )
def forward(self, x):
x, _ = self.lstm(x)
@@ -49,15 +56,17 @@ def forward(self, x):
class BidirectionalLSTM(nn.Layer):
- def __init__(self,
- input_size,
- hidden_size,
- output_size=None,
- num_layers=1,
- dropout=0,
- direction=False,
- time_major=False,
- with_linear=False):
+ def __init__(
+ self,
+ input_size,
+ hidden_size,
+ output_size=None,
+ num_layers=1,
+ dropout=0,
+ direction=False,
+ time_major=False,
+ with_linear=False,
+ ):
super(BidirectionalLSTM, self).__init__()
self.with_linear = with_linear
self.rnn = nn.LSTM(
@@ -66,7 +75,8 @@ def __init__(self,
num_layers=num_layers,
dropout=dropout,
direction=direction,
- time_major=time_major)
+ time_major=time_major,
+ )
# text recognition the specified structure LSTM with linear
if self.with_linear:
@@ -83,23 +93,24 @@ def forward(self, input_feature):
class EncoderWithCascadeRNN(nn.Layer):
- def __init__(self,
- in_channels,
- hidden_size,
- out_channels,
- num_layers=2,
- with_linear=False):
+ def __init__(
+ self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False
+ ):
super(EncoderWithCascadeRNN, self).__init__()
self.out_channels = out_channels[-1]
- self.encoder = nn.LayerList([
- BidirectionalLSTM(
- in_channels if i == 0 else out_channels[i - 1],
- hidden_size,
- output_size=out_channels[i],
- num_layers=1,
- direction='bidirectional',
- with_linear=with_linear) for i in range(num_layers)
- ])
+ self.encoder = nn.LayerList(
+ [
+ BidirectionalLSTM(
+ in_channels if i == 0 else out_channels[i - 1],
+ hidden_size,
+ output_size=out_channels[i],
+ num_layers=1,
+ direction="bidirectional",
+ with_linear=with_linear,
+ )
+ for i in range(num_layers)
+ ]
+ )
def forward(self, x):
for i, l in enumerate(self.encoder):
@@ -111,14 +122,14 @@ class EncoderWithFC(nn.Layer):
def __init__(self, in_channels, hidden_size):
super(EncoderWithFC, self).__init__()
self.out_channels = hidden_size
- weight_attr, bias_attr = get_para_bias_attr(
- l2_decay=0.00001, k=in_channels)
+ weight_attr, bias_attr = get_para_bias_attr(l2_decay=0.00001, k=in_channels)
self.fc = nn.Linear(
in_channels,
hidden_size,
weight_attr=weight_attr,
bias_attr=bias_attr,
- name='reduce_encoder_fea')
+ name="reduce_encoder_fea",
+ )
def forward(self, x):
x = self.fc(x)
@@ -127,20 +138,21 @@ def forward(self, x):
class EncoderWithSVTR(nn.Layer):
def __init__(
- self,
- in_channels,
- dims=64, # XS
- depth=2,
- hidden_dims=120,
- use_guide=False,
- num_heads=8,
- qkv_bias=True,
- mlp_ratio=2.0,
- drop_rate=0.1,
- attn_drop_rate=0.1,
- drop_path=0.,
- kernel_size=[3, 3],
- qk_scale=None):
+ self,
+ in_channels,
+ dims=64, # XS
+ depth=2,
+ hidden_dims=120,
+ use_guide=False,
+ num_heads=8,
+ qkv_bias=True,
+ mlp_ratio=2.0,
+ drop_rate=0.1,
+ attn_drop_rate=0.1,
+ drop_path=0.0,
+ kernel_size=[3, 3],
+ qk_scale=None,
+ ):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
@@ -149,40 +161,45 @@ def __init__(
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
- act=nn.Swish)
+ act=nn.Swish,
+ )
self.conv2 = ConvBNLayer(
- in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish)
+ in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish
+ )
- self.svtr_block = nn.LayerList([
- Block(
- dim=hidden_dims,
- num_heads=num_heads,
- mixer='Global',
- HW=None,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- act_layer=nn.Swish,
- attn_drop=attn_drop_rate,
- drop_path=drop_path,
- norm_layer='nn.LayerNorm',
- epsilon=1e-05,
- prenorm=False) for i in range(depth)
- ])
+ self.svtr_block = nn.LayerList(
+ [
+ Block(
+ dim=hidden_dims,
+ num_heads=num_heads,
+ mixer="Global",
+ HW=None,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=nn.Swish,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-05,
+ prenorm=False,
+ )
+ for i in range(depth)
+ ]
+ )
self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6)
- self.conv3 = ConvBNLayer(
- hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
+ self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels,
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
- act=nn.Swish)
+ act=nn.Swish,
+ )
- self.conv1x1 = ConvBNLayer(
- in_channels // 8, dims, kernel_size=1, act=nn.Swish)
+ self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act=nn.Swish)
self.out_channels = dims
self.apply(self._init_weights)
@@ -227,32 +244,36 @@ def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
- if encoder_type == 'reshape':
+ if encoder_type == "reshape":
self.only_reshape = True
else:
support_encoder_dict = {
- 'reshape': Im2Seq,
- 'fc': EncoderWithFC,
- 'rnn': EncoderWithRNN,
- 'svtr': EncoderWithSVTR,
- 'cascadernn': EncoderWithCascadeRNN
+ "reshape": Im2Seq,
+ "fc": EncoderWithFC,
+ "rnn": EncoderWithRNN,
+ "svtr": EncoderWithSVTR,
+ "cascadernn": EncoderWithCascadeRNN,
}
- assert encoder_type in support_encoder_dict, '{} must in {}'.format(
- encoder_type, support_encoder_dict.keys())
+ assert encoder_type in support_encoder_dict, "{} must in {}".format(
+ encoder_type, support_encoder_dict.keys()
+ )
if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type](
- self.encoder_reshape.out_channels, **kwargs)
- elif encoder_type == 'cascadernn':
+ self.encoder_reshape.out_channels, **kwargs
+ )
+ elif encoder_type == "cascadernn":
self.encoder = support_encoder_dict[encoder_type](
- self.encoder_reshape.out_channels, hidden_size, **kwargs)
+ self.encoder_reshape.out_channels, hidden_size, **kwargs
+ )
else:
self.encoder = support_encoder_dict[encoder_type](
- self.encoder_reshape.out_channels, hidden_size)
+ self.encoder_reshape.out_channels, hidden_size
+ )
self.out_channels = self.encoder.out_channels
self.only_reshape = False
def forward(self, x):
- if self.encoder_type != 'svtr':
+ if self.encoder_type != "svtr":
x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
diff --git a/ppocr/modeling/necks/sast_fpn.py b/ppocr/modeling/necks/sast_fpn.py
index d106179708..4804806d6b 100644
--- a/ppocr/modeling/necks/sast_fpn.py
+++ b/ppocr/modeling/necks/sast_fpn.py
@@ -23,15 +23,17 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -42,16 +44,18 @@ def __init__(self,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
-
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
+
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
- moving_variance_name="bn_" + name + "_variance")
+ moving_variance_name="bn_" + name + "_variance",
+ )
def forward(self, x):
x = self.conv(x)
@@ -60,15 +64,17 @@ def forward(self, x):
class DeConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- groups=1,
- if_act=True,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ ):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -79,15 +85,17 @@ def __init__(self,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
- bias_attr=False)
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ )
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
- moving_variance_name="bn_" + name + "_variance")
+ moving_variance_name="bn_" + name + "_variance",
+ )
def forward(self, x):
x = self.deconv(x)
@@ -100,31 +108,64 @@ def __init__(self, in_channels):
super(FPN_Up_Fusion, self).__init__()
in_channels = in_channels[::-1]
out_channels = [256, 256, 192, 192, 128]
-
- self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 1, 1, act=None, name='fpn_up_h0')
- self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 1, 1, act=None, name='fpn_up_h1')
- self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 1, 1, act=None, name='fpn_up_h2')
- self.h3_conv = ConvBNLayer(in_channels[3], out_channels[3], 1, 1, act=None, name='fpn_up_h3')
- self.h4_conv = ConvBNLayer(in_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_h4')
- self.g0_conv = DeConvBNLayer(out_channels[0], out_channels[1], 4, 2, act=None, name='fpn_up_g0')
+ self.h0_conv = ConvBNLayer(
+ in_channels[0], out_channels[0], 1, 1, act=None, name="fpn_up_h0"
+ )
+ self.h1_conv = ConvBNLayer(
+ in_channels[1], out_channels[1], 1, 1, act=None, name="fpn_up_h1"
+ )
+ self.h2_conv = ConvBNLayer(
+ in_channels[2], out_channels[2], 1, 1, act=None, name="fpn_up_h2"
+ )
+ self.h3_conv = ConvBNLayer(
+ in_channels[3], out_channels[3], 1, 1, act=None, name="fpn_up_h3"
+ )
+ self.h4_conv = ConvBNLayer(
+ in_channels[4], out_channels[4], 1, 1, act=None, name="fpn_up_h4"
+ )
+
+ self.g0_conv = DeConvBNLayer(
+ out_channels[0], out_channels[1], 4, 2, act=None, name="fpn_up_g0"
+ )
self.g1_conv = nn.Sequential(
- ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_up_g1_1'),
- DeConvBNLayer(out_channels[1], out_channels[2], 4, 2, act=None, name='fpn_up_g1_2')
+ ConvBNLayer(
+ out_channels[1], out_channels[1], 3, 1, act="relu", name="fpn_up_g1_1"
+ ),
+ DeConvBNLayer(
+ out_channels[1], out_channels[2], 4, 2, act=None, name="fpn_up_g1_2"
+ ),
)
self.g2_conv = nn.Sequential(
- ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_up_g2_1'),
- DeConvBNLayer(out_channels[2], out_channels[3], 4, 2, act=None, name='fpn_up_g2_2')
+ ConvBNLayer(
+ out_channels[2], out_channels[2], 3, 1, act="relu", name="fpn_up_g2_1"
+ ),
+ DeConvBNLayer(
+ out_channels[2], out_channels[3], 4, 2, act=None, name="fpn_up_g2_2"
+ ),
)
self.g3_conv = nn.Sequential(
- ConvBNLayer(out_channels[3], out_channels[3], 3, 1, act='relu', name='fpn_up_g3_1'),
- DeConvBNLayer(out_channels[3], out_channels[4], 4, 2, act=None, name='fpn_up_g3_2')
+ ConvBNLayer(
+ out_channels[3], out_channels[3], 3, 1, act="relu", name="fpn_up_g3_1"
+ ),
+ DeConvBNLayer(
+ out_channels[3], out_channels[4], 4, 2, act=None, name="fpn_up_g3_2"
+ ),
)
self.g4_conv = nn.Sequential(
- ConvBNLayer(out_channels[4], out_channels[4], 3, 1, act='relu', name='fpn_up_fusion_1'),
- ConvBNLayer(out_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_fusion_2')
+ ConvBNLayer(
+ out_channels[4],
+ out_channels[4],
+ 3,
+ 1,
+ act="relu",
+ name="fpn_up_fusion_1",
+ ),
+ ConvBNLayer(
+ out_channels[4], out_channels[4], 1, 1, act=None, name="fpn_up_fusion_2"
+ ),
)
def _add_relu(self, x1, x2):
@@ -155,20 +196,46 @@ def __init__(self, in_channels):
super(FPN_Down_Fusion, self).__init__()
out_channels = [32, 64, 128]
- self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 3, 1, act=None, name='fpn_down_h0')
- self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 3, 1, act=None, name='fpn_down_h1')
- self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 3, 1, act=None, name='fpn_down_h2')
+ self.h0_conv = ConvBNLayer(
+ in_channels[0], out_channels[0], 3, 1, act=None, name="fpn_down_h0"
+ )
+ self.h1_conv = ConvBNLayer(
+ in_channels[1], out_channels[1], 3, 1, act=None, name="fpn_down_h1"
+ )
+ self.h2_conv = ConvBNLayer(
+ in_channels[2], out_channels[2], 3, 1, act=None, name="fpn_down_h2"
+ )
- self.g0_conv = ConvBNLayer(out_channels[0], out_channels[1], 3, 2, act=None, name='fpn_down_g0')
+ self.g0_conv = ConvBNLayer(
+ out_channels[0], out_channels[1], 3, 2, act=None, name="fpn_down_g0"
+ )
self.g1_conv = nn.Sequential(
- ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_down_g1_1'),
- ConvBNLayer(out_channels[1], out_channels[2], 3, 2, act=None, name='fpn_down_g1_2')
+ ConvBNLayer(
+ out_channels[1], out_channels[1], 3, 1, act="relu", name="fpn_down_g1_1"
+ ),
+ ConvBNLayer(
+ out_channels[1], out_channels[2], 3, 2, act=None, name="fpn_down_g1_2"
+ ),
)
self.g2_conv = nn.Sequential(
- ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_down_fusion_1'),
- ConvBNLayer(out_channels[2], out_channels[2], 1, 1, act=None, name='fpn_down_fusion_2')
+ ConvBNLayer(
+ out_channels[2],
+ out_channels[2],
+ 3,
+ 1,
+ act="relu",
+ name="fpn_down_fusion_1",
+ ),
+ ConvBNLayer(
+ out_channels[2],
+ out_channels[2],
+ 1,
+ 1,
+ act=None,
+ name="fpn_down_fusion_2",
+ ),
)
def forward(self, x):
@@ -189,36 +256,51 @@ def forward(self, x):
class Cross_Attention(nn.Layer):
def __init__(self, in_channels):
super(Cross_Attention, self).__init__()
- self.theta_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_theta')
- self.phi_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_phi')
- self.g_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_g')
+ self.theta_conv = ConvBNLayer(
+ in_channels, in_channels, 1, 1, act="relu", name="f_theta"
+ )
+ self.phi_conv = ConvBNLayer(
+ in_channels, in_channels, 1, 1, act="relu", name="f_phi"
+ )
+ self.g_conv = ConvBNLayer(
+ in_channels, in_channels, 1, 1, act="relu", name="f_g"
+ )
- self.fh_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_weight')
- self.fh_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_sc')
+ self.fh_weight_conv = ConvBNLayer(
+ in_channels, in_channels, 1, 1, act=None, name="fh_weight"
+ )
+ self.fh_sc_conv = ConvBNLayer(
+ in_channels, in_channels, 1, 1, act=None, name="fh_sc"
+ )
- self.fv_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_weight')
- self.fv_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_sc')
+ self.fv_weight_conv = ConvBNLayer(
+ in_channels, in_channels, 1, 1, act=None, name="fv_weight"
+ )
+ self.fv_sc_conv = ConvBNLayer(
+ in_channels, in_channels, 1, 1, act=None, name="fv_sc"
+ )
- self.f_attn_conv = ConvBNLayer(in_channels * 2, in_channels, 1, 1, act='relu', name='f_attn')
+ self.f_attn_conv = ConvBNLayer(
+ in_channels * 2, in_channels, 1, 1, act="relu", name="f_attn"
+ )
def _cal_fweight(self, f, shape):
f_theta, f_phi, f_g = f
- #flatten
+ # flatten
f_theta = paddle.transpose(f_theta, [0, 2, 3, 1])
f_theta = paddle.reshape(f_theta, [shape[0] * shape[1], shape[2], 128])
f_phi = paddle.transpose(f_phi, [0, 2, 3, 1])
f_phi = paddle.reshape(f_phi, [shape[0] * shape[1], shape[2], 128])
f_g = paddle.transpose(f_g, [0, 2, 3, 1])
f_g = paddle.reshape(f_g, [shape[0] * shape[1], shape[2], 128])
- #correlation
+ # correlation
f_attn = paddle.matmul(f_theta, paddle.transpose(f_phi, [0, 2, 1]))
- #scale
+ # scale
f_attn = f_attn / (128**0.5)
f_attn = F.softmax(f_attn)
- #weighted sum
+ # weighted sum
f_weight = paddle.matmul(f_attn, f_g)
- f_weight = paddle.reshape(
- f_weight, [shape[0], shape[1], shape[2], 128])
+ f_weight = paddle.reshape(f_weight, [shape[0], shape[1], shape[2], 128])
return f_weight
def forward(self, f_common):
@@ -230,11 +312,12 @@ def forward(self, f_common):
f_g = self.g_conv(f_common)
######## horizon ########
- fh_weight = self._cal_fweight([f_theta, f_phi, f_g],
- [f_shape[0], f_shape[2], f_shape[3]])
+ fh_weight = self._cal_fweight(
+ [f_theta, f_phi, f_g], [f_shape[0], f_shape[2], f_shape[3]]
+ )
fh_weight = paddle.transpose(fh_weight, [0, 3, 1, 2])
fh_weight = self.fh_weight_conv(fh_weight)
- #short cut
+ # short cut
fh_sc = self.fh_sc_conv(f_common)
f_h = F.relu(fh_weight + fh_sc)
@@ -242,11 +325,12 @@ def forward(self, f_common):
fv_theta = paddle.transpose(f_theta, [0, 1, 3, 2])
fv_phi = paddle.transpose(f_phi, [0, 1, 3, 2])
fv_g = paddle.transpose(f_g, [0, 1, 3, 2])
- fv_weight = self._cal_fweight([fv_theta, fv_phi, fv_g],
- [f_shape[0], f_shape[3], f_shape[2]])
+ fv_weight = self._cal_fweight(
+ [fv_theta, fv_phi, fv_g], [f_shape[0], f_shape[3], f_shape[2]]
+ )
fv_weight = paddle.transpose(fv_weight, [0, 3, 2, 1])
fv_weight = self.fv_weight_conv(fv_weight)
- #short cut
+ # short cut
fv_sc = self.fv_sc_conv(f_common)
f_v = F.relu(fv_weight + fv_sc)
@@ -267,13 +351,13 @@ def __init__(self, in_channels, with_cab=False, **kwargs):
self.cross_attention = Cross_Attention(self.out_channels)
def forward(self, x):
- #down fpn
+ # down fpn
f_down = self.FPN_Down_Fusion(x)
- #up fpn
+ # up fpn
f_up = self.FPN_Up_Fusion(x)
- #fusion
+ # fusion
f_common = paddle.add(x=f_down, y=f_up)
f_common = F.relu(f_common)
diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py
index 734f15af65..d2739652bc 100644
--- a/ppocr/modeling/necks/table_fpn.py
+++ b/ppocr/modeling/necks/table_fpn.py
@@ -32,60 +32,70 @@ def __init__(self, in_channels, out_channels, **kwargs):
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
- stride = 1,
+ stride=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
- bias_attr=False)
+ bias_attr=False,
+ )
self.fuse_conv = nn.Conv2D(
in_channels=self.out_channels * 4,
out_channels=512,
kernel_size=3,
padding=1,
- weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False,
+ )
def forward(self, x):
c2, c3, c4, c5 = x
@@ -96,11 +106,14 @@ def forward(self, x):
in2 = self.in2_conv(c2)
out4 = in4 + F.upsample(
- in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
+ in5, size=in4.shape[2:4], mode="nearest", align_mode=1
+ ) # 1/16
out3 = in3 + F.upsample(
- out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
+ out4, size=in3.shape[2:4], mode="nearest", align_mode=1
+ ) # 1/8
out2 = in2 + F.upsample(
- out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
+ out3, size=in2.shape[2:4], mode="nearest", align_mode=1
+ ) # 1/4
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index 022ece60a5..24aeca347a 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['build_transform']
+__all__ = ["build_transform"]
def build_transform(config):
@@ -22,10 +22,11 @@ def build_transform(config):
from .tbsrn import TBSRN
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
- support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN', 'TBSRN']
+ support_dict = ["TPS", "STN_ON", "GA_SPIN", "TSRN", "TBSRN"]
- module_name = config.pop('name')
+ module_name = config.pop("name")
assert module_name in support_dict, Exception(
- 'transform only support {}'.format(support_dict))
+ "transform only support {}".format(support_dict)
+ )
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py
index 7afa216093..d162246f0b 100644
--- a/ppocr/modeling/transforms/gaspin_transformer.py
+++ b/ppocr/modeling/transforms/gaspin_transformer.py
@@ -24,9 +24,10 @@
import functools
from .tps import GridGenerator
-'''This code is refer from:
+"""This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py
-'''
+"""
+
class SP_TransformerNetwork(nn.Layer):
"""
@@ -35,7 +36,7 @@ class SP_TransformerNetwork(nn.Layer):
"""
def __init__(self, nc=1, default_type=5):
- """ Based on SPIN
+ """Based on SPIN
Args:
nc (int): number of input channels (usually in 1 or 3)
default_type (int): the complexity of transformation intensities (by default set to 6 as the paper)
@@ -56,11 +57,14 @@ def cal_K(self, k=5):
"""
from math import log
+
x = []
if k != 0:
- for i in range(1, k+1):
- lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2)
- upper = round(1/lower, 2)
+ for i in range(1, k + 1):
+ lower = round(
+ log(1 - (0.5 / (k + 1)) * i) / log((0.5 / (k + 1)) * i), 2
+ )
+ upper = round(1 / lower, 2)
x.append(lower)
x.append(upper)
x.append(1.00)
@@ -83,7 +87,7 @@ def forward(self, batch_I, weights, offsets, lambda_color=None):
"""
batch_I = (batch_I + 1) * 0.5
if offsets is not None:
- batch_I = batch_I*(1-lambda_color) + offsets*lambda_color
+ batch_I = batch_I * (1 - lambda_color) + offsets * lambda_color
batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1)
batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1)
@@ -93,6 +97,7 @@ def forward(self, batch_I, weights, offsets, lambda_color=None):
batch_weight_sum = batch_weight_sum * 2 - 1
return batch_weight_sum
+
class GA_SPIN_Transformer(nn.Layer):
"""
Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1]
@@ -101,13 +106,16 @@ class GA_SPIN_Transformer(nn.Layer):
Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
"""
- def __init__(self, in_channels=1,
- I_r_size=(32, 100),
- offsets=False,
- norm_type='BN',
- default_type=6,
- loc_lr=1,
- stn=True):
+ def __init__(
+ self,
+ in_channels=1,
+ I_r_size=(32, 100),
+ offsets=False,
+ norm_type="BN",
+ default_type=6,
+ loc_lr=1,
+ stn=True,
+ ):
"""
Args:
in_channels (int): channel of input features,
@@ -130,78 +138,87 @@ def __init__(self, in_channels=1,
self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN
self.I_r_size = I_r_size
self.out_channels = in_channels
- if norm_type == 'BN':
+ if norm_type == "BN":
norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True)
- elif norm_type == 'IN':
- norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False,
- use_global_stats=False)
+ elif norm_type == "IN":
+ norm_layer = functools.partial(
+ nn.InstanceNorm2D, weight_attr=False, use_global_stats=False
+ )
else:
- raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+ raise NotImplementedError(
+ "normalization layer [%s] is not found" % norm_type
+ )
if self.spt:
- self.sp_net = SP_TransformerNetwork(in_channels,
- default_type)
+ self.sp_net = SP_TransformerNetwork(in_channels, default_type)
self.spt_convnet = nn.Sequential(
- # 32*100
- nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
- norm_layer(32), nn.ReLU(),
- nn.MaxPool2D(kernel_size=2, stride=2),
- # 16*50
- nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
- norm_layer(64), nn.ReLU(),
- nn.MaxPool2D(kernel_size=2, stride=2),
- # 8*25
- nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
- norm_layer(128), nn.ReLU(),
- nn.MaxPool2D(kernel_size=2, stride=2),
- # 4*12
+ # 32*100
+ nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
+ norm_layer(32),
+ nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 16*50
+ nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
+ norm_layer(64),
+ nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 8*25
+ nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
+ norm_layer(128),
+ nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 4*12
)
self.stucture_fc1 = nn.Sequential(
- nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
- norm_layer(256), nn.ReLU(),
- nn.MaxPool2D(kernel_size=2, stride=2),
- nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
- norm_layer(256), nn.ReLU(), # 2*6
- nn.MaxPool2D(kernel_size=2, stride=2),
- nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
- norm_layer(512), nn.ReLU(), # 1*3
- nn.AdaptiveAvgPool2D(1),
- nn.Flatten(1, -1), # batch_size x 512
- nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
- nn.BatchNorm1D(256), nn.ReLU()
- )
- self.out_weight = 2*default_type+1
- self.spt_length = 2*default_type+1
+ nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
+ norm_layer(256),
+ nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
+ norm_layer(256),
+ nn.ReLU(), # 2*6
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
+ norm_layer(512),
+ nn.ReLU(), # 1*3
+ nn.AdaptiveAvgPool2D(1),
+ nn.Flatten(1, -1), # batch_size x 512
+ nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
+ nn.BatchNorm1D(256),
+ nn.ReLU(),
+ )
+ self.out_weight = 2 * default_type + 1
+ self.spt_length = 2 * default_type + 1
if offsets:
self.out_weight += 1
if self.stn:
self.F = 20
self.out_weight += self.F * 2
- self.GridGenerator = GridGenerator(self.F*2, self.F)
-
+ self.GridGenerator = GridGenerator(self.F * 2, self.F)
+
# self.out_weight*=nc
# Init structure_fc2 in LocalizationNetwork
- initial_bias = self.init_spin(default_type*2)
+ initial_bias = self.init_spin(default_type * 2)
initial_bias = initial_bias.reshape(-1)
param_attr = ParamAttr(
learning_rate=loc_lr,
- initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])))
+ initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])),
+ )
bias_attr = ParamAttr(
- learning_rate=loc_lr,
- initializer=nn.initializer.Assign(initial_bias))
- self.stucture_fc2 = nn.Linear(256, self.out_weight,
- weight_attr=param_attr,
- bias_attr=bias_attr)
+ learning_rate=loc_lr, initializer=nn.initializer.Assign(initial_bias)
+ )
+ self.stucture_fc2 = nn.Linear(
+ 256, self.out_weight, weight_attr=param_attr, bias_attr=bias_attr
+ )
self.sigmoid = nn.Sigmoid()
if offsets:
- self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16,
- 3, 1, 1,
- bias_attr=False),
- norm_layer(16),
- nn.ReLU(),)
- self.offset_fc2 = nn.Conv2D(16, in_channels,
- 3, 1, 1)
+ self.offset_fc1 = nn.Sequential(
+ nn.Conv2D(128, 16, 3, 1, 1, bias_attr=False),
+ norm_layer(16),
+ nn.ReLU(),
+ )
+ self.offset_fc2 = nn.Conv2D(16, in_channels, 3, 1, 1)
self.pool = nn.MaxPool2D(2, 2)
def init_spin(self, nz):
@@ -210,7 +227,7 @@ def init_spin(self, nz):
nz (int): number of paired \betas exponents, which means the value of K x 2
"""
- init_id = [0.00]*nz+[5.00]
+ init_id = [0.00] * nz + [5.00]
if self.offsets:
init_id += [-5.00]
# init_id *=3
@@ -243,11 +260,15 @@ def forward(self, x, return_weight=False):
feat = self.spt_convnet(x)
fc1 = self.stucture_fc1(feat)
sp_weight_fusion = self.stucture_fc2(fc1)
- sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1])
+ sp_weight_fusion = sp_weight_fusion.reshape(
+ [x.shape[0], self.out_weight, 1]
+ )
if self.offsets: # SPIN w. AIN
lambda_color = sp_weight_fusion[:, self.spt_length, 0]
- lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
- sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+ lambda_color = (
+ self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+ sp_weight = sp_weight_fusion[:, : self.spt_length, :]
offsets = self.pool(self.offset_fc2(self.offset_fc1(feat)))
assert offsets.shape[2] == 2 # 2
@@ -256,27 +277,31 @@ def forward(self, x, return_weight=False):
if return_weight:
return offsets
- offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear')
+ offsets = nn.functional.upsample(
+ offsets, size=(x.shape[2], x.shape[3]), mode="bilinear"
+ )
if self.stn:
- batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2])
+ batch_C_prime = sp_weight_fusion[
+ :, (self.spt_length + 1) :, :
+ ].reshape([x.shape[0], self.F, 2])
build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
- build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
- self.I_r_size[0],
- self.I_r_size[1],
- 2])
+ build_P_prime_reshape = build_P_prime.reshape(
+ [build_P_prime.shape[0], self.I_r_size[0], self.I_r_size[1], 2]
+ )
else: # SPIN w.o. AIN
- sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+ sp_weight = sp_weight_fusion[:, : self.spt_length, :]
lambda_color, offsets = None, None
if self.stn:
- batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2])
+ batch_C_prime = sp_weight_fusion[:, self.spt_length :, :].reshape(
+ [x.shape[0], self.F, 2]
+ )
build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
- build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
- self.I_r_size[0],
- self.I_r_size[1],
- 2])
+ build_P_prime_reshape = build_P_prime.reshape(
+ [build_P_prime.shape[0], self.I_r_size[0], self.I_r_size[1], 2]
+ )
x = self.sp_net(x, sp_weight, offsets, lambda_color)
if self.stn:
@@ -286,7 +311,9 @@ def forward(self, x, return_weight=False):
x = x.cast(paddle.float32)
build_P_prime_reshape = build_P_prime_reshape.cast(paddle.float32)
is_fp16 = True
- x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
+ x = F.grid_sample(
+ x=x, grid=build_P_prime_reshape, padding_mode="border"
+ )
if is_fp16:
x = x.cast(data_type)
return x
diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py
index 6f2bdda050..a721184a45 100644
--- a/ppocr/modeling/transforms/stn.py
+++ b/ppocr/modeling/transforms/stn.py
@@ -30,75 +30,76 @@
def conv3x3_block(in_channels, out_channels, stride=1):
n = 3 * 3 * out_channels
- w = math.sqrt(2. / n)
+ w = math.sqrt(2.0 / n)
conv_layer = nn.Conv2D(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
- weight_attr=nn.initializer.Normal(
- mean=0.0, std=w),
- bias_attr=nn.initializer.Constant(0))
+ weight_attr=nn.initializer.Normal(mean=0.0, std=w),
+ bias_attr=nn.initializer.Constant(0),
+ )
block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
return block
class STN(nn.Layer):
- def __init__(self, in_channels, num_ctrlpoints, activation='none'):
+ def __init__(self, in_channels, num_ctrlpoints, activation="none"):
super(STN, self).__init__()
self.in_channels = in_channels
self.num_ctrlpoints = num_ctrlpoints
self.activation = activation
self.stn_convnet = nn.Sequential(
- conv3x3_block(in_channels, 32), #32x64
- nn.MaxPool2D(
- kernel_size=2, stride=2),
- conv3x3_block(32, 64), #16x32
- nn.MaxPool2D(
- kernel_size=2, stride=2),
+ conv3x3_block(in_channels, 32), # 32x64
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ conv3x3_block(32, 64), # 16x32
+ nn.MaxPool2D(kernel_size=2, stride=2),
conv3x3_block(64, 128), # 8*16
- nn.MaxPool2D(
- kernel_size=2, stride=2),
+ nn.MaxPool2D(kernel_size=2, stride=2),
conv3x3_block(128, 256), # 4*8
- nn.MaxPool2D(
- kernel_size=2, stride=2),
+ nn.MaxPool2D(kernel_size=2, stride=2),
conv3x3_block(256, 256), # 2*4,
- nn.MaxPool2D(
- kernel_size=2, stride=2),
- conv3x3_block(256, 256)) # 1*2
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ conv3x3_block(256, 256),
+ ) # 1*2
self.stn_fc1 = nn.Sequential(
nn.Linear(
2 * 256,
512,
weight_attr=nn.initializer.Normal(0, 0.001),
- bias_attr=nn.initializer.Constant(0)),
+ bias_attr=nn.initializer.Constant(0),
+ ),
nn.BatchNorm1D(512),
- nn.ReLU())
+ nn.ReLU(),
+ )
fc2_bias = self.init_stn()
self.stn_fc2 = nn.Linear(
512,
num_ctrlpoints * 2,
weight_attr=nn.initializer.Constant(0.0),
- bias_attr=nn.initializer.Assign(fc2_bias))
+ bias_attr=nn.initializer.Assign(fc2_bias),
+ )
def init_stn(self):
margin = 0.01
sampling_num_per_side = int(self.num_ctrlpoints / 2)
- ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
+ ctrl_pts_x = np.linspace(margin, 1.0 - margin, sampling_num_per_side)
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
- ctrl_points = np.concatenate(
- [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
- if self.activation == 'none':
+ ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(
+ np.float32
+ )
+ if self.activation == "none":
pass
- elif self.activation == 'sigmoid':
- ctrl_points = -np.log(1. / ctrl_points - 1.)
+ elif self.activation == "sigmoid":
+ ctrl_points = -np.log(1.0 / ctrl_points - 1.0)
ctrl_points = paddle.to_tensor(ctrl_points)
fc2_bias = paddle.reshape(
- ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
+ ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]]
+ )
return fc2_bias
def forward(self, x):
@@ -107,29 +108,40 @@ def forward(self, x):
x = paddle.reshape(x, shape=(batch_size, -1))
img_feat = self.stn_fc1(x)
x = self.stn_fc2(0.1 * img_feat)
- if self.activation == 'sigmoid':
+ if self.activation == "sigmoid":
x = F.sigmoid(x)
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
return img_feat, x
class STN_ON(nn.Layer):
- def __init__(self, in_channels, tps_inputsize, tps_outputsize,
- num_control_points, tps_margins, stn_activation):
+ def __init__(
+ self,
+ in_channels,
+ tps_inputsize,
+ tps_outputsize,
+ num_control_points,
+ tps_margins,
+ stn_activation,
+ ):
super(STN_ON, self).__init__()
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
- margins=tuple(tps_margins))
- self.stn_head = STN(in_channels=in_channels,
- num_ctrlpoints=num_control_points,
- activation=stn_activation)
+ margins=tuple(tps_margins),
+ )
+ self.stn_head = STN(
+ in_channels=in_channels,
+ num_ctrlpoints=num_control_points,
+ activation=stn_activation,
+ )
self.tps_inputsize = tps_inputsize
self.out_channels = in_channels
def forward(self, image):
stn_input = paddle.nn.functional.interpolate(
- image, self.tps_inputsize, mode="bilinear", align_corners=True)
+ image, self.tps_inputsize, mode="bilinear", align_corners=True
+ )
stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points)
return x
diff --git a/ppocr/modeling/transforms/tbsrn.py b/ppocr/modeling/transforms/tbsrn.py
index a53d0d277e..4d1d373e41 100644
--- a/ppocr/modeling/transforms/tbsrn.py
+++ b/ppocr/modeling/transforms/tbsrn.py
@@ -28,8 +28,12 @@
from .tps_spatial_transformer import TPSSpatialTransformer
from .stn import STN as STNHead
from .tsrn import GruBlock, mish, UpsampleBLock
-from ppocr.modeling.heads.sr_rensnet_transformer import Transformer, LayerNorm, \
- PositionwiseFeedForward, MultiHeadedAttention
+from ppocr.modeling.heads.sr_rensnet_transformer import (
+ Transformer,
+ LayerNorm,
+ PositionwiseFeedForward,
+ MultiHeadedAttention,
+)
def positionalencoding2d(d_model, height, width):
@@ -40,24 +44,31 @@ def positionalencoding2d(d_model, height, width):
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
- raise ValueError("Cannot use sin/cos positional encoding with "
- "odd dimension (got dim={:d})".format(d_model))
+ raise ValueError(
+ "Cannot use sin/cos positional encoding with "
+ "odd dimension (got dim={:d})".format(d_model)
+ )
pe = paddle.zeros([d_model, height, width])
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = paddle.exp(
- paddle.arange(0., d_model, 2, dtype='int64') * -(math.log(10000.0) / d_model))
- pos_w = paddle.arange(0., width, dtype='float32').unsqueeze(1)
- pos_h = paddle.arange(0., height, dtype='float32').unsqueeze(1)
-
- pe[0:d_model:2, :, :] = paddle.sin(pos_w * div_term).transpose(
- [1, 0]).unsqueeze(1).tile([1, height, 1])
- pe[1:d_model:2, :, :] = paddle.cos(pos_w * div_term).transpose(
- [1, 0]).unsqueeze(1).tile([1, height, 1])
- pe[d_model::2, :, :] = paddle.sin(pos_h * div_term).transpose(
- [1, 0]).unsqueeze(2).tile([1, 1, width])
- pe[d_model + 1::2, :, :] = paddle.cos(pos_h * div_term).transpose(
- [1, 0]).unsqueeze(2).tile([1, 1, width])
+ paddle.arange(0.0, d_model, 2, dtype="int64") * -(math.log(10000.0) / d_model)
+ )
+ pos_w = paddle.arange(0.0, width, dtype="float32").unsqueeze(1)
+ pos_h = paddle.arange(0.0, height, dtype="float32").unsqueeze(1)
+
+ pe[0:d_model:2, :, :] = (
+ paddle.sin(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
+ )
+ pe[1:d_model:2, :, :] = (
+ paddle.cos(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
+ )
+ pe[d_model::2, :, :] = (
+ paddle.sin(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
+ )
+ pe[d_model + 1 :: 2, :, :] = (
+ paddle.cos(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
+ )
return pe
@@ -75,21 +86,27 @@ def __init__(self):
self.linear = nn.Linear(128, 64)
def forward(self, conv_feature):
- '''
+ """
text : (batch, seq_len, embedding_size)
global_info: (batch, embedding_size, 1, 1)
conv_feature: (batch, channel, H, W)
- '''
+ """
batch = conv_feature.shape[0]
- position2d = positionalencoding2d(
- 64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
+ position2d = (
+ positionalencoding2d(64, 16, 64)
+ .cast("float32")
+ .unsqueeze(0)
+ .reshape([1, 64, 1024])
+ )
position2d = position2d.tile([batch, 1, 1])
- conv_feature = paddle.concat([conv_feature, position2d],
- 1) # batch, 128(64+64), 32, 128
+ conv_feature = paddle.concat(
+ [conv_feature, position2d], 1
+ ) # batch, 128(64+64), 32, 128
result = conv_feature.transpose([0, 2, 1])
origin_result = result
- result = self.mul_layernorm1(origin_result + self.multihead(
- result, result, result, mask=None)[0])
+ result = self.mul_layernorm1(
+ origin_result + self.multihead(result, result, result, mask=None)[0]
+ )
origin_result = result
result = self.mul_layernorm3(origin_result + self.pff(result))
result = self.linear(result)
@@ -98,31 +115,33 @@ def forward(self, conv_feature):
def str_filt(str_, voc_type):
alpha_dict = {
- 'digit': string.digits,
- 'lower': string.digits + string.ascii_lowercase,
- 'upper': string.digits + string.ascii_letters,
- 'all': string.digits + string.ascii_letters + string.punctuation
+ "digit": string.digits,
+ "lower": string.digits + string.ascii_lowercase,
+ "upper": string.digits + string.ascii_letters,
+ "all": string.digits + string.ascii_letters + string.punctuation,
}
- if voc_type == 'lower':
+ if voc_type == "lower":
str_ = str_.lower()
for char in str_:
if char not in alpha_dict[voc_type]:
- str_ = str_.replace(char, '')
+ str_ = str_.replace(char, "")
str_ = str_.lower()
return str_
class TBSRN(nn.Layer):
- def __init__(self,
- in_channels=3,
- scale_factor=2,
- width=128,
- height=32,
- STN=True,
- srb_nums=5,
- mask=False,
- hidden_units=32,
- infer_mode=False):
+ def __init__(
+ self,
+ in_channels=3,
+ scale_factor=2,
+ width=128,
+ height=32,
+ STN=True,
+ srb_nums=5,
+ mask=False,
+ hidden_units=32,
+ infer_mode=False,
+ ):
super(TBSRN, self).__init__()
in_planes = 3
if mask:
@@ -130,36 +149,27 @@ def __init__(self,
assert math.log(scale_factor, 2) % 1 == 0
upsample_block_num = int(math.log(scale_factor, 2))
self.block1 = nn.Sequential(
- nn.Conv2D(
- in_planes, 2 * hidden_units, kernel_size=9, padding=4),
+ nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4),
nn.PReLU()
# nn.ReLU()
)
self.srb_nums = srb_nums
for i in range(srb_nums):
- setattr(self, 'block%d' % (i + 2),
- RecurrentResidualBlock(2 * hidden_units))
+ setattr(self, "block%d" % (i + 2), RecurrentResidualBlock(2 * hidden_units))
setattr(
self,
- 'block%d' % (srb_nums + 2),
+ "block%d" % (srb_nums + 2),
nn.Sequential(
- nn.Conv2D(
- 2 * hidden_units,
- 2 * hidden_units,
- kernel_size=3,
- padding=1),
- nn.BatchNorm2D(2 * hidden_units)))
+ nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1),
+ nn.BatchNorm2D(2 * hidden_units),
+ ),
+ )
# self.non_local = NonLocalBlock2D(64, 64)
- block_ = [
- UpsampleBLock(2 * hidden_units, 2)
- for _ in range(upsample_block_num)
- ]
- block_.append(
- nn.Conv2D(
- 2 * hidden_units, in_planes, kernel_size=9, padding=4))
- setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
+ block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)]
+ block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4))
+ setattr(self, "block%d" % (srb_nums + 3), nn.Sequential(*block_))
self.tps_inputsize = [height // scale_factor, width // scale_factor]
tps_outputsize = [height // scale_factor, width // scale_factor]
num_control_points = 20
@@ -170,20 +180,23 @@ def __init__(self,
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
- margins=tuple(tps_margins))
+ margins=tuple(tps_margins),
+ )
self.stn_head = STNHead(
in_channels=in_planes,
num_ctrlpoints=num_control_points,
- activation='none')
+ activation="none",
+ )
self.infer_mode = infer_mode
- self.english_alphabet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+ self.english_alphabet = (
+ "-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ )
self.english_dict = {}
for index in range(len(self.english_alphabet)):
self.english_dict[self.english_alphabet[index]] = index
- transformer = Transformer(
- alphabet='-0123456789abcdefghijklmnopqrstuvwxyz')
+ transformer = Transformer(alphabet="-0123456789abcdefghijklmnopqrstuvwxyz")
self.transformer = transformer
for param in self.transformer.parameters():
param.trainable = False
@@ -192,7 +205,7 @@ def label_encoder(self, label):
batch = len(label)
length = [len(i) for i in label]
- length_tensor = paddle.to_tensor(length, dtype='int64')
+ length_tensor = paddle.to_tensor(length, dtype="int64")
max_length = max(length)
input_tensor = np.zeros((batch, max_length))
@@ -204,9 +217,9 @@ def label_encoder(self, label):
for i in label:
for j in i:
text_gt.append(self.english_dict[j])
- text_gt = paddle.to_tensor(text_gt, dtype='int64')
+ text_gt = paddle.to_tensor(text_gt, dtype="int64")
- input_tensor = paddle.to_tensor(input_tensor, dtype='int64')
+ input_tensor = paddle.to_tensor(input_tensor, dtype="int64")
return length_tensor, input_tensor, text_gt
def forward(self, x):
@@ -221,13 +234,13 @@ def forward(self, x):
if self.stn and self.training:
_, ctrl_points_x = self.stn_head(y)
y, _ = self.tps(y, ctrl_points_x)
- block = {'1': self.block1(y)}
+ block = {"1": self.block1(y)}
for i in range(self.srb_nums + 1):
- block[str(i + 2)] = getattr(self,
- 'block%d' % (i + 2))(block[str(i + 1)])
+ block[str(i + 2)] = getattr(self, "block%d" % (i + 2))(block[str(i + 1)])
- block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
- ((block['1'] + block[str(self.srb_nums + 2)]))
+ block[str(self.srb_nums + 3)] = getattr(self, "block%d" % (self.srb_nums + 3))(
+ (block["1"] + block[str(self.srb_nums + 2)])
+ )
sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
output["sr_img"] = sr_img
@@ -236,12 +249,14 @@ def forward(self, x):
hr_img = x[1]
# add transformer
- label = [str_filt(i, 'lower') + '-' for i in x[2]]
+ label = [str_filt(i, "lower") + "-" for i in x[2]]
length_tensor, input_tensor, text_gt = self.label_encoder(label)
hr_pred, word_attention_map_gt, hr_correct_list = self.transformer(
- hr_img, length_tensor, input_tensor)
+ hr_img, length_tensor, input_tensor
+ )
sr_pred, word_attention_map_pred, sr_correct_list = self.transformer(
- sr_img, length_tensor, input_tensor)
+ sr_img, length_tensor, input_tensor
+ )
output["hr_img"] = hr_img
output["hr_pred"] = hr_pred
output["text_gt"] = text_gt
diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py
index 10802bc952..d5681db8c1 100644
--- a/ppocr/modeling/transforms/tps.py
+++ b/ppocr/modeling/transforms/tps.py
@@ -28,14 +28,16 @@
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- act=None,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ act=None,
+ name=None,
+ ):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
in_channels=in_channels,
@@ -45,15 +47,17 @@ def __init__(self,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False)
+ bias_attr=False,
+ )
bn_name = "bn_" + name
self.bn = nn.BatchNorm(
out_channels,
act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
+ param_attr=ParamAttr(name=bn_name + "_scale"),
+ bias_attr=ParamAttr(bn_name + "_offset"),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance",
+ )
def forward(self, x):
x = self.conv(x)
@@ -83,8 +87,10 @@ def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
in_channels=in_channels,
out_channels=num_filters,
kernel_size=3,
- act='relu',
- name=name))
+ act="relu",
+ name=name,
+ ),
+ )
self.block_list.append(conv)
if fno == len(num_filters_list) - 1:
pool = nn.AdaptiveAvgPool2D(1)
@@ -100,9 +106,11 @@ def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
weight_attr=ParamAttr(
learning_rate=loc_lr,
name=name + "_w",
- initializer=nn.initializer.Uniform(-stdv, stdv)),
- bias_attr=ParamAttr(name=name + '.b_0'),
- name=name)
+ initializer=nn.initializer.Uniform(-stdv, stdv),
+ ),
+ bias_attr=ParamAttr(name=name + ".b_0"),
+ name=name,
+ )
# Init fc2 in LocalizationNetwork
initial_bias = self.get_initial_fiducials()
@@ -111,26 +119,25 @@ def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
param_attr = ParamAttr(
learning_rate=loc_lr,
initializer=nn.initializer.Assign(np.zeros([fc_dim, F * 2])),
- name=name + "_w")
+ name=name + "_w",
+ )
bias_attr = ParamAttr(
learning_rate=loc_lr,
initializer=nn.initializer.Assign(initial_bias),
- name=name + "_b")
+ name=name + "_b",
+ )
self.fc2 = nn.Linear(
- fc_dim,
- F * 2,
- weight_attr=param_attr,
- bias_attr=bias_attr,
- name=name)
+ fc_dim, F * 2, weight_attr=param_attr, bias_attr=bias_attr, name=name
+ )
self.out_channels = F * 2
def forward(self, x):
"""
- Estimating parameters of geometric transformation
- Args:
- image: input
- Return:
- batch_C_prime: the matrix of the geometric transformation
+ Estimating parameters of geometric transformation
+ Args:
+ image: input
+ Return:
+ batch_C_prime: the matrix of the geometric transformation
"""
B = x.shape[0]
i = 0
@@ -145,7 +152,7 @@ def forward(self, x):
return x
def get_initial_fiducials(self):
- """ see RARE paper Fig. 6 (a) """
+ """see RARE paper Fig. 6 (a)"""
F = self.F
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
@@ -165,15 +172,14 @@ def __init__(self, in_channels, num_fiducial):
name = "ex_fc"
initializer = nn.initializer.Constant(value=0.0)
param_attr = ParamAttr(
- learning_rate=0.0, initializer=initializer, name=name + "_w")
+ learning_rate=0.0, initializer=initializer, name=name + "_w"
+ )
bias_attr = ParamAttr(
- learning_rate=0.0, initializer=initializer, name=name + "_b")
+ learning_rate=0.0, initializer=initializer, name=name + "_b"
+ )
self.fc = nn.Linear(
- in_channels,
- 6,
- weight_attr=param_attr,
- bias_attr=bias_attr,
- name=name)
+ in_channels, 6, weight_attr=param_attr, bias_attr=bias_attr, name=name
+ )
def forward(self, batch_C_prime, I_r_size):
"""
@@ -187,9 +193,8 @@ def forward(self, batch_C_prime, I_r_size):
C = self.build_C_paddle()
P = self.build_P_paddle(I_r_size)
- inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).astype('float32')
- P_hat_tensor = self.build_P_hat_paddle(
- C, paddle.to_tensor(P)).astype('float32')
+ inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).astype("float32")
+ P_hat_tensor = self.build_P_hat_paddle(C, paddle.to_tensor(P)).astype("float32")
inv_delta_C_tensor.stop_gradient = True
P_hat_tensor.stop_gradient = True
@@ -199,17 +204,18 @@ def forward(self, batch_C_prime, I_r_size):
batch_C_ex_part_tensor.stop_gradient = True
batch_C_prime_with_zeros = paddle.concat(
- [batch_C_prime, batch_C_ex_part_tensor], axis=1)
+ [batch_C_prime, batch_C_ex_part_tensor], axis=1
+ )
batch_T = paddle.matmul(inv_delta_C_tensor, batch_C_prime_with_zeros)
batch_P_prime = paddle.matmul(P_hat_tensor, batch_T)
return batch_P_prime
def build_C_paddle(self):
- """ Return coordinates of fiducial points in I_r; C """
+ """Return coordinates of fiducial points in I_r; C"""
F = self.F
- ctrl_pts_x = paddle.linspace(-1.0, 1.0, int(F / 2), dtype='float64')
- ctrl_pts_y_top = -1 * paddle.ones([int(F / 2)], dtype='float64')
- ctrl_pts_y_bottom = paddle.ones([int(F / 2)], dtype='float64')
+ ctrl_pts_x = paddle.linspace(-1.0, 1.0, int(F / 2), dtype="float64")
+ ctrl_pts_y_top = -1 * paddle.ones([int(F / 2)], dtype="float64")
+ ctrl_pts_y_bottom = paddle.ones([int(F / 2)], dtype="float64")
ctrl_pts_top = paddle.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = paddle.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
C = paddle.concat([ctrl_pts_top, ctrl_pts_bottom], axis=0)
@@ -217,13 +223,13 @@ def build_C_paddle(self):
def build_P_paddle(self, I_r_size):
I_r_height, I_r_width = I_r_size
- I_r_grid_x = (paddle.arange(
- -I_r_width, I_r_width, 2, dtype='float64') + 1.0
- ) / paddle.to_tensor(np.array([I_r_width])).astype('float64')
+ I_r_grid_x = (
+ paddle.arange(-I_r_width, I_r_width, 2, dtype="float64") + 1.0
+ ) / paddle.to_tensor(np.array([I_r_width])).astype("float64")
- I_r_grid_y = (paddle.arange(
- -I_r_height, I_r_height, 2, dtype='float64') + 1.0
- ) / paddle.to_tensor(np.array([I_r_height])).astype('float64')
+ I_r_grid_y = (
+ paddle.arange(-I_r_height, I_r_height, 2, dtype="float64") + 1.0
+ ) / paddle.to_tensor(np.array([I_r_height])).astype("float64")
# P: self.I_r_width x self.I_r_height x 2
P = paddle.stack(paddle.meshgrid(I_r_grid_x, I_r_grid_y), axis=2)
@@ -232,33 +238,35 @@ def build_P_paddle(self, I_r_size):
return P.reshape([-1, 2])
def build_inv_delta_C_paddle(self, C):
- """ Return inv_delta_C which is needed to calculate T """
+ """Return inv_delta_C which is needed to calculate T"""
F = self.F
- hat_eye = paddle.eye(F, dtype='float64') # F x F
- hat_C = paddle.norm(
- C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
+ hat_eye = paddle.eye(F, dtype="float64") # F x F
+ hat_C = (
+ paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
+ )
hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3
[
paddle.concat(
- [paddle.ones(
- (F, 1), dtype='float64'), C, hat_C], axis=1), # F x F+3
+ [paddle.ones((F, 1), dtype="float64"), C, hat_C], axis=1
+ ), # F x F+3
paddle.concat(
[
- paddle.zeros(
- (2, 3), dtype='float64'), paddle.transpose(
- C, perm=[1, 0])
+ paddle.zeros((2, 3), dtype="float64"),
+ paddle.transpose(C, perm=[1, 0]),
],
- axis=1), # 2 x F+3
+ axis=1,
+ ), # 2 x F+3
paddle.concat(
[
- paddle.zeros(
- (1, 3), dtype='float64'), paddle.ones(
- (1, F), dtype='float64')
+ paddle.zeros((1, 3), dtype="float64"),
+ paddle.ones((1, F), dtype="float64"),
],
- axis=1) # 1 x F+3
+ axis=1,
+ ), # 1 x F+3
],
- axis=0)
+ axis=0,
+ )
inv_delta_C = paddle.inverse(delta_C)
return inv_delta_C # F+3 x F+3
@@ -274,11 +282,8 @@ def build_P_hat_paddle(self, C, P):
rbf_norm = paddle.norm(P_diff, p=2, axis=2, keepdim=False)
# rbf: n x F
- rbf = paddle.multiply(
- paddle.square(rbf_norm), paddle.log(rbf_norm + eps))
- P_hat = paddle.concat(
- [paddle.ones(
- (n, 1), dtype='float64'), P, rbf], axis=1)
+ rbf = paddle.multiply(paddle.square(rbf_norm), paddle.log(rbf_norm + eps))
+ P_hat = paddle.concat([paddle.ones((n, 1), dtype="float64"), P, rbf], axis=1)
return P_hat # n x F+3
def get_expand_tensor(self, batch_C_prime):
@@ -292,18 +297,17 @@ def get_expand_tensor(self, batch_C_prime):
class TPS(nn.Layer):
def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
super(TPS, self).__init__()
- self.loc_net = LocalizationNetwork(in_channels, num_fiducial, loc_lr,
- model_name)
- self.grid_generator = GridGenerator(self.loc_net.out_channels,
- num_fiducial)
+ self.loc_net = LocalizationNetwork(
+ in_channels, num_fiducial, loc_lr, model_name
+ )
+ self.grid_generator = GridGenerator(self.loc_net.out_channels, num_fiducial)
self.out_channels = in_channels
def forward(self, image):
image.stop_gradient = False
batch_C_prime = self.loc_net(image)
batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:])
- batch_P_prime = batch_P_prime.reshape(
- [-1, image.shape[2], image.shape[3], 2])
+ batch_P_prime = batch_P_prime.reshape([-1, image.shape[2], image.shape[3], 2])
is_fp16 = False
if batch_P_prime.dtype != paddle.float32:
data_type = batch_P_prime.dtype
@@ -313,5 +317,5 @@ def forward(self, image):
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
if is_fp16:
batch_I_r = batch_I_r.cast(data_type)
-
+
return batch_I_r
diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py
index 4c905ee0cd..9c97978700 100644
--- a/ppocr/modeling/transforms/tps_spatial_transformer.py
+++ b/ppocr/modeling/transforms/tps_spatial_transformer.py
@@ -59,14 +59,13 @@ def grid_sample(input, grid, canvas=None):
def compute_partial_repr(input_points, control_points):
N = input_points.shape[0]
M = control_points.shape[0]
- pairwise_diff = paddle.reshape(
- input_points, shape=[N, 1, 2]) - paddle.reshape(
- control_points, shape=[1, M, 2])
+ pairwise_diff = paddle.reshape(input_points, shape=[N, 1, 2]) - paddle.reshape(
+ control_points, shape=[1, M, 2]
+ )
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square = pairwise_diff * pairwise_diff
- pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
- 1]
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask = np.array(repr_matrix != repr_matrix)
@@ -83,64 +82,62 @@ def build_output_control_points(num_control_points, margins):
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
- output_ctrl_pts_arr = np.concatenate(
- [ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
return output_ctrl_pts
class TPSSpatialTransformer(nn.Layer):
- def __init__(self,
- output_image_size=None,
- num_control_points=None,
- margins=None):
+ def __init__(self, output_image_size=None, num_control_points=None, margins=None):
super(TPSSpatialTransformer, self).__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size
- target_control_points = build_output_control_points(num_control_points,
- margins)
+ target_control_points = build_output_control_points(num_control_points, margins)
N = num_control_points
# create padded kernel matrix
forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
target_control_partial_repr = compute_partial_repr(
- target_control_points, target_control_points)
- target_control_partial_repr = paddle.cast(target_control_partial_repr,
- forward_kernel.dtype)
+ target_control_points, target_control_points
+ )
+ target_control_partial_repr = paddle.cast(
+ target_control_partial_repr, forward_kernel.dtype
+ )
forward_kernel[:N, :N] = target_control_partial_repr
forward_kernel[:N, -3] = 1
forward_kernel[-3, :N] = 1
- target_control_points = paddle.cast(target_control_points,
- forward_kernel.dtype)
+ target_control_points = paddle.cast(target_control_points, forward_kernel.dtype)
forward_kernel[:N, -2:] = target_control_points
- forward_kernel[-2:, :N] = paddle.transpose(
- target_control_points, perm=[1, 0])
+ forward_kernel[-2:, :N] = paddle.transpose(target_control_points, perm=[1, 0])
# compute inverse matrix
inverse_kernel = paddle.inverse(forward_kernel)
# create target cordinate matrix
HW = self.target_height * self.target_width
target_coordinate = list(
- itertools.product(
- range(self.target_height), range(self.target_width)))
+ itertools.product(range(self.target_height), range(self.target_width))
+ )
target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
- Y, X = paddle.split(
- target_coordinate, target_coordinate.shape[1], axis=1)
+ Y, X = paddle.split(target_coordinate, target_coordinate.shape[1], axis=1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
target_coordinate = paddle.concat(
- [X, Y], axis=1) # convert from (y, x) to (x, y)
+ [X, Y], axis=1
+ ) # convert from (y, x) to (x, y)
target_coordinate_partial_repr = compute_partial_repr(
- target_coordinate, target_control_points)
+ target_coordinate, target_control_points
+ )
target_coordinate_repr = paddle.concat(
[
- target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
- target_coordinate
+ target_coordinate_partial_repr,
+ paddle.ones(shape=[HW, 1]),
+ target_coordinate,
],
- axis=1)
+ axis=1,
+ )
# register precomputed matrices
self.inverse_kernel = inverse_kernel
@@ -154,20 +151,19 @@ def forward(self, input, source_control_points):
assert source_control_points.shape[2] == 2
batch_size = source_control_points.shape[0]
- padding_matrix = paddle.expand(
- self.padding_matrix, shape=[batch_size, 3, 2])
- Y = paddle.concat([
- source_control_points.astype(padding_matrix.dtype), padding_matrix
- ], 1)
+ padding_matrix = paddle.expand(self.padding_matrix, shape=[batch_size, 3, 2])
+ Y = paddle.concat(
+ [source_control_points.astype(padding_matrix.dtype), padding_matrix], 1
+ )
mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
- source_coordinate = paddle.matmul(self.target_coordinate_repr,
- mapping_matrix)
+ source_coordinate = paddle.matmul(self.target_coordinate_repr, mapping_matrix)
grid = paddle.reshape(
- source_coordinate,
- shape=[-1, self.target_height, self.target_width, 2])
- grid = paddle.clip(grid, 0,
- 1) # the source_control_points may be out of [0, 1].
+ source_coordinate, shape=[-1, self.target_height, self.target_width, 2]
+ )
+ grid = paddle.clip(
+ grid, 0, 1
+ ) # the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None)
diff --git a/ppocr/modeling/transforms/tsrn.py b/ppocr/modeling/transforms/tsrn.py
index 31aa90ea4b..9376bc6d94 100644
--- a/ppocr/modeling/transforms/tsrn.py
+++ b/ppocr/modeling/transforms/tsrn.py
@@ -35,17 +35,19 @@
class TSRN(nn.Layer):
- def __init__(self,
- in_channels,
- scale_factor=2,
- width=128,
- height=32,
- STN=False,
- srb_nums=5,
- mask=False,
- hidden_units=32,
- infer_mode=False,
- **kwargs):
+ def __init__(
+ self,
+ in_channels,
+ scale_factor=2,
+ width=128,
+ height=32,
+ STN=False,
+ srb_nums=5,
+ mask=False,
+ hidden_units=32,
+ infer_mode=False,
+ **kwargs
+ ):
super(TSRN, self).__init__()
in_planes = 3
if mask:
@@ -53,33 +55,24 @@ def __init__(self,
assert math.log(scale_factor, 2) % 1 == 0
upsample_block_num = int(math.log(scale_factor, 2))
self.block1 = nn.Sequential(
- nn.Conv2D(
- in_planes, 2 * hidden_units, kernel_size=9, padding=4),
- nn.PReLU())
+ nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4), nn.PReLU()
+ )
self.srb_nums = srb_nums
for i in range(srb_nums):
- setattr(self, 'block%d' % (i + 2),
- RecurrentResidualBlock(2 * hidden_units))
+ setattr(self, "block%d" % (i + 2), RecurrentResidualBlock(2 * hidden_units))
setattr(
self,
- 'block%d' % (srb_nums + 2),
+ "block%d" % (srb_nums + 2),
nn.Sequential(
- nn.Conv2D(
- 2 * hidden_units,
- 2 * hidden_units,
- kernel_size=3,
- padding=1),
- nn.BatchNorm2D(2 * hidden_units)))
-
- block_ = [
- UpsampleBLock(2 * hidden_units, 2)
- for _ in range(upsample_block_num)
- ]
- block_.append(
- nn.Conv2D(
- 2 * hidden_units, in_planes, kernel_size=9, padding=4))
- setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
+ nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1),
+ nn.BatchNorm2D(2 * hidden_units),
+ ),
+ )
+
+ block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)]
+ block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4))
+ setattr(self, "block%d" % (srb_nums + 3), nn.Sequential(*block_))
self.tps_inputsize = [height // scale_factor, width // scale_factor]
tps_outputsize = [height // scale_factor, width // scale_factor]
num_control_points = 20
@@ -89,12 +82,14 @@ def __init__(self,
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
- margins=tuple(tps_margins))
+ margins=tuple(tps_margins),
+ )
self.stn_head = STN_model(
in_channels=in_planes,
num_ctrlpoints=num_control_points,
- activation='none')
+ activation="none",
+ )
self.out_channels = in_channels
self.r34_transformer = Transformer()
@@ -114,13 +109,13 @@ def forward(self, x):
if self.stn and self.training:
_, ctrl_points_x = self.stn_head(y)
y, _ = self.tps(y, ctrl_points_x)
- block = {'1': self.block1(y)}
+ block = {"1": self.block1(y)}
for i in range(self.srb_nums + 1):
- block[str(i + 2)] = getattr(self,
- 'block%d' % (i + 2))(block[str(i + 1)])
+ block[str(i + 2)] = getattr(self, "block%d" % (i + 2))(block[str(i + 1)])
- block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
- ((block['1'] + block[str(self.srb_nums + 2)]))
+ block[str(self.srb_nums + 3)] = getattr(self, "block%d" % (self.srb_nums + 3))(
+ (block["1"] + block[str(self.srb_nums + 2)])
+ )
sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
@@ -131,12 +126,14 @@ def forward(self, x):
length = x[2]
input_tensor = x[3]
- # add transformer
+ # add transformer
sr_pred, word_attention_map_pred, _ = self.r34_transformer(
- sr_img, length, input_tensor)
+ sr_img, length, input_tensor
+ )
hr_pred, word_attention_map_gt, _ = self.r34_transformer(
- hr_img, length, input_tensor)
+ hr_img, length, input_tensor
+ )
output["hr_img"] = hr_img
output["hr_pred"] = hr_pred
@@ -164,8 +161,7 @@ def forward(self, x):
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
- residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose(
- [0, 1, 3, 2])
+ residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose([0, 1, 3, 2])
return self.gru2(x + residual)
@@ -174,7 +170,8 @@ class UpsampleBLock(nn.Layer):
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2D(
- in_channels, in_channels * up_scale**2, kernel_size=3, padding=1)
+ in_channels, in_channels * up_scale**2, kernel_size=3, padding=1
+ )
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = mish()
@@ -187,7 +184,9 @@ def forward(self, x):
class mish(nn.Layer):
- def __init__(self, ):
+ def __init__(
+ self,
+ ):
super(mish, self).__init__()
self.activated = True
@@ -201,18 +200,15 @@ class GruBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(GruBlock, self).__init__()
assert out_channels % 2 == 0
- self.conv1 = nn.Conv2D(
- in_channels, out_channels, kernel_size=1, padding=0)
- self.gru = nn.GRU(out_channels,
- out_channels // 2,
- direction='bidirectional')
+ self.conv1 = nn.Conv2D(in_channels, out_channels, kernel_size=1, padding=0)
+ self.gru = nn.GRU(out_channels, out_channels // 2, direction="bidirectional")
def forward(self, x):
# x: b, c, w, h
x = self.conv1(x)
x = x.transpose([0, 2, 3, 1]) # b, w, h, c
batch_size, w, h, c = x.shape
- x = x.reshape([-1, h, c]) # b*w, h, c
+ x = x.reshape([-1, h, c]) # b*w, h, c
x, _ = self.gru(x)
x = x.reshape([-1, w, h, c])
x = x.transpose([0, 3, 1, 2])
diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py
index b92954c9cc..a191a4bd0d 100644
--- a/ppocr/optimizer/__init__.py
+++ b/ppocr/optimizer/__init__.py
@@ -19,47 +19,48 @@
import copy
import paddle
-__all__ = ['build_optimizer']
+__all__ = ["build_optimizer"]
def build_lr_scheduler(lr_config, epochs, step_each_epoch):
from . import learning_rate
- lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
- lr_name = lr_config.pop('name', 'Const')
+
+ lr_config.update({"epochs": epochs, "step_each_epoch": step_each_epoch})
+ lr_name = lr_config.pop("name", "Const")
lr = getattr(learning_rate, lr_name)(**lr_config)()
return lr
def build_optimizer(config, epochs, step_each_epoch, model):
from . import regularizer, optimizer
+
config = copy.deepcopy(config)
# step1 build lr
- lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
+ lr = build_lr_scheduler(config.pop("lr"), epochs, step_each_epoch)
# step2 build regularization
- if 'regularizer' in config and config['regularizer'] is not None:
- reg_config = config.pop('regularizer')
- reg_name = reg_config.pop('name')
+ if "regularizer" in config and config["regularizer"] is not None:
+ reg_config = config.pop("regularizer")
+ reg_name = reg_config.pop("name")
if not hasattr(regularizer, reg_name):
- reg_name += 'Decay'
+ reg_name += "Decay"
reg = getattr(regularizer, reg_name)(**reg_config)()
- elif 'weight_decay' in config:
- reg = config.pop('weight_decay')
+ elif "weight_decay" in config:
+ reg = config.pop("weight_decay")
else:
reg = None
# step3 build optimizer
- optim_name = config.pop('name')
- if 'clip_norm' in config:
- clip_norm = config.pop('clip_norm')
+ optim_name = config.pop("name")
+ if "clip_norm" in config:
+ clip_norm = config.pop("clip_norm")
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
- elif 'clip_norm_global' in config:
- clip_norm = config.pop('clip_norm_global')
+ elif "clip_norm_global" in config:
+ clip_norm = config.pop("clip_norm_global")
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
else:
grad_clip = None
- optim = getattr(optimizer, optim_name)(learning_rate=lr,
- weight_decay=reg,
- grad_clip=grad_clip,
- **config)
+ optim = getattr(optimizer, optim_name)(
+ learning_rate=lr, weight_decay=reg, grad_clip=grad_clip, **config
+ )
return optim(model), lr
diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py
index be52a91845..f0b05ff915 100644
--- a/ppocr/optimizer/learning_rate.py
+++ b/ppocr/optimizer/learning_rate.py
@@ -32,15 +32,17 @@ class Linear(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- learning_rate,
- epochs,
- step_each_epoch,
- end_lr=0.0,
- power=1.0,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate,
+ epochs,
+ step_each_epoch,
+ end_lr=0.0,
+ power=1.0,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(Linear, self).__init__()
self.learning_rate = learning_rate
self.epochs = epochs * step_each_epoch
@@ -55,14 +57,16 @@ def __call__(self):
decay_steps=self.epochs,
end_lr=self.end_lr,
power=self.power,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -77,13 +81,15 @@ class Cosine(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- learning_rate,
- step_each_epoch,
- epochs,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate,
+ step_each_epoch,
+ epochs,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(Cosine, self).__init__()
self.learning_rate = learning_rate
self.T_max = step_each_epoch * epochs
@@ -94,14 +100,16 @@ def __call__(self):
learning_rate = lr.CosineAnnealingDecay(
learning_rate=self.learning_rate,
T_max=self.T_max,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -117,14 +125,16 @@ class Step(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- learning_rate,
- step_size,
- step_each_epoch,
- gamma,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate,
+ step_size,
+ step_each_epoch,
+ gamma,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(Step, self).__init__()
self.step_size = step_each_epoch * step_size
self.learning_rate = learning_rate
@@ -137,14 +147,16 @@ def __call__(self):
learning_rate=self.learning_rate,
step_size=self.step_size,
gamma=self.gamma,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -158,13 +170,15 @@ class Piecewise(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- step_each_epoch,
- decay_epochs,
- values,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ step_each_epoch,
+ decay_epochs,
+ values,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(Piecewise, self).__init__()
self.boundaries = [step_each_epoch * e for e in decay_epochs]
self.values = values
@@ -173,16 +187,16 @@ def __init__(self,
def __call__(self):
learning_rate = lr.PiecewiseDecay(
- boundaries=self.boundaries,
- values=self.values,
- last_epoch=self.last_epoch)
+ boundaries=self.boundaries, values=self.values, last_epoch=self.last_epoch
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.values[0],
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -197,14 +211,16 @@ class CyclicalCosine(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- learning_rate,
- step_each_epoch,
- epochs,
- cycle,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate,
+ step_each_epoch,
+ epochs,
+ cycle,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(CyclicalCosine, self).__init__()
self.learning_rate = learning_rate
self.T_max = step_each_epoch * epochs
@@ -217,14 +233,16 @@ def __call__(self):
learning_rate=self.learning_rate,
T_max=self.T_max,
cycle=self.cycle,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -235,22 +253,24 @@ class OneCycle(object):
max_lr(float): Upper learning rate boundaries
epochs(int): total training epochs
step_each_epoch(int): steps each epoch
- anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
+ anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
Default: ‘cos’
- three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
+ three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- max_lr,
- epochs,
- step_each_epoch,
- anneal_strategy='cos',
- three_phase=False,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ max_lr,
+ epochs,
+ step_each_epoch,
+ anneal_strategy="cos",
+ three_phase=False,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(OneCycle, self).__init__()
self.max_lr = max_lr
self.epochs = epochs
@@ -267,14 +287,16 @@ def __call__(self):
steps_per_epoch=self.steps_per_epoch,
anneal_strategy=self.anneal_strategy,
three_phase=self.three_phase,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.max_lr,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -287,12 +309,9 @@ class Const(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- learning_rate,
- step_each_epoch,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self, learning_rate, step_each_epoch, warmup_epoch=0, last_epoch=-1, **kwargs
+ ):
super(Const, self).__init__()
self.learning_rate = learning_rate
self.last_epoch = last_epoch
@@ -306,7 +325,8 @@ def __call__(self):
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -322,13 +342,9 @@ class DecayLearningRate(object):
end_lr(float): The minimum final learning rate. Default: 0.0.
"""
- def __init__(self,
- learning_rate,
- step_each_epoch,
- epochs,
- factor=0.9,
- end_lr=0,
- **kwargs):
+ def __init__(
+ self, learning_rate, step_each_epoch, epochs, factor=0.9, end_lr=0, **kwargs
+ ):
super(DecayLearningRate, self).__init__()
self.learning_rate = learning_rate
self.epochs = epochs + 1
@@ -341,7 +357,8 @@ def __call__(self):
learning_rate=self.learning_rate,
decay_steps=self.decay_steps,
power=self.factor,
- end_lr=self.end_lr)
+ end_lr=self.end_lr,
+ )
return learning_rate
@@ -357,14 +374,16 @@ class MultiStepDecay(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- learning_rate,
- milestones,
- step_each_epoch,
- gamma,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate,
+ milestones,
+ step_each_epoch,
+ gamma,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(MultiStepDecay, self).__init__()
self.milestones = [step_each_epoch * e for e in milestones]
self.learning_rate = learning_rate
@@ -377,14 +396,16 @@ def __call__(self):
learning_rate=self.learning_rate,
milestones=self.milestones,
gamma=self.gamma,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
@@ -399,13 +420,15 @@ class TwoStepCosine(object):
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
- def __init__(self,
- learning_rate,
- step_each_epoch,
- epochs,
- warmup_epoch=0,
- last_epoch=-1,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate,
+ step_each_epoch,
+ epochs,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs
+ ):
super(TwoStepCosine, self).__init__()
self.learning_rate = learning_rate
self.T_max1 = step_each_epoch * 200
@@ -418,12 +441,14 @@ def __call__(self):
learning_rate=self.learning_rate,
T_max1=self.T_max1,
T_max2=self.T_max2,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
- last_epoch=self.last_epoch)
+ last_epoch=self.last_epoch,
+ )
return learning_rate
diff --git a/ppocr/optimizer/lr_scheduler.py b/ppocr/optimizer/lr_scheduler.py
index cd09367e2a..4034e148eb 100644
--- a/ppocr/optimizer/lr_scheduler.py
+++ b/ppocr/optimizer/lr_scheduler.py
@@ -17,13 +17,9 @@
class CyclicalCosineDecay(LRScheduler):
- def __init__(self,
- learning_rate,
- T_max,
- cycle=1,
- last_epoch=-1,
- eta_min=0.0,
- verbose=False):
+ def __init__(
+ self, learning_rate, T_max, cycle=1, last_epoch=-1, eta_min=0.0, verbose=False
+ ):
"""
Cyclical cosine learning rate decay
A learning rate which can be referred in https://arxiv.org/pdf/2012.12645.pdf
@@ -35,8 +31,7 @@ def __init__(self,
eta_min(float): minimum learning rate during training
verbose(bool): whether to print learning rate for each epoch
"""
- super(CyclicalCosineDecay, self).__init__(learning_rate, last_epoch,
- verbose)
+ super(CyclicalCosineDecay, self).__init__(learning_rate, last_epoch, verbose)
self.cycle = cycle
self.eta_min = eta_min
@@ -44,8 +39,9 @@ def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
reletive_epoch = self.last_epoch % self.cycle
- lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
- (1 + math.cos(math.pi * reletive_epoch / self.cycle))
+ lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * (
+ 1 + math.cos(math.pi * reletive_epoch / self.cycle)
+ )
return lr
@@ -56,26 +52,30 @@ class OneCycleDecay(LRScheduler):
Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
"""
- def __init__(self,
- max_lr,
- epochs=None,
- steps_per_epoch=None,
- pct_start=0.3,
- anneal_strategy='cos',
- div_factor=25.,
- final_div_factor=1e4,
- three_phase=False,
- last_epoch=-1,
- verbose=False):
-
+ def __init__(
+ self,
+ max_lr,
+ epochs=None,
+ steps_per_epoch=None,
+ pct_start=0.3,
+ anneal_strategy="cos",
+ div_factor=25.0,
+ final_div_factor=1e4,
+ three_phase=False,
+ last_epoch=-1,
+ verbose=False,
+ ):
# Validate total_steps
if epochs <= 0 or not isinstance(epochs, int):
raise ValueError(
- "Expected positive integer epochs, but got {}".format(epochs))
+ "Expected positive integer epochs, but got {}".format(epochs)
+ )
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
raise ValueError(
"Expected positive integer steps_per_epoch, but got {}".format(
- steps_per_epoch))
+ steps_per_epoch
+ )
+ )
self.total_steps = epochs * steps_per_epoch
self.max_lr = max_lr
@@ -85,49 +85,51 @@ def __init__(self,
if three_phase:
self._schedule_phases = [
{
- 'end_step': float(pct_start * self.total_steps) - 1,
- 'start_lr': self.initial_lr,
- 'end_lr': self.max_lr,
+ "end_step": float(pct_start * self.total_steps) - 1,
+ "start_lr": self.initial_lr,
+ "end_lr": self.max_lr,
},
{
- 'end_step': float(2 * pct_start * self.total_steps) - 2,
- 'start_lr': self.max_lr,
- 'end_lr': self.initial_lr,
+ "end_step": float(2 * pct_start * self.total_steps) - 2,
+ "start_lr": self.max_lr,
+ "end_lr": self.initial_lr,
},
{
- 'end_step': self.total_steps - 1,
- 'start_lr': self.initial_lr,
- 'end_lr': self.min_lr,
+ "end_step": self.total_steps - 1,
+ "start_lr": self.initial_lr,
+ "end_lr": self.min_lr,
},
]
else:
self._schedule_phases = [
{
- 'end_step': float(pct_start * self.total_steps) - 1,
- 'start_lr': self.initial_lr,
- 'end_lr': self.max_lr,
+ "end_step": float(pct_start * self.total_steps) - 1,
+ "start_lr": self.initial_lr,
+ "end_lr": self.max_lr,
},
{
- 'end_step': self.total_steps - 1,
- 'start_lr': self.max_lr,
- 'end_lr': self.min_lr,
+ "end_step": self.total_steps - 1,
+ "start_lr": self.max_lr,
+ "end_lr": self.min_lr,
},
]
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError(
- "Expected float between 0 and 1 pct_start, but got {}".format(
- pct_start))
+ "Expected float between 0 and 1 pct_start, but got {}".format(pct_start)
+ )
# Validate anneal_strategy
- if anneal_strategy not in ['cos', 'linear']:
+ if anneal_strategy not in ["cos", "linear"]:
raise ValueError(
- "anneal_strategy must by one of 'cos' or 'linear', instead got {}".
- format(anneal_strategy))
- elif anneal_strategy == 'cos':
+ "anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(
+ anneal_strategy
+ )
+ )
+ elif anneal_strategy == "cos":
self.anneal_func = self._annealing_cos
- elif anneal_strategy == 'linear':
+ elif anneal_strategy == "linear":
self.anneal_func = self._annealing_linear
super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)
@@ -147,76 +149,92 @@ def get_lr(self):
if step_num > self.total_steps:
raise ValueError(
- "Tried to step {} times. The specified number of total steps is {}"
- .format(step_num + 1, self.total_steps))
+ "Tried to step {} times. The specified number of total steps is {}".format(
+ step_num + 1, self.total_steps
+ )
+ )
start_step = 0
for i, phase in enumerate(self._schedule_phases):
- end_step = phase['end_step']
+ end_step = phase["end_step"]
if step_num <= end_step or i == len(self._schedule_phases) - 1:
pct = (step_num - start_step) / (end_step - start_step)
- computed_lr = self.anneal_func(phase['start_lr'],
- phase['end_lr'], pct)
+ computed_lr = self.anneal_func(phase["start_lr"], phase["end_lr"], pct)
break
- start_step = phase['end_step']
+ start_step = phase["end_step"]
return computed_lr
class TwoStepCosineDecay(LRScheduler):
- def __init__(self,
- learning_rate,
- T_max1,
- T_max2,
- eta_min=0,
- last_epoch=-1,
- verbose=False):
+ def __init__(
+ self, learning_rate, T_max1, T_max2, eta_min=0, last_epoch=-1, verbose=False
+ ):
if not isinstance(T_max1, int):
raise TypeError(
"The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
- % type(T_max1))
+ % type(T_max1)
+ )
if not isinstance(T_max2, int):
raise TypeError(
"The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
- % type(T_max2))
+ % type(T_max2)
+ )
if not isinstance(eta_min, (float, int)):
raise TypeError(
"The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
- % type(eta_min))
+ % type(eta_min)
+ )
assert T_max1 > 0 and isinstance(
- T_max1, int), " 'T_max1' must be a positive integer."
+ T_max1, int
+ ), " 'T_max1' must be a positive integer."
assert T_max2 > 0 and isinstance(
- T_max2, int), " 'T_max1' must be a positive integer."
+ T_max2, int
+ ), " 'T_max1' must be a positive integer."
self.T_max1 = T_max1
self.T_max2 = T_max2
self.eta_min = float(eta_min)
- super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch,
- verbose)
+ super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
-
if self.last_epoch <= self.T_max1:
if self.last_epoch == 0:
return self.base_lr
elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
- return self.last_lr + (self.base_lr - self.eta_min) * (
- 1 - math.cos(math.pi / self.T_max1)) / 2
+ return (
+ self.last_lr
+ + (self.base_lr - self.eta_min)
+ * (1 - math.cos(math.pi / self.T_max1))
+ / 2
+ )
return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
- 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * (
- self.last_lr - self.eta_min) + self.eta_min
+ 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)
+ ) * (self.last_lr - self.eta_min) + self.eta_min
else:
if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
- return self.last_lr + (self.base_lr - self.eta_min) * (
- 1 - math.cos(math.pi / self.T_max2)) / 2
+ return (
+ self.last_lr
+ + (self.base_lr - self.eta_min)
+ * (1 - math.cos(math.pi / self.T_max2))
+ / 2
+ )
return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
- 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * (
- self.last_lr - self.eta_min) + self.eta_min
+ 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)
+ ) * (self.last_lr - self.eta_min) + self.eta_min
def _get_closed_form_lr(self):
if self.last_epoch <= self.T_max1:
- return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
- math.pi * self.last_epoch / self.T_max1)) / 2
+ return (
+ self.eta_min
+ + (self.base_lr - self.eta_min)
+ * (1 + math.cos(math.pi * self.last_epoch / self.T_max1))
+ / 2
+ )
else:
- return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
- math.pi * self.last_epoch / self.T_max2)) / 2
+ return (
+ self.eta_min
+ + (self.base_lr - self.eta_min)
+ * (1 + math.cos(math.pi * self.last_epoch / self.T_max2))
+ / 2
+ )
diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py
index ffe72d7db3..ee236ff5c2 100644
--- a/ppocr/optimizer/optimizer.py
+++ b/ppocr/optimizer/optimizer.py
@@ -30,12 +30,9 @@ class Momentum(object):
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
- def __init__(self,
- learning_rate,
- momentum,
- weight_decay=None,
- grad_clip=None,
- **args):
+ def __init__(
+ self, learning_rate, momentum, weight_decay=None, grad_clip=None, **args
+ ):
super(Momentum, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
@@ -51,22 +48,25 @@ def __call__(self, model):
momentum=self.momentum,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
- parameters=train_params)
+ parameters=train_params,
+ )
return opt
class Adam(object):
- def __init__(self,
- learning_rate=0.001,
- beta1=0.9,
- beta2=0.999,
- epsilon=1e-08,
- parameter_list=None,
- weight_decay=None,
- grad_clip=None,
- name=None,
- lazy_mode=False,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-08,
+ parameter_list=None,
+ weight_decay=None,
+ grad_clip=None,
+ name=None,
+ lazy_mode=False,
+ **kwargs
+ ):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
@@ -77,25 +77,26 @@ def __init__(self,
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
- self.group_lr = kwargs.get('group_lr', False)
- self.training_step = kwargs.get('training_step', None)
+ self.group_lr = kwargs.get("group_lr", False)
+ self.training_step = kwargs.get("training_step", None)
def __call__(self, model):
if self.group_lr:
- if self.training_step == 'LF_2':
+ if self.training_step == "LF_2":
import paddle
+
if isinstance(model, paddle.DataParallel): # multi gpu
mlm = model._layers.head.MLM_VRM.MLM.parameters()
- pre_mlm_pp = model._layers.head.MLM_VRM.Prediction.pp_share.parameters(
+ pre_mlm_pp = (
+ model._layers.head.MLM_VRM.Prediction.pp_share.parameters()
)
- pre_mlm_w = model._layers.head.MLM_VRM.Prediction.w_share.parameters(
+ pre_mlm_w = (
+ model._layers.head.MLM_VRM.Prediction.w_share.parameters()
)
else: # single gpu
mlm = model.head.MLM_VRM.MLM.parameters()
- pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters(
- )
- pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters(
- )
+ pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters()
+ pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters()
total = []
for param in mlm:
@@ -109,23 +110,20 @@ def __call__(self, model):
param for param in model.parameters() if id(param) in total
]
group_small_params = [
- param for param in model.parameters()
- if id(param) not in total
+ param for param in model.parameters() if id(param) not in total
+ ]
+ train_params = [
+ {"params": group_base_params},
+ {
+ "params": group_small_params,
+ "learning_rate": self.learning_rate.values[0] * 0.1,
+ },
]
- train_params = [{
- 'params': group_base_params
- }, {
- 'params': group_small_params,
- 'learning_rate': self.learning_rate.values[0] * 0.1
- }]
else:
- print(
- 'group lr currently only support VisionLAN in LF_2 training step'
- )
+ print("group lr currently only support VisionLAN in LF_2 training step")
train_params = [
- param for param in model.parameters()
- if param.trainable is True
+ param for param in model.parameters() if param.trainable is True
]
else:
train_params = [
@@ -141,7 +139,8 @@ def __call__(self, model):
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
- parameters=train_params)
+ parameters=train_params,
+ )
return opt
@@ -157,14 +156,16 @@ class RMSProp(object):
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
- def __init__(self,
- learning_rate,
- momentum=0.0,
- rho=0.95,
- epsilon=1e-6,
- weight_decay=None,
- grad_clip=None,
- **args):
+ def __init__(
+ self,
+ learning_rate,
+ momentum=0.0,
+ rho=0.95,
+ epsilon=1e-6,
+ weight_decay=None,
+ grad_clip=None,
+ **args
+ ):
super(RMSProp, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
@@ -184,20 +185,23 @@ def __call__(self, model):
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
- parameters=train_params)
+ parameters=train_params,
+ )
return opt
class Adadelta(object):
- def __init__(self,
- learning_rate=0.001,
- epsilon=1e-08,
- rho=0.95,
- parameter_list=None,
- weight_decay=None,
- grad_clip=None,
- name=None,
- **kwargs):
+ def __init__(
+ self,
+ learning_rate=0.001,
+ epsilon=1e-08,
+ rho=0.95,
+ parameter_list=None,
+ weight_decay=None,
+ grad_clip=None,
+ name=None,
+ **kwargs
+ ):
self.learning_rate = learning_rate
self.epsilon = epsilon
self.rho = rho
@@ -218,24 +222,27 @@ def __call__(self, model):
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
- parameters=train_params)
+ parameters=train_params,
+ )
return opt
class AdamW(object):
- def __init__(self,
- learning_rate=0.001,
- beta1=0.9,
- beta2=0.999,
- epsilon=1e-8,
- weight_decay=0.01,
- multi_precision=False,
- grad_clip=None,
- no_weight_decay_name=None,
- one_dim_param_no_weight_decay=False,
- name=None,
- lazy_mode=False,
- **args):
+ def __init__(
+ self,
+ learning_rate=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8,
+ weight_decay=0.01,
+ multi_precision=False,
+ grad_clip=None,
+ no_weight_decay_name=None,
+ one_dim_param_no_weight_decay=False,
+ name=None,
+ lazy_mode=False,
+ **args
+ ):
super().__init__()
self.learning_rate = learning_rate
self.beta1 = beta1
@@ -247,17 +254,17 @@ def __init__(self,
self.name = name
self.lazy_mode = lazy_mode
self.multi_precision = multi_precision
- self.no_weight_decay_name_list = no_weight_decay_name.split(
- ) if no_weight_decay_name else []
+ self.no_weight_decay_name_list = (
+ no_weight_decay_name.split() if no_weight_decay_name else []
+ )
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model):
- parameters = [
- param for param in model.parameters() if param.trainable is True
- ]
+ parameters = [param for param in model.parameters() if param.trainable is True]
self.no_weight_decay_param_name_list = [
- p.name for n, p in model.named_parameters()
+ p.name
+ for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
]
@@ -277,7 +284,8 @@ def __call__(self, model):
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
- apply_decay_param_fun=self._apply_decay_param_fun)
+ apply_decay_param_fun=self._apply_decay_param_fun,
+ )
return opt
def _apply_decay_param_fun(self, name):
diff --git a/ppocr/optimizer/regularizer.py b/ppocr/optimizer/regularizer.py
index 2ce68f7139..740ad1cfe1 100644
--- a/ppocr/optimizer/regularizer.py
+++ b/ppocr/optimizer/regularizer.py
@@ -48,4 +48,4 @@ def __init__(self, factor=0.0):
self.coeff = float(factor)
def __call__(self):
- return self.coeff
\ No newline at end of file
+ return self.coeff
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index f05316cacd..e0a6a87fd3 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -19,21 +19,40 @@
import copy
-__all__ = ['build_post_process']
+__all__ = ["build_post_process"]
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
-from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
- DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
- SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
- SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode, ParseQLabelDecode, \
- CPPDLabelDecode
+from .rec_postprocess import (
+ CTCLabelDecode,
+ AttnLabelDecode,
+ SRNLabelDecode,
+ DistillationCTCLabelDecode,
+ NRTRLabelDecode,
+ SARLabelDecode,
+ SEEDLabelDecode,
+ PRENLabelDecode,
+ ViTSTRLabelDecode,
+ ABINetLabelDecode,
+ SPINLabelDecode,
+ VLLabelDecode,
+ RFLLabelDecode,
+ SATRNLabelDecode,
+ ParseQLabelDecode,
+ CPPDLabelDecode,
+)
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
-from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
-from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
+from .vqa_token_ser_layoutlm_postprocess import (
+ VQASerTokenLayoutLMPostProcess,
+ DistillationSerPostProcess,
+)
+from .vqa_token_re_layoutlm_postprocess import (
+ VQAReTokenLayoutLMPostProcess,
+ DistillationRePostProcess,
+)
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess
from .ct_postprocess import CTPostProcess
@@ -43,31 +62,55 @@
def build_post_process(config, global_config=None):
support_dict = [
- 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
- 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
- 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
- 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
- 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
- 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
- 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
- 'TableMasterLabelDecode', 'SPINLabelDecode',
- 'DistillationSerPostProcess', 'DistillationRePostProcess',
- 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
- 'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode',
- 'SATRNLabelDecode', 'ParseQLabelDecode', 'CPPDLabelDecode'
+ "DBPostProcess",
+ "EASTPostProcess",
+ "SASTPostProcess",
+ "FCEPostProcess",
+ "CTCLabelDecode",
+ "AttnLabelDecode",
+ "ClsPostProcess",
+ "SRNLabelDecode",
+ "PGPostProcess",
+ "DistillationCTCLabelDecode",
+ "TableLabelDecode",
+ "DistillationDBPostProcess",
+ "NRTRLabelDecode",
+ "SARLabelDecode",
+ "SEEDLabelDecode",
+ "VQASerTokenLayoutLMPostProcess",
+ "VQAReTokenLayoutLMPostProcess",
+ "PRENLabelDecode",
+ "DistillationSARLabelDecode",
+ "ViTSTRLabelDecode",
+ "ABINetLabelDecode",
+ "TableMasterLabelDecode",
+ "SPINLabelDecode",
+ "DistillationSerPostProcess",
+ "DistillationRePostProcess",
+ "VLLabelDecode",
+ "PicoDetPostProcess",
+ "CTPostProcess",
+ "RFLLabelDecode",
+ "DRRGPostprocess",
+ "CANLabelDecode",
+ "SATRNLabelDecode",
+ "ParseQLabelDecode",
+ "CPPDLabelDecode",
]
- if config['name'] == 'PSEPostProcess':
+ if config["name"] == "PSEPostProcess":
from .pse_postprocess import PSEPostProcess
- support_dict.append('PSEPostProcess')
+
+ support_dict.append("PSEPostProcess")
config = copy.deepcopy(config)
- module_name = config.pop('name')
+ module_name = config.pop("name")
if module_name == "None":
return
if global_config is not None:
config.update(global_config)
assert module_name in support_dict, Exception(
- 'post process only support {}'.format(support_dict))
+ "post process only support {}".format(support_dict)
+ )
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/postprocess/cls_postprocess.py b/ppocr/postprocess/cls_postprocess.py
index 9a27ba0831..06c7693a90 100644
--- a/ppocr/postprocess/cls_postprocess.py
+++ b/ppocr/postprocess/cls_postprocess.py
@@ -15,7 +15,7 @@
class ClsPostProcess(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, label_list=None, key=None, **kwargs):
super(ClsPostProcess, self).__init__()
@@ -34,8 +34,9 @@ def __call__(self, preds, label=None, *args, **kwargs):
preds = preds.numpy()
pred_idxs = preds.argmax(axis=1)
- decode_out = [(label_list[idx], preds[i, idx])
- for i, idx in enumerate(pred_idxs)]
+ decode_out = [
+ (label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
+ ]
if label is None:
return decode_out
label = [(label_list[idx], 1.0) for idx in label]
diff --git a/ppocr/postprocess/ct_postprocess.py b/ppocr/postprocess/ct_postprocess.py
index 3ab90be24d..ede05e2fad 100755
--- a/ppocr/postprocess/ct_postprocess.py
+++ b/ppocr/postprocess/ct_postprocess.py
@@ -33,7 +33,7 @@ class CTPostProcess(object):
The post process for Centripetal Text (CT).
"""
- def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs):
+ def __init__(self, min_score=0.88, min_area=16, box_type="poly", **kwargs):
self.min_score = min_score
self.min_area = min_area
self.box_type = box_type
@@ -45,8 +45,8 @@ def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs):
self.coord[1, i, j] = i
def __call__(self, preds, batch):
- outs = preds['maps']
- out_scores = preds['score']
+ outs = preds["maps"]
+ out_scores = preds["score"]
if isinstance(outs, paddle.Tensor):
outs = outs.numpy()
@@ -77,13 +77,11 @@ def __call__(self, preds, batch):
kernel = kernel[0].astype(np.uint8)
loc = loc[0].astype(np.float32)
- label_num, label_kernel = cv2.connectedComponents(
- kernel, connectivity=4)
+ label_num, label_kernel = cv2.connectedComponents(kernel, connectivity=4)
for i in range(1, label_num):
- ind = (label_kernel == i)
- if ind.sum(
- ) < 10: # pixel number less than 10, treated as background
+ ind = label_kernel == i
+ if ind.sum() < 10: # pixel number less than 10, treated as background
label_kernel[ind] = 0
label = np.zeros_like(label_kernel)
@@ -91,18 +89,20 @@ def __call__(self, preds, batch):
pixels = self.coord[:, :h, :w].reshape(2, -1)
points = pixels.transpose([1, 0]).astype(np.float32)
- off_points = (points + 10. / 4. * loc[:, pixels[1], pixels[0]].T
- ).astype(np.int32)
+ off_points = (points + 10.0 / 4.0 * loc[:, pixels[1], pixels[0]].T).astype(
+ np.int32
+ )
off_points[:, 0] = np.clip(off_points[:, 0], 0, label.shape[1] - 1)
off_points[:, 1] = np.clip(off_points[:, 1], 0, label.shape[0] - 1)
- label[pixels[1], pixels[0]] = label_kernel[off_points[:, 1],
- off_points[:, 0]]
+ label[pixels[1], pixels[0]] = label_kernel[
+ off_points[:, 1], off_points[:, 0]
+ ]
label[label_kernel > 0] = label_kernel[label_kernel > 0]
score_pocket = [0.0]
for i in range(1, label_num):
- ind = (label_kernel == i)
+ ind = label_kernel == i
if ind.sum() == 0:
score_pocket.append(0.0)
continue
@@ -111,14 +111,16 @@ def __call__(self, preds, batch):
label_num = np.max(label) + 1
label = cv2.resize(
- label, (img_size[1], img_size[0]),
- interpolation=cv2.INTER_NEAREST)
+ label, (img_size[1], img_size[0]), interpolation=cv2.INTER_NEAREST
+ )
- scale = (float(org_img_size[1]) / float(img_size[1]),
- float(org_img_size[0]) / float(img_size[0]))
+ scale = (
+ float(org_img_size[1]) / float(img_size[1]),
+ float(org_img_size[0]) / float(img_size[0]),
+ )
for i in range(1, label_num):
- ind = (label == i)
+ ind = label == i
points = np.array(np.where(ind)).transpose((1, 0))
if points.shape[0] < self.min_area:
@@ -128,27 +130,29 @@ def __call__(self, preds, batch):
if score_i < self.min_score:
continue
- if self.box_type == 'rect':
+ if self.box_type == "rect":
rect = cv2.minAreaRect(points[:, ::-1])
bbox = cv2.boxPoints(rect) * scale
z = bbox.mean(0)
bbox = z + (bbox - z) * 0.85
- elif self.box_type == 'poly':
- binary = np.zeros(label.shape, dtype='uint8')
+ elif self.box_type == "poly":
+ binary = np.zeros(label.shape, dtype="uint8")
binary[ind] = 1
try:
_, contours, _ = cv2.findContours(
- binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
except BaseException:
contours, _ = cv2.findContours(
- binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
bbox = contours[0] * scale
- bbox = bbox.astype('int32')
+ bbox = bbox.astype("int32")
bboxes.append(bbox.reshape(-1, 2))
scores.append(score_i)
- boxes_batch.append({'points': bboxes})
+ boxes_batch.append({"points": bboxes})
return boxes_batch
diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py
index 244825b76a..cdd050d190 100755
--- a/ppocr/postprocess/db_postprocess.py
+++ b/ppocr/postprocess/db_postprocess.py
@@ -31,15 +31,17 @@ class DBPostProcess(object):
The post process for Differentiable Binarization (DB).
"""
- def __init__(self,
- thresh=0.3,
- box_thresh=0.7,
- max_candidates=1000,
- unclip_ratio=2.0,
- use_dilation=False,
- score_mode="fast",
- box_type='quad',
- **kwargs):
+ def __init__(
+ self,
+ thresh=0.3,
+ box_thresh=0.7,
+ max_candidates=1000,
+ unclip_ratio=2.0,
+ use_dilation=False,
+ score_mode="fast",
+ box_type="quad",
+ **kwargs
+ ):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
@@ -48,17 +50,17 @@ def __init__(self,
self.score_mode = score_mode
self.box_type = box_type
assert score_mode in [
- "slow", "fast"
+ "slow",
+ "fast",
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
- self.dilation_kernel = None if not use_dilation else np.array(
- [[1, 1], [1, 1]])
+ self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
- '''
+ """
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
- '''
+ """
bitmap = _bitmap
height, width = bitmap.shape
@@ -66,10 +68,11 @@ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
boxes = []
scores = []
- contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
- cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+ contours, _ = cv2.findContours(
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+ )
- for contour in contours[:self.max_candidates]:
+ for contour in contours[: self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
@@ -93,25 +96,26 @@ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
continue
box = np.array(box)
- box[:, 0] = np.clip(
- np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
- np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
+ )
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
- '''
+ """
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
- '''
+ """
bitmap = _bitmap
height, width = bitmap.shape
- outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
- cv2.CHAIN_APPROX_SIMPLE)
+ outs = cv2.findContours(
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+ )
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
@@ -140,10 +144,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
continue
box = np.array(box)
- box[:, 0] = np.clip(
- np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
- np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
+ )
boxes.append(box.astype("int32"))
scores.append(score)
return np.array(boxes, dtype="int32"), scores
@@ -174,15 +178,13 @@ def get_mini_boxes(self, contour):
index_2 = 3
index_3 = 2
- box = [
- points[index_1], points[index_2], points[index_3], points[index_4]
- ]
+ box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
- '''
+ """
box_score_fast: use bbox mean score as the mean score
- '''
+ """
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
@@ -194,12 +196,12 @@ def box_score_fast(self, bitmap, _box):
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
- return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
- '''
+ """
box_score_slow: use polyon mean score as the mean score
- '''
+ """
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
@@ -215,10 +217,10 @@ def box_score_slow(self, bitmap, contour):
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
- return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
- pred = outs_dict['maps']
+ pred = outs_dict["maps"]
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
@@ -230,34 +232,39 @@ def __call__(self, outs_dict, shape_list):
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
- self.dilation_kernel)
+ self.dilation_kernel,
+ )
else:
mask = segmentation[batch_index]
- if self.box_type == 'poly':
- boxes, scores = self.polygons_from_bitmap(pred[batch_index],
- mask, src_w, src_h)
- elif self.box_type == 'quad':
- boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
- src_w, src_h)
+ if self.box_type == "poly":
+ boxes, scores = self.polygons_from_bitmap(
+ pred[batch_index], mask, src_w, src_h
+ )
+ elif self.box_type == "quad":
+ boxes, scores = self.boxes_from_bitmap(
+ pred[batch_index], mask, src_w, src_h
+ )
else:
raise ValueError("box_type can only be one of ['quad', 'poly']")
- boxes_batch.append({'points': boxes})
+ boxes_batch.append({"points": boxes})
return boxes_batch
class DistillationDBPostProcess(object):
- def __init__(self,
- model_name=["student"],
- key=None,
- thresh=0.3,
- box_thresh=0.6,
- max_candidates=1000,
- unclip_ratio=1.5,
- use_dilation=False,
- score_mode="fast",
- box_type='quad',
- **kwargs):
+ def __init__(
+ self,
+ model_name=["student"],
+ key=None,
+ thresh=0.3,
+ box_thresh=0.6,
+ max_candidates=1000,
+ unclip_ratio=1.5,
+ use_dilation=False,
+ score_mode="fast",
+ box_type="quad",
+ **kwargs
+ ):
self.model_name = model_name
self.key = key
self.post_process = DBPostProcess(
@@ -267,7 +274,8 @@ def __init__(self,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
score_mode=score_mode,
- box_type=box_type)
+ box_type=box_type,
+ )
def __call__(self, predicts, shape_list):
results = {}
diff --git a/ppocr/postprocess/drrg_postprocess.py b/ppocr/postprocess/drrg_postprocess.py
index 56fd034f7c..f1241c49d8 100644
--- a/ppocr/postprocess/drrg_postprocess.py
+++ b/ppocr/postprocess/drrg_postprocess.py
@@ -43,7 +43,7 @@ def add_link(self, link_node):
link_node.__links.add(self)
-def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
+def graph_propagation(edges, scores, text_comps, edge_len_thr=50.0):
assert edges.ndim == 2
assert edges.shape[1] == 2
assert edges.shape[0] == scores.shape[0]
@@ -63,7 +63,8 @@ def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
scores[i] = 0
if (edge[0], edge[1]) in score_dict:
score_dict[edge[0], edge[1]] = 0.5 * (
- score_dict[edge[0], edge[1]] + scores[i])
+ score_dict[edge[0], edge[1]] + scores[i]
+ )
else:
score_dict[edge[0], edge[1]] = scores[i]
@@ -92,10 +93,13 @@ def connected_components(nodes, score_dict, link_thr):
node_queue = [node]
while node_queue:
node = node_queue.pop(0)
- neighbors = set([
- neighbor for neighbor in node.links if
- score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
- ])
+ neighbors = set(
+ [
+ neighbor
+ for neighbor in node.links
+ if score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
+ ]
+ )
neighbors.difference_update(cluster)
nodes.difference_update(neighbors)
cluster.update(neighbors)
@@ -107,8 +111,7 @@ def connected_components(nodes, score_dict, link_thr):
def clusters2labels(clusters, num_nodes):
assert isinstance(clusters, list)
assert all([isinstance(cluster, list) for cluster in clusters])
- assert all(
- [isinstance(node, Node) for cluster in clusters for node in cluster])
+ assert all([isinstance(node, Node) for cluster in clusters for node in cluster])
assert isinstance(num_nodes, int)
node_labels = np.zeros(num_nodes)
@@ -125,7 +128,7 @@ def remove_single(text_comps, comp_pred_labels):
single_flags = np.zeros_like(comp_pred_labels)
pred_labels = np.unique(comp_pred_labels)
for label in pred_labels:
- current_label_flag = (comp_pred_labels == label)
+ current_label_flag = comp_pred_labels == label
if np.sum(current_label_flag) == 1:
single_flags[np.where(current_label_flag)[0][0]] = 1
keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
@@ -136,7 +139,7 @@ def remove_single(text_comps, comp_pred_labels):
def norm2(point1, point2):
- return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
+ return ((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2) ** 0.5
def min_connect_path(points):
@@ -225,8 +228,9 @@ def comps2boundaries(text_comps, comp_pred_labels):
return boundaries
for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
- text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
- (-1, 4, 2)).astype(np.int32)
+ text_comp_boxes = (
+ text_comps[cluster_comp_inds, :8].reshape((-1, 4, 2)).astype(np.int32)
+ )
score = np.mean(text_comps[cluster_comp_inds, -1])
if text_comp_boxes.shape[0] < 1:
@@ -236,12 +240,15 @@ def comps2boundaries(text_comps, comp_pred_labels):
centers = np.mean(text_comp_boxes, axis=1).astype(np.int32).tolist()
shortest_path = min_connect_path(centers)
text_comp_boxes = text_comp_boxes[shortest_path]
- top_line = np.mean(
- text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
- bot_line = np.mean(
- text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
+ top_line = (
+ np.mean(text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
+ )
+ bot_line = (
+ np.mean(text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
+ )
top_line, bot_line = fix_corner(
- top_line, bot_line, text_comp_boxes[0], text_comp_boxes[-1])
+ top_line, bot_line, text_comp_boxes[0], text_comp_boxes[-1]
+ )
boundary_points = top_line + bot_line[::-1]
else:
@@ -298,7 +305,8 @@ def __call__(self, preds, shape_list):
boundaries = []
boundaries, scores = self.resize_boundary(
- boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
+ boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]
+ )
boxes_batch = [dict(points=boundaries, scores=scores)]
return boxes_batch
@@ -318,8 +326,13 @@ def resize_boundary(self, boundaries, scale_factor):
for b in boundaries:
sz = len(b)
scores.append(b[-1])
- b = (np.array(b[:sz - 1]) *
- (np.tile(scale_factor[:2], int(
- (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
+ b = (
+ (
+ np.array(b[: sz - 1])
+ * (np.tile(scale_factor[:2], int((sz - 1) / 2)).reshape(1, sz - 1))
+ )
+ .flatten()
+ .tolist()
+ )
boxes.append(np.array(b).reshape([-1, 2]))
return boxes, scores
diff --git a/ppocr/postprocess/east_postprocess.py b/ppocr/postprocess/east_postprocess.py
index c1af3eccef..0b138adba1 100755
--- a/ppocr/postprocess/east_postprocess.py
+++ b/ppocr/postprocess/east_postprocess.py
@@ -31,12 +31,7 @@ class EASTPostProcess(object):
The post process for EAST.
"""
- def __init__(self,
- score_thresh=0.8,
- cover_thresh=0.1,
- nms_thresh=0.2,
- **kwargs):
-
+ def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs):
self.score_thresh = score_thresh
self.cover_thresh = cover_thresh
self.nms_thresh = nms_thresh
@@ -47,17 +42,15 @@ def restore_rectangle_quad(self, origin, geometry):
"""
# quad
origin_concat = np.concatenate(
- (origin, origin, origin, origin), axis=1) # (n, 8)
+ (origin, origin, origin, origin), axis=1
+ ) # (n, 8)
pred_quads = origin_concat - geometry
pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
return pred_quads
- def detect(self,
- score_map,
- geo_map,
- score_thresh=0.8,
- cover_thresh=0.1,
- nms_thresh=0.2):
+ def detect(
+ self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2
+ ):
"""
restore text boxes from score map and geo map
"""
@@ -71,30 +64,31 @@ def detect(self,
return []
# sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 0])]
- #restore quad proposals
+ # restore quad proposals
text_box_restored = self.restore_rectangle_quad(
- xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
+ xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]
+ )
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
try:
- check_install('lanms', 'lanms-nova')
+ check_install("lanms", "lanms-nova")
import lanms
+
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
except:
print(
- 'You should install lanms by pip3 install lanms-nova to speed up nms_locality'
+ "You should install lanms by pip3 install lanms-nova to speed up nms_locality"
)
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if boxes.shape[0] == 0:
return []
- # Here we filter some low score boxes by the average score map,
+ # Here we filter some low score boxes by the average score map,
# this is different from the orginal paper.
for i, box in enumerate(boxes):
mask = np.zeros_like(score_map, dtype=np.uint8)
- cv2.fillPoly(mask, box[:8].reshape(
- (-1, 4, 2)).astype(np.int32) // 4, 1)
+ cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
boxes[i, 8] = cv2.mean(score_map, mask)[0]
boxes = boxes[boxes[:, 8] > cover_thresh]
return boxes
@@ -104,16 +98,15 @@ def sort_poly(self, p):
Sort polygons.
"""
min_axis = np.argmin(np.sum(p, axis=1))
- p = p[[min_axis, (min_axis + 1) % 4,\
- (min_axis + 2) % 4, (min_axis + 3) % 4]]
+ p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
return p
else:
return p[[0, 3, 2, 1]]
def __call__(self, outs_dict, shape_list):
- score_list = outs_dict['f_score']
- geo_list = outs_dict['f_geo']
+ score_list = outs_dict["f_score"]
+ geo_list = outs_dict["f_geo"]
if isinstance(score_list, paddle.Tensor):
score_list = score_list.numpy()
geo_list = geo_list.numpy()
@@ -127,7 +120,8 @@ def __call__(self, outs_dict, shape_list):
geo_map=geo,
score_thresh=self.score_thresh,
cover_thresh=self.cover_thresh,
- nms_thresh=self.nms_thresh)
+ nms_thresh=self.nms_thresh,
+ )
boxes_norm = []
if len(boxes) > 0:
h, w = score.shape[1:]
@@ -137,9 +131,11 @@ def __call__(self, outs_dict, shape_list):
boxes[:, :, 1] /= ratio_h
for i_box, box in enumerate(boxes):
box = self.sort_poly(box.astype(np.int32))
- if np.linalg.norm(box[0] - box[1]) < 5 \
- or np.linalg.norm(box[3] - box[0]) < 5:
+ if (
+ np.linalg.norm(box[0] - box[1]) < 5
+ or np.linalg.norm(box[3] - box[0]) < 5
+ ):
continue
boxes_norm.append(box)
- dt_boxes_list.append({'points': np.array(boxes_norm)})
+ dt_boxes_list.append({"points": np.array(boxes_norm)})
return dt_boxes_list
diff --git a/ppocr/postprocess/fce_postprocess.py b/ppocr/postprocess/fce_postprocess.py
index 959f86efa4..ccdedc015d 100755
--- a/ppocr/postprocess/fce_postprocess.py
+++ b/ppocr/postprocess/fce_postprocess.py
@@ -26,38 +26,38 @@
def fill_hole(input_mask):
h, w = input_mask.shape
canvas = np.zeros((h + 2, w + 2), np.uint8)
- canvas[1:h + 1, 1:w + 1] = input_mask.copy()
+ canvas[1 : h + 1, 1 : w + 1] = input_mask.copy()
mask = np.zeros((h + 4, w + 4), np.uint8)
cv2.floodFill(canvas, mask, (0, 0), 1)
- canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool_)
+ canvas = canvas[1 : h + 1, 1 : w + 1].astype(np.bool_)
return ~canvas | input_mask
def fourier2poly(fourier_coeff, num_reconstr_points=50):
- """ Inverse Fourier transform
- Args:
- fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
- with n and k being candidates number and Fourier degree
- respectively.
- num_reconstr_points (int): Number of reconstructed polygon points.
- Returns:
- Polygons (ndarray): The reconstructed polygons shaped (n, n')
- """
+ """Inverse Fourier transform
+ Args:
+ fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
+ with n and k being candidates number and Fourier degree
+ respectively.
+ num_reconstr_points (int): Number of reconstructed polygon points.
+ Returns:
+ Polygons (ndarray): The reconstructed polygons shaped (n, n')
+ """
- a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex')
+ a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype="complex")
k = (len(fourier_coeff[0]) - 1) // 2
- a[:, 0:k + 1] = fourier_coeff[:, k:]
+ a[:, 0 : k + 1] = fourier_coeff[:, k:]
a[:, -k:] = fourier_coeff[:, :k]
poly_complex = ifft(a) * num_reconstr_points
polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
polygon[:, :, 0] = poly_complex.real
polygon[:, :, 1] = poly_complex.imag
- return polygon.astype('int32').reshape((len(fourier_coeff), -1))
+ return polygon.astype("int32").reshape((len(fourier_coeff), -1))
class FCEPostProcess(object):
@@ -65,18 +65,19 @@ class FCEPostProcess(object):
The post process for FCENet.
"""
- def __init__(self,
- scales,
- fourier_degree=5,
- num_reconstr_points=50,
- decoding_type='fcenet',
- score_thr=0.3,
- nms_thr=0.1,
- alpha=1.0,
- beta=1.0,
- box_type='poly',
- **kwargs):
-
+ def __init__(
+ self,
+ scales,
+ fourier_degree=5,
+ num_reconstr_points=50,
+ decoding_type="fcenet",
+ score_thr=0.3,
+ nms_thr=0.1,
+ alpha=1.0,
+ beta=1.0,
+ box_type="poly",
+ **kwargs
+ ):
self.scales = scales
self.fourier_degree = fourier_degree
self.num_reconstr_points = num_reconstr_points
@@ -115,9 +116,14 @@ def resize_boundary(self, boundaries, scale_factor):
sz = len(b)
valid_boundary(b, True)
scores.append(b[-1])
- b = (np.array(b[:sz - 1]) *
- (np.tile(scale_factor[:2], int(
- (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
+ b = (
+ (
+ np.array(b[: sz - 1])
+ * (np.tile(scale_factor[:2], int((sz - 1) / 2)).reshape(1, sz - 1))
+ )
+ .flatten()
+ .tolist()
+ )
boxes.append(np.array(b).reshape([-1, 2]))
return np.array(boxes, dtype=np.float32), scores
@@ -127,13 +133,13 @@ def get_boundary(self, score_maps, shape_list):
boundaries = []
for idx, score_map in enumerate(score_maps):
scale = self.scales[idx]
- boundaries = boundaries + self._get_boundary_single(score_map,
- scale)
+ boundaries = boundaries + self._get_boundary_single(score_map, scale)
# nms
boundaries = poly_nms(boundaries, self.nms_thr)
boundaries, scores = self.resize_boundary(
- boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
+ boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]
+ )
boxes_batch = [dict(points=boundaries, scores=scores)]
return boxes_batch
@@ -151,18 +157,21 @@ def _get_boundary_single(self, score_map, scale):
beta=self.beta,
box_type=self.box_type,
score_thr=self.score_thr,
- nms_thr=self.nms_thr)
-
- def fcenet_decode(self,
- preds,
- fourier_degree,
- num_reconstr_points,
- scale,
- alpha=1.0,
- beta=2.0,
- box_type='poly',
- score_thr=0.3,
- nms_thr=0.1):
+ nms_thr=self.nms_thr,
+ )
+
+ def fcenet_decode(
+ self,
+ preds,
+ fourier_degree,
+ num_reconstr_points,
+ scale,
+ alpha=1.0,
+ beta=2.0,
+ box_type="poly",
+ score_thr=0.3,
+ nms_thr=0.1,
+ ):
"""Decoding predictions of FCENet to instances.
Args:
@@ -186,23 +195,23 @@ def fcenet_decode(self,
"""
assert isinstance(preds, list)
assert len(preds) == 2
- assert box_type in ['poly', 'quad']
+ assert box_type in ["poly", "quad"]
cls_pred = preds[0][0]
tr_pred = cls_pred[0:2]
tcl_pred = cls_pred[2:]
reg_pred = preds[1][0].transpose([1, 2, 0])
- x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
- y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
+ x_pred = reg_pred[:, :, : 2 * fourier_degree + 1]
+ y_pred = reg_pred[:, :, 2 * fourier_degree + 1 :]
- score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
+ score_pred = (tr_pred[1] ** alpha) * (tcl_pred[1] ** beta)
tr_pred_mask = (score_pred) > score_thr
tr_mask = fill_hole(tr_pred_mask)
tr_contours, _ = cv2.findContours(
- tr_mask.astype(np.uint8), cv2.RETR_TREE,
- cv2.CHAIN_APPROX_SIMPLE) # opencv4
+ tr_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
+ ) # opencv4
mask = np.zeros_like(tr_mask)
boundaries = []
@@ -228,7 +237,7 @@ def fcenet_decode(self,
boundaries = poly_nms(boundaries, nms_thr)
- if box_type == 'quad':
+ if box_type == "quad":
new_boundaries = []
for boundary in boundaries:
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
diff --git a/ppocr/postprocess/locality_aware_nms.py b/ppocr/postprocess/locality_aware_nms.py
index d305ef6818..dbd9814cb8 100644
--- a/ppocr/postprocess/locality_aware_nms.py
+++ b/ppocr/postprocess/locality_aware_nms.py
@@ -34,7 +34,7 @@ def intersection_iog(g, p):
if not g.is_valid or not p.is_valid:
return 0
inter = Polygon(g).intersection(Polygon(p)).area
- #union = g.area + p.area - inter
+ # union = g.area + p.area - inter
union = p.area
if union == 0:
print("p_area is very small")
@@ -48,7 +48,7 @@ def weighted_merge(g, p):
Weighted merge.
"""
g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
- g[8] = (g[8] + p[8])
+ g[8] = g[8] + p[8]
return g
@@ -126,21 +126,21 @@ def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
tbox = boxes[i].copy()
ti = inds[i]
pos = i + 1
- #get max box
+ # get max box
while pos < N:
if maxscore < boxes[pos, 8]:
maxscore = boxes[pos, 8]
maxpos = pos
pos = pos + 1
- #add max box as a detection
+ # add max box as a detection
boxes[i, :] = boxes[maxpos, :]
inds[i] = inds[maxpos]
- #swap
+ # swap
boxes[maxpos, :] = tbox
inds[maxpos] = ti
tbox = boxes[i].copy()
pos = i + 1
- #NMS iteration
+ # NMS iteration
while pos < N:
sbox = boxes[pos].copy()
ts_iou_val = intersection(tbox, sbox)
@@ -158,8 +158,8 @@ def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
else:
weight = 1
boxes[pos, 8] = weight * boxes[pos, 8]
- #if box score falls below thresold, discard the box by
- #swaping last box update N
+ # if box score falls below thresold, discard the box by
+ # swaping last box update N
if boxes[pos, 8] < threshold:
boxes[pos, :] = boxes[N - 1, :]
inds[pos] = inds[N - 1]
@@ -193,8 +193,6 @@ def nms_locality(polys, thres=0.3):
return standard_nms(np.array(S), thres)
-if __name__ == '__main__':
+if __name__ == "__main__":
# 343,350,448,135,474,143,369,359
- print(
- Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]]))
- .area)
\ No newline at end of file
+ print(Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]])).area)
diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py
index 058cf8b907..4001cca680 100644
--- a/ppocr/postprocess/pg_postprocess.py
+++ b/ppocr/postprocess/pg_postprocess.py
@@ -21,7 +21,7 @@
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
-sys.path.append(os.path.join(__dir__, '..'))
+sys.path.append(os.path.join(__dir__, ".."))
from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
@@ -30,13 +30,15 @@ class PGPostProcess(object):
The post process for PGNet.
"""
- def __init__(self,
- character_dict_path,
- valid_set,
- score_thresh,
- mode,
- point_gather_mode=None,
- **kwargs):
+ def __init__(
+ self,
+ character_dict_path,
+ valid_set,
+ score_thresh,
+ mode,
+ point_gather_mode=None,
+ **kwargs
+ ):
self.character_dict_path = character_dict_path
self.valid_set = valid_set
self.score_thresh = score_thresh
@@ -55,8 +57,9 @@ def __call__(self, outs_dict, shape_list):
self.score_thresh,
outs_dict,
shape_list,
- point_gather_mode=self.point_gather_mode)
- if self.mode == 'fast':
+ point_gather_mode=self.point_gather_mode,
+ )
+ if self.mode == "fast":
data = post.pg_postprocess_fast()
else:
data = post.pg_postprocess_slow()
diff --git a/ppocr/postprocess/picodet_postprocess.py b/ppocr/postprocess/picodet_postprocess.py
index 4053714d30..9189c8ef5d 100644
--- a/ppocr/postprocess/picodet_postprocess.py
+++ b/ppocr/postprocess/picodet_postprocess.py
@@ -41,8 +41,8 @@ def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
- np.expand_dims(
- current_box, axis=0), )
+ np.expand_dims(current_box, axis=0),
+ )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
@@ -105,13 +105,15 @@ class PicoDetPostProcess(object):
enable_mkldnn (bool): whether to open MKLDNN
"""
- def __init__(self,
- layout_dict_path,
- strides=[8, 16, 32, 64],
- score_threshold=0.4,
- nms_threshold=0.5,
- nms_top_k=1000,
- keep_top_k=100):
+ def __init__(
+ self,
+ layout_dict_path,
+ strides=[8, 16, 32, 64],
+ score_threshold=0.4,
+ nms_threshold=0.5,
+ nms_top_k=1000,
+ keep_top_k=100,
+ ):
self.labels = self.load_layout_dict(layout_dict_path)
self.strides = strides
self.score_threshold = score_threshold
@@ -120,27 +122,28 @@ def __init__(self,
self.keep_top_k = keep_top_k
def load_layout_dict(self, layout_dict_path):
- with open(layout_dict_path, 'r', encoding='utf-8') as fp:
+ with open(layout_dict_path, "r", encoding="utf-8") as fp:
labels = fp.readlines()
- return [label.strip('\n') for label in labels]
+ return [label.strip("\n") for label in labels]
def warp_boxes(self, boxes, ori_shape):
- """Apply transform to boxes
- """
+ """Apply transform to boxes"""
width, height = ori_shape[1], ori_shape[0]
n = len(boxes)
if n:
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
- n * 4, 2) # x1y1, x2y2, x1y2, x2y1
+ n * 4, 2
+ ) # x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
- xy = np.concatenate(
- (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+ xy = (
+ np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+ )
# clip boxes
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
@@ -156,13 +159,13 @@ def img_info(self, ori_img, img):
scale_factor = np.array([im_scale_y, im_scale_x], dtype=np.float32)
img_shape = np.array(img.shape[2:], dtype=np.float32)
- input_shape = np.array(img).astype('float32').shape[2:]
- ori_shape = np.array((img_shape, )).astype('float32')
- scale_factor = np.array((scale_factor, )).astype('float32')
+ input_shape = np.array(img).astype("float32").shape[2:]
+ ori_shape = np.array((img_shape,)).astype("float32")
+ scale_factor = np.array((scale_factor,)).astype("float32")
return ori_shape, input_shape, scale_factor
def __call__(self, ori_img, img, preds):
- scores, raw_boxes = preds['boxes'], preds['boxes_num']
+ scores, raw_boxes = preds["boxes"], preds["boxes_num"]
batch_size = raw_boxes[0].shape[0]
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
out_boxes_num = []
@@ -174,8 +177,7 @@ def __call__(self, ori_img, img, preds):
# generate centers
decode_boxes = []
select_scores = []
- for stride, box_distribute, score in zip(self.strides, raw_boxes,
- scores):
+ for stride, box_distribute, score in zip(self.strides, raw_boxes, scores):
box_distribute = box_distribute[batch_id]
score = score[batch_id]
# centers
@@ -198,7 +200,7 @@ def __call__(self, ori_img, img, preds):
# top K candidate
topk_idx = np.argsort(score.max(axis=1))[::-1]
- topk_idx = topk_idx[:self.nms_top_k]
+ topk_idx = topk_idx[: self.nms_top_k]
center = center[topk_idx]
score = score[topk_idx]
box_distance = box_distance[topk_idx]
@@ -221,12 +223,12 @@ def __call__(self, ori_img, img, preds):
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
- box_probs = np.concatenate(
- [subset_boxes, probs.reshape(-1, 1)], axis=1)
+ box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
- top_k=self.keep_top_k, )
+ top_k=self.keep_top_k,
+ )
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
@@ -239,22 +241,23 @@ def __call__(self, ori_img, img, preds):
# resize output boxes
picked_box_probs[:, :4] = self.warp_boxes(
- picked_box_probs[:, :4], ori_shape[batch_id])
- im_scale = np.concatenate([
- scale_factor[batch_id][::-1], scale_factor[batch_id][::-1]
- ])
+ picked_box_probs[:, :4], ori_shape[batch_id]
+ )
+ im_scale = np.concatenate(
+ [scale_factor[batch_id][::-1], scale_factor[batch_id][::-1]]
+ )
picked_box_probs[:, :4] /= im_scale
# clas score box
out_boxes_list.append(
np.concatenate(
[
- np.expand_dims(
- np.array(picked_labels),
- axis=-1), np.expand_dims(
- picked_box_probs[:, 4], axis=-1),
- picked_box_probs[:, :4]
+ np.expand_dims(np.array(picked_labels), axis=-1),
+ np.expand_dims(picked_box_probs[:, 4], axis=-1),
+ picked_box_probs[:, :4],
],
- axis=1))
+ axis=1,
+ )
+ )
out_boxes_num.append(len(picked_labels))
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
@@ -263,12 +266,12 @@ def __call__(self, ori_img, img, preds):
for dt in out_boxes_list:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
label = self.labels[clsid]
- result = {'bbox': bbox, 'label': label, 'score': score}
+ result = {"bbox": bbox, "label": label, "score": score}
results.append(result)
# Handle conflict where a box is simultaneously recognized as multiple labels.
# Use IoU to find similar boxes. Prioritize labels as table, text, and others when deduplicate similar boxes.
- bboxes = np.array([x['bbox'] for x in results])
+ bboxes = np.array([x["bbox"] for x in results])
duplicate_idx = list()
for i in range(len(results)):
if i in duplicate_idx:
@@ -276,11 +279,19 @@ def __call__(self, ori_img, img, preds):
containments = calculate_containment(bboxes, bboxes[i, ...])
overlaps = np.where(containments > 0.5)[0]
if len(overlaps) > 1:
- table_box = [x for x in overlaps if results[x]['label'] == 'table']
+ table_box = [x for x in overlaps if results[x]["label"] == "table"]
if len(table_box) > 0:
- keep = sorted([(x, results[x]) for x in table_box], key=lambda x: x[1]['score'], reverse=True)[0][0]
+ keep = sorted(
+ [(x, results[x]) for x in table_box],
+ key=lambda x: x[1]["score"],
+ reverse=True,
+ )[0][0]
else:
- keep = sorted([(x, results[x]) for x in overlaps], key=lambda x: x[1]['score'], reverse=True)[0][0]
+ keep = sorted(
+ [(x, results[x]) for x in overlaps],
+ key=lambda x: x[1]["score"],
+ reverse=True,
+ )[0][0]
duplicate_idx.extend([x for x in overlaps if x != keep])
results = [x for i, x in enumerate(results) if i not in duplicate_idx]
return results
diff --git a/ppocr/postprocess/pse_postprocess/__init__.py b/ppocr/postprocess/pse_postprocess/__init__.py
index 680473bf4b..eed166a43f 100644
--- a/ppocr/postprocess/pse_postprocess/__init__.py
+++ b/ppocr/postprocess/pse_postprocess/__init__.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .pse_postprocess import PSEPostProcess
\ No newline at end of file
+from .pse_postprocess import PSEPostProcess
diff --git a/ppocr/postprocess/pse_postprocess/pse/__init__.py b/ppocr/postprocess/pse_postprocess/pse/__init__.py
index 1903a9149a..60288025e5 100644
--- a/ppocr/postprocess/pse_postprocess/pse/__init__.py
+++ b/ppocr/postprocess/pse_postprocess/pse/__init__.py
@@ -18,12 +18,16 @@
python_path = sys.executable
ori_path = os.getcwd()
-os.chdir('ppocr/postprocess/pse_postprocess/pse')
-if subprocess.call(
- '{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0:
+os.chdir("ppocr/postprocess/pse_postprocess/pse")
+if (
+ subprocess.call("{} setup.py build_ext --inplace".format(python_path), shell=True)
+ != 0
+):
raise RuntimeError(
- 'Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+'.
- format(os.path.dirname(os.path.realpath(__file__))))
+ "Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+".format(
+ os.path.dirname(os.path.realpath(__file__))
+ )
+ )
os.chdir(ori_path)
from .pse import pse
diff --git a/ppocr/postprocess/pse_postprocess/pse/setup.py b/ppocr/postprocess/pse_postprocess/pse/setup.py
index 03746782af..6f8913e05d 100644
--- a/ppocr/postprocess/pse_postprocess/pse/setup.py
+++ b/ppocr/postprocess/pse_postprocess/pse/setup.py
@@ -2,13 +2,17 @@
from Cython.Build import cythonize
import numpy
-setup(ext_modules=cythonize(Extension(
- 'pse',
- sources=['pse.pyx'],
- language='c++',
- include_dirs=[numpy.get_include()],
- library_dirs=[],
- libraries=[],
- extra_compile_args=['-O3'],
- extra_link_args=[]
-)))
+setup(
+ ext_modules=cythonize(
+ Extension(
+ "pse",
+ sources=["pse.pyx"],
+ language="c++",
+ include_dirs=[numpy.get_include()],
+ library_dirs=[],
+ libraries=[],
+ extra_compile_args=["-O3"],
+ extra_link_args=[],
+ )
+ )
+)
diff --git a/ppocr/postprocess/pse_postprocess/pse_postprocess.py b/ppocr/postprocess/pse_postprocess/pse_postprocess.py
index 962f3efe92..d5cb992e17 100755
--- a/ppocr/postprocess/pse_postprocess/pse_postprocess.py
+++ b/ppocr/postprocess/pse_postprocess/pse_postprocess.py
@@ -33,14 +33,16 @@ class PSEPostProcess(object):
The post process for PSE.
"""
- def __init__(self,
- thresh=0.5,
- box_thresh=0.85,
- min_area=16,
- box_type='quad',
- scale=4,
- **kwargs):
- assert box_type in ['quad', 'poly'], 'Only quad and poly is supported'
+ def __init__(
+ self,
+ thresh=0.5,
+ box_thresh=0.85,
+ min_area=16,
+ box_type="quad",
+ scale=4,
+ **kwargs
+ ):
+ assert box_type in ["quad", "poly"], "Only quad and poly is supported"
self.thresh = thresh
self.box_thresh = box_thresh
self.min_area = min_area
@@ -48,15 +50,14 @@ def __init__(self,
self.scale = scale
def __call__(self, outs_dict, shape_list):
- pred = outs_dict['maps']
+ pred = outs_dict["maps"]
if not isinstance(pred, paddle.Tensor):
pred = paddle.to_tensor(pred)
- pred = F.interpolate(
- pred, scale_factor=4 // self.scale, mode='bilinear')
+ pred = F.interpolate(pred, scale_factor=4 // self.scale, mode="bilinear")
score = F.sigmoid(pred[:, 0, :, :])
- kernels = (pred > self.thresh).astype('float32')
+ kernels = (pred > self.thresh).astype("float32")
text_mask = kernels[:, 0, :, :]
text_mask = paddle.unsqueeze(text_mask, axis=1)
@@ -67,11 +68,11 @@ def __call__(self, outs_dict, shape_list):
boxes_batch = []
for batch_index in range(pred.shape[0]):
- boxes, scores = self.boxes_from_bitmap(score[batch_index],
- kernels[batch_index],
- shape_list[batch_index])
+ boxes, scores = self.boxes_from_bitmap(
+ score[batch_index], kernels[batch_index], shape_list[batch_index]
+ )
- boxes_batch.append({'points': boxes, 'scores': scores})
+ boxes_batch.append({"points": boxes, "scores": scores})
return boxes_batch
def boxes_from_bitmap(self, score, kernels, shape):
@@ -97,18 +98,19 @@ def generate_box(self, score, label, shape):
label[ind] = 0
continue
- if self.box_type == 'quad':
+ if self.box_type == "quad":
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect)
- elif self.box_type == 'poly':
+ elif self.box_type == "poly":
box_height = np.max(points[:, 1]) + 10
box_width = np.max(points[:, 0]) + 10
mask = np.zeros((box_height, box_width), np.uint8)
mask[points[:, 1], points[:, 0]] = 255
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
- cv2.CHAIN_APPROX_SIMPLE)
+ contours, _ = cv2.findContours(
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
bbox = np.squeeze(contours[0], 1)
else:
raise NotImplementedError
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 9db113cf74..ed04ebcedd 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -19,7 +19,7 @@
class BaseRecLabelDecode(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
@@ -34,12 +34,12 @@ def __init__(self, character_dict_path=None, use_space_char=False):
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
- if 'arabic' in character_dict_path:
+ if "arabic" in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
@@ -50,34 +50,34 @@ def __init__(self, character_dict_path=None, use_space_char=False):
def pred_reverse(self, pred):
pred_re = []
- c_current = ''
+ c_current = ""
for c in pred:
- if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
- if c_current != '':
+ if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
+ if c_current != "":
pred_re.append(c_current)
pred_re.append(c)
- c_current = ''
+ c_current = ""
else:
c_current += c
- if c_current != '':
+ if c_current != "":
pred_re.append(c_current)
- return ''.join(pred_re[::-1])
+ return "".join(pred_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def get_word_info(self, text, selection):
"""
- Group the decoded characters and record the corresponding decoded positions.
+ Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
- selection: the bool array that identifies which columns of features are decoded as non-separated characters
+ selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
- state_list: list of marker to identify the type of grouping words, including two types of grouping words:
+ state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
@@ -88,21 +88,28 @@ def get_word_info(self, text, selection):
word_list = []
word_col_list = []
state_list = []
- valid_col = np.where(selection==True)[0]
+ valid_col = np.where(selection == True)[0]
for c_i, char in enumerate(text):
- if '\u4e00' <= char <= '\u9fff':
- c_state = 'cn'
- elif bool(re.search('[a-zA-Z0-9]', char)):
- c_state = 'en&num'
+ if "\u4e00" <= char <= "\u9fff":
+ c_state = "cn"
+ elif bool(re.search("[a-zA-Z0-9]", char)):
+ c_state = "en&num"
else:
- c_state = 'splitter'
-
- if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number
- c_state = 'en&num'
- if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
- c_state = 'en&num'
-
+ c_state = "splitter"
+
+ if (
+ char == "."
+ and state == "en&num"
+ and c_i + 1 < len(text)
+ and bool(re.search("[0-9]", text[c_i + 1]))
+ ): # grouping floting number
+ c_state = "en&num"
+ if (
+ char == "-" and state == "en&num"
+ ): # grouping word with '-', such as 'state-of-the-art'
+ c_state = "en&num"
+
if state == None:
state = c_state
@@ -126,26 +133,26 @@ def get_word_info(self, text, selection):
return word_list, word_col_list, state_list
- def decode(self,
- text_index,
- text_prob=None,
- is_remove_duplicate=False,
- return_word_box=False):
- """ convert text-index into text-label. """
+ def decode(
+ self,
+ text_index,
+ text_prob=None,
+ is_remove_duplicate=False,
+ return_word_box=False,
+ ):
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
- selection[1:] = text_index[batch_idx][1:] != text_index[
- batch_idx][:-1]
+ selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
- self.character[text_id]
- for text_id in text_index[batch_idx][selection]
+ self.character[text_id] for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
@@ -154,18 +161,27 @@ def decode(self,
if len(conf_list) == 0:
conf_list = [0]
- text = ''.join(char_list)
+ text = "".join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(
- text, selection)
- result_list.append((text, np.mean(conf_list).tolist(), [
- len(text_index[batch_idx]), word_list, word_col_list,
- state_list
- ]))
+ text, selection
+ )
+ result_list.append(
+ (
+ text,
+ np.mean(conf_list).tolist(),
+ [
+ len(text_index[batch_idx]),
+ word_list,
+ word_col_list,
+ state_list,
+ ],
+ )
+ )
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -175,19 +191,12 @@ def get_ignored_tokens(self):
class CTCLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
-
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(CTCLabelDecode, self).__init__(character_dict_path,
- use_space_char)
-
- def __call__(self,
- preds,
- label=None,
- return_word_box=False,
- *args,
- **kwargs):
+ """Convert between text-label and text-index"""
+
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
+
+ def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
@@ -198,11 +207,12 @@ def __call__(self,
preds_idx,
preds_prob,
is_remove_duplicate=True,
- return_word_box=return_word_box)
+ return_word_box=return_word_box,
+ )
if return_word_box:
for rec_idx, rec in enumerate(text):
- wh_ratio = kwargs['wh_ratio_list'][rec_idx]
- max_wh_ratio = kwargs['max_wh_ratio']
+ wh_ratio = kwargs["wh_ratio_list"][rec_idx]
+ max_wh_ratio = kwargs["max_wh_ratio"]
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
if label is None:
return text
@@ -210,25 +220,28 @@ def __call__(self,
return text, label
def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
+ dict_character = ["blank"] + dict_character
return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
- Convert
+ Convert
Convert between text-label and text-index
"""
- def __init__(self,
- character_dict_path=None,
- use_space_char=False,
- model_name=["student"],
- key=None,
- multi_head=False,
- **kwargs):
- super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(
+ self,
+ character_dict_path=None,
+ use_space_char=False,
+ model_name=["student"],
+ key=None,
+ multi_head=False,
+ **kwargs
+ ):
+ super(DistillationCTCLabelDecode, self).__init__(
+ character_dict_path, use_space_char
+ )
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
@@ -243,18 +256,16 @@ def __call__(self, preds, label=None, *args, **kwargs):
if self.key is not None:
pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
- pred = pred['ctc']
+ pred = pred["ctc"]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(AttnLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -264,7 +275,7 @@ def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
@@ -279,16 +290,17 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
break
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -323,18 +335,15 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
class RFLLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(RFLLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(RFLLabelDecode, self).__init__(character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -344,7 +353,7 @@ def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
@@ -359,16 +368,17 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
break
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -412,26 +422,21 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
class SEEDLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SEEDLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.padding_str = "padding"
self.end_str = "eos"
self.unknown = "unknown"
- dict_character = dict_character + [
- self.end_str, self.padding_str, self.unknown
- ]
+ dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
return dict_character
def get_ignored_tokens(self):
@@ -448,7 +453,7 @@ def get_beg_end_flag_idx(self, beg_or_end):
return idx
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
[end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
@@ -460,16 +465,17 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
break
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -499,16 +505,14 @@ def __call__(self, preds, label=None, *args, **kwargs):
class SRNLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SRNLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- self.max_text_length = kwargs.get('max_text_length', 25)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
+ self.max_text_length = kwargs.get("max_text_length", 25)
def __call__(self, preds, label=None, *args, **kwargs):
- pred = preds['predict']
+ pred = preds["predict"]
char_num = len(self.character_str) + 2
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
@@ -530,7 +534,7 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
@@ -543,17 +547,18 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
continue
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -572,32 +577,30 @@ def get_beg_end_flag_idx(self, beg_or_end):
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
- assert False, "unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
class ParseQLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- BOS = '[B]'
- EOS = '[E]'
- PAD = '[P]'
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(ParseQLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- self.max_text_length = kwargs.get('max_text_length', 25)
+ BOS = "[B]"
+ EOS = "[E]"
+ PAD = "[P]"
+
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(ParseQLabelDecode, self).__init__(character_dict_path, use_space_char)
+ self.max_text_length = kwargs.get("max_text_length", 25)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, dict):
- pred = preds['predict']
+ pred = preds["predict"]
else:
pred = preds
- char_num = len(
- self.character_str
- ) + 1 # We don't predict nor , with only addition
+ char_num = (
+ len(self.character_str) + 1
+ ) # We don't predict nor , with only addition
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
B, L = pred.shape[:2]
@@ -619,7 +622,7 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def decode(self, text_index, text_prob=None, raw=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
@@ -645,7 +648,7 @@ def decode(self, text_index, text_prob=None, raw=False):
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -663,8 +666,7 @@ def _filter(self, ids, probs=None):
# Truncate after EOS
ids = ids[:eos_idx]
if probs is not None:
- probs = probs[:eos_idx +
- 1] # but include prob. for EOS (if it exists)
+ probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
return ids, probs
def get_ignored_tokens(self):
@@ -672,14 +674,12 @@ def get_ignored_tokens(self):
class SARLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SARLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
- self.rm_symbol = kwargs.get('rm_symbol', False)
+ self.rm_symbol = kwargs.get("rm_symbol", False)
def add_special_char(self, dict_character):
beg_end_str = ""
@@ -695,7 +695,7 @@ def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
@@ -713,20 +713,21 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
break
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
if self.rm_symbol:
- comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
+ comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
text = text.lower()
- text = comp.sub('', text)
+ text = comp.sub("", text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -748,14 +749,12 @@ def get_ignored_tokens(self):
class SATRNLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SATRNLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(SATRNLabelDecode, self).__init__(character_dict_path, use_space_char)
- self.rm_symbol = kwargs.get('rm_symbol', False)
+ self.rm_symbol = kwargs.get("rm_symbol", False)
def add_special_char(self, dict_character):
beg_end_str = ""
@@ -771,7 +770,7 @@ def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
@@ -789,20 +788,21 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
break
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
if self.rm_symbol:
- comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
+ comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
text = text.lower()
- text = comp.sub('', text)
+ text = comp.sub("", text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -825,19 +825,22 @@ def get_ignored_tokens(self):
class DistillationSARLabelDecode(SARLabelDecode):
"""
- Convert
+ Convert
Convert between text-label and text-index
"""
- def __init__(self,
- character_dict_path=None,
- use_space_char=False,
- model_name=["student"],
- key=None,
- multi_head=False,
- **kwargs):
- super(DistillationSARLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(
+ self,
+ character_dict_path=None,
+ use_space_char=False,
+ model_name=["student"],
+ key=None,
+ multi_head=False,
+ **kwargs
+ ):
+ super(DistillationSARLabelDecode, self).__init__(
+ character_dict_path, use_space_char
+ )
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
@@ -852,23 +855,21 @@ def __call__(self, preds, label=None, *args, **kwargs):
if self.key is not None:
pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
- pred = pred['sar']
+ pred = pred["sar"]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class PRENLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(PRENLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
def add_special_char(self, dict_character):
- padding_str = '' # 0
- end_str = '' # 1
- unknown_str = '' # 2
+ padding_str = "" # 0
+ end_str = "" # 1
+ unknown_str = "" # 2
dict_character = [padding_str, end_str, unknown_str] + dict_character
self.padding_idx = 0
@@ -878,7 +879,7 @@ def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
batch_size = len(text_index)
@@ -888,22 +889,20 @@ def decode(self, text_index, text_prob=None):
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == self.end_idx:
break
- if text_index[batch_idx][idx] in \
- [self.padding_idx, self.unknown_idx]:
+ if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
if len(text) > 0:
result_list.append((text, np.mean(conf_list).tolist()))
else:
# here confidence of empty recog result is 1
- result_list.append(('', 1))
+ result_list.append(("", 1))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
@@ -919,14 +918,12 @@ def __call__(self, preds, label=None, *args, **kwargs):
class NRTRLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
- super(NRTRLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
-
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
@@ -955,11 +952,11 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def add_special_char(self, dict_character):
- dict_character = ['blank', '', '', ''] + dict_character
+ dict_character = ["blank", "", "", ""] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
@@ -970,25 +967,23 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
char_idx = self.character[int(text_index[batch_idx][idx])]
except:
continue
- if char_idx == '': # end
+ if char_idx == "": # end
break
char_list.append(char_idx)
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
class ViTSTRLabelDecode(NRTRLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(ViTSTRLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(ViTSTRLabelDecode, self).__init__(character_dict_path, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
@@ -1004,21 +999,19 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def add_special_char(self, dict_character):
- dict_character = ['', ''] + dict_character
+ dict_character = ["", ""] + dict_character
return dict_character
class ABINetLabelDecode(NRTRLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(ABINetLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(ABINetLabelDecode, self).__init__(character_dict_path, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, dict):
- preds = preds['align'][-1].numpy()
+ preds = preds["align"][-1].numpy()
elif isinstance(preds, paddle.Tensor):
preds = preds.numpy()
else:
@@ -1033,17 +1026,15 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def add_special_char(self, dict_character):
- dict_character = [''] + dict_character
+ dict_character = [""] + dict_character
return dict_character
class SPINLabelDecode(AttnLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SPINLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(SPINLabelDecode, self).__init__(character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -1054,24 +1045,22 @@ def add_special_char(self, dict_character):
class VLLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
- self.max_text_length = kwargs.get('max_text_length', 25)
+ self.max_text_length = kwargs.get("max_text_length", 25)
self.nclass = len(self.character) + 1
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
- selection[1:] = text_index[batch_idx][1:] != text_index[
- batch_idx][:-1]
+ selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
@@ -1086,7 +1075,7 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
if len(conf_list) == 0:
conf_list = [0]
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@@ -1098,10 +1087,9 @@ def __call__(self, preds, label=None, length=None, *args, **kwargs):
nsteps = self.max_text_length
if not isinstance(text_pre, paddle.Tensor):
- text_pre = paddle.to_tensor(text_pre, dtype='float32')
+ text_pre = paddle.to_tensor(text_pre, dtype="float32")
- out_res = paddle.zeros(
- shape=[lenText, b, self.nclass], dtype=x.dtype)
+ out_res = paddle.zeros(shape=[lenText, b, self.nclass], dtype=x.dtype)
out_length = paddle.zeros(shape=[b], dtype=x.dtype)
now_step = 0
for _ in range(nsteps):
@@ -1118,10 +1106,11 @@ def __call__(self, preds, label=None, length=None, *args, **kwargs):
out_length[j] = nsteps
start = 0
output = paddle.zeros(
- shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
+ shape=[int(out_length.sum()), self.nclass], dtype=x.dtype
+ )
for i in range(0, b):
cur_length = int(out_length[i])
- output[start:start + cur_length] = out_res[0:cur_length, i, :]
+ output[start : start + cur_length] = out_res[0:cur_length, i, :]
start += cur_length
net_out = output
length = out_length
@@ -1132,7 +1121,7 @@ def __call__(self, preds, label=None, length=None, *args, **kwargs):
net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
text = []
if not isinstance(net_out, paddle.Tensor):
- net_out = paddle.to_tensor(net_out, dtype='float32')
+ net_out = paddle.to_tensor(net_out, dtype="float32")
net_out = F.softmax(net_out, axis=1)
for i in range(0, length.shape[0]):
if i == 0:
@@ -1142,14 +1131,18 @@ def __call__(self, preds, label=None, length=None, *args, **kwargs):
start_idx = int(length[:i].sum())
end_idx = int(length[:i].sum() + length[i])
preds_idx = net_out[start_idx:end_idx].topk(1)[1][:, 0].tolist()
- preds_text = ''.join([
- self.character[idx - 1]
- if idx > 0 and idx <= len(self.character) else ''
- for idx in preds_idx
- ])
+ preds_text = "".join(
+ [
+ self.character[idx - 1]
+ if idx > 0 and idx <= len(self.character)
+ else ""
+ for idx in preds_idx
+ ]
+ )
preds_prob = net_out[start_idx:end_idx].topk(1)[0][:, 0]
preds_prob = paddle.exp(
- paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
+ paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6)
+ )
text.append((preds_text, float(preds_prob)))
if label is None:
return text
@@ -1158,12 +1151,10 @@ def __call__(self, preds, label=None, length=None, *args, **kwargs):
class CANLabelDecode(BaseRecLabelDecode):
- """ Convert between latex-symbol and symbol-index """
+ """Convert between latex-symbol and symbol-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(CANLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(CANLabelDecode, self).__init__(character_dict_path, use_space_char)
def decode(self, text_index, preds_prob=None):
result_list = []
@@ -1174,9 +1165,9 @@ def decode(self, text_index, preds_prob=None):
symbol_list = [self.character[idx] for idx in idx_list]
probs = []
if preds_prob is not None:
- probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
+ probs = preds_prob[batch_idx][: len(symbol_list)].tolist()
- result_list.append([' '.join(symbol_list), probs])
+ result_list.append([" ".join(symbol_list), probs])
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
@@ -1191,17 +1182,15 @@ def __call__(self, preds, label=None, *args, **kwargs):
class CPPDLabelDecode(NRTRLabelDecode):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(CPPDLabelDecode, self).__init__(character_dict_path,
- use_space_char)
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(CPPDLabelDecode, self).__init__(character_dict_path, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
if isinstance(preds[-1], dict):
- preds = preds[-1]['align'][-1].numpy()
+ preds = preds[-1]["align"][-1].numpy()
else:
preds = preds[-1].numpy()
if isinstance(preds, paddle.Tensor):
@@ -1217,5 +1206,5 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label
def add_special_char(self, dict_character):
- dict_character = [''] + dict_character
+ dict_character = [""] + dict_character
return dict_character
diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py
index 594bf17d6a..18e825245e 100755
--- a/ppocr/postprocess/sast_postprocess.py
+++ b/ppocr/postprocess/sast_postprocess.py
@@ -21,7 +21,7 @@
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
-sys.path.append(os.path.join(__dir__, '..'))
+sys.path.append(os.path.join(__dir__, ".."))
import numpy as np
from .locality_aware_nms import nms_locality
@@ -35,15 +35,16 @@ class SASTPostProcess(object):
The post process for SAST.
"""
- def __init__(self,
- score_thresh=0.5,
- nms_thresh=0.2,
- sample_pts_num=2,
- shrink_ratio_of_width=0.3,
- expand_scale=1.0,
- tcl_map_thresh=0.5,
- **kwargs):
-
+ def __init__(
+ self,
+ score_thresh=0.5,
+ nms_thresh=0.2,
+ sample_pts_num=2,
+ shrink_ratio_of_width=0.3,
+ expand_scale=1.0,
+ tcl_map_thresh=0.5,
+ **kwargs
+ ):
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sample_pts_num = sample_pts_num
@@ -68,15 +69,13 @@ def point_pair2poly(self, point_pair_list):
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
- def shrink_quad_along_width(self,
- quad,
- begin_width_ratio=0.,
- end_width_ratio=1.):
- """
+ def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
+ """
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
- [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
+ )
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
@@ -86,23 +85,26 @@ def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
expand poly along width.
"""
point_num = poly.shape[0]
- left_quad = np.array(
- [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
- left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
- left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
- 1.0)
+ left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = (
+ -shrink_ratio_of_width
+ * np.linalg.norm(left_quad[0] - left_quad[3])
+ / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ )
+ left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
- poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]
+ poly[point_num // 2 - 2],
+ poly[point_num // 2 - 1],
+ poly[point_num // 2],
+ poly[point_num // 2 + 1],
],
- dtype=np.float32)
- right_ratio = 1.0 + \
- shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
- right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
- right_ratio)
+ dtype=np.float32,
+ )
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
+ right_quad[0] - right_quad[3]
+ ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
@@ -133,17 +135,21 @@ def quad_area(self, quad):
"""
compute area of a quad.
"""
- edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
- (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
- (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
- (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
- return np.sum(edge) / 2.
+ edge = [
+ (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
+ (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
+ (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
+ (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]),
+ ]
+ return np.sum(edge) / 2.0
def nms(self, dets):
if self.is_python35:
from ppocr.utils.utility import check_install
- check_install('lanms', 'lanms-nova')
+
+ check_install("lanms", "lanms-nova")
import lanms
+
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
else:
dets = nms_locality(dets, self.nms_thresh)
@@ -169,8 +175,7 @@ def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2)
- pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
- (1, m, 1)) # (n, m, 2)
+ pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
@@ -182,47 +187,50 @@ def estimate_sample_pts_num(self, quad, xy_text):
"""
Estimate sample points number.
"""
- eh = (np.linalg.norm(quad[0] - quad[3]) +
- np.linalg.norm(quad[1] - quad[2])) / 2.0
- ew = (np.linalg.norm(quad[0] - quad[1]) +
- np.linalg.norm(quad[2] - quad[3])) / 2.0
+ eh = (
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
+ ) / 2.0
+ ew = (
+ np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
+ ) / 2.0
dense_sample_pts_num = max(2, int(ew))
- dense_xy_center_line = xy_text[np.linspace(
- 0,
- xy_text.shape[0] - 1,
- dense_sample_pts_num,
- endpoint=True,
- dtype=np.float32).astype(np.int32)]
-
- dense_xy_center_line_diff = dense_xy_center_line[
- 1:] - dense_xy_center_line[:-1]
- estimate_arc_len = np.sum(
- np.linalg.norm(
- dense_xy_center_line_diff, axis=1))
+ dense_xy_center_line = xy_text[
+ np.linspace(
+ 0,
+ xy_text.shape[0] - 1,
+ dense_sample_pts_num,
+ endpoint=True,
+ dtype=np.float32,
+ ).astype(np.int32)
+ ]
+
+ dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
+ estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
- def detect_sast(self,
- tcl_map,
- tvo_map,
- tbo_map,
- tco_map,
- ratio_w,
- ratio_h,
- src_w,
- src_h,
- shrink_ratio_of_width=0.3,
- tcl_map_thresh=0.5,
- offset_expand=1.0,
- out_strid=4.0):
+ def detect_sast(
+ self,
+ tcl_map,
+ tvo_map,
+ tbo_map,
+ tco_map,
+ ratio_w,
+ ratio_h,
+ src_w,
+ src_h,
+ shrink_ratio_of_width=0.3,
+ tcl_map_thresh=0.5,
+ offset_expand=1.0,
+ out_strid=4.0,
+ ):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
- scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
- tvo_map)
+ scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
@@ -237,7 +245,8 @@ def detect_sast(self,
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(
- tcl_map, tcl_map_thresh, quads, tco_map)
+ tcl_map, tcl_map_thresh, quads, tco_map
+ )
# restore single poly with tcl instance.
poly_list = []
@@ -267,13 +276,14 @@ def detect_sast(self,
# sort xy_text
left_center_pt = np.array(
- [[(quad[0, 0] + quad[-1, 0]) / 2.0,
- (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
+ [[(quad[0, 0] + quad[-1, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]]
+ ) # (1, 2)
right_center_pt = np.array(
- [[(quad[1, 0] + quad[2, 0]) / 2.0,
- (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
- proj_unit_vec = (right_center_pt - left_center_pt) / \
- (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
+ [[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[1, 1] + quad[2, 1]) / 2.0]]
+ ) # (1, 2)
+ proj_unit_vec = (right_center_pt - left_center_pt) / (
+ np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
+ )
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
xy_text = xy_text[np.argsort(proj_value)]
@@ -282,49 +292,52 @@ def detect_sast(self,
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
- xy_center_line = xy_text[np.linspace(
- 0,
- xy_text.shape[0] - 1,
- sample_pts_num,
- endpoint=True,
- dtype=np.float32).astype(np.int32)]
+ xy_center_line = xy_text[
+ np.linspace(
+ 0,
+ xy_text.shape[0] - 1,
+ sample_pts_num,
+ endpoint=True,
+ dtype=np.float32,
+ ).astype(np.int32)
+ ]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
- offset_length = np.linalg.norm(
- offset, axis=1, keepdims=True)
+ offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
expand_length = np.clip(
- offset_length * (offset_expand - 1),
- a_min=0.5,
- a_max=3.0)
+ offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
+ )
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
# original point
ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
- [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair = (
+ (ori_yx + offset)[:, ::-1]
+ * out_strid
+ / np.array([ratio_w, ratio_h]).reshape(-1, 2)
+ )
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
- detected_poly = self.expand_poly_along_width(detected_poly,
- shrink_ratio_of_width)
- detected_poly[:, 0] = np.clip(
- detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(
- detected_poly[:, 1], a_min=0, a_max=src_h)
+ detected_poly = self.expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width
+ )
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
def __call__(self, outs_dict, shape_list):
- score_list = outs_dict['f_score']
- border_list = outs_dict['f_border']
- tvo_list = outs_dict['f_tvo']
- tco_list = outs_dict['f_tco']
+ score_list = outs_dict["f_score"]
+ border_list = outs_dict["f_border"]
+ tvo_list = outs_dict["f_tvo"]
+ tco_list = outs_dict["f_tco"]
if isinstance(score_list, paddle.Tensor):
score_list = score_list.numpy()
border_list = border_list.numpy()
@@ -351,7 +364,8 @@ def __call__(self, outs_dict, shape_list):
src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh,
- offset_expand=self.expand_scale)
- poly_lists.append({'points': np.array(poly_list)})
+ offset_expand=self.expand_scale,
+ )
+ poly_lists.append({"points": np.array(poly_list)})
return poly_lists
diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py
index a47061f935..a18abd872a 100644
--- a/ppocr/postprocess/table_postprocess.py
+++ b/ppocr/postprocess/table_postprocess.py
@@ -19,17 +19,14 @@
class TableLabelDecode(AttnLabelDecode):
- """ """
+ """ """
- def __init__(self,
- character_dict_path,
- merge_no_span_structure=False,
- **kwargs):
+ def __init__(self, character_dict_path, merge_no_span_structure=False, **kwargs):
dict_character = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
dict_character.append(line)
if merge_no_span_structure:
@@ -43,11 +40,11 @@ def __init__(self,
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
- self.td_token = ['', ' | | ']
+ self.td_token = ["", " | | "]
def __call__(self, preds, batch=None):
- structure_probs = preds['structure_probs']
- bbox_preds = preds['loc_preds']
+ structure_probs = preds["structure_probs"]
+ bbox_preds = preds["loc_preds"]
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(bbox_preds, paddle.Tensor):
@@ -61,8 +58,7 @@ def __call__(self, preds, batch=None):
return result, label_decode_result
def decode(self, structure_probs, bbox_preds, shape_list):
- """convert text-label into text-index.
- """
+ """convert text-label into text-index."""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
@@ -92,14 +88,13 @@ def decode(self, structure_probs, bbox_preds, shape_list):
structure_batch_list.append([structure_list, np.mean(score_list)])
bbox_batch_list.append(np.array(bbox_list))
result = {
- 'bbox_batch_list': bbox_batch_list,
- 'structure_batch_list': structure_batch_list,
+ "bbox_batch_list": bbox_batch_list,
+ "structure_batch_list": structure_batch_list,
}
return result
def decode_label(self, batch):
- """convert text-label into text-index.
- """
+ """convert text-label into text-index."""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
@@ -127,8 +122,8 @@ def decode_label(self, batch):
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
result = {
- 'bbox_batch_list': bbox_batch_list,
- 'structure_batch_list': structure_batch_list,
+ "bbox_batch_list": bbox_batch_list,
+ "structure_batch_list": structure_batch_list,
}
return result
@@ -140,28 +135,35 @@ def _bbox_decode(self, bbox, shape):
class TableMasterLabelDecode(TableLabelDecode):
- """ """
-
- def __init__(self,
- character_dict_path,
- box_shape='ori',
- merge_no_span_structure=True,
- **kwargs):
- super(TableMasterLabelDecode, self).__init__(character_dict_path,
- merge_no_span_structure)
+ """ """
+
+ def __init__(
+ self,
+ character_dict_path,
+ box_shape="ori",
+ merge_no_span_structure=True,
+ **kwargs
+ ):
+ super(TableMasterLabelDecode, self).__init__(
+ character_dict_path, merge_no_span_structure
+ )
self.box_shape = box_shape
assert box_shape in [
- 'ori', 'pad'
- ], 'The shape used for box normalization must be ori or pad'
+ "ori",
+ "pad",
+ ], "The shape used for box normalization must be ori or pad"
def add_special_char(self, dict_character):
- self.beg_str = ''
- self.end_str = ''
- self.unknown_str = ''
- self.pad_str = ''
+ self.beg_str = ""
+ self.end_str = ""
+ self.unknown_str = ""
+ self.pad_str = ""
dict_character = dict_character
dict_character = dict_character + [
- self.unknown_str, self.beg_str, self.end_str, self.pad_str
+ self.unknown_str,
+ self.beg_str,
+ self.end_str,
+ self.pad_str,
]
return dict_character
@@ -174,7 +176,7 @@ def get_ignored_tokens(self):
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
- if self.box_shape == 'pad':
+ if self.box_shape == "pad":
h, w = pad_h, pad_w
bbox[0::2] *= w
bbox[1::2] *= h
diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
index 64f7d76195..efdfffefa9 100644
--- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
@@ -15,14 +15,14 @@
class VQAReTokenLayoutLMPostProcess(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, **kwargs):
super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs):
- pred_relations = preds['pred_relations']
- if isinstance(preds['pred_relations'], paddle.Tensor):
+ pred_relations = preds["pred_relations"]
+ if isinstance(preds["pred_relations"], paddle.Tensor):
pred_relations = pred_relations.numpy()
pred_relations = self.decode_pred(pred_relations)
@@ -35,21 +35,22 @@ def _metric(self, pred_relations, label):
return pred_relations, label[-1], label[-2]
def _infer(self, pred_relations, *args, **kwargs):
- ser_results = kwargs['ser_results']
- entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
+ ser_results = kwargs["ser_results"]
+ entity_idx_dict_batch = kwargs["entity_idx_dict_batch"]
# merge relations and ocr info
results = []
for pred_relation, ser_result, entity_idx_dict in zip(
- pred_relations, ser_results, entity_idx_dict_batch):
+ pred_relations, ser_results, entity_idx_dict_batch
+ ):
result = []
used_tail_id = []
for relation in pred_relation:
- if relation['tail_id'] in used_tail_id:
+ if relation["tail_id"] in used_tail_id:
continue
- used_tail_id.append(relation['tail_id'])
- ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
- ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
+ used_tail_id.append(relation["tail_id"])
+ ocr_info_head = ser_result[entity_idx_dict[relation["head_id"]]]
+ ocr_info_tail = ser_result[entity_idx_dict[relation["tail_id"]]]
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
@@ -58,16 +59,16 @@ def decode_pred(self, pred_relations):
pred_relations_new = []
for pred_relation in pred_relations:
pred_relation_new = []
- pred_relation = pred_relation[1:pred_relation[0, 0, 0] + 1]
+ pred_relation = pred_relation[1 : pred_relation[0, 0, 0] + 1]
for relation in pred_relation:
relation_new = dict()
- relation_new['head_id'] = relation[0, 0]
- relation_new['head'] = tuple(relation[1])
- relation_new['head_type'] = relation[2, 0]
- relation_new['tail_id'] = relation[3, 0]
- relation_new['tail'] = tuple(relation[4])
- relation_new['tail_type'] = relation[5, 0]
- relation_new['type'] = relation[6, 0]
+ relation_new["head_id"] = relation[0, 0]
+ relation_new["head"] = tuple(relation[1])
+ relation_new["head_type"] = relation[2, 0]
+ relation_new["tail_id"] = relation[3, 0]
+ relation_new["tail"] = tuple(relation[4])
+ relation_new["tail_type"] = relation[5, 0]
+ relation_new["type"] = relation[6, 0]
pred_relation_new.append(relation_new)
pred_relations_new.append(pred_relation_new)
return pred_relations_new
diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
index 5541da90a0..a10f070f24 100644
--- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -17,7 +17,7 @@
class VQASerTokenLayoutLMPostProcess(object):
- """ Convert between text-label and text-index """
+ """Convert between text-label and text-index"""
def __init__(self, class_path, **kwargs):
super(VQASerTokenLayoutLMPostProcess, self).__init__()
@@ -59,17 +59,16 @@ def _metric(self, preds, label):
for i in range(pred_idxs.shape[0]):
for j in range(pred_idxs.shape[1]):
if label[i, j] != -100:
- label_decode_out_list[i].append(self.id2label_map[label[i,
- j]])
- decode_out_list[i].append(self.id2label_map[pred_idxs[i,
- j]])
+ label_decode_out_list[i].append(self.id2label_map[label[i, j]])
+ decode_out_list[i].append(self.id2label_map[pred_idxs[i, j]])
return decode_out_list, label_decode_out_list
def _infer(self, preds, segment_offset_ids, ocr_infos):
results = []
- for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
- ocr_infos):
+ for pred, segment_offset_id, ocr_info in zip(
+ preds, segment_offset_ids, ocr_infos
+ ):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py
index c2a4383eed..f4767e16ef 100755
--- a/ppocr/utils/e2e_metric/Deteval.py
+++ b/ppocr/utils/e2e_metric/Deteval.py
@@ -29,9 +29,16 @@ def input_reading_mod(pred_dict):
det = []
n = len(pred_dict)
for i in range(n):
- points = pred_dict[i]['points']
- text = pred_dict[i]['texts']
- point = ",".join(map(str, points.reshape(-1, )))
+ points = pred_dict[i]["points"]
+ text = pred_dict[i]["texts"]
+ point = ",".join(
+ map(
+ str,
+ points.reshape(
+ -1,
+ ),
+ )
+ )
det.append([point, text])
return det
@@ -40,36 +47,37 @@ def gt_reading_mod(gt_dict):
gt = []
n = len(gt_dict)
for i in range(n):
- points = gt_dict[i]['points'].tolist()
+ points = gt_dict[i]["points"].tolist()
h = len(points)
- text = gt_dict[i]['text']
+ text = gt_dict[i]["text"]
xx = [
- np.array(
- ['x:'], dtype=' 1):
+ if (gt[5] == "#") and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
+ detection = [float(x) for x in detection[0].split(",")]
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
@@ -84,14 +92,16 @@ def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(gt_x, gt_y)), 2)
+ return np.round(
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), 2
+ )
def tau_calculation(det_x, det_y, gt_x, gt_y):
if area(det_x, det_y) == 0.0:
return 0
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(det_x, det_y)), 2)
+ return np.round(
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2
+ )
##############################Initialization###################################
# global_sigma = []
@@ -101,18 +111,23 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
###############################################################################
for input_id in range(allInputs):
- if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
- input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
- input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
- and (input_id != 'Deteval_result_non_curved.txt'):
+ if (
+ (input_id != ".DS_Store")
+ and (input_id != "Pascal_result.txt")
+ and (input_id != "Pascal_result_curved.txt")
+ and (input_id != "Pascal_result_non_curved.txt")
+ and (input_id != "Deteval_result.txt")
+ and (input_id != "Deteval_result_curved.txt")
+ and (input_id != "Deteval_result_non_curved.txt")
+ ):
detections = input_reading_mod(pred_dict)
groundtruths = gt_reading_mod(gt_dir)
detections = detection_filtering(
- detections,
- groundtruths) # filters detections overlapping with DC area
+ detections, groundtruths
+ ) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
- if groundtruths[i][5] == '#':
+ if groundtruths[i][5] == "#":
dc_id.append(i)
cnt = 0
for a in dc_id:
@@ -129,7 +144,7 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
+ detection = [float(x) for x in detection[0].split(",")]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
@@ -139,9 +154,11 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
- det_x, det_y, gt_x, gt_y)
+ det_x, det_y, gt_x, gt_y
+ )
local_tau_table[gt_id, det_id] = tau_calculation(
- det_x, det_y, gt_x, gt_y)
+ det_x, det_y, gt_x, gt_y
+ )
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
@@ -151,10 +168,10 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
global_gt_str = local_gt_str
single_data = {}
- single_data['sigma'] = global_sigma
- single_data['global_tau'] = global_tau
- single_data['global_pred_str'] = global_pred_str
- single_data['global_gt_str'] = global_gt_str
+ single_data["sigma"] = global_sigma
+ single_data["global_tau"] = global_tau
+ single_data["global_pred_str"] = global_pred_str
+ single_data["global_gt_str"] = global_gt_str
return single_data
@@ -166,25 +183,32 @@ def input_reading_mod(pred_dict):
det = []
n = len(pred_dict)
for i in range(n):
- points = pred_dict[i]['points']
- text = pred_dict[i]['texts']
- point = ",".join(map(str, points.reshape(-1, )))
+ points = pred_dict[i]["points"]
+ text = pred_dict[i]["texts"]
+ point = ",".join(
+ map(
+ str,
+ points.reshape(
+ -1,
+ ),
+ )
+ )
det.append([point, text])
return det
def gt_reading_mod(gt_dir, gt_id):
- gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
- gt = gt['polygt']
+ gt = io.loadmat("%s/poly_gt_img%s.mat" % (gt_dir, gt_id))
+ gt = gt["polygt"]
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
- if (gt[5] == '#') and (gt[1].shape[1] > 1):
+ if (gt[5] == "#") and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
+ detection = [float(x) for x in detection[0].split(",")]
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
@@ -199,14 +223,16 @@ def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(gt_x, gt_y)), 2)
+ return np.round(
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), 2
+ )
def tau_calculation(det_x, det_y, gt_x, gt_y):
if area(det_x, det_y) == 0.0:
return 0
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(det_x, det_y)), 2)
+ return np.round(
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2
+ )
##############################Initialization###################################
# global_sigma = []
@@ -216,18 +242,23 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
###############################################################################
for input_id in range(allInputs):
- if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
- input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
- input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
- and (input_id != 'Deteval_result_non_curved.txt'):
+ if (
+ (input_id != ".DS_Store")
+ and (input_id != "Pascal_result.txt")
+ and (input_id != "Pascal_result_curved.txt")
+ and (input_id != "Pascal_result_non_curved.txt")
+ and (input_id != "Deteval_result.txt")
+ and (input_id != "Deteval_result_curved.txt")
+ and (input_id != "Deteval_result_non_curved.txt")
+ ):
detections = input_reading_mod(pred_dict)
groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
detections = detection_filtering(
- detections,
- groundtruths) # filters detections overlapping with DC area
+ detections, groundtruths
+ ) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
- if groundtruths[i][5] == '#':
+ if groundtruths[i][5] == "#":
dc_id.append(i)
cnt = 0
for a in dc_id:
@@ -244,7 +275,7 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
+ detection = [float(x) for x in detection[0].split(",")]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
@@ -254,9 +285,11 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
- det_x, det_y, gt_x, gt_y)
+ det_x, det_y, gt_x, gt_y
+ )
local_tau_table[gt_id, det_id] = tau_calculation(
- det_x, det_y, gt_x, gt_y)
+ det_x, det_y, gt_x, gt_y
+ )
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
@@ -266,10 +299,10 @@ def tau_calculation(det_x, det_y, gt_x, gt_y):
global_gt_str = local_gt_str
single_data = {}
- single_data['sigma'] = global_sigma
- single_data['global_tau'] = global_tau
- single_data['global_pred_str'] = global_pred_str
- single_data['global_gt_str'] = global_gt_str
+ single_data["sigma"] = global_sigma
+ single_data["global_tau"] = global_tau
+ single_data["global_pred_str"] = global_pred_str
+ single_data["global_gt_str"] = global_gt_str
return single_data
@@ -303,10 +336,9 @@ def get_intersection(pD, pG):
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt in groundtruths:
- point_num = gt['points'].shape[1] // 2
- if gt['transcription'] == '###' and (point_num > 1):
- gt_p = np.array(gt['points']).reshape(point_num,
- 2).astype('int32')
+ point_num = gt["points"].shape[1] // 2
+ if gt["transcription"] == "###" and (point_num > 1):
+ gt_p = np.array(gt["points"]).reshape(point_num, 2).astype("int32")
gt_p = plg.Polygon(gt_p)
for det_id, detection in enumerate(detections):
@@ -318,8 +350,7 @@ def detection_filtering(detections, groundtruths, threshold=0.5):
det_p = plg.Polygon(det_p)
try:
- det_gt_iou = get_intersection(det_p,
- gt_p) / det_p.area()
+ det_gt_iou = get_intersection(det_p, gt_p) / det_p.area()
except:
print(det_x, det_y, gt_p)
if det_gt_iou > threshold:
@@ -332,7 +363,7 @@ def sigma_calculation(det_p, gt_p):
"""
sigma = inter_area / gt_area
"""
- if gt_p.area() == 0.:
+ if gt_p.area() == 0.0:
return 0
return get_intersection(det_p, gt_p) / gt_p.area()
@@ -340,7 +371,7 @@ def tau_calculation(det_p, gt_p):
"""
tau = inter_area / det_area
"""
- if det_p.area() == 0.:
+ if det_p.area() == 0.0:
return 0
return get_intersection(det_p, gt_p) / det_p.area()
@@ -352,12 +383,13 @@ def tau_calculation(det_p, gt_p):
groundtruths = gt_reading_mod(gt_label, text)
detections = detection_filtering(
- detections, groundtruths) # filters detections overlapping with DC area
+ detections, groundtruths
+ ) # filters detections overlapping with DC area
for idx in range(len(groundtruths) - 1, -1, -1):
- #NOTE: source code use 'orin' to indicate '#', here we use 'anno',
+ # NOTE: source code use 'orin' to indicate '#', here we use 'anno',
# which may cause slight drop in fscore, about 0.12
- if groundtruths[idx]['transcription'] == '###':
+ if groundtruths[idx]["transcription"] == "###":
groundtruths.pop(idx)
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
@@ -366,10 +398,9 @@ def tau_calculation(det_p, gt_p):
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
- point_num = gt['points'].shape[1] // 2
+ point_num = gt["points"].shape[1] // 2
- gt_p = np.array(gt['points']).reshape(point_num,
- 2).astype('int32')
+ gt_p = np.array(gt["points"]).reshape(point_num, 2).astype("int32")
gt_p = plg.Polygon(gt_p)
det_y = detection[0::2]
@@ -380,15 +411,14 @@ def tau_calculation(det_p, gt_p):
det_p = det_p.reshape(2, -1).transpose()
det_p = plg.Polygon(det_p)
- local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
- gt_p)
+ local_sigma_table[gt_id, det_id] = sigma_calculation(det_p, gt_p)
local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
data = {}
- data['sigma'] = local_sigma_table
- data['global_tau'] = local_tau_table
- data['global_pred_str'] = ''
- data['global_gt_str'] = ''
+ data["sigma"] = local_sigma_table
+ data["global_tau"] = local_tau_table
+ data["global_pred_str"] = ""
+ data["global_gt_str"] = ""
return data
@@ -403,10 +433,10 @@ def combine_results(all_data, rec_flag=True):
global_gt_str = []
for data in all_data:
- global_sigma.append(data['sigma'])
- global_tau.append(data['global_tau'])
- global_pred_str.append(data['global_pred_str'])
- global_gt_str.append(data['global_gt_str'])
+ global_sigma.append(data["sigma"])
+ global_tau.append(data["global_tau"])
+ global_pred_str.append(data["global_pred_str"])
+ global_gt_str.append(data["global_gt_str"])
global_accumulative_recall = 0
global_accumulative_precision = 0
@@ -415,35 +445,52 @@ def combine_results(all_data, rec_flag=True):
hit_str_count = 0
hit_count = 0
- def one_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy, rec_flag):
+ def one_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idy,
+ rec_flag,
+ ):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
- local_sigma_table[gt_id, :] > tr)
- gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
- 0].shape[0]
+ local_sigma_table[gt_id, :] > tr
+ )
+ gt_matching_num_qualified_sigma_candidates = (
+ gt_matching_qualified_sigma_candidates[0].shape[0]
+ )
gt_matching_qualified_tau_candidates = np.where(
- local_tau_table[gt_id, :] > tp)
- gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
- 0].shape[0]
+ local_tau_table[gt_id, :] > tp
+ )
+ gt_matching_num_qualified_tau_candidates = (
+ gt_matching_qualified_tau_candidates[0].shape[0]
+ )
det_matching_qualified_sigma_candidates = np.where(
- local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
- > tr)
- det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
- 0].shape[0]
+ local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr
+ )
+ det_matching_num_qualified_sigma_candidates = (
+ det_matching_qualified_sigma_candidates[0].shape[0]
+ )
det_matching_qualified_tau_candidates = np.where(
- local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
- tp)
- det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
- 0].shape[0]
-
- if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
- (det_matching_num_qualified_sigma_candidates == 1) and (
- det_matching_num_qualified_tau_candidates == 1):
+ local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp
+ )
+ det_matching_num_qualified_tau_candidates = (
+ det_matching_qualified_tau_candidates[0].shape[0]
+ )
+
+ if (
+ (gt_matching_num_qualified_sigma_candidates == 1)
+ and (gt_matching_num_qualified_tau_candidates == 1)
+ and (det_matching_num_qualified_sigma_candidates == 1)
+ and (det_matching_num_qualified_tau_candidates == 1)
+ ):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
@@ -454,8 +501,7 @@ def one_to_one(local_sigma_table, local_tau_table,
# recg start
if rec_flag:
gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][matched_det_id[0]
- .tolist()[0]]
+ pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
@@ -463,12 +509,28 @@ def one_to_one(local_sigma_table, local_tau_table,
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
- return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
-
- def one_to_many(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy, rec_flag):
+ return (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num,
+ )
+
+ def one_to_many(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idy,
+ rec_flag,
+ ):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
@@ -480,21 +542,24 @@ def one_to_many(local_sigma_table, local_tau_table,
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
- qualified_tau_candidates = np.where((local_tau_table[
- gt_id, :] >= tp) & (det_flag[0, :] == 0))
- num_qualified_tau_candidates = qualified_tau_candidates[
- 0].shape[0]
+ qualified_tau_candidates = np.where(
+ (local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0)
+ )
+ num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
if num_qualified_tau_candidates == 1:
- if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
- and
- (local_sigma_table[gt_id, qualified_tau_candidates] >=
- tr)):
+ if (local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
+ local_sigma_table[gt_id, qualified_tau_candidates] >= tr
+ ):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = global_accumulative_precision + 1.0
+ global_accumulative_precision = (
+ global_accumulative_precision + 1.0
+ )
local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = local_accumulative_precision + 1.0
+ local_accumulative_precision = (
+ local_accumulative_precision + 1.0
+ )
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
@@ -502,22 +567,23 @@ def one_to_many(local_sigma_table, local_tau_table,
if rec_flag:
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
+ qualified_tau_candidates[0].tolist()[0]
+ ]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
- elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
- >= tr):
+ elif np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr:
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
if rec_flag:
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
+ qualified_tau_candidates[0].tolist()[0]
+ ]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
@@ -526,17 +592,39 @@ def one_to_many(local_sigma_table, local_tau_table,
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
- global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
+ global_accumulative_precision = (
+ global_accumulative_precision
+ + num_qualified_tau_candidates * fsc_k
+ )
local_accumulative_recall = local_accumulative_recall + fsc_k
- local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
-
- return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
-
- def many_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy, rec_flag):
+ local_accumulative_precision = (
+ local_accumulative_precision
+ + num_qualified_tau_candidates * fsc_k
+ )
+
+ return (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num,
+ )
+
+ def many_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idy,
+ rec_flag,
+ ):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
@@ -548,21 +636,24 @@ def many_to_one(local_sigma_table, local_tau_table,
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
- qualified_sigma_candidates = np.where((
- local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
- num_qualified_sigma_candidates = qualified_sigma_candidates[
- 0].shape[0]
+ qualified_sigma_candidates = np.where(
+ (local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)
+ )
+ num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
if num_qualified_sigma_candidates == 1:
- if ((local_tau_table[qualified_sigma_candidates, det_id] >=
- tp) and
- (local_sigma_table[qualified_sigma_candidates, det_id]
- >= tr)):
+ if (local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
+ local_sigma_table[qualified_sigma_candidates, det_id] >= tr
+ ):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = global_accumulative_precision + 1.0
+ global_accumulative_precision = (
+ global_accumulative_precision + 1.0
+ )
local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = local_accumulative_precision + 1.0
+ local_accumulative_precision = (
+ local_accumulative_precision + 1.0
+ )
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
@@ -571,8 +662,7 @@ def many_to_one(local_sigma_table, local_tau_table,
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
- ele_gt_id = qualified_sigma_candidates[
- 0].tolist()[idx]
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
@@ -580,13 +670,11 @@ def many_to_one(local_sigma_table, local_tau_table,
hit_str_num += 1
break
else:
- if pred_str_cur.lower() == gt_str_cur.lower(
- ):
+ if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
- elif (np.sum(local_tau_table[qualified_sigma_candidates,
- det_id]) >= tp):
+ elif np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp:
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
@@ -594,8 +682,7 @@ def many_to_one(local_sigma_table, local_tau_table,
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
- ele_gt_id = qualified_sigma_candidates[0].tolist()[
- idx]
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
@@ -608,12 +695,28 @@ def many_to_one(local_sigma_table, local_tau_table,
break
# recg end
- global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
- global_accumulative_precision = global_accumulative_precision + fsc_k
-
- local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
+ global_accumulative_recall = (
+ global_accumulative_recall
+ + num_qualified_sigma_candidates * fsc_k
+ )
+ global_accumulative_precision = (
+ global_accumulative_precision + fsc_k
+ )
+
+ local_accumulative_recall = (
+ local_accumulative_recall
+ + num_qualified_sigma_candidates * fsc_k
+ )
local_accumulative_precision = local_accumulative_precision + fsc_k
- return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
+ return (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num,
+ )
for idx in range(len(global_sigma)):
local_sigma_table = np.array(global_sigma[idx])
@@ -631,26 +734,71 @@ def many_to_one(local_sigma_table, local_tau_table,
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
- local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
- gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx, rec_flag)
+ (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num,
+ ) = one_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idx,
+ rec_flag,
+ )
hit_str_count += hit_str_num
#######then check for one-to-many case##########
- local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
- gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx, rec_flag)
+ (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num,
+ ) = one_to_many(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idx,
+ rec_flag,
+ )
hit_str_count += hit_str_num
#######then check for many-to-one case##########
- local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
- gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx, rec_flag)
+ (
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ hit_str_num,
+ ) = many_to_one(
+ local_sigma_table,
+ local_tau_table,
+ local_accumulative_recall,
+ local_accumulative_precision,
+ global_accumulative_recall,
+ global_accumulative_precision,
+ gt_flag,
+ det_flag,
+ idx,
+ rec_flag,
+ )
hit_str_count += hit_str_num
try:
@@ -684,22 +832,21 @@ def many_to_one(local_sigma_table, local_tau_table,
precision_e2e = 0
try:
- f_score_e2e = 2 * precision_e2e * recall_e2e / (
- precision_e2e + recall_e2e)
+ f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
except ZeroDivisionError:
f_score_e2e = 0
final = {
- 'total_num_gt': total_num_gt,
- 'total_num_det': total_num_det,
- 'global_accumulative_recall': global_accumulative_recall,
- 'hit_str_count': hit_str_count,
- 'recall': recall,
- 'precision': precision,
- 'f_score': f_score,
- 'seqerr': seqerr,
- 'recall_e2e': recall_e2e,
- 'precision_e2e': precision_e2e,
- 'f_score_e2e': f_score_e2e
+ "total_num_gt": total_num_gt,
+ "total_num_det": total_num_det,
+ "global_accumulative_recall": global_accumulative_recall,
+ "hit_str_count": hit_str_count,
+ "recall": recall,
+ "precision": precision,
+ "f_score": f_score,
+ "seqerr": seqerr,
+ "recall_e2e": recall_e2e,
+ "precision_e2e": precision_e2e,
+ "f_score_e2e": f_score_e2e,
}
return final
diff --git a/ppocr/utils/e2e_metric/polygon_fast.py b/ppocr/utils/e2e_metric/polygon_fast.py
index 81c9ad7067..2e2d947c84 100755
--- a/ppocr/utils/e2e_metric/polygon_fast.py
+++ b/ppocr/utils/e2e_metric/polygon_fast.py
@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
from shapely.geometry import Polygon
+
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
@@ -72,12 +73,12 @@ def area_of_union(det_x, det_y, gt_x, gt_y):
def iou(det_x, det_y, gt_x, gt_y):
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
- area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
+ area_of_union(det_x, det_y, gt_x, gt_y) + 1.0
+ )
def iod(det_x, det_y, gt_x, gt_y):
"""
This helper determine the fraction of intersection area over detection area
"""
- return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
- area(det_x, det_y) + 1.0)
+ return area_of_intersection(det_x, det_y, gt_x, gt_y) / (area(det_x, det_y) + 1.0)
diff --git a/ppocr/utils/e2e_utils/extract_batchsize.py b/ppocr/utils/e2e_utils/extract_batchsize.py
index e99a833ea7..f1ab77bd0f 100644
--- a/ppocr/utils/e2e_utils/extract_batchsize.py
+++ b/ppocr/utils/e2e_utils/extract_batchsize.py
@@ -4,8 +4,7 @@
def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
- """
- """
+ """ """
pos_lists_, pos_masks_, label_lists_ = [], [], []
img_bs = batch_size
ngpu = int(batch_size / img_bs)
@@ -51,8 +50,9 @@ def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
return pos_lists_, pos_masks_, label_lists_
-def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
- pad_num, tcl_bs):
+def pre_process(
+ label_list, pos_list, pos_mask, max_text_length, max_text_nums, pad_num, tcl_bs
+):
label_list = label_list.numpy()
batch, _, _, _ = label_list.shape
pos_list = pos_list.numpy()
@@ -66,8 +66,9 @@ def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j])
- pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
- label_list_t, tcl_bs)
+ pos_list, pos_mask, label_list = org_tcl_rois(
+ batch, pos_list_t, pos_mask_t, label_list_t, tcl_bs
+ )
label = []
tt = [l.tolist() for l in label_list]
for i in range(tcl_bs):
@@ -79,9 +80,9 @@ def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
break
label.append(k)
label = paddle.to_tensor(label)
- label = paddle.cast(label, dtype='int64')
+ label = paddle.cast(label, dtype="int64")
pos_list = paddle.to_tensor(pos_list)
pos_mask = paddle.to_tensor(pos_mask)
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
- label_list = paddle.cast(label_list, dtype='int32')
+ label_list = paddle.cast(label_list, dtype="int32")
return pos_list, pos_mask, label_list, label
diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py
index a85b8e78ea..67a89e2d43 100644
--- a/ppocr/utils/e2e_utils/extract_textpoint_fast.py
+++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py
@@ -29,7 +29,7 @@ def get_dict(character_dict_path):
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
@@ -83,36 +83,39 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
- raw_str, remove_blank=remove_blank_in_pos)
+ raw_str, remove_blank=remove_blank_in_pos
+ )
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
-def instance_ctc_greedy_decoder(gather_info,
- logits_map,
- pts_num=4,
- point_gather_mode=None):
+def instance_ctc_greedy_decoder(
+ gather_info, logits_map, pts_num=4, point_gather_mode=None
+):
_, _, C = logits_map.shape
- if point_gather_mode == 'align':
+ if point_gather_mode == "align":
insert_num = 0
gather_info = np.array(gather_info)
length = len(gather_info) - 1
for index in range(length):
- stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
- index + 1 + insert_num][0])
- stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
- index + 1 + insert_num][1])
+ stride_y = np.abs(
+ gather_info[index + insert_num][0]
+ - gather_info[index + 1 + insert_num][0]
+ )
+ stride_x = np.abs(
+ gather_info[index + insert_num][1]
+ - gather_info[index + 1 + insert_num][1]
+ )
max_points = int(max(stride_x, stride_y))
- stride = (gather_info[index + insert_num] -
- gather_info[index + 1 + insert_num]) / (max_points)
+ stride = (
+ gather_info[index + insert_num] - gather_info[index + 1 + insert_num]
+ ) / (max_points)
insert_num_temp = max_points - 1
for i in range(int(insert_num_temp)):
- insert_value = gather_info[index + insert_num] - (i + 1
- ) * stride
+ insert_value = gather_info[index + insert_num] - (i + 1) * stride
insert_index = index + i + 1 + insert_num
- gather_info = np.insert(
- gather_info, insert_index, insert_value, axis=0)
+ gather_info = np.insert(gather_info, insert_index, insert_value, axis=0)
insert_num += insert_num_temp
gather_info = gather_info.tolist()
else:
@@ -128,11 +131,9 @@ def instance_ctc_greedy_decoder(gather_info,
return dst_str, keep_gather_list
-def ctc_decoder_for_image(gather_info_list,
- logits_map,
- Lexicon_Table,
- pts_num=6,
- point_gather_mode=None):
+def ctc_decoder_for_image(
+ gather_info_list, logits_map, Lexicon_Table, pts_num=6, point_gather_mode=None
+):
"""
CTC decoder using multiple processes.
"""
@@ -145,8 +146,9 @@ def ctc_decoder_for_image(gather_info_list,
gather_info,
logits_map,
pts_num=pts_num,
- point_gather_mode=point_gather_mode)
- dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
+ point_gather_mode=point_gather_mode,
+ )
+ dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2:
continue
decoder_str.append(dst_str_readable)
@@ -172,8 +174,7 @@ def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
- sorted_point, sorted_direction = sort_part_with_direction(pos_list,
- point_direction)
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
point_num = len(sorted_point)
if point_num >= 16:
@@ -181,12 +182,14 @@ def sort_part_with_direction(pos_list, point_direction):
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
- first_part_point, first_point_direction)
+ first_part_point, first_point_direction
+ )
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
- last_part_point, last_point_direction)
+ last_part_point, last_point_direction
+ )
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
@@ -214,7 +217,7 @@ def sort_and_expand_with_direction(pos_list, f_direction):
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
- right_dirction = point_direction[point_num - sub_direction_len:, :]
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
@@ -226,17 +229,24 @@ def sort_and_expand_with_direction(pos_list, f_direction):
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
- append_num = max(
- int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
- ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ly, lx = (
+ np.round(left_start + left_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
- ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ry, rx = (
+ np.round(right_start + right_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
@@ -256,7 +266,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
- right_dirction = point_direction[point_num - sub_direction_len:, :]
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
@@ -268,15 +278,18 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
- append_num = max(
- int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
- ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ly, lx = (
+ np.round(left_start + left_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
@@ -284,8 +297,12 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
break
for i in range(max_append_num):
- ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ry, rx = (
+ np.round(right_start + right_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
@@ -308,9 +325,8 @@ def point_pair2poly(point_pair_list):
return np.array(point_list).reshape(-1, 2)
-def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
- ratio_pair = np.array(
- [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
+ ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
@@ -321,19 +337,25 @@ def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
expand poly along width.
"""
point_num = poly.shape[0]
- left_quad = np.array(
- [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
- left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = (
+ -shrink_ratio_of_width
+ * np.linalg.norm(left_quad[0] - left_quad[3])
+ / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ )
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
- poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]
+ poly[point_num // 2 - 2],
+ poly[point_num // 2 - 1],
+ poly[point_num // 2],
+ poly[point_num // 2 + 1],
],
- dtype=np.float32)
- right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ dtype=np.float32,
+ )
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
+ right_quad[0] - right_quad[3]
+ ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
@@ -342,53 +364,59 @@ def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
return poly
-def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
- src_h, valid_set):
+def restore_poly(
+ instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, valid_set
+):
poly_list = []
keep_str_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(keep_str) < 2:
- print('--> too short, {}'.format(keep_str))
+ print("--> too short, {}".format(keep_str))
continue
offset_expand = 1.0
- if valid_set == 'totaltext':
+ if valid_set == "totaltext":
offset_expand = 1.2
point_pair_list = []
for y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
- [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair = (
+ (ori_yx + offset)[:, ::-1]
+ * 4.0
+ / np.array([ratio_w, ratio_h]).reshape(-1, 2)
+ )
point_pair_list.append(point_pair)
detected_poly = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
- detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly, shrink_ratio_of_width=0.2
+ )
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
keep_str_list.append(keep_str)
- if valid_set == 'partvgg':
+ if valid_set == "partvgg":
middle_point = len(detected_poly) // 2
- detected_poly = detected_poly[
- [0, middle_point - 1, middle_point, -1], :]
+ detected_poly = detected_poly[[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
- elif valid_set == 'totaltext':
+ elif valid_set == "totaltext":
poly_list.append(detected_poly)
else:
- print('--> Not supported format.')
+ print("--> Not supported format.")
exit(-1)
return poly_list, keep_str_list
-def generate_pivot_list_fast(p_score,
- p_char_maps,
- f_direction,
- Lexicon_Table,
- score_thresh=0.5,
- point_gather_mode=None):
+def generate_pivot_list_fast(
+ p_score,
+ p_char_maps,
+ f_direction,
+ Lexicon_Table,
+ score_thresh=0.5,
+ point_gather_mode=None,
+):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
@@ -397,7 +425,8 @@ def generate_pivot_list_fast(p_score,
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map.astype(np.uint8))
instance_count, instance_label_map = cv2.connectedComponents(
- skeleton_map.astype(np.uint8), connectivity=8)
+ skeleton_map.astype(np.uint8), connectivity=8
+ )
# get TCL Instance
all_pos_yxs = []
@@ -411,7 +440,8 @@ def generate_pivot_list_fast(p_score,
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
- pos_list, f_direction, p_tcl_map)
+ pos_list, f_direction, p_tcl_map
+ )
all_pos_yxs.append(pos_list_sorted)
p_char_maps = p_char_maps.transpose([1, 2, 0])
@@ -419,7 +449,8 @@ def generate_pivot_list_fast(p_score,
all_pos_yxs,
logits_map=p_char_maps,
Lexicon_Table=Lexicon_Table,
- point_gather_mode=point_gather_mode)
+ point_gather_mode=point_gather_mode,
+ )
return keep_yxs_list, decoded_str
@@ -432,8 +463,7 @@ def extract_main_direction(pos_list, f_direction):
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
- average_direction = average_direction / (
- np.linalg.norm(average_direction) + 1e-6)
+ average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
return average_direction
@@ -471,8 +501,7 @@ def sort_part_with_direction(pos_list_full, point_direction):
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
- sorted_point, sorted_direction = sort_part_with_direction(pos_list,
- point_direction)
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
point_num = len(sorted_point)
if point_num >= 16:
@@ -480,12 +509,14 @@ def sort_part_with_direction(pos_list_full, point_direction):
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
- first_part_point, first_point_direction)
+ first_part_point, first_point_direction
+ )
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
- last_part_point, last_point_direction)
+ last_part_point, last_point_direction
+ )
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
diff --git a/ppocr/utils/e2e_utils/extract_textpoint_slow.py b/ppocr/utils/e2e_utils/extract_textpoint_slow.py
index ace46fba37..123e900c67 100644
--- a/ppocr/utils/e2e_utils/extract_textpoint_slow.py
+++ b/ppocr/utils/e2e_utils/extract_textpoint_slow.py
@@ -29,7 +29,7 @@ def get_dict(character_dict_path):
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
@@ -44,8 +44,11 @@ def point_pair2poly(point_pair_list):
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
- pair_info = (pair_length_list.max(), pair_length_list.min(),
- pair_length_list.mean())
+ pair_info = (
+ pair_length_list.max(),
+ pair_length_list.min(),
+ pair_length_list.mean(),
+ )
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
@@ -55,12 +58,11 @@ def point_pair2poly(point_pair_list):
return np.array(point_list).reshape(-1, 2), pair_info
-def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
"""
Generate shrink_quad_along_width.
"""
- ratio_pair = np.array(
- [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
@@ -71,20 +73,25 @@ def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
expand poly along width.
"""
point_num = poly.shape[0]
- left_quad = np.array(
- [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
- left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = (
+ -shrink_ratio_of_width
+ * np.linalg.norm(left_quad[0] - left_quad[3])
+ / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ )
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
- poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]
+ poly[point_num // 2 - 2],
+ poly[point_num // 2 - 1],
+ poly[point_num // 2],
+ poly[point_num // 2 + 1],
],
- dtype=np.float32)
- right_ratio = 1.0 + \
- shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ dtype=np.float32,
+ )
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
+ right_quad[0] - right_quad[3]
+ ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
@@ -141,14 +148,13 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
- raw_str, remove_blank=remove_blank_in_pos)
+ raw_str, remove_blank=remove_blank_in_pos
+ )
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
-def instance_ctc_greedy_decoder(gather_info,
- logits_map,
- keep_blank_in_idxs=True):
+def instance_ctc_greedy_decoder(gather_info, logits_map, keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
@@ -158,20 +164,21 @@ def instance_ctc_greedy_decoder(gather_info,
logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = softmax(logits_seq)
dst_str, keep_idx_list = ctc_greedy_decoder(
- probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
+ probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs
+ )
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
-def ctc_decoder_for_image(gather_info_list, logits_map,
- keep_blank_in_idxs=True):
+def ctc_decoder_for_image(gather_info_list, logits_map, keep_blank_in_idxs=True):
"""
CTC decoder using multiple processes.
"""
decoder_results = []
for gather_info in gather_info_list:
res = instance_ctc_greedy_decoder(
- gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
+ gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs
+ )
decoder_results.append(res)
return decoder_results
@@ -194,8 +201,7 @@ def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
- sorted_point, sorted_direction = sort_part_with_direction(pos_list,
- point_direction)
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
point_num = len(sorted_point)
if point_num >= 16:
@@ -203,12 +209,14 @@ def sort_part_with_direction(pos_list, point_direction):
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
- first_part_point, first_point_direction)
+ first_part_point, first_point_direction
+ )
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
- last_part_point, last_point_direction)
+ last_part_point, last_point_direction
+ )
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
@@ -237,7 +245,7 @@ def sort_and_expand_with_direction(pos_list, f_direction):
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
- right_dirction = point_direction[point_num - sub_direction_len:, :]
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
@@ -249,17 +257,24 @@ def sort_and_expand_with_direction(pos_list, f_direction):
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
- append_num = max(
- int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
- ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ly, lx = (
+ np.round(left_start + left_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
- ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ry, rx = (
+ np.round(right_start + right_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
@@ -280,7 +295,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
- right_dirction = point_direction[point_num - sub_direction_len:, :]
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
@@ -292,15 +307,18 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
- append_num = max(
- int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
- ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ly, lx = (
+ np.round(left_start + left_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
@@ -308,8 +326,12 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
break
for i in range(max_append_num):
- ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
- 'int32').tolist()
+ ry, rx = (
+ np.round(right_start + right_step * (i + 1))
+ .flatten()
+ .astype("int32")
+ .tolist()
+ )
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
@@ -320,13 +342,15 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
return all_list
-def generate_pivot_list_curved(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_expand=True,
- is_backbone=False,
- image_id=0):
+def generate_pivot_list_curved(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_expand=True,
+ is_backbone=False,
+ image_id=0,
+):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
@@ -335,7 +359,8 @@ def generate_pivot_list_curved(p_score,
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
- skeleton_map.astype(np.uint8), connectivity=8)
+ skeleton_map.astype(np.uint8), connectivity=8
+ )
# get TCL Instance
all_pos_yxs = []
@@ -355,7 +380,8 @@ def generate_pivot_list_curved(p_score,
if is_expand:
pos_list_sorted = sort_and_expand_with_direction_v2(
- pos_list, f_direction, p_tcl_map)
+ pos_list, f_direction, p_tcl_map
+ )
else:
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
@@ -363,7 +389,8 @@ def generate_pivot_list_curved(p_score,
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
- all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
+ )
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
@@ -379,12 +406,9 @@ def generate_pivot_list_curved(p_score,
return center_pos_yxs, end_points_yxs
-def generate_pivot_list_horizontal(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_backbone=False,
- image_id=0):
+def generate_pivot_list_horizontal(
+ p_score, p_char_maps, f_direction, score_thresh=0.5, is_backbone=False, image_id=0
+):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
@@ -392,7 +416,8 @@ def generate_pivot_list_horizontal(p_score,
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map_bi = (p_score > score_thresh) * 1.0
instance_count, instance_label_map = cv2.connectedComponents(
- p_tcl_map_bi.astype(np.uint8), connectivity=8)
+ p_tcl_map_bi.astype(np.uint8), connectivity=8
+ )
# get TCL Instance
all_pos_yxs = []
@@ -411,12 +436,11 @@ def generate_pivot_list_horizontal(p_score,
continue
# add rule here
- main_direction = extract_main_direction(pos_list,
- f_direction) # y x
+ main_direction = extract_main_direction(pos_list, f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
- is_h_angle = abs(np.sum(
- main_direction * reference_directin)) < math.cos(math.pi / 180 *
- 70)
+ is_h_angle = abs(np.sum(main_direction * reference_directin)) < math.cos(
+ math.pi / 180 * 70
+ )
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
@@ -427,24 +451,24 @@ def generate_pivot_list_horizontal(p_score,
if is_h_len:
xs = np.unique(xs)
for x in xs:
- ys = instance_label_map[:, x].copy().reshape((-1, ))
+ ys = instance_label_map[:, x].copy().reshape((-1,))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
- xs = instance_label_map[y, :].copy().reshape((-1, ))
+ xs = instance_label_map[y, :].copy().reshape((-1,))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
- pos_list_sorted, _ = sort_with_direction(pos_list_final,
- f_direction)
+ pos_list_sorted, _ = sort_with_direction(pos_list_final, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
- all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
+ )
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
@@ -459,13 +483,15 @@ def generate_pivot_list_horizontal(p_score,
return center_pos_yxs, end_points_yxs
-def generate_pivot_list_slow(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_backbone=False,
- is_curved=True,
- image_id=0):
+def generate_pivot_list_slow(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0,
+):
"""
Warp all the function together.
"""
@@ -477,7 +503,8 @@ def generate_pivot_list_slow(p_score,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
- image_id=image_id)
+ image_id=image_id,
+ )
else:
return generate_pivot_list_horizontal(
p_score,
@@ -485,7 +512,8 @@ def generate_pivot_list_slow(p_score,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
- image_id=image_id)
+ image_id=image_id,
+ )
# for refine module
@@ -498,8 +526,7 @@ def extract_main_direction(pos_list, f_direction):
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
- average_direction = average_direction / (
- np.linalg.norm(average_direction) + 1e-6)
+ average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
return average_direction
@@ -537,8 +564,7 @@ def sort_part_with_direction(pos_list_full, point_direction):
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
- sorted_point, sorted_direction = sort_part_with_direction(pos_list,
- point_direction)
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
point_num = len(sorted_point)
if point_num >= 16:
@@ -546,25 +572,29 @@ def sort_part_with_direction(pos_list_full, point_direction):
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
- first_part_point, first_point_direction)
+ first_part_point, first_point_direction
+ )
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
- last_part_point, last_point_direction)
+ last_part_point, last_point_direction
+ )
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
-def generate_pivot_list_tt_inference(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_backbone=False,
- is_curved=True,
- image_id=0):
+def generate_pivot_list_tt_inference(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0,
+):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
@@ -573,7 +603,8 @@ def generate_pivot_list_tt_inference(p_score,
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
- skeleton_map.astype(np.uint8), connectivity=8)
+ skeleton_map.astype(np.uint8), connectivity=8
+ )
# get TCL Instance
all_pos_yxs = []
@@ -586,7 +617,8 @@ def generate_pivot_list_tt_inference(p_score,
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
- pos_list, f_direction, p_tcl_map)
+ pos_list, f_direction, p_tcl_map
+ )
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py
index 06a766b0e7..71379e5e17 100644
--- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py
+++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py
@@ -21,20 +21,22 @@
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
-sys.path.append(os.path.join(__dir__, '..'))
+sys.path.append(os.path.join(__dir__, ".."))
from extract_textpoint_slow import *
from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object):
# two different post-process
- def __init__(self,
- character_dict_path,
- valid_set,
- score_thresh,
- outs_dict,
- shape_list,
- point_gather_mode=None):
+ def __init__(
+ self,
+ character_dict_path,
+ valid_set,
+ score_thresh,
+ outs_dict,
+ shape_list,
+ point_gather_mode=None,
+ ):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
@@ -43,10 +45,10 @@ def __init__(self,
self.point_gather_mode = point_gather_mode
def pg_postprocess_fast(self):
- p_score = self.outs_dict['f_score']
- p_border = self.outs_dict['f_border']
- p_char = self.outs_dict['f_char']
- p_direction = self.outs_dict['f_direction']
+ p_score = self.outs_dict["f_score"]
+ p_border = self.outs_dict["f_border"]
+ p_char = self.outs_dict["f_char"]
+ p_direction = self.outs_dict["f_direction"]
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
@@ -65,21 +67,29 @@ def pg_postprocess_fast(self):
p_direction,
self.Lexicon_Table,
score_thresh=self.score_thresh,
- point_gather_mode=self.point_gather_mode)
- poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
- p_border, ratio_w, ratio_h,
- src_w, src_h, self.valid_set)
+ point_gather_mode=self.point_gather_mode,
+ )
+ poly_list, keep_str_list = restore_poly(
+ instance_yxs_list,
+ seq_strs,
+ p_border,
+ ratio_w,
+ ratio_h,
+ src_w,
+ src_h,
+ self.valid_set,
+ )
data = {
- 'points': poly_list,
- 'texts': keep_str_list,
+ "points": poly_list,
+ "texts": keep_str_list,
}
return data
def pg_postprocess_slow(self):
- p_score = self.outs_dict['f_score']
- p_border = self.outs_dict['f_border']
- p_char = self.outs_dict['f_char']
- p_direction = self.outs_dict['f_direction']
+ p_score = self.outs_dict["f_score"]
+ p_border = self.outs_dict["f_border"]
+ p_char = self.outs_dict["f_char"]
+ p_direction = self.outs_dict["f_direction"]
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
@@ -98,10 +108,11 @@ def pg_postprocess_slow(self):
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
- is_curved=is_curved)
+ is_curved=is_curved,
+ )
seq_strs = []
for char_idx_set in char_seq_idx_set:
- pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
+ pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
@@ -112,58 +123,57 @@ def pg_postprocess_slow(self):
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
- if self.valid_set == 'totaltext':
+ if self.valid_set == "totaltext":
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
- offset_length = np.linalg.norm(
- offset, axis=1, keepdims=True)
+ offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
expand_length = np.clip(
- offset_length * (offset_expand - 1),
- a_min=0.5,
- a_max=3.0)
+ offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
+ )
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
- [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair = (
+ (ori_yx + offset)[:, ::-1]
+ * 4.0
+ / np.array([ratio_w, ratio_h]).reshape(-1, 2)
+ )
point_pair_list.append(point_pair)
- all_point_list.append([
- int(round(x * 4.0 / ratio_w)),
- int(round(y * 4.0 / ratio_h))
- ])
- all_point_pair_list.append(point_pair.round().astype(np.int32)
- .tolist())
+ all_point_list.append(
+ [int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h))]
+ )
+ all_point_pair_list.append(point_pair.round().astype(np.int32).tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
- detected_poly, shrink_ratio_of_width=0.2)
- detected_poly[:, 0] = np.clip(
- detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(
- detected_poly[:, 1], a_min=0, a_max=src_h)
+ detected_poly, shrink_ratio_of_width=0.2
+ )
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
- detected_poly = np.round(detected_poly).astype('int32')
- if self.valid_set == 'partvgg':
+ detected_poly = np.round(detected_poly).astype("int32")
+ if self.valid_set == "partvgg":
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
- [0, middle_point - 1, middle_point, -1], :]
+ [0, middle_point - 1, middle_point, -1], :
+ ]
poly_list.append(detected_poly)
- elif self.valid_set == 'totaltext':
+ elif self.valid_set == "totaltext":
poly_list.append(detected_poly)
else:
- print('--> Not supported format.')
+ print("--> Not supported format.")
exit(-1)
data = {
- 'points': poly_list,
- 'texts': keep_str_list,
+ "points": poly_list,
+ "texts": keep_str_list,
}
return data
diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py
index e6e4fd0667..e321827996 100644
--- a/ppocr/utils/e2e_utils/visual.py
+++ b/ppocr/utils/e2e_utils/visual.py
@@ -47,8 +47,7 @@ def resize_image(im, max_side_len=512):
def resize_image_min(im, max_side_len=512):
- """
- """
+ """ """
h, w, _ = im.shape
resize_w = w
@@ -72,8 +71,7 @@ def resize_image_min(im, max_side_len=512):
def resize_image_for_totaltext(im, max_side_len=512):
- """
- """
+ """ """
h, w, _ = im.shape
resize_w = w
@@ -103,8 +101,11 @@ def point_pair2poly(point_pair_list):
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
- pair_info = (pair_length_list.max(), pair_length_list.min(),
- pair_length_list.mean())
+ pair_info = (
+ pair_length_list.max(),
+ pair_length_list.min(),
+ pair_length_list.mean(),
+ )
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
@@ -114,12 +115,11 @@ def point_pair2poly(point_pair_list):
return np.array(point_list).reshape(-1, 2), pair_info
-def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
"""
Generate shrink_quad_along_width.
"""
- ratio_pair = np.array(
- [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
@@ -130,20 +130,25 @@ def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
expand poly along width.
"""
point_num = poly.shape[0]
- left_quad = np.array(
- [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
- left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = (
+ -shrink_ratio_of_width
+ * np.linalg.norm(left_quad[0] - left_quad[3])
+ / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ )
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
- poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]
+ poly[point_num // 2 - 2],
+ poly[point_num // 2 - 1],
+ poly[point_num // 2],
+ poly[point_num // 2 + 1],
],
- dtype=np.float32)
- right_ratio = 1.0 + \
- shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ dtype=np.float32,
+ )
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
+ right_quad[0] - right_quad[3]
+ ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
diff --git a/ppocr/utils/gen_label.py b/ppocr/utils/gen_label.py
index 56d75544db..1634b386cc 100644
--- a/ppocr/utils/gen_label.py
+++ b/ppocr/utils/gen_label.py
@@ -17,61 +17,62 @@
def gen_rec_label(input_path, out_label):
- with open(out_label, 'w') as out_file:
- with open(input_path, 'r') as f:
+ with open(out_label, "w") as out_file:
+ with open(input_path, "r") as f:
for line in f.readlines():
- tmp = line.strip('\n').replace(" ", "").split(',')
+ tmp = line.strip("\n").replace(" ", "").split(",")
img_path, label = tmp[0], tmp[1]
- label = label.replace("\"", "")
- out_file.write(img_path + '\t' + label + '\n')
+ label = label.replace('"', "")
+ out_file.write(img_path + "\t" + label + "\n")
def gen_det_label(root_path, input_dir, out_label):
- with open(out_label, 'w') as out_file:
+ with open(out_label, "w") as out_file:
for label_file in os.listdir(input_dir):
img_path = os.path.join(root_path, label_file[3:-4] + ".jpg")
label = []
with open(
- os.path.join(input_dir, label_file), 'r',
- encoding='utf-8-sig') as f:
+ os.path.join(input_dir, label_file), "r", encoding="utf-8-sig"
+ ) as f:
for line in f.readlines():
- tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
- "").split(',')
+ tmp = line.strip("\n\r").replace("\xef\xbb\xbf", "").split(",")
points = tmp[:8]
s = []
for i in range(0, len(points), 2):
- b = points[i:i + 2]
+ b = points[i : i + 2]
b = [int(t) for t in b]
s.append(b)
result = {"transcription": tmp[8], "points": s}
label.append(result)
- out_file.write(img_path + '\t' + json.dumps(
- label, ensure_ascii=False) + '\n')
+ out_file.write(
+ img_path + "\t" + json.dumps(label, ensure_ascii=False) + "\n"
+ )
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- '--mode',
+ "--mode",
type=str,
default="rec",
- help='Generate rec_label or det_label, can be set rec or det')
+ help="Generate rec_label or det_label, can be set rec or det",
+ )
parser.add_argument(
- '--root_path',
+ "--root_path",
type=str,
default=".",
- help='The root directory of images.Only takes effect when mode=det ')
+ help="The root directory of images.Only takes effect when mode=det ",
+ )
parser.add_argument(
- '--input_path',
+ "--input_path",
type=str,
default=".",
- help='Input_label or input path to be converted')
+ help="Input_label or input path to be converted",
+ )
parser.add_argument(
- '--output_label',
- type=str,
- default="out_label.txt",
- help='Output file name')
+ "--output_label", type=str, default="out_label.txt", help="Output file name"
+ )
args = parser.parse_args()
if args.mode == "rec":
diff --git a/ppocr/utils/iou.py b/ppocr/utils/iou.py
index 35459f5f05..cb77c3437e 100644
--- a/ppocr/utils/iou.py
+++ b/ppocr/utils/iou.py
@@ -31,8 +31,8 @@ def iou_single(a, b, mask, n_class):
inter = paddle.to_tensor(0.0)
union = paddle.to_tensor(0.0)
else:
- inter = ((a == i).logical_and(b == i)).astype('float32')
- union = ((a == i).logical_or(b == i)).astype('float32')
+ inter = ((a == i).logical_and(b == i)).astype("float32")
+ union = ((a == i).logical_or(b == i)).astype("float32")
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
miou = sum(miou) / len(miou)
return miou
@@ -45,7 +45,7 @@ def iou(a, b, mask, n_class=2, reduce=True):
b = b.reshape([batch_size, -1])
mask = mask.reshape([batch_size, -1])
- iou = paddle.zeros((batch_size, ), dtype='float32')
+ iou = paddle.zeros((batch_size,), dtype="float32")
for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
diff --git a/ppocr/utils/loggers/base_logger.py b/ppocr/utils/loggers/base_logger.py
index 3a7fc3593b..0aa132de2c 100644
--- a/ppocr/utils/loggers/base_logger.py
+++ b/ppocr/utils/loggers/base_logger.py
@@ -1,6 +1,7 @@
import os
from abc import ABC, abstractmethod
+
class BaseLogger(ABC):
def __init__(self, save_dir):
self.save_dir = save_dir
@@ -12,4 +13,4 @@ def log_metrics(self, metrics, prefix=None):
@abstractmethod
def close(self):
- pass
\ No newline at end of file
+ pass
diff --git a/ppocr/utils/loggers/loggers.py b/ppocr/utils/loggers/loggers.py
index 2601466208..a14dbcb95e 100644
--- a/ppocr/utils/loggers/loggers.py
+++ b/ppocr/utils/loggers/loggers.py
@@ -1,5 +1,6 @@
from .wandb_logger import WandbLogger
+
class Loggers(object):
def __init__(self, loggers):
super().__init__()
@@ -8,11 +9,11 @@ def __init__(self, loggers):
def log_metrics(self, metrics, prefix=None, step=None):
for logger in self.loggers:
logger.log_metrics(metrics, prefix=prefix, step=step)
-
+
def log_model(self, is_best, prefix, metadata=None):
for logger in self.loggers:
logger.log_model(is_best=is_best, prefix=prefix, metadata=metadata)
-
+
def close(self):
for logger in self.loggers:
- logger.close()
\ No newline at end of file
+ logger.close()
diff --git a/ppocr/utils/loggers/vdl_logger.py b/ppocr/utils/loggers/vdl_logger.py
index c345f93235..d2851e67a2 100644
--- a/ppocr/utils/loggers/vdl_logger.py
+++ b/ppocr/utils/loggers/vdl_logger.py
@@ -1,6 +1,7 @@
from .base_logger import BaseLogger
from visualdl import LogWriter
+
class VDLLogger(BaseLogger):
def __init__(self, save_dir):
super().__init__(save_dir)
@@ -13,9 +14,9 @@ def log_metrics(self, metrics, prefix=None, step=None):
for k, v in updated_metrics.items():
self.vdl_writer.add_scalar(k, v, step)
-
+
def log_model(self, is_best, prefix, metadata=None):
pass
-
+
def close(self):
- self.vdl_writer.close()
\ No newline at end of file
+ self.vdl_writer.close()
diff --git a/ppocr/utils/loggers/wandb_logger.py b/ppocr/utils/loggers/wandb_logger.py
index b9c6711696..83596d86d1 100644
--- a/ppocr/utils/loggers/wandb_logger.py
+++ b/ppocr/utils/loggers/wandb_logger.py
@@ -1,22 +1,24 @@
import os
from .base_logger import BaseLogger
+
class WandbLogger(BaseLogger):
- def __init__(self,
- project=None,
- name=None,
- id=None,
- entity=None,
- save_dir=None,
+ def __init__(
+ self,
+ project=None,
+ name=None,
+ id=None,
+ entity=None,
+ save_dir=None,
config=None,
- **kwargs):
+ **kwargs
+ ):
try:
import wandb
+
self.wandb = wandb
except ModuleNotFoundError:
- raise ModuleNotFoundError(
- "Please install wandb using `pip install wandb`"
- )
+ raise ModuleNotFoundError("Please install wandb using `pip install wandb`")
self.project = project
self.name = name
@@ -32,7 +34,7 @@ def __init__(self,
id=self.id,
entity=self.entity,
dir=self.save_dir,
- resume="allow"
+ resume="allow",
)
self._wandb_init.update(**kwargs)
@@ -60,12 +62,14 @@ def log_metrics(self, metrics, prefix=None, step=None):
if not prefix:
prefix = ""
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
-
+
self.run.log(updated_metrics, step=step)
def log_model(self, is_best, prefix, metadata=None):
- model_path = os.path.join(self.save_dir, prefix + '.pdparams')
- artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
+ model_path = os.path.join(self.save_dir, prefix + ".pdparams")
+ artifact = self.wandb.Artifact(
+ "model-{}".format(self.run.id), type="model", metadata=metadata
+ )
artifact.add_file(model_path, name="model_ckpt.pdparams")
aliases = [prefix]
@@ -75,4 +79,4 @@ def log_model(self, is_best, prefix, metadata=None):
self.run.log_artifact(artifact, aliases=aliases)
def close(self):
- self.run.finish()
\ No newline at end of file
+ self.run.finish()
diff --git a/ppocr/utils/logging.py b/ppocr/utils/logging.py
index 1eac8f351a..945bb3ee75 100644
--- a/ppocr/utils/logging.py
+++ b/ppocr/utils/logging.py
@@ -26,7 +26,7 @@
@functools.lru_cache()
-def get_logger(name='ppocr', log_file=None, log_level=logging.DEBUG):
+def get_logger(name="ppocr", log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
@@ -50,8 +50,8 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.DEBUG):
return logger
formatter = logging.Formatter(
- '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
- datefmt="%Y/%m/%d %H:%M:%S")
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
+ )
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
@@ -59,7 +59,7 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.DEBUG):
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
- file_handler = logging.FileHandler(log_file, 'a')
+ file_handler = logging.FileHandler(log_file, "a")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py
index f2cd690e12..fd2d8a64ba 100644
--- a/ppocr/utils/network.py
+++ b/ppocr/utils/network.py
@@ -27,11 +27,10 @@ def download_with_progressbar(url, save_path):
logger = get_logger()
response = requests.get(url, stream=True)
if response.status_code == 200:
- total_size_in_bytes = int(response.headers.get('content-length', 1))
+ total_size_in_bytes = int(response.headers.get("content-length", 1))
block_size = 1024 # 1 Kibibyte
- progress_bar = tqdm(
- total=total_size_in_bytes, unit='iB', unit_scale=True)
- with open(save_path, 'wb') as file:
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
+ with open(save_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
@@ -43,28 +42,25 @@ def download_with_progressbar(url, save_path):
def maybe_download(model_storage_directory, url):
# using custom model
- tar_file_name_list = ['.pdiparams', '.pdiparams.info', '.pdmodel']
+ tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel"]
if not os.path.exists(
- os.path.join(model_storage_directory, 'inference.pdiparams')
- ) or not os.path.exists(
- os.path.join(model_storage_directory, 'inference.pdmodel')):
- assert url.endswith('.tar'), 'Only supports tar compressed package'
- tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
- print('download {} to {}'.format(url, tmp_path))
+ os.path.join(model_storage_directory, "inference.pdiparams")
+ ) or not os.path.exists(os.path.join(model_storage_directory, "inference.pdmodel")):
+ assert url.endswith(".tar"), "Only supports tar compressed package"
+ tmp_path = os.path.join(model_storage_directory, url.split("/")[-1])
+ print("download {} to {}".format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
- with tarfile.open(tmp_path, 'r') as tarObj:
+ with tarfile.open(tmp_path, "r") as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if member.name.endswith(tar_file_name):
- filename = 'inference' + tar_file_name
+ filename = "inference" + tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
- with open(
- os.path.join(model_storage_directory, filename),
- 'wb') as f:
+ with open(os.path.join(model_storage_directory, filename), "wb") as f:
f.write(file.read())
os.remove(tmp_path)
@@ -74,15 +70,15 @@ def maybe_download_params(model_path):
return model_path
else:
url = model_path
- tmp_path = os.path.join(MODELS_DIR, url.split('/')[-1])
- print('download {} to {}'.format(url, tmp_path))
+ tmp_path = os.path.join(MODELS_DIR, url.split("/")[-1])
+ print("download {} to {}".format(url, tmp_path))
os.makedirs(MODELS_DIR, exist_ok=True)
download_with_progressbar(url, tmp_path)
return tmp_path
def is_link(s):
- return s is not None and s.startswith('http')
+ return s is not None and s.startswith("http")
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
@@ -90,7 +86,7 @@ def confirm_model_dir_url(model_dir, default_model_dir, default_url):
if model_dir is None or is_link(model_dir):
if is_link(model_dir):
url = model_dir
- file_name = url.split('/')[-1][:-4]
+ file_name = url.split("/")[-1][:-4]
model_dir = default_model_dir
model_dir = os.path.join(model_dir, file_name)
return model_dir, url
diff --git a/ppocr/utils/poly_nms.py b/ppocr/utils/poly_nms.py
index 9dcb3d2c2f..c3a1338fa8 100644
--- a/ppocr/utils/poly_nms.py
+++ b/ppocr/utils/poly_nms.py
@@ -136,7 +136,7 @@ def poly_nms(polygons, threshold):
keep_poly.append(polygons[index[-1]].tolist())
A = polygons[index[-1]][:-1]
index = np.delete(index, -1)
- iou_list = np.zeros((len(index), ))
+ iou_list = np.zeros((len(index),))
for i in range(len(index)):
B = polygons[index[i]][:-1]
iou_list[i] = boundary_iou(A, B)
diff --git a/ppocr/utils/profiler.py b/ppocr/utils/profiler.py
index 629ef4ef05..e4e3e05649 100644
--- a/ppocr/utils/profiler.py
+++ b/ppocr/utils/profiler.py
@@ -23,8 +23,9 @@
_profiler_options = None
_prof = None
+
class ProfilerOptions(object):
- '''
+ """
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
@@ -34,7 +35,7 @@ class ProfilerOptions(object):
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
- state - a string, the optional values are 'CPU', 'GPU' or 'All'.
+ state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
@@ -42,60 +43,60 @@ class ProfilerOptions(object):
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
- '''
+ """
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
- 'batch_range': [10, 20],
- 'state': 'All',
- 'sorted_key': 'total',
- 'tracer_option': 'Default',
- 'profile_path': '/tmp/profile',
- 'exit_on_finished': True,
- 'timer_only': True
+ "batch_range": [10, 20],
+ "state": "All",
+ "sorted_key": "total",
+ "tracer_option": "Default",
+ "profile_path": "/tmp/profile",
+ "exit_on_finished": True,
+ "timer_only": True,
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
- for kv in options_str.replace(' ', '').split(';'):
- key, value = kv.split('=')
- if key == 'batch_range':
- value_list = value.replace('[', '').replace(']', '').split(',')
+ for kv in options_str.replace(" ", "").split(";"):
+ key, value = kv.split("=")
+ if key == "batch_range":
+ value_list = value.replace("[", "").replace("]", "").split(",")
value_list = list(map(int, value_list))
- if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
- 1] > value_list[0]:
+ if (
+ len(value_list) >= 2
+ and value_list[0] >= 0
+ and value_list[1] > value_list[0]
+ ):
self._options[key] = value_list
- elif key == 'exit_on_finished':
+ elif key == "exit_on_finished":
self._options[key] = value.lower() in ("yes", "true", "t", "1")
- elif key in [
- 'state', 'sorted_key', 'tracer_option', 'profile_path'
- ]:
+ elif key in ["state", "sorted_key", "tracer_option", "profile_path"]:
self._options[key] = value
- elif key == 'timer_only':
+ elif key == "timer_only":
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
- raise ValueError(
- "ProfilerOptions does not have an option named %s." % name)
+ raise ValueError("ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
- '''
+ """
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
- '''
+ """
if options_str is None:
return
- global _prof
+ global _prof
global _profiler_step_id
global _profiler_options
@@ -106,23 +107,24 @@ def add_profiler_step(options_str=None):
# timer_only = False calling summary can print a statistical form that presents performance data from different perspectives.
# timer_only = False the output Timeline information can be found in the profiler_log directory
if _prof is None:
- _timer_only = str(_profiler_options['timer_only']) == str(True)
+ _timer_only = str(_profiler_options["timer_only"]) == str(True)
_prof = profiler.Profiler(
- scheduler = (_profiler_options['batch_range'][0], _profiler_options['batch_range'][1]),
- on_trace_ready = profiler.export_chrome_tracing('./profiler_log'),
- timer_only = _timer_only)
+ scheduler=(
+ _profiler_options["batch_range"][0],
+ _profiler_options["batch_range"][1],
+ ),
+ on_trace_ready=profiler.export_chrome_tracing("./profiler_log"),
+ timer_only=_timer_only,
+ )
_prof.start()
else:
_prof.step()
-
- if _profiler_step_id == _profiler_options['batch_range'][1]:
+
+ if _profiler_step_id == _profiler_options["batch_range"][1]:
_prof.stop()
- _prof.summary(
- op_detail=True,
- thread_sep=False,
- time_unit='ms')
+ _prof.summary(op_detail=True, thread_sep=False, time_unit="ms")
_prof = None
- if _profiler_options['exit_on_finished']:
+ if _profiler_options["exit_on_finished"]:
sys.exit(0)
_profiler_step_id += 1
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index e6a81c48df..c397d5e0c6 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -26,7 +26,7 @@
from ppocr.utils.logging import get_logger
from ppocr.utils.network import maybe_download_params
-__all__ = ['load_model']
+__all__ = ["load_model"]
def _mkdir_if_not_exist(path, logger):
@@ -39,69 +39,74 @@ def _mkdir_if_not_exist(path, logger):
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
- 'be happy if some process has already created {}'.format(
- path))
+ "be happy if some process has already created {}".format(path)
+ )
else:
- raise OSError('Failed to mkdir {}'.format(path))
+ raise OSError("Failed to mkdir {}".format(path))
-def load_model(config, model, optimizer=None, model_type='det'):
+def load_model(config, model, optimizer=None, model_type="det"):
"""
load model from checkpoint or pretrained_model
"""
logger = get_logger()
- global_config = config['Global']
- checkpoints = global_config.get('checkpoints')
- pretrained_model = global_config.get('pretrained_model')
+ global_config = config["Global"]
+ checkpoints = global_config.get("checkpoints")
+ pretrained_model = global_config.get("pretrained_model")
best_model_dict = {}
is_float16 = False
- is_nlp_model = model_type == 'kie' and config["Architecture"][
- "algorithm"] not in ["SDMGR"]
+ is_nlp_model = model_type == "kie" and config["Architecture"]["algorithm"] not in [
+ "SDMGR"
+ ]
if is_nlp_model is True:
# NOTE: for kie model dsitillation, resume training is not supported now
if config["Architecture"]["algorithm"] in ["Distillation"]:
return best_model_dict
- checkpoints = config['Architecture']['Backbone']['checkpoints']
+ checkpoints = config["Architecture"]["Backbone"]["checkpoints"]
# load kie method metric
if checkpoints:
- if os.path.exists(os.path.join(checkpoints, 'metric.states')):
- with open(os.path.join(checkpoints, 'metric.states'),
- 'rb') as f:
- states_dict = pickle.load(f) if six.PY2 else pickle.load(
- f, encoding='latin1')
- best_model_dict = states_dict.get('best_model_dict', {})
- if 'epoch' in states_dict:
- best_model_dict['start_epoch'] = states_dict['epoch'] + 1
+ if os.path.exists(os.path.join(checkpoints, "metric.states")):
+ with open(os.path.join(checkpoints, "metric.states"), "rb") as f:
+ states_dict = (
+ pickle.load(f) if six.PY2 else pickle.load(f, encoding="latin1")
+ )
+ best_model_dict = states_dict.get("best_model_dict", {})
+ if "epoch" in states_dict:
+ best_model_dict["start_epoch"] = states_dict["epoch"] + 1
logger.info("resume from {}".format(checkpoints))
if optimizer is not None:
- if checkpoints[-1] in ['/', '\\']:
+ if checkpoints[-1] in ["/", "\\"]:
checkpoints = checkpoints[:-1]
- if os.path.exists(checkpoints + '.pdopt'):
- optim_dict = paddle.load(checkpoints + '.pdopt')
+ if os.path.exists(checkpoints + ".pdopt"):
+ optim_dict = paddle.load(checkpoints + ".pdopt")
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
- "{}.pdopt is not exists, params of optimizer is not loaded".
- format(checkpoints))
+ "{}.pdopt is not exists, params of optimizer is not loaded".format(
+ checkpoints
+ )
+ )
return best_model_dict
if checkpoints:
- if checkpoints.endswith('.pdparams'):
- checkpoints = checkpoints.replace('.pdparams', '')
- assert os.path.exists(checkpoints + ".pdparams"), \
- "The {}.pdparams does not exists!".format(checkpoints)
+ if checkpoints.endswith(".pdparams"):
+ checkpoints = checkpoints.replace(".pdparams", "")
+ assert os.path.exists(
+ checkpoints + ".pdparams"
+ ), "The {}.pdparams does not exists!".format(checkpoints)
# load params from trained model
- params = paddle.load(checkpoints + '.pdparams')
+ params = paddle.load(checkpoints + ".pdparams")
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
- logger.warning("{} not in loaded params {} !".format(
- key, params.keys()))
+ logger.warning(
+ "{} not in loaded params {} !".format(key, params.keys())
+ )
continue
pre_value = params[key]
if pre_value.dtype == paddle.float16:
@@ -112,47 +117,53 @@ def load_model(config, model, optimizer=None, model_type='det'):
new_state_dict[key] = pre_value
else:
logger.warning(
- "The shape of model params {} {} not matched with loaded params shape {} !".
- format(key, value.shape, pre_value.shape))
+ "The shape of model params {} {} not matched with loaded params shape {} !".format(
+ key, value.shape, pre_value.shape
+ )
+ )
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
"The parameter type is float16, which is converted to float32 when loading"
)
if optimizer is not None:
- if os.path.exists(checkpoints + '.pdopt'):
- optim_dict = paddle.load(checkpoints + '.pdopt')
+ if os.path.exists(checkpoints + ".pdopt"):
+ optim_dict = paddle.load(checkpoints + ".pdopt")
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
- "{}.pdopt is not exists, params of optimizer is not loaded".
- format(checkpoints))
-
- if os.path.exists(checkpoints + '.states'):
- with open(checkpoints + '.states', 'rb') as f:
- states_dict = pickle.load(f) if six.PY2 else pickle.load(
- f, encoding='latin1')
- best_model_dict = states_dict.get('best_model_dict', {})
- if 'epoch' in states_dict:
- best_model_dict['start_epoch'] = states_dict['epoch'] + 1
+ "{}.pdopt is not exists, params of optimizer is not loaded".format(
+ checkpoints
+ )
+ )
+
+ if os.path.exists(checkpoints + ".states"):
+ with open(checkpoints + ".states", "rb") as f:
+ states_dict = (
+ pickle.load(f) if six.PY2 else pickle.load(f, encoding="latin1")
+ )
+ best_model_dict = states_dict.get("best_model_dict", {})
+ if "epoch" in states_dict:
+ best_model_dict["start_epoch"] = states_dict["epoch"] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
is_float16 = load_pretrained_params(model, pretrained_model)
else:
- logger.info('train from scratch')
- best_model_dict['is_float16'] = is_float16
+ logger.info("train from scratch")
+ best_model_dict["is_float16"] = is_float16
return best_model_dict
def load_pretrained_params(model, path):
logger = get_logger()
path = maybe_download_params(path)
- if path.endswith('.pdparams'):
- path = path.replace('.pdparams', '')
- assert os.path.exists(path + ".pdparams"), \
- "The {}.pdparams does not exists!".format(path)
+ if path.endswith(".pdparams"):
+ path = path.replace(".pdparams", "")
+ assert os.path.exists(
+ path + ".pdparams"
+ ), "The {}.pdparams does not exists!".format(path)
- params = paddle.load(path + '.pdparams')
+ params = paddle.load(path + ".pdparams")
state_dict = model.state_dict()
@@ -160,7 +171,6 @@ def load_pretrained_params(model, path):
is_float16 = False
for k1 in params.keys():
-
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
@@ -172,8 +182,10 @@ def load_pretrained_params(model, path):
new_state_dict[k1] = params[k1]
else:
logger.warning(
- "The shape of model params {} {} not matched with loaded params {} {} !".
- format(k1, state_dict[k1].shape, k1, params[k1].shape))
+ "The shape of model params {} {} not matched with loaded params {} {} !".format(
+ k1, state_dict[k1].shape, k1, params[k1].shape
+ )
+ )
model.set_state_dict(new_state_dict)
if is_float16:
@@ -184,56 +196,61 @@ def load_pretrained_params(model, path):
return is_float16
-def save_model(model,
- optimizer,
- model_path,
- logger,
- config,
- is_best=False,
- prefix='ppocr',
- **kwargs):
+def save_model(
+ model,
+ optimizer,
+ model_path,
+ logger,
+ config,
+ is_best=False,
+ prefix="ppocr",
+ **kwargs
+):
"""
save model to the target path
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
- if prefix == 'best_accuracy':
- best_model_path = os.path.join(model_path, 'best_model')
+ if prefix == "best_accuracy":
+ best_model_path = os.path.join(model_path, "best_model")
_mkdir_if_not_exist(best_model_path, logger)
- paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
- if prefix == 'best_accuracy':
- paddle.save(optimizer.state_dict(),
- os.path.join(best_model_path, 'model.pdopt'))
+ paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
+ if prefix == "best_accuracy":
+ paddle.save(
+ optimizer.state_dict(), os.path.join(best_model_path, "model.pdopt")
+ )
- is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[
- "Architecture"]["algorithm"] not in ["SDMGR"]
+ is_nlp_model = config["Architecture"]["model_type"] == "kie" and config[
+ "Architecture"
+ ]["algorithm"] not in ["SDMGR"]
if is_nlp_model is not True:
- paddle.save(model.state_dict(), model_prefix + '.pdparams')
+ paddle.save(model.state_dict(), model_prefix + ".pdparams")
metric_prefix = model_prefix
- if prefix == 'best_accuracy':
- paddle.save(model.state_dict(),
- os.path.join(best_model_path, 'model.pdparams'))
+ if prefix == "best_accuracy":
+ paddle.save(
+ model.state_dict(), os.path.join(best_model_path, "model.pdparams")
+ )
else: # for kie system, we follow the save/load rules in NLP
- if config['Global']['distributed']:
+ if config["Global"]["distributed"]:
arch = model._layers
else:
arch = model
if config["Architecture"]["algorithm"] in ["Distillation"]:
arch = arch.Student
arch.backbone.model.save_pretrained(model_prefix)
- metric_prefix = os.path.join(model_prefix, 'metric')
+ metric_prefix = os.path.join(model_prefix, "metric")
- if prefix == 'best_accuracy':
+ if prefix == "best_accuracy":
arch.backbone.model.save_pretrained(best_model_path)
# save metric and config
- with open(metric_prefix + '.states', 'wb') as f:
+ with open(metric_prefix + ".states", "wb") as f:
pickle.dump(kwargs, f, protocol=2)
if is_best:
- logger.info('save best model is to {}'.format(model_prefix))
+ logger.info("save best model is to {}".format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
diff --git a/ppocr/utils/stats.py b/ppocr/utils/stats.py
index 179b0082f1..6dd8c5856f 100755
--- a/ppocr/utils/stats.py
+++ b/ppocr/utils/stats.py
@@ -16,7 +16,7 @@
import numpy as np
import datetime
-__all__ = ['TrainingStats', 'Time']
+__all__ = ["TrainingStats", "Time"]
class SmoothedValue(object):
@@ -35,22 +35,20 @@ def get_median_value(self):
def Time():
- return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
class TrainingStats(object):
def __init__(self, window_size, stats_keys):
self.window_size = window_size
self.smoothed_losses_and_metrics = {
- key: SmoothedValue(window_size)
- for key in stats_keys
+ key: SmoothedValue(window_size) for key in stats_keys
}
def update(self, stats):
for k, v in stats.items():
if k not in self.smoothed_losses_and_metrics:
- self.smoothed_losses_and_metrics[k] = SmoothedValue(
- self.window_size)
+ self.smoothed_losses_and_metrics[k] = SmoothedValue(self.window_size)
self.smoothed_losses_and_metrics[k].add_value(v)
def get(self, extras=None):
@@ -67,6 +65,6 @@ def log(self, extras=None):
d = self.get(extras)
strs = []
for k, v in d.items():
- strs.append('{}: {:x<6f}'.format(k, v))
- strs = ', '.join(strs)
+ strs.append("{}: {:x<6f}".format(k, v))
+ strs = ", ".join(strs)
return strs
diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py
index 9108a37281..1a49c3106b 100644
--- a/ppocr/utils/visual.py
+++ b/ppocr/utils/visual.py
@@ -18,22 +18,22 @@
from PIL import Image, ImageDraw, ImageFont
-def draw_ser_results(image,
- ocr_results,
- font_path="doc/fonts/simfang.ttf",
- font_size=14):
+def draw_ser_results(
+ image, ocr_results, font_path="doc/fonts/simfang.ttf", font_size=14
+):
np.random.seed(2021)
- color = (np.random.permutation(range(255)),
- np.random.permutation(range(255)),
- np.random.permutation(range(255)))
+ color = (
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ )
color_map = {
- idx: (color[0][idx], color[1][idx], color[2][idx])
- for idx in range(1, 255)
+ idx: (color[0][idx], color[1][idx], color[2][idx]) for idx in range(1, 255)
}
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
- image = Image.open(image).convert('RGB')
+ image = Image.open(image).convert("RGB")
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
@@ -57,23 +57,23 @@ def draw_ser_results(image,
def draw_box_txt(bbox, text, draw, font, font_size, color):
-
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)
# draw ocr results
- if int(PIL.__version__.split('.')[0]) < 10:
+ if int(PIL.__version__.split(".")[0]) < 10:
tw = font.getsize(text)[0]
th = font.getsize(text)[1]
else:
left, top, right, bottom = font.getbbox(text)
tw, th = right - left, bottom - top
-
+
start_y = max(0, bbox[0][1] - th)
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + th)],
- fill=(0, 0, 255))
+ fill=(0, 0, 255),
+ )
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
@@ -85,15 +85,12 @@ def trans_poly_to_bbox(poly):
return [x1, y1, x2, y2]
-def draw_re_results(image,
- result,
- font_path="doc/fonts/simfang.ttf",
- font_size=18):
+def draw_re_results(image, result, font_path="doc/fonts/simfang.ttf", font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
- image = Image.open(image).convert('RGB')
+ image = Image.open(image).convert("RGB")
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
@@ -103,17 +100,31 @@ def draw_re_results(image,
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
- draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
- draw, font, font_size, color_head)
- draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
- draw, font, font_size, color_tail)
+ draw_box_txt(
+ ocr_info_head["bbox"],
+ ocr_info_head["transcription"],
+ draw,
+ font,
+ font_size,
+ color_head,
+ )
+ draw_box_txt(
+ ocr_info_tail["bbox"],
+ ocr_info_tail["transcription"],
+ draw,
+ font,
+ font_size,
+ color_tail,
+ )
center_head = (
- (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
- (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
+ (ocr_info_head["bbox"][0] + ocr_info_head["bbox"][2]) // 2,
+ (ocr_info_head["bbox"][1] + ocr_info_head["bbox"][3]) // 2,
+ )
center_tail = (
- (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
- (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
+ (ocr_info_tail["bbox"][0] + ocr_info_tail["bbox"][2]) // 2,
+ (ocr_info_tail["bbox"][1] + ocr_info_tail["bbox"][3]) // 2,
+ )
draw.line([center_head, center_tail], fill=color_line, width=5)
@@ -128,4 +139,4 @@ def draw_rectangle(img_path, boxes):
for box in boxes.astype(int):
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
- return img_show
\ No newline at end of file
+ return img_show
diff --git a/ppstructure/kie/predict_kie_token_ser.py b/ppstructure/kie/predict_kie_token_ser.py
index e570979bcb..57e8eb607e 100644
--- a/ppstructure/kie/predict_kie_token_ser.py
+++ b/ppstructure/kie/predict_kie_token_ser.py
@@ -16,9 +16,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
@@ -45,62 +45,64 @@ def __init__(self, args):
det_model_dir=args.det_model_dir,
rec_model_dir=args.rec_model_dir,
show_log=False,
- use_gpu=args.use_gpu)
-
- pre_process_list = [{
- 'VQATokenLabelEncode': {
- 'algorithm': args.kie_algorithm,
- 'class_path': args.ser_dict_path,
- 'contains_re': False,
- 'ocr_engine': self.ocr_engine,
- 'order_method': args.ocr_order_method,
- }
- }, {
- 'VQATokenPad': {
- 'max_seq_len': 512,
- 'return_attention_mask': True
- }
- }, {
- 'VQASerTokenChunk': {
- 'max_seq_len': 512,
- 'return_attention_mask': True
- }
- }, {
- 'Resize': {
- 'size': [224, 224]
- }
- }, {
- 'NormalizeImage': {
- 'std': [58.395, 57.12, 57.375],
- 'mean': [123.675, 116.28, 103.53],
- 'scale': '1',
- 'order': 'hwc'
- }
- }, {
- 'ToCHWImage': None
- }, {
- 'KeepKeys': {
- 'keep_keys': [
- 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
- 'image', 'labels', 'segment_offset_id', 'ocr_info',
- 'entities'
- ]
- }
- }]
+ use_gpu=args.use_gpu,
+ )
+
+ pre_process_list = [
+ {
+ "VQATokenLabelEncode": {
+ "algorithm": args.kie_algorithm,
+ "class_path": args.ser_dict_path,
+ "contains_re": False,
+ "ocr_engine": self.ocr_engine,
+ "order_method": args.ocr_order_method,
+ }
+ },
+ {"VQATokenPad": {"max_seq_len": 512, "return_attention_mask": True}},
+ {"VQASerTokenChunk": {"max_seq_len": 512, "return_attention_mask": True}},
+ {"Resize": {"size": [224, 224]}},
+ {
+ "NormalizeImage": {
+ "std": [58.395, 57.12, 57.375],
+ "mean": [123.675, 116.28, 103.53],
+ "scale": "1",
+ "order": "hwc",
+ }
+ },
+ {"ToCHWImage": None},
+ {
+ "KeepKeys": {
+ "keep_keys": [
+ "input_ids",
+ "bbox",
+ "attention_mask",
+ "token_type_ids",
+ "image",
+ "labels",
+ "segment_offset_id",
+ "ocr_info",
+ "entities",
+ ]
+ }
+ },
+ ]
postprocess_params = {
- 'name': 'VQASerTokenLayoutLMPostProcess',
+ "name": "VQASerTokenLayoutLMPostProcess",
"class_path": args.ser_dict_path,
}
- self.preprocess_op = create_operators(pre_process_list,
- {'infer_mode': True})
+ self.preprocess_op = create_operators(pre_process_list, {"infer_mode": True})
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, self.config = \
- utility.create_predictor(args, 'ser', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "ser", logger)
def __call__(self, img):
ori_im = img.copy()
- data = {'image': img}
+ data = {"image": img}
data = transform(data, self.preprocess_op)
if data[0] is None:
return None, 0
@@ -124,7 +126,8 @@ def __call__(self, img):
preds = outputs[0]
post_result = self.postprocess_op(
- preds, segment_offset_ids=data[6], ocr_infos=data[7])
+ preds, segment_offset_ids=data[6], ocr_infos=data[7]
+ )
elapse = time.time() - starttime
return post_result, data, elapse
@@ -137,8 +140,8 @@ def main(args):
os.makedirs(args.output, exist_ok=True)
with open(
- os.path.join(args.output, 'infer.txt'), mode='w',
- encoding='utf-8') as f_w:
+ os.path.join(args.output, "infer.txt"), mode="w", encoding="utf-8"
+ ) as f_w:
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
@@ -150,21 +153,24 @@ def main(args):
ser_res, _, elapse = ser_predictor(img)
ser_res = ser_res[0]
- res_str = '{}\t{}\n'.format(
+ res_str = "{}\t{}\n".format(
image_file,
json.dumps(
{
"ocr_info": ser_res,
- }, ensure_ascii=False))
+ },
+ ensure_ascii=False,
+ ),
+ )
f_w.write(res_str)
img_res = draw_ser_results(
image_file,
ser_res,
- font_path=args.vis_font_path, )
+ font_path=args.vis_font_path,
+ )
- img_save_path = os.path.join(args.output,
- os.path.basename(image_file))
+ img_save_path = os.path.join(args.output, os.path.basename(image_file))
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py
index b29a8f69db..ebf59882b9 100644
--- a/ppstructure/kie/predict_kie_token_ser_re.py
+++ b/ppstructure/kie/predict_kie_token_ser_re.py
@@ -16,9 +16,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
@@ -42,10 +42,14 @@ def __init__(self, args):
self.use_visual_backbone = args.use_visual_backbone
self.ser_engine = SerPredictor(args)
if args.re_model_dir is not None:
- postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
+ postprocess_params = {"name": "VQAReTokenLayoutLMPostProcess"}
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, self.config = \
- utility.create_predictor(args, 're', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "re", logger)
else:
self.predictor = None
@@ -69,12 +73,12 @@ def __call__(self, img):
preds = dict(
loss=outputs[1],
pred_relations=outputs[2],
- hidden_states=outputs[0], )
+ hidden_states=outputs[0],
+ )
post_result = self.postprocess_op(
- preds,
- ser_results=ser_results,
- entity_idx_dict_batch=entity_idx_dict_batch)
+ preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch
+ )
elapse = time.time() - starttime
return post_result, elapse
@@ -88,8 +92,8 @@ def main(args):
os.makedirs(args.output, exist_ok=True)
with open(
- os.path.join(args.output, 'infer.txt'), mode='w',
- encoding='utf-8') as f_w:
+ os.path.join(args.output, "infer.txt"), mode="w", encoding="utf-8"
+ ) as f_w:
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
@@ -101,27 +105,32 @@ def main(args):
re_res, elapse = ser_re_predictor(img)
re_res = re_res[0]
- res_str = '{}\t{}\n'.format(
+ res_str = "{}\t{}\n".format(
image_file,
json.dumps(
{
"ocr_info": re_res,
- }, ensure_ascii=False))
+ },
+ ensure_ascii=False,
+ ),
+ )
f_w.write(res_str)
if ser_re_predictor.predictor is not None:
img_res = draw_re_results(
- image_file, re_res, font_path=args.vis_font_path)
+ image_file, re_res, font_path=args.vis_font_path
+ )
img_save_path = os.path.join(
args.output,
- os.path.splitext(os.path.basename(image_file))[0] +
- "_ser_re.jpg")
+ os.path.splitext(os.path.basename(image_file))[0] + "_ser_re.jpg",
+ )
else:
img_res = draw_ser_results(
- image_file, re_res, font_path=args.vis_font_path)
+ image_file, re_res, font_path=args.vis_font_path
+ )
img_save_path = os.path.join(
args.output,
- os.path.splitext(os.path.basename(image_file))[0] +
- "_ser.jpg")
+ os.path.splitext(os.path.basename(image_file))[0] + "_ser.jpg",
+ )
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
diff --git a/ppstructure/kie/tools/eval_with_label_end2end.py b/ppstructure/kie/tools/eval_with_label_end2end.py
index b0fd84363f..f97da12861 100644
--- a/ppstructure/kie/tools/eval_with_label_end2end.py
+++ b/ppstructure/kie/tools/eval_with_label_end2end.py
@@ -37,7 +37,7 @@ def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
assert fp_type in ["gt", "pred"]
key = "label" if fp_type == "gt" else "pred"
res_dict = dict()
- with open(fp, "r", encoding='utf-8') as fin:
+ with open(fp, "r", encoding="utf-8") as fin:
lines = fin.readlines()
for _, line in enumerate(lines):
@@ -71,8 +71,7 @@ def polygon_iou(poly1, poly2):
"""
Intersection over union between two shapely polygons.
"""
- if not poly1.intersects(
- poly2): # this test is fast and can accelerate calculation
+ if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
iou = 0
else:
try:
@@ -82,7 +81,7 @@ def polygon_iou(poly1, poly2):
except shapely.geos.TopologicalError:
# except Exception as e:
# print(e)
- print('shapely.geos.TopologicalError occurred, iou set to 0')
+ print("shapely.geos.TopologicalError occurred, iou set to 0")
iou = 0
return iou
@@ -109,11 +108,11 @@ def convert_bbox_to_polygon(bbox):
def eval_e2e(args):
# gt
- gt_results = parse_ser_results_fp(args.gt_json_path, "gt",
- args.ignore_background)
+ gt_results = parse_ser_results_fp(args.gt_json_path, "gt", args.ignore_background)
# pred
- dt_results = parse_ser_results_fp(args.pred_json_path, "pred",
- args.ignore_background)
+ dt_results = parse_ser_results_fp(
+ args.pred_json_path, "pred", args.ignore_background
+ )
iou_thresh = args.iou_thres
num_gt_chars = 0
gt_count = 0
@@ -144,8 +143,7 @@ def eval_e2e(args):
iou = polygon_iou(dt_poly, gt_poly)
if iou >= iou_thresh:
all_ious[(index_gt, index_dt)] = iou
- sorted_ious = sorted(
- all_ious.items(), key=operator.itemgetter(1), reverse=True)
+ sorted_ious = sorted(all_ious.items(), key=operator.itemgetter(1), reverse=True)
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
# matched gt and dt
@@ -169,14 +167,14 @@ def eval_e2e(args):
if args.ignore_ser_prediction or gt_label == dt_label:
hit += 1
-# unmatched dt
+ # unmatched dt
for tindex, dt_match_flag in enumerate(dt_match):
if dt_match_flag == False:
dt_text = dt_info[tindex]["text"]
gt_text = ""
ed_sum += ed(args, dt_text, gt_text)
-# unmatched gt
+ # unmatched gt
for tindex, gt_match_flag in enumerate(gt_match):
if gt_match_flag == False:
dt_text = ""
@@ -186,7 +184,7 @@ def eval_e2e(args):
eps = 1e-9
print("config: ", args)
- print('hit, dt_count, gt_count', hit, dt_count, gt_count)
+ print("hit, dt_count, gt_count", hit, dt_count, gt_count)
precision = hit / (dt_count + eps)
recall = hit / (gt_count + eps)
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
@@ -194,19 +192,18 @@ def eval_e2e(args):
avg_edit_dist_field = ed_sum / (gt_count + eps)
character_acc = 1 - ed_sum / (num_gt_chars + eps)
- print('character_acc: %.2f' % (character_acc * 100) + "%")
- print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
- print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
- print('precision: %.2f' % (precision * 100) + "%")
- print('recall: %.2f' % (recall * 100) + "%")
- print('fmeasure: %.2f' % (fmeasure * 100) + "%")
+ print("character_acc: %.2f" % (character_acc * 100) + "%")
+ print("avg_edit_dist_field: %.2f" % (avg_edit_dist_field))
+ print("avg_edit_dist_img: %.2f" % (avg_edit_dist_img))
+ print("precision: %.2f" % (precision * 100) + "%")
+ print("recall: %.2f" % (recall * 100) + "%")
+ print("fmeasure: %.2f" % (fmeasure * 100) + "%")
return
def parse_args():
- """
- """
+ """ """
def str2bool(v):
return v.lower() in ("true", "t", "1")
@@ -217,12 +214,14 @@ def str2bool(v):
"--gt_json_path",
default=None,
type=str,
- required=True, )
+ required=True,
+ )
parser.add_argument(
"--pred_json_path",
default=None,
type=str,
- required=True, )
+ required=True,
+ )
parser.add_argument("--iou_thres", default=0.5, type=float)
@@ -230,30 +229,31 @@ def str2bool(v):
"--ignore_case",
default=False,
type=str2bool,
- help="whether to do lower case for the strs")
+ help="whether to do lower case for the strs",
+ )
parser.add_argument(
- "--ignore_space",
- default=True,
- type=str2bool,
- help="whether to ignore space")
+ "--ignore_space", default=True, type=str2bool, help="whether to ignore space"
+ )
parser.add_argument(
"--ignore_background",
default=True,
type=str2bool,
- help="whether to ignore other label")
+ help="whether to ignore other label",
+ )
parser.add_argument(
"--ignore_ser_prediction",
default=False,
type=str2bool,
- help="whether to ignore ocr pred results")
+ help="whether to ignore ocr pred results",
+ )
args = parser.parse_args()
return args
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
eval_e2e(args)
diff --git a/ppstructure/kie/tools/trans_funsd_label.py b/ppstructure/kie/tools/trans_funsd_label.py
index ef7d1db010..104cb68a2c 100644
--- a/ppstructure/kie/tools/trans_funsd_label.py
+++ b/ppstructure/kie/tools/trans_funsd_label.py
@@ -62,16 +62,17 @@ def load_funsd_label(image_dir, anno_dir):
curr_texts = [words[0]["text"]]
while word_idx < len(words):
# switch to a new link
- if words[word_idx]["box"][0] + 10 <= words[word_idx - 1][
- "box"][2]:
+ if words[word_idx]["box"][0] + 10 <= words[word_idx - 1]["box"][2]:
if len("".join(curr_texts[0])) > 0:
- res.append({
- "transcription": " ".join(curr_texts),
- "label": info["label"],
- "points": get_outer_poly(curr_bboxes),
- "linking": info["linking"],
- "id": global_new_id,
- })
+ res.append(
+ {
+ "transcription": " ".join(curr_texts),
+ "label": info["label"],
+ "points": get_outer_poly(curr_bboxes),
+ "linking": info["linking"],
+ "id": global_new_id,
+ }
+ )
if info["id"] not in old_id2new_id_map:
old_id2new_id_map[info["id"]] = []
old_id2new_id_map[info["id"]].append(global_new_id)
@@ -83,23 +84,25 @@ def load_funsd_label(image_dir, anno_dir):
curr_texts.append(words[word_idx]["text"])
word_idx += 1
if len("".join(curr_texts[0])) > 0:
- res.append({
- "transcription": " ".join(curr_texts),
- "label": info["label"],
- "points": get_outer_poly(curr_bboxes),
- "linking": info["linking"],
- "id": global_new_id,
- })
+ res.append(
+ {
+ "transcription": " ".join(curr_texts),
+ "label": info["label"],
+ "points": get_outer_poly(curr_bboxes),
+ "linking": info["linking"],
+ "id": global_new_id,
+ }
+ )
if info["id"] not in old_id2new_id_map:
old_id2new_id_map[info["id"]] = []
old_id2new_id_map[info["id"]].append(global_new_id)
global_new_id += 1
- res = sorted(
- res, key=lambda r: (r["points"][0][1], r["points"][0][0]))
+ res = sorted(res, key=lambda r: (r["points"][0][1], r["points"][0][0]))
for i in range(len(res) - 1):
for j in range(i, 0, -1):
- if abs(res[j + 1]["points"][0][1] - res[j]["points"][0][1]) < 20 and \
- (res[j + 1]["points"][0][0] < res[j]["points"][0][0]):
+ if abs(
+ res[j + 1]["points"][0][1] - res[j]["points"][0][1]
+ ) < 20 and (res[j + 1]["points"][0][0] < res[j]["points"][0][0]):
tmp = deepcopy(res[j])
res[j] = deepcopy(res[j + 1])
res[j + 1] = deepcopy(tmp)
@@ -110,8 +113,10 @@ def load_funsd_label(image_dir, anno_dir):
new_links = []
for link in r["linking"]:
# illegal links will be removed
- if link[0] not in old_id2new_id_map or link[
- 1] not in old_id2new_id_map:
+ if (
+ link[0] not in old_id2new_id_map
+ or link[1] not in old_id2new_id_map
+ ):
continue
for src in old_id2new_id_map[link[0]]:
for dst in old_id2new_id_map[link[1]]:
@@ -131,8 +136,13 @@ def main():
fn_info_map = load_funsd_label(test_image_dir, test_anno_dir)
with open(test_output_dir, "w") as fout:
for fn in fn_info_map:
- fout.write(fn + ".png" + "\t" + json.dumps(
- fn_info_map[fn], ensure_ascii=False) + "\n")
+ fout.write(
+ fn
+ + ".png"
+ + "\t"
+ + json.dumps(fn_info_map[fn], ensure_ascii=False)
+ + "\n"
+ )
train_image_dir = "train_data/FUNSD/training_data/images/"
train_anno_dir = "train_data/FUNSD/training_data/annotations/"
@@ -141,8 +151,13 @@ def main():
fn_info_map = load_funsd_label(train_image_dir, train_anno_dir)
with open(train_output_dir, "w") as fout:
for fn in fn_info_map:
- fout.write(fn + ".png" + "\t" + json.dumps(
- fn_info_map[fn], ensure_ascii=False) + "\n")
+ fout.write(
+ fn
+ + ".png"
+ + "\t"
+ + json.dumps(fn_info_map[fn], ensure_ascii=False)
+ + "\n"
+ )
print("====ok====")
return
diff --git a/ppstructure/kie/tools/trans_xfun_data.py b/ppstructure/kie/tools/trans_xfun_data.py
index 11d221bea4..4e83b5aa3c 100644
--- a/ppstructure/kie/tools/trans_xfun_data.py
+++ b/ppstructure/kie/tools/trans_xfun_data.py
@@ -16,12 +16,12 @@
def transfer_xfun_data(json_path=None, output_file=None):
- with open(json_path, "r", encoding='utf-8') as fin:
+ with open(json_path, "r", encoding="utf-8") as fin:
lines = fin.readlines()
json_info = json.loads(lines[0])
documents = json_info["documents"]
- with open(output_file, "w", encoding='utf-8') as fout:
+ with open(output_file, "w", encoding="utf-8") as fout:
for idx, document in enumerate(documents):
label_info = []
img_info = document["img"]
@@ -31,27 +31,31 @@ def transfer_xfun_data(json_path=None, output_file=None):
for doc in document:
x1, y1, x2, y2 = doc["box"]
points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
- label_info.append({
- "transcription": doc["text"],
- "label": doc["label"],
- "points": points,
- "id": doc["id"],
- "linking": doc["linking"]
- })
+ label_info.append(
+ {
+ "transcription": doc["text"],
+ "label": doc["label"],
+ "points": points,
+ "id": doc["id"],
+ "linking": doc["linking"],
+ }
+ )
- fout.write(image_path + "\t" + json.dumps(
- label_info, ensure_ascii=False) + "\n")
+ fout.write(
+ image_path + "\t" + json.dumps(label_info, ensure_ascii=False) + "\n"
+ )
print("===ok====")
def parser_args():
import argparse
+
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument(
- "--ori_gt_path", type=str, required=True, help='origin xfun gt path')
- parser.add_argument(
- "--output_path", type=str, required=True, help='path to save')
+ "--ori_gt_path", type=str, required=True, help="origin xfun gt path"
+ )
+ parser.add_argument("--output_path", type=str, required=True, help="path to save")
args = parser.parse_args()
return args
diff --git a/ppstructure/layout/predict_layout.py b/ppstructure/layout/predict_layout.py
index 9f8c884e14..65984eeede 100755
--- a/ppstructure/layout/predict_layout.py
+++ b/ppstructure/layout/predict_layout.py
@@ -16,9 +16,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import numpy as np
@@ -37,26 +37,21 @@
class LayoutPredictor(object):
def __init__(self, args):
- pre_process_list = [{
- 'Resize': {
- 'size': [800, 608]
- }
- }, {
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225],
- 'mean': [0.485, 0.456, 0.406],
- 'scale': '1./255.',
- 'order': 'hwc'
- }
- }, {
- 'ToCHWImage': None
- }, {
- 'KeepKeys': {
- 'keep_keys': ['image']
- }
- }]
+ pre_process_list = [
+ {"Resize": {"size": [800, 608]}},
+ {
+ "NormalizeImage": {
+ "std": [0.229, 0.224, 0.225],
+ "mean": [0.485, 0.456, 0.406],
+ "scale": "1./255.",
+ "order": "hwc",
+ }
+ },
+ {"ToCHWImage": None},
+ {"KeepKeys": {"keep_keys": ["image"]}},
+ ]
postprocess_params = {
- 'name': 'PicoDetPostProcess',
+ "name": "PicoDetPostProcess",
"layout_dict_path": args.layout_dict_path,
"score_threshold": args.layout_score_threshold,
"nms_threshold": args.layout_nms_threshold,
@@ -64,12 +59,16 @@ def __init__(self, args):
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, self.config = \
- utility.create_predictor(args, 'layout', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "layout", logger)
def __call__(self, img):
ori_im = img.copy()
- data = {'image': img}
+ data = {"image": img}
data = transform(data, self.preprocess_op)
img = data[0]
@@ -90,11 +89,13 @@ def __call__(self, img):
num_outs = int(len(output_names) / 2)
for out_idx in range(num_outs):
np_score_list.append(
- self.predictor.get_output_handle(output_names[out_idx])
- .copy_to_cpu())
+ self.predictor.get_output_handle(output_names[out_idx]).copy_to_cpu()
+ )
np_boxes_list.append(
- self.predictor.get_output_handle(output_names[
- out_idx + num_outs]).copy_to_cpu())
+ self.predictor.get_output_handle(
+ output_names[out_idx + num_outs]
+ ).copy_to_cpu()
+ )
preds = dict(boxes=np_score_list, boxes_num=np_boxes_list)
post_preds = self.postprocess_op(ori_im, img, preds)
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index 8d504ff90c..9073e87ee1 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -18,9 +18,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
import numpy as np
@@ -47,10 +47,12 @@ def __init__(self, args):
self.image_orientation_predictor = None
if args.image_orientation:
import paddleclas
+
self.image_orientation_predictor = paddleclas.PaddleClas(
- model_name="text_image_orientation")
+ model_name="text_image_orientation"
+ )
- if self.mode == 'structure':
+ if self.mode == "structure":
if not args.show_log:
logger.setLevel(logging.INFO)
if args.layout == False and args.ocr == True:
@@ -69,54 +71,56 @@ def __init__(self, args):
if args.table:
if self.text_system is not None:
self.table_system = TableSystem(
- args, self.text_system.text_detector,
- self.text_system.text_recognizer)
+ args,
+ self.text_system.text_detector,
+ self.text_system.text_recognizer,
+ )
else:
self.table_system = TableSystem(args)
- elif self.mode == 'kie':
+ elif self.mode == "kie":
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
+
self.kie_predictor = SerRePredictor(args)
self.return_word_box = args.return_word_box
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = {
- 'image_orientation': 0,
- 'layout': 0,
- 'table': 0,
- 'table_match': 0,
- 'det': 0,
- 'rec': 0,
- 'kie': 0,
- 'all': 0
+ "image_orientation": 0,
+ "layout": 0,
+ "table": 0,
+ "table_match": 0,
+ "det": 0,
+ "rec": 0,
+ "kie": 0,
+ "all": 0,
}
start = time.time()
if self.image_orientation_predictor is not None:
tic = time.time()
- cls_result = self.image_orientation_predictor.predict(
- input_data=img)
+ cls_result = self.image_orientation_predictor.predict(input_data=img)
cls_res = next(cls_result)
- angle = cls_res[0]['label_names'][0]
+ angle = cls_res[0]["label_names"][0]
cv_rotate_code = {
- '90': cv2.ROTATE_90_COUNTERCLOCKWISE,
- '180': cv2.ROTATE_180,
- '270': cv2.ROTATE_90_CLOCKWISE
+ "90": cv2.ROTATE_90_COUNTERCLOCKWISE,
+ "180": cv2.ROTATE_180,
+ "270": cv2.ROTATE_90_CLOCKWISE,
}
if angle in cv_rotate_code:
img = cv2.rotate(img, cv_rotate_code[angle])
toc = time.time()
- time_dict['image_orientation'] = toc - tic
+ time_dict["image_orientation"] = toc - tic
- if self.mode == 'structure':
+ if self.mode == "structure":
ori_im = img.copy()
if self.layout_predictor is not None:
layout_res, elapse = self.layout_predictor(img)
- time_dict['layout'] += elapse
+ time_dict["layout"] += elapse
else:
h, w = ori_im.shape[:2]
- layout_res = [dict(bbox=None, label='table')]
+ layout_res = [dict(bbox=None, label="table")]
# As reported in issues such as #10270 and #11665, the old
# implementation, which recognizes texts from the layout regions,
@@ -128,14 +132,14 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
text_res = None
if self.text_system is not None:
text_res, ocr_time_dict = self._predict_text(img)
- time_dict['det'] += ocr_time_dict['det']
- time_dict['rec'] += ocr_time_dict['rec']
+ time_dict["det"] += ocr_time_dict["det"]
+ time_dict["rec"] += ocr_time_dict["rec"]
res_list = []
for region in layout_res:
- res = ''
- if region['bbox'] is not None:
- x1, y1, x2, y2 = region['bbox']
+ res = ""
+ if region["bbox"] is not None:
+ x1, y1, x2, y2 = region["bbox"]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
else:
@@ -143,35 +147,38 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
roi_img = ori_im
bbox = [x1, y1, x2, y2]
- if region['label'] == 'table':
+ if region["label"] == "table":
if self.table_system is not None:
res, table_time_dict = self.table_system(
- roi_img, return_ocr_result_in_table)
- time_dict['table'] += table_time_dict['table']
- time_dict['table_match'] += table_time_dict['match']
- time_dict['det'] += table_time_dict['det']
- time_dict['rec'] += table_time_dict['rec']
+ roi_img, return_ocr_result_in_table
+ )
+ time_dict["table"] += table_time_dict["table"]
+ time_dict["table_match"] += table_time_dict["match"]
+ time_dict["det"] += table_time_dict["det"]
+ time_dict["rec"] += table_time_dict["rec"]
else:
if text_res is not None:
# Filter the text results whose regions intersect with the current layout bbox.
res = self._filter_text_res(text_res, bbox)
- res_list.append({
- 'type': region['label'].lower(),
- 'bbox': bbox,
- 'img': roi_img,
- 'res': res,
- 'img_idx': img_idx
- })
+ res_list.append(
+ {
+ "type": region["label"].lower(),
+ "bbox": bbox,
+ "img": roi_img,
+ "res": res,
+ "img_idx": img_idx,
+ }
+ )
end = time.time()
- time_dict['all'] = end - start
+ time_dict["all"] = end - start
return res_list, time_dict
- elif self.mode == 'kie':
+ elif self.mode == "kie":
re_res, elapse = self.kie_predictor(img)
- time_dict['kie'] = elapse
- time_dict['all'] = elapse
+ time_dict["kie"] = elapse
+ time_dict["all"] = elapse
return re_res[0], time_dict
return None, None
@@ -183,38 +190,54 @@ def _predict_text(self, img):
# when using the recognition model trained on the PubtabNet dataset,
# it will recognize the text format in the table, such as
style_token = [
- '', '', '', '', '',
- '', '', '', '',
- '', '', '', '',
- ''
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token:
if token in rec_str:
- rec_str = rec_str.replace(token, '')
+ rec_str = rec_str.replace(token, "")
if self.return_word_box:
- word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
- res.append({
- 'text': rec_str,
- 'confidence': float(rec_conf),
- 'text_region': box.tolist(),
- 'text_word': word_box_content_list,
- 'text_word_region': word_box_list
- })
+ word_box_content_list, word_box_list = cal_ocr_word_box(
+ rec_str, box, rec_res[2]
+ )
+ res.append(
+ {
+ "text": rec_str,
+ "confidence": float(rec_conf),
+ "text_region": box.tolist(),
+ "text_word": word_box_content_list,
+ "text_word_region": word_box_list,
+ }
+ )
else:
- res.append({
- 'text': rec_str,
- 'confidence': float(rec_conf),
- 'text_region': box.tolist()
- })
+ res.append(
+ {
+ "text": rec_str,
+ "confidence": float(rec_conf),
+ "text_region": box.tolist(),
+ }
+ )
return res, ocr_time_dict
def _filter_text_res(self, text_res, bbox):
res = []
for r in text_res:
- box = r['text_region']
+ box = r["text_region"]
rect = box[0][0], box[0][1], box[2][0], box[2][1]
if self._has_intersection(bbox, rect):
res.append(r)
@@ -236,30 +259,34 @@ def save_structure_res(res, save_folder, img_name, img_idx=0):
res_cp = deepcopy(res)
# save res
with open(
- os.path.join(excel_save_folder, 'res_{}.txt'.format(img_idx)),
- 'w',
- encoding='utf8') as f:
+ os.path.join(excel_save_folder, "res_{}.txt".format(img_idx)),
+ "w",
+ encoding="utf8",
+ ) as f:
for region in res_cp:
- roi_img = region.pop('img')
- f.write('{}\n'.format(json.dumps(region)))
-
- if region['type'].lower() == 'table' and len(region[
- 'res']) > 0 and 'html' in region['res']:
+ roi_img = region.pop("img")
+ f.write("{}\n".format(json.dumps(region)))
+
+ if (
+ region["type"].lower() == "table"
+ and len(region["res"]) > 0
+ and "html" in region["res"]
+ ):
excel_path = os.path.join(
- excel_save_folder,
- '{}_{}.xlsx'.format(region['bbox'], img_idx))
- to_excel(region['res']['html'], excel_path)
- elif region['type'].lower() == 'figure':
+ excel_save_folder, "{}_{}.xlsx".format(region["bbox"], img_idx)
+ )
+ to_excel(region["res"]["html"], excel_path)
+ elif region["type"].lower() == "figure":
img_path = os.path.join(
- excel_save_folder,
- '{}_{}.jpg'.format(region['bbox'], img_idx))
+ excel_save_folder, "{}_{}.jpg".format(region["bbox"], img_idx)
+ )
cv2.imwrite(img_path, roi_img)
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
- image_file_list = image_file_list[args.process_id::args.total_process_num]
+ image_file_list = image_file_list[args.process_id :: args.total_process_num]
if not args.use_pdf2docx_api:
structure_sys = StructureSystem(args)
@@ -270,17 +297,17 @@ def main(args):
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag_gif, flag_pdf = check_and_read(image_file)
- img_name = os.path.basename(image_file).split('.')[0]
+ img_name = os.path.basename(image_file).split(".")[0]
if args.recovery and args.use_pdf2docx_api and flag_pdf:
from pdf2docx.converter import Converter
+
os.makedirs(args.output, exist_ok=True)
- docx_file = os.path.join(args.output,
- '{}_api.docx'.format(img_name))
+ docx_file = os.path.join(args.output, "{}_api.docx".format(img_name))
cv = Converter(image_file)
cv.convert(docx_file)
cv.close()
- logger.info('docx save to {}'.format(docx_file))
+ logger.info("docx save to {}".format(docx_file))
continue
if not flag_gif and not flag_pdf:
@@ -297,37 +324,37 @@ def main(args):
all_res = []
for index, img in enumerate(imgs):
res, time_dict = structure_sys(img, img_idx=index)
- img_save_path = os.path.join(save_folder, img_name,
- 'show_{}.jpg'.format(index))
+ img_save_path = os.path.join(
+ save_folder, img_name, "show_{}.jpg".format(index)
+ )
os.makedirs(os.path.join(save_folder, img_name), exist_ok=True)
- if structure_sys.mode == 'structure' and res != []:
+ if structure_sys.mode == "structure" and res != []:
draw_img = draw_structure_result(img, res, args.vis_font_path)
save_structure_res(res, save_folder, img_name, index)
- elif structure_sys.mode == 'kie':
+ elif structure_sys.mode == "kie":
if structure_sys.kie_predictor.predictor is not None:
- draw_img = draw_re_results(
- img, res, font_path=args.vis_font_path)
+ draw_img = draw_re_results(img, res, font_path=args.vis_font_path)
else:
- draw_img = draw_ser_results(
- img, res, font_path=args.vis_font_path)
+ draw_img = draw_ser_results(img, res, font_path=args.vis_font_path)
with open(
- os.path.join(save_folder, img_name,
- 'res_{}_kie.txt'.format(index)),
- 'w',
- encoding='utf8') as f:
- res_str = '{}\t{}\n'.format(
- image_file,
- json.dumps(
- {
- "ocr_info": res
- }, ensure_ascii=False))
+ os.path.join(save_folder, img_name, "res_{}_kie.txt".format(index)),
+ "w",
+ encoding="utf8",
+ ) as f:
+ res_str = "{}\t{}\n".format(
+ image_file, json.dumps({"ocr_info": res}, ensure_ascii=False)
+ )
f.write(res_str)
if res != []:
cv2.imwrite(img_save_path, draw_img)
- logger.info('result save to {}'.format(img_save_path))
+ logger.info("result save to {}".format(img_save_path))
if args.recovery and res != []:
- from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
+ from ppstructure.recovery.recovery_to_doc import (
+ sorted_layout_boxes,
+ convert_info_docx,
+ )
+
h, w, _ = img.shape
res = sorted_layout_boxes(res, w)
all_res += res
@@ -336,10 +363,13 @@ def main(args):
try:
convert_info_docx(img, all_res, save_folder, img_name)
except Exception as ex:
- logger.error("error in layout recovery image:{}, err msg: {}".
- format(image_file, ex))
+ logger.error(
+ "error in layout recovery image:{}, err msg: {}".format(
+ image_file, ex
+ )
+ )
continue
- logger.info("Predict time : {:.3f}s".format(time_dict['all']))
+ logger.info("Predict time : {:.3f}s".format(time_dict["all"]))
if __name__ == "__main__":
@@ -348,10 +378,11 @@ def main(args):
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
- cmd = [sys.executable, "-u"] + sys.argv + [
- "--process_id={}".format(process_id),
- "--use_mp={}".format(False)
- ]
+ cmd = (
+ [sys.executable, "-u"]
+ + sys.argv
+ + ["--process_id={}".format(process_id), "--use_mp={}".format(False)]
+ )
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
diff --git a/ppstructure/recovery/recovery_to_doc.py b/ppstructure/recovery/recovery_to_doc.py
index cd1728b666..edbeefd9f2 100644
--- a/ppstructure/recovery/recovery_to_doc.py
+++ b/ppstructure/recovery/recovery_to_doc.py
@@ -25,33 +25,35 @@
from ppstructure.recovery.table_process import HtmlToDocx
from ppocr.utils.logging import get_logger
+
logger = get_logger()
def convert_info_docx(img, res, save_folder, img_name):
doc = Document()
- doc.styles['Normal'].font.name = 'Times New Roman'
- doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体')
- doc.styles['Normal'].font.size = shared.Pt(6.5)
+ doc.styles["Normal"].font.name = "Times New Roman"
+ doc.styles["Normal"]._element.rPr.rFonts.set(qn("w:eastAsia"), "宋体")
+ doc.styles["Normal"].font.size = shared.Pt(6.5)
flag = 1
for i, region in enumerate(res):
- if len(region['res']) == 0:
+ if len(region["res"]) == 0:
continue
- img_idx = region['img_idx']
- if flag == 2 and region['layout'] == 'single':
+ img_idx = region["img_idx"]
+ if flag == 2 and region["layout"] == "single":
section = doc.add_section(WD_SECTION.CONTINUOUS)
- section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '1')
+ section._sectPr.xpath("./w:cols")[0].set(qn("w:num"), "1")
flag = 1
- elif flag == 1 and region['layout'] == 'double':
+ elif flag == 1 and region["layout"] == "double":
section = doc.add_section(WD_SECTION.CONTINUOUS)
- section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '2')
+ section._sectPr.xpath("./w:cols")[0].set(qn("w:num"), "2")
flag = 2
- if region['type'].lower() == 'figure':
+ if region["type"].lower() == "figure":
excel_save_folder = os.path.join(save_folder, img_name)
- img_path = os.path.join(excel_save_folder,
- '{}_{}.jpg'.format(region['bbox'], img_idx))
+ img_path = os.path.join(
+ excel_save_folder, "{}_{}.jpg".format(region["bbox"], img_idx)
+ )
paragraph_pic = doc.add_paragraph()
paragraph_pic.alignment = WD_ALIGN_PARAGRAPH.CENTER
run = paragraph_pic.add_run("")
@@ -59,25 +61,25 @@ def convert_info_docx(img, res, save_folder, img_name):
run.add_picture(img_path, width=shared.Inches(5))
elif flag == 2:
run.add_picture(img_path, width=shared.Inches(2))
- elif region['type'].lower() == 'title':
- doc.add_heading(region['res'][0]['text'])
- elif region['type'].lower() == 'table':
+ elif region["type"].lower() == "title":
+ doc.add_heading(region["res"][0]["text"])
+ elif region["type"].lower() == "table":
parser = HtmlToDocx()
- parser.table_style = 'TableGrid'
- parser.handle_table(region['res']['html'], doc)
+ parser.table_style = "TableGrid"
+ parser.handle_table(region["res"]["html"], doc)
else:
paragraph = doc.add_paragraph()
paragraph_format = paragraph.paragraph_format
- for i, line in enumerate(region['res']):
+ for i, line in enumerate(region["res"]):
if i == 0:
paragraph_format.first_line_indent = shared.Inches(0.25)
- text_run = paragraph.add_run(line['text'] + ' ')
+ text_run = paragraph.add_run(line["text"] + " ")
text_run.font.size = shared.Pt(10)
# save to docx
- docx_path = os.path.join(save_folder, '{}_ocr.docx'.format(img_name))
+ docx_path = os.path.join(save_folder, "{}_ocr.docx".format(img_name))
doc.save(docx_path)
- logger.info('docx save to {}'.format(docx_path))
+ logger.info("docx save to {}".format(docx_path))
def sorted_layout_boxes(res, w):
@@ -90,10 +92,10 @@ def sorted_layout_boxes(res, w):
"""
num_boxes = len(res)
if num_boxes == 1:
- res[0]['layout'] = 'single'
+ res[0]["layout"] = "single"
return res
- sorted_boxes = sorted(res, key=lambda x: (x['bbox'][1], x['bbox'][0]))
+ sorted_boxes = sorted(res, key=lambda x: (x["bbox"][1], x["bbox"][0]))
_boxes = list(sorted_boxes)
new_res = []
@@ -105,38 +107,41 @@ def sorted_layout_boxes(res, w):
if i >= num_boxes:
break
if i == num_boxes - 1:
- if _boxes[i]['bbox'][1] > _boxes[i - 1]['bbox'][3] and _boxes[i][
- 'bbox'][0] < w / 2 and _boxes[i]['bbox'][2] > w / 2:
+ if (
+ _boxes[i]["bbox"][1] > _boxes[i - 1]["bbox"][3]
+ and _boxes[i]["bbox"][0] < w / 2
+ and _boxes[i]["bbox"][2] > w / 2
+ ):
new_res += res_left
new_res += res_right
- _boxes[i]['layout'] = 'single'
+ _boxes[i]["layout"] = "single"
new_res.append(_boxes[i])
else:
- if _boxes[i]['bbox'][2] > w / 2:
- _boxes[i]['layout'] = 'double'
+ if _boxes[i]["bbox"][2] > w / 2:
+ _boxes[i]["layout"] = "double"
res_right.append(_boxes[i])
new_res += res_left
new_res += res_right
- elif _boxes[i]['bbox'][0] < w / 2:
- _boxes[i]['layout'] = 'double'
+ elif _boxes[i]["bbox"][0] < w / 2:
+ _boxes[i]["layout"] = "double"
res_left.append(_boxes[i])
new_res += res_left
new_res += res_right
res_left = []
res_right = []
break
- elif _boxes[i]['bbox'][0] < w / 4 and _boxes[i]['bbox'][2] < 3 * w / 4:
- _boxes[i]['layout'] = 'double'
+ elif _boxes[i]["bbox"][0] < w / 4 and _boxes[i]["bbox"][2] < 3 * w / 4:
+ _boxes[i]["layout"] = "double"
res_left.append(_boxes[i])
i += 1
- elif _boxes[i]['bbox'][0] > w / 4 and _boxes[i]['bbox'][2] > w / 2:
- _boxes[i]['layout'] = 'double'
+ elif _boxes[i]["bbox"][0] > w / 4 and _boxes[i]["bbox"][2] > w / 2:
+ _boxes[i]["layout"] = "double"
res_right.append(_boxes[i])
i += 1
else:
new_res += res_left
new_res += res_right
- _boxes[i]['layout'] = 'single'
+ _boxes[i]["layout"] = "single"
new_res.append(_boxes[i])
res_left = []
res_right = []
diff --git a/ppstructure/recovery/table_process.py b/ppstructure/recovery/table_process.py
index 77a6ef7659..086461ca61 100644
--- a/ppstructure/recovery/table_process.py
+++ b/ppstructure/recovery/table_process.py
@@ -24,16 +24,18 @@
def get_table_rows(table_soup):
table_row_selectors = [
- 'table > tr', 'table > thead > tr', 'table > tbody > tr',
- 'table > tfoot > tr'
+ "table > tr",
+ "table > thead > tr",
+ "table > tbody > tr",
+ "table > tfoot > tr",
]
# If there's a header, body, footer or direct child tr tags, add row dimensions from there
- return table_soup.select(', '.join(table_row_selectors), recursive=False)
+ return table_soup.select(", ".join(table_row_selectors), recursive=False)
def get_table_columns(row):
# Get all columns for the specified row tag.
- return row.find_all(['th', 'td'], recursive=False) if row else []
+ return row.find_all(["th", "td"], recursive=False) if row else []
def get_table_dimensions(table_soup):
@@ -46,7 +48,7 @@ def get_table_dimensions(table_soup):
# Add colspan calculation column number
col_count = 0
for col in cols:
- colspan = col.attrs.get('colspan', 1)
+ colspan = col.attrs.get("colspan", 1)
col_count += int(colspan)
return rows, col_count
@@ -56,7 +58,7 @@ def get_cell_html(soup):
# Returns string of td element with opening and closing tags removed
# Cannot use find_all as it only finds element tags and does not find text which
# is not inside an element
- return ' '.join([str(i) for i in soup.contents])
+ return " ".join([str(i) for i in soup.contents])
def delete_paragraph(paragraph):
@@ -107,33 +109,33 @@ def remove_whitespace(string, leading=False, trailing=False):
"""
# Remove any leading new line characters along with any surrounding white space
if leading:
- string = re.sub(r'^\s*\n+\s*', '', string)
+ string = re.sub(r"^\s*\n+\s*", "", string)
# Remove any trailing new line characters along with any surrounding white space
if trailing:
- string = re.sub(r'\s*\n+\s*$', '', string)
+ string = re.sub(r"\s*\n+\s*$", "", string)
# Replace new line characters and absorb any surrounding space.
- string = re.sub(r'\s*\n\s*', ' ', string)
+ string = re.sub(r"\s*\n\s*", " ", string)
# TODO need some way to get rid of extra spaces in e.g. text text
- return re.sub(r'\s+', ' ', string)
+ return re.sub(r"\s+", " ", string)
font_styles = {
- 'b': 'bold',
- 'strong': 'bold',
- 'em': 'italic',
- 'i': 'italic',
- 'u': 'underline',
- 's': 'strike',
- 'sup': 'superscript',
- 'sub': 'subscript',
- 'th': 'bold',
+ "b": "bold",
+ "strong": "bold",
+ "em": "italic",
+ "i": "italic",
+ "u": "underline",
+ "s": "strike",
+ "sup": "superscript",
+ "sub": "subscript",
+ "th": "bold",
}
font_names = {
- 'code': 'Courier',
- 'pre': 'Courier',
+ "code": "Courier",
+ "pre": "Courier",
}
@@ -141,33 +143,34 @@ class HtmlToDocx(HTMLParser):
def __init__(self):
super().__init__()
self.options = {
- 'fix-html': True,
- 'images': True,
- 'tables': True,
- 'styles': True,
+ "fix-html": True,
+ "images": True,
+ "tables": True,
+ "styles": True,
}
self.table_row_selectors = [
- 'table > tr', 'table > thead > tr', 'table > tbody > tr',
- 'table > tfoot > tr'
+ "table > tr",
+ "table > thead > tr",
+ "table > tbody > tr",
+ "table > tfoot > tr",
]
self.table_style = None
self.paragraph_style = None
def set_initial_attrs(self, document=None):
self.tags = {
- 'span': [],
- 'list': [],
+ "span": [],
+ "list": [],
}
if document:
self.doc = document
else:
self.doc = Document()
- self.bs = self.options[
- 'fix-html'] # whether or not to clean with BeautifulSoup
+ self.bs = self.options["fix-html"] # whether or not to clean with BeautifulSoup
self.document = self.doc
- self.include_tables = True #TODO add this option back in?
- self.include_images = self.options['images']
- self.include_styles = self.options['styles']
+ self.include_tables = True # TODO add this option back in?
+ self.include_images = self.options["images"]
+ self.include_styles = self.options["styles"]
self.paragraph = None
self.skip = False
self.skip_tag = None
@@ -192,20 +195,20 @@ def ignore_nested_tables(self, tables_soup):
nest -= 1
continue
new_tables.append(table)
- nest = len(table.find_all('table'))
+ nest = len(table.find_all("table"))
return new_tables
def get_tables(self):
- if not hasattr(self, 'soup'):
+ if not hasattr(self, "soup"):
self.include_tables = False
return
# find other way to do it, or require this dependency?
- self.tables = self.ignore_nested_tables(self.soup.find_all('table'))
+ self.tables = self.ignore_nested_tables(self.soup.find_all("table"))
self.table_no = 0
def run_process(self, html):
if self.bs and BeautifulSoup:
- self.soup = BeautifulSoup(html, 'html.parser')
+ self.soup = BeautifulSoup(html, "html.parser")
html = str(self.soup)
if self.include_tables:
self.get_tables()
@@ -213,8 +216,7 @@ def run_process(self, html):
def add_html_to_cell(self, html, cell):
if not isinstance(cell, docx.table._Cell):
- raise ValueError('Second argument needs to be a %s' %
- docx.table._Cell)
+ raise ValueError("Second argument needs to be a %s" % docx.table._Cell)
unwanted_paragraph = cell.paragraphs[0]
if unwanted_paragraph.text == "":
delete_paragraph(unwanted_paragraph)
@@ -223,7 +225,7 @@ def add_html_to_cell(self, html, cell):
# cells must end with a paragraph or will get message about corrupt file
# https://stackoverflow.com/a/29287121
if not self.doc.paragraphs:
- self.doc.add_paragraph('')
+ self.doc.add_paragraph("")
def apply_paragraph_style(self, style=None):
try:
@@ -232,8 +234,7 @@ def apply_paragraph_style(self, style=None):
elif self.paragraph_style:
self.paragraph.style = self.paragraph_style
except KeyError as e:
- raise ValueError(
- f"Unable to apply style {self.paragraph_style}.") from e
+ raise ValueError(f"Unable to apply style {self.paragraph_style}.") from e
def handle_table(self, html, doc):
"""
@@ -243,37 +244,38 @@ def handle_table(self, html, doc):
Iterate over soup and fill docx table with new instances of this parser
Tell HTMLParser to ignore any tags until the corresponding closing table tag
"""
- table_soup = BeautifulSoup(html, 'html.parser')
+ table_soup = BeautifulSoup(html, "html.parser")
rows, cols_len = get_table_dimensions(table_soup)
table = doc.add_table(len(rows), cols_len)
- table.style = doc.styles['Table Grid']
+ table.style = doc.styles["Table Grid"]
cell_row = 0
for index, row in enumerate(rows):
cols = get_table_columns(row)
cell_col = 0
for col in cols:
- colspan = int(col.attrs.get('colspan', 1))
- rowspan = int(col.attrs.get('rowspan', 1))
+ colspan = int(col.attrs.get("colspan", 1))
+ rowspan = int(col.attrs.get("rowspan", 1))
cell_html = get_cell_html(col)
- if col.name == 'th':
+ if col.name == "th":
cell_html = "%s" % cell_html
docx_cell = table.cell(cell_row, cell_col)
- while docx_cell.text != '': # Skip the merged cell
+ while docx_cell.text != "": # Skip the merged cell
cell_col += 1
docx_cell = table.cell(cell_row, cell_col)
- cell_to_merge = table.cell(cell_row + rowspan - 1,
- cell_col + colspan - 1)
+ cell_to_merge = table.cell(
+ cell_row + rowspan - 1, cell_col + colspan - 1
+ )
if docx_cell != cell_to_merge:
docx_cell.merge(cell_to_merge)
child_parser = HtmlToDocx()
child_parser.copy_settings_from(self)
- child_parser.add_html_to_cell(cell_html or ' ', docx_cell)
+ child_parser.add_html_to_cell(cell_html or " ", docx_cell)
cell_col += colspan
cell_row += 1
@@ -283,7 +285,7 @@ def handle_data(self, data):
return
# Only remove white space if we're not in a pre block.
- if 'pre' not in self.tags:
+ if "pre" not in self.tags:
# remove leading and trailing whitespace in all instances
data = remove_whitespace(data, True, True)
@@ -294,16 +296,16 @@ def handle_data(self, data):
# There can only be one nested link in a valid html document
# You cannot have interactive content in an A tag, this includes links
# https://html.spec.whatwg.org/#interactive-content
- link = self.tags.get('a')
+ link = self.tags.get("a")
if link:
- self.handle_link(link['href'], data)
+ self.handle_link(link["href"], data)
else:
# If there's a link, dont put the data directly in the run
self.run = self.paragraph.add_run(data)
- spans = self.tags['span']
+ spans = self.tags["span"]
for span in spans:
- if 'style' in span:
- style = self.parse_dict_string(span['style'])
+ if "style" in span:
+ style = self.parse_dict_string(span["style"])
self.add_styles_to_run(style)
# add font style and name
diff --git a/ppstructure/table/convert_label2html.py b/ppstructure/table/convert_label2html.py
index be16212ac4..fb05dfe546 100644
--- a/ppstructure/table/convert_label2html.py
+++ b/ppstructure/table/convert_label2html.py
@@ -21,8 +21,8 @@
def save_pred_txt(key, val, tmp_file_path):
- with open(tmp_file_path, 'a+', encoding='utf-8') as f:
- f.write('{}\t{}\n'.format(key, val))
+ with open(tmp_file_path, "a+", encoding="utf-8") as f:
+ f.write("{}\t{}\n".format(key, val))
def skip_char(text, sp_char_list):
@@ -33,27 +33,27 @@ def skip_char(text, sp_char_list):
@return:
"""
for sp_char in sp_char_list:
- text = text.replace(sp_char, '')
+ text = text.replace(sp_char, "")
return text
def gen_html(img):
- '''
+ """
Formats HTML code from tokenized annotation of img
- '''
- html_code = img['html']['structure']['tokens'].copy()
- to_insert = [i for i, tag in enumerate(html_code) if tag in (' | ', '>')]
- for i, cell in zip(to_insert[::-1], img['html']['cells'][::-1]):
- if cell['tokens']:
- text = ''.join(cell['tokens'])
+ """
+ html_code = img["html"]["structure"]["tokens"].copy()
+ to_insert = [i for i, tag in enumerate(html_code) if tag in (" | ", ">")]
+ for i, cell in zip(to_insert[::-1], img["html"]["cells"][::-1]):
+ if cell["tokens"]:
+ text = "".join(cell["tokens"])
# skip empty text
- sp_char_list = ['', '', '\u2028', ' ', '', '']
+ sp_char_list = ["", "", "\u2028", " ", "", ""]
text_remove_style = skip_char(text, sp_char_list)
if len(text_remove_style) == 0:
continue
html_code.insert(i + 1, text)
- html_code = ''.join(html_code)
- html_code = ''.format(html_code)
+ html_code = "".join(html_code)
+ html_code = "".format(html_code)
return html_code
@@ -64,12 +64,12 @@ def load_gt_data(gt_path):
@return:
"""
data_list = {}
- with open(gt_path, 'rb') as f:
+ with open(gt_path, "rb") as f:
lines = f.readlines()
for line in tqdm(lines):
- data_line = line.decode('utf-8').strip("\n")
+ data_line = line.decode("utf-8").strip("\n")
info = json.loads(data_line)
- data_list[info['filename']] = info
+ data_list[info["filename"]] = info
return data_list
@@ -84,19 +84,19 @@ def convert(origin_gt_path, save_path):
for img_name, gt in tqdm(data_dict.items()):
html = gen_html(gt)
save_pred_txt(img_name, html, save_path)
- print('conver finish')
+ print("conver finish")
def parse_args():
parser = argparse.ArgumentParser(description="args for paddleserving")
+ parser.add_argument("--ori_gt_path", type=str, required=True, help="label gt path")
parser.add_argument(
- "--ori_gt_path", type=str, required=True, help="label gt path")
- parser.add_argument(
- "--save_path", type=str, required=True, help="path to save file")
+ "--save_path", type=str, required=True, help="path to save file"
+ )
args = parser.parse_args()
return args
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
convert(args.ori_gt_path, args.save_path)
diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py
index 4fc16b5d4c..b9e4661a0e 100755
--- a/ppstructure/table/eval_table.py
+++ b/ppstructure/table/eval_table.py
@@ -17,7 +17,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
import cv2
import pickle
@@ -41,10 +41,10 @@ def load_txt(txt_path):
pred_html_dict = {}
if not os.path.exists(txt_path):
return pred_html_dict
- with open(txt_path, encoding='utf-8') as f:
+ with open(txt_path, encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
- line = line.strip().split('\t')
+ line = line.strip().split("\t")
img_name, pred_html = line
pred_html_dict[img_name] = pred_html
return pred_html_dict
@@ -53,14 +53,14 @@ def load_txt(txt_path):
def load_result(path):
data = {}
if os.path.exists(path):
- data = pickle.load(open(path, 'rb'))
+ data = pickle.load(open(path, "rb"))
return data
def save_result(path, data):
old_data = load_result(path)
old_data.update(data)
- with open(path, 'wb') as f:
+ with open(path, "wb") as f:
pickle.dump(old_data, f)
@@ -71,9 +71,8 @@ def main(gt_path, img_root, args):
# load gt and preds html result
gt_html_dict = load_txt(gt_path)
- ocr_result = load_result(os.path.join(args.output, 'ocr.pickle'))
- structure_result = load_result(
- os.path.join(args.output, 'structure.pickle'))
+ ocr_result = load_result(os.path.join(args.output, "ocr.pickle"))
+ structure_result = load_result(os.path.join(args.output, "structure.pickle"))
pred_htmls = []
gt_htmls = []
@@ -83,13 +82,12 @@ def main(gt_path, img_root, args):
if img_name not in ocr_result:
dt_boxes, rec_res, _, _ = text_sys._ocr(img)
ocr_result[img_name] = [dt_boxes, rec_res]
- save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result)
+ save_result(os.path.join(args.output, "ocr.pickle"), ocr_result)
# run structure and save result
if img_name not in structure_result:
structure_res, _ = text_sys._structure(img)
structure_result[img_name] = structure_res
- save_result(
- os.path.join(args.output, 'structure.pickle'), structure_result)
+ save_result(os.path.join(args.output, "structure.pickle"), structure_result)
dt_boxes, rec_res = ocr_result[img_name]
structure_res = structure_result[img_name]
# match ocr and structure
@@ -101,9 +99,9 @@ def main(gt_path, img_root, args):
# compute teds
teds = TEDS(n_jobs=16)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
- logger.info('teds: {}'.format(sum(scores) / len(scores)))
+ logger.info("teds: {}".format(sum(scores) / len(scores)))
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
main(args.gt_path, args.image_dir, args)
diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py
index 9c5bd2630f..51e6250f47 100755
--- a/ppstructure/table/matcher.py
+++ b/ppstructure/table/matcher.py
@@ -62,15 +62,16 @@ def __init__(self, filter_ocr_result=False, use_master=False):
def __call__(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
if self.filter_ocr_result:
- dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes,
- rec_res)
+ dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
if self.use_master:
- pred_html, pred = self.get_pred_html_master(pred_structures,
- matched_index, rec_res)
+ pred_html, pred = self.get_pred_html_master(
+ pred_structures, matched_index, rec_res
+ )
else:
- pred_html, pred = self.get_pred_html(pred_structures, matched_index,
- rec_res)
+ pred_html, pred = self.get_pred_html(
+ pred_structures, matched_index, rec_res
+ )
return pred_html
def match_result(self, dt_boxes, pred_bboxes):
@@ -80,16 +81,19 @@ def match_result(self, dt_boxes, pred_bboxes):
for j, pred_box in enumerate(pred_bboxes):
if len(pred_box) == 8:
pred_box = [
- np.min(pred_box[0::2]), np.min(pred_box[1::2]),
- np.max(pred_box[0::2]), np.max(pred_box[1::2])
+ np.min(pred_box[0::2]),
+ np.min(pred_box[1::2]),
+ np.max(pred_box[0::2]),
+ np.max(pred_box[1::2]),
]
- distances.append((distance(gt_box, pred_box),
- 1. - compute_iou(gt_box, pred_box)
- )) # compute iou and l1 distance
+ distances.append(
+ (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
+ ) # compute iou and l1 distance
sorted_distances = distances.copy()
# select det box by iou and l1 distance
sorted_distances = sorted(
- sorted_distances, key=lambda item: (item[1], item[0]))
+ sorted_distances, key=lambda item: (item[1], item[0])
+ )
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
@@ -100,82 +104,89 @@ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
- if ' | ' in tag:
- if ' | ' == tag:
- end_html.extend('')
+ if " | " in tag:
+ if " | " == tag:
+ end_html.extend("")
if td_index in matched_index.keys():
b_with = False
- if '' in ocr_contents[matched_index[td_index][
- 0]] and len(matched_index[td_index]) > 1:
+ if (
+ "" in ocr_contents[matched_index[td_index][0]]
+ and len(matched_index[td_index]) > 1
+ ):
b_with = True
- end_html.extend('')
+ end_html.extend("")
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
- if content[0] == ' ':
+ if content[0] == " ":
content = content[1:]
- if '' in content:
+ if "" in content:
content = content[3:]
- if '' in content:
+ if "" in content:
content = content[:-4]
if len(content) == 0:
continue
- if i != len(matched_index[
- td_index]) - 1 and ' ' != content[-1]:
- content += ' '
+ if (
+ i != len(matched_index[td_index]) - 1
+ and " " != content[-1]
+ ):
+ content += " "
end_html.extend(content)
if b_with:
- end_html.extend('')
- if ' | | ' == tag:
- end_html.append(' | ')
+ end_html.extend("")
+ if " | " == tag:
+ end_html.append(" | ")
else:
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
- return ''.join(end_html), end_html
+ return "".join(end_html), end_html
- def get_pred_html_master(self, pred_structures, matched_index,
- ocr_contents):
+ def get_pred_html_master(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for token in pred_structures:
- if '' in token:
- txt = ''
+ if "" in token:
+ txt = ""
b_with = False
if td_index in matched_index.keys():
- if '' in ocr_contents[matched_index[td_index][
- 0]] and len(matched_index[td_index]) > 1:
+ if (
+ "" in ocr_contents[matched_index[td_index][0]]
+ and len(matched_index[td_index]) > 1
+ ):
b_with = True
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
- if content[0] == ' ':
+ if content[0] == " ":
content = content[1:]
- if '' in content:
+ if "" in content:
content = content[3:]
- if '' in content:
+ if "" in content:
content = content[:-4]
if len(content) == 0:
continue
- if i != len(matched_index[
- td_index]) - 1 and ' ' != content[-1]:
- content += ' '
+ if (
+ i != len(matched_index[td_index]) - 1
+ and " " != content[-1]
+ ):
+ content += " "
txt += content
if b_with:
- txt = '{}'.format(txt)
- if ' | ' == token:
- token = '{} | '.format(txt)
+ txt = "{}".format(txt)
+ if " | " == token:
+ token = "{} | ".format(txt)
else:
- token = '{}'.format(txt)
+ token = "{}".format(txt)
td_index += 1
token = deal_eb_token(token)
end_html.append(token)
- html = ''.join(end_html)
+ html = "".join(end_html)
html = deal_bb(html)
return html, end_html
diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py
index 08e381a846..93a930b27b 100755
--- a/ppstructure/table/predict_structure.py
+++ b/ppstructure/table/predict_structure.py
@@ -16,9 +16,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import numpy as np
@@ -37,32 +37,30 @@
def build_pre_process_list(args):
- resize_op = {'ResizeTableImage': {'max_len': args.table_max_len, }}
- pad_op = {
- 'PaddingTableImage': {
- 'size': [args.table_max_len, args.table_max_len]
+ resize_op = {
+ "ResizeTableImage": {
+ "max_len": args.table_max_len,
}
}
+ pad_op = {"PaddingTableImage": {"size": [args.table_max_len, args.table_max_len]}}
normalize_op = {
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225] if
- args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
- 'mean': [0.485, 0.456, 0.406] if
- args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
- 'scale': '1./255.',
- 'order': 'hwc'
+ "NormalizeImage": {
+ "std": [0.229, 0.224, 0.225]
+ if args.table_algorithm not in ["TableMaster"]
+ else [0.5, 0.5, 0.5],
+ "mean": [0.485, 0.456, 0.406]
+ if args.table_algorithm not in ["TableMaster"]
+ else [0.5, 0.5, 0.5],
+ "scale": "1./255.",
+ "order": "hwc",
}
}
- to_chw_op = {'ToCHWImage': None}
- keep_keys_op = {'KeepKeys': {'keep_keys': ['image', 'shape']}}
- if args.table_algorithm not in ['TableMaster']:
- pre_process_list = [
- resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op
- ]
+ to_chw_op = {"ToCHWImage": None}
+ keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
+ if args.table_algorithm not in ["TableMaster"]:
+ pre_process_list = [resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op]
else:
- pre_process_list = [
- resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op
- ]
+ pre_process_list = [resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op]
return pre_process_list
@@ -71,27 +69,32 @@ def __init__(self, args):
self.args = args
self.use_onnx = args.use_onnx
pre_process_list = build_pre_process_list(args)
- if args.table_algorithm not in ['TableMaster']:
+ if args.table_algorithm not in ["TableMaster"]:
postprocess_params = {
- 'name': 'TableLabelDecode',
+ "name": "TableLabelDecode",
"character_dict_path": args.table_char_dict_path,
- 'merge_no_span_structure': args.merge_no_span_structure
+ "merge_no_span_structure": args.merge_no_span_structure,
}
else:
postprocess_params = {
- 'name': 'TableMasterLabelDecode',
+ "name": "TableMasterLabelDecode",
"character_dict_path": args.table_char_dict_path,
- 'box_shape': 'pad',
- 'merge_no_span_structure': args.merge_no_span_structure
+ "box_shape": "pad",
+ "merge_no_span_structure": args.merge_no_span_structure,
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, self.config = \
- utility.create_predictor(args, 'table', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "table", logger)
if args.benchmark:
import auto_log
+
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
@@ -99,16 +102,15 @@ def __init__(self, args):
model_precision=args.precision,
batch_size=1,
data_shape="dynamic",
- save_path=None, #args.save_log_path,
+ save_path=None, # args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
- time_keys=[
- 'preprocess_time', 'inference_time', 'postprocess_time'
- ],
+ time_keys=["preprocess_time", "inference_time", "postprocess_time"],
warmup=0,
- logger=logger)
+ logger=logger,
+ )
def __call__(self, img):
starttime = time.time()
@@ -116,7 +118,7 @@ def __call__(self, img):
self.autolog.times.start()
ori_im = img.copy()
- data = {'image': img}
+ data = {"image": img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
@@ -140,18 +142,20 @@ def __call__(self, img):
self.autolog.times.stamp()
preds = {}
- preds['structure_probs'] = outputs[1]
- preds['loc_preds'] = outputs[0]
+ preds["structure_probs"] = outputs[1]
+ preds["loc_preds"] = outputs[0]
shape_list = np.expand_dims(data[-1], axis=0)
post_result = self.postprocess_op(preds, [shape_list])
- structure_str_list = post_result['structure_batch_list'][0]
- bbox_list = post_result['bbox_batch_list'][0]
+ structure_str_list = post_result["structure_batch_list"][0]
+ bbox_list = post_result["bbox_batch_list"][0]
structure_str_list = structure_str_list[0]
- structure_str_list = [
- '', '', ''
- ] + structure_str_list + ['
', '', '']
+ structure_str_list = (
+ ["", "", ""]
+ + structure_str_list
+ + ["
", "", ""]
+ )
elapse = time.time() - starttime
if self.args.benchmark:
self.autolog.times.end(stamp=True)
@@ -165,8 +169,8 @@ def main(args):
total_time = 0
os.makedirs(args.output, exist_ok=True)
with open(
- os.path.join(args.output, 'infer.txt'), mode='w',
- encoding='utf-8') as f_w:
+ os.path.join(args.output, "infer.txt"), mode="w", encoding="utf-8"
+ ) as f_w:
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
@@ -177,17 +181,14 @@ def main(args):
structure_res, elapse = table_structurer(img)
structure_str_list, bbox_list = structure_res
bbox_list_str = json.dumps(bbox_list.tolist())
- logger.info("result: {}, {}".format(structure_str_list,
- bbox_list_str))
- f_w.write("result: {}, {}\n".format(structure_str_list,
- bbox_list_str))
+ logger.info("result: {}, {}".format(structure_str_list, bbox_list_str))
+ f_w.write("result: {}, {}\n".format(structure_str_list, bbox_list_str))
if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
img = draw_rectangle(image_file, bbox_list)
else:
img = utility.draw_boxes(img, bbox_list)
- img_save_path = os.path.join(args.output,
- os.path.basename(image_file))
+ img_save_path = os.path.join(args.output, os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py
index 76bd42dc00..e1880cd8ad 100644
--- a/ppstructure/table/predict_table.py
+++ b/ppstructure/table/predict_table.py
@@ -17,10 +17,10 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import copy
import logging
@@ -64,45 +64,54 @@ def __init__(self, args, text_detector=None, text_recognizer=None):
if args.benchmark:
benchmark_tmp = args.benchmark
args.benchmark = False
- self.text_detector = predict_det.TextDetector(copy.deepcopy(
- args)) if text_detector is None else text_detector
- self.text_recognizer = predict_rec.TextRecognizer(copy.deepcopy(
- args)) if text_recognizer is None else text_recognizer
+ self.text_detector = (
+ predict_det.TextDetector(copy.deepcopy(args))
+ if text_detector is None
+ else text_detector
+ )
+ self.text_recognizer = (
+ predict_rec.TextRecognizer(copy.deepcopy(args))
+ if text_recognizer is None
+ else text_recognizer
+ )
if benchmark_tmp:
args.benchmark = True
self.table_structurer = predict_strture.TableStructurer(args)
- if args.table_algorithm in ['TableMaster']:
+ if args.table_algorithm in ["TableMaster"]:
self.match = TableMasterMatcher()
else:
self.match = TableMatch(filter_ocr_result=True)
- self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
- args, 'table', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "table", logger)
def __call__(self, img, return_ocr_result_in_table=False):
result = dict()
- time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0}
+ time_dict = {"det": 0, "rec": 0, "table": 0, "all": 0, "match": 0}
start = time.time()
structure_res, elapse = self._structure(copy.deepcopy(img))
- result['cell_bbox'] = structure_res[1].tolist()
- time_dict['table'] = elapse
+ result["cell_bbox"] = structure_res[1].tolist()
+ time_dict["table"] = elapse
- dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
- copy.deepcopy(img))
- time_dict['det'] = det_elapse
- time_dict['rec'] = rec_elapse
+ dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(copy.deepcopy(img))
+ time_dict["det"] = det_elapse
+ time_dict["rec"] = rec_elapse
if return_ocr_result_in_table:
- result['boxes'] = [x.tolist() for x in dt_boxes]
- result['rec_res'] = rec_res
+ result["boxes"] = [x.tolist() for x in dt_boxes]
+ result["rec_res"] = rec_res
tic = time.time()
pred_html = self.match(structure_res, dt_boxes, rec_res)
toc = time.time()
- time_dict['match'] = toc - tic
- result['html'] = pred_html
+ time_dict["match"] = toc - tic
+ result["html"] = pred_html
end = time.time()
- time_dict['all'] = end - start
+ time_dict["all"] = end - start
return result, time_dict
def _structure(self, img):
@@ -123,8 +132,7 @@ def _ocr(self, img):
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
- logger.debug("dt_boxes num : {}, elapse : {}".format(
- len(dt_boxes), det_elapse))
+ logger.debug("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), det_elapse))
if dt_boxes is None:
return None, None
@@ -132,46 +140,46 @@ def _ocr(self, img):
for i in range(len(dt_boxes)):
det_box = dt_boxes[i]
x0, y0, x1, y1 = expand(2, det_box, img.shape)
- text_rect = img[int(y0):int(y1), int(x0):int(x1), :]
+ text_rect = img[int(y0) : int(y1), int(x0) : int(x1), :]
img_crop_list.append(text_rect)
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
- logger.debug("rec_res num : {}, elapse : {}".format(
- len(rec_res), rec_elapse))
+ logger.debug("rec_res num : {}, elapse : {}".format(len(rec_res), rec_elapse))
return dt_boxes, rec_res, det_elapse, rec_elapse
def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl
+
tablepyxl.document_to_xl(html_table, excel_path)
def main(args):
image_file_list = get_image_file_list(args.image_dir)
- image_file_list = image_file_list[args.process_id::args.total_process_num]
+ image_file_list = image_file_list[args.process_id :: args.total_process_num]
os.makedirs(args.output, exist_ok=True)
table_sys = TableSystem(args)
img_num = len(image_file_list)
- f_html = open(
- os.path.join(args.output, 'show.html'), mode='w', encoding='utf-8')
- f_html.write('\n\n')
+ f_html = open(os.path.join(args.output, "show.html"), mode="w", encoding="utf-8")
+ f_html.write("\n\n")
f_html.write('\n')
f_html.write(
- ""
+ ''
)
f_html.write("\n")
- f_html.write('img name\n')
- f_html.write(' | ori image | ')
- f_html.write('table html | ')
- f_html.write('cell box | ')
+ f_html.write("img name\n")
+ f_html.write(" | ori image | ")
+ f_html.write("table html | ")
+ f_html.write("cell box | ")
f_html.write("
\n")
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag, _ = check_and_read(image_file)
excel_path = os.path.join(
- args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
+ args.output, os.path.basename(image_file).split(".")[0] + ".xlsx"
+ )
if not flag:
img = cv2.imread(image_file)
if img is None:
@@ -179,30 +187,31 @@ def main(args):
continue
starttime = time.time()
pred_res, _ = table_sys(img)
- pred_html = pred_res['html']
+ pred_html = pred_res["html"]
logger.info(pred_html)
to_excel(pred_html, excel_path)
- logger.info('excel saved to {}'.format(excel_path))
+ logger.info("excel saved to {}".format(excel_path))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
- if len(pred_res['cell_bbox']) > 0 and len(pred_res['cell_bbox'][
- 0]) == 4:
- img = predict_strture.draw_rectangle(image_file,
- pred_res['cell_bbox'])
+ if len(pred_res["cell_bbox"]) > 0 and len(pred_res["cell_bbox"][0]) == 4:
+ img = predict_strture.draw_rectangle(image_file, pred_res["cell_bbox"])
else:
- img = utility.draw_boxes(img, pred_res['cell_bbox'])
+ img = utility.draw_boxes(img, pred_res["cell_bbox"])
img_save_path = os.path.join(args.output, os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
f_html.write("\n")
- f_html.write(f' {os.path.basename(image_file)} \n')
+ f_html.write(f" | {os.path.basename(image_file)} \n")
f_html.write(f' | | \n')
- f_html.write('' + pred_html.replace(
- '', '') +
- ' | \n')
f_html.write(
- f' | \n')
+ ''
+ + pred_html.replace("", ""
+ )
+ + " | \n"
+ )
+ f_html.write(f' | \n')
f_html.write("
\n")
f_html.write("
\n")
f_html.close()
@@ -215,13 +224,15 @@ def main(args):
args = parse_args()
if args.use_mp:
import subprocess
+
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
- cmd = [sys.executable, "-u"] + sys.argv + [
- "--process_id={}".format(process_id),
- "--use_mp={}".format(False)
- ]
+ cmd = (
+ [sys.executable, "-u"]
+ + sys.argv
+ + ["--process_id={}".format(process_id), "--use_mp={}".format(False)]
+ )
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
diff --git a/ppstructure/table/table_master_match.py b/ppstructure/table/table_master_match.py
index 7a7208d4a9..54b4506510 100644
--- a/ppstructure/table/table_master_match.py
+++ b/ppstructure/table/table_master_match.py
@@ -26,6 +26,7 @@
import numpy as np
from shapely.geometry import Polygon, MultiPoint
+
"""
Useful function in matching.
"""
@@ -40,7 +41,7 @@ def remove_empty_bboxes(bboxes):
"""
new_bboxes = []
for bbox in bboxes:
- if sum(bbox) == 0.:
+ if sum(bbox) == 0.0:
continue
new_bboxes.append(bbox)
return np.array(new_bboxes)
@@ -84,15 +85,15 @@ def xyxy2xywh(bboxes):
raise ValueError
-def pickle_load(path, prefix='end2end'):
+def pickle_load(path, prefix="end2end"):
if os.path.isfile(path):
- data = pickle.load(open(path, 'rb'))
+ data = pickle.load(open(path, "rb"))
elif os.path.isdir(path):
data = dict()
- search_path = os.path.join(path, '{}_*.pkl'.format(prefix))
+ search_path = os.path.join(path, "{}_*.pkl".format(prefix))
pkls = glob.glob(search_path)
for pkl in pkls:
- this_data = pickle.load(open(pkl, 'rb'))
+ this_data = pickle.load(open(pkl, "rb"))
data.update(this_data)
else:
raise ValueError
@@ -147,10 +148,12 @@ def is_inside(center_point, corner_point):
x_flag = False
y_flag = False
if (center_point[0] >= corner_point[0][0]) and (
- center_point[0] <= corner_point[1][0]):
+ center_point[0] <= corner_point[1][0]
+ ):
x_flag = True
if (center_point[1] >= corner_point[0][1]) and (
- center_point[1] <= corner_point[1][1]):
+ center_point[1] <= corner_point[1][1]
+ ):
y_flag = True
if x_flag and y_flag:
return True
@@ -158,7 +161,7 @@ def is_inside(center_point, corner_point):
return False
-def find_no_match(match_list, all_end2end_nums, type='end2end'):
+def find_no_match(match_list, all_end2end_nums, type="end2end"):
"""
Find out no match end2end bbox in previous match list.
:param match_list: matching pairs.
@@ -166,9 +169,9 @@ def find_no_match(match_list, all_end2end_nums, type='end2end'):
:param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
:return: no match pse bbox index list
"""
- if type == 'end2end':
+ if type == "end2end":
idx = 0
- elif type == 'master':
+ elif type == "master":
idx = 1
else:
raise ValueError
@@ -232,8 +235,7 @@ def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
"""
groups = []
bbox_groups = []
- for index, end2end_xywh_bbox in zip(no_match_end2end_indexes,
- end2end_xywh_bboxes):
+ for index, end2end_xywh_bbox in zip(no_match_end2end_indexes, end2end_xywh_bboxes):
this_bbox = end2end_xywh_bbox
if len(groups) == 0:
groups.append([index])
@@ -270,10 +272,16 @@ def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
sorted_bbox_groups[idx] = bg
# flatten, get final result
- end2end_sorted_idx_list, end2end_sorted_bbox_list \
- = flatten(sorted_groups, sorted_bbox_groups)
+ end2end_sorted_idx_list, end2end_sorted_bbox_list = flatten(
+ sorted_groups, sorted_bbox_groups
+ )
- return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups
+ return (
+ end2end_sorted_idx_list,
+ end2end_sorted_bbox_list,
+ sorted_groups,
+ sorted_bbox_groups,
+ )
def get_bboxes_list(end2end_result, structure_master_result):
@@ -288,7 +296,7 @@ def get_bboxes_list(end2end_result, structure_master_result):
end2end_xyxy_list = []
end2end_xywh_list = []
for end2end_item in end2end_result:
- src_bbox = end2end_item['bbox']
+ src_bbox = end2end_item["bbox"]
end2end_xyxy_list.append(src_bbox)
xywh_bbox = xyxy2xywh(src_bbox)
end2end_xywh_list.append(xywh_bbox)
@@ -296,13 +304,18 @@ def get_bboxes_list(end2end_result, structure_master_result):
end2end_xywh_bboxes = np.array(end2end_xywh_list)
# structure master
- src_bboxes = structure_master_result['bbox']
+ src_bboxes = structure_master_result["bbox"]
src_bboxes = remove_empty_bboxes(src_bboxes)
structure_master_xyxy_bboxes = src_bboxes
xywh_bbox = xyxy2xywh(src_bboxes)
structure_master_xywh_bboxes = xywh_bbox
- return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes
+ return (
+ end2end_xyxy_bboxes,
+ end2end_xywh_bboxes,
+ structure_master_xywh_bboxes,
+ structure_master_xyxy_bboxes,
+ )
def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
@@ -317,18 +330,22 @@ def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
for i, end2end_xywh in enumerate(end2end_xywh_bboxes):
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1]
- x_master1, y_master1, x_master2, y_master2 \
- = master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3]
+ x_master1, y_master1, x_master2, y_master2 = (
+ master_xyxy[0],
+ master_xyxy[1],
+ master_xyxy[2],
+ master_xyxy[3],
+ )
center_point_end2end = (x_end2end, y_end2end)
- corner_point_master = ((x_master1, y_master1),
- (x_master2, y_master2))
+ corner_point_master = ((x_master1, y_master1), (x_master2, y_master2))
if is_inside(center_point_end2end, corner_point_master):
match_pairs_list.append([i, j])
return match_pairs_list
-def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
- structure_master_xyxy_bboxes):
+def iou_rule_match(
+ end2end_xyxy_bboxes, end2end_xyxy_indexes, structure_master_xyxy_bboxes
+):
"""
Use iou to find matching list.
choose max iou value bbox as match pair.
@@ -338,8 +355,9 @@ def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
:return: match pairs list, e.g. [[0,1], [1,2], ...]
"""
match_pair_list = []
- for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes,
- end2end_xyxy_bboxes):
+ for end2end_xyxy_index, end2end_xyxy in zip(
+ end2end_xyxy_indexes, end2end_xyxy_bboxes
+ ):
max_iou = 0
max_match = [None, None]
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
@@ -357,8 +375,7 @@ def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
return match_pair_list
-def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes,
- master_bboxes):
+def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes, master_bboxes):
"""
Get matching between no-match end2end bboxes and no-match master bboxes.
Use min distance to match.
@@ -428,9 +445,9 @@ def deal_successive_space(text):
:param text:
:return:
"""
- text = text.replace(' ' * 3, '')
- text = text.replace(' ', '')
- text = text.replace('', ' ')
+ text = text.replace(" " * 3, "")
+ text = text.replace(" ", "")
+ text = text.replace("", " ")
return text
@@ -444,24 +461,23 @@ def reduce_repeat_bb(text_list, break_token):
"""
count = 0
for text in text_list:
- if text.startswith(''):
+ if text.startswith(""):
count += 1
if count == len(text_list):
new_text_list = []
for text in text_list:
- text = text.replace('', '').replace('', '')
+ text = text.replace("", "").replace("", "")
new_text_list.append(text)
- return ['' + break_token.join(new_text_list) + '']
+ return ["" + break_token.join(new_text_list) + ""]
else:
return text_list
-def get_match_text_dict(match_dict, end2end_info, break_token=' '):
+def get_match_text_dict(match_dict, end2end_info, break_token=" "):
match_text_dict = dict()
for master_index, end2end_index_list in match_dict.items():
text_list = [
- end2end_info[end2end_index]['text']
- for end2end_index in end2end_index_list
+ end2end_info[end2end_index]["text"] for end2end_index in end2end_index_list
]
text_list = reduce_repeat_bb(text_list, break_token)
text = break_token.join(text_list)
@@ -477,32 +493,32 @@ def merge_span_token(master_token_list):
"""
new_master_token_list = []
pointer = 0
- if master_token_list[-1] != '':
- master_token_list.append('')
- while master_token_list[pointer] != '':
+ if master_token_list[-1] != "":
+ master_token_list.append("")
+ while master_token_list[pointer] != "":
try:
- if master_token_list[pointer] == '' + ' | '
"""
- tmp = ''.join(master_token_list[pointer:pointer + 3 + 1])
+ tmp = "".join(master_token_list[pointer : pointer + 3 + 1])
pointer += 4
new_master_token_list.append(tmp)
elif master_token_list[pointer + 2].startswith(
- ' colspan=') or master_token_list[
- pointer + 2].startswith(' rowspan='):
+ " colspan="
+ ) or master_token_list[pointer + 2].startswith(" rowspan="):
"""
example:
pattern
' | ' + ' | '
"""
- tmp = ''.join(master_token_list[pointer:pointer + 4 + 1])
+ tmp = "".join(master_token_list[pointer : pointer + 4 + 1])
pointer += 5
new_master_token_list.append(tmp)
@@ -515,7 +531,7 @@ def merge_span_token(master_token_list):
except:
print("Break in merge...")
break
- new_master_token_list.append('')
+ new_master_token_list.append("")
return new_master_token_list
@@ -539,20 +555,19 @@ def deal_eb_token(master_token):
:param master_token:
:return:
"""
- master_token = master_token.replace('', ' | ')
- master_token = master_token.replace('', ' | ')
- master_token = master_token.replace('', ' | ')
- master_token = master_token.replace('', '\u2028\u2028 | ')
- master_token = master_token.replace('', ' | ')
- master_token = master_token.replace('', ' | ')
- master_token = master_token.replace('', ' | ')
- master_token = master_token.replace('',
- ' | ')
- master_token = master_token.replace('',
- ' | ')
- master_token = master_token.replace('', ' | ')
- master_token = master_token.replace('',
- ' \u2028 \u2028 | ')
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", "\u2028\u2028 | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace(
+ "", " \u2028 \u2028 | "
+ )
return master_token
@@ -567,7 +582,7 @@ def insert_text_to_token(master_token_list, match_text_dict):
merged_result_list = []
text_count = 0
for master_token in master_token_list:
- if master_token.startswith(' len(match_text_dict) - 1:
text_count += 1
continue
@@ -576,12 +591,13 @@ def insert_text_to_token(master_token_list, match_text_dict):
continue
else:
master_token = master_token.replace(
- '><', '>{}<'.format(match_text_dict[text_count]))
+ "><", ">{}<".format(match_text_dict[text_count])
+ )
text_count += 1
master_token = deal_eb_token(master_token)
merged_result_list.append(master_token)
- return ''.join(merged_result_list)
+ return "".join(merged_result_list)
def deal_isolate_span(thead_part):
@@ -593,25 +609,29 @@ def deal_isolate_span(thead_part):
:return:
"""
# 1. find out isolate span tokens.
- isolate_pattern = " | | rowspan=\"(\d)+\" colspan=\"(\d)+\">|" \
- " | colspan=\"(\d)+\" rowspan=\"(\d)+\">|" \
- " | rowspan=\"(\d)+\">|" \
- " | colspan=\"(\d)+\">"
+ isolate_pattern = (
+ ' | rowspan="(\d)+" colspan="(\d)+">|'
+ ' | colspan="(\d)+" rowspan="(\d)+">|'
+ ' | rowspan="(\d)+">|'
+ ' | colspan="(\d)+">'
+ )
isolate_iter = re.finditer(isolate_pattern, thead_part)
isolate_list = [i.group() for i in isolate_iter]
# 2. find out span number, by step 1 results.
- span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \
- " colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \
- " rowspan=\"(\d)+\"|" \
- " colspan=\"(\d)+\""
+ span_pattern = (
+ ' rowspan="(\d)+" colspan="(\d)+"|'
+ ' colspan="(\d)+" rowspan="(\d)+"|'
+ ' rowspan="(\d)+"|'
+ ' colspan="(\d)+"'
+ )
corrected_list = []
for isolate_item in isolate_list:
span_part = re.search(span_pattern, isolate_item)
spanStr_in_isolateItem = span_part.group()
# 3. merge the span number into the span token format string.
if spanStr_in_isolateItem is not None:
- corrected_item = ' | '.format(spanStr_in_isolateItem)
+ corrected_item = " | ".format(spanStr_in_isolateItem)
corrected_list.append(corrected_item)
else:
corrected_list.append(None)
@@ -633,24 +653,25 @@ def deal_duplicate_bb(thead_part):
:return:
"""
# 1. find out | in .
- td_pattern = "(.+?) | |" \
- "(.+?) | |" \
- "(.+?) | |" \
- "(.+?) | |" \
- "(.*?) | "
+ td_pattern = (
+ '(.+?) | |'
+ '(.+?) | |'
+ '(.+?) | |'
+ '(.+?) | |'
+ "(.*?) | "
+ )
td_iter = re.finditer(td_pattern, thead_part)
td_list = [t.group() for t in td_iter]
# 2. is multiply in | or not?
new_td_list = []
for td_item in td_list:
- if td_item.count('') > 1 or td_item.count('') > 1:
+ if td_item.count("") > 1 or td_item.count("") > 1:
# multiply in | case.
# 1. remove all
- td_item = td_item.replace('', '').replace('', '')
+ td_item = td_item.replace("", "").replace("", "")
# 2. replace -> , -> .
- td_item = td_item.replace('', ' | ').replace(' | ',
- '')
+ td_item = td_item.replace("", " | ").replace(" | ", "")
new_td_list.append(td_item)
else:
new_td_list.append(td_item)
@@ -669,14 +690,14 @@ def deal_bb(result_token):
:return:
"""
# find out parts.
- thead_pattern = '(.*?)'
+ thead_pattern = "(.*?)"
if re.search(thead_pattern, result_token) is None:
return result_token
thead_part = re.search(thead_pattern, result_token).group()
origin_thead_part = copy.deepcopy(thead_part)
# check "rowspan" or "colspan" occur in parts or not .
- span_pattern = "| | | | | | "
+ span_pattern = ' | | | | | | | '
span_iter = re.finditer(span_pattern, thead_part)
span_list = [s.group() for s in span_iter]
has_span_in_head = True if len(span_list) > 0 else False
@@ -686,10 +707,12 @@ def deal_bb(result_token):
# 1. replace | to | , and | to
# 2. it is possible to predict text include or by Text-line recognition,
# so we replace to , and to
- thead_part = thead_part.replace('', ' | ')\
- .replace(' | ', '')\
- .replace('', '')\
- .replace('', '')
+ thead_part = (
+ thead_part.replace("", " | ")
+ .replace(" | ", "")
+ .replace("", "")
+ .replace("", "")
+ )
else:
# include "rowspan" or "colspan" branch 2.
# Firstly, we deal rowspan or colspan cases.
@@ -703,12 +726,12 @@ def deal_bb(result_token):
# replace ">" to ""
replaced_span_list = []
for sp in span_list:
- replaced_span_list.append(sp.replace('>', '>'))
+ replaced_span_list.append(sp.replace(">", ">"))
for sp, rsp in zip(span_list, replaced_span_list):
thead_part = thead_part.replace(sp, rsp)
# replace "" to ""
- thead_part = thead_part.replace('', '')
+ thead_part = thead_part.replace("", "")
# remove duplicated by re.sub
mb_pattern = "()+"
@@ -720,12 +743,11 @@ def deal_bb(result_token):
thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
# ordinary cases like branch 1
- thead_part = thead_part.replace('', ' | ').replace('',
- '')
+ thead_part = thead_part.replace(" | ", " | ").replace("", "")
# convert back to , empty cell has no .
# but space cell( ) is suitable for | |
- thead_part = thead_part.replace(' | ', ' | ')
+ thead_part = thead_part.replace(" | ", " | ")
# deal with duplicated
thead_part = deal_duplicate_bb(thead_part)
# deal with isolate span tokens, which causes by wrong predict by structure prediction.
@@ -745,9 +767,10 @@ def __init__(self, end2end_file, structure_master_file):
"""
self.end2end_file = end2end_file
self.structure_master_file = structure_master_file
- self.end2end_results = pickle_load(end2end_file, prefix='end2end')
+ self.end2end_results = pickle_load(end2end_file, prefix="end2end")
self.structure_master_results = pickle_load(
- structure_master_file, prefix='structure')
+ structure_master_file, prefix="structure"
+ )
def match(self):
"""
@@ -759,51 +782,67 @@ def match(self):
:return:
"""
match_results = dict()
- for idx, (file_name,
- end2end_result) in enumerate(self.end2end_results.items()):
+ for idx, (file_name, end2end_result) in enumerate(self.end2end_results.items()):
match_list = []
if file_name not in self.structure_master_results:
continue
structure_master_result = self.structure_master_results[file_name]
- end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \
- get_bboxes_list(end2end_result, structure_master_result)
+ (
+ end2end_xyxy_bboxes,
+ end2end_xywh_bboxes,
+ structure_master_xywh_bboxes,
+ structure_master_xyxy_bboxes,
+ ) = get_bboxes_list(end2end_result, structure_master_result)
# rule 1: center rule
- center_rule_match_list = \
- center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes)
+ center_rule_match_list = center_rule_match(
+ end2end_xywh_bboxes, structure_master_xyxy_bboxes
+ )
match_list.extend(center_rule_match_list)
# rule 2: iou rule
# firstly, find not match index in previous step.
- center_no_match_end2end_indexs = \
- find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
+ center_no_match_end2end_indexs = find_no_match(
+ match_list, len(end2end_xywh_bboxes), type="end2end"
+ )
if len(center_no_match_end2end_indexs) > 0:
center_no_match_end2end_xyxy = end2end_xyxy_bboxes[
- center_no_match_end2end_indexs]
+ center_no_match_end2end_indexs
+ ]
# secondly, iou rule match
- iou_rule_match_list = \
- iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes)
+ iou_rule_match_list = iou_rule_match(
+ center_no_match_end2end_xyxy,
+ center_no_match_end2end_indexs,
+ structure_master_xyxy_bboxes,
+ )
match_list.extend(iou_rule_match_list)
# rule 3: distance rule
# match between no-match end2end bboxes and no-match master bboxes.
# it will return master_bboxes_nums match-pairs.
# firstly, find not match index in previous step.
- centerIou_no_match_end2end_indexs = \
- find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
- centerIou_no_match_master_indexs = \
- find_no_match(match_list, len(structure_master_xywh_bboxes), type='master')
- if len(centerIou_no_match_master_indexs) > 0 and len(
- centerIou_no_match_end2end_indexs) > 0:
+ centerIou_no_match_end2end_indexs = find_no_match(
+ match_list, len(end2end_xywh_bboxes), type="end2end"
+ )
+ centerIou_no_match_master_indexs = find_no_match(
+ match_list, len(structure_master_xywh_bboxes), type="master"
+ )
+ if (
+ len(centerIou_no_match_master_indexs) > 0
+ and len(centerIou_no_match_end2end_indexs) > 0
+ ):
centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[
- centerIou_no_match_end2end_indexs]
+ centerIou_no_match_end2end_indexs
+ ]
centerIou_no_match_master_xywh = structure_master_xywh_bboxes[
- centerIou_no_match_master_indexs]
+ centerIou_no_match_master_indexs
+ ]
distance_match_list = distance_rule_match(
centerIou_no_match_end2end_indexs,
centerIou_no_match_end2end_xywh,
centerIou_no_match_master_indexs,
- centerIou_no_match_master_xywh)
+ centerIou_no_match_master_xywh,
+ )
match_list.extend(distance_match_list)
# TODO:
@@ -813,18 +852,22 @@ def match(self):
# For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
# The above extra insert bboxes will be further processed in "formatOutput" function.
# After this operation, it will increase TEDS score.
- no_match_end2end_indexes = \
- find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
+ no_match_end2end_indexes = find_no_match(
+ match_list, len(end2end_xywh_bboxes), type="end2end"
+ )
if len(no_match_end2end_indexes) > 0:
- no_match_end2end_xywh = end2end_xywh_bboxes[
- no_match_end2end_indexes]
+ no_match_end2end_xywh = end2end_xywh_bboxes[no_match_end2end_indexes]
# sort the render no-match end2end bbox in row
- end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \
- sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
+ (
+ end2end_sorted_indexes_list,
+ end2end_sorted_bboxes_list,
+ sorted_groups,
+ sorted_bboxes_groups,
+ ) = sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
# make virtual master bboxes, and get matching with the no-match end2end bboxes.
extra_match_list = extra_match(
- end2end_sorted_indexes_list,
- len(structure_master_xywh_bboxes))
+ end2end_sorted_indexes_list, len(structure_master_xywh_bboxes)
+ )
match_list_add_extra_match = copy.deepcopy(match_list)
match_list_add_extra_match.extend(extra_match_list)
else:
@@ -834,10 +877,10 @@ def match(self):
sorted_bboxes_groups = []
match_result_dict = {
- 'match_list': match_list,
- 'match_list_add_extra_match': match_list_add_extra_match,
- 'sorted_groups': sorted_groups,
- 'sorted_bboxes_groups': sorted_bboxes_groups
+ "match_list": match_list,
+ "match_list_add_extra_match": match_list_add_extra_match,
+ "sorted_groups": sorted_groups,
+ "sorted_bboxes_groups": sorted_bboxes_groups,
}
# format output
@@ -856,22 +899,22 @@ def _format(self, match_result, file_name):
"""
end2end_info = self.end2end_results[file_name]
master_info = self.structure_master_results[file_name]
- master_token = master_info['text']
- sorted_groups = match_result['sorted_groups']
+ master_token = master_info["text"]
+ sorted_groups = match_result["sorted_groups"]
# creat virtual master token
virtual_master_token_list = []
for line_group in sorted_groups:
- tmp_list = ['']
+ tmp_list = ["
"]
item_nums = len(line_group)
for _ in range(item_nums):
- tmp_list.append(' | ')
- tmp_list.append('
')
+ tmp_list.append(" | ")
+ tmp_list.append("")
virtual_master_token_list.extend(tmp_list)
# insert virtual master token
- master_token_list = master_token.split(',')
- if master_token_list[-1] == '':
+ master_token_list = master_token.split(",")
+ if master_token_list[-1] == "":
# complete predict(no cut by max length)
# This situation insert virtual master token will drop TEDs score in val set.
# So we will not extend virtual token in this situation.
@@ -884,16 +927,16 @@ def _format(self, match_result, file_name):
# master_token_list.extend(virtual_master_token_list)
# master_token_list.append('')
- elif master_token_list[-1] == ' | ':
- master_token_list.append('')
+ elif master_token_list[-1] == " | ":
+ master_token_list.append("")
master_token_list.extend(virtual_master_token_list)
- master_token_list.append('')
+ master_token_list.append("")
else:
master_token_list.extend(virtual_master_token_list)
- master_token_list.append('')
+ master_token_list.append("")
# format output
- match_result.setdefault('matched_master_token_list', master_token_list)
+ match_result.setdefault("matched_master_token_list", master_token_list)
return match_result
def get_merge_result(self, match_results):
@@ -905,18 +948,16 @@ def get_merge_result(self, match_results):
merged_results = dict()
# break_token is linefeed token, when one master bbox has multiply end2end bboxes.
- break_token = ' '
+ break_token = " "
for idx, (file_name, match_info) in enumerate(match_results.items()):
end2end_info = self.end2end_results[file_name]
- master_token_list = match_info['matched_master_token_list']
- match_list = match_info['match_list_add_extra_match']
+ master_token_list = match_info["matched_master_token_list"]
+ match_list = match_info["match_list_add_extra_match"]
match_dict = get_match_dict(match_list)
- match_text_dict = get_match_text_dict(match_dict, end2end_info,
- break_token)
- merged_result = insert_text_to_token(master_token_list,
- match_text_dict)
+ match_text_dict = get_match_text_dict(match_dict, end2end_info, break_token)
+ merged_result = insert_text_to_token(master_token_list, match_text_dict)
merged_result = deal_bb(merged_result)
merged_results[file_name] = merged_result
@@ -933,21 +974,22 @@ def __call__(self, structure_res, dt_boxes, rec_res, img_name=1):
for dt_box, res in zip(dt_boxes, rec_res):
d = dict(
bbox=np.array(dt_box),
- text=res[0], )
+ text=res[0],
+ )
end2end_results[img_name].append(d)
self.end2end_results = end2end_results
structure_master_result_dict = {img_name: {}}
pred_structures, pred_bboxes = structure_res
- pred_structures = ','.join(pred_structures[3:-3])
- structure_master_result_dict[img_name]['text'] = pred_structures
- structure_master_result_dict[img_name]['bbox'] = pred_bboxes
+ pred_structures = ",".join(pred_structures[3:-3])
+ structure_master_result_dict[img_name]["text"] = pred_structures
+ structure_master_result_dict[img_name]["bbox"] = pred_bboxes
self.structure_master_results = structure_master_result_dict
# match
match_results = self.match()
merged_results = self.get_merge_result(match_results)
pred_html = merged_results[img_name]
- pred_html = ''
+ pred_html = ""
return pred_html
diff --git a/ppstructure/table/table_metric/__init__.py b/ppstructure/table/table_metric/__init__.py
index de2d307430..70be931de4 100755
--- a/ppstructure/table/table_metric/__init__.py
+++ b/ppstructure/table/table_metric/__init__.py
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['TEDS']
-from .table_metric import TEDS
\ No newline at end of file
+__all__ = ["TEDS"]
+from .table_metric import TEDS
diff --git a/ppstructure/table/table_metric/parallel.py b/ppstructure/table/table_metric/parallel.py
index f7326a1f50..f706f35a0e 100755
--- a/ppstructure/table/table_metric/parallel.py
+++ b/ppstructure/table/table_metric/parallel.py
@@ -4,27 +4,31 @@
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
"""
- A parallel version of the map function with a progress bar.
- Args:
- array (array-like): An array to iterate over.
- function (function): A python function to apply to the elements of array
- n_jobs (int, default=16): The number of cores to use
- use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
- keyword arguments to function
- front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
- Useful for catching bugs
- Returns:
- [function(array[0]), function(array[1]), ...]
+ A parallel version of the map function with a progress bar.
+ Args:
+ array (array-like): An array to iterate over.
+ function (function): A python function to apply to the elements of array
+ n_jobs (int, default=16): The number of cores to use
+ use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
+ keyword arguments to function
+ front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
+ Useful for catching bugs
+ Returns:
+ [function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
- front = [function(**a) if use_kwargs else function(a)
- for a in array[:front_num]]
+ front = [
+ function(**a) if use_kwargs else function(a) for a in array[:front_num]
+ ]
else:
front = []
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
- return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
+ return front + [
+ function(**a) if use_kwargs else function(a)
+ for a in tqdm(array[front_num:])
+ ]
# Assemble the workers
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
# Pass the elements of array into function
@@ -33,10 +37,10 @@ def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
else:
futures = [pool.submit(function, a) for a in array[front_num:]]
kwargs = {
- 'total': len(futures),
- 'unit': 'it',
- 'unit_scale': True,
- 'leave': True
+ "total": len(futures),
+ "unit": "it",
+ "unit_scale": True,
+ "leave": True,
}
# Print out the progress as tasks complete
for f in tqdm(as_completed(futures), **kwargs):
diff --git a/ppstructure/table/table_metric/table_metric.py b/ppstructure/table/table_metric/table_metric.py
index 923a9c0071..d5ba6a0afb 100755
--- a/ppstructure/table/table_metric/table_metric.py
+++ b/ppstructure/table/table_metric/table_metric.py
@@ -28,9 +28,13 @@ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
def bracket(self):
"""Show tree using brackets notation"""
- if self.tag == 'td':
- result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
- (self.tag, self.colspan, self.rowspan, self.content)
+ if self.tag == "td":
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
+ self.tag,
+ self.colspan,
+ self.rowspan,
+ self.content,
+ )
else:
result = '"tag": %s' % self.tag
for child in self.children:
@@ -41,117 +45,130 @@ def bracket(self):
class CustomConfig(Config):
def rename(self, node1, node2):
"""Compares attributes of trees"""
- #print(node1.tag)
- if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
- return 1.
- if node1.tag == 'td':
+ # print(node1.tag)
+ if (
+ (node1.tag != node2.tag)
+ or (node1.colspan != node2.colspan)
+ or (node1.rowspan != node2.rowspan)
+ ):
+ return 1.0
+ if node1.tag == "td":
if node1.content or node2.content:
- #print(node1.content, )
+ # print(node1.content, )
return Levenshtein.normalized_distance(node1.content, node2.content)
- return 0.
-
+ return 0.0
class CustomConfig_del_short(Config):
def rename(self, node1, node2):
"""Compares attributes of trees"""
- if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
- return 1.
- if node1.tag == 'td':
+ if (
+ (node1.tag != node2.tag)
+ or (node1.colspan != node2.colspan)
+ or (node1.rowspan != node2.rowspan)
+ ):
+ return 1.0
+ if node1.tag == "td":
if node1.content or node2.content:
- #print('before')
- #print(node1.content, node2.content)
- #print('after')
+ # print('before')
+ # print(node1.content, node2.content)
+ # print('after')
node1_content = node1.content
node2_content = node2.content
if len(node1_content) < 3:
- node1_content = ['####']
+ node1_content = ["####"]
if len(node2_content) < 3:
- node2_content = ['####']
+ node2_content = ["####"]
return Levenshtein.normalized_distance(node1_content, node2_content)
- return 0.
+ return 0.0
+
class CustomConfig_del_block(Config):
def rename(self, node1, node2):
"""Compares attributes of trees"""
- if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
- return 1.
- if node1.tag == 'td':
+ if (
+ (node1.tag != node2.tag)
+ or (node1.colspan != node2.colspan)
+ or (node1.rowspan != node2.rowspan)
+ ):
+ return 1.0
+ if node1.tag == "td":
if node1.content or node2.content:
-
node1_content = node1.content
node2_content = node2.content
- while ' ' in node1_content:
- print(node1_content.index(' '))
- node1_content.pop(node1_content.index(' '))
- while ' ' in node2_content:
- print(node2_content.index(' '))
- node2_content.pop(node2_content.index(' '))
+ while " " in node1_content:
+ print(node1_content.index(" "))
+ node1_content.pop(node1_content.index(" "))
+ while " " in node2_content:
+ print(node2_content.index(" "))
+ node2_content.pop(node2_content.index(" "))
return Levenshtein.normalized_distance(node1_content, node2_content)
- return 0.
+ return 0.0
+
class TEDS(object):
- ''' Tree Edit Distance basead Similarity
- '''
+ """Tree Edit Distance basead Similarity"""
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (
- n_jobs >= 1), 'n_jobs must be an integer greather than 1'
+ n_jobs >= 1
+ ), "n_jobs must be an integer greather than 1"
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []
def tokenize(self, node):
- ''' Tokenizes table cells
- '''
- self.__tokens__.append('<%s>' % node.tag)
+ """Tokenizes table cells"""
+ self.__tokens__.append("<%s>" % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
- if node.tag != 'unk':
- self.__tokens__.append('%s>' % node.tag)
- if node.tag != 'td' and node.tail is not None:
+ if node.tag != "unk":
+ self.__tokens__.append("%s>" % node.tag)
+ if node.tag != "td" and node.tail is not None:
self.__tokens__ += list(node.tail)
def load_html_tree(self, node, parent=None):
- ''' Converts HTML tree to the format required by apted
- '''
+ """Converts HTML tree to the format required by apted"""
global __tokens__
- if node.tag == 'td':
+ if node.tag == "td":
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()
- new_node = TableTree(node.tag,
- int(node.attrib.get('colspan', '1')),
- int(node.attrib.get('rowspan', '1')),
- cell, *deque())
+ new_node = TableTree(
+ node.tag,
+ int(node.attrib.get("colspan", "1")),
+ int(node.attrib.get("rowspan", "1")),
+ cell,
+ *deque()
+ )
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
- if node.tag != 'td':
+ if node.tag != "td":
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node
def evaluate(self, pred, true):
- ''' Computes TEDS score between the prediction and the ground truth of a
- given sample
- '''
+ """Computes TEDS score between the prediction and the ground truth of a
+ given sample
+ """
if (not pred) or (not true):
return 0.0
- parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
+ parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
- if pred.xpath('body/table') and true.xpath('body/table'):
- pred = pred.xpath('body/table')[0]
- true = true.xpath('body/table')[0]
+ if pred.xpath("body/table") and true.xpath("body/table"):
+ pred = pred.xpath("body/table")[0]
+ true = true.xpath("body/table")[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
@@ -160,53 +177,68 @@ def evaluate(self, pred, true):
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)
- distance = APTED(tree_pred, tree_true,
- CustomConfig()).compute_edit_distance()
+ distance = APTED(
+ tree_pred, tree_true, CustomConfig()
+ ).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0
def batch_evaluate(self, pred_json, true_json):
- ''' Computes TEDS score between the prediction and the ground truth of
- a batch of samples
- @params pred_json: {'FILENAME': 'HTML CODE', ...}
- @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
- @output: {'FILENAME': 'TEDS SCORE', ...}
- '''
+ """Computes TEDS score between the prediction and the ground truth of
+ a batch of samples
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
+ @output: {'FILENAME': 'TEDS SCORE', ...}
+ """
samples = true_json.keys()
if self.n_jobs == 1:
- scores = [self.evaluate(pred_json.get(
- filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
+ scores = [
+ self.evaluate(pred_json.get(filename, ""), true_json[filename]["html"])
+ for filename in tqdm(samples)
+ ]
else:
- inputs = [{'pred': pred_json.get(
- filename, ''), 'true': true_json[filename]['html']} for filename in samples]
+ inputs = [
+ {
+ "pred": pred_json.get(filename, ""),
+ "true": true_json[filename]["html"],
+ }
+ for filename in samples
+ ]
scores = parallel_process(
- inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
+ inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
+ )
scores = dict(zip(samples, scores))
return scores
def batch_evaluate_html(self, pred_htmls, true_htmls):
- ''' Computes TEDS score between the prediction and the ground truth of
- a batch of samples
- '''
+ """Computes TEDS score between the prediction and the ground truth of
+ a batch of samples
+ """
if self.n_jobs == 1:
- scores = [self.evaluate(pred_html, true_html) for (
- pred_html, true_html) in zip(pred_htmls, true_htmls)]
+ scores = [
+ self.evaluate(pred_html, true_html)
+ for (pred_html, true_html) in zip(pred_htmls, true_htmls)
+ ]
else:
- inputs = [{"pred": pred_html, "true": true_html} for(
- pred_html, true_html) in zip(pred_htmls, true_htmls)]
+ inputs = [
+ {"pred": pred_html, "true": true_html}
+ for (pred_html, true_html) in zip(pred_htmls, true_htmls)
+ ]
scores = parallel_process(
- inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
+ inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
+ )
return scores
-if __name__ == '__main__':
+if __name__ == "__main__":
import json
import pprint
- with open('sample_pred.json') as fp:
+
+ with open("sample_pred.json") as fp:
pred_json = json.load(fp)
- with open('sample_gt.json') as fp:
+ with open("sample_gt.json") as fp:
true_json = json.load(fp)
teds = TEDS(n_jobs=4)
scores = teds.batch_evaluate(pred_json, true_json)
diff --git a/ppstructure/table/tablepyxl/__init__.py b/ppstructure/table/tablepyxl/__init__.py
index dc0085071c..1d11e26559 100644
--- a/ppstructure/table/tablepyxl/__init__.py
+++ b/ppstructure/table/tablepyxl/__init__.py
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.
\ No newline at end of file
+# limitations under the License.
diff --git a/ppstructure/table/tablepyxl/style.py b/ppstructure/table/tablepyxl/style.py
index ebd794b1b4..4787e7d377 100644
--- a/ppstructure/table/tablepyxl/style.py
+++ b/ppstructure/table/tablepyxl/style.py
@@ -2,19 +2,27 @@
# and cascading those from parent to child in the dom.
from openpyxl.cell import cell
-from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
+from openpyxl.styles import (
+ Font,
+ Alignment,
+ PatternFill,
+ NamedStyle,
+ Border,
+ Side,
+ Color,
+)
from openpyxl.styles.fills import FILL_SOLID
from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
from openpyxl.styles.colors import BLACK
-FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
+FORMAT_DATE_MMDDYYYY = "mm/dd/yyyy"
def colormap(color):
"""
Convenience for looking up known colors
"""
- cmap = {'black': BLACK}
+ cmap = {"black": BLACK}
return cmap.get(color, color)
@@ -22,15 +30,20 @@ def style_string_to_dict(style):
"""
Convert css style string to a python dictionary
"""
+
def clean_split(string, delim):
return (s.strip() for s in string.split(delim))
+
styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
return dict(styles)
def get_side(style, name):
- return {'border_style': style.get('border-{}-style'.format(name)),
- 'color': colormap(style.get('border-{}-color'.format(name)))}
+ return {
+ "border_style": style.get("border-{}-style".format(name)),
+ "color": colormap(style.get("border-{}-color".format(name))),
+ }
+
known_styles = {}
@@ -40,49 +53,65 @@ def style_dict_to_named_style(style_dict, number_format=None):
Change css style (stored in a python dictionary) to openpyxl NamedStyle
"""
- style_and_format_string = str({
- 'style_dict': style_dict,
- 'parent': style_dict.parent,
- 'number_format': number_format,
- })
+ style_and_format_string = str(
+ {
+ "style_dict": style_dict,
+ "parent": style_dict.parent,
+ "number_format": number_format,
+ }
+ )
if style_and_format_string not in known_styles:
# Font
- font = Font(bold=style_dict.get('font-weight') == 'bold',
- color=style_dict.get_color('color', None),
- size=style_dict.get('font-size'))
+ font = Font(
+ bold=style_dict.get("font-weight") == "bold",
+ color=style_dict.get_color("color", None),
+ size=style_dict.get("font-size"),
+ )
# Alignment
- alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
- vertical=style_dict.get('vertical-align'),
- wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
+ alignment = Alignment(
+ horizontal=style_dict.get("text-align", "general"),
+ vertical=style_dict.get("vertical-align"),
+ wrap_text=style_dict.get("white-space", "nowrap") == "normal",
+ )
# Fill
- bg_color = style_dict.get_color('background-color')
- fg_color = style_dict.get_color('foreground-color', Color())
- fill_type = style_dict.get('fill-type')
- if bg_color and bg_color != 'transparent':
- fill = PatternFill(fill_type=fill_type or FILL_SOLID,
- start_color=bg_color,
- end_color=fg_color)
+ bg_color = style_dict.get_color("background-color")
+ fg_color = style_dict.get_color("foreground-color", Color())
+ fill_type = style_dict.get("fill-type")
+ if bg_color and bg_color != "transparent":
+ fill = PatternFill(
+ fill_type=fill_type or FILL_SOLID,
+ start_color=bg_color,
+ end_color=fg_color,
+ )
else:
fill = PatternFill()
# Border
- border = Border(left=Side(**get_side(style_dict, 'left')),
- right=Side(**get_side(style_dict, 'right')),
- top=Side(**get_side(style_dict, 'top')),
- bottom=Side(**get_side(style_dict, 'bottom')),
- diagonal=Side(**get_side(style_dict, 'diagonal')),
- diagonal_direction=None,
- outline=Side(**get_side(style_dict, 'outline')),
- vertical=None,
- horizontal=None)
-
- name = 'Style {}'.format(len(known_styles) + 1)
-
- pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
- number_format=number_format)
+ border = Border(
+ left=Side(**get_side(style_dict, "left")),
+ right=Side(**get_side(style_dict, "right")),
+ top=Side(**get_side(style_dict, "top")),
+ bottom=Side(**get_side(style_dict, "bottom")),
+ diagonal=Side(**get_side(style_dict, "diagonal")),
+ diagonal_direction=None,
+ outline=Side(**get_side(style_dict, "outline")),
+ vertical=None,
+ horizontal=None,
+ )
+
+ name = "Style {}".format(len(known_styles) + 1)
+
+ pyxl_style = NamedStyle(
+ name=name,
+ font=font,
+ fill=fill,
+ alignment=alignment,
+ border=border,
+ number_format=number_format,
+ )
known_styles[style_and_format_string] = pyxl_style
@@ -93,8 +122,9 @@ class StyleDict(dict):
"""
It's like a dictionary, but it looks for items in the parent dictionary
"""
+
def __init__(self, *args, **kwargs):
- self.parent = kwargs.pop('parent', None)
+ self.parent = kwargs.pop("parent", None)
super(StyleDict, self).__init__(*args, **kwargs)
def __getitem__(self, item):
@@ -103,7 +133,7 @@ def __getitem__(self, item):
elif self.parent:
return self.parent[item]
else:
- raise KeyError('{} not found'.format(item))
+ raise KeyError("{} not found".format(item))
def __hash__(self):
return hash(tuple([(k, self.get(k)) for k in self._keys()]))
@@ -133,10 +163,12 @@ def get_color(self, k, d=None):
Strip leading # off colors if necessary
"""
color = self.get(k, d)
- if hasattr(color, 'startswith') and color.startswith('#'):
+ if hasattr(color, "startswith") and color.startswith("#"):
color = color[1:]
- if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
- color = ''.join(2 * c for c in color)
+ if (
+ len(color) == 3
+ ): # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
+ color = "".join(2 * c for c in color)
return color
@@ -146,11 +178,14 @@ class Element(object):
The element is created along with a parent so that the StyleDict that we store
can point to the parent's StyleDict.
"""
+
def __init__(self, element, parent=None):
self.element = element
self.number_format = None
parent_style = parent.style_dict if parent else None
- self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
+ self.style_dict = StyleDict(
+ style_string_to_dict(element.get("style", "")), parent=parent_style
+ )
self._style_cache = None
def style(self):
@@ -158,7 +193,9 @@ def style(self):
Turn the css styles for this element into an openpyxl NamedStyle.
"""
if not self._style_cache:
- self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
+ self._style_cache = style_dict_to_named_style(
+ self.style_dict, number_format=self.number_format
+ )
return self._style_cache
def get_dimension(self, dimension_key):
@@ -167,7 +204,7 @@ def get_dimension(self, dimension_key):
"""
dimension = self.style_dict.get(dimension_key)
if dimension:
- if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
+ if dimension[-2:] in ["px", "em", "pt", "in", "cm"]:
dimension = dimension[:-2]
dimension = float(dimension)
return dimension
@@ -179,42 +216,52 @@ class Table(Element):
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
"""
+
def __init__(self, table):
"""
takes an html table object (from lxml)
"""
super(Table, self).__init__(table)
- table_head = table.find('thead')
- self.head = TableHead(table_head, parent=self) if table_head is not None else None
- table_body = table.find('tbody')
- self.body = TableBody(table_body if table_body is not None else table, parent=self)
+ table_head = table.find("thead")
+ self.head = (
+ TableHead(table_head, parent=self) if table_head is not None else None
+ )
+ table_body = table.find("tbody")
+ self.body = TableBody(
+ table_body if table_body is not None else table, parent=self
+ )
class TableHead(Element):
"""
This class maps to the `` element of the html table.
"""
+
def __init__(self, head, parent=None):
super(TableHead, self).__init__(head, parent=parent)
- self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
+ self.rows = [TableRow(tr, parent=self) for tr in head.findall("tr")]
class TableBody(Element):
"""
This class maps to the ` | ` element of the html table.
"""
+
def __init__(self, body, parent=None):
super(TableBody, self).__init__(body, parent=parent)
- self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
+ self.rows = [TableRow(tr, parent=self) for tr in body.findall("tr")]
class TableRow(Element):
"""
This class maps to the `` element of the html table.
"""
+
def __init__(self, tr, parent=None):
super(TableRow, self).__init__(tr, parent=parent)
- self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
+ self.cells = [
+ TableCell(cell, parent=self) for cell in tr.findall("th") + tr.findall("td")
+ ]
def element_to_string(el):
@@ -222,23 +269,35 @@ def element_to_string(el):
def _element_to_string(el):
- string = ''
+ string = ""
for x in el.iterchildren():
- string += '\n' + _element_to_string(x)
+ string += "\n" + _element_to_string(x)
- text = el.text.strip() if el.text else ''
- tail = el.tail.strip() if el.tail else ''
+ text = el.text.strip() if el.text else ""
+ tail = el.tail.strip() if el.tail else ""
- return text + string + '\n' + tail
+ return text + string + "\n" + tail
class TableCell(Element):
"""
This class maps to the `` element of the html table.
"""
- CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
- 'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
+
+ CELL_TYPES = {
+ "TYPE_STRING",
+ "TYPE_FORMULA",
+ "TYPE_NUMERIC",
+ "TYPE_BOOL",
+ "TYPE_CURRENCY",
+ "TYPE_PERCENTAGE",
+ "TYPE_NULL",
+ "TYPE_INLINE",
+ "TYPE_ERROR",
+ "TYPE_FORMULA_CACHE_STRING",
+ "TYPE_INTEGER",
+ }
def __init__(self, cell, parent=None):
super(TableCell, self).__init__(cell, parent=parent)
@@ -246,38 +305,38 @@ def __init__(self, cell, parent=None):
self.number_format = self.get_number_format()
def data_type(self):
- cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
+ cell_types = self.CELL_TYPES & set(self.element.get("class", "").split())
if cell_types:
- if 'TYPE_FORMULA' in cell_types:
+ if "TYPE_FORMULA" in cell_types:
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
- cell_type = 'TYPE_FORMULA'
- elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
- cell_type = 'TYPE_NUMERIC'
+ cell_type = "TYPE_FORMULA"
+ elif cell_types & {"TYPE_CURRENCY", "TYPE_INTEGER", "TYPE_PERCENTAGE"}:
+ cell_type = "TYPE_NUMERIC"
else:
cell_type = cell_types.pop()
else:
- cell_type = 'TYPE_STRING'
+ cell_type = "TYPE_STRING"
return getattr(cell, cell_type)
def get_number_format(self):
- if 'TYPE_CURRENCY' in self.element.get('class', '').split():
+ if "TYPE_CURRENCY" in self.element.get("class", "").split():
return FORMAT_CURRENCY_USD_SIMPLE
- if 'TYPE_INTEGER' in self.element.get('class', '').split():
- return '#,##0'
- if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
+ if "TYPE_INTEGER" in self.element.get("class", "").split():
+ return "#,##0"
+ if "TYPE_PERCENTAGE" in self.element.get("class", "").split():
return FORMAT_PERCENTAGE
- if 'TYPE_DATE' in self.element.get('class', '').split():
+ if "TYPE_DATE" in self.element.get("class", "").split():
return FORMAT_DATE_MMDDYYYY
if self.data_type() == cell.TYPE_NUMERIC:
try:
int(self.value)
except ValueError:
- return '#,##0.##'
+ return "#,##0.##"
else:
- return '#,##0'
+ return "#,##0"
def format(self, cell):
cell.style = self.style()
data_type = self.data_type()
if data_type:
- cell.data_type = data_type
\ No newline at end of file
+ cell.data_type = data_type
diff --git a/ppstructure/table/tablepyxl/tablepyxl.py b/ppstructure/table/tablepyxl/tablepyxl.py
index ba3cc0fc49..95b75d1ac7 100644
--- a/ppstructure/table/tablepyxl/tablepyxl.py
+++ b/ppstructure/table/tablepyxl/tablepyxl.py
@@ -16,10 +16,10 @@ def string_to_int(s):
def get_Tables(doc):
tree = html.fromstring(doc)
- comments = tree.xpath('//comment()')
+ comments = tree.xpath("//comment()")
for comment in comments:
comment.drop_tag()
- return [Table(table) for table in tree.xpath('//table')]
+ return [Table(table) for table in tree.xpath("//table")]
def write_rows(worksheet, elem, row, column=1):
@@ -40,20 +40,27 @@ def write_rows(worksheet, elem, row, column=1):
colspan = string_to_int(table_cell.element.get("colspan", "1"))
rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
if rowspan > 1 or colspan > 1:
- worksheet.merge_cells(start_row=row, start_column=column,
- end_row=row + rowspan - 1, end_column=column + colspan - 1)
+ worksheet.merge_cells(
+ start_row=row,
+ start_column=column,
+ end_row=row + rowspan - 1,
+ end_column=column + colspan - 1,
+ )
cell.value = table_cell.value
table_cell.format(cell)
- min_width = table_cell.get_dimension('min-width')
- max_width = table_cell.get_dimension('max-width')
+ min_width = table_cell.get_dimension("min-width")
+ max_width = table_cell.get_dimension("max-width")
if colspan == 1:
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
# cell in the same column (i.e. width of A2 = width of A1)
- width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
+ width = max(
+ worksheet.column_dimensions[get_column_letter(column)].width or 0,
+ len(table_cell.value) + 2,
+ )
if max_width and width > max_width:
width = max_width
elif min_width and width < min_width:
@@ -70,7 +77,7 @@ def table_to_sheet(table, wb):
Takes a table and workbook and writes the table to a new sheet.
The sheet title will be the same as the table attribute name.
"""
- ws = wb.create_sheet(title=table.element.get('name'))
+ ws = wb.create_sheet(title=table.element.get("name"))
insert_table(table, ws, 1, 1)
@@ -84,7 +91,9 @@ def document_to_workbook(doc, wb=None, base_url=None):
wb = Workbook()
wb.remove(wb.active)
- inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
+ inline_styles_doc = Premailer(
+ doc, base_url=base_url, remove_classes=False
+ ).transform()
tables = get_Tables(inline_styles_doc)
for table in tables:
@@ -115,4 +124,4 @@ def insert_table_at_cell(table, cell):
"""
ws = cell.parent
column, row = cell.column, cell.row
- insert_table(table, ws, column, row)
\ No newline at end of file
+ insert_table(table, ws, column, row)
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index aa4742e97a..bffc1fdda0 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -16,104 +16,116 @@
import PIL
from PIL import Image, ImageDraw, ImageFont
import numpy as np
-from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
+from tools.infer.utility import (
+ draw_ocr_box_txt,
+ str2bool,
+ str2int_tuple,
+ init_args as infer_args,
+)
import math
+
def init_args():
parser = infer_args()
# params for output
- parser.add_argument("--output", type=str, default='./output')
+ parser.add_argument("--output", type=str, default="./output")
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
- parser.add_argument("--table_algorithm", type=str, default='TableAttn')
+ parser.add_argument("--table_algorithm", type=str, default="TableAttn")
parser.add_argument("--table_model_dir", type=str)
- parser.add_argument(
- "--merge_no_span_structure", type=str2bool, default=True)
+ parser.add_argument("--merge_no_span_structure", type=str2bool, default=True)
parser.add_argument(
"--table_char_dict_path",
type=str,
- default="../ppocr/utils/dict/table_structure_dict_ch.txt")
+ default="../ppocr/utils/dict/table_structure_dict_ch.txt",
+ )
# params for layout
parser.add_argument("--layout_model_dir", type=str)
parser.add_argument(
"--layout_dict_path",
type=str,
- default="../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt")
+ default="../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt",
+ )
parser.add_argument(
- "--layout_score_threshold",
- type=float,
- default=0.5,
- help="Threshold of score.")
+ "--layout_score_threshold", type=float, default=0.5, help="Threshold of score."
+ )
parser.add_argument(
- "--layout_nms_threshold",
- type=float,
- default=0.5,
- help="Threshold of nms.")
+ "--layout_nms_threshold", type=float, default=0.5, help="Threshold of nms."
+ )
# params for kie
- parser.add_argument("--kie_algorithm", type=str, default='LayoutXLM')
+ parser.add_argument("--kie_algorithm", type=str, default="LayoutXLM")
parser.add_argument("--ser_model_dir", type=str)
parser.add_argument("--re_model_dir", type=str)
parser.add_argument("--use_visual_backbone", type=str2bool, default=True)
parser.add_argument(
- "--ser_dict_path",
- type=str,
- default="../train_data/XFUND/class_list_xfun.txt")
+ "--ser_dict_path", type=str, default="../train_data/XFUND/class_list_xfun.txt"
+ )
# need to be None or tb-yx
parser.add_argument("--ocr_order_method", type=str, default=None)
# params for inference
parser.add_argument(
"--mode",
type=str,
- choices=['structure', 'kie'],
- default='structure',
- help='structure and kie is supported')
+ choices=["structure", "kie"],
+ default="structure",
+ help="structure and kie is supported",
+ )
parser.add_argument(
"--image_orientation",
type=bool,
default=False,
- help='Whether to enable image orientation recognition')
+ help="Whether to enable image orientation recognition",
+ )
parser.add_argument(
"--layout",
type=str2bool,
default=True,
- help='Whether to enable layout analysis')
+ help="Whether to enable layout analysis",
+ )
parser.add_argument(
"--table",
type=str2bool,
default=True,
- help='In the forward, whether the table area uses table recognition')
+ help="In the forward, whether the table area uses table recognition",
+ )
parser.add_argument(
"--ocr",
type=str2bool,
default=True,
- help='In the forward, whether the non-table area is recognition by ocr')
+ help="In the forward, whether the non-table area is recognition by ocr",
+ )
# param for recovery
parser.add_argument(
"--recovery",
type=str2bool,
default=False,
- help='Whether to enable layout of recovery')
+ help="Whether to enable layout of recovery",
+ )
parser.add_argument(
"--use_pdf2docx_api",
type=str2bool,
default=False,
- help='Whether to use pdf2docx api')
+ help="Whether to use pdf2docx api",
+ )
parser.add_argument(
"--invert",
type=str2bool,
default=False,
- help='Whether to invert image before processing')
+ help="Whether to invert image before processing",
+ )
parser.add_argument(
"--binarize",
type=str2bool,
default=False,
- help='Whether to threshold binarize image before processing')
+ help="Whether to threshold binarize image before processing",
+ )
parser.add_argument(
"--alphacolor",
type=str2int_tuple,
default=(255, 255, 255),
- help='Replacement color for the alpha channel, if the latter is present; R,G,B integers')
+ help="Replacement color for the alpha channel, if the latter is present; R,G,B integers",
+ )
return parser
@@ -137,51 +149,62 @@ def draw_structure_result(image, result, font_path):
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for region in result:
- if region['type'] not in catid2color:
- box_color = (random.randint(0, 255), random.randint(0, 255),
- random.randint(0, 255))
- catid2color[region['type']] = box_color
+ if region["type"] not in catid2color:
+ box_color = (
+ random.randint(0, 255),
+ random.randint(0, 255),
+ random.randint(0, 255),
+ )
+ catid2color[region["type"]] = box_color
else:
- box_color = catid2color[region['type']]
- box_layout = region['bbox']
+ box_color = catid2color[region["type"]]
+ box_layout = region["bbox"]
draw_layout.rectangle(
[(box_layout[0], box_layout[1]), (box_layout[2], box_layout[3])],
outline=box_color,
- width=3)
+ width=3,
+ )
- if int(PIL.__version__.split('.')[0]) < 10:
- text_w, text_h = font.getsize(region['type'])
+ if int(PIL.__version__.split(".")[0]) < 10:
+ text_w, text_h = font.getsize(region["type"])
else:
- left, top, right, bottom = font.getbbox(region['type'])
+ left, top, right, bottom = font.getbbox(region["type"])
text_w, text_h = right - left, bottom - top
draw_layout.rectangle(
- [(box_layout[0], box_layout[1]),
- (box_layout[0] + text_w, box_layout[1] + text_h)],
- fill=text_background_color)
+ [
+ (box_layout[0], box_layout[1]),
+ (box_layout[0] + text_w, box_layout[1] + text_h),
+ ],
+ fill=text_background_color,
+ )
draw_layout.text(
- (box_layout[0], box_layout[1]),
- region['type'],
- fill=text_color,
- font=font)
+ (box_layout[0], box_layout[1]), region["type"], fill=text_color, font=font
+ )
- if region['type'] == 'table':
+ if region["type"] == "table":
pass
else:
- for text_result in region['res']:
- boxes.append(np.array(text_result['text_region']))
- txts.append(text_result['text'])
- scores.append(text_result['confidence'])
+ for text_result in region["res"]:
+ boxes.append(np.array(text_result["text_region"]))
+ txts.append(text_result["text"])
+ scores.append(text_result["confidence"])
- if 'text_word_region' in text_result:
- for word_region in text_result['text_word_region']:
+ if "text_word_region" in text_result:
+ for word_region in text_result["text_word_region"]:
char_box = word_region
box_height = int(
- math.sqrt((char_box[0][0] - char_box[3][0])**2 + (
- char_box[0][1] - char_box[3][1])**2))
+ math.sqrt(
+ (char_box[0][0] - char_box[3][0]) ** 2
+ + (char_box[0][1] - char_box[3][1]) ** 2
+ )
+ )
box_width = int(
- math.sqrt((char_box[0][0] - char_box[1][0])**2 + (
- char_box[0][1] - char_box[1][1])**2))
+ math.sqrt(
+ (char_box[0][0] - char_box[1][0]) ** 2
+ + (char_box[0][1] - char_box[1][1]) ** 2
+ )
+ )
if box_height == 0 or box_width == 0:
continue
boxes.append(word_region)
@@ -189,11 +212,13 @@ def draw_structure_result(image, result, font_path):
scores.append(1.0)
im_show = draw_ocr_box_txt(
- img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
+ img_layout, boxes, txts, scores, font_path=font_path, drop_score=0
+ )
return im_show
+
def cal_ocr_word_box(rec_str, box, rec_word_info):
- ''' Calculate the detection frame for each word based on the results of recognition and detection of ocr'''
+ """Calculate the detection frame for each word based on the results of recognition and detection of ocr"""
col_num, word_list, word_col_list, state_list = rec_word_info
box = box.tolist()
@@ -209,7 +234,7 @@ def cal_ocr_word_box(rec_str, box, rec_word_info):
cn_width_list = []
cn_col_list = []
for word, word_col, state in zip(word_list, word_col_list, state_list):
- if state == 'cn':
+ if state == "cn":
if len(word_col) != 1:
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
char_width = char_seq_length / (len(word_col) - 1)
@@ -219,8 +244,12 @@ def cal_ocr_word_box(rec_str, box, rec_word_info):
else:
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width)
- cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start),
- (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
+ cell = (
+ (cell_x_start, bbox_y_start),
+ (cell_x_end, bbox_y_start),
+ (cell_x_end, bbox_y_end),
+ (cell_x_start, bbox_y_end),
+ )
word_box_list.append(cell)
word_box_content_list.append("".join(word))
if len(cn_col_list) != 0:
@@ -230,14 +259,17 @@ def cal_ocr_word_box(rec_str, box, rec_word_info):
avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_str)
for center_idx in cn_col_list:
center_x = (center_idx + 0.5) * cell_width
- cell_x_start = max(int(center_x - avg_char_width / 2),
- 0) + bbox_x_start
- cell_x_end = min(
- int(center_x + avg_char_width / 2), bbox_x_end -
- bbox_x_start) + bbox_x_start
- cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start),
- (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
+ cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start
+ cell_x_end = (
+ min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start)
+ + bbox_x_start
+ )
+ cell = (
+ (cell_x_start, bbox_y_start),
+ (cell_x_end, bbox_y_start),
+ (cell_x_end, bbox_y_end),
+ (cell_x_start, bbox_y_end),
+ )
word_box_list.append(cell)
return word_box_content_list, word_box_list
-
diff --git a/setup.py b/setup.py
index 703d61244e..c70113f725 100644
--- a/setup.py
+++ b/setup.py
@@ -18,15 +18,20 @@
import subprocess
# get version by matchiing, so will not need to setup complex env in github aciton
-p = subprocess.Popen("grep ^VERSION ./paddleocr.py | cut -d\\' -f 2", stdout=subprocess.PIPE,
- stderr=subprocess.PIPE, shell=True)
+p = subprocess.Popen(
+ "grep ^VERSION ./paddleocr.py | cut -d\\' -f 2",
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+)
raw_VERSION, _ = p.communicate()
VERSION = raw_VERSION.decode().strip()
+
def load_requirements(file_list=None):
if file_list is None:
- file_list = ['requirements.txt']
- if isinstance(file_list,str):
+ file_list = ["requirements.txt"]
+ if isinstance(file_list, str):
file_list = [file_list]
requirements = []
for file in file_list:
@@ -36,36 +41,41 @@ def load_requirements(file_list=None):
def readme():
- with open('doc/doc_en/whl_en.md', encoding="utf-8-sig") as f:
+ with open("doc/doc_en/whl_en.md", encoding="utf-8-sig") as f:
README = f.read()
return README
setup(
- name='paddleocr',
- packages=['paddleocr'],
- package_dir={'paddleocr': ''},
+ name="paddleocr",
+ packages=["paddleocr"],
+ package_dir={"paddleocr": ""},
include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version=VERSION,
- install_requires=load_requirements(['requirements.txt', 'ppstructure/recovery/requirements.txt']),
- license='Apache License 2.0',
- description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embedded and IoT devices',
+ install_requires=load_requirements(
+ ["requirements.txt", "ppstructure/recovery/requirements.txt"]
+ ),
+ license="Apache License 2.0",
+ description="Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embedded and IoT devices",
long_description=readme(),
- long_description_content_type='text/markdown',
- url='https://github.com/PaddlePaddle/PaddleOCR',
- download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
+ long_description_content_type="text/markdown",
+ url="https://github.com/PaddlePaddle/PaddleOCR",
+ download_url="https://github.com/PaddlePaddle/PaddleOCR.git",
keywords=[
- 'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
+ "ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition"
],
classifiers=[
- 'Intended Audience :: Developers', 'Operating System :: OS Independent',
- 'Natural Language :: Chinese (Simplified)',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.2',
- 'Programming Language :: Python :: 3.3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
- ], )
+ "Intended Audience :: Developers",
+ "Operating System :: OS Independent",
+ "Natural Language :: Chinese (Simplified)",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.2",
+ "Programming Language :: Python :: 3.3",
+ "Programming Language :: Python :: 3.4",
+ "Programming Language :: Python :: 3.5",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Topic :: Utilities",
+ ],
+)
diff --git a/test_tipc/compare_results.py b/test_tipc/compare_results.py
index e28410ed6c..96a1f20e03 100644
--- a/test_tipc/compare_results.py
+++ b/test_tipc/compare_results.py
@@ -24,11 +24,12 @@ def parse_args():
def run_shell_command(cmd):
p = subprocess.Popen(
- cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
+ )
out, err = p.communicate()
if p.returncode == 0:
- return out.decode('utf-8')
+ return out.decode("utf-8")
else:
return None
@@ -57,13 +58,13 @@ def parser_results_from_log_by_name(log_path, names_list):
def load_gt_from_file(gt_file):
if not os.path.exists(gt_file):
raise ValueError("The log file {} does not exists!".format(gt_file))
- with open(gt_file, 'r') as f:
+ with open(gt_file, "r") as f:
data = f.readlines()
f.close()
parser_gt = {}
for line in data:
image_name, result = line.strip("\n").split("\t")
- image_name = image_name.split('/')[-1]
+ image_name = image_name.split("/")[-1]
try:
result = json.loads(result)
except:
@@ -103,7 +104,8 @@ def collect_predict_from_logs(log_path, key_list):
def testing_assert_allclose(dict_x, dict_y, atol=1e-7, rtol=1e-7):
for k in dict_x:
np.testing.assert_allclose(
- np.array(dict_x[k]), np.array(dict_y[k]), atol=atol, rtol=rtol)
+ np.array(dict_x[k]), np.array(dict_y[k]), atol=atol, rtol=rtol
+ )
if __name__ == "__main__":
@@ -128,13 +130,16 @@ def testing_assert_allclose(dict_x, dict_y, atol=1e-7, rtol=1e-7):
pred_dict = pred_collection[filename]
try:
- testing_assert_allclose(
- gt_dict, pred_dict, atol=args.atol, rtol=args.rtol)
+ testing_assert_allclose(gt_dict, pred_dict, atol=args.atol, rtol=args.rtol)
print(
- "Assert allclose passed! The results of {} and {} are consistent!".
- format(filename, gt_filename))
+ "Assert allclose passed! The results of {} and {} are consistent!".format(
+ filename, gt_filename
+ )
+ )
except Exception as E:
print(E)
raise ValueError(
- "The results of {} and the results of {} are inconsistent!".
- format(filename, gt_filename))
+ "The results of {} and the results of {} are inconsistent!".format(
+ filename, gt_filename
+ )
+ )
diff --git a/test_tipc/docs/termux_for_android.md b/test_tipc/docs/termux_for_android.md
index 73ecbb2e93..3ae9d18e4f 100644
--- a/test_tipc/docs/termux_for_android.md
+++ b/test_tipc/docs/termux_for_android.md
@@ -125,4 +125,3 @@ scp -P 8022 test.txt u0_a374@172.24.162.117:/home/storage/test
## 3. 更多教程
本教程可以完成Termux基本配置,更多关于Termux的用法,请参考:[Termux高级终端安装使用配置教程](https://www.sqlsec.com/2018/05/termux.html)。
-
diff --git a/test_tipc/docs/test_inference_js.md b/test_tipc/docs/test_inference_js.md
index c0b7d653ae..ae5dea3004 100644
--- a/test_tipc/docs/test_inference_js.md
+++ b/test_tipc/docs/test_inference_js.md
@@ -44,7 +44,7 @@ bash test_tipc/test_inference_js.sh
2. 启动 Jest 测试服务,通过 jest-puppeteer 插件完成 chrome 操作,加载 @paddlejs-models/ocr 脚本完成推理流程
3. 测试用例为原图识别后的文本结果与预期文本结果(expect.json)进行对比,测试通过有两个标准:
* 原图识别结果逐字符与预期结果对比,误差不超过 **10个字符**;
- * 原图识别结果每个文本框字符内容与预期结果进行相似度对比,相似度不小于 0.9(全部一致则相似度为1)。
+ * 原图识别结果每个文本框字符内容与预期结果进行相似度对比,相似度不小于 0.9(全部一致则相似度为1)。
只有满足上述两个标准,视为测试通过。通过为如下显示:
diff --git a/test_tipc/supplementary/config.py b/test_tipc/supplementary/config.py
index 72a99c70af..22445be0c7 100644
--- a/test_tipc/supplementary/config.py
+++ b/test_tipc/supplementary/config.py
@@ -14,23 +14,20 @@
class ArgsParser(ArgumentParser):
def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
+ super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
+ self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
self.add_argument(
- "-o", "--opt", nargs='+', help="set configuration options")
- self.add_argument(
- '-p',
- '--profiler_options',
+ "-p",
+ "--profiler_options",
type=str,
default=None,
- help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
+ help='The option of profiler, which should be in format "key1=value1;key2=value2;key3=value3".',
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, \
- "Please specify --config=configure_file_path."
+ assert args.config is not None, "Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
@@ -40,7 +37,7 @@ def _parse_opt(self, opts):
return config
for s in opts:
s = s.strip()
- k, v = s.split('=')
+ k, v = s.split("=")
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
@@ -60,7 +57,11 @@ def __getattr__(self, key):
global_config = AttrDict()
-default_config = {'Global': {'debug': False, }}
+default_config = {
+ "Global": {
+ "debug": False,
+ }
+}
def load_config(file_path):
@@ -72,8 +73,8 @@ def load_config(file_path):
"""
merge_config(default_config)
_, ext = os.path.splitext(file_path)
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
+ assert ext in [".yml", ".yaml"], "only support yaml files for now"
+ merge_config(yaml.load(open(file_path, "rb"), Loader=yaml.Loader))
return global_config
@@ -91,11 +92,12 @@ def merge_config(config):
else:
global_config[key] = value
else:
- sub_keys = key.split('.')
+ sub_keys = key.split(".")
assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
- global_config.keys(), sub_keys[0])
+ global_config.keys(), sub_keys[0]
+ )
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
@@ -114,18 +116,17 @@ def preprocess(is_train=False):
if is_train:
# save_config
- save_model_dir = config['save_model_dir']
+ save_model_dir = config["save_model_dir"]
os.makedirs(save_model_dir, exist_ok=True)
- with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
- yaml.dump(
- dict(config), f, default_flow_style=False, sort_keys=False)
- log_file = '{}/train.log'.format(save_model_dir)
+ with open(os.path.join(save_model_dir, "config.yml"), "w") as f:
+ yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
+ log_file = "{}/train.log".format(save_model_dir)
else:
log_file = None
logger = get_logger(log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
- use_gpu = config['use_gpu']
+ use_gpu = config["use_gpu"]
print_dict(config, logger)
diff --git a/test_tipc/supplementary/custom_op/custom_relu_op.cc b/test_tipc/supplementary/custom_op/custom_relu_op.cc
index 97002a9118..86d8380c27 100644
--- a/test_tipc/supplementary/custom_op/custom_relu_op.cc
+++ b/test_tipc/supplementary/custom_op/custom_relu_op.cc
@@ -12,16 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-
-// reference from : https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cc
+// reference from :
+// https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cc
#include
#include
#include "paddle/extension.h"
template
-void relu_cpu_forward_kernel(const data_t* x_data,
- data_t* out_data,
+void relu_cpu_forward_kernel(const data_t *x_data, data_t *out_data,
int64_t x_numel) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast(0.), x_data[i]);
@@ -29,9 +28,8 @@ void relu_cpu_forward_kernel(const data_t* x_data,
}
template
-void relu_cpu_backward_kernel(const data_t* grad_out_data,
- const data_t* out_data,
- data_t* grad_x_data,
+void relu_cpu_backward_kernel(const data_t *grad_out_data,
+ const data_t *out_data, data_t *grad_x_data,
int64_t out_numel) {
for (int i = 0; i < out_numel; ++i) {
grad_x_data[i] =
@@ -39,7 +37,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
}
}
-std::vector relu_cpu_forward(const paddle::Tensor& x) {
+std::vector relu_cpu_forward(const paddle::Tensor &x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
@@ -52,16 +50,15 @@ std::vector relu_cpu_forward(const paddle::Tensor& x) {
return {out};
}
-std::vector relu_cpu_backward(const paddle::Tensor& x,
- const paddle::Tensor& out,
- const paddle::Tensor& grad_out) {
+std::vector relu_cpu_backward(const paddle::Tensor &x,
+ const paddle::Tensor &out,
+ const paddle::Tensor &grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU);
grad_x.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel(
- grad_out.data(),
- out.data(),
+ grad_out.data(), out.data(),
grad_x.mutable_data(x.place()),
out.size());
}));
@@ -69,12 +66,12 @@ std::vector relu_cpu_backward(const paddle::Tensor& x,
return {grad_x};
}
-std::vector relu_cuda_forward(const paddle::Tensor& x);
-std::vector relu_cuda_backward(const paddle::Tensor& x,
- const paddle::Tensor& out,
- const paddle::Tensor& grad_out);
+std::vector relu_cuda_forward(const paddle::Tensor &x);
+std::vector relu_cuda_backward(const paddle::Tensor &x,
+ const paddle::Tensor &out,
+ const paddle::Tensor &grad_out);
-std::vector ReluForward(const paddle::Tensor& x) {
+std::vector ReluForward(const paddle::Tensor &x) {
// TODO(chenweihang): Check Input
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_forward(x);
@@ -85,9 +82,9 @@ std::vector ReluForward(const paddle::Tensor& x) {
}
}
-std::vector ReluBackward(const paddle::Tensor& x,
- const paddle::Tensor& out,
- const paddle::Tensor& grad_out) {
+std::vector ReluBackward(const paddle::Tensor &x,
+ const paddle::Tensor &out,
+ const paddle::Tensor &grad_out) {
// TODO(chenweihang): Check Input
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward(x, out, grad_out);
diff --git a/test_tipc/supplementary/custom_op/custom_relu_op.cu b/test_tipc/supplementary/custom_op/custom_relu_op.cu
index 9b953a33cc..fc2792e614 100644
--- a/test_tipc/supplementary/custom_op/custom_relu_op.cu
+++ b/test_tipc/supplementary/custom_op/custom_relu_op.cu
@@ -12,14 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-
-// reference https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cu
+// reference
+// https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cu
#include "paddle/extension.h"
template
-__global__ void relu_cuda_forward_kernel(const data_t* x,
- data_t* y,
+__global__ void relu_cuda_forward_kernel(const data_t *x, data_t *y,
const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
@@ -28,17 +27,15 @@ __global__ void relu_cuda_forward_kernel(const data_t* x,
}
template
-__global__ void relu_cuda_backward_kernel(const data_t* dy,
- const data_t* y,
- data_t* dx,
- const int num) {
+__global__ void relu_cuda_backward_kernel(const data_t *dy, const data_t *y,
+ data_t *dx, const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
}
}
-std::vector relu_cuda_forward(const paddle::Tensor& x) {
+std::vector relu_cuda_forward(const paddle::Tensor &x) {
auto out = paddle::Tensor(paddle::PlaceType::kGPU);
out.reshape(x.shape());
@@ -54,9 +51,9 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) {
return {out};
}
-std::vector relu_cuda_backward(const paddle::Tensor& x,
- const paddle::Tensor& out,
- const paddle::Tensor& grad_out) {
+std::vector relu_cuda_backward(const paddle::Tensor &x,
+ const paddle::Tensor &out,
+ const paddle::Tensor &grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU);
grad_x.reshape(x.shape());
@@ -66,10 +63,8 @@ std::vector relu_cuda_backward(const paddle::Tensor& x,
PD_DISPATCH_FLOATING_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<<>>(
- grad_out.data(),
- out.data(),
- grad_x.mutable_data(x.place()),
- numel);
+ grad_out.data(), out.data(),
+ grad_x.mutable_data(x.place()), numel);
}));
return {grad_x};
diff --git a/test_tipc/supplementary/custom_op/test.py b/test_tipc/supplementary/custom_op/test.py
index 8b7f303dd6..df8d939bc1 100644
--- a/test_tipc/supplementary/custom_op/test.py
+++ b/test_tipc/supplementary/custom_op/test.py
@@ -11,17 +11,18 @@
# jit compile custom op
custom_ops = load(
- name="custom_jit_ops", sources=["custom_relu_op.cc", "custom_relu_op.cu"])
+ name="custom_jit_ops", sources=["custom_relu_op.cc", "custom_relu_op.cu"]
+)
class LeNet(nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2D(
- in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
+ in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2
+ )
self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2)
- self.conv2 = nn.Conv2D(
- in_channels=6, out_channels=16, kernel_size=5, stride=1)
+ self.conv2 = nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
self.linear2 = nn.Linear(in_features=120, out_features=84)
@@ -52,14 +53,11 @@ def forward(self, x):
opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=net.parameters())
# data loader
-transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format='CHW')])
-train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
+transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format="CHW")])
+train_dataset = paddle.vision.datasets.MNIST(mode="train", transform=transform)
train_loader = paddle.io.DataLoader(
- train_dataset,
- batch_size=BATCH_SIZE,
- shuffle=True,
- drop_last=True,
- num_workers=2)
+ train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2
+)
# train
for epoch_id in range(EPOCH_NUM):
@@ -69,8 +67,11 @@ def forward(self, x):
loss.backward()
if batch_id % 300 == 0:
- print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id,
- np.mean(loss.numpy())))
+ print(
+ "Epoch {} batch {}: loss = {}".format(
+ epoch_id, batch_id, np.mean(loss.numpy())
+ )
+ )
opt.step()
opt.clear_grad()
diff --git a/test_tipc/supplementary/data.py b/test_tipc/supplementary/data.py
index 2770a9a42c..5fe538d6b1 100644
--- a/test_tipc/supplementary/data.py
+++ b/test_tipc/supplementary/data.py
@@ -6,7 +6,7 @@
def transform(data, ops=None):
- """ transform """
+ """transform"""
if ops is None:
ops = []
for op in ops:
@@ -22,11 +22,10 @@ def create_operators(op_param_list, global_config=None):
Args:
params(list): a dict list, used to create some operators
"""
- assert isinstance(op_param_list, list), ('operator config should be a list')
+ assert isinstance(op_param_list, list), "operator config should be a list"
ops = []
for operator in op_param_list:
- assert isinstance(operator,
- dict) and len(operator) == 1, "yaml format error"
+ assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
@@ -37,84 +36,85 @@ def create_operators(op_param_list, global_config=None):
class DecodeImage(object):
- """ decode image """
+ """decode image"""
- def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
+ def __init__(self, img_mode="RGB", channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
- img = data['image']
+ img = data["image"]
if six.PY2:
- assert type(img) is str and len(
- img) > 0, "invalid input 'img' in DecodeImage"
+ assert (
+ type(img) is str and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
else:
- assert type(img) is bytes and len(
- img) > 0, "invalid input 'img' in DecodeImage"
- img = np.frombuffer(img, dtype='uint8')
+ assert (
+ type(img) is bytes and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype="uint8")
img = cv2.imdecode(img, 1)
if img is None:
return None
- if self.img_mode == 'GRAY':
+ if self.img_mode == "GRAY":
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- elif self.img_mode == 'RGB':
- assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
+ elif self.img_mode == "RGB":
+ assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
- data['image'] = img
- data['src_image'] = img
+ data["image"] = img
+ data["src_image"] = img
return data
class NormalizeImage(object):
- """ normalize image such as substract mean, divide std
- """
+ """normalize image such as substract mean, divide std"""
- def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
+ def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype('float32')
- self.std = np.array(std).reshape(shape).astype('float32')
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
- img = data['image']
+ img = data["image"]
from PIL import Image
+
if isinstance(img, Image.Image):
img = np.array(img)
- assert isinstance(img,
- np.ndarray), "invalid input 'img' in NormalizeImage"
- data['image'] = (
- img.astype('float32') * self.scale - self.mean) / self.std
+ assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+ data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
- """ convert hwc image to chw image
- """
+ """convert hwc image to chw image"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
- img = data['image']
+ img = data["image"]
from PIL import Image
+
if isinstance(img, Image.Image):
img = np.array(img)
- data['image'] = img.transpose((2, 0, 1))
+ data["image"] = img.transpose((2, 0, 1))
- src_img = data['src_image']
+ src_img = data["src_image"]
from PIL import Image
+
if isinstance(img, Image.Image):
src_img = np.array(src_img)
- data['src_image'] = img.transpose((2, 0, 1))
+ data["src_image"] = img.transpose((2, 0, 1))
return data
@@ -124,11 +124,11 @@ def __init__(self, config, mode, logger, seed=None):
self.logger = logger
self.mode = mode.lower()
- data_dir = config['Train']['data_dir']
+ data_dir = config["Train"]["data_dir"]
imgs_list = self.get_image_list(data_dir)
- self.ops = create_operators(cfg['transforms'], None)
+ self.ops = create_operators(cfg["transforms"], None)
def get_image_list(self, img_dir):
imgs = glob.glob(os.path.join(img_dir, "*.png"))
diff --git a/test_tipc/supplementary/data_loader.py b/test_tipc/supplementary/data_loader.py
index f0245dd27c..2e40662071 100644
--- a/test_tipc/supplementary/data_loader.py
+++ b/test_tipc/supplementary/data_loader.py
@@ -7,8 +7,7 @@
def term_mp(sig_num, frame):
- """ kill all child processes
- """
+ """kill all child processes"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
@@ -16,25 +15,20 @@ def term_mp(sig_num, frame):
return
-def build_dataloader(mode,
- batch_size=4,
- seed=None,
- num_workers=0,
- device='gpu:0'):
-
- normalize = Normalize(
- mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='HWC')
+def build_dataloader(mode, batch_size=4, seed=None, num_workers=0, device="gpu:0"):
+ normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format="HWC")
if mode.lower() == "train":
dataset = Cifar100(mode=mode, transform=normalize)
- elif mode.lower() in ["test", 'valid', 'eval']:
+ elif mode.lower() in ["test", "valid", "eval"]:
dataset = Cifar100(mode="test", transform=normalize)
else:
raise ValueError(f"{mode} should be one of ['train', 'test']")
# define batch sampler
batch_sampler = DistributedBatchSampler(
- dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True)
+ dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True
+ )
data_loader = DataLoader(
dataset=dataset,
@@ -42,7 +36,8 @@ def build_dataloader(mode,
places=device,
num_workers=num_workers,
return_list=True,
- use_shared_memory=False)
+ use_shared_memory=False,
+ )
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
diff --git a/test_tipc/supplementary/load_cifar.py b/test_tipc/supplementary/load_cifar.py
index 6646dca390..35d0ffaffc 100644
--- a/test_tipc/supplementary/load_cifar.py
+++ b/test_tipc/supplementary/load_cifar.py
@@ -4,12 +4,12 @@
def load_CIFAR_batch(filename):
- """ load single batch of cifar """
- with open(filename, 'rb') as f:
- datadict = p.load(f, encoding='bytes')
+ """load single batch of cifar"""
+ with open(filename, "rb") as f:
+ datadict = p.load(f, encoding="bytes")
# 以字典的形式取出数据
- X = datadict[b'data']
- Y = datadict[b'fine_labels']
+ X = datadict[b"data"]
+ Y = datadict[b"fine_labels"]
try:
X = X.reshape(10000, 3, 32, 32)
except:
@@ -22,9 +22,9 @@ def load_CIFAR_batch(filename):
if __name__ == "__main__":
mode = "train"
imgX, imgY = load_CIFAR_batch(f"./cifar-100-python/{mode}")
- with open(f'./cifar-100-python/{mode}_imgs/img_label.txt', 'a+') as f:
+ with open(f"./cifar-100-python/{mode}_imgs/img_label.txt", "a+") as f:
for i in range(imgY.shape[0]):
- f.write('img' + str(i) + ' ' + str(imgY[i]) + '\n')
+ f.write("img" + str(i) + " " + str(imgY[i]) + "\n")
for i in range(imgX.shape[0]):
imgs = imgX[i]
diff --git a/test_tipc/supplementary/loss.py b/test_tipc/supplementary/loss.py
index 8cb1cd498c..f139a7103b 100644
--- a/test_tipc/supplementary/loss.py
+++ b/test_tipc/supplementary/loss.py
@@ -44,7 +44,7 @@ def __call__(self, input, target):
def build_loss(config, epsilon=None):
- class_dim = config['class_dim']
+ class_dim = config["class_dim"]
loss_func = Loss(class_dim=class_dim, epsilon=epsilon)
return loss_func
@@ -72,9 +72,13 @@ def __call__(self, input, target):
class KLJSLoss(object):
- def __init__(self, mode='kl'):
- assert mode in ['kl', 'js', 'KL', 'JS'
- ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+ def __init__(self, mode="kl"):
+ assert mode in [
+ "kl",
+ "js",
+ "KL",
+ "JS",
+ ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
@@ -84,8 +88,7 @@ def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js":
- loss += paddle.multiply(
- p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+ loss += paddle.multiply(p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss)
@@ -97,8 +100,7 @@ def __call__(self, p1, p2, reduction="mean"):
class DMLLoss(object):
- def __init__(self, model_name_pairs, mode='js'):
-
+ def __init__(self, model_name_pairs, mode="js"):
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.kljs_loss = KLJSLoss(mode=mode)
@@ -106,7 +108,8 @@ def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
- model_name_pairs[0][0], str):
+ model_name_pairs[0][0], str
+ ):
return model_name_pairs
else:
return [model_name_pairs]
diff --git a/test_tipc/supplementary/metric.py b/test_tipc/supplementary/metric.py
index 401cf9b9d2..687fadebbf 100644
--- a/test_tipc/supplementary/metric.py
+++ b/test_tipc/supplementary/metric.py
@@ -3,13 +3,15 @@
from collections import OrderedDict
-def create_metric(out,
- label,
- architecture=None,
- topk=5,
- classes_num=1000,
- use_distillation=False,
- mode="train"):
+def create_metric(
+ out,
+ label,
+ architecture=None,
+ topk=5,
+ classes_num=1000,
+ use_distillation=False,
+ mode="train",
+):
"""
Create measures of model accuracy, such as top1 and top5
@@ -42,15 +44,17 @@ def create_metric(out,
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
- top1 = paddle.distributed.all_reduce(
- top1, op=paddle.distributed.ReduceOp.
- SUM) / paddle.distributed.get_world_size()
- topk = paddle.distributed.all_reduce(
- topk, op=paddle.distributed.ReduceOp.
- SUM) / paddle.distributed.get_world_size()
-
- fetchs['top1'] = top1
- topk_name = 'top{}'.format(k)
+ top1 = (
+ paddle.distributed.all_reduce(top1, op=paddle.distributed.ReduceOp.SUM)
+ / paddle.distributed.get_world_size()
+ )
+ topk = (
+ paddle.distributed.all_reduce(topk, op=paddle.distributed.ReduceOp.SUM)
+ / paddle.distributed.get_world_size()
+ )
+
+ fetchs["top1"] = top1
+ topk_name = "top{}".format(k)
fetchs[topk_name] = topk
return fetchs
diff --git a/test_tipc/supplementary/mv3.py b/test_tipc/supplementary/mv3.py
index 9ffcedac03..f891b86923 100644
--- a/test_tipc/supplementary/mv3.py
+++ b/test_tipc/supplementary/mv3.py
@@ -28,10 +28,12 @@
import math
from paddle.utils.cpp_extension import load
+
# jit compile custom op
custom_ops = load(
name="custom_jit_ops",
- sources=["./custom_op/custom_relu_op.cc", "./custom_op/custom_relu_op.cu"])
+ sources=["./custom_op/custom_relu_op.cc", "./custom_op/custom_relu_op.cu"],
+)
def make_divisible(v, divisor=8, min_value=None):
@@ -44,12 +46,14 @@ def make_divisible(v, divisor=8, min_value=None):
class MobileNetV3(nn.Layer):
- def __init__(self,
- scale=1.0,
- model_name="small",
- dropout_prob=0.2,
- class_dim=1000,
- use_custom_relu=False):
+ def __init__(
+ self,
+ scale=1.0,
+ model_name="small",
+ dropout_prob=0.2,
+ class_dim=1000,
+ use_custom_relu=False,
+ ):
super(MobileNetV3, self).__init__()
self.use_custom_relu = use_custom_relu
@@ -94,7 +98,8 @@ def __init__(self,
self.cls_ch_expand = 1280
else:
raise NotImplementedError(
- "mode[{}_model] is not implemented!".format(model_name))
+ "mode[{}_model] is not implemented!".format(model_name)
+ )
self.conv1 = ConvBNLayer(
in_c=3,
@@ -106,12 +111,13 @@ def __init__(self,
if_act=True,
act="hardswish",
name="conv1",
- use_custom_relu=self.use_custom_relu)
+ use_custom_relu=self.use_custom_relu,
+ )
self.block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
- for (k, exp, c, se, nl, s) in self.cfg:
+ for k, exp, c, se, nl, s in self.cfg:
block = self.add_sublayer(
"conv" + str(i + 2),
ResidualUnit(
@@ -123,7 +129,9 @@ def __init__(self,
use_se=se,
act=nl,
name="conv" + str(i + 2),
- use_custom_relu=self.use_custom_relu))
+ use_custom_relu=self.use_custom_relu,
+ ),
+ )
self.block_list.append(block)
inplanes = make_divisible(scale * c)
i += 1
@@ -138,7 +146,8 @@ def __init__(self,
if_act=True,
act="hardswish",
name="conv_last",
- use_custom_relu=self.use_custom_relu)
+ use_custom_relu=self.use_custom_relu,
+ )
self.pool = AdaptiveAvgPool2D(1)
@@ -149,7 +158,8 @@ def __init__(self,
stride=1,
padding=0,
weight_attr=ParamAttr(),
- bias_attr=False)
+ bias_attr=False,
+ )
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
@@ -157,7 +167,8 @@ def __init__(self,
self.cls_ch_expand,
class_dim,
weight_attr=ParamAttr(),
- bias_attr=ParamAttr())
+ bias_attr=ParamAttr(),
+ )
def forward(self, inputs, label=None):
x = self.conv1(inputs)
@@ -177,18 +188,20 @@ def forward(self, inputs, label=None):
class ConvBNLayer(nn.Layer):
- def __init__(self,
- in_c,
- out_c,
- filter_size,
- stride,
- padding,
- num_groups=1,
- if_act=True,
- act=None,
- use_cudnn=True,
- name="",
- use_custom_relu=False):
+ def __init__(
+ self,
+ in_c,
+ out_c,
+ filter_size,
+ stride,
+ padding,
+ num_groups=1,
+ if_act=True,
+ act=None,
+ use_cudnn=True,
+ name="",
+ use_custom_relu=False,
+ ):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -200,12 +213,14 @@ def __init__(self,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(),
- bias_attr=False)
+ bias_attr=False,
+ )
self.bn = BatchNorm(
num_channels=out_c,
act=None,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
- bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ )
# moving_mean_name=name + "_bn_mean",
# moving_variance_name=name + "_bn_variance")
@@ -229,16 +244,18 @@ def forward(self, x):
class ResidualUnit(nn.Layer):
- def __init__(self,
- in_c,
- mid_c,
- out_c,
- filter_size,
- stride,
- use_se,
- act=None,
- name='',
- use_custom_relu=False):
+ def __init__(
+ self,
+ in_c,
+ mid_c,
+ out_c,
+ filter_size,
+ stride,
+ use_se,
+ act=None,
+ name="",
+ use_custom_relu=False,
+ ):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_c == out_c
self.if_se = use_se
@@ -254,7 +271,8 @@ def __init__(self,
if_act=True,
act=act,
name=name + "_expand",
- use_custom_relu=self.use_custom_relu)
+ use_custom_relu=self.use_custom_relu,
+ )
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
@@ -265,7 +283,8 @@ def __init__(self,
if_act=True,
act=act,
name=name + "_depthwise",
- use_custom_relu=self.use_custom_relu)
+ use_custom_relu=self.use_custom_relu,
+ )
if self.if_se:
self.mid_se = SEModule(mid_c, name=name + "_se")
self.linear_conv = ConvBNLayer(
@@ -277,7 +296,8 @@ def __init__(self,
if_act=False,
act=None,
name=name + "_linear",
- use_custom_relu=self.use_custom_relu)
+ use_custom_relu=self.use_custom_relu,
+ )
def forward(self, inputs):
x = self.expand_conv(inputs)
@@ -301,7 +321,8 @@ def __init__(self, channel, reduction=4, name=""):
stride=1,
padding=0,
weight_attr=ParamAttr(),
- bias_attr=ParamAttr())
+ bias_attr=ParamAttr(),
+ )
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
@@ -309,7 +330,8 @@ def __init__(self, channel, reduction=4, name=""):
stride=1,
padding=0,
weight_attr=ParamAttr(),
- bias_attr=ParamAttr())
+ bias_attr=ParamAttr(),
+ )
def forward(self, inputs):
outputs = self.avg_pool(inputs)
@@ -371,31 +393,35 @@ def MobileNetV3_large_x1_25(**args):
class DistillMV3(nn.Layer):
- def __init__(self,
- scale=1.0,
- model_name="small",
- dropout_prob=0.2,
- class_dim=1000,
- args=None,
- use_custom_relu=False):
+ def __init__(
+ self,
+ scale=1.0,
+ model_name="small",
+ dropout_prob=0.2,
+ class_dim=1000,
+ args=None,
+ use_custom_relu=False,
+ ):
super(DistillMV3, self).__init__()
self.student = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
- use_custom_relu=use_custom_relu)
+ use_custom_relu=use_custom_relu,
+ )
self.student1 = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
- use_custom_relu=use_custom_relu)
+ use_custom_relu=use_custom_relu,
+ )
def forward(self, inputs, label=None):
predicts = dict()
- predicts['student'] = self.student(inputs, label)
- predicts['student1'] = self.student1(inputs, label)
+ predicts["student"] = self.student(inputs, label)
+ predicts["student1"] = self.student1(inputs, label)
return predicts
@@ -405,25 +431,29 @@ def distillmv3_large_x0_5(**args):
class SiameseMV3(nn.Layer):
- def __init__(self,
- scale=1.0,
- model_name="small",
- dropout_prob=0.2,
- class_dim=1000,
- args=None,
- use_custom_relu=False):
+ def __init__(
+ self,
+ scale=1.0,
+ model_name="small",
+ dropout_prob=0.2,
+ class_dim=1000,
+ args=None,
+ use_custom_relu=False,
+ ):
super(SiameseMV3, self).__init__()
self.net = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
- use_custom_relu=use_custom_relu)
+ use_custom_relu=use_custom_relu,
+ )
self.net1 = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
- use_custom_relu=use_custom_relu)
+ use_custom_relu=use_custom_relu,
+ )
def forward(self, inputs, label=None):
# net
@@ -431,7 +461,7 @@ def forward(self, inputs, label=None):
for block in self.net.block_list:
x = block(x)
- # net1
+ # net1
x1 = self.net1.conv1(inputs)
for block in self.net1.block_list:
x1 = block(x1)
@@ -454,33 +484,34 @@ def siamese_mv3(class_dim, use_custom_relu):
scale=0.5,
model_name="large",
class_dim=class_dim,
- use_custom_relu=use_custom_relu)
+ use_custom_relu=use_custom_relu,
+ )
return model
def build_model(config):
- model_type = config['model_type']
+ model_type = config["model_type"]
if model_type == "cls":
- class_dim = config['MODEL']['class_dim']
- use_custom_relu = config['MODEL']['use_custom_relu']
- if 'siamese' in config['MODEL'] and config['MODEL']['siamese'] is True:
- model = siamese_mv3(
- class_dim=class_dim, use_custom_relu=use_custom_relu)
+ class_dim = config["MODEL"]["class_dim"]
+ use_custom_relu = config["MODEL"]["use_custom_relu"]
+ if "siamese" in config["MODEL"] and config["MODEL"]["siamese"] is True:
+ model = siamese_mv3(class_dim=class_dim, use_custom_relu=use_custom_relu)
else:
model = MobileNetV3_large_x0_5(
- class_dim=class_dim, use_custom_relu=use_custom_relu)
+ class_dim=class_dim, use_custom_relu=use_custom_relu
+ )
elif model_type == "cls_distill":
- class_dim = config['MODEL']['class_dim']
- use_custom_relu = config['MODEL']['use_custom_relu']
+ class_dim = config["MODEL"]["class_dim"]
+ use_custom_relu = config["MODEL"]["use_custom_relu"]
model = distillmv3_large_x0_5(
- class_dim=class_dim, use_custom_relu=use_custom_relu)
+ class_dim=class_dim, use_custom_relu=use_custom_relu
+ )
elif model_type == "cls_distill_multiopt":
- class_dim = config['MODEL']['class_dim']
- use_custom_relu = config['MODEL']['use_custom_relu']
- model = distillmv3_large_x0_5(
- class_dim=100, use_custom_relu=use_custom_relu)
+ class_dim = config["MODEL"]["class_dim"]
+ use_custom_relu = config["MODEL"]["use_custom_relu"]
+ model = distillmv3_large_x0_5(class_dim=100, use_custom_relu=use_custom_relu)
else:
raise ValueError("model_type should be one of ['']")
diff --git a/test_tipc/supplementary/optimizer.py b/test_tipc/supplementary/optimizer.py
index aaa0153475..2a23161948 100644
--- a/test_tipc/supplementary/optimizer.py
+++ b/test_tipc/supplementary/optimizer.py
@@ -22,7 +22,8 @@ class Cosine(CosineAnnealingDecay):
def __init__(self, lr, step_each_epoch, epochs, **kwargs):
super(Cosine, self).__init__(
learning_rate=lr,
- T_max=step_each_epoch * epochs, )
+ T_max=step_each_epoch * epochs,
+ )
self.update_specified = False
@@ -58,8 +59,11 @@ class CosineWarmup(LinearWarmup):
"""
def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
- assert epochs > warmup_epoch, "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
- epochs, warmup_epoch)
+ assert (
+ epochs > warmup_epoch
+ ), "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
+ epochs, warmup_epoch
+ )
warmup_step = warmup_epoch * step_each_epoch
start_lr = 0.0
end_lr = lr
@@ -69,7 +73,8 @@ def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
learning_rate=lr_sch,
warmup_steps=warmup_step,
start_lr=start_lr,
- end_lr=end_lr)
+ end_lr=end_lr,
+ )
self.update_specified = False
@@ -87,13 +92,15 @@ class ExponentialWarmup(LinearWarmup):
warmup_epoch(int): epoch num of warmup
"""
- def __init__(self,
- lr,
- step_each_epoch,
- decay_epochs=2.4,
- decay_rate=0.97,
- warmup_epoch=5,
- **kwargs):
+ def __init__(
+ self,
+ lr,
+ step_each_epoch,
+ decay_epochs=2.4,
+ decay_rate=0.97,
+ warmup_epoch=5,
+ **kwargs
+ ):
warmup_step = warmup_epoch * step_each_epoch
start_lr = 0.0
end_lr = lr
@@ -103,7 +110,8 @@ def __init__(self,
learning_rate=lr_sch,
warmup_steps=warmup_step,
start_lr=start_lr,
- end_lr=end_lr)
+ end_lr=end_lr,
+ )
# NOTE: hac method to update exponential lr scheduler
self.update_specified = True
@@ -112,7 +120,7 @@ def __init__(self,
self.step_each_epoch = step_each_epoch
-class LearningRateBuilder():
+class LearningRateBuilder:
"""
Build learning rate variable
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn.html
@@ -121,11 +129,9 @@ class LearningRateBuilder():
params(dict): parameters used for init the class
"""
- def __init__(self,
- function='Linear',
- params={'lr': 0.1,
- 'steps': 100,
- 'end_lr': 0.0}):
+ def __init__(
+ self, function="Linear", params={"lr": 0.1, "steps": 100, "end_lr": 0.0}
+ ):
self.function = function
self.params = params
@@ -177,12 +183,9 @@ class Momentum(object):
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
- def __init__(self,
- learning_rate,
- momentum,
- parameter_list=None,
- regularization=None,
- **args):
+ def __init__(
+ self, learning_rate, momentum, parameter_list=None, regularization=None, **args
+ ):
super(Momentum, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
@@ -194,7 +197,8 @@ def __call__(self):
learning_rate=self.learning_rate,
momentum=self.momentum,
parameters=self.parameter_list,
- weight_decay=self.regularization)
+ weight_decay=self.regularization,
+ )
return opt
@@ -210,14 +214,16 @@ class RMSProp(object):
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
- def __init__(self,
- learning_rate,
- momentum,
- rho=0.95,
- epsilon=1e-6,
- parameter_list=None,
- regularization=None,
- **args):
+ def __init__(
+ self,
+ learning_rate,
+ momentum,
+ rho=0.95,
+ epsilon=1e-6,
+ parameter_list=None,
+ regularization=None,
+ **args
+ ):
super(RMSProp, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
@@ -233,7 +239,8 @@ def __call__(self):
rho=self.rho,
epsilon=self.epsilon,
parameters=self.parameter_list,
- weight_decay=self.regularization)
+ weight_decay=self.regularization,
+ )
return opt
@@ -246,26 +253,23 @@ class OptimizerBuilder(object):
regularizer (dict): parameters used for create regularization
"""
- def __init__(self,
- function='Momentum',
- params={'momentum': 0.9},
- regularizer=None):
+ def __init__(self, function="Momentum", params={"momentum": 0.9}, regularizer=None):
self.function = function
self.params = params
# create regularizer
if regularizer is not None:
mod = sys.modules[__name__]
- reg_func = regularizer['function'] + 'Decay'
- del regularizer['function']
+ reg_func = regularizer["function"] + "Decay"
+ del regularizer["function"]
reg = getattr(mod, reg_func)(**regularizer)()
- self.params['regularization'] = reg
+ self.params["regularization"] = reg
def __call__(self, learning_rate, parameter_list=None):
mod = sys.modules[__name__]
opt = getattr(mod, self.function)
- return opt(learning_rate=learning_rate,
- parameter_list=parameter_list,
- **self.params)()
+ return opt(
+ learning_rate=learning_rate, parameter_list=parameter_list, **self.params
+ )()
def create_optimizer(config, parameter_list=None):
@@ -292,34 +296,35 @@ def create_optimizer(config, parameter_list=None):
an optimizer instance
"""
# create learning_rate instance
- lr_config = config['LEARNING_RATE']
- lr_config['params'].update({
- 'epochs': config['epoch'],
- 'step_each_epoch':
- config['total_images'] // config['TRAIN']['batch_size'],
- })
+ lr_config = config["LEARNING_RATE"]
+ lr_config["params"].update(
+ {
+ "epochs": config["epoch"],
+ "step_each_epoch": config["total_images"] // config["TRAIN"]["batch_size"],
+ }
+ )
lr = LearningRateBuilder(**lr_config)()
# create optimizer instance
- opt_config = deepcopy(config['OPTIMIZER'])
+ opt_config = deepcopy(config["OPTIMIZER"])
opt = OptimizerBuilder(**opt_config)
return opt(lr, parameter_list), lr
def create_multi_optimizer(config, parameter_list=None):
- """
- """
+ """ """
# create learning_rate instance
- lr_config = config['LEARNING_RATE']
- lr_config['params'].update({
- 'epochs': config['epoch'],
- 'step_each_epoch':
- config['total_images'] // config['TRAIN']['batch_size'],
- })
+ lr_config = config["LEARNING_RATE"]
+ lr_config["params"].update(
+ {
+ "epochs": config["epoch"],
+ "step_each_epoch": config["total_images"] // config["TRAIN"]["batch_size"],
+ }
+ )
lr = LearningRateBuilder(**lr_config)()
# create optimizer instance
- opt_config = deepcopy.copy(config['OPTIMIZER'])
+ opt_config = deepcopy.copy(config["OPTIMIZER"])
opt = OptimizerBuilder(**opt_config)
return opt(lr, parameter_list), lr
diff --git a/test_tipc/supplementary/slim/slim_fpgm.py b/test_tipc/supplementary/slim/slim_fpgm.py
index 0e7621592d..030474448d 100644
--- a/test_tipc/supplementary/slim/slim_fpgm.py
+++ b/test_tipc/supplementary/slim/slim_fpgm.py
@@ -6,13 +6,12 @@
def prune_model(model, input_shape, prune_ratio=0.1):
-
flops = paddle.flops(model, input_shape)
pruner = FPGMFilterPruner(model, input_shape)
params_sensitive = {}
for param in model.parameters():
- if 'transpose' not in param.name and 'linear' not in param.name:
+ if "transpose" not in param.name and "linear" not in param.name:
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
params_sensitive[param.name] = prune_ratio
diff --git a/test_tipc/supplementary/slim/slim_quant.py b/test_tipc/supplementary/slim/slim_quant.py
index 7c201bf55d..41db3eb2f1 100644
--- a/test_tipc/supplementary/slim/slim_quant.py
+++ b/test_tipc/supplementary/slim/slim_quant.py
@@ -12,10 +12,10 @@ def __init__(self):
name=self.full_name() + ".pact",
initializer=paddle.nn.initializer.Constant(value=20),
learning_rate=1.0,
- regularizer=paddle.regularizer.L2Decay(2e-5))
+ regularizer=paddle.regularizer.L2Decay(2e-5),
+ )
- self.alpha = self.create_parameter(
- shape=[1], attr=alpha_attr, dtype='float32')
+ self.alpha = self.create_parameter(shape=[1], attr=alpha_attr, dtype="float32")
def forward(self, x):
out_left = paddle.nn.functional.relu(x - self.alpha)
@@ -25,24 +25,24 @@ def forward(self, x):
quant_config = {
- # weight preprocess type, default is None and no preprocessing is performed.
- 'weight_preprocess_type': None,
+ # weight preprocess type, default is None and no preprocessing is performed.
+ "weight_preprocess_type": None,
# activation preprocess type, default is None and no preprocessing is performed.
- 'activation_preprocess_type': None,
+ "activation_preprocess_type": None,
# weight quantize type, default is 'channel_wise_abs_max'
- 'weight_quantize_type': 'channel_wise_abs_max',
+ "weight_quantize_type": "channel_wise_abs_max",
# activation quantize type, default is 'moving_average_abs_max'
- 'activation_quantize_type': 'moving_average_abs_max',
+ "activation_quantize_type": "moving_average_abs_max",
# weight quantize bit num, default is 8
- 'weight_bits': 8,
+ "weight_bits": 8,
# activation quantize bit num, default is 8
- 'activation_bits': 8,
+ "activation_bits": 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
- 'dtype': 'int8',
+ "dtype": "int8",
# window size for 'range_abs_max' quantization. default is 10000
- 'window_size': 10000,
+ "window_size": 10000,
# The decay coefficient of moving average, default is 0.9
- 'moving_rate': 0.9,
+ "moving_rate": 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
- 'quantizable_layer_type': ['Conv2D', 'Linear'],
+ "quantizable_layer_type": ["Conv2D", "Linear"],
}
diff --git a/test_tipc/supplementary/train.py b/test_tipc/supplementary/train.py
index f582123407..9dfec5ba4a 100644
--- a/test_tipc/supplementary/train.py
+++ b/test_tipc/supplementary/train.py
@@ -3,6 +3,7 @@
import os
import paddle.nn as nn
import paddle.distributed as dist
+
dist.get_world_size()
dist.init_parallel_env()
@@ -30,54 +31,50 @@ def _mkdir_if_not_exist(path, logger):
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
- 'be happy if some process has already created {}'.format(
- path))
+ "be happy if some process has already created {}".format(path)
+ )
else:
- raise OSError('Failed to mkdir {}'.format(path))
+ raise OSError("Failed to mkdir {}".format(path))
-def save_model(model,
- optimizer,
- model_path,
- logger,
- is_best=False,
- prefix='ppocr',
- **kwargs):
+def save_model(
+ model, optimizer, model_path, logger, is_best=False, prefix="ppocr", **kwargs
+):
"""
save model to the target path
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
- paddle.save(model.state_dict(), model_prefix + '.pdparams')
+ paddle.save(model.state_dict(), model_prefix + ".pdparams")
if type(optimizer) is list:
- paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt')
- paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt')
+ paddle.save(optimizer[0].state_dict(), model_prefix + ".pdopt")
+ paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + ".pdopt")
else:
- paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
+ paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
# # save metric and config
# with open(model_prefix + '.states', 'wb') as f:
# pickle.dump(kwargs, f, protocol=2)
if is_best:
- logger.info('save best model is to {}'.format(model_prefix))
+ logger.info("save best model is to {}".format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
def amp_scaler(config):
- if 'AMP' in config and config['AMP']['use_amp'] is True:
+ if "AMP" in config and config["AMP"]["use_amp"] is True:
AMP_RELATED_FLAGS_SETTING = {
- 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
- 'FLAGS_max_inplace_grad_add': 8,
+ "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
+ "FLAGS_max_inplace_grad_add": 8,
}
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["AMP"].get("scale_loss", 1.0)
- use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling",
- False)
+ use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling", False)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
- use_dynamic_loss_scaling=use_dynamic_loss_scaling)
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling,
+ )
return scaler
else:
return None
@@ -89,13 +86,14 @@ def set_seed(seed):
def train(config, scaler=None):
- EPOCH = config['epoch']
- topk = config['topk']
+ EPOCH = config["epoch"]
+ topk = config["topk"]
- batch_size = config['TRAIN']['batch_size']
- num_workers = config['TRAIN']['num_workers']
+ batch_size = config["TRAIN"]["batch_size"]
+ num_workers = config["TRAIN"]["num_workers"]
train_loader = build_dataloader(
- 'train', batch_size=batch_size, num_workers=num_workers)
+ "train", batch_size=batch_size, num_workers=num_workers
+ )
# build metric
metric_func = create_metric
@@ -104,22 +102,24 @@ def train(config, scaler=None):
# model = MobileNetV3_large_x0_5(class_dim=100)
model = build_model(config)
- # build_optimizer
+ # build_optimizer
optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.parameters())
+ config, parameter_list=model.parameters()
+ )
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
- pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
+ pre_str = "The metric of loaded metric as follows {}".format(
+ ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
+ )
logger.info(pre_str)
# about slim prune and quant
- if "quant_train" in config and config['quant_train'] is True:
+ if "quant_train" in config and config["quant_train"] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
- elif "prune_train" in config and config['prune_train'] is True:
+ elif "prune_train" in config and config["prune_train"] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
@@ -146,7 +146,7 @@ def train(config, scaler=None):
else:
outs = model(img_batch)
- # cal metric
+ # cal metric
acc = metric_func(outs, label)
# cal loss
@@ -169,16 +169,18 @@ def train(config, scaler=None):
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {float(avg_loss)}"
- strs += f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
+ strs += (
+ f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
+ )
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model)
- if len(best_acc) < 1 or float(acc['top5']) > best_acc['top5']:
+ if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
best_acc = acc
- best_acc['epoch'] = epoch
+ best_acc["epoch"] = epoch
is_best = True
else:
is_best = False
@@ -188,20 +190,22 @@ def train(config, scaler=None):
save_model(
model,
optimizer,
- config['save_model_dir'],
+ config["save_model_dir"],
logger,
is_best,
- prefix="cls")
+ prefix="cls",
+ )
def train_distill(config, scaler=None):
- EPOCH = config['epoch']
- topk = config['topk']
+ EPOCH = config["epoch"]
+ topk = config["topk"]
- batch_size = config['TRAIN']['batch_size']
- num_workers = config['TRAIN']['num_workers']
+ batch_size = config["TRAIN"]["batch_size"]
+ num_workers = config["TRAIN"]["num_workers"]
train_loader = build_dataloader(
- 'train', batch_size=batch_size, num_workers=num_workers)
+ "train", batch_size=batch_size, num_workers=num_workers
+ )
# build metric
metric_func = create_metric
@@ -210,32 +214,34 @@ def train_distill(config, scaler=None):
model = build_model(config)
# pact quant train
- if "quant_train" in config and config['quant_train'] is True:
+ if "quant_train" in config and config["quant_train"] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
- elif "prune_train" in config and config['prune_train'] is True:
+ elif "prune_train" in config and config["prune_train"] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
- # build_optimizer
+ # build_optimizer
optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.parameters())
+ config, parameter_list=model.parameters()
+ )
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
- pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
+ pre_str = "The metric of loaded metric as follows {}".format(
+ ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
+ )
logger.info(pre_str)
model.train()
model = paddle.DataParallel(model)
# build loss function
- loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
- loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
- loss_func_js = KLJSLoss(mode='js')
+ loss_func_distill = LossDistill(model_name_list=["student", "student1"])
+ loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
+ loss_func_js = KLJSLoss(mode="js")
data_num = len(train_loader)
@@ -252,13 +258,15 @@ def train_distill(config, scaler=None):
else:
outs = model(img_batch)
- # cal metric
- acc = metric_func(outs['student'], label)
+ # cal metric
+ acc = metric_func(outs["student"], label)
# cal loss
- avg_loss = loss_func_distill(outs, label)['student'] + \
- loss_func_distill(outs, label)['student1'] + \
- loss_func_dml(outs, label)['student_student1']
+ avg_loss = (
+ loss_func_distill(outs, label)["student"]
+ + loss_func_distill(outs, label)["student1"]
+ + loss_func_dml(outs, label)["student_student1"]
+ )
# backward
if scaler is None:
@@ -277,16 +285,18 @@ def train_distill(config, scaler=None):
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {float(avg_loss)}"
- strs += f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
+ strs += (
+ f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
+ )
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model._layers.student)
- if len(best_acc) < 1 or float(acc['top5']) > best_acc['top5']:
+ if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
best_acc = acc
- best_acc['epoch'] = epoch
+ best_acc["epoch"] = epoch
is_best = True
else:
is_best = False
@@ -297,20 +307,22 @@ def train_distill(config, scaler=None):
save_model(
model,
optimizer,
- config['save_model_dir'],
+ config["save_model_dir"],
logger,
is_best,
- prefix="cls_distill")
+ prefix="cls_distill",
+ )
def train_distill_multiopt(config, scaler=None):
- EPOCH = config['epoch']
- topk = config['topk']
+ EPOCH = config["epoch"]
+ topk = config["topk"]
- batch_size = config['TRAIN']['batch_size']
- num_workers = config['TRAIN']['num_workers']
+ batch_size = config["TRAIN"]["batch_size"]
+ num_workers = config["TRAIN"]["num_workers"]
train_loader = build_dataloader(
- 'train', batch_size=batch_size, num_workers=num_workers)
+ "train", batch_size=batch_size, num_workers=num_workers
+ )
# build metric
metric_func = create_metric
@@ -318,24 +330,27 @@ def train_distill_multiopt(config, scaler=None):
# model = distillmv3_large_x0_5(class_dim=100)
model = build_model(config)
- # build_optimizer
+ # build_optimizer
optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.student.parameters())
+ config, parameter_list=model.student.parameters()
+ )
optimizer1, lr_scheduler1 = create_optimizer(
- config, parameter_list=model.student1.parameters())
+ config, parameter_list=model.student1.parameters()
+ )
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
- pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
+ pre_str = "The metric of loaded metric as follows {}".format(
+ ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
+ )
logger.info(pre_str)
# quant train
- if "quant_train" in config and config['quant_train'] is True:
+ if "quant_train" in config and config["quant_train"] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
- elif "prune_train" in config and config['prune_train'] is True:
+ elif "prune_train" in config and config["prune_train"] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
@@ -345,9 +360,9 @@ def train_distill_multiopt(config, scaler=None):
model = paddle.DataParallel(model)
# build loss function
- loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
- loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
- loss_func_js = KLJSLoss(mode='js')
+ loss_func_distill = LossDistill(model_name_list=["student", "student1"])
+ loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
+ loss_func_js = KLJSLoss(mode="js")
data_num = len(train_loader)
best_acc = {}
@@ -364,16 +379,18 @@ def train_distill_multiopt(config, scaler=None):
else:
outs = model(img_batch)
- # cal metric
- acc = metric_func(outs['student'], label)
+ # cal metric
+ acc = metric_func(outs["student"], label)
# cal loss
- avg_loss = loss_func_distill(outs,
- label)['student'] + loss_func_dml(
- outs, label)['student_student1']
- avg_loss1 = loss_func_distill(outs,
- label)['student1'] + loss_func_dml(
- outs, label)['student_student1']
+ avg_loss = (
+ loss_func_distill(outs, label)["student"]
+ + loss_func_dml(outs, label)["student_student1"]
+ )
+ avg_loss1 = (
+ loss_func_distill(outs, label)["student1"]
+ + loss_func_dml(outs, label)["student_student1"]
+ )
if scaler is None:
# backward
@@ -402,16 +419,18 @@ def train_distill_multiopt(config, scaler=None):
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {float(avg_loss)}, loss1: {float(avg_loss1)}"
- strs += f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
+ strs += (
+ f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
+ )
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model._layers.student)
- if len(best_acc) < 1 or float(acc['top5']) > best_acc['top5']:
+ if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
best_acc = acc
- best_acc['epoch'] = epoch
+ best_acc["epoch"] = epoch
is_best = True
else:
is_best = False
@@ -419,18 +438,21 @@ def train_distill_multiopt(config, scaler=None):
f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
)
save_model(
- model, [optimizer, optimizer1],
- config['save_model_dir'],
+ model,
+ [optimizer, optimizer1],
+ config["save_model_dir"],
logger,
is_best,
- prefix="cls_distill_multiopt")
+ prefix="cls_distill_multiopt",
+ )
def eval(config, model):
- batch_size = config['VALID']['batch_size']
- num_workers = config['VALID']['num_workers']
+ batch_size = config["VALID"]["batch_size"]
+ num_workers = config["VALID"]["num_workers"]
valid_loader = build_dataloader(
- 'test', batch_size=batch_size, num_workers=num_workers)
+ "test", batch_size=batch_size, num_workers=num_workers
+ )
# build metric
metric_func = create_metric
@@ -456,13 +478,12 @@ def eval(config, model):
if __name__ == "__main__":
-
config, logger = preprocess(is_train=False)
# AMP scaler
scaler = amp_scaler(config)
- model_type = config['model_type']
+ model_type = config["model_type"]
if model_type == "cls":
train(config)
diff --git a/test_tipc/supplementary/utils.py b/test_tipc/supplementary/utils.py
index ae9ae061b9..8b21490dec 100644
--- a/test_tipc/supplementary/utils.py
+++ b/test_tipc/supplementary/utils.py
@@ -39,7 +39,7 @@ def print_dict(d, logger, delimiter=0):
@functools.lru_cache()
-def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
+def get_logger(name="root", log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
@@ -63,8 +63,8 @@ def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
return logger
formatter = logging.Formatter(
- '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
- datefmt="%Y/%m/%d %H:%M:%S")
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
+ )
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
@@ -72,7 +72,7 @@ def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
- file_handler = logging.FileHandler(log_file, 'a')
+ file_handler = logging.FileHandler(log_file, "a")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
@@ -88,65 +88,73 @@ def load_model(config, model, optimizer=None):
load model from checkpoint or pretrained_model
"""
logger = get_logger()
- checkpoints = config.get('checkpoints')
- pretrained_model = config.get('pretrained_model')
+ checkpoints = config.get("checkpoints")
+ pretrained_model = config.get("pretrained_model")
best_model_dict = {}
if checkpoints:
- if checkpoints.endswith('.pdparams'):
- checkpoints = checkpoints.replace('.pdparams', '')
- assert os.path.exists(checkpoints + ".pdparams"), \
- "The {}.pdparams does not exists!".format(checkpoints)
+ if checkpoints.endswith(".pdparams"):
+ checkpoints = checkpoints.replace(".pdparams", "")
+ assert os.path.exists(
+ checkpoints + ".pdparams"
+ ), "The {}.pdparams does not exists!".format(checkpoints)
# load params from trained model
- params = paddle.load(checkpoints + '.pdparams')
+ params = paddle.load(checkpoints + ".pdparams")
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
- logger.warning("{} not in loaded params {} !".format(
- key, params.keys()))
+ logger.warning(
+ "{} not in loaded params {} !".format(key, params.keys())
+ )
continue
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
- "The shape of model params {} {} not matched with loaded params shape {} !".
- format(key, value.shape, pre_value.shape))
+ "The shape of model params {} {} not matched with loaded params shape {} !".format(
+ key, value.shape, pre_value.shape
+ )
+ )
model.set_state_dict(new_state_dict)
if optimizer is not None:
- if os.path.exists(checkpoints + '.pdopt'):
- optim_dict = paddle.load(checkpoints + '.pdopt')
+ if os.path.exists(checkpoints + ".pdopt"):
+ optim_dict = paddle.load(checkpoints + ".pdopt")
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
- "{}.pdopt is not exists, params of optimizer is not loaded".
- format(checkpoints))
-
- if os.path.exists(checkpoints + '.states'):
- with open(checkpoints + '.states', 'rb') as f:
- states_dict = pickle.load(f) if six.PY2 else pickle.load(
- f, encoding='latin1')
- best_model_dict = states_dict.get('best_model_dict', {})
- if 'epoch' in states_dict:
- best_model_dict['start_epoch'] = states_dict['epoch'] + 1
+ "{}.pdopt is not exists, params of optimizer is not loaded".format(
+ checkpoints
+ )
+ )
+
+ if os.path.exists(checkpoints + ".states"):
+ with open(checkpoints + ".states", "rb") as f:
+ states_dict = (
+ pickle.load(f) if six.PY2 else pickle.load(f, encoding="latin1")
+ )
+ best_model_dict = states_dict.get("best_model_dict", {})
+ if "epoch" in states_dict:
+ best_model_dict["start_epoch"] = states_dict["epoch"] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_pretrained_params(model, pretrained_model)
else:
- logger.info('train from scratch')
+ logger.info("train from scratch")
return best_model_dict
def load_pretrained_params(model, path):
logger = get_logger()
- if path.endswith('.pdparams'):
- path = path.replace('.pdparams', '')
- assert os.path.exists(path + ".pdparams"), \
- "The {}.pdparams does not exists!".format(path)
+ if path.endswith(".pdparams"):
+ path = path.replace(".pdparams", "")
+ assert os.path.exists(
+ path + ".pdparams"
+ ), "The {}.pdparams does not exists!".format(path)
- params = paddle.load(path + '.pdparams')
+ params = paddle.load(path + ".pdparams")
state_dict = model.state_dict()
new_state_dict = {}
for k1 in params.keys():
@@ -157,8 +165,10 @@ def load_pretrained_params(model, path):
new_state_dict[k1] = params[k1]
else:
logger.warning(
- "The shape of model params {} {} not matched with loaded params {} {} !".
- format(k1, state_dict[k1].shape, k1, params[k1].shape))
+ "The shape of model params {} {} not matched with loaded params {} {} !".format(
+ k1, state_dict[k1].shape, k1, params[k1].shape
+ )
+ )
model.set_state_dict(new_state_dict)
logger.info("load pretrain successful from {}".format(path))
return model
diff --git a/tools/end2end/convert_ppocr_label.py b/tools/end2end/convert_ppocr_label.py
index c64b9ed168..08cea77667 100644
--- a/tools/end2end/convert_ppocr_label.py
+++ b/tools/end2end/convert_ppocr_label.py
@@ -30,40 +30,42 @@ def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
assert label_dir != save_dir, "hahahhaha"
- label_file = open(label_dir, 'r')
+ label_file = open(label_dir, "r")
data = label_file.readlines()
gt_dict = {}
for line in data:
try:
- tmp = line.split('\t')
+ tmp = line.split("\t")
assert len(tmp) == 2, ""
except:
- tmp = line.strip().split(' ')
+ tmp = line.strip().split(" ")
gt_lists = []
- if tmp[0].split('/')[0] is not None:
+ if tmp[0].split("/")[0] is not None:
img_path = tmp[0]
anno = json.loads(tmp[1])
gt_collect = []
for dic in anno:
- #txt = dic['transcription'].replace(' ', '') # ignore blank
- txt = dic['transcription']
- if 'score' in dic and float(dic['score']) < 0.5:
+ # txt = dic['transcription'].replace(' ', '') # ignore blank
+ txt = dic["transcription"]
+ if "score" in dic and float(dic["score"]) < 0.5:
continue
- if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
- #while ' ' in txt:
+ if "\u3000" in txt:
+ txt = txt.replace("\u3000", " ")
+ # while ' ' in txt:
# txt = txt.replace(' ', '')
- poly = np.array(dic['points']).flatten()
+ poly = np.array(dic["points"]).flatten()
if txt == "###":
txt_tag = 1 ## ignore 1
else:
txt_tag = 0
if mode == "gt":
- gt_label = poly_to_string(poly) + "\t" + str(
- txt_tag) + "\t" + txt + "\n"
+ gt_label = (
+ poly_to_string(poly) + "\t" + str(txt_tag) + "\t" + txt + "\n"
+ )
else:
gt_label = poly_to_string(poly) + "\t" + txt + "\n"
@@ -87,6 +89,7 @@ def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
def parse_args():
import argparse
+
parser = argparse.ArgumentParser(description="args")
parser.add_argument("--label_path", type=str, required=True)
parser.add_argument("--save_folder", type=str, required=True)
diff --git a/tools/end2end/draw_html.py b/tools/end2end/draw_html.py
index fcac8ad3bf..c894f44355 100644
--- a/tools/end2end/draw_html.py
+++ b/tools/end2end/draw_html.py
@@ -34,40 +34,39 @@ def parse_args():
def draw_debug_img(args):
-
html_path = args.save_html_path
err_cnt = 0
- with open(html_path, 'w') as html:
- html.write('\n\n')
+ with open(html_path, "w") as html:
+ html.write("\n\n")
html.write('\n')
html.write(
- ""
+ ''
)
image_list = []
path = args.image_dir
for i, filename in enumerate(sorted(os.listdir(path))):
- if filename.endswith("txt"): continue
+ if filename.endswith("txt"):
+ continue
# The image path
base = "{}/{}".format(path, filename)
html.write("\n")
- html.write(f' {filename}\n GT')
+ html.write(f" | {filename}\n GT")
html.write(f' | GT\n | ')
html.write(" \n")
- html.write('\n')
- html.write(' \n')
- html.write('\n\n')
+ html.write("\n")
+ html.write("\n")
+ html.write("\n\n")
print(f"The html file saved in {html_path}")
return
if __name__ == "__main__":
-
args = parse_args()
draw_debug_img(args)
diff --git a/tools/end2end/eval_end2end.py b/tools/end2end/eval_end2end.py
index dd37940845..3795c3c28f 100644
--- a/tools/end2end/eval_end2end.py
+++ b/tools/end2end/eval_end2end.py
@@ -29,7 +29,7 @@ def strQ2B(ustring):
inside_code = ord(uchar)
if inside_code == 12288:
inside_code = 32
- elif (inside_code >= 65281 and inside_code <= 65374):
+ elif inside_code >= 65281 and inside_code <= 65374:
inside_code -= 65248
rstring += chr(inside_code)
return rstring
@@ -48,8 +48,7 @@ def polygon_iou(poly1, poly2):
"""
Intersection over union between two shapely polygons.
"""
- if not poly1.intersects(
- poly2): # this test is fast and can accelerate calculation
+ if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
iou = 0
else:
try:
@@ -59,7 +58,7 @@ def polygon_iou(poly1, poly2):
except shapely.geos.TopologicalError:
# except Exception as e:
# print(e)
- print('shapely.geos.TopologicalError occurred, iou set to 0')
+ print("shapely.geos.TopologicalError occurred, iou set to 0")
iou = 0
return iou
@@ -69,7 +68,7 @@ def ed(str1, str2):
def e2e_eval(gt_dir, res_dir, ignore_blank=False):
- print('start testing...')
+ print("start testing...")
iou_thresh = 0.5
val_names = os.listdir(gt_dir)
num_gt_chars = 0
@@ -79,18 +78,18 @@ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
ed_sum = 0
for i, val_name in enumerate(val_names):
- with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
+ with open(os.path.join(gt_dir, val_name), encoding="utf-8") as f:
gt_lines = [o.strip() for o in f.readlines()]
gts = []
ignore_masks = []
for line in gt_lines:
- parts = line.strip().split('\t')
+ parts = line.strip().split("\t")
# ignore illegal data
if len(parts) < 9:
continue
- assert (len(parts) < 11)
+ assert len(parts) < 11
if len(parts) == 9:
- gts.append(parts[:8] + [''])
+ gts.append(parts[:8] + [""])
else:
gts.append(parts[:8] + [parts[-1]])
@@ -100,15 +99,15 @@ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
if not os.path.exists(val_path):
dt_lines = []
else:
- with open(val_path, encoding='utf-8') as f:
+ with open(val_path, encoding="utf-8") as f:
dt_lines = [o.strip() for o in f.readlines()]
dts = []
for line in dt_lines:
# print(line)
parts = line.strip().split("\t")
- assert (len(parts) < 10), "line error: {}".format(line)
+ assert len(parts) < 10, "line error: {}".format(line)
if len(parts) == 8:
- dts.append(parts + [''])
+ dts.append(parts + [""])
else:
dts.append(parts)
@@ -124,8 +123,7 @@ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
iou = polygon_iou(dt_poly, gt_poly)
if iou >= iou_thresh:
all_ious[(index_gt, index_dt)] = iou
- sorted_ious = sorted(
- all_ious.items(), key=operator.itemgetter(1), reverse=True)
+ sorted_ious = sorted(all_ious.items(), key=operator.itemgetter(1), reverse=True)
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
# matched gt and dt
@@ -140,7 +138,7 @@ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
else:
gt_str = strQ2B(gts[index_gt][8])
dt_str = strQ2B(dts[index_dt][8])
- if ignore_masks[index_gt] == '0':
+ if ignore_masks[index_gt] == "0":
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
if gt_str == dt_str:
@@ -152,21 +150,21 @@ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
for tindex, dt_match_flag in enumerate(dt_match):
if dt_match_flag == False:
dt_str = dts[tindex][8]
- gt_str = ''
+ gt_str = ""
ed_sum += ed(dt_str, gt_str)
dt_count += 1
# unmatched gt
for tindex, gt_match_flag in enumerate(gt_match):
- if gt_match_flag == False and ignore_masks[tindex] == '0':
- dt_str = ''
+ if gt_match_flag == False and ignore_masks[tindex] == "0":
+ dt_str = ""
gt_str = gts[tindex][8]
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
gt_count += 1
eps = 1e-9
- print('hit, dt_count, gt_count', hit, dt_count, gt_count)
+ print("hit, dt_count, gt_count", hit, dt_count, gt_count)
precision = hit / (dt_count + eps)
recall = hit / (gt_count + eps)
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
@@ -174,15 +172,15 @@ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
avg_edit_dist_field = ed_sum / (gt_count + eps)
character_acc = 1 - ed_sum / (num_gt_chars + eps)
- print('character_acc: %.2f' % (character_acc * 100) + "%")
- print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
- print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
- print('precision: %.2f' % (precision * 100) + "%")
- print('recall: %.2f' % (recall * 100) + "%")
- print('fmeasure: %.2f' % (fmeasure * 100) + "%")
+ print("character_acc: %.2f" % (character_acc * 100) + "%")
+ print("avg_edit_dist_field: %.2f" % (avg_edit_dist_field))
+ print("avg_edit_dist_img: %.2f" % (avg_edit_dist_img))
+ print("precision: %.2f" % (precision * 100) + "%")
+ print("recall: %.2f" % (recall * 100) + "%")
+ print("fmeasure: %.2f" % (fmeasure * 100) + "%")
-if __name__ == '__main__':
+if __name__ == "__main__":
# if len(sys.argv) != 3:
# print("python3 ocr_e2e_eval.py gt_dir res_dir")
# exit(-1)
diff --git a/tools/eval.py b/tools/eval.py
index b4c69b6d37..9ac5498b75 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -21,7 +21,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
import paddle
from ppocr.data import build_dataloader, set_signal_handlers
@@ -33,114 +33,135 @@
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build dataloader
set_signal_handlers()
- valid_dataloader = build_dataloader(config, 'Eval', device, logger)
+ valid_dataloader = build_dataloader(config, "Eval", device, logger)
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
# for rec algorithm
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- if config['Architecture']["algorithm"] in ["Distillation",
- ]: # distillation model
- for key in config['Architecture']["Models"]:
- if config['Architecture']['Models'][key]['Head'][
- 'name'] == 'MultiHead': # for multi head
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ if config["Architecture"]["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
+ for key in config["Architecture"]["Models"]:
+ if (
+ config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
+ ): # for multi head
out_channels_list = {}
- if config['PostProcess'][
- 'name'] == 'DistillationSARLabelDecode':
+ if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
- if config['PostProcess'][
- 'name'] == 'DistillationNRTRLabelDecode':
+ if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- out_channels_list['NRTRLabelDecode'] = char_num + 3
- config['Architecture']['Models'][key]['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels_list"
+ ] = out_channels_list
else:
- config['Architecture']["Models"][key]["Head"][
- 'out_channels'] = char_num
- elif config['Architecture']['Head'][
- 'name'] == 'MultiHead': # for multi head
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels"
+ ] = char_num
+ elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
out_channels_list = {}
- if config['PostProcess']['name'] == 'SARLabelDecode':
+ if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
- if config['PostProcess']['name'] == 'NRTRLabelDecode':
+ if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- out_channels_list['NRTRLabelDecode'] = char_num + 3
- config['Architecture']['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+ config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
- config['Architecture']["Head"]['out_channels'] = char_num
+ config["Architecture"]["Head"]["out_channels"] = char_num
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
extra_input_models = [
- "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "VisionLAN",
- "RobustScanner", "SVTR_HGNet"
+ "SRN",
+ "NRTR",
+ "SAR",
+ "SEED",
+ "SVTR",
+ "SVTR_LCNet",
+ "VisionLAN",
+ "RobustScanner",
+ "SVTR_HGNet",
]
extra_input = False
- if config['Architecture']['algorithm'] == 'Distillation':
- for key in config['Architecture']["Models"]:
- extra_input = extra_input or config['Architecture']['Models'][key][
- 'algorithm'] in extra_input_models
+ if config["Architecture"]["algorithm"] == "Distillation":
+ for key in config["Architecture"]["Models"]:
+ extra_input = (
+ extra_input
+ or config["Architecture"]["Models"][key]["algorithm"]
+ in extra_input_models
+ )
else:
- extra_input = config['Architecture']['algorithm'] in extra_input_models
- if "model_type" in config['Architecture'].keys():
- if config['Architecture']['algorithm'] == 'CAN':
- model_type = 'can'
+ extra_input = config["Architecture"]["algorithm"] in extra_input_models
+ if "model_type" in config["Architecture"].keys():
+ if config["Architecture"]["algorithm"] == "CAN":
+ model_type = "can"
else:
- model_type = config['Architecture']['model_type']
+ model_type = config["Architecture"]["model_type"]
else:
model_type = None
# build metric
- eval_class = build_metric(config['Metric'])
+ eval_class = build_metric(config["Metric"])
# amp
use_amp = config["Global"].get("use_amp", False)
- amp_level = config["Global"].get("amp_level", 'O2')
- amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
+ amp_level = config["Global"].get("amp_level", "O2")
+ amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
- 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
- 'FLAGS_max_inplace_grad_add': 8,
+ "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
+ "FLAGS_max_inplace_grad_add": 8,
}
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
- "use_dynamic_loss_scaling", False)
+ "use_dynamic_loss_scaling", False
+ )
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
- use_dynamic_loss_scaling=use_dynamic_loss_scaling)
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling,
+ )
if amp_level == "O2":
model = paddle.amp.decorate(
- models=model, level=amp_level, master_weight=True)
+ models=model, level=amp_level, master_weight=True
+ )
else:
scaler = None
best_model_dict = load_model(
- config, model, model_type=config['Architecture']["model_type"])
+ config, model, model_type=config["Architecture"]["model_type"]
+ )
if len(best_model_dict):
- logger.info('metric in ckpt ***************')
+ logger.info("metric in ckpt ***************")
for k, v in best_model_dict.items():
- logger.info('{}:{}'.format(k, v))
+ logger.info("{}:{}".format(k, v))
# start eval
- metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, model_type, extra_input, scaler,
- amp_level, amp_custom_black_list)
- logger.info('metric eval ***************')
+ metric = program.eval(
+ model,
+ valid_dataloader,
+ post_process_class,
+ eval_class,
+ model_type,
+ extra_input,
+ scaler,
+ amp_level,
+ amp_custom_black_list,
+ )
+ logger.info("metric eval ***************")
for k, v in metric.items():
- logger.info('{}:{}'.format(k, v))
+ logger.info("{}:{}".format(k, v))
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/export_center.py b/tools/export_center.py
index 3f7a883528..e79c2b8629 100644
--- a/tools/export_center.py
+++ b/tools/export_center.py
@@ -22,7 +22,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
from ppocr.data import build_dataloader, set_signal_handlers
from ppocr.modeling.architectures import build_model
@@ -33,46 +33,45 @@
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build dataloader
- config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
- config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
- 'data_dir']
- config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
- 'label_file_list']
+ config["Eval"]["dataset"]["name"] = config["Train"]["dataset"]["name"]
+ config["Eval"]["dataset"]["data_dir"] = config["Train"]["dataset"]["data_dir"]
+ config["Eval"]["dataset"]["label_file_list"] = config["Train"]["dataset"][
+ "label_file_list"
+ ]
set_signal_handlers()
- eval_dataloader = build_dataloader(config, 'Eval', device, logger)
+ eval_dataloader = build_dataloader(config, "Eval", device, logger)
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
# for rec algorithm
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- config['Architecture']["Head"]['out_channels'] = char_num
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ config["Architecture"]["Head"]["out_channels"] = char_num
- #set return_features = True
- config['Architecture']["Head"]["return_feats"] = True
+ # set return_features = True
+ config["Architecture"]["Head"]["return_feats"] = True
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
best_model_dict = load_model(config, model)
if len(best_model_dict):
- logger.info('metric in ckpt ***************')
+ logger.info("metric in ckpt ***************")
for k, v in best_model_dict.items():
- logger.info('{}:{}'.format(k, v))
+ logger.info("{}:{}".format(k, v))
# get features from train data
char_center = program.get_center(model, eval_dataloader, post_process_class)
- #serialize to disk
- with open("train_center.pkl", 'wb') as f:
+ # serialize to disk
+ with open("train_center.pkl", "wb") as f:
pickle.dump(char_center, f)
return
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/export_model.py b/tools/export_model.py
index 8228175ead..8ca31c9d58 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -31,136 +31,125 @@
from tools.program import load_config, merge_config, ArgsParser
-def export_single_model(model,
- arch_config,
- save_path,
- logger,
- input_shape=None,
- quanter=None):
+def export_single_model(
+ model, arch_config, save_path, logger, input_shape=None, quanter=None
+):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 1, 64, 256], dtype="float32"), [
- paddle.static.InputSpec(
- shape=[None, 256, 1],
- dtype="int64"), paddle.static.InputSpec(
- shape=[None, max_text_length, 1], dtype="int64"),
- paddle.static.InputSpec(
- shape=[None, 8, max_text_length, max_text_length],
- dtype="int64"), paddle.static.InputSpec(
- shape=[None, 8, max_text_length, max_text_length],
- dtype="int64")
- ]
+ paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
+ [
+ paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
+ paddle.static.InputSpec(
+ shape=[None, max_text_length, 1], dtype="int64"
+ ),
+ paddle.static.InputSpec(
+ shape=[None, 8, max_text_length, max_text_length], dtype="int64"
+ ),
+ paddle.static.InputSpec(
+ shape=[None, 8, max_text_length, max_text_length], dtype="int64"
+ ),
+ ],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SAR":
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 48, 160], dtype="float32"),
- [paddle.static.InputSpec(
- shape=[None], dtype="float32")]
+ paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
+ [paddle.static.InputSpec(shape=[None], dtype="float32")],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 48, -1], dtype="float32"),
+ paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
other_shape = [
- paddle.static.InputSpec(
- shape=[None] + input_shape, dtype="float32"),
+ paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN":
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 64, 256], dtype="float32"),
+ paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["model_type"] == "sr":
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 16, 64], dtype="float32")
+ paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ViTSTR":
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 1, 224, 224], dtype="float32"),
+ paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ABINet":
if not input_shape:
input_shape = [3, 32, 128]
other_shape = [
- paddle.static.InputSpec(
- shape=[None] + input_shape, dtype="float32"),
+ paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
+ elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 1, 32, 100], dtype="float32"),
+ paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] in ['SATRN']:
+ elif arch_config["algorithm"] in ["SATRN"]:
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 32, 100], dtype="float32"),
+ paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "VisionLAN":
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 64, 256], dtype="float32"),
+ paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 48, 160], dtype="float32"), [
- paddle.static.InputSpec(
- shape=[None, ], dtype="float32"),
- paddle.static.InputSpec(
- shape=[None, max_text_length], dtype="int64")
- ]
+ paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
+ [
+ paddle.static.InputSpec(
+ shape=[
+ None,
+ ],
+ dtype="float32",
+ ),
+ paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
+ ],
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "CAN":
- other_shape = [[
- paddle.static.InputSpec(
- shape=[None, 1, None, None],
- dtype="float32"), paddle.static.InputSpec(
- shape=[None, 1, None, None], dtype="float32"),
- paddle.static.InputSpec(
- shape=[None, arch_config['Head']['max_text_length']],
- dtype="int64")
- ]]
+ other_shape = [
+ [
+ paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
+ paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
+ paddle.static.InputSpec(
+ shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
+ ),
+ ]
+ ]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
- paddle.static.InputSpec(
- shape=[None, 512], dtype="int64"), # input_ids
- paddle.static.InputSpec(
- shape=[None, 512, 4], dtype="int64"), # bbox
- paddle.static.InputSpec(
- shape=[None, 512], dtype="int64"), # attention_mask
- paddle.static.InputSpec(
- shape=[None, 512], dtype="int64"), # token_type_ids
- paddle.static.InputSpec(
- shape=[None, 3, 224, 224], dtype="int64"), # image
+ paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
+ paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
+ paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
+ paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
+ paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
]
- if 'Re' in arch_config['Backbone']['name']:
- input_spec.extend([
- paddle.static.InputSpec(
- shape=[None, 512, 3], dtype="int64"), # entities
- paddle.static.InputSpec(
- shape=[None, None, 2], dtype="int64"), # relations
- ])
+ if "Re" in arch_config["Backbone"]["name"]:
+ input_spec.extend(
+ [
+ paddle.static.InputSpec(
+ shape=[None, 512, 3], dtype="int64"
+ ), # entities
+ paddle.static.InputSpec(
+ shape=[None, None, 2], dtype="int64"
+ ), # relations
+ ]
+ )
if model.backbone.use_visual_backbone is False:
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
@@ -168,9 +157,11 @@ def export_single_model(model,
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
- if "Transform" in arch_config and arch_config[
- "Transform"] is not None and arch_config["Transform"][
- "name"] == "TPS":
+ if (
+ "Transform" in arch_config
+ and arch_config["Transform"] is not None
+ and arch_config["Transform"]["name"] == "TPS"
+ ):
logger.info(
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
@@ -184,12 +175,14 @@ def export_single_model(model,
model = to_static(
model,
input_spec=[
- paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype="float32")
- ])
+ paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
+ ],
+ )
- if arch_config["model_type"] != "sr" and arch_config["Backbone"][
- "name"] == "PPLCNetV3":
+ if (
+ arch_config["model_type"] != "sr"
+ and arch_config["Backbone"]["name"] == "PPLCNetV3"
+ ):
# for rep lcnetv3
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
@@ -210,83 +203,92 @@ def main():
logger = get_logger()
# build post process
- post_process_class = build_post_process(config["PostProcess"],
- config["Global"])
+ post_process_class = build_post_process(config["PostProcess"], config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
- if config["Architecture"]["algorithm"] in ["Distillation",
- ]: # distillation model
+ if config["Architecture"]["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
for key in config["Architecture"]["Models"]:
- if config["Architecture"]["Models"][key]["Head"][
- "name"] == 'MultiHead': # multi head
+ if (
+ config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
+ ): # multi head
out_channels_list = {}
- if config['PostProcess'][
- 'name'] == 'DistillationSARLabelDecode':
+ if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
- if config['PostProcess'][
- 'name'] == 'DistillationNRTRLabelDecode':
+ if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- out_channels_list['NRTRLabelDecode'] = char_num + 3
- config['Architecture']['Models'][key]['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels_list"
+ ] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
- "out_channels"] = char_num
+ "out_channels"
+ ] = char_num
# just one final tensor needs to exported for inference
- config["Architecture"]["Models"][key][
- "return_all_feats"] = False
- elif config['Architecture']['Head'][
- 'name'] == 'MultiHead': # multi head
+ config["Architecture"]["Models"][key]["return_all_feats"] = False
+ elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
out_channels_list = {}
- char_num = len(getattr(post_process_class, 'character'))
- if config['PostProcess']['name'] == 'SARLabelDecode':
+ char_num = len(getattr(post_process_class, "character"))
+ if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
- if config['PostProcess']['name'] == 'NRTRLabelDecode':
+ if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- out_channels_list['NRTRLabelDecode'] = char_num + 3
- config['Architecture']['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+ config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
# for sr algorithm
if config["Architecture"]["model_type"] == "sr":
- config['Architecture']["Transform"]['infer_mode'] = True
+ config["Architecture"]["Transform"]["infer_mode"] = True
model = build_model(config["Architecture"])
- load_model(config, model, model_type=config['Architecture']["model_type"])
+ load_model(config, model, model_type=config["Architecture"]["model_type"])
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
- if arch_config["algorithm"] in ["SVTR", "CPPD"] and arch_config["Head"][
- "name"] != 'MultiHead':
- input_shape = config["Eval"]["dataset"]["transforms"][-2][
- 'SVTRRecResizeImg']['image_shape']
+ if (
+ arch_config["algorithm"] in ["SVTR", "CPPD"]
+ and arch_config["Head"]["name"] != "MultiHead"
+ ):
+ input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
+ "image_shape"
+ ]
elif arch_config["algorithm"].lower() == "ABINet".lower():
- rec_rs = [c for c in config["Eval"]["dataset"]["transforms"] if 'ABINetRecResizeImg' in c]
- input_shape = rec_rs[0]['ABINetRecResizeImg']['image_shape'] if rec_rs else None
+ rec_rs = [
+ c
+ for c in config["Eval"]["dataset"]["transforms"]
+ if "ABINetRecResizeImg" in c
+ ]
+ input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
else:
input_shape = None
- if arch_config["algorithm"] in ["Distillation", ]: # distillation model
+ if arch_config["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
- export_single_model(model.model_list[idx], archs[idx],
- sub_model_save_path, logger)
+ export_single_model(
+ model.model_list[idx], archs[idx], sub_model_save_path, logger
+ )
else:
save_path = os.path.join(save_path, "inference")
export_single_model(
- model, arch_config, save_path, logger, input_shape=input_shape)
+ model, arch_config, save_path, logger, input_shape=input_shape
+ )
if __name__ == "__main__":
diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py
index d2b7108ca3..e4d5a88c12 100755
--- a/tools/infer/predict_cls.py
+++ b/tools/infer/predict_cls.py
@@ -16,9 +16,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import copy
@@ -41,12 +41,16 @@ def __init__(self, args):
self.cls_batch_num = args.cls_batch_num
self.cls_thresh = args.cls_thresh
postprocess_params = {
- 'name': 'ClsPostProcess',
+ "name": "ClsPostProcess",
"label_list": args.label_list,
}
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, _ = \
- utility.create_predictor(args, 'cls', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ _,
+ ) = utility.create_predictor(args, "cls", logger)
self.use_onnx = args.use_onnx
def resize_norm_img(self, img):
@@ -59,7 +63,7 @@ def resize_norm_img(self, img):
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if self.cls_image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -81,11 +85,10 @@ def __call__(self, img_list):
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
- cls_res = [['', 0.0]] * img_num
+ cls_res = [["", 0.0]] * img_num
batch_num = self.cls_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
-
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
@@ -116,9 +119,10 @@ def __call__(self, img_list):
for rno in range(len(cls_result)):
label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score]
- if '180' in label and score > self.cls_thresh:
+ if "180" in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
- img_list[indices[beg_img_no + rno]], 1)
+ img_list[indices[beg_img_no + rno]], 1
+ )
return img_list, cls_res, elapse
@@ -143,8 +147,9 @@ def main(args):
logger.info(E)
exit()
for ino in range(len(img_list)):
- logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
- cls_res[ino]))
+ logger.info(
+ "Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ino])
+ )
if __name__ == "__main__":
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 831a49e6d4..148010acc1 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -16,9 +16,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import numpy as np
@@ -31,6 +31,7 @@
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
import json
+
logger = get_logger()
@@ -39,28 +40,27 @@ def __init__(self, args):
self.args = args
self.det_algorithm = args.det_algorithm
self.use_onnx = args.use_onnx
- pre_process_list = [{
- 'DetResizeForTest': {
- 'limit_side_len': args.det_limit_side_len,
- 'limit_type': args.det_limit_type,
- }
- }, {
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225],
- 'mean': [0.485, 0.456, 0.406],
- 'scale': '1./255.',
- 'order': 'hwc'
- }
- }, {
- 'ToCHWImage': None
- }, {
- 'KeepKeys': {
- 'keep_keys': ['image', 'shape']
- }
- }]
+ pre_process_list = [
+ {
+ "DetResizeForTest": {
+ "limit_side_len": args.det_limit_side_len,
+ "limit_type": args.det_limit_type,
+ }
+ },
+ {
+ "NormalizeImage": {
+ "std": [0.229, 0.224, 0.225],
+ "mean": [0.485, 0.456, 0.406],
+ "scale": "1./255.",
+ "order": "hwc",
+ }
+ },
+ {"ToCHWImage": None},
+ {"KeepKeys": {"keep_keys": ["image", "shape"]}},
+ ]
postprocess_params = {}
if self.det_algorithm == "DB":
- postprocess_params['name'] = 'DBPostProcess'
+ postprocess_params["name"] = "DBPostProcess"
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
@@ -69,7 +69,7 @@ def __init__(self, args):
postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["box_type"] = args.det_box_type
elif self.det_algorithm == "DB++":
- postprocess_params['name'] = 'DBPostProcess'
+ postprocess_params["name"] = "DBPostProcess"
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
@@ -78,30 +78,27 @@ def __init__(self, args):
postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["box_type"] = args.det_box_type
pre_process_list[1] = {
- 'NormalizeImage': {
- 'std': [1.0, 1.0, 1.0],
- 'mean':
- [0.48109378172549, 0.45752457890196, 0.40787054090196],
- 'scale': '1./255.',
- 'order': 'hwc'
+ "NormalizeImage": {
+ "std": [1.0, 1.0, 1.0],
+ "mean": [0.48109378172549, 0.45752457890196, 0.40787054090196],
+ "scale": "1./255.",
+ "order": "hwc",
}
}
elif self.det_algorithm == "EAST":
- postprocess_params['name'] = 'EASTPostProcess'
+ postprocess_params["name"] = "EASTPostProcess"
postprocess_params["score_thresh"] = args.det_east_score_thresh
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
elif self.det_algorithm == "SAST":
pre_process_list[0] = {
- 'DetResizeForTest': {
- 'resize_long': args.det_limit_side_len
- }
+ "DetResizeForTest": {"resize_long": args.det_limit_side_len}
}
- postprocess_params['name'] = 'SASTPostProcess'
+ postprocess_params["name"] = "SASTPostProcess"
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
- if args.det_box_type == 'poly':
+ if args.det_box_type == "poly":
postprocess_params["sample_pts_num"] = 6
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
@@ -111,35 +108,35 @@ def __init__(self, args):
postprocess_params["shrink_ratio_of_width"] = 0.3
elif self.det_algorithm == "PSE":
- postprocess_params['name'] = 'PSEPostProcess'
+ postprocess_params["name"] = "PSEPostProcess"
postprocess_params["thresh"] = args.det_pse_thresh
postprocess_params["box_thresh"] = args.det_pse_box_thresh
postprocess_params["min_area"] = args.det_pse_min_area
postprocess_params["box_type"] = args.det_box_type
postprocess_params["scale"] = args.det_pse_scale
elif self.det_algorithm == "FCE":
- pre_process_list[0] = {
- 'DetResizeForTest': {
- 'rescale_img': [1080, 736]
- }
- }
- postprocess_params['name'] = 'FCEPostProcess'
+ pre_process_list[0] = {"DetResizeForTest": {"rescale_img": [1080, 736]}}
+ postprocess_params["name"] = "FCEPostProcess"
postprocess_params["scales"] = args.scales
postprocess_params["alpha"] = args.alpha
postprocess_params["beta"] = args.beta
postprocess_params["fourier_degree"] = args.fourier_degree
postprocess_params["box_type"] = args.det_box_type
elif self.det_algorithm == "CT":
- pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
- postprocess_params['name'] = 'CTPostProcess'
+ pre_process_list[0] = {"ScaleAlignedShort": {"short_size": 640}}
+ postprocess_params["name"] = "CTPostProcess"
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
- args, 'det', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "det", logger)
if self.use_onnx:
img_h, img_w = self.input_tensor.shape[2:]
@@ -147,14 +144,13 @@ def __init__(self, args):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
pre_process_list[0] = {
- 'DetResizeForTest': {
- 'image_shape': [img_h, img_w]
- }
+ "DetResizeForTest": {"image_shape": [img_h, img_w]}
}
self.preprocess_op = create_operators(pre_process_list)
if args.benchmark:
import auto_log
+
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
@@ -167,11 +163,10 @@ def __init__(self, args):
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
- time_keys=[
- 'preprocess_time', 'inference_time', 'postprocess_time'
- ],
+ time_keys=["preprocess_time", "inference_time", "postprocess_time"],
warmup=2,
- logger=logger)
+ logger=logger,
+ )
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
@@ -219,7 +214,7 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
def predict(self, img):
ori_im = img.copy()
- data = {'image': img}
+ data = {"image": img}
st = time.time()
@@ -252,28 +247,28 @@ def predict(self, img):
preds = {}
if self.det_algorithm == "EAST":
- preds['f_geo'] = outputs[0]
- preds['f_score'] = outputs[1]
- elif self.det_algorithm == 'SAST':
- preds['f_border'] = outputs[0]
- preds['f_score'] = outputs[1]
- preds['f_tco'] = outputs[2]
- preds['f_tvo'] = outputs[3]
- elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
- preds['maps'] = outputs[0]
- elif self.det_algorithm == 'FCE':
+ preds["f_geo"] = outputs[0]
+ preds["f_score"] = outputs[1]
+ elif self.det_algorithm == "SAST":
+ preds["f_border"] = outputs[0]
+ preds["f_score"] = outputs[1]
+ preds["f_tco"] = outputs[2]
+ preds["f_tvo"] = outputs[3]
+ elif self.det_algorithm in ["DB", "PSE", "DB++"]:
+ preds["maps"] = outputs[0]
+ elif self.det_algorithm == "FCE":
for i, output in enumerate(outputs):
- preds['level_{}'.format(i)] = output
+ preds["level_{}".format(i)] = output
elif self.det_algorithm == "CT":
- preds['maps'] = outputs[0]
- preds['score'] = outputs[1]
+ preds["maps"] = outputs[0]
+ preds["score"] = outputs[1]
else:
raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list)
- dt_boxes = post_result[0]['points']
+ dt_boxes = post_result[0]["points"]
- if self.args.det_box_type == 'poly':
+ if self.args.det_box_type == "poly":
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
@@ -289,52 +284,80 @@ def __call__(self, img):
MIN_BOUND_DISTANCE = 50
dt_boxes = np.zeros((0, 4, 2), dtype=np.float32)
elapse = 0
- if img.shape[0] / img.shape[1] > 2 and img.shape[0] > self.args.det_limit_side_len:
+ if (
+ img.shape[0] / img.shape[1] > 2
+ and img.shape[0] > self.args.det_limit_side_len
+ ):
start_h = 0
end_h = 0
while end_h <= img.shape[0]:
end_h = start_h + img.shape[1] * 3 // 4
- subimg = img[start_h: end_h, :]
+ subimg = img[start_h:end_h, :]
if len(subimg) == 0:
break
sub_dt_boxes, sub_elapse = self.predict(subimg)
offset = start_h
# To prevent text blocks from being cut off, roll back a certain buffer area.
- if len(sub_dt_boxes) == 0 or img.shape[1] - max([x[-1][1] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE:
+ if (
+ len(sub_dt_boxes) == 0
+ or img.shape[1] - max([x[-1][1] for x in sub_dt_boxes])
+ > MIN_BOUND_DISTANCE
+ ):
start_h = end_h
else:
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 1])
sub_dt_boxes = sub_dt_boxes[sorted_indices]
- bottom_line = 0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 2, 1]))
+ bottom_line = (
+ 0
+ if len(sub_dt_boxes) <= 1
+ else int(np.max(sub_dt_boxes[:-1, 2, 1]))
+ )
if bottom_line > 0:
start_h += bottom_line
- sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 2, 1] <= bottom_line]
+ sub_dt_boxes = sub_dt_boxes[
+ sub_dt_boxes[:, 2, 1] <= bottom_line
+ ]
else:
start_h = end_h
if len(sub_dt_boxes) > 0:
if dt_boxes.shape[0] == 0:
- dt_boxes = sub_dt_boxes + np.array([0, offset], dtype=np.float32)
+ dt_boxes = sub_dt_boxes + np.array(
+ [0, offset], dtype=np.float32
+ )
else:
- dt_boxes = np.append(dt_boxes,
- sub_dt_boxes + np.array([0, offset], dtype=np.float32),
- axis=0)
+ dt_boxes = np.append(
+ dt_boxes,
+ sub_dt_boxes + np.array([0, offset], dtype=np.float32),
+ axis=0,
+ )
elapse += sub_elapse
- elif img.shape[1] / img.shape[0] > 3 and img.shape[1] > self.args.det_limit_side_len * 3:
+ elif (
+ img.shape[1] / img.shape[0] > 3
+ and img.shape[1] > self.args.det_limit_side_len * 3
+ ):
start_w = 0
end_w = 0
while end_w <= img.shape[1]:
end_w = start_w + img.shape[0] * 3 // 4
- subimg = img[:, start_w: end_w]
+ subimg = img[:, start_w:end_w]
if len(subimg) == 0:
break
sub_dt_boxes, sub_elapse = self.predict(subimg)
offset = start_w
- if len(sub_dt_boxes) == 0 or img.shape[0] - max([x[-1][0] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE:
+ if (
+ len(sub_dt_boxes) == 0
+ or img.shape[0] - max([x[-1][0] for x in sub_dt_boxes])
+ > MIN_BOUND_DISTANCE
+ ):
start_w = end_w
else:
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 0])
sub_dt_boxes = sub_dt_boxes[sorted_indices]
- right_line = 0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 1, 0]))
+ right_line = (
+ 0
+ if len(sub_dt_boxes) <= 1
+ else int(np.max(sub_dt_boxes[:-1, 1, 0]))
+ )
if right_line > 0:
start_w += right_line
sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 1, 0] <= right_line]
@@ -342,11 +365,15 @@ def __call__(self, img):
start_w = end_w
if len(sub_dt_boxes) > 0:
if dt_boxes.shape[0] == 0:
- dt_boxes = sub_dt_boxes + np.array([offset, 0], dtype=np.float32)
+ dt_boxes = sub_dt_boxes + np.array(
+ [offset, 0], dtype=np.float32
+ )
else:
- dt_boxes = np.append(dt_boxes,
- sub_dt_boxes + np.array([offset, 0], dtype=np.float32),
- axis=0)
+ dt_boxes = np.append(
+ dt_boxes,
+ sub_dt_boxes + np.array([offset, 0], dtype=np.float32),
+ axis=0,
+ )
elapse += sub_elapse
else:
dt_boxes, elapse = self.predict(img)
@@ -387,37 +414,49 @@ def __call__(self, img):
elapse = time.time() - st
total_time += elapse
if len(imgs) > 1:
- save_pred = os.path.basename(image_file) + '_' + str(
- index) + "\t" + str(
- json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+ save_pred = (
+ os.path.basename(image_file)
+ + "_"
+ + str(index)
+ + "\t"
+ + str(json.dumps([x.tolist() for x in dt_boxes]))
+ + "\n"
+ )
else:
- save_pred = os.path.basename(image_file) + "\t" + str(
- json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+ save_pred = (
+ os.path.basename(image_file)
+ + "\t"
+ + str(json.dumps([x.tolist() for x in dt_boxes]))
+ + "\n"
+ )
save_results.append(save_pred)
logger.info(save_pred)
if len(imgs) > 1:
- logger.info("{}_{} The predict time of {}: {}".format(
- idx, index, image_file, elapse))
+ logger.info(
+ "{}_{} The predict time of {}: {}".format(
+ idx, index, image_file, elapse
+ )
+ )
else:
- logger.info("{} The predict time of {}: {}".format(
- idx, image_file, elapse))
+ logger.info(
+ "{} The predict time of {}: {}".format(idx, image_file, elapse)
+ )
src_im = utility.draw_text_det_res(dt_boxes, img)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
- save_file = image_file.replace('.pdf',
- '_' + str(index) + '.png')
+ save_file = image_file.replace(".pdf", "_" + str(index) + ".png")
else:
save_file = image_file
img_path = os.path.join(
- draw_img_save_dir,
- "det_res_{}".format(os.path.basename(save_file)))
+ draw_img_save_dir, "det_res_{}".format(os.path.basename(save_file))
+ )
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
- with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
+ with open(os.path.join(draw_img_save_dir, "det_results.txt"), "w") as f:
f.writelines(save_results)
f.close()
if args.benchmark:
diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py
index de315d701c..52526e3941 100755
--- a/tools/infer/predict_e2e.py
+++ b/tools/infer/predict_e2e.py
@@ -16,9 +16,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import numpy as np
@@ -39,31 +39,28 @@ def __init__(self, args):
self.args = args
self.e2e_algorithm = args.e2e_algorithm
self.use_onnx = args.use_onnx
- pre_process_list = [{
- 'E2EResizeForTest': {}
- }, {
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225],
- 'mean': [0.485, 0.456, 0.406],
- 'scale': '1./255.',
- 'order': 'hwc'
- }
- }, {
- 'ToCHWImage': None
- }, {
- 'KeepKeys': {
- 'keep_keys': ['image', 'shape']
- }
- }]
+ pre_process_list = [
+ {"E2EResizeForTest": {}},
+ {
+ "NormalizeImage": {
+ "std": [0.229, 0.224, 0.225],
+ "mean": [0.485, 0.456, 0.406],
+ "scale": "1./255.",
+ "order": "hwc",
+ }
+ },
+ {"ToCHWImage": None},
+ {"KeepKeys": {"keep_keys": ["image", "shape"]}},
+ ]
postprocess_params = {}
if self.e2e_algorithm == "PGNet":
pre_process_list[0] = {
- 'E2EResizeForTest': {
- 'max_side_len': args.e2e_limit_side_len,
- 'valid_set': 'totaltext'
+ "E2EResizeForTest": {
+ "max_side_len": args.e2e_limit_side_len,
+ "valid_set": "totaltext",
}
}
- postprocess_params['name'] = 'PGPostProcess'
+ postprocess_params["name"] = "PGPostProcess"
postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
@@ -74,8 +71,14 @@ def __init__(self, args):
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor(
- args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ _,
+ ) = utility.create_predictor(
+ args, "e2e", logger
+ ) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def clip_det_res(self, points, img_height, img_width):
@@ -94,9 +97,8 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
return dt_boxes
def __call__(self, img):
-
ori_im = img.copy()
- data = {'image': img}
+ data = {"image": img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
@@ -111,10 +113,10 @@ def __call__(self, img):
input_dict[self.input_tensor.name] = img
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = {}
- preds['f_border'] = outputs[0]
- preds['f_char'] = outputs[1]
- preds['f_direction'] = outputs[2]
- preds['f_score'] = outputs[3]
+ preds["f_border"] = outputs[0]
+ preds["f_char"] = outputs[1]
+ preds["f_direction"] = outputs[2]
+ preds["f_score"] = outputs[3]
else:
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
@@ -124,15 +126,15 @@ def __call__(self, img):
outputs.append(output)
preds = {}
- if self.e2e_algorithm == 'PGNet':
- preds['f_border'] = outputs[0]
- preds['f_char'] = outputs[1]
- preds['f_direction'] = outputs[2]
- preds['f_score'] = outputs[3]
+ if self.e2e_algorithm == "PGNet":
+ preds["f_border"] = outputs[0]
+ preds["f_char"] = outputs[1]
+ preds["f_direction"] = outputs[2]
+ preds["f_score"] = outputs[3]
else:
raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list)
- points, strs = post_result['points'], post_result['texts']
+ points, strs = post_result["points"], post_result["texts"]
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, strs, elapse
@@ -161,8 +163,7 @@ def __call__(self, img):
logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_e2e_res(points, strs, image_file)
img_name_pure = os.path.split(image_file)[-1]
- img_path = os.path.join(draw_img_save,
- "e2e_res_{}".format(img_name_pure))
+ img_path = os.path.join(draw_img_save, "e2e_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1:
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index b47e0c64d5..f15ab09e59 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -14,11 +14,12 @@
import os
import sys
from PIL import Image
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import numpy as np
@@ -41,109 +42,114 @@ def __init__(self, args):
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
postprocess_params = {
- 'name': 'CTCLabelDecode',
+ "name": "CTCLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
if self.rec_algorithm == "SRN":
postprocess_params = {
- 'name': 'SRNLabelDecode',
+ "name": "SRNLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
- 'name': 'AttnLabelDecode',
+ "name": "AttnLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
- elif self.rec_algorithm == 'NRTR':
+ elif self.rec_algorithm == "NRTR":
postprocess_params = {
- 'name': 'NRTRLabelDecode',
+ "name": "NRTRLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "SAR":
postprocess_params = {
- 'name': 'SARLabelDecode',
+ "name": "SARLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "VisionLAN":
postprocess_params = {
- 'name': 'VLLabelDecode',
+ "name": "VLLabelDecode",
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
- "max_text_length": args.max_text_length
+ "max_text_length": args.max_text_length,
}
- elif self.rec_algorithm == 'ViTSTR':
+ elif self.rec_algorithm == "ViTSTR":
postprocess_params = {
- 'name': 'ViTSTRLabelDecode',
+ "name": "ViTSTRLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
- elif self.rec_algorithm == 'ABINet':
+ elif self.rec_algorithm == "ABINet":
postprocess_params = {
- 'name': 'ABINetLabelDecode',
+ "name": "ABINetLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "SPIN":
postprocess_params = {
- 'name': 'SPINLabelDecode',
+ "name": "SPINLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "RobustScanner":
postprocess_params = {
- 'name': 'SARLabelDecode',
+ "name": "SARLabelDecode",
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
- "rm_symbol": True
+ "rm_symbol": True,
}
- elif self.rec_algorithm == 'RFL':
+ elif self.rec_algorithm == "RFL":
postprocess_params = {
- 'name': 'RFLLabelDecode',
+ "name": "RFLLabelDecode",
"character_dict_path": None,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "SATRN":
postprocess_params = {
- 'name': 'SATRNLabelDecode',
+ "name": "SATRNLabelDecode",
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
- "rm_symbol": True
+ "rm_symbol": True,
}
elif self.rec_algorithm in ["CPPD", "CPPDPadding"]:
postprocess_params = {
- 'name': 'CPPDLabelDecode',
+ "name": "CPPDLabelDecode",
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
- "rm_symbol": True
+ "rm_symbol": True,
}
elif self.rec_algorithm == "PREN":
- postprocess_params = {'name': 'PRENLabelDecode'}
+ postprocess_params = {"name": "PRENLabelDecode"}
elif self.rec_algorithm == "CAN":
self.inverse = args.rec_image_inverse
postprocess_params = {
- 'name': 'CANLabelDecode',
+ "name": "CANLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "ParseQ":
postprocess_params = {
- 'name': 'ParseQLabelDecode',
+ "name": "ParseQLabelDecode",
"character_dict_path": args.rec_char_dict_path,
- "use_space_char": args.use_space_char
+ "use_space_char": args.use_space_char,
}
self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
- self.predictor, self.input_tensor, self.output_tensors, self.config = \
- utility.create_predictor(args, 'rec', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "rec", logger)
self.benchmark = args.benchmark
self.use_onnx = args.use_onnx
if args.benchmark:
import auto_log
+
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
@@ -151,41 +157,39 @@ def __init__(self, args):
model_precision=args.precision,
batch_size=args.rec_batch_num,
data_shape="dynamic",
- save_path=None, #args.save_log_path,
+ save_path=None, # args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
- time_keys=[
- 'preprocess_time', 'inference_time', 'postprocess_time'
- ],
+ time_keys=["preprocess_time", "inference_time", "postprocess_time"],
warmup=0,
- logger=logger)
+ logger=logger,
+ )
self.return_word_box = args.return_word_box
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
- if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
+ if self.rec_algorithm == "NRTR" or self.rec_algorithm == "ViTSTR":
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
image_pil = Image.fromarray(np.uint8(img))
- if self.rec_algorithm == 'ViTSTR':
+ if self.rec_algorithm == "ViTSTR":
img = image_pil.resize([imgW, imgH], Image.BICUBIC)
else:
img = image_pil.resize([imgW, imgH], Image.Resampling.LANCZOS)
img = np.array(img)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
- if self.rec_algorithm == 'ViTSTR':
- norm_img = norm_img.astype(np.float32) / 255.
+ if self.rec_algorithm == "ViTSTR":
+ norm_img = norm_img.astype(np.float32) / 255.0
else:
- norm_img = norm_img.astype(np.float32) / 128. - 1.
+ norm_img = norm_img.astype(np.float32) / 128.0 - 1.0
return norm_img
- elif self.rec_algorithm == 'RFL':
+ elif self.rec_algorithm == "RFL":
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
- resized_image = resized_image.astype('float32')
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
+ resized_image = resized_image.astype("float32")
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
resized_image -= 0.5
@@ -206,12 +210,12 @@ def resize_norm_img(self, img, max_wh_ratio):
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
- if self.rec_algorithm == 'RARE':
+ if self.rec_algorithm == "RARE":
if resized_w > self.rec_image_shape[2]:
resized_w = self.rec_image_shape[2]
imgW = self.rec_image_shape[2]
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
@@ -220,12 +224,10 @@ def resize_norm_img(self, img, max_wh_ratio):
return padding_im
def resize_norm_img_vl(self, img, image_shape):
-
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image = resized_image.astype('float32')
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
@@ -247,7 +249,7 @@ def resize_norm_img_srn(self, img, image_shape):
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
- img_black[:, 0:img_np.shape[1]] = img_np
+ img_black[:, 0 : img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
@@ -256,53 +258,68 @@ def resize_norm_img_srn(self, img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
-
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
- encoder_word_pos = np.array(range(0, feature_dim)).reshape(
- (feature_dim, 1)).astype('int64')
- gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
- (max_text_length, 1)).astype('int64')
+ encoder_word_pos = (
+ np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64")
+ )
+ gsrm_word_pos = (
+ np.array(range(0, max_text_length))
+ .reshape((max_text_length, 1))
+ .astype("int64")
+ )
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
- [-1, 1, max_text_length, max_text_length])
- gsrm_slf_attn_bias1 = np.tile(
- gsrm_slf_attn_bias1,
- [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+ [-1, 1, max_text_length, max_text_length]
+ )
+ gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype(
+ "float32"
+ ) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
- [-1, 1, max_text_length, max_text_length])
- gsrm_slf_attn_bias2 = np.tile(
- gsrm_slf_attn_bias2,
- [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+ [-1, 1, max_text_length, max_text_length]
+ )
+ gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype(
+ "float32"
+ ) * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
- encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
- gsrm_slf_attn_bias2
+ encoder_word_pos,
+ gsrm_word_pos,
+ gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2,
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
- [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
- self.srn_other_inputs(image_shape, num_heads, max_text_length)
+ [
+ encoder_word_pos,
+ gsrm_word_pos,
+ gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2,
+ ] = self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
- return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
- gsrm_slf_attn_bias2)
+ return (
+ norm_img,
+ encoder_word_pos,
+ gsrm_word_pos,
+ gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2,
+ )
- def resize_norm_img_sar(self, img, image_shape,
- width_downsample_ratio=0.25):
+ def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
@@ -320,8 +337,8 @@ def resize_norm_img_sar(self, img, image_shape,
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
- resized_image = resized_image.astype('float32')
- # norm
+ resized_image = resized_image.astype("float32")
+ # norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -354,27 +371,22 @@ def resize_norm_img_spin(self, img):
return img
def resize_norm_img_svtr(self, img, image_shape):
-
imgC, imgH, imgW = image_shape
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image = resized_image.astype('float32')
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
- def resize_norm_img_cppd_padding(self,
- img,
- image_shape,
- padding=True,
- interpolation=cv2.INTER_LINEAR):
+ def resize_norm_img_cppd_padding(
+ self, img, image_shape, padding=True, interpolation=cv2.INTER_LINEAR
+ ):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=interpolation)
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=interpolation)
resized_w = imgW
else:
ratio = w / float(h)
@@ -383,7 +395,7 @@ def resize_norm_img_cppd_padding(self,
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -397,27 +409,22 @@ def resize_norm_img_cppd_padding(self,
return padding_im
def resize_norm_img_abinet(self, img, image_shape):
-
imgC, imgH, imgW = image_shape
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image = resized_image.astype('float32')
- resized_image = resized_image / 255.
+ resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype("float32")
+ resized_image = resized_image / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
- resized_image = (
- resized_image - mean[None, None, ...]) / std[None, None, ...]
+ resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
return resized_image
def norm_img_can(self, img, image_shape):
-
- img = cv2.cvtColor(
- img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.inverse:
img = 255 - img
@@ -428,13 +435,16 @@ def norm_img_can(self, img, image_shape):
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
- img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
- 'constant',
- constant_values=(255))
+ img_padded = np.pad(
+ img,
+ ((0, padding_h), (0, padding_w)),
+ "constant",
+ constant_values=(255),
+ )
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
- img = img.astype('float32')
+ img = img.astype("float32")
return img
@@ -446,7 +456,7 @@ def __call__(self, img_list):
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
- rec_res = [['', 0.0]] * img_num
+ rec_res = [["", 0.0]] * img_num
batch_num = self.rec_batch_num
st = time.time()
if self.benchmark:
@@ -472,71 +482,78 @@ def __call__(self, img_list):
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
- img_list[indices[ino]], self.rec_image_shape)
+ img_list[indices[ino]], self.rec_image_shape
+ )
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn(
- img_list[indices[ino]], self.rec_image_shape, 8, 25)
+ img_list[indices[ino]], self.rec_image_shape, 8, 25
+ )
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
elif self.rec_algorithm in ["SVTR", "SATRN", "ParseQ", "CPPD"]:
- norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
- self.rec_image_shape)
+ norm_img = self.resize_norm_img_svtr(
+ img_list[indices[ino]], self.rec_image_shape
+ )
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm in ["CPPDPadding"]:
norm_img = self.resize_norm_img_cppd_padding(
- img_list[indices[ino]], self.rec_image_shape)
+ img_list[indices[ino]], self.rec_image_shape
+ )
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm in ["VisionLAN", "PREN"]:
- norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
- self.rec_image_shape)
+ norm_img = self.resize_norm_img_vl(
+ img_list[indices[ino]], self.rec_image_shape
+ )
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
- elif self.rec_algorithm == 'SPIN':
+ elif self.rec_algorithm == "SPIN":
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet(
- img_list[indices[ino]], self.rec_image_shape)
+ img_list[indices[ino]], self.rec_image_shape
+ )
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "RobustScanner":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]],
self.rec_image_shape,
- width_downsample_ratio=0.25)
+ width_downsample_ratio=0.25,
+ )
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
word_positions_list = []
- word_positions = np.array(range(0, 40)).astype('int64')
+ word_positions = np.array(range(0, 40)).astype("int64")
word_positions = np.expand_dims(word_positions, axis=0)
word_positions_list.append(word_positions)
elif self.rec_algorithm == "CAN":
- norm_img = self.norm_img_can(img_list[indices[ino]],
- max_wh_ratio)
+ norm_img = self.norm_img_can(img_list[indices[ino]], max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
- norm_image_mask = np.ones(norm_img.shape, dtype='float32')
- word_label = np.ones([1, 36], dtype='int64')
+ norm_image_mask = np.ones(norm_img.shape, dtype="float32")
+ word_label = np.ones([1, 36], dtype="int64")
norm_img_mask_batch = []
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
else:
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
+ norm_img = self.resize_norm_img(
+ img_list[indices[ino]], max_wh_ratio
+ )
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
@@ -547,10 +564,8 @@ def __call__(self, img_list):
if self.rec_algorithm == "SRN":
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
- gsrm_slf_attn_bias1_list = np.concatenate(
- gsrm_slf_attn_bias1_list)
- gsrm_slf_attn_bias2_list = np.concatenate(
- gsrm_slf_attn_bias2_list)
+ gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list)
+ gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch,
@@ -562,14 +577,12 @@ def __call__(self, img_list):
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
- outputs = self.predictor.run(self.output_tensors,
- input_dict)
+ outputs = self.predictor.run(self.output_tensors, input_dict)
preds = {"predict": outputs[2]}
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
- input_tensor = self.predictor.get_input_handle(
- input_names[i])
+ input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
@@ -583,20 +596,17 @@ def __call__(self, img_list):
valid_ratios = np.concatenate(valid_ratios)
inputs = [
norm_img_batch,
- np.array(
- [valid_ratios], dtype=np.float32).T,
+ np.array([valid_ratios], dtype=np.float32).T,
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
- outputs = self.predictor.run(self.output_tensors,
- input_dict)
+ outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
- input_tensor = self.predictor.get_input_handle(
- input_names[i])
+ input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
@@ -614,14 +624,12 @@ def __call__(self, img_list):
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
- outputs = self.predictor.run(self.output_tensors,
- input_dict)
+ outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
- input_tensor = self.predictor.get_input_handle(
- input_names[i])
+ input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
@@ -638,15 +646,13 @@ def __call__(self, img_list):
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
- outputs = self.predictor.run(self.output_tensors,
- input_dict)
+ outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs
else:
input_names = self.predictor.get_input_names()
input_tensor = []
for i in range(len(input_names)):
- input_tensor_i = self.predictor.get_input_handle(
- input_names[i])
+ input_tensor_i = self.predictor.get_input_handle(input_names[i])
input_tensor_i.copy_from_cpu(inputs[i])
input_tensor.append(input_tensor_i)
self.input_tensor = input_tensor
@@ -662,8 +668,7 @@ def __call__(self, img_list):
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
- outputs = self.predictor.run(self.output_tensors,
- input_dict)
+ outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
else:
self.input_tensor.copy_from_cpu(norm_img_batch)
@@ -678,12 +683,13 @@ def __call__(self, img_list):
preds = outputs
else:
preds = outputs[0]
- if self.postprocess_params['name'] == 'CTCLabelDecode':
+ if self.postprocess_params["name"] == "CTCLabelDecode":
rec_result = self.postprocess_op(
preds,
return_word_box=self.return_word_box,
wh_ratio_list=wh_ratio_list,
- max_wh_ratio=max_wh_ratio)
+ max_wh_ratio=max_wh_ratio,
+ )
else:
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
@@ -726,8 +732,9 @@ def main(args):
logger.info(E)
exit()
for ino in range(len(img_list)):
- logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
- rec_res[ino]))
+ logger.info(
+ "Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ino])
+ )
if args.benchmark:
text_recognizer.autolog.report()
diff --git a/tools/infer/predict_sr.py b/tools/infer/predict_sr.py
index ca99f6819f..101c7755d4 100755
--- a/tools/infer/predict_sr.py
+++ b/tools/infer/predict_sr.py
@@ -14,11 +14,12 @@
import os
import sys
from PIL import Image
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import numpy as np
@@ -40,11 +41,16 @@ def __init__(self, args):
self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")]
self.sr_batch_num = args.sr_batch_num
- self.predictor, self.input_tensor, self.output_tensors, self.config = \
- utility.create_predictor(args, 'sr', logger)
+ (
+ self.predictor,
+ self.input_tensor,
+ self.output_tensors,
+ self.config,
+ ) = utility.create_predictor(args, "sr", logger)
self.benchmark = args.benchmark
if args.benchmark:
import auto_log
+
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
@@ -52,16 +58,15 @@ def __init__(self, args):
model_precision=args.precision,
batch_size=args.sr_batch_num,
data_shape="dynamic",
- save_path=None, #args.save_log_path,
+ save_path=None, # args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
- time_keys=[
- 'preprocess_time', 'inference_time', 'postprocess_time'
- ],
+ time_keys=["preprocess_time", "inference_time", "postprocess_time"],
warmup=0,
- logger=logger)
+ logger=logger,
+ )
def resize_norm_img(self, img):
imgC, imgH, imgW = self.sr_image_shape
@@ -133,15 +138,20 @@ def main(args):
for beg_no in range(len(preds)):
sr_img = preds[beg_no][1]
lr_img = preds[beg_no][0]
- for i in (range(sr_img.shape[0])):
+ for i in range(sr_img.shape[0]):
fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
- img_name_pure = os.path.split(valid_image_file_list[
- beg_no * args.sr_batch_num + i])[-1]
- cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
- fm_sr[:, :, ::-1])
- logger.info("The visualized image saved in infer_result/sr_{}".
- format(img_name_pure))
+ img_name_pure = os.path.split(
+ valid_image_file_list[beg_no * args.sr_batch_num + i]
+ )[-1]
+ cv2.imwrite(
+ "infer_result/sr_{}".format(img_name_pure), fm_sr[:, :, ::-1]
+ )
+ logger.info(
+ "The visualized image saved in infer_result/sr_{}".format(
+ img_name_pure
+ )
+ )
except Exception as E:
logger.info(traceback.format_exc())
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index 8af45b4cf5..95b199b2a0 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -17,9 +17,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import copy
@@ -34,7 +34,12 @@
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read
from ppocr.utils.logging import get_logger
-from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
+from tools.infer.utility import (
+ draw_ocr_box_txt,
+ get_rotate_crop_image,
+ get_minarea_rect_crop,
+)
+
logger = get_logger()
@@ -58,14 +63,16 @@ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
bbox_num = len(img_crop_list)
for bno in range(bbox_num):
cv2.imwrite(
- os.path.join(output_dir,
- f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
- img_crop_list[bno])
+ os.path.join(
+ output_dir, f"mg_crop_{bno+self.crop_image_res_index}.jpg"
+ ),
+ img_crop_list[bno],
+ )
logger.debug(f"{bno}, {rec_res[bno]}")
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True):
- time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
+ time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
if img is None:
logger.debug("no valid image provided")
@@ -74,16 +81,17 @@ def __call__(self, img, cls=True):
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
- time_dict['det'] = elapse
+ time_dict["det"] = elapse
if dt_boxes is None:
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
end = time.time()
- time_dict['all'] = end - start
+ time_dict["all"] = end - start
return None, None, time_dict
else:
- logger.debug("dt_boxes num : {}, elapsed : {}".format(
- len(dt_boxes), elapse))
+ logger.debug(
+ "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
+ )
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
@@ -96,19 +104,17 @@ def __call__(self, img, cls=True):
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls and cls:
- img_crop_list, angle_list, elapse = self.text_classifier(
- img_crop_list)
- time_dict['cls'] = elapse
- logger.debug("cls num : {}, elapsed : {}".format(
- len(img_crop_list), elapse))
+ img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
+ time_dict["cls"] = elapse
+ logger.debug(
+ "cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
+ )
rec_res, elapse = self.text_recognizer(img_crop_list)
- time_dict['rec'] = elapse
- logger.debug("rec_res num : {}, elapsed : {}".format(
- len(rec_res), elapse))
+ time_dict["rec"] = elapse
+ logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
if self.args.save_crop_res:
- self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
- rec_res)
+ self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result[0], rec_result[1]
@@ -116,7 +122,7 @@ def __call__(self, img, cls=True):
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
- time_dict['all'] = end - start
+ time_dict["all"] = end - start
return filter_boxes, filter_rec_res, time_dict
@@ -134,8 +140,9 @@ def sorted_boxes(dt_boxes):
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
- if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
- (_boxes[j + 1][0][0] < _boxes[j][0][0]):
+ if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
+ _boxes[j + 1][0][0] < _boxes[j][0][0]
+ ):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
@@ -146,7 +153,7 @@ def sorted_boxes(dt_boxes):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
- image_file_list = image_file_list[args.process_id::args.total_process_num]
+ image_file_list = image_file_list[args.process_id :: args.total_process_num]
text_sys = TextSystem(args)
is_visualize = True
font_path = args.vis_font_path
@@ -171,7 +178,6 @@ def main(args):
_st = time.time()
count = 0
for idx, image_file in enumerate(image_file_list):
-
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
@@ -192,26 +198,41 @@ def main(args):
total_time += elapse
if len(imgs) > 1:
logger.debug(
- str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
- % (image_file, elapse))
+ str(idx)
+ + "_"
+ + str(index)
+ + " Predict time of %s: %.3fs" % (image_file, elapse)
+ )
else:
logger.debug(
- str(idx) + " Predict time of %s: %.3fs" % (image_file,
- elapse))
+ str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)
+ )
for text, score in rec_res:
logger.debug("{}, {:.3f}".format(text, score))
- res = [{
- "transcription": rec_res[i][0],
- "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
- } for i in range(len(dt_boxes))]
+ res = [
+ {
+ "transcription": rec_res[i][0],
+ "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
+ }
+ for i in range(len(dt_boxes))
+ ]
if len(imgs) > 1:
- save_pred = os.path.basename(image_file) + '_' + str(
- index) + "\t" + json.dumps(
- res, ensure_ascii=False) + "\n"
+ save_pred = (
+ os.path.basename(image_file)
+ + "_"
+ + str(index)
+ + "\t"
+ + json.dumps(res, ensure_ascii=False)
+ + "\n"
+ )
else:
- save_pred = os.path.basename(image_file) + "\t" + json.dumps(
- res, ensure_ascii=False) + "\n"
+ save_pred = (
+ os.path.basename(image_file)
+ + "\t"
+ + json.dumps(res, ensure_ascii=False)
+ + "\n"
+ )
save_results.append(save_pred)
if is_visualize:
@@ -226,21 +247,23 @@ def main(args):
txts,
scores,
drop_score=drop_score,
- font_path=font_path)
+ font_path=font_path,
+ )
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
- save_file = image_file.replace('.pdf',
- '_' + str(index) + '.png')
+ save_file = image_file.replace(".pdf", "_" + str(index) + ".png")
else:
save_file = image_file
cv2.imwrite(
- os.path.join(draw_img_save_dir,
- os.path.basename(save_file)),
- draw_img[:, :, ::-1])
- logger.debug("The visualized image saved in {}".format(
- os.path.join(draw_img_save_dir, os.path.basename(
- save_file))))
+ os.path.join(draw_img_save_dir, os.path.basename(save_file)),
+ draw_img[:, :, ::-1],
+ )
+ logger.debug(
+ "The visualized image saved in {}".format(
+ os.path.join(draw_img_save_dir, os.path.basename(save_file))
+ )
+ )
logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark:
@@ -248,9 +271,8 @@ def main(args):
text_sys.text_recognizer.autolog.report()
with open(
- os.path.join(draw_img_save_dir, "system_results.txt"),
- 'w',
- encoding='utf-8') as f:
+ os.path.join(draw_img_save_dir, "system_results.txt"), "w", encoding="utf-8"
+ ) as f:
f.writelines(save_results)
@@ -260,10 +282,11 @@ def main(args):
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
- cmd = [sys.executable, "-u"] + sys.argv + [
- "--process_id={}".format(process_id),
- "--use_mp={}".format(False)
- ]
+ cmd = (
+ [sys.executable, "-u"]
+ + sys.argv
+ + ["--process_id={}".format(process_id), "--use_mp={}".format(False)]
+ )
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 34465f7957..61f4ffacc9 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -53,11 +53,11 @@ def init_args():
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--page_num", type=int, default=0)
- parser.add_argument("--det_algorithm", type=str, default='DB')
+ parser.add_argument("--det_algorithm", type=str, default="DB")
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
- parser.add_argument("--det_limit_type", type=str, default='max')
- parser.add_argument("--det_box_type", type=str, default='quad')
+ parser.add_argument("--det_limit_type", type=str, default="max")
+ parser.add_argument("--det_box_type", type=str, default="quad")
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
@@ -89,39 +89,38 @@ def init_args():
parser.add_argument("--fourier_degree", type=int, default=5)
# params for text recognizer
- parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
+ parser.add_argument("--rec_algorithm", type=str, default="SVTR_LCNet")
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
- "--rec_char_dict_path",
- type=str,
- default="./ppocr/utils/ppocr_keys_v1.txt")
+ "--rec_char_dict_path", type=str, default="./ppocr/utils/ppocr_keys_v1.txt"
+ )
parser.add_argument("--use_space_char", type=str2bool, default=True)
- parser.add_argument(
- "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
+ parser.add_argument("--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for e2e
- parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
+ parser.add_argument("--e2e_algorithm", type=str, default="PGNet")
parser.add_argument("--e2e_model_dir", type=str)
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
- parser.add_argument("--e2e_limit_type", type=str, default='max')
+ parser.add_argument("--e2e_limit_type", type=str, default="max")
# PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
parser.add_argument(
- "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
- parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
- parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
+ "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt"
+ )
+ parser.add_argument("--e2e_pgnet_valid_set", type=str, default="totaltext")
+ parser.add_argument("--e2e_pgnet_mode", type=str, default="fast")
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
- parser.add_argument("--label_list", type=list, default=['0', '180'])
+ parser.add_argument("--label_list", type=list, default=["0", "180"])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
@@ -136,8 +135,7 @@ def init_args():
parser.add_argument("--sr_batch_num", type=int, default=1)
#
- parser.add_argument(
- "--draw_img_save_dir", type=str, default="./inference_results")
+ parser.add_argument("--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
@@ -153,7 +151,12 @@ def init_args():
parser.add_argument("--use_onnx", type=str2bool, default=False)
# extended function
- parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery')
+ parser.add_argument(
+ "--return_word_box",
+ type=str2bool,
+ default=False,
+ help="Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery",
+ )
return parser
@@ -166,19 +169,19 @@ def parse_args():
def create_predictor(args, mode, logger):
if mode == "det":
model_dir = args.det_model_dir
- elif mode == 'cls':
+ elif mode == "cls":
model_dir = args.cls_model_dir
- elif mode == 'rec':
+ elif mode == "rec":
model_dir = args.rec_model_dir
- elif mode == 'table':
+ elif mode == "table":
model_dir = args.table_model_dir
- elif mode == 'ser':
+ elif mode == "ser":
model_dir = args.ser_model_dir
- elif mode == 're':
+ elif mode == "re":
model_dir = args.re_model_dir
elif mode == "sr":
model_dir = args.sr_model_dir
- elif mode == 'layout':
+ elif mode == "layout":
model_dir = args.layout_model_dir
else:
model_dir = args.e2e_model_dir
@@ -188,36 +191,39 @@ def create_predictor(args, mode, logger):
sys.exit(0)
if args.use_onnx:
import onnxruntime as ort
+
model_file_path = model_dir
if not os.path.exists(model_file_path):
- raise ValueError("not find model file path {}".format(
- model_file_path))
+ raise ValueError("not find model file path {}".format(model_file_path))
if args.use_gpu:
- sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
+ sess = ort.InferenceSession(
+ model_file_path, providers=["CUDAExecutionProvider"]
+ )
else:
sess = ort.InferenceSession(model_file_path)
return sess, sess.get_inputs()[0], None, None
else:
- file_names = ['model', 'inference']
+ file_names = ["model", "inference"]
for file_name in file_names:
- model_file_path = '{}/{}.pdmodel'.format(model_dir, file_name)
- params_file_path = '{}/{}.pdiparams'.format(model_dir, file_name)
- if os.path.exists(model_file_path) and os.path.exists(
- params_file_path):
+ model_file_path = "{}/{}.pdmodel".format(model_dir, file_name)
+ params_file_path = "{}/{}.pdiparams".format(model_dir, file_name)
+ if os.path.exists(model_file_path) and os.path.exists(params_file_path):
break
if not os.path.exists(model_file_path):
raise ValueError(
- "not find model.pdmodel or inference.pdmodel in {}".format(
- model_dir))
+ "not find model.pdmodel or inference.pdmodel in {}".format(model_dir)
+ )
if not os.path.exists(params_file_path):
raise ValueError(
"not find model.pdiparams or inference.pdiparams in {}".format(
- model_dir))
+ model_dir
+ )
+ )
config = inference.Config(model_file_path, params_file_path)
- if hasattr(args, 'precision'):
+ if hasattr(args, "precision"):
if args.precision == "fp16" and args.use_tensorrt:
precision = inference.PrecisionType.Half
elif args.precision == "int8":
@@ -239,21 +245,18 @@ def create_predictor(args, mode, logger):
workspace_size=1 << 30,
precision_mode=precision,
max_batch_size=args.max_batch_size,
- min_subgraph_size=args.
- min_subgraph_size, # skip the minmum trt subgraph
- use_calib_mode=False)
+ min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph
+ use_calib_mode=False,
+ )
# collect shape
- trt_shape_f = os.path.join(model_dir,
- f"{mode}_trt_dynamic_shape.txt")
+ trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt")
if not os.path.exists(trt_shape_f):
config.collect_shape_range_info(trt_shape_f)
- logger.info(
- f"collect dynamic shape info into : {trt_shape_f}")
+ logger.info(f"collect dynamic shape info into : {trt_shape_f}")
try:
- config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f,
- True)
+ config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
except Exception as E:
logger.info(E)
logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
@@ -282,9 +285,9 @@ def create_predictor(args, mode, logger):
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.delete_pass("matmul_transpose_reshape_fuse_pass")
- if mode == 're':
+ if mode == "re":
config.delete_pass("simplify_with_basic_ops_pass")
- if mode == 'table':
+ if mode == "table":
config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False)
config.switch_ir_optim(True)
@@ -292,7 +295,7 @@ def create_predictor(args, mode, logger):
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
- if mode in ['ser', 're']:
+ if mode in ["ser", "re"]:
input_tensor = []
for name in input_names:
input_tensor.append(predictor.get_input_handle(name))
@@ -306,10 +309,8 @@ def create_predictor(args, mode, logger):
def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
- if mode == "rec" and args.rec_algorithm in [
- "CRNN", "SVTR_LCNet", "SVTR_HGNet"
- ]:
- output_name = 'softmax_0.tmp_0'
+ if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet", "SVTR_HGNet"]:
+ output_name = "softmax_0.tmp_0"
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
else:
@@ -352,7 +353,8 @@ def draw_e2e_res(dt_boxes, strs, img_path):
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
- thickness=1)
+ thickness=1,
+ )
return src_im
@@ -375,12 +377,14 @@ def resize_img(img, input_size=600):
return img
-def draw_ocr(image,
- boxes,
- txts=None,
- scores=None,
- drop_score=0.5,
- font_path="./doc/fonts/simfang.ttf"):
+def draw_ocr(
+ image,
+ boxes,
+ txts=None,
+ scores=None,
+ drop_score=0.5,
+ font_path="./doc/fonts/simfang.ttf",
+):
"""
Visualize the results of OCR detection and recognition
args:
@@ -397,8 +401,7 @@ def draw_ocr(image,
scores = [1] * len(boxes)
box_num = len(boxes)
for i in range(box_num):
- if scores is not None and (scores[i] < drop_score or
- math.isnan(scores[i])):
+ if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])):
continue
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
@@ -410,18 +413,21 @@ def draw_ocr(image,
img_h=img.shape[0],
img_w=600,
threshold=drop_score,
- font_path=font_path)
+ font_path=font_path,
+ )
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
return img
return image
-def draw_ocr_box_txt(image,
- boxes,
- txts=None,
- scores=None,
- drop_score=0.5,
- font_path="./doc/fonts/simfang.ttf"):
+def draw_ocr_box_txt(
+ image,
+ boxes,
+ txts=None,
+ scores=None,
+ drop_score=0.5,
+ font_path="./doc/fonts/simfang.ttf",
+):
h, w = image.height, image.width
img_left = image.copy()
img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
@@ -433,15 +439,14 @@ def draw_ocr_box_txt(image,
for idx, (box, txt) in enumerate(zip(boxes, txts)):
if scores is not None and scores[idx] < drop_score:
continue
- color = (random.randint(0, 255), random.randint(0, 255),
- random.randint(0, 255))
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
draw_left.polygon(box, fill=color)
img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
pts = np.array(box, np.int32).reshape((-1, 1, 2))
cv2.polylines(img_right_text, [pts], True, color, 1)
img_right = cv2.bitwise_and(img_right, img_right_text)
img_left = Image.blend(image, img_left, 0.5)
- img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
+ img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
return np.array(img_show)
@@ -449,26 +454,29 @@ def draw_ocr_box_txt(image,
def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
box_height = int(
- math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2))
+ math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+ )
box_width = int(
- math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2))
+ math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+ )
if box_height > 2 * box_width and box_height > 30:
- img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255))
+ img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_height, box_width), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
img_text = img_text.transpose(Image.ROTATE_270)
else:
- img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255))
+ img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_width, box_height), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
pts1 = np.float32(
- [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
+ [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+ )
pts2 = np.array(box, dtype=np.float32)
M = cv2.getPerspectiveTransform(pts1, pts2)
@@ -479,18 +487,19 @@ def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
img_size,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
- borderValue=(255, 255, 255))
+ borderValue=(255, 255, 255),
+ )
return img_right_text
def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
font_size = int(sz[1] * 0.99)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
- if int(PIL.__version__.split('.')[0]) < 10:
+ if int(PIL.__version__.split(".")[0]) < 10:
length = font.getsize(txt)[0]
else:
length = font.getlength(txt)
-
+
if length > sz[0]:
font_size = int(font_size * sz[0] / length)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
@@ -508,6 +517,7 @@ def str_count(s):
the number of Chinese characters
"""
import string
+
count_zh = count_pu = 0
s_len = len(s)
en_dg_count = 0
@@ -521,12 +531,9 @@ def str_count(s):
return s_len - math.ceil(en_dg_count / 2)
-def text_visual(texts,
- scores,
- img_h=400,
- img_w=600,
- threshold=0.,
- font_path="./doc/simfang.ttf"):
+def text_visual(
+ texts, scores, img_h=400, img_w=600, threshold=0.0, font_path="./doc/simfang.ttf"
+):
"""
create new blank img and draw txt on it
args:
@@ -539,11 +546,12 @@ def text_visual(texts,
"""
if scores is not None:
assert len(texts) == len(
- scores), "The number of txts and corresponding scores must match"
+ scores
+ ), "The number of txts and corresponding scores must match"
def create_blank_img():
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
- blank_img[:, img_w - 1:] = 0
+ blank_img[:, img_w - 1 :] = 0
blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img)
return blank_img, draw_txt
@@ -565,23 +573,23 @@ def create_blank_img():
first_line = True
while str_count(txt) >= img_w // font_size - 4:
tmp = txt
- txt = tmp[:img_w // font_size - 4]
+ txt = tmp[: img_w // font_size - 4]
if first_line:
- new_txt = str(index) + ': ' + txt
+ new_txt = str(index) + ": " + txt
first_line = False
else:
- new_txt = ' ' + txt
+ new_txt = " " + txt
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
- txt = tmp[img_w // font_size - 4:]
+ txt = tmp[img_w // font_size - 4 :]
if count >= img_h // gap - 1:
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
if first_line:
- new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
+ new_txt = str(index) + ": " + txt + " " + "%.3f" % (scores[idx])
else:
- new_txt = " " + txt + " " + '%.3f' % (scores[idx])
+ new_txt = " " + txt + " " + "%.3f" % (scores[idx])
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
# whether add new blank img or not
if count >= img_h // gap - 1 and idx + 1 < len(texts):
@@ -599,7 +607,8 @@ def create_blank_img():
def base64_to_cv2(b64str):
import base64
- data = base64.b64decode(b64str.encode('utf8'))
+
+ data = base64.b64decode(b64str.encode("utf8"))
data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
@@ -608,7 +617,7 @@ def base64_to_cv2(b64str):
def draw_boxes(image, boxes, scores=None, drop_score=0.5):
if scores is None:
scores = [1] * len(boxes)
- for (box, score) in zip(boxes, scores):
+ for box, score in zip(boxes, scores):
if score < drop_score:
continue
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
@@ -617,7 +626,7 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
def get_rotate_crop_image(img, points):
- '''
+ """
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
@@ -626,25 +635,34 @@ def get_rotate_crop_image(img, points):
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
- '''
+ """
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
- np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])))
+ np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])
+ )
+ )
img_crop_height = int(
max(
- np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])))
- pts_std = np.float32([[0, 0], [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]])
+ np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])
+ )
+ )
+ pts_std = np.float32(
+ [
+ [0, 0],
+ [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height],
+ ]
+ )
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
- M, (img_crop_width, img_crop_height),
+ M,
+ (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC)
+ flags=cv2.INTER_CUBIC,
+ )
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
@@ -675,10 +693,12 @@ def get_minarea_rect_crop(img, points):
def check_gpu(use_gpu):
- if use_gpu and (not paddle.is_compiled_with_cuda() or paddle.device.get_device() == 'cpu'):
+ if use_gpu and (
+ not paddle.is_compiled_with_cuda() or paddle.device.get_device() == "cpu"
+ ):
use_gpu = False
return use_gpu
-if __name__ == '__main__':
+if __name__ == "__main__":
pass
diff --git a/tools/infer_cls.py b/tools/infer_cls.py
index 7fd6b536fb..6c26ff4aa8 100755
--- a/tools/infer_cls.py
+++ b/tools/infer_cls.py
@@ -23,9 +23,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import paddle
@@ -38,37 +38,36 @@
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
load_model(config, model)
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
op_name = list(op)[0]
- if 'Label' in op_name:
+ if "Label" in op_name:
continue
- elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image']
+ elif op_name == "KeepKeys":
+ op[op_name]["keep_keys"] = ["image"]
elif op_name == "SSLRotateResize":
op[op_name]["mode"] = "test"
transforms.append(op)
- global_config['infer_mode'] = True
+ global_config["infer_mode"] = True
ops = create_operators(transforms, global_config)
model.eval()
- for file in get_image_file_list(config['Global']['infer_img']):
+ for file in get_image_file_list(config["Global"]["infer_img"]):
logger.info("infer_img: {}".format(file))
- with open(file, 'rb') as f:
+ with open(file, "rb") as f:
img = f.read()
- data = {'image': img}
+ data = {"image": img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
@@ -76,10 +75,10 @@ def main():
preds = model(images)
post_result = post_process_class(preds)
for rec_result in post_result:
- logger.info('\t result: {}'.format(rec_result))
+ logger.info("\t result: {}".format(rec_result))
logger.info("success!")
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/infer_det.py b/tools/infer_det.py
index 097d032b99..6c029899d4 100755
--- a/tools/infer_det.py
+++ b/tools/infer_det.py
@@ -23,9 +23,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
@@ -41,6 +41,7 @@
def draw_det_res(dt_boxes, config, img, img_name, save_path):
import cv2
+
src_im = img
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
@@ -54,38 +55,38 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path):
@paddle.no_grad()
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build model
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
load_model(config, model)
# build post process
- post_process_class = build_post_process(config['PostProcess'])
+ post_process_class = build_post_process(config["PostProcess"])
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
op_name = list(op)[0]
- if 'Label' in op_name:
+ if "Label" in op_name:
continue
- elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image', 'shape']
+ elif op_name == "KeepKeys":
+ op[op_name]["keep_keys"] = ["image", "shape"]
transforms.append(op)
ops = create_operators(transforms, global_config)
- save_res_path = config['Global']['save_res_path']
+ save_res_path = config["Global"]["save_res_path"]
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
with open(save_res_path, "wb") as fout:
- for file in get_image_file_list(config['Global']['infer_img']):
+ for file in get_image_file_list(config["Global"]["infer_img"]):
logger.info("infer_img: {}".format(file))
- with open(file, 'rb') as f:
+ with open(file, "rb") as f:
img = f.read()
- data = {'image': img}
+ data = {"image": img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
@@ -101,26 +102,28 @@ def main():
if isinstance(post_result, dict):
det_box_json = {}
for k in post_result.keys():
- boxes = post_result[k][0]['points']
+ boxes = post_result[k][0]["points"]
dt_boxes_list = []
for box in boxes:
tmp_json = {"transcription": ""}
- tmp_json['points'] = np.array(box).tolist()
+ tmp_json["points"] = np.array(box).tolist()
dt_boxes_list.append(tmp_json)
det_box_json[k] = dt_boxes_list
- save_det_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/det_results_{}/".format(k)
+ save_det_path = os.path.dirname(
+ config["Global"]["save_res_path"]
+ ) + "/det_results_{}/".format(k)
draw_det_res(boxes, config, src_img, file, save_det_path)
else:
- boxes = post_result[0]['points']
+ boxes = post_result[0]["points"]
dt_boxes_json = []
# write result
for box in boxes:
tmp_json = {"transcription": ""}
- tmp_json['points'] = np.array(box).tolist()
+ tmp_json["points"] = np.array(box).tolist()
dt_boxes_json.append(tmp_json)
- save_det_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/det_results/"
+ save_det_path = (
+ os.path.dirname(config["Global"]["save_res_path"]) + "/det_results/"
+ )
draw_det_res(boxes, config, src_img, file, save_det_path)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
@@ -128,6 +131,6 @@ def main():
logger.info("success!")
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py
index 37fdcbaadc..02cd68a6c4 100755
--- a/tools/infer_e2e.py
+++ b/tools/infer_e2e.py
@@ -23,9 +23,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
@@ -41,15 +41,12 @@
import math
-def draw_e2e_res_for_chinese(image,
- boxes,
- txts,
- config,
- img_name,
- font_path="./doc/simfang.ttf"):
+def draw_e2e_res_for_chinese(
+ image, boxes, txts, config, img_name, font_path="./doc/simfang.ttf"
+):
h, w = image.height, image.width
img_left = image.copy()
- img_right = Image.new('RGB', (w, h), (255, 255, 255))
+ img_right = Image.new("RGB", (w, h), (255, 255, 255))
import random
@@ -59,19 +56,17 @@ def draw_e2e_res_for_chinese(image,
for idx, (box, txt) in enumerate(zip(boxes, txts)):
box = np.array(box)
box = [tuple(x) for x in box]
- color = (random.randint(0, 255), random.randint(0, 255),
- random.randint(0, 255))
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
draw_left.polygon(box, fill=color)
draw_right.polygon(box, outline=color)
font = ImageFont.truetype(font_path, 15, encoding="utf-8")
draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
- img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
+ img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
- save_e2e_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/e2e_results/"
+ save_e2e_path = os.path.dirname(config["Global"]["save_res_path"]) + "/e2e_results/"
if not os.path.exists(save_e2e_path):
os.makedirs(save_e2e_path)
save_path = os.path.join(save_e2e_path, os.path.basename(img_name))
@@ -92,9 +87,11 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name):
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
- thickness=1)
- save_det_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/e2e_results/"
+ thickness=1,
+ )
+ save_det_path = (
+ os.path.dirname(config["Global"]["save_res_path"]) + "/e2e_results/"
+ )
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, os.path.basename(img_name))
@@ -103,72 +100,71 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name):
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build model
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
load_model(config, model)
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
op_name = list(op)[0]
- if 'Label' in op_name:
+ if "Label" in op_name:
continue
- elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image', 'shape']
+ elif op_name == "KeepKeys":
+ op[op_name]["keep_keys"] = ["image", "shape"]
transforms.append(op)
ops = create_operators(transforms, global_config)
- save_res_path = config['Global']['save_res_path']
+ save_res_path = config["Global"]["save_res_path"]
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
with open(save_res_path, "wb") as fout:
- for file in get_image_file_list(config['Global']['infer_img']):
+ for file in get_image_file_list(config["Global"]["infer_img"]):
logger.info("infer_img: {}".format(file))
- with open(file, 'rb') as f:
+ with open(file, "rb") as f:
img = f.read()
- data = {'image': img}
+ data = {"image": img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
- points, strs = post_result['points'], post_result['texts']
+ points, strs = post_result["points"], post_result["texts"]
# write result
dt_boxes_json = []
for poly, str in zip(points, strs):
tmp_json = {"transcription": str}
- tmp_json['points'] = poly.tolist()
+ tmp_json["points"] = poly.tolist()
dt_boxes_json.append(tmp_json)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
- if global_config['infer_visual_type'] == 'EN':
+ if global_config["infer_visual_type"] == "EN":
draw_e2e_res(points, strs, config, src_img, file)
- elif global_config['infer_visual_type'] == 'CN':
- src_img = Image.fromarray(
- cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
+ elif global_config["infer_visual_type"] == "CN":
+ src_img = Image.fromarray(cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
draw_e2e_res_for_chinese(
src_img,
points,
strs,
config,
file,
- font_path="./doc/fonts/simfang.ttf")
+ font_path="./doc/fonts/simfang.ttf",
+ )
logger.info("success!")
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/infer_kie.py b/tools/infer_kie.py
index 9375434cc8..f3efaf92f1 100755
--- a/tools/infer_kie.py
+++ b/tools/infer_kie.py
@@ -24,9 +24,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import paddle
@@ -59,77 +59,88 @@ def draw_kie_result(batch, node, idx_to_cls, count):
for i, box in enumerate(boxes):
if i >= len(node_pred_label):
break
- new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
- [box[0], box[3]]]
+ new_box = [
+ [box[0], box[1]],
+ [box[2], box[1]],
+ [box[2], box[3]],
+ [box[0], box[3]],
+ ]
Pts = np.array([new_box], np.int32)
cv2.polylines(
- img, [Pts.reshape((-1, 1, 2))],
- True,
- color=(255, 255, 0),
- thickness=1)
+ img, [Pts.reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1
+ )
x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box]))
pred_label = node_pred_label[i]
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
- pred_score = '{:.2f}'.format(node_pred_score[i])
- text = pred_label + '(' + pred_score + ')'
- cv2.putText(pred_img, text, (x_min * 2, y_min),
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
+ pred_score = "{:.2f}".format(node_pred_score[i])
+ text = pred_label + "(" + pred_score + ")"
+ cv2.putText(
+ pred_img,
+ text,
+ (x_min * 2, y_min),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5,
+ (255, 0, 0),
+ 1,
+ )
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img
vis_img[:, w:] = pred_img
- save_kie_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/kie_results/"
+ save_kie_path = os.path.dirname(config["Global"]["save_res_path"]) + "/kie_results/"
if not os.path.exists(save_kie_path):
os.makedirs(save_kie_path)
save_path = os.path.join(save_kie_path, str(count) + ".png")
cv2.imwrite(save_path, vis_img)
logger.info("The Kie Image saved in {}".format(save_path))
+
def write_kie_result(fout, node, data):
"""
Write infer result to output file, sorted by the predict label of each line.
The format keeps the same as the input with additional score attribute.
"""
import json
- label = data['label']
+
+ label = data["label"]
annotations = json.loads(label)
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
res = []
for i, label in enumerate(node_pred_label):
- pred_score = '{:.2f}'.format(node_pred_score[i])
+ pred_score = "{:.2f}".format(node_pred_score[i])
pred_res = {
- 'label': label,
- 'transcription': annotations[i]['transcription'],
- 'score': pred_score,
- 'points': annotations[i]['points'],
- }
+ "label": label,
+ "transcription": annotations[i]["transcription"],
+ "score": pred_score,
+ "points": annotations[i]["points"],
+ }
res.append(pred_res)
- res.sort(key=lambda x: x['label'])
- fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
+ res.sort(key=lambda x: x["label"])
+ fout.writelines([json.dumps(res, ensure_ascii=False) + "\n"])
+
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build model
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
load_model(config, model)
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
transforms.append(op)
- data_dir = config['Eval']['dataset']['data_dir']
+ data_dir = config["Eval"]["dataset"]["data_dir"]
ops = create_operators(transforms, global_config)
- save_res_path = config['Global']['save_res_path']
- class_path = config['Global']['class_path']
+ save_res_path = config["Global"]["save_res_path"]
+ class_path = config["Global"]["class_path"]
idx_to_cls = read_class_list(class_path)
os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
@@ -138,25 +149,23 @@ def main():
warmup_times = 0
count_t = []
with open(save_res_path, "w") as fout:
- with open(config['Global']['infer_img'], "rb") as f:
+ with open(config["Global"]["infer_img"], "rb") as f:
lines = f.readlines()
for index, data_line in enumerate(lines):
if index == 10:
warmup_t = time.time()
- data_line = data_line.decode('utf-8')
+ data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").split("\t")
img_path, label = data_dir + "/" + substr[0], substr[1]
- data = {'img_path': img_path, 'label': label}
- with open(data['img_path'], 'rb') as f:
+ data = {"img_path": img_path, "label": label}
+ with open(data["img_path"], "rb") as f:
img = f.read()
- data['image'] = img
+ data["image"] = img
st = time.time()
batch = transform(data, ops)
batch_pred = [0] * len(batch)
for i in range(len(batch)):
- batch_pred[i] = paddle.to_tensor(
- np.expand_dims(
- batch[i], axis=0))
+ batch_pred[i] = paddle.to_tensor(np.expand_dims(batch[i], axis=0))
st = time.time()
node, edge = model(batch_pred)
node = F.softmax(node, -1)
@@ -165,12 +174,13 @@ def main():
write_kie_result(fout, node, data)
fout.close()
logger.info("success!")
- logger.info("It took {} s for predict {} images.".format(
- np.sum(count_t), len(count_t)))
+ logger.info(
+ "It took {} s for predict {} images.".format(np.sum(count_t), len(count_t))
+ )
ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
logger.info("The ips is {} images/s".format(ips))
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/infer_kie_token_ser.py b/tools/infer_kie_token_ser.py
index 2fc5749b9c..8fe25acdfa 100755
--- a/tools/infer_kie_token_ser.py
+++ b/tools/infer_kie_token_ser.py
@@ -23,9 +23,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
import paddle
@@ -42,6 +42,7 @@
def to_tensor(data):
import numbers
from collections import defaultdict
+
data_dict = defaultdict(list)
to_tensor_idxs = []
@@ -57,18 +58,18 @@ def to_tensor(data):
class SerPredictor(object):
def __init__(self, config):
- global_config = config['Global']
- self.algorithm = config['Architecture']["algorithm"]
+ global_config = config["Global"]
+ self.algorithm = config["Architecture"]["algorithm"]
# build post process
- self.post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ self.post_process_class = build_post_process(
+ config["PostProcess"], global_config
+ )
# build model
- self.model = build_model(config['Architecture'])
+ self.model = build_model(config["Architecture"])
- load_model(
- config, self.model, model_type=config['Architecture']["model_type"])
+ load_model(config, self.model, model_type=config["Architecture"]["model_type"])
from paddleocr import PaddleOCR
@@ -77,30 +78,38 @@ def __init__(self, config):
show_log=False,
rec_model_dir=global_config.get("kie_rec_model_dir", None),
det_model_dir=global_config.get("kie_det_model_dir", None),
- use_gpu=global_config['use_gpu'])
+ use_gpu=global_config["use_gpu"],
+ )
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
op_name = list(op)[0]
- if 'Label' in op_name:
- op[op_name]['ocr_engine'] = self.ocr_engine
- elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = [
- 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
- 'image', 'labels', 'segment_offset_id', 'ocr_info',
- 'entities'
+ if "Label" in op_name:
+ op[op_name]["ocr_engine"] = self.ocr_engine
+ elif op_name == "KeepKeys":
+ op[op_name]["keep_keys"] = [
+ "input_ids",
+ "bbox",
+ "attention_mask",
+ "token_type_ids",
+ "image",
+ "labels",
+ "segment_offset_id",
+ "ocr_info",
+ "entities",
]
transforms.append(op)
if config["Global"].get("infer_mode", None) is None:
- global_config['infer_mode'] = True
- self.ops = create_operators(config['Eval']['dataset']['transforms'],
- global_config)
+ global_config["infer_mode"] = True
+ self.ops = create_operators(
+ config["Eval"]["dataset"]["transforms"], global_config
+ )
self.model.eval()
def __call__(self, data):
- with open(data["img_path"], 'rb') as f:
+ with open(data["img_path"], "rb") as f:
img = f.read()
data["image"] = img
batch = transform(data, self.ops)
@@ -108,50 +117,62 @@ def __call__(self, data):
preds = self.model(batch)
post_result = self.post_process_class(
- preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
+ preds, segment_offset_ids=batch[6], ocr_infos=batch[7]
+ )
return post_result, batch
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
- os.makedirs(config['Global']['save_res_path'], exist_ok=True)
+ os.makedirs(config["Global"]["save_res_path"], exist_ok=True)
ser_engine = SerPredictor(config)
if config["Global"].get("infer_mode", None) is False:
- data_dir = config['Eval']['dataset']['data_dir']
- with open(config['Global']['infer_img'], "rb") as f:
+ data_dir = config["Eval"]["dataset"]["data_dir"]
+ with open(config["Global"]["infer_img"], "rb") as f:
infer_imgs = f.readlines()
else:
- infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ infer_imgs = get_image_file_list(config["Global"]["infer_img"])
with open(
- os.path.join(config['Global']['save_res_path'],
- "infer_results.txt"),
- "w",
- encoding='utf-8') as fout:
+ os.path.join(config["Global"]["save_res_path"], "infer_results.txt"),
+ "w",
+ encoding="utf-8",
+ ) as fout:
for idx, info in enumerate(infer_imgs):
if config["Global"].get("infer_mode", None) is False:
- data_line = info.decode('utf-8')
+ data_line = info.decode("utf-8")
substr = data_line.strip("\n").split("\t")
img_path = os.path.join(data_dir, substr[0])
- data = {'img_path': img_path, 'label': substr[1]}
+ data = {"img_path": img_path, "label": substr[1]}
else:
img_path = info
- data = {'img_path': img_path}
+ data = {"img_path": img_path}
save_img_path = os.path.join(
- config['Global']['save_res_path'],
- os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
+ config["Global"]["save_res_path"],
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg",
+ )
result, _ = ser_engine(data)
result = result[0]
- fout.write(img_path + "\t" + json.dumps(
- {
- "ocr_info": result,
- }, ensure_ascii=False) + "\n")
+ fout.write(
+ img_path
+ + "\t"
+ + json.dumps(
+ {
+ "ocr_info": result,
+ },
+ ensure_ascii=False,
+ )
+ + "\n"
+ )
img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
+ logger.info(
+ "process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path
+ )
+ )
diff --git a/tools/infer_kie_token_ser_re.py b/tools/infer_kie_token_ser_re.py
index 76120a913f..a9589ca813 100755
--- a/tools/infer_kie_token_ser_re.py
+++ b/tools/infer_kie_token_ser_re.py
@@ -23,9 +23,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
import paddle
@@ -46,23 +46,23 @@ class ReArgsParser(ArgsParser):
def __init__(self):
super(ReArgsParser, self).__init__()
self.add_argument(
- "-c_ser", "--config_ser", help="ser configuration file to use")
+ "-c_ser", "--config_ser", help="ser configuration file to use"
+ )
self.add_argument(
- "-o_ser",
- "--opt_ser",
- nargs='+',
- help="set ser configuration options ")
+ "-o_ser", "--opt_ser", nargs="+", help="set ser configuration options "
+ )
def parse_args(self, argv=None):
args = super(ReArgsParser, self).parse_args(argv)
- assert args.config_ser is not None, \
- "Please specify --config_ser=ser_configure_file_path."
+ assert (
+ args.config_ser is not None
+ ), "Please specify --config_ser=ser_configure_file_path."
args.opt_ser = self._parse_opt(args.opt_ser)
return args
def make_input(ser_inputs, ser_results):
- entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+ entities_labels = {"HEADER": 0, "QUESTION": 1, "ANSWER": 2}
batch_size, max_seq_len = ser_inputs[0].shape[:2]
entities = ser_inputs[8][0]
ser_results = ser_results[0]
@@ -74,20 +74,20 @@ def make_input(ser_inputs, ser_results):
label = []
entity_idx_dict = {}
for i, (res, entity) in enumerate(zip(ser_results, entities)):
- if res['pred'] == 'O':
+ if res["pred"] == "O":
continue
entity_idx_dict[len(start)] = i
- start.append(entity['start'])
- end.append(entity['end'])
- label.append(entities_labels[res['pred']])
+ start.append(entity["start"])
+ end.append(entity["end"])
+ label.append(entities_labels[res["pred"]])
entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
entities[0, 0] = len(start)
- entities[1:len(start) + 1, 0] = start
+ entities[1 : len(start) + 1, 0] = start
entities[0, 1] = len(end)
- entities[1:len(end) + 1, 1] = end
+ entities[1 : len(end) + 1, 1] = end
entities[0, 2] = len(label)
- entities[1:len(label) + 1, 2] = label
+ entities[1 : len(label) + 1, 2] = label
# relations
head = []
@@ -100,9 +100,9 @@ def make_input(ser_inputs, ser_results):
relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
relations[0, 0] = len(head)
- relations[1:len(head) + 1, 0] = head
+ relations[1 : len(head) + 1, 0] = head
relations[0, 1] = len(tail)
- relations[1:len(tail) + 1, 1] = tail
+ relations[1 : len(tail) + 1, 1] = tail
entities = np.expand_dims(entities, axis=0)
entities = np.repeat(entities, batch_size, axis=0)
@@ -123,23 +123,23 @@ def make_input(ser_inputs, ser_results):
class SerRePredictor(object):
def __init__(self, config, ser_config):
- global_config = config['Global']
+ global_config = config["Global"]
if "infer_mode" in global_config:
ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
self.ser_engine = SerPredictor(ser_config)
- # init re model
+ # init re model
# build post process
- self.post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ self.post_process_class = build_post_process(
+ config["PostProcess"], global_config
+ )
# build model
- self.model = build_model(config['Architecture'])
+ self.model = build_model(config["Architecture"])
- load_model(
- config, self.model, model_type=config['Architecture']["model_type"])
+ load_model(config, self.model, model_type=config["Architecture"]["model_type"])
self.model.eval()
@@ -150,9 +150,8 @@ def __call__(self, data):
re_input.pop(4)
preds = self.model(re_input)
post_result = self.post_process_class(
- preds,
- ser_results=ser_results,
- entity_idx_dict_batch=entity_idx_dict_batch)
+ preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch
+ )
return post_result
@@ -167,59 +166,61 @@ def preprocess():
logger = get_logger()
# check if set use_gpu=True in paddlepaddle cpu version
- use_gpu = config['Global']['use_gpu']
+ use_gpu = config["Global"]["use_gpu"]
- device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
+ device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
device = paddle.set_device(device)
- logger.info('{} re config {}'.format('*' * 10, '*' * 10))
+ logger.info("{} re config {}".format("*" * 10, "*" * 10))
print_dict(config, logger)
- logger.info('\n')
- logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
+ logger.info("\n")
+ logger.info("{} ser config {}".format("*" * 10, "*" * 10))
print_dict(ser_config, logger)
- logger.info('train with paddle {} and device {}'.format(paddle.__version__,
- device))
+ logger.info("train with paddle {} and device {}".format(paddle.__version__, device))
return config, ser_config, device, logger
-if __name__ == '__main__':
+if __name__ == "__main__":
config, ser_config, device, logger = preprocess()
- os.makedirs(config['Global']['save_res_path'], exist_ok=True)
+ os.makedirs(config["Global"]["save_res_path"], exist_ok=True)
ser_re_engine = SerRePredictor(config, ser_config)
if config["Global"].get("infer_mode", None) is False:
- data_dir = config['Eval']['dataset']['data_dir']
- with open(config['Global']['infer_img'], "rb") as f:
+ data_dir = config["Eval"]["dataset"]["data_dir"]
+ with open(config["Global"]["infer_img"], "rb") as f:
infer_imgs = f.readlines()
else:
- infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ infer_imgs = get_image_file_list(config["Global"]["infer_img"])
with open(
- os.path.join(config['Global']['save_res_path'],
- "infer_results.txt"),
- "w",
- encoding='utf-8') as fout:
+ os.path.join(config["Global"]["save_res_path"], "infer_results.txt"),
+ "w",
+ encoding="utf-8",
+ ) as fout:
for idx, info in enumerate(infer_imgs):
if config["Global"].get("infer_mode", None) is False:
- data_line = info.decode('utf-8')
+ data_line = info.decode("utf-8")
substr = data_line.strip("\n").split("\t")
img_path = os.path.join(data_dir, substr[0])
- data = {'img_path': img_path, 'label': substr[1]}
+ data = {"img_path": img_path, "label": substr[1]}
else:
img_path = info
- data = {'img_path': img_path}
+ data = {"img_path": img_path}
save_img_path = os.path.join(
- config['Global']['save_res_path'],
- os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
+ config["Global"]["save_res_path"],
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg",
+ )
result = ser_re_engine(data)
result = result[0]
- fout.write(img_path + "\t" + json.dumps(
- result, ensure_ascii=False) + "\n")
+ fout.write(img_path + "\t" + json.dumps(result, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
+ logger.info(
+ "process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path
+ )
+ )
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 8a7d599356..0e04c8b636 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -24,9 +24,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import paddle
@@ -39,96 +39,98 @@
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- if config["Architecture"]["algorithm"] in ["Distillation",
- ]: # distillation model
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ if config["Architecture"]["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
for key in config["Architecture"]["Models"]:
- if config["Architecture"]["Models"][key]["Head"][
- "name"] == 'MultiHead': # multi head
+ if (
+ config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
+ ): # multi head
out_channels_list = {}
- if config['PostProcess'][
- 'name'] == 'DistillationSARLabelDecode':
+ if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
- if config['PostProcess'][
- 'name'] == 'DistillationNRTRLabelDecode':
+ if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- out_channels_list['NRTRLabelDecode'] = char_num + 3
- config['Architecture']['Models'][key]['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels_list"
+ ] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
- "out_channels"] = char_num
- elif config['Architecture']['Head'][
- 'name'] == 'MultiHead': # multi head
+ "out_channels"
+ ] = char_num
+ elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
out_channels_list = {}
- char_num = len(getattr(post_process_class, 'character'))
- if config['PostProcess']['name'] == 'SARLabelDecode':
+ char_num = len(getattr(post_process_class, "character"))
+ if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
- if config['PostProcess']['name'] == 'NRTRLabelDecode':
+ if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- out_channels_list['NRTRLabelDecode'] = char_num + 3
- config['Architecture']['Head'][
- 'out_channels_list'] = out_channels_list
+ out_channels_list["CTCLabelDecode"] = char_num
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+ config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
load_model(config, model)
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
op_name = list(op)[0]
- if 'Label' in op_name:
+ if "Label" in op_name:
continue
- elif op_name in ['RecResizeImg']:
- op[op_name]['infer_mode'] = True
- elif op_name == 'KeepKeys':
- if config['Architecture']['algorithm'] == "SRN":
- op[op_name]['keep_keys'] = [
- 'image', 'encoder_word_pos', 'gsrm_word_pos',
- 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
+ elif op_name in ["RecResizeImg"]:
+ op[op_name]["infer_mode"] = True
+ elif op_name == "KeepKeys":
+ if config["Architecture"]["algorithm"] == "SRN":
+ op[op_name]["keep_keys"] = [
+ "image",
+ "encoder_word_pos",
+ "gsrm_word_pos",
+ "gsrm_slf_attn_bias1",
+ "gsrm_slf_attn_bias2",
]
- elif config['Architecture']['algorithm'] == "SAR":
- op[op_name]['keep_keys'] = ['image', 'valid_ratio']
- elif config['Architecture']['algorithm'] == "RobustScanner":
- op[op_name][
- 'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
+ elif config["Architecture"]["algorithm"] == "SAR":
+ op[op_name]["keep_keys"] = ["image", "valid_ratio"]
+ elif config["Architecture"]["algorithm"] == "RobustScanner":
+ op[op_name]["keep_keys"] = ["image", "valid_ratio", "word_positons"]
else:
- op[op_name]['keep_keys'] = ['image']
+ op[op_name]["keep_keys"] = ["image"]
transforms.append(op)
- global_config['infer_mode'] = True
+ global_config["infer_mode"] = True
ops = create_operators(transforms, global_config)
- save_res_path = config['Global'].get('save_res_path',
- "./output/rec/predicts_rec.txt")
+ save_res_path = config["Global"].get(
+ "save_res_path", "./output/rec/predicts_rec.txt"
+ )
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
-
- infer_imgs = config['Global']['infer_img']
- infer_list = config['Global'].get('infer_list', None)
+
+ infer_imgs = config["Global"]["infer_img"]
+ infer_list = config["Global"].get("infer_list", None)
with open(save_res_path, "w") as fout:
for file in get_image_file_list(infer_imgs, infer_list=infer_list):
logger.info("infer_img: {}".format(file))
- with open(file, 'rb') as f:
+ with open(file, "rb") as f:
img = f.read()
- data = {'image': img}
+ data = {"image": img}
batch = transform(data, ops)
- if config['Architecture']['algorithm'] == "SRN":
+ if config["Architecture"]["algorithm"] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
@@ -138,32 +140,32 @@ def main():
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
- paddle.to_tensor(gsrm_slf_attn_bias2_list)
+ paddle.to_tensor(gsrm_slf_attn_bias2_list),
]
- if config['Architecture']['algorithm'] == "SAR":
+ if config["Architecture"]["algorithm"] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]
- if config['Architecture']['algorithm'] == "RobustScanner":
+ if config["Architecture"]["algorithm"] == "RobustScanner":
valid_ratio = np.expand_dims(batch[1], axis=0)
word_positons = np.expand_dims(batch[2], axis=0)
img_metas = [
paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
- if config['Architecture']['algorithm'] == "CAN":
+ if config["Architecture"]["algorithm"] == "CAN":
image_mask = paddle.ones(
- (np.expand_dims(
- batch[0], axis=0).shape), dtype='float32')
- label = paddle.ones((1, 36), dtype='int64')
+ (np.expand_dims(batch[0], axis=0).shape), dtype="float32"
+ )
+ label = paddle.ones((1, 36), dtype="int64")
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
- if config['Architecture']['algorithm'] == "SRN":
+ if config["Architecture"]["algorithm"] == "SRN":
preds = model(images, others)
- elif config['Architecture']['algorithm'] == "SAR":
+ elif config["Architecture"]["algorithm"] == "SAR":
preds = model(images, img_metas)
- elif config['Architecture']['algorithm'] == "RobustScanner":
+ elif config["Architecture"]["algorithm"] == "RobustScanner":
preds = model(images, img_metas)
- elif config['Architecture']['algorithm'] == "CAN":
+ elif config["Architecture"]["algorithm"] == "CAN":
preds = model([images, image_mask, label])
else:
preds = model(images)
@@ -178,9 +180,8 @@ def main():
"score": float(post_result[key][0][1]),
}
info = json.dumps(rec_info, ensure_ascii=False)
- elif isinstance(post_result, list) and isinstance(post_result[0],
- int):
- # for RFLearning CNT branch
+ elif isinstance(post_result, list) and isinstance(post_result[0], int):
+ # for RFLearning CNT branch
info = str(post_result[0])
else:
if len(post_result[0]) >= 2:
@@ -192,6 +193,6 @@ def main():
logger.info("success!")
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/infer_sr.py b/tools/infer_sr.py
index df4334f342..42856256a4 100755
--- a/tools/infer_sr.py
+++ b/tools/infer_sr.py
@@ -26,9 +26,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import paddle
@@ -41,42 +41,41 @@
def main():
- global_config = config['Global']
+ global_config = config["Global"]
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# sr transform
- config['Architecture']["Transform"]['infer_mode'] = True
+ config["Architecture"]["Transform"]["infer_mode"] = True
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
load_model(config, model)
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
op_name = list(op)[0]
- if 'Label' in op_name:
+ if "Label" in op_name:
continue
- elif op_name in ['SRResize']:
- op[op_name]['infer_mode'] = True
- elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['img_lr']
+ elif op_name in ["SRResize"]:
+ op[op_name]["infer_mode"] = True
+ elif op_name == "KeepKeys":
+ op[op_name]["keep_keys"] = ["img_lr"]
transforms.append(op)
- global_config['infer_mode'] = True
+ global_config["infer_mode"] = True
ops = create_operators(transforms, global_config)
- save_visual_path = config['Global'].get('save_visual', "infer_result/")
+ save_visual_path = config["Global"].get("save_visual", "infer_result/")
if not os.path.exists(os.path.dirname(save_visual_path)):
os.makedirs(os.path.dirname(save_visual_path))
model.eval()
- for file in get_image_file_list(config['Global']['infer_img']):
+ for file in get_image_file_list(config["Global"]["infer_img"]):
logger.info("infer_img: {}".format(file))
img = Image.open(file).convert("RGB")
- data = {'image_lr': img}
+ data = {"image_lr": img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
@@ -87,14 +86,16 @@ def main():
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(file)[-1]
- cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
- fm_sr[:, :, ::-1])
- logger.info("The visualized image saved in infer_result/sr_{}".format(
- img_name_pure))
+ cv2.imwrite(
+ "{}/sr_{}".format(save_visual_path, img_name_pure), fm_sr[:, :, ::-1]
+ )
+ logger.info(
+ "The visualized image saved in infer_result/sr_{}".format(img_name_pure)
+ )
logger.info("success!")
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main()
diff --git a/tools/infer_table.py b/tools/infer_table.py
index 6dde5d67d0..c386cef0f4 100644
--- a/tools/infer_table.py
+++ b/tools/infer_table.py
@@ -24,9 +24,9 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import paddle
from paddle.jit import to_static
@@ -44,47 +44,47 @@
@paddle.no_grad()
def main(config, device, logger, vdl_writer):
- global_config = config['Global']
+ global_config = config["Global"]
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
- if hasattr(post_process_class, 'character'):
- config['Architecture']["Head"]['out_channels'] = len(
- getattr(post_process_class, 'character'))
+ if hasattr(post_process_class, "character"):
+ config["Architecture"]["Head"]["out_channels"] = len(
+ getattr(post_process_class, "character")
+ )
- model = build_model(config['Architecture'])
- algorithm = config['Architecture']['algorithm']
+ model = build_model(config["Architecture"])
+ algorithm = config["Architecture"]["algorithm"]
load_model(config, model)
# create data ops
transforms = []
- for op in config['Eval']['dataset']['transforms']:
+ for op in config["Eval"]["dataset"]["transforms"]:
op_name = list(op)[0]
- if 'Encode' in op_name:
+ if "Encode" in op_name:
continue
- if op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image', 'shape']
+ if op_name == "KeepKeys":
+ op[op_name]["keep_keys"] = ["image", "shape"]
transforms.append(op)
- global_config['infer_mode'] = True
+ global_config["infer_mode"] = True
ops = create_operators(transforms, global_config)
- save_res_path = config['Global']['save_res_path']
+ save_res_path = config["Global"]["save_res_path"]
os.makedirs(save_res_path, exist_ok=True)
model.eval()
with open(
- os.path.join(save_res_path, 'infer.txt'), mode='w',
- encoding='utf-8') as f_w:
- for file in get_image_file_list(config['Global']['infer_img']):
+ os.path.join(save_res_path, "infer.txt"), mode="w", encoding="utf-8"
+ ) as f_w:
+ for file in get_image_file_list(config["Global"]["infer_img"]):
logger.info("infer_img: {}".format(file))
- with open(file, 'rb') as f:
+ with open(file, "rb") as f:
img = f.read()
- data = {'image': img}
+ data = {"image": img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
@@ -93,29 +93,28 @@ def main(config, device, logger, vdl_writer):
preds = model(images)
post_result = post_process_class(preds, [shape_list])
- structure_str_list = post_result['structure_batch_list'][0]
- bbox_list = post_result['bbox_batch_list'][0]
+ structure_str_list = post_result["structure_batch_list"][0]
+ bbox_list = post_result["bbox_batch_list"][0]
structure_str_list = structure_str_list[0]
- structure_str_list = [
- '', '', ''
- ] + structure_str_list + [' ', '', '']
+ structure_str_list = (
+ ["", "", ""]
+ + structure_str_list
+ + [" ", "", ""]
+ )
bbox_list_str = json.dumps(bbox_list.tolist())
- logger.info("result: {}, {}".format(structure_str_list,
- bbox_list_str))
- f_w.write("result: {}, {}\n".format(structure_str_list,
- bbox_list_str))
+ logger.info("result: {}, {}".format(structure_str_list, bbox_list_str))
+ f_w.write("result: {}, {}\n".format(structure_str_list, bbox_list_str))
if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
img = draw_rectangle(file, bbox_list)
else:
img = draw_boxes(cv2.imread(file), bbox_list)
- cv2.imwrite(
- os.path.join(save_res_path, os.path.basename(file)), img)
- logger.info('save result to {}'.format(save_res_path))
+ cv2.imwrite(os.path.join(save_res_path, os.path.basename(file)), img)
+ logger.info("save result to {}".format(save_res_path))
logger.info("success!")
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer)
diff --git a/tools/program.py b/tools/program.py
index c0de3767dc..65e8c00cd7 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -40,24 +40,21 @@
class ArgsParser(ArgumentParser):
def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
+ super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
+ self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
self.add_argument(
- "-o", "--opt", nargs='+', help="set configuration options")
- self.add_argument(
- '-p',
- '--profiler_options',
+ "-p",
+ "--profiler_options",
type=str,
default=None,
- help='The option of profiler, which should be in format ' \
- '\"key1=value1;key2=value2;key3=value3\".'
+ help="The option of profiler, which should be in format "
+ '"key1=value1;key2=value2;key3=value3".',
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, \
- "Please specify --config=configure_file_path."
+ assert args.config is not None, "Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
@@ -67,7 +64,7 @@ def _parse_opt(self, opts):
return config
for s in opts:
s = s.strip()
- k, v = s.split('=')
+ k, v = s.split("=")
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
@@ -80,8 +77,8 @@ def load_config(file_path):
Returns: global config
"""
_, ext = os.path.splitext(file_path)
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
+ assert ext in [".yml", ".yaml"], "only support yaml files for now"
+ config = yaml.load(open(file_path, "rb"), Loader=yaml.Loader)
return config
@@ -99,12 +96,13 @@ def merge_config(config, opts):
else:
config[key] = value
else:
- sub_keys = key.split('.')
- assert (
- sub_keys[0] in config
- ), "the sub_keys can only be one of global_config: {}, but get: " \
- "{}, please check your running command".format(
- config.keys(), sub_keys[0])
+ sub_keys = key.split(".")
+ assert sub_keys[0] in config, (
+ "the sub_keys can only be one of global_config: {}, but get: "
+ "{}, please check your running command".format(
+ config.keys(), sub_keys[0]
+ )
+ )
cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
@@ -119,11 +117,13 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
- err = "Config {} cannot be set as true while your paddle " \
- "is not compiled with {} ! \nPlease try: \n" \
- "\t1. Install paddlepaddle to run model on {} \n" \
- "\t2. Set {} as false in config file to run " \
- "model on CPU"
+ err = (
+ "Config {} cannot be set as true while your paddle "
+ "is not compiled with {} ! \nPlease try: \n"
+ "\t1. Install paddlepaddle to run model on {} \n"
+ "\t2. Set {} as false in config file to run "
+ "model on CPU"
+ )
try:
if use_gpu and use_xpu:
@@ -135,9 +135,11 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
sys.exit(1)
if use_npu:
- if int(paddle.version.major) != 0 and int(
- paddle.version.major) <= 2 and int(
- paddle.version.minor) <= 4:
+ if (
+ int(paddle.version.major) != 0
+ and int(paddle.version.major) <= 2
+ and int(paddle.version.minor) <= 4
+ ):
if not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
@@ -173,85 +175,108 @@ def to_float32(preds):
return preds
-def train(config,
- train_dataloader,
- valid_dataloader,
- device,
- model,
- loss_class,
- optimizer,
- lr_scheduler,
- post_process_class,
- eval_class,
- pre_best_model_dict,
- logger,
- step_pre_epoch,
- log_writer=None,
- scaler=None,
- amp_level='O2',
- amp_custom_black_list=[],
- amp_custom_white_list=[],
- amp_dtype='float16'):
- cal_metric_during_train = config['Global'].get('cal_metric_during_train',
- False)
- calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
- log_smooth_window = config['Global']['log_smooth_window']
- epoch_num = config['Global']['epoch_num']
- print_batch_step = config['Global']['print_batch_step']
- eval_batch_step = config['Global']['eval_batch_step']
- eval_batch_epoch = config['Global'].get('eval_batch_epoch', None)
- profiler_options = config['profiler_options']
+def train(
+ config,
+ train_dataloader,
+ valid_dataloader,
+ device,
+ model,
+ loss_class,
+ optimizer,
+ lr_scheduler,
+ post_process_class,
+ eval_class,
+ pre_best_model_dict,
+ logger,
+ step_pre_epoch,
+ log_writer=None,
+ scaler=None,
+ amp_level="O2",
+ amp_custom_black_list=[],
+ amp_custom_white_list=[],
+ amp_dtype="float16",
+):
+ cal_metric_during_train = config["Global"].get("cal_metric_during_train", False)
+ calc_epoch_interval = config["Global"].get("calc_epoch_interval", 1)
+ log_smooth_window = config["Global"]["log_smooth_window"]
+ epoch_num = config["Global"]["epoch_num"]
+ print_batch_step = config["Global"]["print_batch_step"]
+ eval_batch_step = config["Global"]["eval_batch_step"]
+ eval_batch_epoch = config["Global"].get("eval_batch_epoch", None)
+ profiler_options = config["profiler_options"]
global_step = 0
- if 'global_step' in pre_best_model_dict:
- global_step = pre_best_model_dict['global_step']
+ if "global_step" in pre_best_model_dict:
+ global_step = pre_best_model_dict["global_step"]
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0
- eval_batch_step = eval_batch_step[
- 1] if not eval_batch_epoch else step_pre_epoch * eval_batch_epoch
+ eval_batch_step = (
+ eval_batch_step[1]
+ if not eval_batch_epoch
+ else step_pre_epoch * eval_batch_epoch
+ )
if len(valid_dataloader) == 0:
logger.info(
- 'No Images in eval dataset, evaluation during training ' \
- 'will be disabled'
+ "No Images in eval dataset, evaluation during training "
+ "will be disabled"
)
start_eval_step = 1e111
logger.info(
- "During the training process, after the {}th iteration, " \
- "an evaluation is run every {} iterations".
- format(start_eval_step, eval_batch_step))
- save_epoch_step = config['Global']['save_epoch_step']
- save_model_dir = config['Global']['save_model_dir']
+ "During the training process, after the {}th iteration, "
+ "an evaluation is run every {} iterations".format(
+ start_eval_step, eval_batch_step
+ )
+ )
+ save_epoch_step = config["Global"]["save_epoch_step"]
+ save_model_dir = config["Global"]["save_model_dir"]
if not os.path.exists(save_model_dir):
os.makedirs(save_model_dir)
main_indicator = eval_class.main_indicator
best_model_dict = {main_indicator: 0}
best_model_dict.update(pre_best_model_dict)
- train_stats = TrainingStats(log_smooth_window, ['lr'])
+ train_stats = TrainingStats(log_smooth_window, ["lr"])
model_average = False
model.train()
- use_srn = config['Architecture']['algorithm'] == "SRN"
+ use_srn = config["Architecture"]["algorithm"] == "SRN"
extra_input_models = [
- "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN",
- "RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet', "ParseQ", "CPPD"
+ "SRN",
+ "NRTR",
+ "SAR",
+ "SEED",
+ "SVTR",
+ "SVTR_LCNet",
+ "SPIN",
+ "VisionLAN",
+ "RobustScanner",
+ "RFL",
+ "DRRG",
+ "SATRN",
+ "SVTR_HGNet",
+ "ParseQ",
+ "CPPD",
]
extra_input = False
- if config['Architecture']['algorithm'] == 'Distillation':
- for key in config['Architecture']["Models"]:
- extra_input = extra_input or config['Architecture']['Models'][key][
- 'algorithm'] in extra_input_models
+ if config["Architecture"]["algorithm"] == "Distillation":
+ for key in config["Architecture"]["Models"]:
+ extra_input = (
+ extra_input
+ or config["Architecture"]["Models"][key]["algorithm"]
+ in extra_input_models
+ )
else:
- extra_input = config['Architecture']['algorithm'] in extra_input_models
+ extra_input = config["Architecture"]["algorithm"] in extra_input_models
try:
- model_type = config['Architecture']['model_type']
+ model_type = config["Architecture"]["model_type"]
except:
model_type = None
- algorithm = config['Architecture']['algorithm']
+ algorithm = config["Architecture"]["algorithm"]
- start_epoch = best_model_dict[
- 'start_epoch'] if 'start_epoch' in best_model_dict else 1
+ start_epoch = (
+ best_model_dict["start_epoch"] if "start_epoch" in best_model_dict else 1
+ )
total_samples = 0
train_reader_cost = 0.0
@@ -259,15 +284,22 @@ def train(config,
reader_start = time.time()
eta_meter = AverageMeter()
- max_iter = len(train_dataloader) - 1 if platform.system(
- ) == "Windows" else len(train_dataloader)
+ max_iter = (
+ len(train_dataloader) - 1
+ if platform.system() == "Windows"
+ else len(train_dataloader)
+ )
for epoch in range(start_epoch, epoch_num + 1):
if train_dataloader.dataset.need_reset:
train_dataloader = build_dataloader(
- config, 'Train', device, logger, seed=epoch)
- max_iter = len(train_dataloader) - 1 if platform.system(
- ) == "Windows" else len(train_dataloader)
+ config, "Train", device, logger, seed=epoch
+ )
+ max_iter = (
+ len(train_dataloader) - 1
+ if platform.system() == "Windows"
+ else len(train_dataloader)
+ )
for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options)
@@ -281,58 +313,63 @@ def train(config,
# use amp
if scaler:
with paddle.amp.auto_cast(
- level=amp_level,
- custom_black_list=amp_custom_black_list,
- custom_white_list=amp_custom_white_list,
- dtype=amp_dtype):
- if model_type == 'table' or extra_input:
+ level=amp_level,
+ custom_black_list=amp_custom_black_list,
+ custom_white_list=amp_custom_white_list,
+ dtype=amp_dtype,
+ ):
+ if model_type == "table" or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
preds = model(batch)
- elif algorithm in ['CAN']:
+ elif algorithm in ["CAN"]:
preds = model(batch[:3])
else:
preds = model(images)
preds = to_float32(preds)
loss = loss_class(preds, batch)
- avg_loss = loss['loss']
+ avg_loss = loss["loss"]
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
else:
- if model_type == 'table' or extra_input:
+ if model_type == "table" or extra_input:
preds = model(images, data=batch[1:])
- elif model_type in ["kie", 'sr']:
+ elif model_type in ["kie", "sr"]:
preds = model(batch)
- elif algorithm in ['CAN']:
+ elif algorithm in ["CAN"]:
preds = model(batch[:3])
else:
preds = model(images)
loss = loss_class(preds, batch)
- avg_loss = loss['loss']
+ avg_loss = loss["loss"]
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
- if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
+ if (
+ cal_metric_during_train and epoch % calc_epoch_interval == 0
+ ): # only rec and cls need
batch = [item.numpy() for item in batch]
- if model_type in ['kie', 'sr']:
+ if model_type in ["kie", "sr"]:
eval_class(preds, batch)
- elif model_type in ['table']:
+ elif model_type in ["table"]:
post_result = post_process_class(preds, batch)
eval_class(post_result, batch)
- elif algorithm in ['CAN']:
- model_type = 'can'
+ elif algorithm in ["CAN"]:
+ model_type = "can"
eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
else:
- if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
- ]: # for multi head loss
+ if config["Loss"]["name"] in [
+ "MultiLoss",
+ "MultiLoss_v2",
+ ]: # for multi head loss
post_result = post_process_class(
- preds['ctc'], batch[1]) # for CTC head out
- elif config['Loss']['name'] in ['VLLoss']:
- post_result = post_process_class(preds, batch[1],
- batch[-1])
+ preds["ctc"], batch[1]
+ ) # for CTC head out
+ elif config["Loss"]["name"] in ["VLLoss"]:
+ post_result = post_process_class(preds, batch[1], batch[-1])
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
@@ -353,49 +390,64 @@ def train(config,
k: float(v) if v.shape == [] else v.numpy().mean()
for k, v in loss.items()
}
- stats['lr'] = lr
+ stats["lr"] = lr
train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(
- metrics=train_stats.get(), prefix="TRAIN", step=global_step)
+ metrics=train_stats.get(), prefix="TRAIN", step=global_step
+ )
if dist.get_rank() == 0 and (
- (global_step > 0 and global_step % print_batch_step == 0) or
- (idx >= len(train_dataloader) - 1)):
+ (global_step > 0 and global_step % print_batch_step == 0)
+ or (idx >= len(train_dataloader) - 1)
+ ):
logs = train_stats.log()
- eta_sec = ((epoch_num + 1 - epoch) * \
- len(train_dataloader) - idx - 1) * eta_meter.avg
+ eta_sec = (
+ (epoch_num + 1 - epoch) * len(train_dataloader) - idx - 1
+ ) * eta_meter.avg
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
max_mem_reserved_str = ""
max_mem_allocated_str = ""
if paddle.device.is_compiled_with_cuda():
max_mem_reserved_str = f"max_mem_reserved: {paddle.device.cuda.max_memory_reserved() // (1024 ** 2)} MB,"
max_mem_allocated_str = f"max_mem_allocated: {paddle.device.cuda.max_memory_allocated() // (1024 ** 2)} MB"
- strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
- '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
- 'ips: {:.5f} samples/s, eta: {}, {} {}'.format(
- epoch, epoch_num, global_step, logs,
- train_reader_cost / print_batch_step,
- train_batch_cost / print_batch_step,
- total_samples / print_batch_step,
- total_samples / train_batch_cost, eta_sec_format, max_mem_reserved_str, max_mem_allocated_str)
+ strs = (
+ "epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: "
+ "{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, "
+ "ips: {:.5f} samples/s, eta: {}, {} {}".format(
+ epoch,
+ epoch_num,
+ global_step,
+ logs,
+ train_reader_cost / print_batch_step,
+ train_batch_cost / print_batch_step,
+ total_samples / print_batch_step,
+ total_samples / train_batch_cost,
+ eta_sec_format,
+ max_mem_reserved_str,
+ max_mem_allocated_str,
+ )
+ )
logger.info(strs)
-
+
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
# eval
- if global_step > start_eval_step and \
- (global_step - start_eval_step) % eval_batch_step == 0 \
- and dist.get_rank() == 0:
+ if (
+ global_step > start_eval_step
+ and (global_step - start_eval_step) % eval_batch_step == 0
+ and dist.get_rank() == 0
+ ):
if model_average:
Model_Average = paddle.incubate.optimizer.ModelAverage(
0.15,
parameters=model.parameters(),
min_average_window=10000,
- max_average_window=15625)
+ max_average_window=15625,
+ )
Model_Average.apply()
cur_metric = eval(
model,
@@ -408,20 +460,22 @@ def train(config,
amp_level=amp_level,
amp_custom_black_list=amp_custom_black_list,
amp_custom_white_list=amp_custom_white_list,
- amp_dtype=amp_dtype)
- cur_metric_str = 'cur metric, {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
+ amp_dtype=amp_dtype,
+ )
+ cur_metric_str = "cur metric, {}".format(
+ ", ".join(["{}: {}".format(k, v) for k, v in cur_metric.items()])
+ )
logger.info(cur_metric_str)
# logger metric
if log_writer is not None:
log_writer.log_metrics(
- metrics=cur_metric, prefix="EVAL", step=global_step)
+ metrics=cur_metric, prefix="EVAL", step=global_step
+ )
- if cur_metric[main_indicator] >= best_model_dict[
- main_indicator]:
+ if cur_metric[main_indicator] >= best_model_dict[main_indicator]:
best_model_dict.update(cur_metric)
- best_model_dict['best_epoch'] = epoch
+ best_model_dict["best_epoch"] = epoch
save_model(
model,
optimizer,
@@ -429,28 +483,32 @@ def train(config,
logger,
config,
is_best=True,
- prefix='best_accuracy',
+ prefix="best_accuracy",
best_model_dict=best_model_dict,
epoch=epoch,
- global_step=global_step)
- best_str = 'best metric, {}'.format(', '.join([
- '{}: {}'.format(k, v) for k, v in best_model_dict.items()
- ]))
+ global_step=global_step,
+ )
+ best_str = "best metric, {}".format(
+ ", ".join(
+ ["{}: {}".format(k, v) for k, v in best_model_dict.items()]
+ )
+ )
logger.info(best_str)
# logger best metric
if log_writer is not None:
log_writer.log_metrics(
metrics={
- "best_{}".format(main_indicator):
- best_model_dict[main_indicator]
+ "best_{}".format(main_indicator): best_model_dict[
+ main_indicator
+ ]
},
prefix="EVAL",
- step=global_step)
+ step=global_step,
+ )
log_writer.log_model(
- is_best=True,
- prefix="best_accuracy",
- metadata=best_model_dict)
+ is_best=True, prefix="best_accuracy", metadata=best_model_dict
+ )
reader_start = time.time()
if dist.get_rank() == 0:
@@ -461,10 +519,11 @@ def train(config,
logger,
config,
is_best=False,
- prefix='latest',
+ prefix="latest",
best_model_dict=best_model_dict,
epoch=epoch,
- global_step=global_step)
+ global_step=global_step,
+ )
if log_writer is not None:
log_writer.log_model(is_best=False, prefix="latest")
@@ -477,44 +536,50 @@ def train(config,
logger,
config,
is_best=False,
- prefix='iter_epoch_{}'.format(epoch),
+ prefix="iter_epoch_{}".format(epoch),
best_model_dict=best_model_dict,
epoch=epoch,
- global_step=global_step)
+ global_step=global_step,
+ )
if log_writer is not None:
log_writer.log_model(
- is_best=False, prefix='iter_epoch_{}'.format(epoch))
+ is_best=False, prefix="iter_epoch_{}".format(epoch)
+ )
- best_str = 'best metric, {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
+ best_str = "best metric, {}".format(
+ ", ".join(["{}: {}".format(k, v) for k, v in best_model_dict.items()])
+ )
logger.info(best_str)
if dist.get_rank() == 0 and log_writer is not None:
log_writer.close()
return
-def eval(model,
- valid_dataloader,
- post_process_class,
- eval_class,
- model_type=None,
- extra_input=False,
- scaler=None,
- amp_level='O2',
- amp_custom_black_list=[],
- amp_custom_white_list=[],
- amp_dtype='float16'):
+def eval(
+ model,
+ valid_dataloader,
+ post_process_class,
+ eval_class,
+ model_type=None,
+ extra_input=False,
+ scaler=None,
+ amp_level="O2",
+ amp_custom_black_list=[],
+ amp_custom_white_list=[],
+ amp_dtype="float16",
+):
model.eval()
with paddle.no_grad():
total_frame = 0.0
total_time = 0.0
pbar = tqdm(
- total=len(valid_dataloader),
- desc='eval model:',
- position=0,
- leave=True)
- max_iter = len(valid_dataloader) - 1 if platform.system(
- ) == "Windows" else len(valid_dataloader)
+ total=len(valid_dataloader), desc="eval model:", position=0, leave=True
+ )
+ max_iter = (
+ len(valid_dataloader) - 1
+ if platform.system() == "Windows"
+ else len(valid_dataloader)
+ )
sum_images = 0
for idx, batch in enumerate(valid_dataloader):
if idx >= max_iter:
@@ -525,16 +590,17 @@ def eval(model,
# use amp
if scaler:
with paddle.amp.auto_cast(
- level=amp_level,
- custom_black_list=amp_custom_black_list,
- dtype=amp_dtype):
- if model_type == 'table' or extra_input:
+ level=amp_level,
+ custom_black_list=amp_custom_black_list,
+ dtype=amp_dtype,
+ ):
+ if model_type == "table" or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
preds = model(batch)
- elif model_type in ['can']:
+ elif model_type in ["can"]:
preds = model(batch[:3])
- elif model_type in ['sr']:
+ elif model_type in ["sr"]:
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
@@ -542,13 +608,13 @@ def eval(model,
preds = model(images)
preds = to_float32(preds)
else:
- if model_type == 'table' or extra_input:
+ if model_type == "table" or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
preds = model(batch)
- elif model_type in ['can']:
+ elif model_type in ["can"]:
preds = model(batch[:3])
- elif model_type in ['sr']:
+ elif model_type in ["sr"]:
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
@@ -564,15 +630,15 @@ def eval(model,
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
- if model_type in ['table', 'kie']:
+ if model_type in ["table", "kie"]:
if post_process_class is None:
eval_class(preds, batch_numpy)
else:
post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy)
- elif model_type in ['sr']:
+ elif model_type in ["sr"]:
eval_class(preds, batch_numpy)
- elif model_type in ['can']:
+ elif model_type in ["can"]:
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
else:
post_result = post_process_class(preds, batch_numpy[1])
@@ -586,7 +652,7 @@ def eval(model,
pbar.close()
model.train()
- metric['fps'] = total_frame / total_time
+ metric["fps"] = total_frame / total_time
return metric
@@ -605,8 +671,8 @@ def update_center(char_center, post_result, preds):
index = logit[idx_time]
if index in char_center.keys():
char_center[index][0] = (
- char_center[index][0] * char_center[index][1] +
- feat[idx_time]) / (char_center[index][1] + 1)
+ char_center[index][0] * char_center[index][1] + feat[idx_time]
+ ) / (char_center[index][1] + 1)
char_center[index][1] += 1
else:
char_center[index] = [feat[idx_time], 1]
@@ -614,9 +680,12 @@ def update_center(char_center, post_result, preds):
def get_center(model, eval_dataloader, post_process_class):
- pbar = tqdm(total=len(eval_dataloader), desc='get center:')
- max_iter = len(eval_dataloader) - 1 if platform.system(
- ) == "Windows" else len(eval_dataloader)
+ pbar = tqdm(total=len(eval_dataloader), desc="get center:")
+ max_iter = (
+ len(eval_dataloader) - 1
+ if platform.system() == "Windows"
+ else len(eval_dataloader)
+ )
char_center = dict()
for idx, batch in enumerate(eval_dataloader):
if idx >= max_iter:
@@ -629,7 +698,7 @@ def get_center(model, eval_dataloader, post_process_class):
# Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1])
- #update char_center
+ # update char_center
char_center = update_center(char_center, post_result, preds)
pbar.update(1)
@@ -649,63 +718,98 @@ def preprocess(is_train=False):
if is_train:
# save_config
- save_model_dir = config['Global']['save_model_dir']
+ save_model_dir = config["Global"]["save_model_dir"]
os.makedirs(save_model_dir, exist_ok=True)
- with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
- yaml.dump(
- dict(config), f, default_flow_style=False, sort_keys=False)
- log_file = '{}/train.log'.format(save_model_dir)
+ with open(os.path.join(save_model_dir, "config.yml"), "w") as f:
+ yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
+ log_file = "{}/train.log".format(save_model_dir)
else:
log_file = None
logger = get_logger(log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
- use_gpu = config['Global'].get('use_gpu', False)
- use_xpu = config['Global'].get('use_xpu', False)
- use_npu = config['Global'].get('use_npu', False)
- use_mlu = config['Global'].get('use_mlu', False)
+ use_gpu = config["Global"].get("use_gpu", False)
+ use_xpu = config["Global"].get("use_xpu", False)
+ use_npu = config["Global"].get("use_npu", False)
+ use_mlu = config["Global"].get("use_mlu", False)
- alg = config['Architecture']['algorithm']
+ alg = config["Architecture"]["algorithm"]
assert alg in [
- 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
- 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
- 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN',
- 'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG',
- 'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet', 'ParseQ', 'CPPD'
+ "EAST",
+ "DB",
+ "SAST",
+ "Rosetta",
+ "CRNN",
+ "STARNet",
+ "RARE",
+ "SRN",
+ "CLS",
+ "PGNet",
+ "Distillation",
+ "NRTR",
+ "TableAttn",
+ "SAR",
+ "PSE",
+ "SEED",
+ "SDMGR",
+ "LayoutXLM",
+ "LayoutLM",
+ "LayoutLMv2",
+ "PREN",
+ "FCE",
+ "SVTR",
+ "SVTR_LCNet",
+ "ViTSTR",
+ "ABINet",
+ "DB++",
+ "TableMaster",
+ "SPIN",
+ "VisionLAN",
+ "Gestalt",
+ "SLANet",
+ "RobustScanner",
+ "CT",
+ "RFL",
+ "DRRG",
+ "CAN",
+ "Telescope",
+ "SATRN",
+ "SVTR_HGNet",
+ "ParseQ",
+ "CPPD",
]
if use_xpu:
- device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
+ device = "xpu:{0}".format(os.getenv("FLAGS_selected_xpus", 0))
elif use_npu:
- device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
+ device = "npu:{0}".format(os.getenv("FLAGS_selected_npus", 0))
elif use_mlu:
- device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
+ device = "mlu:{0}".format(os.getenv("FLAGS_selected_mlus", 0))
else:
- device = 'gpu:{}'.format(dist.ParallelEnv()
- .dev_id) if use_gpu else 'cpu'
+ device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
check_device(use_gpu, use_xpu, use_npu, use_mlu)
device = paddle.set_device(device)
- config['Global']['distributed'] = dist.get_world_size() != 1
+ config["Global"]["distributed"] = dist.get_world_size() != 1
loggers = []
- if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
- save_model_dir = config['Global']['save_model_dir']
+ if "use_visualdl" in config["Global"] and config["Global"]["use_visualdl"]:
+ save_model_dir = config["Global"]["save_model_dir"]
vdl_writer_path = save_model_dir
log_writer = VDLLogger(vdl_writer_path)
loggers.append(log_writer)
- if ('use_wandb' in config['Global'] and
- config['Global']['use_wandb']) or 'wandb' in config:
- save_dir = config['Global']['save_model_dir']
+ if (
+ "use_wandb" in config["Global"] and config["Global"]["use_wandb"]
+ ) or "wandb" in config:
+ save_dir = config["Global"]["save_model_dir"]
wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config:
- wandb_params = config['wandb']
+ wandb_params = config["wandb"]
else:
wandb_params = dict()
- wandb_params.update({'save_dir': save_dir})
+ wandb_params.update({"save_dir": save_dir})
log_writer = WandbLogger(**wandb_params, config=config)
loggers.append(log_writer)
else:
@@ -717,6 +821,5 @@ def preprocess(is_train=False):
else:
log_writer = None
- logger.info('train with paddle {} and device {}'.format(paddle.__version__,
- device))
+ logger.info("train with paddle {} and device {}".format(paddle.__version__, device))
return config, device, logger, log_writer
diff --git a/tools/test_hubserving.py b/tools/test_hubserving.py
index ec17a9413e..5480107228 100755
--- a/tools/test_hubserving.py
+++ b/tools/test_hubserving.py
@@ -13,11 +13,13 @@
# limitations under the License.
import os
import sys
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
from ppocr.utils.logging import get_logger
+
logger = get_logger()
import cv2
@@ -35,7 +37,7 @@
def cv2_to_base64(image):
- return base64.b64encode(image).decode('utf8')
+ return base64.b64encode(image).decode("utf8")
def draw_server_result(image_file, res):
@@ -44,14 +46,14 @@ def draw_server_result(image_file, res):
if len(res) == 0:
return np.array(image)
keys = res[0].keys()
- if 'text_region' not in keys: # for ocr_rec, draw function is invalid
+ if "text_region" not in keys: # for ocr_rec, draw function is invalid
logger.info("draw function is invalid for ocr_rec!")
return None
- elif 'text' not in keys: # for ocr_det
+ elif "text" not in keys: # for ocr_det
logger.info("draw text boxes only!")
boxes = []
for dno in range(len(res)):
- boxes.append(res[dno]['text_region'])
+ boxes.append(res[dno]["text_region"])
boxes = np.array(boxes)
draw_img = draw_boxes(image, boxes)
return draw_img
@@ -61,13 +63,12 @@ def draw_server_result(image_file, res):
texts = []
scores = []
for dno in range(len(res)):
- boxes.append(res[dno]['text_region'])
- texts.append(res[dno]['text'])
- scores.append(res[dno]['confidence'])
+ boxes.append(res[dno]["text_region"])
+ texts.append(res[dno]["text"])
+ scores.append(res[dno]["confidence"])
boxes = np.array(boxes)
scores = np.array(scores)
- draw_img = draw_ocr(
- image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
+ draw_img = draw_ocr(image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
return draw_img
@@ -76,24 +77,24 @@ def save_structure_res(res, save_folder, image_file):
excel_save_folder = os.path.join(save_folder, os.path.basename(image_file))
os.makedirs(excel_save_folder, exist_ok=True)
# save res
- with open(
- os.path.join(excel_save_folder, 'res.txt'), 'w',
- encoding='utf8') as f:
+ with open(os.path.join(excel_save_folder, "res.txt"), "w", encoding="utf8") as f:
for region in res:
- if region['type'] == 'Table':
- excel_path = os.path.join(excel_save_folder,
- '{}.xlsx'.format(region['bbox']))
- to_excel(region['res'], excel_path)
- elif region['type'] == 'Figure':
- x1, y1, x2, y2 = region['bbox']
- print(region['bbox'])
+ if region["type"] == "Table":
+ excel_path = os.path.join(
+ excel_save_folder, "{}.xlsx".format(region["bbox"])
+ )
+ to_excel(region["res"], excel_path)
+ elif region["type"] == "Figure":
+ x1, y1, x2, y2 = region["bbox"]
+ print(region["bbox"])
roi_img = img[y1:y2, x1:x2, :]
- img_path = os.path.join(excel_save_folder,
- '{}.jpg'.format(region['bbox']))
+ img_path = os.path.join(
+ excel_save_folder, "{}.jpg".format(region["bbox"])
+ )
cv2.imwrite(img_path, roi_img)
else:
- for text_result in region['res']:
- f.write('{}\n'.format(json.dumps(text_result)))
+ for text_result in region["res"]:
+ f.write("{}\n".format(json.dumps(text_result)))
def main(args):
@@ -103,16 +104,15 @@ def main(args):
cnt = 0
total_time = 0
for image_file in image_file_list:
- img = open(image_file, 'rb').read()
+ img = open(image_file, "rb").read()
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
img_name = os.path.basename(image_file)
# seed http request
starttime = time.time()
- data = {'images': [cv2_to_base64(img)]}
- r = requests.post(
- url=args.server_url, headers=headers, data=json.dumps(data))
+ data = {"images": [cv2_to_base64(img)]}
+ r = requests.post(url=args.server_url, headers=headers, data=json.dumps(data))
elapse = time.time() - starttime
total_time += elapse
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
@@ -121,10 +121,10 @@ def main(args):
if args.visualize:
draw_img = None
- if 'structure_table' in args.server_url:
- to_excel(res['html'], './{}.xlsx'.format(img_name))
- elif 'structure_system' in args.server_url:
- save_structure_res(res['regions'], args.output, image_file)
+ if "structure_table" in args.server_url:
+ to_excel(res["html"], "./{}.xlsx".format(img_name))
+ elif "structure_system" in args.server_url:
+ save_structure_res(res["regions"], args.output, image_file)
else:
draw_img = draw_server_result(image_file, res)
if draw_img is not None:
@@ -132,9 +132,13 @@ def main(args):
os.makedirs(args.output)
cv2.imwrite(
os.path.join(args.output, os.path.basename(image_file)),
- draw_img[:, :, ::-1])
- logger.info("The visualized image saved in {}".format(
- os.path.join(args.output, os.path.basename(image_file))))
+ draw_img[:, :, ::-1],
+ )
+ logger.info(
+ "The visualized image saved in {}".format(
+ os.path.join(args.output, os.path.basename(image_file))
+ )
+ )
cnt += 1
if cnt % 100 == 0:
logger.info("{} processed".format(cnt))
@@ -143,15 +147,16 @@ def main(args):
def parse_args():
import argparse
+
parser = argparse.ArgumentParser(description="args for hub serving")
parser.add_argument("--server_url", type=str, required=True)
parser.add_argument("--image_dir", type=str, required=True)
parser.add_argument("--visualize", type=str2bool, default=False)
- parser.add_argument("--output", type=str, default='./hubserving_result')
+ parser.add_argument("--output", type=str, default="./hubserving_result")
args = parser.parse_args()
return args
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
main(args)
diff --git a/tools/train.py b/tools/train.py
index faed388ec1..0aaa0089f9 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -21,7 +21,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
import yaml
import paddle
@@ -43,170 +43,195 @@
def main(config, device, logger, vdl_writer, seed):
# init dist environment
- if config['Global']['distributed']:
+ if config["Global"]["distributed"]:
dist.init_parallel_env()
- global_config = config['Global']
+ global_config = config["Global"]
# build dataloader
set_signal_handlers()
- train_dataloader = build_dataloader(config, 'Train', device, logger, seed)
+ train_dataloader = build_dataloader(config, "Train", device, logger, seed)
if len(train_dataloader) == 0:
logger.error(
- "No Images in train dataset, please ensure\n" +
- "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
- +
- "\t2. The annotation file and path in the configuration file are provided normally."
+ "No Images in train dataset, please ensure\n"
+ + "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+ + "\t2. The annotation file and path in the configuration file are provided normally."
)
return
- if config['Eval']:
- valid_dataloader = build_dataloader(config, 'Eval', device, logger,
- seed)
+ if config["Eval"]:
+ valid_dataloader = build_dataloader(config, "Eval", device, logger, seed)
else:
valid_dataloader = None
step_pre_epoch = len(train_dataloader)
# build post process
- post_process_class = build_post_process(config['PostProcess'],
- global_config)
+ post_process_class = build_post_process(config["PostProcess"], global_config)
# build model
# for rec algorithm
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- if config['Architecture']["algorithm"] in ["Distillation",
- ]: # distillation model
- for key in config['Architecture']["Models"]:
- if config['Architecture']['Models'][key]['Head'][
- 'name'] == 'MultiHead': # for multi head
- if config['PostProcess'][
- 'name'] == 'DistillationSARLabelDecode':
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ if config["Architecture"]["algorithm"] in [
+ "Distillation",
+ ]: # distillation model
+ for key in config["Architecture"]["Models"]:
+ if (
+ config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
+ ): # for multi head
+ if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
- if config['PostProcess'][
- 'name'] == 'DistillationNRTRLabelDecode':
+ if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
out_channels_list = {}
- out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list["CTCLabelDecode"] = char_num
# update SARLoss params
- if list(config['Loss']['loss_config_list'][-1].keys())[
- 0] == 'DistillationSARLoss':
- config['Loss']['loss_config_list'][-1][
- 'DistillationSARLoss'][
- 'ignore_index'] = char_num + 1
- out_channels_list['SARLabelDecode'] = char_num + 2
- elif any('DistillationNRTRLoss' in d
- for d in config['Loss']['loss_config_list']):
- out_channels_list['NRTRLabelDecode'] = char_num + 3
-
- config['Architecture']['Models'][key]['Head'][
- 'out_channels_list'] = out_channels_list
+ if (
+ list(config["Loss"]["loss_config_list"][-1].keys())[0]
+ == "DistillationSARLoss"
+ ):
+ config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
+ "ignore_index"
+ ] = (char_num + 1)
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ elif any(
+ "DistillationNRTRLoss" in d
+ for d in config["Loss"]["loss_config_list"]
+ ):
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels_list"
+ ] = out_channels_list
else:
- config['Architecture']["Models"][key]["Head"][
- 'out_channels'] = char_num
- elif config['Architecture']['Head'][
- 'name'] == 'MultiHead': # for multi head
- if config['PostProcess']['name'] == 'SARLabelDecode':
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels"
+ ] = char_num
+ elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
+ if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
- if config['PostProcess']['name'] == 'NRTRLabelDecode':
+ if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
out_channels_list = {}
- out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list["CTCLabelDecode"] = char_num
# update SARLoss params
- if list(config['Loss']['loss_config_list'][1].keys())[
- 0] == 'SARLoss':
- if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
- config['Loss']['loss_config_list'][1]['SARLoss'] = {
- 'ignore_index': char_num + 1
+ if list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss":
+ if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
+ config["Loss"]["loss_config_list"][1]["SARLoss"] = {
+ "ignore_index": char_num + 1
}
else:
- config['Loss']['loss_config_list'][1]['SARLoss'][
- 'ignore_index'] = char_num + 1
- out_channels_list['SARLabelDecode'] = char_num + 2
- elif list(config['Loss']['loss_config_list'][1].keys())[
- 0] == 'NRTRLoss':
- out_channels_list['NRTRLabelDecode'] = char_num + 3
- config['Architecture']['Head'][
- 'out_channels_list'] = out_channels_list
+ config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
+ char_num + 1
+ )
+ out_channels_list["SARLabelDecode"] = char_num + 2
+ elif list(config["Loss"]["loss_config_list"][1].keys())[0] == "NRTRLoss":
+ out_channels_list["NRTRLabelDecode"] = char_num + 3
+ config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
- config['Architecture']["Head"]['out_channels'] = char_num
+ config["Architecture"]["Head"]["out_channels"] = char_num
- if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
- config['Loss']['ignore_index'] = char_num - 1
+ if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
+ config["Loss"]["ignore_index"] = char_num - 1
- model = build_model(config['Architecture'])
+ model = build_model(config["Architecture"])
use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
- logger.info('convert_sync_batchnorm')
+ logger.info("convert_sync_batchnorm")
model = apply_to_static(model, config, logger)
# build loss
- loss_class = build_loss(config['Loss'])
+ loss_class = build_loss(config["Loss"])
# build optim
optimizer, lr_scheduler = build_optimizer(
- config['Optimizer'],
- epochs=config['Global']['epoch_num'],
+ config["Optimizer"],
+ epochs=config["Global"]["epoch_num"],
step_each_epoch=len(train_dataloader),
- model=model)
+ model=model,
+ )
# build metric
- eval_class = build_metric(config['Metric'])
+ eval_class = build_metric(config["Metric"])
- logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
+ logger.info("train dataloader has {} iters".format(len(train_dataloader)))
if valid_dataloader is not None:
- logger.info('valid dataloader has {} iters'.format(
- len(valid_dataloader)))
+ logger.info("valid dataloader has {} iters".format(len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False)
- amp_level = config["Global"].get("amp_level", 'O2')
- amp_dtype = config["Global"].get("amp_dtype", 'float16')
- amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
- amp_custom_white_list = config['Global'].get('amp_custom_white_list', [])
+ amp_level = config["Global"].get("amp_level", "O2")
+ amp_dtype = config["Global"].get("amp_dtype", "float16")
+ amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
+ amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
if use_amp:
- AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
+ AMP_RELATED_FLAGS_SETTING = {
+ "FLAGS_max_inplace_grad_add": 8,
+ }
if paddle.is_compiled_with_cuda():
- AMP_RELATED_FLAGS_SETTING.update({
- 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
- 'FLAGS_gemm_use_half_precision_compute_type': 0,
- })
+ AMP_RELATED_FLAGS_SETTING.update(
+ {
+ "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
+ "FLAGS_gemm_use_half_precision_compute_type": 0,
+ }
+ )
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
- "use_dynamic_loss_scaling", False)
+ "use_dynamic_loss_scaling", False
+ )
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
- use_dynamic_loss_scaling=use_dynamic_loss_scaling)
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling,
+ )
if amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=amp_level,
master_weight=True,
- dtype=amp_dtype)
+ dtype=amp_dtype,
+ )
else:
scaler = None
# load pretrain model
- pre_best_model_dict = load_model(config, model, optimizer,
- config['Architecture']["model_type"])
+ pre_best_model_dict = load_model(
+ config, model, optimizer, config["Architecture"]["model_type"]
+ )
- if config['Global']['distributed']:
+ if config["Global"]["distributed"]:
model = paddle.DataParallel(model)
# start train
- program.train(config, train_dataloader, valid_dataloader, device, model,
- loss_class, optimizer, lr_scheduler, post_process_class,
- eval_class, pre_best_model_dict, logger, step_pre_epoch,
- vdl_writer, scaler, amp_level, amp_custom_black_list,
- amp_custom_white_list, amp_dtype)
+ program.train(
+ config,
+ train_dataloader,
+ valid_dataloader,
+ device,
+ model,
+ loss_class,
+ optimizer,
+ lr_scheduler,
+ post_process_class,
+ eval_class,
+ pre_best_model_dict,
+ logger,
+ step_pre_epoch,
+ vdl_writer,
+ scaler,
+ amp_level,
+ amp_custom_black_list,
+ amp_custom_white_list,
+ amp_dtype,
+ )
def test_reader(config, device, logger):
- loader = build_dataloader(config, 'Train', device, logger)
+ loader = build_dataloader(config, "Train", device, logger)
import time
+
starttime = time.time()
count = 0
try:
@@ -215,16 +240,17 @@ def test_reader(config, device, logger):
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
- logger.info("reader: {}, {}, {}".format(
- count, len(data[0]), batch_time))
+ logger.info(
+ "reader: {}, {}, {}".format(count, len(data[0]), batch_time)
+ )
except Exception as e:
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))
-if __name__ == '__main__':
+if __name__ == "__main__":
config, device, logger, vdl_writer = program.preprocess(is_train=True)
- seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
+ seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer, seed)
# test_reader(config, device, logger)
|