Skip to content

Commit 42cc379

Browse files
feat: ocr(android) (#96)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent dee900d commit 42cc379

File tree

11 files changed

+1308
-7
lines changed

11 files changed

+1308
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.Promise
5+
import com.facebook.react.bridge.ReactApplicationContext
6+
import com.swmansion.rnexecutorch.utils.ETError
7+
import com.swmansion.rnexecutorch.utils.ImageProcessor
8+
import org.opencv.android.OpenCVLoader
9+
import com.swmansion.rnexecutorch.models.ocr.Detector
10+
import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler
11+
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
12+
import com.swmansion.rnexecutorch.utils.Fetcher
13+
import com.swmansion.rnexecutorch.utils.ResourceType
14+
import org.opencv.imgproc.Imgproc
15+
16+
class OCR(reactContext: ReactApplicationContext) :
17+
NativeOCRSpec(reactContext) {
18+
19+
private lateinit var detector: Detector
20+
private lateinit var recognitionHandler: RecognitionHandler
21+
22+
companion object {
23+
const val NAME = "OCR"
24+
}
25+
26+
init {
27+
if (!OpenCVLoader.initLocal()) {
28+
Log.d("rn_executorch", "OpenCV not loaded")
29+
} else {
30+
Log.d("rn_executorch", "OpenCV loaded")
31+
}
32+
}
33+
34+
override fun loadModule(
35+
detectorSource: String,
36+
recognizerSourceLarge: String,
37+
recognizerSourceMedium: String,
38+
recognizerSourceSmall: String,
39+
symbols: String,
40+
languageDictPath: String,
41+
promise: Promise
42+
) {
43+
try {
44+
detector = Detector(reactApplicationContext)
45+
detector.loadModel(detectorSource)
46+
Fetcher.downloadResource(
47+
reactApplicationContext,
48+
languageDictPath,
49+
ResourceType.TXT,
50+
false,
51+
{ path, error ->
52+
if (error != null) {
53+
throw Error(error.message!!)
54+
}
55+
56+
recognitionHandler = RecognitionHandler(
57+
symbols,
58+
path!!,
59+
reactApplicationContext
60+
)
61+
62+
recognitionHandler.loadRecognizers(
63+
recognizerSourceLarge,
64+
recognizerSourceMedium,
65+
recognizerSourceSmall
66+
) { _, errorRecognizer ->
67+
if (errorRecognizer != null) {
68+
throw Error(errorRecognizer.message!!)
69+
}
70+
71+
promise.resolve(0)
72+
}
73+
})
74+
} catch (e: Exception) {
75+
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
76+
}
77+
}
78+
79+
override fun forward(input: String, promise: Promise) {
80+
try {
81+
val inputImage = ImageProcessor.readImage(input)
82+
val bBoxesList = detector.runModel(inputImage)
83+
val detectorSize = detector.getModelImageSize()
84+
Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY)
85+
val result = recognitionHandler.recognize(
86+
bBoxesList,
87+
inputImage,
88+
(detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(),
89+
(detectorSize.height * Constants.RECOGNIZER_RATIO).toInt()
90+
)
91+
promise.resolve(result)
92+
} catch (e: Exception) {
93+
Log.d("rn_executorch", "Error running model: ${e.message}")
94+
promise.reject(e.message!!, e.message)
95+
}
96+
}
97+
98+
override fun getName(): String {
99+
return NAME
100+
}
101+
}

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class RnExecutorchPackage : TurboReactPackage() {
2323
Classification(reactContext)
2424
} else if (name == ObjectDetection.NAME) {
2525
ObjectDetection(reactContext)
26+
} else if (name == OCR.NAME){
27+
OCR(reactContext)
2628
}
2729
else {
2830
null
@@ -74,6 +76,14 @@ class RnExecutorchPackage : TurboReactPackage() {
7476
false, // isCxxModule
7577
true
7678
)
79+
moduleInfos[OCR.NAME] = ReactModuleInfo(
80+
OCR.NAME,
81+
OCR.NAME,
82+
false, // canOverrideExistingModule
83+
false, // needsEagerInit
84+
false, // isCxxModule
85+
true
86+
)
7787
moduleInfos
7888
}
7989
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.swmansion.rnexecutorch.models.ocr
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.swmansion.rnexecutorch.models.BaseModel
6+
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
7+
import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils
8+
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
9+
import com.swmansion.rnexecutorch.utils.ImageProcessor
10+
import org.opencv.core.Mat
11+
import org.opencv.core.Scalar
12+
import org.opencv.core.Size
13+
import org.pytorch.executorch.EValue
14+
15+
class Detector(reactApplicationContext: ReactApplicationContext) :
16+
BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
17+
private lateinit var originalSize: Size
18+
19+
fun getModelImageSize(): Size {
20+
val inputShape = module.getInputShape(0)
21+
val width = inputShape[inputShape.lastIndex]
22+
val height = inputShape[inputShape.lastIndex - 1]
23+
24+
val modelImageSize = Size(height.toDouble(), width.toDouble())
25+
26+
return modelImageSize
27+
}
28+
29+
override fun preprocess(input: Mat): EValue {
30+
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
31+
val resizedImage = ImageProcessor.resizeWithPadding(
32+
input,
33+
getModelImageSize().width.toInt(),
34+
getModelImageSize().height.toInt()
35+
)
36+
37+
return ImageProcessor.matToEValue(
38+
resizedImage,
39+
module.getInputShape(0),
40+
Constants.MEAN,
41+
Constants.VARIANCE
42+
)
43+
}
44+
45+
override fun postprocess(output: Array<EValue>): List<OCRbBox> {
46+
val outputTensor = output[0].toTensor()
47+
val outputArray = outputTensor.dataAsFloatArray
48+
val modelImageSize = getModelImageSize()
49+
50+
val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats(
51+
outputArray,
52+
Size(modelImageSize.width / 2, modelImageSize.height / 2)
53+
)
54+
var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(
55+
scoreText,
56+
scoreLink,
57+
Constants.TEXT_THRESHOLD,
58+
Constants.LINK_THRESHOLD,
59+
Constants.LOW_TEXT_THRESHOLD
60+
)
61+
bBoxesList =
62+
DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat())
63+
bBoxesList = DetectorUtils.groupTextBoxes(
64+
bBoxesList,
65+
Constants.CENTER_THRESHOLD,
66+
Constants.DISTANCE_THRESHOLD,
67+
Constants.HEIGHT_THRESHOLD,
68+
Constants.MIN_SIDE_THRESHOLD,
69+
Constants.MAX_SIDE_THRESHOLD,
70+
Constants.MAX_WIDTH
71+
)
72+
73+
return bBoxesList.toList()
74+
}
75+
76+
override fun runModel(input: Mat): List<OCRbBox> {
77+
return postprocess(forward(preprocess(input)))
78+
}
79+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package com.swmansion.rnexecutorch.models.ocr
2+
3+
import com.facebook.react.bridge.Arguments
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.facebook.react.bridge.WritableArray
6+
import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter
7+
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
8+
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
9+
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
10+
import com.swmansion.rnexecutorch.utils.ImageProcessor
11+
import org.opencv.core.Core
12+
import org.opencv.core.Mat
13+
14+
class RecognitionHandler(
15+
symbols: String,
16+
languageDictPath: String,
17+
reactApplicationContext: ReactApplicationContext
18+
) {
19+
private val recognizerLarge = Recognizer(reactApplicationContext)
20+
private val recognizerMedium = Recognizer(reactApplicationContext)
21+
private val recognizerSmall = Recognizer(reactApplicationContext)
22+
private val converter = CTCLabelConverter(symbols, mapOf(languageDictPath to "key"))
23+
24+
private fun runModel(croppedImage: Mat): Pair<List<Int>, Double> {
25+
val result: Pair<List<Int>, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) {
26+
recognizerLarge.runModel(croppedImage)
27+
} else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) {
28+
recognizerMedium.runModel(croppedImage)
29+
} else {
30+
recognizerSmall.runModel(croppedImage)
31+
}
32+
33+
return result
34+
}
35+
36+
fun loadRecognizers(
37+
largeRecognizerPath: String,
38+
mediumRecognizerPath: String,
39+
smallRecognizerPath: String,
40+
onComplete: (Int, Exception?) -> Unit
41+
) {
42+
try {
43+
recognizerLarge.loadModel(largeRecognizerPath)
44+
recognizerMedium.loadModel(mediumRecognizerPath)
45+
recognizerSmall.loadModel(smallRecognizerPath)
46+
onComplete(0, null)
47+
} catch (e: Exception) {
48+
onComplete(1, e)
49+
}
50+
}
51+
52+
fun recognize(
53+
bBoxesList: List<OCRbBox>,
54+
imgGray: Mat,
55+
desiredWidth: Int,
56+
desiredHeight: Int
57+
): WritableArray {
58+
val res: WritableArray = Arguments.createArray()
59+
val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings(
60+
imgGray.width(),
61+
imgGray.height(),
62+
desiredWidth,
63+
desiredHeight
64+
)
65+
66+
val left = ratioAndPadding["left"] as Int
67+
val top = ratioAndPadding["top"] as Int
68+
val resizeRatio = ratioAndPadding["resizeRatio"] as Float
69+
val resizedImg = ImageProcessor.resizeWithPadding(
70+
imgGray,
71+
desiredWidth,
72+
desiredHeight
73+
)
74+
75+
for (box in bBoxesList) {
76+
var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT)
77+
if (croppedImage.empty()) {
78+
continue
79+
}
80+
81+
croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST)
82+
83+
var result = runModel(croppedImage)
84+
var confidenceScore = result.second
85+
86+
if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) {
87+
Core.rotate(croppedImage, croppedImage, Core.ROTATE_180)
88+
val rotatedResult = runModel(croppedImage)
89+
val rotatedConfidenceScore = rotatedResult.second
90+
if (rotatedConfidenceScore > confidenceScore) {
91+
result = rotatedResult
92+
confidenceScore = rotatedConfidenceScore
93+
}
94+
}
95+
96+
val predIndex = result.first
97+
val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size)
98+
99+
for (bBox in box.bBox) {
100+
bBox.x = (bBox.x - left) * resizeRatio
101+
bBox.y = (bBox.y - top) * resizeRatio
102+
}
103+
104+
val resMap = Arguments.createMap()
105+
106+
resMap.putString("text", decodedTexts[0])
107+
resMap.putArray("bbox", box.toWritableArray())
108+
resMap.putDouble("confidence", confidenceScore)
109+
110+
res.pushMap(resMap)
111+
}
112+
113+
return res
114+
}
115+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package com.swmansion.rnexecutorch.models.ocr
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.swmansion.rnexecutorch.models.BaseModel
5+
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
6+
import com.swmansion.rnexecutorch.utils.ImageProcessor
7+
import org.opencv.core.Mat
8+
import org.opencv.core.Size
9+
import org.pytorch.executorch.EValue
10+
11+
class Recognizer(reactApplicationContext: ReactApplicationContext) :
12+
BaseModel<Mat, Pair<List<Int>, Double>>(reactApplicationContext) {
13+
14+
private fun getModelOutputSize(): Size {
15+
val outputShape = module.getOutputShape(0)
16+
val width = outputShape[outputShape.lastIndex]
17+
val height = outputShape[outputShape.lastIndex - 1]
18+
19+
return Size(height.toDouble(), width.toDouble())
20+
}
21+
22+
override fun preprocess(input: Mat): EValue {
23+
return ImageProcessor.matToEValueGray(input)
24+
}
25+
26+
override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
27+
val modelOutputHeight = getModelOutputSize().height.toInt()
28+
val tensor = output[0].toTensor().dataAsFloatArray
29+
val numElements = tensor.size
30+
val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight
31+
val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F)
32+
var counter = 0
33+
var currentRow = 0
34+
for (num in tensor) {
35+
resultMat.put(currentRow, counter, floatArrayOf(num))
36+
counter++
37+
if (counter >= modelOutputHeight) {
38+
counter = 0
39+
currentRow++
40+
}
41+
}
42+
43+
var probabilities = RecognizerUtils.softmax(resultMat)
44+
val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight)
45+
probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm)
46+
val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities)
47+
48+
val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices)
49+
return Pair(indices, confidenceScore)
50+
}
51+
52+
53+
override fun runModel(input: Mat): Pair<List<Int>, Double> {
54+
return postprocess(module.forward(preprocess(input)))
55+
}
56+
}

0 commit comments

Comments
 (0)