Ask Your Question
0

Problems with training a two output layer ANN_MLP from a CSV File

asked 2017-02-15 17:10:19 -0600

José Jácome gravatar image

updated 2017-02-16 08:43:37 -0600

kbarni gravatar image

Hi, I'm trying to train a ANN_MLP with two layer output, I generated a CVS file and I read it with the function TrainData::loadFromCSV() , I have the following CSV file for recognize number 1 and 0:

0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,1
1,1,1,1,1,0,0,1,1,0,0,1,,1,1,1,1,0

In Opencv2 I noticed that at the final of the line de CSV files have a ;, but in the function loadFromCSV() it seams that the output argument is a one-dimmension array

That's my sample code

#include <iostream>
#include <opencv2/ml.hpp>

using namespace std;
using namespace cv;
using namespace cv::ml;

int main(int argc, char *argv[])
{
    Ptr<ANN_MLP> nnetwork = ANN_MLP::create();
    cout << "Leyendo Datos" << endl;
    Ptr<TrainData> datos = TrainData::loadFromCSV("/home/josejacomeb/QT/MLP_DosCapas/mlp_2capas.csv",ROW_SAMPLE);
    vector<int> layerSizes = { 16, //Numero de Entradas
                               32,   //Capa Oculta
                               2  //Capa de Salida, para 2 Numeros
                             };
    cout << datos->getSamples()<< endl;
    nnetwork->setLayerSizes(layerSizes);
    nnetwork->setActivationFunction(ANN_MLP::SIGMOID_SYM,0.6,1);
    //Entrenamiento
    nnetwork->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS,100000,0.0000001));
    nnetwork->setTrainMethod(ANN_MLP::BACKPROP);
    cout << "Creando la RNA" << endl;
    nnetwork->train(datos);
    printf( "Entrenado \n");
    FileStorage fs("/home/josejacomeb/QT/MLP_DosCapas/parametros.xml",FileStorage::WRITE);
    nnetwork->write(fs);
    fs.release();
    cout << "Escribiendo XML" << endl;
}

So I've got the following error:

Opencv Error: Bad argument (output training data should be a floating-point matrix with  the number of rows equal to the number of training samples and the number of columns equal to the size of last (output) layer) in prepare_to_train, file /build/opencv/src/opencv-3.2.0/modules/ml/src/ann_mlp.cpp

I read in this forum that:

the train responses for an ann differ a bit from the usual opencv ml approach. if have 2 output neurons in your ann, need 2 output neurons for each training feature too, not a single "class label" (like with e.g. an SVM).

How can I create a 2 dimmension array from TrainData::loadFromCSV() for train the ANN_MLP?

I use Opencv 3.2 and GCC6

edit retag flag offensive close merge delete

1 answer

Sort by » oldest newest most voted
2

answered 2017-02-16 01:27:51 -0600

berak gravatar image

updated 2017-02-17 01:54:29 -0600

your ANN needs "one-hot-encoded" responses to train, [1,0] for label 0, and [0,1] for label 1.

you can either change the csv file, so it has 16 data numbers and 2 response labels at the end:

0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,0,1
1,1,1,1,1,0,0,1,1,0,0,1,,1,1,1,1,1,0

and use

Ptr<TrainData> datos = TrainData::loadFromCSV("/home/josejacomeb/QT/MLP_DosCapas/mlp_2capas.csv", 0, 16, 18);

(btw, not ROW_SAMPLE there !)

or, fix the responses after reading your unchanged csv:

 Mat labels = datos->getTrainResponses(); 
 Mat responses(num_samples, 2, CV_32F, 0.0f);
 for (size_t i=0; i<num_samples; i++) {
      int id = (int)labels.at<float>(i);  // 0 or 1
      responses.at<float>(i, id) = 1;
 }
 ann->train(datos->getTrainSamples(), 0, responses);
edit flag offensive delete link more

Comments

Thanks so much, it works! I modified the csv file and I used the first Code! It works Now.... A final question, for a ten layer ANN_MLP I should add the responses as for 0 0,1,2,3,4,5,6,7,8,9 for 1 1,2,3,4,5,6,7,8,9,0 . . . for 9 9,0,1,2,3,4,5,6,7,8 etc.?

José Jácome gravatar imageJosé Jácome ( 2017-02-19 22:06:03 -0600 )edit

for a 10 layer ann, you have to add 10 numbers, [1,0,0,0,0,0,0,0,0,0] for label 0, [0,1,0,0,0,0,0,0,0,0] for label 1, [0,0,1,0,0,0,0,0,0,0] for label 3, etc.

then: TrainData::loadFromCSV("my.csv", 0, 16, 26);

berak gravatar imageberak ( 2017-02-20 03:32:07 -0600 )edit

Question Tools

1 follower

Stats

Asked: 2017-02-15 17:10:19 -0600

Seen: 807 times

Last updated: Feb 17 '17