Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add support for multithreaded training in the neural net example #2454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
Loading
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clang format changes
  • Loading branch information
9prady9 committed Mar 28, 2019
commit 73ac2c665fa2431522c35744580bba7aa0e8952c
67 changes: 35 additions & 32 deletions 67 examples/machine_learning/neural_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
********************************************************/

#include <arrayfire.h>
#include <math.h>
#include <stdio.h>
#include <vector>
#include <af/util.h>
#include <list>
#include <string>
#include <thread>
#include <af/util.h>
#include <math.h>
#include <vector>
#include "mnist_common.h"

using namespace af;
Expand Down Expand Up @@ -157,7 +157,8 @@ double ann::train(const array &input, const array &target, double alpha,

if (verbose) {
if ((i + 1) % 10 == 0)
printf("Device: %d, Epoch: %4d, Error: %0.4f\n", af::getDevice(), i + 1, err);
printf("Device: %d, Epoch: %4d, Error: %0.4f\n",
af::getDevice(), i + 1, err);
}
}
return err;
Expand Down Expand Up @@ -209,7 +210,6 @@ int ann_demo(bool console, int perc) {
array train_output = network.predict(train_feats);
array test_output = network.predict(test_feats);


// Benchmark prediction
af::sync();
timer::start();
Expand All @@ -221,10 +221,12 @@ int ann_demo(bool console, int perc) {
accuracy(train_output, train_target), af::getDevice());

printf("Accuracy on testing data: %2.2f device: %d\n",
accuracy(test_output , test_target ), af::getDevice());
accuracy(test_output, test_target), af::getDevice());

printf("\nTraining time on device %d: %4.4lf s\n", af::getDevice(), train_time);
printf("Prediction time on device %d: %4.4lf s\n\n", af::getDevice(), test_time);
printf("\nTraining time on device %d: %4.4lf s\n", af::getDevice(),
train_time);
printf("Prediction time on device %d: %4.4lf s\n\n", af::getDevice(),
test_time);

if (!console) {
// Get 20 random test images.
Expand All @@ -236,47 +238,48 @@ int ann_demo(bool console, int perc) {
}

class learner {
public:
void learn(const unsigned d, const bool console, const int perc) {
printf("Starting new learner thread on device %d\n", d);
af::setDevice(d);
af::array r = af::randu(10);
ann_demo(console, perc);
}
public:
void learn(const unsigned d, const bool console, const int perc) {
printf("Starting new learner thread on device %d\n", d);
af::setDevice(d);
af::array r = af::randu(10);
ann_demo(console, perc);
}
};

int main(int argc, char** argv)
{
int main(int argc, char **argv) {
int device = argc > 1 ? atoi(argv[1]) : 0;
bool console = argc > 2 ? argv[2][0] == '-' : false;
int perc = argc > 3 ? atoi(argv[3]) : 60; // percentage training/test data
int perc = argc > 3 ? atoi(argv[3]) : 60; // percentage training/test data
af::info();
const unsigned dc = af::getDeviceCount();
printf("** ArrayFire ANN Demo **\n\n");
printf("Usage: %s deviceId console percentage\n", argv[0]);
printf("- deviceId: either a device id (>= 0). If -1, 1 training will be triggered per device\n");
printf(
"- deviceId: either a device id (>= 0). If -1, 1 training will be "
"triggered per device\n");
printf("- console: console mode\n");
printf("- percentage: percent of training/testing data, default 60% used for training\n");
printf(
"- percentage: percent of training/testing data, default 60% used for "
"training\n");
af::info();

std::list<learner> ls;
std::list<std::thread> ts;
try {
if (device < 0) {
for (unsigned i = 0; i < dc; ++i) {
ls.push_back(learner());
ts.push_back(std::thread(&learner::learn, ls.back(), i, console, perc));
}
}
else {
for (unsigned i = 0; i < dc; ++i) {
ls.push_back(learner());
ts.push_back(
std::thread(&learner::learn, ls.back(), i, console, perc));
}
} else {
ls.push_back(learner());
ts.push_back(std::thread(&learner::learn, ls.back(), device, console, perc));
ts.push_back(
std::thread(&learner::learn, ls.back(), device, console, perc));
}
for (auto& t : ts)
t.join();
} catch (af::exception &ae) {
std::cerr << ae.what() << std::endl;
}
for (auto &t : ts) t.join();
} catch (af::exception &ae) { std::cerr << ae.what() << std::endl; }

return 0;
}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.