Scala-based Keras API for the Java bindings to TensorFlow
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Hanns Holger Rutz ff42514695 ... 7 months ago
.github/workflows add GH CI 7 months ago
.settings rename package to tensorflow.keras and tensorflow.keras 3 years ago
doc tf-keras java 3 years ago
project try to compile with current tensorflow-java <> 9 months ago
src ... 7 months ago
.gitignore try to compile with current tensorflow-java <> 9 months ago
LICENSE too much work for now <> 9 months ago
LICENSE_JAVA too much work for now <> 9 months ago dense <> 7 months ago
build.sbt up java bindings version <> 7 months ago
pom.xml_UNUSED try to compile with current tensorflow-java <> 9 months ago


Build Status


This is a project for testing a Scala-based Keras API for the Java bindings to TensorFlow. It is a fork from a Java-based project. The original author of the Java project was Dhruv Rajan, the original project was released under MIT license.

This project is (C)opyright 2021 by Hanns Holger Rutz. All rights reserved. The project is released under the GNU Lesser General Public License v2.1+ and comes with absolutely no warranties. To contact the author, send an e-mail to contact at

requirements / installation

The project builds with sbt. You may need to edit build.sbt to adjust tfClassifer.

Below are the original README notes.

Tensorflow-Keras (Java)

This repository contains a JVM implementation of the Keras API, built on Tensorflow Java. Keras is a high-level API for building and training deep learning models. This implementation aims to mirror the Python tf-keras syntax in a clean Java-friendly style.

It is a work-in-progress; feel free to add issues / comments! Happy to collaborate.


Below is a comparison of the Java and Python implementations of a model trained for the MNIST dataset.


import tensorflow as tf

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu', kernel_initializer="random_normal", bias_initializer="zeros"),
  tf.keras.layers.Dense(10, activation='softmax', kernel_initializer="random_normal", bias_initializer="zeros")

model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

(X_train, y_train), (X_val, y_val) = tf.keras.datasets.load_mnist(), y_train, val_data=(X_val, y_val), epochs=10, batch_size=100)

