Skip to content

Commit 812f014

Browse files
Merge branch 'v1.x' into mubaidr-patch-1
2 parents 79096ba + e72c9ab commit 812f014

File tree

8 files changed

+276
-16
lines changed

8 files changed

+276
-16
lines changed

bower.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@
3131
"node_modules",
3232
"test"
3333
],
34-
"version": "1.3.1"
34+
"version": "1.4.0"
3535
}

browser.js

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* license: MIT (http://opensource.org/licenses/MIT)
77
* author: Heather Arthur <[email protected]>
88
* homepage: https://github.com/brainjs/brain.js#readme
9-
* version: 1.3.1
9+
* version: 1.4.0
1010
*
1111
* acorn:
1212
* license: MIT (http://opensource.org/licenses/MIT)
@@ -1092,7 +1092,11 @@ var NeuralNetwork = function () {
10921092
momentum: 0.1, // multiply's against the specified "change" then adds to learning rate for change
10931093
callback: null, // a periodic call back that can be triggered while training
10941094
callbackPeriod: 10, // the number of iterations through the training data between callback calls
1095-
timeout: Infinity // the max number of milliseconds to train for
1095+
timeout: Infinity, // the max number of milliseconds to train for
1096+
praxis: null,
1097+
beta1: 0.9,
1098+
beta2: 0.999,
1099+
epsilon: 1e-8
10961100
};
10971101
}
10981102
}, {
@@ -1531,6 +1535,10 @@ var NeuralNetwork = function () {
15311535
endTime = _prepTraining2.endTime;
15321536

15331537

1538+
if (options.praxis === 'adam') {
1539+
this._setupAdam();
1540+
}
1541+
15341542
while (this._trainingTick(data, status, endTime)) {}
15351543
return status;
15361544
}
@@ -1737,6 +1745,71 @@ var NeuralNetwork = function () {
17371745
}
17381746
}
17391747
}
1748+
}, {
1749+
key: '_setupAdam',
1750+
value: function _setupAdam() {
1751+
this.biasChangesLow = [];
1752+
this.biasChangesHigh = [];
1753+
this.changesLow = [];
1754+
this.changesHigh = [];
1755+
this.iterations = 0;
1756+
1757+
for (var layer = 0; layer <= this.outputLayer; layer++) {
1758+
var size = this.sizes[layer];
1759+
if (layer > 0) {
1760+
this.biasChangesLow[layer] = (0, _zeros2.default)(size);
1761+
this.biasChangesHigh[layer] = (0, _zeros2.default)(size);
1762+
this.changesLow[layer] = new Array(size);
1763+
this.changesHigh[layer] = new Array(size);
1764+
1765+
for (var node = 0; node < size; node++) {
1766+
var prevSize = this.sizes[layer - 1];
1767+
this.changesLow[layer][node] = (0, _zeros2.default)(prevSize);
1768+
this.changesHigh[layer][node] = (0, _zeros2.default)(prevSize);
1769+
}
1770+
}
1771+
}
1772+
1773+
this._adjustWeights = this._adjustWeightsAdam;
1774+
}
1775+
}, {
1776+
key: '_adjustWeightsAdam',
1777+
value: function _adjustWeightsAdam() {
1778+
var trainOpts = this.trainOpts;
1779+
this.iterations++;
1780+
1781+
for (var layer = 1; layer <= this.outputLayer; layer++) {
1782+
var incoming = this.outputs[layer - 1];
1783+
1784+
for (var node = 0; node < this.sizes[layer]; node++) {
1785+
var delta = this.deltas[layer][node];
1786+
1787+
for (var k = 0; k < incoming.length; k++) {
1788+
var gradient = delta * incoming[k];
1789+
var changeLow = this.changesLow[layer][node][k] * trainOpts.beta1 + (1 - trainOpts.beta1) * gradient;
1790+
var changeHigh = this.changesHigh[layer][node][k] * trainOpts.beta2 + (1 - trainOpts.beta2) * gradient * gradient;
1791+
1792+
var momentumCorrection = changeLow / (1 - Math.pow(trainOpts.beta1, this.iterations));
1793+
var gradientCorrection = changeHigh / (1 - Math.pow(trainOpts.beta2, this.iterations));
1794+
1795+
this.changesLow[layer][node][k] = changeLow;
1796+
this.changesHigh[layer][node][k] = changeHigh;
1797+
this.weights[layer][node][k] += this.trainOpts.learningRate * momentumCorrection / (Math.sqrt(gradientCorrection) + trainOpts.epsilon);
1798+
}
1799+
1800+
var biasGradient = this.deltas[layer][node];
1801+
var biasChangeLow = this.biasChangesLow[layer][node] * trainOpts.beta1 + (1 - trainOpts.beta1) * biasGradient;
1802+
var biasChangeHigh = this.biasChangesHigh[layer][node] * trainOpts.beta2 + (1 - trainOpts.beta2) * biasGradient * biasGradient;
1803+
1804+
var biasMomentumCorrection = this.biasChangesLow[layer][node] / (1 - Math.pow(trainOpts.beta1, this.iterations));
1805+
var biasGradientCorrection = this.biasChangesHigh[layer][node] / (1 - Math.pow(trainOpts.beta2, this.iterations));
1806+
1807+
this.biasChangesLow[layer][node] = biasChangeLow;
1808+
this.biasChangesHigh[layer][node] = biasChangeHigh;
1809+
this.biases[layer][node] += trainOpts.learningRate * biasMomentumCorrection / (Math.sqrt(biasGradientCorrection) + trainOpts.epsilon);
1810+
}
1811+
}
1812+
}
17401813

17411814
/**
17421815
*

browser.min.js

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js

Lines changed: 74 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "brain.js",
33
"description": "Neural network library",
4-
"version": "1.3.1",
4+
"version": "1.4.0",
55
"author": "Heather Arthur <[email protected]>",
66
"repository": {
77
"type": "git",

src/neural-network.js

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ export default class NeuralNetwork {
2222
momentum: 0.1, // multiply's against the specified "change" then adds to learning rate for change
2323
callback: null, // a periodic call back that can be triggered while training
2424
callbackPeriod: 10, // the number of iterations through the training data between callback calls
25-
timeout: Infinity // the max number of milliseconds to train for
25+
timeout: Infinity, // the max number of milliseconds to train for
26+
praxis: null,
27+
beta1: 0.9,
28+
beta2: 0.999,
29+
epsilon: 1e-8,
2630
};
2731
}
2832

@@ -443,6 +447,10 @@ export default class NeuralNetwork {
443447
let status, endTime;
444448
({ data, status, endTime } = this._prepTraining(data, options));
445449

450+
if (options.praxis === 'adam') {
451+
this._setupAdam();
452+
}
453+
446454
while (this._trainingTick(data, status, endTime));
447455
return status;
448456
}
@@ -619,6 +627,69 @@ export default class NeuralNetwork {
619627
}
620628
}
621629

630+
_setupAdam() {
631+
this.biasChangesLow = [];
632+
this.biasChangesHigh = [];
633+
this.changesLow = [];
634+
this.changesHigh = [];
635+
this.iterations = 0;
636+
637+
for (let layer = 0; layer <= this.outputLayer; layer++) {
638+
let size = this.sizes[layer];
639+
if (layer > 0) {
640+
this.biasChangesLow[layer] = zeros(size);
641+
this.biasChangesHigh[layer] = zeros(size);
642+
this.changesLow[layer] = new Array(size);
643+
this.changesHigh[layer] = new Array(size);
644+
645+
for (let node = 0; node < size; node++) {
646+
let prevSize = this.sizes[layer - 1];
647+
this.changesLow[layer][node] = zeros(prevSize);
648+
this.changesHigh[layer][node] = zeros(prevSize);
649+
}
650+
}
651+
}
652+
653+
this._adjustWeights = this._adjustWeightsAdam;
654+
}
655+
656+
_adjustWeightsAdam() {
657+
const trainOpts = this.trainOpts;
658+
this.iterations++;
659+
660+
for (let layer = 1; layer <= this.outputLayer; layer++) {
661+
const incoming = this.outputs[layer - 1];
662+
663+
for (let node = 0; node < this.sizes[layer]; node++) {
664+
const delta = this.deltas[layer][node];
665+
666+
for (let k = 0; k < incoming.length; k++) {
667+
const gradient = delta * incoming[k];
668+
const changeLow = this.changesLow[layer][node][k] * trainOpts.beta1 + (1 - trainOpts.beta1) * gradient;
669+
const changeHigh = this.changesHigh[layer][node][k] * trainOpts.beta2 + (1 - trainOpts.beta2) * gradient * gradient;
670+
671+
const momentumCorrection = changeLow / (1 - Math.pow(trainOpts.beta1, this.iterations));
672+
const gradientCorrection = changeHigh / (1 - Math.pow(trainOpts.beta2, this.iterations));
673+
674+
this.changesLow[layer][node][k] = changeLow;
675+
this.changesHigh[layer][node][k] = changeHigh;
676+
this.weights[layer][node][k] += this.trainOpts.learningRate * momentumCorrection / (Math.sqrt(gradientCorrection) + trainOpts.epsilon);
677+
}
678+
679+
const biasGradient = this.deltas[layer][node];
680+
const biasChangeLow = this.biasChangesLow[layer][node] * trainOpts.beta1 + (1 - trainOpts.beta1) * biasGradient;
681+
const biasChangeHigh = this.biasChangesHigh[layer][node] * trainOpts.beta2 + (1 - trainOpts.beta2) * biasGradient * biasGradient;
682+
683+
const biasMomentumCorrection = this.biasChangesLow[layer][node] / (1 - Math.pow(trainOpts.beta1, this.iterations));
684+
const biasGradientCorrection = this.biasChangesHigh[layer][node] / (1 - Math.pow(trainOpts.beta2, this.iterations));
685+
686+
this.biasChangesLow[layer][node] = biasChangeLow;
687+
this.biasChangesHigh[layer][node] = biasChangeHigh;
688+
this.biases[layer][node] += trainOpts.learningRate * biasMomentumCorrection / (Math.sqrt(biasGradientCorrection) + trainOpts.epsilon);
689+
}
690+
}
691+
}
692+
622693
/**
623694
*
624695
* @param data

test/base/bitwise.js

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,20 @@ function testBitwise(data, op) {
1414
let res = net.train(data, { errorThresh: 0.003 });
1515

1616
data.forEach(d => {
17-
var actual = net.run(d.input)
17+
var actual = net.run(d.input);
1818
var expected = d.output;
19-
assert.ok(isAround(actual, expected), `failed to train "${op}" - expected: ${expected}, actual: ${actual}`);
19+
assert.ok(isAround(actual[0], expected[0]), `failed to train "${op}" - expected: ${expected}, actual: ${actual}`);
20+
});
21+
}
22+
23+
function testBitwiseAdam(data, op) {
24+
let net = new brain.NeuralNetwork();
25+
let res = net.train(data, { errorThresh: 0.003, learningRate: 0.05, praxis: 'adam' });
26+
27+
data.forEach(d => {
28+
var actual = net.run(d.input);
29+
var expected = d.output;
30+
assert.ok(isAround(actual[0], expected[0]), `failed to train "${op}" - expected: ${expected}, actual: ${actual}`);
2031
});
2132
}
2233

@@ -45,10 +56,10 @@ describe('bitwise functions sync training', () => {
4556
});
4657

4758
it('XOR function', () => {
48-
let xor = [{input: [0, 0], output: [0]},
49-
{input: [0, 1], output: [1]},
50-
{input: [1, 0], output: [1]},
51-
{input: [1, 1], output: [0]}];
59+
let xor = [{input: [0.001, 0.001], output: [0.001]},
60+
{input: [0.001, 1], output: [1]},
61+
{input: [1, 0.001], output: [1]},
62+
{input: [1, 1], output: [0.001]}];
5263
testBitwise(xor, 'xor');
5364
});
5465

@@ -67,4 +78,36 @@ describe('bitwise functions sync training', () => {
6778
{input: [1, 1], output: [1]}];
6879
testBitwise(and, 'and');
6980
});
81+
});
82+
83+
describe('bitwise using adam praxis functions sync training', () => {
84+
it('NOT function', () => {
85+
let not = [{input: [0], output: [1]},
86+
{input: [1], output: [0]}];
87+
testBitwiseAdam(not, 'not');
88+
});
89+
90+
it('XOR function', () => {
91+
let xor = [{input: [0.001, 0.001], output: [0.001]},
92+
{input: [0.001, 1], output: [1]},
93+
{input: [1, 0.001], output: [1]},
94+
{input: [1, 1], output: [0.001]}];
95+
testBitwiseAdam(xor, 'xor');
96+
});
97+
98+
it('OR function', () => {
99+
let or = [{input: [0, 0], output: [0]},
100+
{input: [0, 1], output: [1]},
101+
{input: [1, 0], output: [1]},
102+
{input: [1, 1], output: [1]}];
103+
testBitwiseAdam(or, 'or');
104+
});
105+
106+
it('AND function', () => {
107+
let and = [{input: [0, 0], output: [0]},
108+
{input: [0, 1], output: [0]},
109+
{input: [1, 0], output: [0]},
110+
{input: [1, 1], output: [1]}];
111+
testBitwiseAdam(and, 'and');
112+
});
70113
});

0 commit comments

Comments
 (0)