Skip to content

A neural network lib in C++20 optimized for CPU. Create, train and use a neural network in less than 10 lines of codes.

License

Notifications You must be signed in to change notification settings

MatthieuHernandez/StraightforwardNeuralNetwork

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


Straightforward Neural Network

Straightforward Neural Network is an open source neural network library in C++20 optimized for CPU. The goal of this library is to make the use of neural networks as easy as possible.

Documentation

See the full documentation here.

Classification datasets results

Dataset Name Data type Problem type Score Number of Parameters
Audio Cats and Dogs audio classification 91.04% Accurracy 382
Daily min temperatures time series regression 1.42 Mean Absolute Error 30
CIFAR-10 image classification 61.77% Accurracy 207210
Fashion-MNIST image classification 89.65% Accurracy 270926
MNIST image classification 98.71% Accurracy 261206
Wine multivariate classification 100.0% Accurracy 444
Iris multivariate classification 100.0% Accurracy 150

Installation (with CMake 3.17.1)

Linux, UNIX - GCC 13.1.0 or Clang 18.1.3

  • To compile open a command prompt and run cmake -G"Unix Makefiles" ./.. && make from the build folder.

  • To run the unit tests execute ./tests/unit_tests/UnitTests from build folder.

  • To run the dataset tests run ./ImportDatasets.sh and execute ./tests/dataset_tests/DatasetTests from build folder.

Windows - MSVC 19.41

  • You can generate a Visual Studio project by running cmake -G"Visual Studio 17 2022" ./.. from build folder.

  • To run the unit tests open ./build/tests/unit_tests/UnitTests.vcxproj in Visual Studio.

  • To run the dataset tests run ./build/ImportDatasets.sh and open ./build/tests/dataset_tests/DatasetTests.vcxproj in Visual Studio.

Use

Create, train and use a neural network in few lines of code.

using namespace snn;

Data data(problem::classification, inputData, expectedOutputs);

StraightforwardNeuralNetwork neuralNetwork({
    Input(1, 28, 28), // (C, X, Y)
    Convolution(16, 3, activation::ReLU), // 16 filters and (3, 3) kernels
    FullyConnected(92),
    FullyConnected(10, activation::identity, Softmax())
});

neuralNetwork.train(data, 0.90_acc || 20_s); // Train neural network on data until 90% accuracy or 20s

float accuracy = neuralNetwork.getGlobalClusteringRate(); // Retrieve the accuracy

see more details

License

Apache License 2.0