Java (

package org.tensorflow.keras.examples.mnist;

import org.tensorflow.Graph;
import org.tensorflow.keras.activations.Activations;
import org.tensorflow.keras.datasets.MNIST;
import org.tensorflow.keras.initializers.Initializers;
import org.tensorflow.keras.layers.Dense;
import org.tensorflow.keras.layers.Layers;
import org.tensorflow.keras.losses.Losses;
import org.tensorflow.keras.metrics.Metrics;
import org.tensorflow.keras.models.Model;
import org.tensorflow.keras.models.Sequential;
import org.tensorflow.keras.optimizers.Optimizers;
import org.tensorflow.op.Ops;
import org.tensorflow.utils.Pair;

public class MNISTKeras {
    private static Model<Float> model;
    private static Model.CompileOptions compileOptions;
    private static Model.FitOptions fitOptions;

    static {
        // Define Neural Network Model
        // Note: Layers can be constructed either from individual
        //       Option.Builder classes, or from the static helper
        //       methods defined in `Layers` which wrap the explicit builders
        //       to decrease verbosity.
        model = Sequential.of(
                Layers.input(28, 28),
                Layers.flatten(), // 28 * 28),

                // Using Layer Options Builder
                new Dense(128, Dense.Options.builder()
                // Using static helper Layers.dense(...)
                Layers.dense(10, Activations.softmax, Initializers.randomNormal, Initializers.zeros)

        // Model Compile Configuration
        compileOptions = Model.CompileOptions.builder()

        // Model Training Loop Configuration
        fitOptions = Model.FitOptions.builder()

    public static Model<Float> train() throws Exception {
        try (Graph graph = new Graph()) {
            // Create Tensorflow Ops Accessor
            Ops tf = Ops.create(graph);

            // Compile Model
            model.compile(tf, compileOptions);

            Pair<GraphLoader<Float>, GraphLoader<Float>> loaders = MNIST.graphLoaders2D();
            // GraphLoader objects contain AutoCloseable `Tensor` objects.
            try (GraphLoader<Float> train = loaders.first();
                 GraphLoader<Float> test = loaders.second()) {
                // Fit model
      , train, test, fitOptions);

        return model;

    public static void main(String[] args) throws Exception {

Scala (See

package org.tensorflow.keras.scala.examples

import org.tensorflow.Graph
import org.tensorflow.keras.activations.Activations.{relu, softmax}
import org.tensorflow.keras.datasets.FashionMNIST
import org.tensorflow.keras.losses.Losses.sparseCategoricalCrossentropy
import org.tensorflow.keras.metrics.Metrics.accuracy
import org.tensorflow.keras.models.Sequential
import org.tensorflow.keras.optimizers.Optimizers.sgd
import org.tensorflow.keras.scala.{Layers, Model}
import org.tensorflow.op.Ops
import org.tensorflow.utils.Pair

import scala.util.Using

object FashionMNISTKeras {
  val model: Model[java.lang.Float] = Sequential.of(
    Layers.input(28, 28),
    Layers.flatten(28 * 28),
    Layers.dense(128, activation = relu),
    Layers.dense(10, activation = softmax)

  def train(model: Model[java.lang.Float]): Model[java.lang.Float] = {
    Using.resource(new Graph()) { graph => {
      val tf: Ops = Ops.create(graph)
      // Compile Model
      model.compile(tf, optimizer = sgd, loss = sparseCategoricalCrossentropy, metrics = List(accuracy))

      val data: Pair[GraphLoader[java.lang.Float], GraphLoader[java.lang.Float]] = FashionMNIST.graphLoaders2D()
      // GraphLoader objects contain AutoCloseable `Tensors`.
      Using.resources(data.first(), data.second()) { (train, test) => {
        // Fit Model, train, test, epochs = 10, batchSize = 100)
    // Output Model

  def main(args: Array[String]): Unit = {


Keras is built around a few core abstractions which comprise much of what is needed to build deep learning models. These include Layers, Models, Optimizers, Activations, Losses, Metrics, Regularizers, and Initializers.


To use TensorFlow, we must have access to a Graph to build and run computations. In Python, this is constructed implicitly. In TF Java, the Graph and Ops objects must be created explicitly. Thus, we leave this (below) for the user to write, and allow the Ops object to be passed throughout keras Layer construction calls to provide access to core tensorflow operations.

try(Graph graph = new Graph) {
    Ops tf = Ops.create(graph);
    // Keras code here.

Layer API

Layers represent arbitrary graph computations on tensors. They are composable; meant to be "stacked" on each other to build neural networks. Thus they must encode differentiable functions (i.e. built using TF operations), and must contain trainable weight tensors that can be optimized using backpropagation.

  1. Using Layers. Layers have two main methods: tf), and op1, Operand op2, ...). is used to create the necessary graph components (variables), while defines the layer's logic in terms of these variables, and other tensorflow operators.

    TODO(#1): Currently, the build vs. call functionality is mixed, and handled entirely by the build function.

  2. Custom Layers. New custom layers can be built by subclassing any Layer class. Use Layer.addWeight(String name, Variable var) to add a weight tensor to the layer and Layer.addInitializer(String name, Assign initOp) to add its initializer, during build (See keras.layers.Dense for an example).

    TODO(#2): More robust Initializer support.

    TODO(#3): Add support for Regularizers.

  3. Multiple Input/Output. Keras Layers can express functions between labeled sets of tensors. Multiple inputs are handled by adding arguments to

    TODO(#4): Add support for multiple outputs; likely based on the the keras Network Layer.

  4. Arbitrarily-Shaped Input. So long as the layer's comptuations are implemented properly, keras layers can operate on arbitrarily shaped input (within the restrictions of the computation). It's really useful not to have to worry about things like tensor dimensions and batch size etc.; this is handled for the user at graph building time.

    TODO(#5): To make this more explicit and debuggable, layer definitions should include a Layer.computeOutputShape(Shape inputShape) method definition.

  5. Layer Types Keras provides a large library of pre-defined Layers, including convolutional, dropout, recurrent layers.

    TODO(#6): Add support for variety of Layer types.

    Note: Not all layers, currently, can just be wrappers around TF operations. For example, tf.dropout doesn't seem to be implemented as a core tf operation in Java.

Model API

In Keras, Models represent built neural networks, and provide easy abstractions for model training/evaluation, using theModel.compile,, and Model.evaluate. Additionally, the Sequential and Functional model construction APIs provide concise ways to specify models.

  1. Sequential API Sequential models are built by stacking Layer objects in an array/list.

     Model model = new Sequential(
         new Dense(100, Shape.make(10, 500))
         new Dense(50, Shape.make(10, 100))
         new Dense(10, Shape.make(10, 50)),

    TODO(#7): During Model.compile/, the input shape of each layer should be automatically obtained from the previous layer's output shape (with the exception of the first layer, whose input shape must be specified by the user.)

  2. Model Training Model training is implemented by the method. Currently, a Dataset object must be passed to this function. Currently, this means the whole dataset must be represented and iterated over in memory.

    TODO(#8): Allow Stream-based iteration over batches during training.

  3. Model Config Keras objects (Optimizer, Layer, etc...) are fairly simple, and are nested in a clean way. Thus JSON would likely be a good method for representing serializable configurations of Models.

    TODO(#9): Add support for model configuration (Perhaps using RapidJson or something similar)?

  4. Trained Model Serialization Serializing a trained Model for later inference is very useful. This would likely entail both serializing the model config, as well as the trained weight tensors. Still need to decide on a good way to do this.

    TODO(#10): Add support for Model Serialization.


Keras in Python supports input data in a number of forms (lists, numpy arrays, etc.). Efficient data representation / loading / batch iteration makes quite an impact on training performance, esp. on large datasets. So far, we've defined, which takes Tensor objects as input, and constructs batches as operands in the tensorflow graph using tf.slice(). There's also the class which is based on float[][], which delays the construction of Tensor objects until batch training time, and constructs 1 tensor per batch. It will probably be best to support multiple variants on this, and even integrate with a Java port of

Keras "Calling Convention"

Keras, as implemented in Python, provides a large number of pythonic language features to speed up development. Here we'll describe the differences/similarities in the Java API.

  1. Variable and Keyword Arguments Some Keras functions support a large number of keyword arguments with specified defaults. Java can't easily provide identical behavior, so we aim to come close by using the Builder Pattern. When keyword arguments are used in a method call, we define a simple Options/Builder class to manage these (see: Model.CompilerBuilder, Model.FitBuilder).

  2. Predefined Hyperparameters Optimizers, Activations, and Layers often have well-behaving default values that can be used automatically. These objects can be constructed using the default constructor, and their attributes will be populated using these defaults.

  3. String Aliases Keras allows default objects to be addressed by a string name (e.g. optimizer="adam", activation="sigmoid"). This is supported here in most cases; the names are matched against enums for each object type.

Keras Backend

The Keras Backend interface in python originally allowed for switching between various tensor library backends. Since this project is aimed at supporting tensorflow, it seems cleaner for now to depend on the core tensorflow implementation of core operations.


The ability to set breakpoints etc. within Layer class definitions would be very useful for debugging.

TODO(#11): Add support for training in tensorflow eager mode.


TODO(#12): Add support for Keras callbacks. This would allow things like model checkpointing, early stopping, Tensorboard integration, and other useful features.


One of the nice things about using the JVM is that we can have clean interface to other JVM languages like Scala and Kotlin, without having to re-implement everything.

TODO: Would be nice to have nice Scala/Kotlin-esque interfaces for those languages in addition to the Java API.