Created
March 25, 2023 11:42
-
-
Save aimerneige/b84bfed160b7810349b3ee2e43ca507a to your computer and use it in GitHub Desktop.
PyQt OCR Software With Baidu API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/env/bin python3 | |
# -*- coding: utf-8 -*- | |
import sys | |
import base64 | |
import requests | |
from enum import Enum | |
from PyQt5 import QtCore | |
from PyQt5.QtGui import QPixmap | |
from PyQt5.QtCore import QSize, pyqtSlot | |
from PyQt5.QtWidgets import QApplication, QDesktopWidget, QMainWindow, QPushButton, QLabel, QTextEdit, QComboBox, QFileDialog, QMessageBox | |
window_title = "文本识别" | |
API_KEY = "replace_this_with_your_key" | |
SECRET_KEY = "replace_this_with_your_key" | |
TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' | |
OCR_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/' | |
class Language(Enum): | |
auto_detect = 0 | |
CHN_ENG = 1 | |
ENG = 2 | |
JAP = 3 | |
KOR = 4 | |
FRE = 5 | |
SPA = 6 | |
POR = 7 | |
GER = 8 | |
ITA = 9 | |
RUS = 10 | |
DAN = 11 | |
DUT = 12 | |
MAL = 13 | |
SWE = 14 | |
IND = 15 | |
POL = 16 | |
ROM = 17 | |
TUR = 18 | |
GRE = 19 | |
HUN = 20 | |
class OCR(object): | |
def __init__(self, api_key, secret_key): | |
super().__init__() | |
self.API_KEY = api_key | |
self.SECRET_KEY = secret_key | |
self.OCR_TYPE = "通用文字" | |
self.ACCESS_TOKEN = self.fetch_token() | |
self.LANGUAGE = Language.auto_detect.name | |
self.DETECT_DIRECTION = "false" | |
self.PARAGRAPH = "false" | |
self.PROBABILITY = "true" | |
def set_language(self, language): | |
self.LANGUAGE = language.name | |
def set_detect_direction(self, detect_direction): | |
self.DETECT_DIRECTION = detect_direction | |
def set_paragraph(self, paragraph): | |
self.PARAGRAPH = paragraph | |
def set_probability(self, probability): | |
self.PROBABILITY = probability | |
def fetch_token(self): | |
response = requests.post(TOKEN_URL, data={ | |
'grant_type': 'client_credentials', | |
'client_id': self.API_KEY, | |
'client_secret': self.SECRET_KEY | |
}) | |
if response: | |
return response.json()['access_token'] | |
def encode_image(self, image_path): | |
with open(image_path, 'rb') as f: | |
image_data = f.read() | |
return base64.b64encode(image_data) | |
def accurate_basic(self, image_path): | |
ocr_url = OCR_URL + "accurate_basic" | |
request_url = ocr_url + '?access_token=' + self.ACCESS_TOKEN | |
headers = {'Content-Type': 'application/x-www-form-urlencoded'} | |
params = { | |
'image': self.encode_image(image_path), | |
'language_type': self.LANGUAGE, | |
'detect_direction': self.DETECT_DIRECTION, | |
'paragraph': self.PARAGRAPH, | |
'probability': self.PROBABILITY, | |
} | |
json_data = requests.post( | |
request_url, headers=headers, data=params).json() | |
result_text = "" | |
for paragraphs in json_data["paragraphs_result"]: | |
for index in paragraphs["words_result_idx"]: | |
result_text += json_data["words_result"][index]["words"] | |
result_text += " " | |
result_text += "\n" | |
return result_text | |
def numbers(self, image_path): | |
ocr_url = OCR_URL + "numbers" | |
request_url = ocr_url + "?access_token=" + self.ACCESS_TOKEN | |
headers = {'content-type': 'application/x-www-form-urlencoded'} | |
params = { | |
'image': self.encode_image(image_path), | |
'detect_direction': self.DETECT_DIRECTION, | |
} | |
json_data = requests.post( | |
request_url, headers=headers, data=params).json() | |
result_text = "" | |
for result in json_data["words_result"]: | |
numbers = result['words'] | |
result_text += numbers | |
result_text += "\n" | |
return result_text | |
def handwriting(self, image_path): | |
ocr_url = OCR_URL + "handwriting" | |
request_url = ocr_url + "?access_token=" + self.ACCESS_TOKEN | |
headers = {'content-type': 'application/x-www-form-urlencoded'} | |
params = { | |
'image': self.encode_image(image_path), | |
'detect_direction': self.DETECT_DIRECTION, | |
'probability': self.PROBABILITY, | |
} | |
json_data = requests.post( | |
request_url, headers=headers, data=params).json() | |
result_text = "" | |
for result in json_data["words_result"]: | |
words = result['words'] | |
result_text += words | |
result_text += "\n" | |
return result_text | |
def get_ocr_result(self, image_path): | |
if self.OCR_TYPE == "通用文字": | |
return self.accurate_basic(image_path) | |
elif self.OCR_TYPE == "数字识别": | |
return self.numbers(image_path) | |
elif self.OCR_TYPE == "手写文字": | |
return self.handwriting(image_path) | |
class Window(QMainWindow): | |
def __init__(self): | |
super().__init__() | |
self.initWindow() | |
self.initUI() | |
self.initOCR() | |
self.center() | |
def initWindow(self): | |
self.setWindowTitle(window_title) | |
self.setFixedWidth(1280) | |
self.setFixedHeight(720) | |
def initUI(self): | |
self.initOCROptionSelection() | |
self.initImageSelection() | |
self.initTextResult() | |
def initOCR(self): | |
self.OCR = OCR(API_KEY, SECRET_KEY) | |
self.OCR.set_detect_direction("false") | |
self.OCR.set_language(Language.CHN_ENG) | |
self.OCR.set_paragraph("true") | |
self.OCR.set_probability("false") | |
def initOCROptionSelection(self): | |
self.ocrOptionLabel = QLabel("选择识别类型", self) | |
self.ocrOptionLabel.setFixedSize(QSize(220, 40)) | |
self.ocrOptionLabel.move(80, 80) | |
self.ocrOptionSelection = QComboBox(self) | |
self.ocrOptionSelection.move(320, 80) | |
self.ocrOptionSelection.setFixedSize(QSize(240, 40)) | |
self.ocrOptionSelection.addItem("通用文字") | |
self.ocrOptionSelection.addItem("数字识别") | |
self.ocrOptionSelection.addItem("手写文字") | |
self.ocrOptionSelection.currentIndexChanged.connect( | |
self.selectionChange) | |
def initImageSelection(self): | |
self.imageSelectionButton = QPushButton("选择需要识别的图片", self) | |
self.imageSelectionButton.setFixedSize(QSize(480, 80)) | |
self.imageSelectionButton.move(80, 180) | |
self.imageSelectionButton.clicked.connect(self.imageSelectionClicked) | |
self.imagePreviewImage = QLabel(self) | |
self.imagePreviewImage.setText("请选择要识别的图片") | |
self.imagePreviewImage.setStyleSheet( | |
"QLabel { background-color : gray; color : black; }") | |
self.imagePreviewImage.setAlignment(QtCore.Qt.AlignCenter) | |
self.imagePreviewImage.setFixedSize(QSize(480, 320)) | |
self.imagePreviewImage.setScaledContents(True) | |
self.imagePreviewImage.move(80, 320) | |
def initTextResult(self): | |
self.textResult = QTextEdit(self) | |
self.textResult.setFixedSize(QSize(480, 440)) | |
self.textResult.move(720, 80) | |
self.textCopy = QPushButton("复制到剪切板", self) | |
self.textCopy.setFixedSize(QSize(480, 80)) | |
self.textCopy.move(720, 560) | |
self.textCopy.clicked.connect(self.copyClicked) | |
def center(self): | |
qr = self.frameGeometry() | |
cp = QDesktopWidget().availableGeometry().center() | |
qr.moveCenter(cp) | |
self.move(qr.topLeft()) | |
def selectionChange(self, i): | |
self.OCR.OCR_TYPE = self.ocrOptionSelection.currentText() | |
@pyqtSlot() | |
def imageSelectionClicked(self): | |
selected_file = QFileDialog.getOpenFileName(self, "选择你要识别的图片", "~/") | |
file_path = selected_file[0] | |
self.imagePreviewImage.setPixmap(QPixmap(file_path)) | |
self.callOCR(file_path) | |
@pyqtSlot() | |
def copyClicked(self): | |
print("copy") | |
QApplication.clipboard().setText(self.textResult.toPlainText()) | |
msg = QMessageBox(self) | |
msg.setText('已复制到剪切板') | |
msg.exec_() | |
def callOCR(self, image_path): | |
ocr_result = self.OCR.get_ocr_result(image_path) | |
self.textResult.setPlainText(ocr_result) | |
def main(): | |
app = QApplication(sys.argv) | |
window = Window() | |
window.show() | |
sys.exit(app.exec_()) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment