This blog post is about the two machine-learning frameworks Smile and DeepLearning4j. I am presenting my feedback and some hints useful for any Java developers interested in machine learning.
Introduction and disclaimer
This blog post is based on my own experience with Machine learning frameworks and java coding. I am not the most efficient or skilled data scientist or expert in the ML domain though I have been using it for some research use-cases in the past and for a real in-production service this year.
I have been using machine-learning (and deep-learning) for the following use-cases (classification problems):
- How to identify code and bugs based on a certain type of code ie initialization code, exception code, IO code etc
- How to identify a security leak into a code? the sink, the source, the incriminated leak? especially faults by injection.
NLP Problem / how to identify newspaper titles or articles that may become viral?
Experiments
I began my experiments with NLP with the following frameworks and they were great :
- Knime: https://www.knime.com/
- Weka: https://www.cs.waikato.ac.nz/ml/weka/
- Tensorflow using Google Colaboratory. https://colab.research.google.com/
I quickly obtained some results while I was extracting, refining, and evaluating my dataset.
However; for the next step; I had to increase my dataset size (magnitude 10x-30x to reach more than 30000 documents). A bit small but I already was meeting the limits of Weka and KNime on some algorithms.
Why not using XYZ or XYY?
To implement my latest use-case (that is running in production for one month), I had to identify a technological solution to implement NLP classification.
This kind of problem is hard since it consumes a lot of resources (software; hardware) and may result in low accuracy.
I had basically three offered machine learning solutions :
- use a Saas service as Google Cloud Machine Learning or AMI from Amazon
- use the de facto leaders as Keras or Tensorflow or Pytorch
- use another solution
Here are my reasons to go along with the choice of a Java Framework. Everything is debatable and would you disagree with me, write it in the comments.
How the heck do you go in production with python?
A remaining issue to solve before switching to Python. I had basically three goals to achieve :
- identify a machine learning algorithm or technology to predict viral content. For this, fortunately, I found several research papers.
- Implement a prototype quickly (less than 2 months in fact) to assess if we can use the technology effectively
- Go to production in 3 months. Some hard requirements as the prediction computation time.
Go to production with python frameworks definitely looks tough with my schedule time and my lack (for the moment) of high-level python programming skills.
Why not a Saas service?
I wanted it. And still want to use one. I wanted to use GCML. However, our current infrastructure is all based on Amazon. And I gave more than a try on AMI and was dumbfounded by the obscure way to start with it. It has been a big fail, the UI was awkward and I understood Amazon was providing EC2 with all the preinstalled tools to perform deep learning .. in python.
Basically, I wanted an API kind of service where I can do the learning and the prediction without worrying about the computation resources.
GCML (and probably Azure) are the best for this kind of service.
Why not Pytorch, Kerras or Tensorflow ?
Python frameworks are the obvious best way to implement machine learning. However, to go live is another matter since my python skills were not good enough to stick with the schedule. I had to develop in parallel with the prediction system a whole integration with our current infrastructure to obtain the prediction, process it, and modify the system behavior including REST API, security, monitoring, and database connection.
My decision to use Tensorflow faced the issue that I had to integrate the tool with MongoDB and build a REST API in a really short time as processing, extracting the dataset, and build the vectors.
OK, you do not know python, what Java can do?
Notice before going further: Java frameworks are not as good (accurate; performant) as the previously mentioned solutions and not as fast.
I did not know if Deep learning was the ideal solution for my problem therefore I need frameworks allowing me to compare machine learning classic algorithms (forest; bayesian) and deep learning.
I used Smile for its wide list of implemented machine learning algorithms and DeepLearning4j. for my neural network experiments.
Clearly expect some surprises with both of the frameworks but you can achieve some great results.
Machine learning with the Smile framework
You will find more information about this framework on https://github.com/haifengl/smile.
The framework is in license Apache 2.0 and therefore business-friendly.
Several algorithms are provided in the framework, here is a small list of the classification tools :
- K-Nearest Neighbor
- Linear Discriminant Analysis
- Fisher’s Linear Discriminant
- Quadratic Discriminant analysis
- Regularized Discriminant Analysis
- Logistic Regression
- Maximum Entropy Classifier
- Multilayer Perceptron Neural Network
- Radial Basis Function Networks
- Support Vector Machines
- Decision Trees
- Random Forest
- Gradient Boosted Trees
- AdaBoost
The documentation is super great although the framework has been thought for the Scala developers, a Java developer may find enough help to build its tool.
Most of the time I spent, has been associated with the conversion of MongoDB entries (documents) into a valid dataset. Using AttributeSelection and building nominal attributes were not so easy and the implementation is quite a memory consuming.
I also had some difficulties between the algorithms I may choose and the use of sparse arrays and optimized data structures. Some algorithms were failing at runtime; blaming me to use unsupported features. I lost clearly some time with such limitations.
Deep learning with the DeepLearning4J framework
cnnComputationGraph = new ComputationGraph(config);
cnnComputationGraph.setListeners(
new ScoreIterationListener(100),
new EvaluativeListener(trainIterator, 1, InvocationType.EPOCH_END),
new PerformanceListener(1, true));
cnnComputationGraph.init();
log.info("[CNN] Training workspace config: {}", cnnComputationGraph.getConfiguration().getTrainingWorkspaceMode());
log.info("[CNN] Inference workspace config: {}", cnnComputationGraph.getConfiguration().getInferenceWorkspaceMode());
log.info("[CNN] Training launch...");
cnnComputationGraph.fit(trainIterator, nEpochs);
log.info("[CNN] Number of parameters by layer:");
for (final Layer l : cnnComputationGraph.getLayers()) {
log.info("[CNN] \t{}\t{}\t{}", l.conf().getLayer().getLayerName(), l.type().name(), l.numParams());
}
log.info("[CNN] Number of parameters for the graph numParams={}, summary={}", cnnComputationGraph.numParams(), cnnComputationGraph.summary());
I lead experiments with a CNN network to identify text patterns that may indicate a viral title into a document.
To build this CNN network, I also wanted to use Word2Vec to have a more flexible prediction based not only on the lexical similarity but also on the semantical axis.
DeepLearning4j is not a mature framework despite the interesting features it provides. The code is in Mai 2019, still in beta. Moreover, the Maven dependencies are providing a whole big set of transitive dependencies that may break your program.
The maven releases of DeepLearning4j are not frequently released (you may wait some months) and during that time, many bug fixes are done on the master branch without having the benefit to use it. The snapshot version is not available and building the project is a pain.
If I did not frighten you, the documentation is quite illegal for example the memory management (workspaces) is quite a mystery. Some exceptions and errors during the build of your network are genuinely disturbing especially when I tried to build an RNN network based on this example.
However, with some patience, I have been able to use the CNN network and do some nice predictions.
My main current issues with the framework are a really high memory consumption (20go) and slow performances because I still do not have access to an EC2 with GPU.
However, I have been able to build a prediction system based on a CNN network and NLP using DeepLearning4j and it is for the moment a success. Clearly I am planning to replace DeepLearning4j with a python equivalent but now I have some months to develop it in parallel.
And you, what would be your choice in such a situation ?