Ticker

6/recent/ticker-posts

Android Handwriting digit prediction app using Machine Learning Model in Kotlin

In this blog, we are going to see how to create a handwriting digit prediction app using a machine learning model in kotlin. 



Technology used:

  1. Kotlin
  2. Machine Learning
  3. Tensorflow


Dataset

The MNIST dataset is used to create the TFLite model.

Dataset demo


Python code link: CLICK HERE to open the code in the Google Colab.

First, create the TFLite model then upload it in the assets folder of the Android Studio.



Android

1. activity_main.xml

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
    xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    android:background="@android:color/black"
    tools:context="com.codewithgolap.tflite.mnist.MainActivity">
    <TableLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:padding="16dp"
        android:background="@android:color/white">
        <TextView
            style="@style/ResultText"
            android:fontFamily="@font/poppins_bold"
            android:text="@string/prediction"
            android:textColor="@android:color/black"
            android:textSize="19sp"
            android:letterSpacing="0.05"/>
        <TextView
            android:id="@+id/tv_prediction"
            style="@style/ResultText"
            android:text="@string/empty"
            android:textColor="@android:color/black"
            android:textSize="24sp"/>
        <TableRow
            android:layout_marginTop="16dp">
            <TextView
                style="@style/ResultText"
                android:text="@string/probability"
                android:background="#FFC107"
                android:textColor="@color/colorPrimary"
                android:fontFamily="@font/poppins_medium"
                android:letterSpacing="0.02"
                android:layout_marginEnd="2dp"/>
            <TextView
                style="@style/ResultText"
                android:text="@string/timecost"
                android:background="#FFC107"
                android:textColor="@color/colorPrimary"
                android:fontFamily="@font/poppins_medium"
                android:letterSpacing="0.02"
                android:layout_marginStart="2dp"/>
        </TableRow>

        <TableRow>
            <TextView
                android:id="@+id/tv_probability"
                style="@style/ResultText"
                android:text="@string/empty"
                android:background="#FFEB3B"
                android:textColor="@color/colorPrimary"
                android:fontFamily="@font/poppins_medium"
                android:letterSpacing="0.02"
                android:layout_marginEnd="2dp"/>
            <TextView
                android:id="@+id/tv_timecost"
                style="@style/ResultText"
                android:text="@string/empty"
                android:background="#FFEB3B"
                android:textColor="@color/colorPrimary"
                android:fontFamily="@font/poppins_medium"
                android:letterSpacing="0.02"
                android:layout_marginStart="2dp"/>
        </TableRow>
    </TableLayout>
    <com.nex3z.fingerpaintview.FingerPaintView
        android:id="@+id/fingerPaintView"
        android:layout_width="280dp"
        android:layout_height="280dp"
        android:layout_marginTop="48dp"
        android:layout_gravity="center"
        android:background="@android:color/white"
        android:foreground="@drawable/shape_rect_border"/>
    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:orientation="horizontal"
        android:paddingStart="16dp"
        android:paddingEnd="16dp"
        android:layout_marginTop="48dp">
        <Button
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:id="@+id/btn_detect"
            android:text="@string/detect"
            android:layout_weight="1"
            android:textSize="16sp"
            android:fontFamily="@font/poppins_medium"/>
        <Button
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:id="@+id/btn_clear"
            android:text="@string/clear"
            android:layout_weight="1"
            android:textSize="16sp"
            android:fontFamily="@font/poppins_medium"/>
    </LinearLayout>
</LinearLayout>

 

2. Classifier.java

