Skip to content

Commit 2a0093e

Browse files
authored
Merge pull request #148 from BrainJS/valid-options
Handles User Object Validation #142
2 parents 9da27e7 + 23af33d commit 2a0093e

File tree

7 files changed

+11608
-11295
lines changed

7 files changed

+11608
-11295
lines changed

README.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,16 @@ var output = net.run({ r: 1, g: 0.4, b: 0 }); // { white: 0.81, black: 0.18 }
129129

130130
```javascript
131131
net.train(data, {
132-
iterations: 20000, // the maximum times to iterate the training data
133-
errorThresh: 0.005, // the acceptable error percentage from training data
134-
log: false, // true to use console.log, when a function is supplied it is used
135-
logPeriod: 10, // iterations between logging out
136-
learningRate: 0.3, // scales with delta to effect traiing rate
137-
momentum: 0.1, // scales with next layer's change value
138-
callback: null, // a periodic call back that can be triggered while training
139-
callbackPeriod: 10, // the number of iterations through the training data between callback calls
140-
timeout: Infinity // the max number of milliseconds to train for
132+
// Defaults values --> expected validation
133+
iterations: 20000, // the maximum times to iterate the training data --> number greater than 0
134+
errorThresh: 0.005, // the acceptable error percentage from training data --> number between 0 and 1
135+
log: false, // true to use console.log, when a function is supplied it is used --> Either true or a function
136+
logPeriod: 10, // iterations between logging out --> number greater than 0
137+
learningRate: 0.3, // scales with delta to effect traiing rate --> number between 0 and 1
138+
momentum: 0.1, // scales with next layer's change value --> number between 0 and 1
139+
callback: null, // a periodic call back that can be triggered while training --> null or function
140+
callbackPeriod: 10, // the number of iterations through the training data between callback calls --> number greater than 0
141+
timeout: Infinity // the max number of milliseconds to train for --> number greater than 0
141142
});
142143
```
143144

@@ -151,6 +152,8 @@ The momentum is similar to learning rate, expecting a value from `0` to `1` as w
151152

152153
Any of these training options can be passed into the constructor or passed into the `updateTrainingOptions(opts)` method and they will be saved on the network and used any time you trian. If you save your network to json, these training options are saved and restored as well (except for callback and log, callback will be forgoten and log will be restored using console.log).
153154

155+
There is a boolean property called `invalidTrainOptsShouldThrow` that by default is set to true. While true if you enter a training option that is outside the normal range an error will be thrown with a message about the option you sent. When set to false no error is sent but a message is still sent to `console.warn` with the information.
156+
154157
### Async Training
155158
`trainAsync()` takes the same arguments as train (data and options). Instead of returning the results object from training it returns a promise that when resolved will return the training results object.
156159

browser.js

Lines changed: 11367 additions & 11222 deletions
Large diffs are not rendered by default.

browser.min.js

Lines changed: 62 additions & 60 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js

Lines changed: 48 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neural-network.js

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,31 @@ export default class NeuralNetwork {
3535
};
3636
}
3737

