blob: d0b8a172be3bd6eb00f57e32371b311604291346 [file] [log] [blame]
// Copyright 2008 Google Inc.
// All Rights Reserved.
// Author: ahmadab@google.com (Ahmad Abdulkader)
//
// neural_net.cpp: Declarations of a class for an object that
// represents an arbitrary network of neurons
//
#include <vector>
#include <string>
#include "neural_net.h"
#include "input_file_buffer.h"
namespace tesseract {
// Instantiate all supported templates
template bool NeuralNet::FeedForward(const float *inputs, float *outputs);
template bool NeuralNet::FeedForward(const double *inputs, double *outputs);
template bool NeuralNet::FastFeedForward(const float *inputs, float *outputs);
template bool NeuralNet::FastFeedForward(const double *inputs,
double *outputs);
template bool NeuralNet::ReadBinary(InputFileBuffer *input_buffer);
NeuralNet::NeuralNet() {
Init();
}
NeuralNet::~NeuralNet() {
// clean up the wts chunks vector
for(int vec = 0; vec < wts_vec_.size(); vec++) {
delete wts_vec_[vec];
}
// clean up neurons
delete []neurons_;
// clean up nodes
for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
delete []fast_nodes_[node_idx].inputs;
}
}
// Initiaization function
void NeuralNet::Init() {
read_only_ = true;
auto_encoder_ = false;
alloc_wgt_cnt_ = 0;
wts_cnt_ = 0;
neuron_cnt_ = 0;
in_cnt_ = 0;
out_cnt_ = 0;
wts_vec_.clear();
neurons_ = NULL;
inputs_mean_.clear();
inputs_std_dev_.clear();
inputs_min_.clear();
inputs_max_.clear();
}
// Does a fast feedforward for read_only nets
// Templatized for float and double Types
template <typename Type> bool NeuralNet::FastFeedForward(const Type *inputs,
Type *outputs) {
int node_idx = 0;
Node *node = &fast_nodes_[0];
// feed inputs in and offset them by the pre-computed bias
for (node_idx = 0; node_idx < in_cnt_; node_idx++, node++) {
node->out = inputs[node_idx] - node->bias;
}
// compute nodes activations and outputs
for (;node_idx < neuron_cnt_; node_idx++, node++) {
double activation = -node->bias;
for (int fan_in_idx = 0; fan_in_idx < node->fan_in_cnt; fan_in_idx++) {
activation += (node->inputs[fan_in_idx].input_weight *
node->inputs[fan_in_idx].input_node->out);
}
node->out = Neuron::Sigmoid(activation);
}
// copy the outputs to the output buffers
node = &fast_nodes_[neuron_cnt_ - out_cnt_];
for (node_idx = 0; node_idx < out_cnt_; node_idx++, node++) {
outputs[node_idx] = node->out;
}
return true;
}
// Performs a feedforward for general nets. Used mainly in training mode
// Templatized for float and double Types
template <typename Type> bool NeuralNet::FeedForward(const Type *inputs,
Type *outputs) {
// call the fast version in case of readonly nets
if (read_only_) {
return FastFeedForward(inputs, outputs);
}
// clear all neurons
Clear();
// for auto encoders, apply no input normalization
if (auto_encoder_) {
for (int in = 0; in < in_cnt_; in++) {
neurons_[in].set_output(inputs[in]);
}
} else {
// Input normalization : subtract mean and divide by stddev
for (int in = 0; in < in_cnt_; in++) {
neurons_[in].set_output((inputs[in] - inputs_min_[in]) /
(inputs_max_[in] - inputs_min_[in]));
neurons_[in].set_output((neurons_[in].output() - inputs_mean_[in]) /
inputs_std_dev_[in]);
}
}
// compute the net outputs: follow a pull model each output pulls the
// outputs of its input nodes and so on
for (int out = neuron_cnt_ - out_cnt_; out < neuron_cnt_; out++) {
neurons_[out].FeedForward();
// copy the values to the output buffer
outputs[out] = neurons_[out].output();
}
return true;
}
// Sets a connection between two neurons
bool NeuralNet::SetConnection(int from, int to) {
// allocate the wgt
float *wts = AllocWgt(1);
if (wts == NULL) {
return false;
}
// register the connection
neurons_[to].AddFromConnection(neurons_ + from, wts, 1);
return true;
}
// Create a fast readonly version of the net
bool NeuralNet::CreateFastNet() {
fast_nodes_.resize(neuron_cnt_);
// build the node structures
int wts_cnt = 0;
for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
Node *node = &fast_nodes_[node_idx];
if (neurons_[node_idx].node_type() == Neuron::Input) {
// Input neurons have no fan-in
node->fan_in_cnt = 0;
node->inputs = NULL;
// Input bias is the normalization offset computed from
// training input stats
node->bias = inputs_min_[node_idx] +
(inputs_mean_[node_idx] *
(inputs_max_[node_idx] - inputs_min_[node_idx]));
} else {
node->bias = neurons_[node_idx].bias();
node->fan_in_cnt = neurons_[node_idx].fan_in_cnt();
// allocate memory for fan-in nodes
node->inputs = new WeightedNode[node->fan_in_cnt];
if (node->inputs == NULL) {
return false;
}
for (int fan_in = 0; fan_in < node->fan_in_cnt; fan_in++) {
// identify fan-in neuron
const int id = neurons_[node_idx].fan_in(fan_in)->id();
// Feedback connections are not allowed and should never happen
if (id >= node_idx) {
return false;
}
// add the the fan-in neuron and its wgt
node->inputs[fan_in].input_node = &fast_nodes_[id];
float wgt_val = neurons_[node_idx].fan_in_wts(fan_in);
// for input neurons normalize the wgt by the input scaling
// values to save time during feedforward
if (neurons_[node_idx].fan_in(fan_in)->node_type() == Neuron::Input) {
wgt_val /= ((inputs_max_[id] - inputs_min_[id]) *
inputs_std_dev_[id]);
}
node->inputs[fan_in].input_weight = wgt_val;
}
// incr wgt count to validate against at the end
wts_cnt += node->fan_in_cnt;
}
}
// sanity check
return wts_cnt_ == wts_cnt;
}
// returns a pointer to the requested set of weights
// Allocates in chunks
float * NeuralNet::AllocWgt(int wgt_cnt) {
// see if need to allocate a new chunk of wts
if (wts_vec_.size() == 0 || (alloc_wgt_cnt_ + wgt_cnt) > kWgtChunkSize) {
// add the new chunck to the wts_chunks vector
wts_vec_.push_back(new vector<float> (kWgtChunkSize));
alloc_wgt_cnt_ = 0;
}
float *ret_ptr = &((*wts_vec_.back())[alloc_wgt_cnt_]);
// incr usage counts
alloc_wgt_cnt_ += wgt_cnt;
wts_cnt_ += wgt_cnt;
return ret_ptr;
}
// create a new net object using an input file as a source
NeuralNet *NeuralNet::FromFile(const string file_name) {
// open the file
InputFileBuffer input_buff(file_name);
// create a new net object using input buffer
NeuralNet *net_obj = FromInputBuffer(&input_buff);
return net_obj;
}
// create a net object from an input buffer
NeuralNet *NeuralNet::FromInputBuffer(InputFileBuffer *ib) {
// create a new net object
NeuralNet *net_obj = new NeuralNet();
if (net_obj == NULL) {
return NULL;
}
// load the net
if (!net_obj->ReadBinary(ib)) {
delete net_obj;
net_obj = NULL;
}
return net_obj;
}
}