package com.codewithgolap.tflite.mnist
import android.content.Context
import android.graphics.Bitmap
import android.os.SystemClock
import android.util.Log
import android.util.Size
import org.tensorflow.lite.Delegate
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.Tensor
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.nnapi.NnApiDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.Closeable
import java.nio.ByteBuffer
import java.nio.ByteOrder
class Classifier(
    context: Context,
    device: Device = Device.CPU,
    numThreads: Int = 4
) {
    private val delegate: Delegate? = when(device) {
        Device.CPU -> null
        Device.NNAPI -> NnApiDelegate()
        Device.GPU -> GpuDelegate()
    }
    private val interpreter: Interpreter = Interpreter(
        FileUtil.loadMappedFile(context, MODEL_FILE_NAME),
        Interpreter.Options().apply {
            setNumThreads(numThreads)
            delegate?.let { addDelegate(it) }
        }
    )
    private val inputTensor: Tensor = interpreter.getInputTensor(0)
    private val outputTensor: Tensor = interpreter.getOutputTensor(0)
    val inputShape: Size = with(inputTensor.shape()) { Size(this[2], this[1]) }
    private val imagePixels = IntArray(inputShape.height * inputShape.width)
    private val imageBuffer: ByteBuffer =
        ByteBuffer.allocateDirect(4 * inputShape.height * inputShape.width).apply {
            order(ByteOrder.nativeOrder())
        }
    private val outputBuffer: TensorBuffer =
        TensorBuffer.createFixedSize(outputTensor.shape(), outputTensor.dataType())
    init {
        Log.v(
            LOG_TAG, "[Input] shape = ${inputTensor.shape()?.contentToString()}, " +
                    "dataType = ${inputTensor.dataType()}")
        Log.v(
            LOG_TAG, "[Output] shape = ${outputTensor.shape()?.contentToString()}, " +
                    "dataType = ${outputTensor.dataType()}")
    }
    fun classify(image: Bitmap): Recognition {
        convertBitmapToByteBuffer(image)
        val start = SystemClock.uptimeMillis()
        interpreter.run(imageBuffer, outputBuffer.buffer.rewind())
        val end = SystemClock.uptimeMillis()
        val timeCost = end - start
        val probs = outputBuffer.floatArray
        val top = probs.argMax()
        Log.v(LOG_TAG, "classify(): timeCost = $timeCost, top = $top, probs = ${probs.contentToString()}")
        return Recognition(top, probs[top], timeCost)
    }
    fun close() {
        interpreter.close()
        if (delegate is Closeable) {
            delegate.close()
        }
    }
    private fun convertBitmapToByteBuffer(bitmap: Bitmap) {
        imageBuffer.rewind()
        bitmap.getPixels(imagePixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
        for (i in 0 until inputShape.width * inputShape.height) {
            val pixel: Int = imagePixels[i]
            imageBuffer.putFloat(convertPixel(pixel))
        }
    }
    private fun convertPixel(color: Int): Float {
        return (255 - ((color shr 16 and 0xFF) * 0.299f
                + (color shr 8 and 0xFF) * 0.587f
                + (color and 0xFF) * 0.114f)) / 255.0f
    }
    companion object {
        private val LOG_TAG: String = Classifier::class.java.simpleName
        private const val MODEL_FILE_NAME: String = "mnist.tflite"
    }
}
fun FloatArray.argMax(): Int {
    return this.withIndex().maxByOrNull { it.value }?.index
        ?: throw IllegalArgumentException("Cannot find arg max in empty list")
}


3. Device.java

package com.codewithgolap.tflite.mnist
enum class Device {
    CPU,
    NNAPI,
    GPU
}



4. Recognition.java

package com.codewithgolap.tflite.mnist
data class Recognition(
    val label: Int,
    val confidence: Float,
    val timeCost: Long
)


5. MainActivity.java

package com.codewithgolap.tflite.mnist
import android.graphics.Bitmap
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import kotlinx.android.synthetic.main.activity_main.*
import java.io.IOException
class MainActivity : AppCompatActivity() {
    // call out classifier class
    private lateinit var classifier: Classifier
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        init()
    }
    private fun init() {
        initClassifier()
        initView()
    }
    // if your internet connection and classifier code is okay there will be no failed msg
    private fun initClassifier() {
        try {
            classifier = Classifier(this)
            Log.v(LOG_TAG, "Classifier initialized")
        } catch (e: IOException) {
            Toast.makeText(this, R.string.failed_to_create_classifier, Toast.LENGTH_LONG).show()
            Log.e(LOG_TAG, "init(): Failed to create Classifier", e)
        }
    }
    // buttons click events
    private fun initView() {
        btn_detect.setOnClickListener { onDetectClick() }
        btn_clear.setOnClickListener { clearResult() }
    }
    private fun onDetectClick() {
        if (!this::classifier.isInitialized) {
            Log.e(LOG_TAG, "onDetectClick(): Classifier is not initialized")
            return
        } else if (fingerPaintView.isEmpty) {
            Toast.makeText(this, R.string.please_write_a_digit, Toast.LENGTH_SHORT).show()
            return
        }
        // when we draw sometihing on the finerpaint view it will call the renderResult function
        val image: Bitmap = fingerPaintView.exportToBitmap(
            classifier.inputShape.width, classifier.inputShape.height
        )
        val result = classifier.classify(image)
        renderResult(result)
    }
    // in this function we will get the label that is the digit, confirence that is the probability and the time cost
    private fun renderResult(result: Recognition) {
        tv_prediction.text = java.lang.String.valueOf(result.label)
        tv_probability.text = java.lang.String.valueOf(result.confidence)
        tv_timecost.text = java.lang.String.format(
            getString(R.string.timecost_value),
            result.timeCost
        )
    }
    // when click the clear button all data will be gone
    private fun clearResult() {
        fingerPaintView.clear()
        tv_prediction.setText(R.string.empty)
        tv_probability.setText(R.string.empty)
        tv_timecost.setText(R.string.empty)
    }
    override fun onDestroy() {
        super.onDestroy()
//        classifier.close()
    }
    companion object {
        private val LOG_TAG: String = MainActivity::class.java.simpleName
    }
}




Watch the full video to know more step by step








Post a Comment

0 Comments