19

Recognizing Handwritten digits from a JavaFX application using Deeplearning4j

 2 years ago
source link: http://fxapps.blogspot.com/2017/06/recognizing-handwritten-digits-from.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Recognizing Handwritten digits from a JavaFX application using Deeplearning4j



We already talked about tensorflow and JavaFX on this blog, but tensorflow Java API is still incomplete. A mature and well documented API is DeepLearning4J.


In this example we load the trained model in our application, create a canvas for writing and enter is pressed, the canvas image is resized and sent to the deeplearning4j trained model for recognition:

Screenshot%2Bfrom%2B2017-06-05%2B22-52-16.png
Screenshot%2Bfrom%2B2017-06-05%2B23-24-58.png
Screenshot%2Bfrom%2B2017-06-05%2B23-25-29.png
The way it "guess" the digit is like the "hello world" for  deep neural networks. A neuron roughly mimics the human brain neuron and it has a weight, which controls when the neuron is activated. A neural network consists of many neurons linked to each other and organized in layers. What we do is provide to our neural network labeled data and adjust the weights of our neurons until it is able to correctly predict values for the given data, this is called training.


Once it is trained, we test the neural network against known labeled data to measure the neural network precision (in our case the precision is 97.5%!). In our case we use the famous MNIST database.


Because it has hidden layers between the input layer (where we input our data) and the output layer (where we get our predictions), we call it deep neural network. We have many other concepts and types of neural networks, I encourage you to watch some videos about the subject on youtube.

And if it is the first time you reading about this stuff, be aware that it won't be the last time!

If you try the code you may find that it is not so precise as this web application, for example. The reason is that I didn't handle the image precisely before sending it for prediction, we just resize it to 28x28 pixels as required by our trained model.

The code of the JavaFX application is below and the full project is on my github, including the training Java code, which was created using deeplearning4j examples.

package org.fxapps.deeplearning;

import java.awt.Graphics; import java.awt.Image; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException;

import org.datavec.image.loader.NativeImageLoader; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;

import javafx.application.Application; import javafx.embed.swing.SwingFXUtils; import javafx.geometry.Pos; import javafx.scene.Scene; import javafx.scene.canvas.Canvas; import javafx.scene.canvas.GraphicsContext; import javafx.scene.control.Label; import javafx.scene.image.ImageView; import javafx.scene.image.WritableImage; import javafx.scene.input.KeyCode; import javafx.scene.input.MouseButton; import javafx.scene.layout.HBox; import javafx.scene.layout.VBox; import javafx.scene.paint.Color; import javafx.scene.shape.StrokeLineCap; import javafx.stage.Stage;

public class MnistTestFXApp extends Application {

private final int CANVAS_WIDTH = 150; private final int CANVAS_HEIGHT = 150; private NativeImageLoader loader; private MultiLayerNetwork model; private Label lblResult;

public static void main(String[] args) throws IOException { launch(); }

@Override public void start(Stage stage) throws Exception { Canvas canvas = new Canvas(CANVAS_WIDTH, CANVAS_HEIGHT); ImageView imgView = new ImageView(); GraphicsContext ctx = canvas.getGraphicsContext2D(); model = ModelSerializer.restoreMultiLayerNetwork(new File("minist-model.zip")); loader = new NativeImageLoader(28,28,1,true); imgView.setFitHeight(100); imgView.setFitWidth(100); ctx.setLineWidth(10); ctx.setLineCap(StrokeLineCap.SQUARE); lblResult = new Label(); HBox hbBottom = new HBox(10, imgView, lblResult); VBox root = new VBox(5, canvas, hbBottom); hbBottom.setAlignment(Pos.CENTER); root.setAlignment(Pos.CENTER); Scene scene = new Scene(root, 520, 300); stage.setScene(scene); stage.show(); stage.setTitle("Handwritten digits recognition"); canvas.setOnMousePressed(e -> { ctx.setStroke(Color.WHITE); ctx.beginPath(); ctx.moveTo(e.getX(), e.getY()); ctx.stroke(); }); canvas.setOnMouseDragged(e -> { ctx.setStroke(Color.WHITE); ctx.lineTo(e.getX(), e.getY()); ctx.stroke(); }); canvas.setOnMouseClicked(e -> { if (e.getButton() == MouseButton.SECONDARY) { clear(ctx); } }); canvas.setOnKeyReleased(e -> { if(e.getCode() == KeyCode.ENTER) { BufferedImage scaledImg = getScaledImage(canvas); imgView.setImage(SwingFXUtils.toFXImage(scaledImg, null)); try { predictImage(scaledImg); } catch (Exception e1) { e1.printStackTrace(); } } }); clear(ctx); canvas.requestFocus(); }

private BufferedImage getScaledImage(Canvas canvas) { // for a better recognition we should improve this part of how we retrieve the image from the canvas WritableImage writableImage = new WritableImage(CANVAS_WIDTH, CANVAS_HEIGHT); canvas.snapshot(null, writableImage); Image tmp = SwingFXUtils.fromFXImage(writableImage, null).getScaledInstance(28, 28, Image.SCALE_SMOOTH); BufferedImage scaledImg = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); Graphics graphics = scaledImg.getGraphics(); graphics.drawImage(tmp, 0, 0, null); graphics.dispose(); return scaledImg; }

private void clear(GraphicsContext ctx) { ctx.setFill(Color.BLACK); ctx.fillRect(0, 0, 300, 300); } private void predictImage(BufferedImage img ) throws IOException { ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0, 1); INDArray image = loader.asRowVector(img); imagePreProcessingScaler.transform(image); INDArray output = model.output(image); String putStr = output.toString(); lblResult.setText("Prediction: " + model.predict(image)[0] + "\n " + putStr); }

}


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK