Skip to content

Commit fc48ae6

Browse files
Merge pull request #281 from BrainJS/185-cross-validation-fixes
fix: Fix CrossValidate to have tests for when data too small
2 parents f0a1a56 + ca437f3 commit fc48ae6

File tree

14 files changed

+85
-53
lines changed

14 files changed

+85
-53
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ With multiple networks you can train in parallel like this:
279279
### Cross Validation
280280
[Cross Validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) can provide a less fragile way of training on larger data sets. The brain.js api provides Cross Validation in this example:
281281
```js
282-
const crossValidate = new CrossValidate(brain.NeuralNetwork, networkOptions);
283-
const stats = crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
282+
const crossValidate = new brain.CrossValidate(brain.NeuralNetwork, networkOptions);
283+
crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
284+
const json = crossValidate.toJSON(); // all stats in json as well as neural networks
284285
const net = crossValidate.toNeuralNetwork();
285286

286287

bower.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@
3131
"node_modules",
3232
"test"
3333
],
34-
"version": "1.4.1"
34+
"version": "1.4.2"
3535
}

browser.js

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* license: MIT (http://opensource.org/licenses/MIT)
77
* author: Heather Arthur <[email protected]>
88
* homepage: https://github.com/brainjs/brain.js#readme
9-
* version: 1.4.1
9+
* version: 1.4.2
1010
*
1111
* acorn:
1212
* license: MIT (http://opensource.org/licenses/MIT)
@@ -214,8 +214,13 @@ var CrossValidate = function () {
214214

215215
}, {
216216
key: "train",
217-
value: function train(data, trainOpts, k) {
218-
k = k || 4;
217+
value: function train(data) {
218+
var trainOpts = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
219+
var k = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 4;
220+
221+
if (data.length <= k) {
222+
throw new Error("Training set size is too small for " + data.length + " k folds of " + k);
223+
}
219224
var size = data.length / k;
220225

221226
if (data.constructor === Array) {
@@ -1946,8 +1951,8 @@ var NeuralNetwork = function () {
19461951
falseNeg: falseNeg,
19471952
falsePos: falsePos,
19481953
total: data.length,
1949-
precision: truePos / (truePos + falsePos),
1950-
recall: truePos / (truePos + falseNeg),
1954+
precision: truePos > 0 ? truePos / (truePos + falsePos) : 0,
1955+
recall: truePos > 0 ? truePos / (truePos + falseNeg) : 0,
19511956
accuracy: (trueNeg + truePos) / data.length
19521957
});
19531958
}

browser.min.js

Lines changed: 7 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/cross-validate.js

Lines changed: 7 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/cross-validate.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples-typescript/cross-validate.ts

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,11 @@ const trainingData = [
88
{ input: [1, 1], output: [0] },
99
{ input: [1, 0], output: [1] },
1010

11-
// xor repeats
11+
// repeat xor data to have enough to train with
1212
{ input: [0, 1], output: [1] },
1313
{ input: [0, 0], output: [0] },
1414
{ input: [1, 1], output: [0] },
15-
{ input: [1, 0], output: [1] },
16-
17-
// xor repeats
18-
{ input: [0, 1], output: [1] },
19-
{ input: [0, 0], output: [0] },
20-
{ input: [1, 1], output: [0] },
21-
{ input: [1, 0], output: [1] },
22-
23-
// xor repeats
24-
{ input: [0, 1], output: [1] },
25-
{ input: [0, 0], output: [0] },
26-
{ input: [1, 1], output: [0] },
27-
{ input: [1, 0], output: [1] },
15+
{ input: [1, 0], output: [1] }
2816
];
2917

3018
const netOptions = {

examples/cross-validate.js

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,11 @@ const trainingData = [
88
{ input: [1, 1], output: [0] },
99
{ input: [1, 0], output: [1] },
1010

11-
// xor repeats
11+
// repeat xor data to have enough to train with
1212
{ input: [0, 1], output: [1] },
1313
{ input: [0, 0], output: [0] },
1414
{ input: [1, 1], output: [0] },
15-
{ input: [1, 0], output: [1] },
16-
17-
// xor repeats
18-
{ input: [0, 1], output: [1] },
19-
{ input: [0, 0], output: [0] },
20-
{ input: [1, 1], output: [0] },
21-
{ input: [1, 0], output: [1] },
22-
23-
// xor repeats
24-
{ input: [0, 1], output: [1] },
25-
{ input: [0, 0], output: [0] },
26-
{ input: [1, 1], output: [0] },
27-
{ input: [1, 0], output: [1] },
15+
{ input: [1, 0], output: [1] }
2816
];
2917

3018
const netOptions = {

0 commit comments

Comments
 (0)