Skip to content

Commit 11ae3e8

Browse files
fixes #110, makes data formatter a set of singletons and don't overwrite input, but rather use rawInput
Also unit test. Bump version number.
1 parent e9ec361 commit 11ae3e8

File tree

11 files changed

+48
-16
lines changed

11 files changed

+48
-16
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "brain.js",
33
"description": "Neural network library",
4-
"version": "1.0.0-rc.6",
4+
"version": "1.0.0-rc.7",
55
"author": "Heather Arthur <[email protected]>",
66
"repository": {
77
"type": "git",

src/neural-network-gpu.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ export default class NeuralNetworkGPU extends NeuralNetwork {
290290
}
291291
// turn sparse hash input into arrays with 0s as filler
292292
let datum = data[0].input;
293-
if (!Array.isArray(datum) && !(datum instanceof Float64Array)) {
293+
if (!Array.isArray(datum) && !(datum instanceof Float32Array)) {
294294
if (!this.inputLookup) {
295295
this.inputLookup = lookup.buildLookup(data.map(value => value['input']));
296296
}

src/neural-network.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ export default class NeuralNetwork {
441441
}
442442
// turn sparse hash input into arrays with 0s as filler
443443
let datum = data[0].input;
444-
if (!Array.isArray(datum) && !(datum instanceof Float64Array)) {
444+
if (!Array.isArray(datum) && !(datum instanceof Float32Array)) {
445445
if (!this.inputLookup) {
446446
this.inputLookup = lookup.buildLookup(data.map(value => value['input']));
447447
}

src/recurrent/matrix/index.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ export default class Matrix {
2121
*
2222
* @param {Number} row
2323
* @param {Number} col
24-
* @returns {Float64Array|Array}
24+
* @returns {Float32Array|Array}
2525
*/
2626
getWeights(row, col) {
2727
// slow but careful accessor function

src/recurrent/rnn.js

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,11 +615,12 @@ export default class RNN {
615615
}
616616
}
617617

618-
return new Function('input', 'maxPredictionLength', 'isSampleI', 'temperature', `
618+
const src = `
619619
if (typeof input === 'undefined') input = [];
620620
if (typeof maxPredictionLength === 'undefined') maxPredictionLength = 100;
621621
if (typeof isSampleI === 'undefined') isSampleI = false;
622622
if (typeof temperature === 'undefined') temperature = 1;
623+
${ (this.dataFormatter !== null) ? this.dataFormatter.toFunctionString() : '' }
623624
624625
${
625626
(this.dataFormatter !== null && typeof this.formatDataIn === 'function')
@@ -677,24 +678,31 @@ ${ innerFunctionsSwitch.join('\n') }
677678
${ (this.dataFormatter !== null && typeof this.formatDataOut === 'function')
678679
? 'return formatDataOut(output.slice(input.length).map(function(value) { return value - 1; }))'
679680
: 'return output.slice(input.length).map(function(value) { return value - 1; })' };
680-
681681
function Matrix(rows, columns) {
682682
this.rows = rows;
683683
this.columns = columns;
684684
this.weights = zeros(rows * columns);
685685
}
686686
${ this.dataFormatter !== null && typeof this.formatDataIn === 'function'
687-
? `function formatDataIn(input, output) { ${ toInner(this.formatDataIn.toString()).replace('this.dataFormatter', 'json.options.dataFormatter') } }`
687+
? `function formatDataIn(input, output) { ${
688+
toInner(this.formatDataIn.toString())
689+
.replace(/this[.]dataFormatter[.]/g, '')
690+
.replace(/this[.]dataFormatter/g, 'true')
691+
} }`
688692
: '' }
689693
${ this.dataFormatter !== null && typeof this.formatDataOut === 'function'
690-
? `function formatDataOut(output) { ${ toInner(this.formatDataIn.toString()).replace('this.dataFormatter', 'json.options.dataFormatter') } }`
694+
? `function formatDataOut(output) { ${
695+
toInner(this.formatDataIn.toString())
696+
.replace(/this[.]dataFormatter[.]/g, '')
697+
.replace(/this[.]dataFormatter/g, 'true')
698+
} }`
691699
: '' }
692-
${ (this.dataFormatter !== null) ? this.dataFormatter.toFunctionString('json.options.dataFormatter') : '' }
693700
${ zeros.toString() }
694701
${ softmax.toString().replace('_2.default', 'Matrix') }
695702
${ randomF.toString() }
696703
${ sampleI.toString() }
697-
${ maxI.toString() }`)
704+
${ maxI.toString() }`;
705+
return new Function('rawInput', 'maxPredictionLength', 'isSampleI', 'temperature', src);
698706
}
699707
}
700708

src/utilities/data-formatter.js

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,16 @@ export default class DataFormatter {
161161
}
162162
}
163163

164-
toFunctionString(dataFormatterVariableName) {
164+
toFunctionString() {
165165
return `
166-
${ this.toIndexes.toString().replace('this', dataFormatterVariableName) }
167-
${ this.toIndexesInputOutput.toString().replace('this', dataFormatterVariableName) }
168-
${ this.toCharacters.toString().replace('this', dataFormatterVariableName) }
166+
var characterTable = ${ JSON.stringify(this.characterTable) };
167+
var indexTable = ${ JSON.stringify(this.indexTable) };
168+
var characters = ${ JSON.stringify(this.characters) };
169+
${ this.toIndexes.toString()
170+
.replace(/(let|var) indexTable = this[.]indexTable;\n/, '')
171+
.replace(/this[.]/g, '') }
172+
${ this.toIndexesInputOutput.toString().replace(/this[.]/g, '') }
173+
${ this.toCharacters.toString().replace(/this[.]/, '') }
169174
`;
170175
}
171176
}

src/utilities/ones.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export default function ones(size) {
2-
// if (typeof Float64Array !== 'undefined') return new Float64Array(size).fill(1);
2+
if (typeof Float32Array !== 'undefined') return new Float32Array(size).fill(1);
33
let array = new Array(size);
44
for (let i = 0; i < size; i++) {
55
array[i] = 1;

src/utilities/zeros.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export default function zeros(size) {
2-
// if (typeof Float64Array !== 'undefined') return new Float64Array(size);
2+
if (typeof Float32Array !== 'undefined') return new Float32Array(size);
33
let array = new Array(size);
44
for (let i = 0; i < size; i++) {
55
array[i] = 0;

test/recurrent/gru.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,5 +159,12 @@ describe('gru', () => {
159159
var lastOutput = dataFormatter.toCharacters(net.run()).join('');
160160
assert.equal(dataFormatter.toCharacters(net.toFunction()()).join(''), lastOutput);
161161
});
162+
163+
it('can include the DataFormatter', () => {
164+
const net = new GRU();
165+
net.train(['hi mom!'], { iterations: 1 });
166+
const newNet = net.toFunction();
167+
newNet('hi mom!');
168+
});
162169
});
163170
});

test/recurrent/lstm.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ describe('lstm', () => {
109109
var lastOutput = dataFormatter.toCharacters(net.run()).join('');
110110
assert.equal(dataFormatter.toCharacters(net.toFunction()()).join(''), lastOutput);
111111
});
112+
it('can include the DataFormatter', () => {
113+
const net = new LSTM();
114+
net.train(['hi mom!'], { iterations: 1 });
115+
const newNet = net.toFunction();
116+
newNet('hi mom!');
117+
});
112118
});
113119

114120
describe('.run', () => {

0 commit comments

Comments
 (0)