38+
/**
39+
*
40+
* @param options
41+
* @param boolean
42+
* @private
43+
*/
44+
static _validateTrainingOptions(options) {
45+
var validations = {
46+
iterations: (val) => { return typeof val === 'number' && val > 0; },
47+
errorThresh: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
48+
log: (val) => { return typeof val === 'function' || typeof val === 'boolean'; },
49+
logPeriod: (val) => { return typeof val === 'number' && val > 0; },
50+
learningRate: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
51+
momentum: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
52+
callback: (val) => { return typeof val === 'function' || val === null },
53+
callbackPeriod: (val) => { return typeof val === 'number' && val > 0; },
54+
timeout: (val) => { return typeof val === 'number' && val > 0 }
55+
};
56+
Object.keys(NeuralNetwork.trainDefaults).forEach(key => {
57+
if (validations.hasOwnProperty(key) && !validations[key](options[key])) {
58+
throw new Error(`[${key}, ${options[key]}] is out of normal training range, your network will probably not train.`);
59+
}
60+
});
61+
}
62+
3863
constructor(options = {}) {
3964
Object.assign(this, this.constructor.defaults, options);
4065
this.hiddenSizes = options.hiddenLayers;
@@ -293,7 +318,8 @@ export default class NeuralNetwork {
293318
* activation: ['sigmoid', 'relu', 'leaky-relu', 'tanh']
294319
*/
295320
_updateTrainingOptions(opts) {
296-
Object.keys(NeuralNetwork.trainDefaults).forEach(opt => this.trainOpts[opt] = opts[opt] || this.trainOpts[opt]);
321+
Object.keys(NeuralNetwork.trainDefaults).forEach(opt => this.trainOpts[opt] = (opts.hasOwnProperty(opt)) ? opts[opt] : this.trainOpts[opt]);
322+
NeuralNetwork._validateTrainingOptions(this.trainOpts);
297323
this._setLogMethod(opts.log || this.trainOpts.log);
298324
this.activation = opts.activation || this.activation;
299325
}

test/base/trainopts.js

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,94 @@ describe('train() and trainAsync() use the same private methods', () => {
166166
done()
167167
});
168168
});
169-
});
169+
});
170+
171+
describe('training options validation', () => {
172+
it('iterations validation', () => {
173+
let net = new brain.NeuralNetwork();
174+
assert.throws(() => { net._updateTrainingOptions({ iterations: 'should be a string' }) });
175+
assert.throws(() => { net._updateTrainingOptions({ iterations: () => {} }) });
176+
assert.throws(() => { net._updateTrainingOptions({ iterations: false }) });
177+
assert.throws(() => { net._updateTrainingOptions({ iterations: -1 }) });
178+
assert.doesNotThrow(() => { net._updateTrainingOptions({ iterations: 5000 }) });
179+
});
180+
181+
it('errorThresh validation', () => {
182+
let net = new brain.NeuralNetwork();
183+
assert.throws(() => { net._updateTrainingOptions({ errorThresh: 'no strings'}) });
184+
assert.throws(() => { net._updateTrainingOptions({ errorThresh: () => {} }) });
185+
assert.throws(() => { net._updateTrainingOptions({ errorThresh: 5}) });
186+
assert.throws(() => { net._updateTrainingOptions({ errorThresh: -1}) });
187+
assert.throws(() => { net._updateTrainingOptions({ errorThresh: false}) });
188+
assert.doesNotThrow(() => { net._updateTrainingOptions({ errorThresh: 0.008}) });
189+
});
190+
191+
it('log validation', () => {
192+
let net = new brain.NeuralNetwork();
193+
assert.throws(() => { net._updateTrainingOptions({ log: 'no strings' }) });
194+
assert.throws(() => { net._updateTrainingOptions({ log: 4 }) });
195+
assert.doesNotThrow(() => { net._updateTrainingOptions({ log: false }) });
196+
assert.doesNotThrow(() => { net._updateTrainingOptions({ log: () => {} }) });
197+
});
198+
199+
it('logPeriod validation', () => {
200+
let net = new brain.NeuralNetwork();
201+
assert.throws(() => { net._updateTrainingOptions({ logPeriod: 'no strings' }) });
202+
assert.throws(() => { net._updateTrainingOptions({ logPeriod: -50 }) });
203+
assert.throws(() => { net._updateTrainingOptions({ logPeriod: () => {} }) });
204+
assert.throws(() => { net._updateTrainingOptions({ logPeriod: false }) });
205+
assert.doesNotThrow(() => { net._updateTrainingOptions({ logPeriod: 40 }) });
206+
});
207+
208+
it('learningRate validation', () => {
209+
let net = new brain.NeuralNetwork();
210+
assert.throws(() => { net._updateTrainingOptions({ learningRate: 'no strings' }) });
211+
assert.throws(() => { net._updateTrainingOptions({ learningRate: -50 }) });
212+
assert.throws(() => { net._updateTrainingOptions({ learningRate: 50 }) });
213+
assert.throws(() => { net._updateTrainingOptions({ learningRate: () => {} }) });
214+
assert.throws(() => { net._updateTrainingOptions({ learningRate: false }) });
215+
assert.doesNotThrow(() => { net._updateTrainingOptions({ learningRate: 0.5 }) });
216+
});
217+
218+
it('momentum validation', () => {
219+
let net = new brain.NeuralNetwork();
220+
assert.throws(() => { net._updateTrainingOptions({ momentum: 'no strings' }) });
221+
assert.throws(() => { net._updateTrainingOptions({ momentum: -50 }) });
222+
assert.throws(() => { net._updateTrainingOptions({ momentum: 50 }) });
223+
assert.throws(() => { net._updateTrainingOptions({ momentum: () => {} }) });
224+
assert.throws(() => { net._updateTrainingOptions({ momentum: false }) });
225+
assert.doesNotThrow(() => { net._updateTrainingOptions({ momentum: 0.8 }) });
226+
});
227+
228+
it('callback validation', () => {
229+
let net = new brain.NeuralNetwork();
230+
assert.throws(() => { net._updateTrainingOptions({ callback: 'no strings' }) });
231+
assert.throws(() => { net._updateTrainingOptions({ callback: 4 }) });
232+
assert.throws(() => { net._updateTrainingOptions({ callback: false }) });
233+
assert.doesNotThrow(() => { net._updateTrainingOptions({ callback: null }) });
234+
assert.doesNotThrow(() => { net._updateTrainingOptions({ callback: () => {} }) });
235+
});
236+
237+
it('callbackPeriod validation', () => {
238+
let net = new brain.NeuralNetwork();
239+
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: 'no strings' }) });
240+
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: -50 }) });
241+
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: () => {} }) });
242+
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: false }) });
243+
assert.doesNotThrow(() => { net._updateTrainingOptions({ callbackPeriod: 40 }) });
244+
});
245+
246+
it('timeout validation', () => {
247+
let net = new brain.NeuralNetwork();
248+
assert.throws(() => { net._updateTrainingOptions({ timeout: 'no strings' }) });
249+
assert.throws(() => { net._updateTrainingOptions({ timeout: -50 }) });
250+
assert.throws(() => { net._updateTrainingOptions({ timeout: () => {} }) });
251+
assert.throws(() => { net._updateTrainingOptions({ timeout: false }) });
252+
assert.doesNotThrow(() => { net._updateTrainingOptions({ timeout: 40 }) });
253+
});
254+
255+
it('should handle unsupported options', () => {
256+
let net = new brain.NeuralNetwork();
257+
assert.doesNotThrow(() => { net._updateTrainingOptions({ fakeProperty: 'should be handled fine' }) });
258+
})
259+
});

0 commit comments

Comments
 (0)