-
Notifications
You must be signed in to change notification settings - Fork 30
ONNX to Tensorflow.js conversion of GPT-2 #927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
martinjaggi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
amazing work, thanks a lot!
later i'd be very curious to check if also the conversion in the reverse direction (to onnx) works fine.
ae33116 to
99e2db4
Compare
6565027 to
d08b99d
Compare
0f7831d to
772832d
Compare
772832d to
464ff8b
Compare
tharvik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haa, that's great work, well done in hacking around the onnx protobuf!
nothing blocking (except the tsc --build thing), only a few nitpicks and questions here and there
| @@ -1,50 +1,102 @@ | |||
| // import fs from 'fs'; | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented
| import { Tokenizer, models } from '@epfml/discojs'; | ||
| import { models, serialization, Tokenizer } from '@epfml/discojs'; | ||
| import { loadHellaSwag } from '@epfml/discojs-node'; | ||
| // import { AutoTokenizer } from '@xenova/transformers'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented
| const logLines: string[] = []; | ||
| function log(message: string) { | ||
| console.log(message); | ||
| logLines.push(message); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need a log system for the CLI, we can simply output to the console, no?
| case 'gpt-tfjs-random': | ||
| log("Using GPT-TFJS with random initialization") | ||
| model = new models.GPT({ seed: 42 }); | ||
| break; | ||
| case 'gpt-tfjs-pretrained': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that was confusing
| case 'gpt-tfjs-random': | |
| log("Using GPT-TFJS with random initialization") | |
| model = new models.GPT({ seed: 42 }); | |
| break; | |
| case 'gpt-tfjs-pretrained': | |
| case 'gpt-tfjs-random': | |
| log("Using GPT-TFJS with random initialization") | |
| model = new models.GPT({ seed: 42 }); | |
| break; | |
| case 'gpt-tfjs-pretrained': |
| const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json") | ||
| const args = parse<HellaSwagArgs>({ | ||
| model: { | ||
| type: (raw: string) => raw as ModelType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
casting isn't nice, especially for user interfaces, better to implement so value checking with a switch
| console.log("WARNING: protobuf raw data is empty, falling back on specific data fields.") | ||
| if (tensor.floatData && tensor.floatData.length > 0) return new Float32Array(tensor.floatData); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that look to me as more relevant field than using rawData, why not using it first (and drop the warning)?
| if (gptLayersModel.weights.length !== onnxTfjsMapping.size) | ||
| throw new Error(`Mismatch between TFJS and ONNX weight mapping weights.`); | ||
|
|
||
| const finalWeights = gptLayersModel.weights.map((weight, _i) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| const finalWeights = gptLayersModel.weights.map((weight, _i) => { | |
| const finalWeights = gptLayersModel.weights.map((weight) => { |
| const ASSET_FOLDER = path.join(__dirname, "..", "assets"); | ||
| const OUTPUT_FILENAME = path.join(ASSET_FOLDER, "model.json"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we more simply store it in the current directory and so letting choose the user where to have it?
| "extends": "../tsconfig.base.json", | ||
| "compilerOptions": { "outDir": "dist" }, | ||
| "include": ["src"], | ||
| "exclude": ["**/*.spec.ts"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have tests anyway
| "exclude": ["**/*.spec.ts"] |
| pretrainedModelPath: { | ||
| type: String, | ||
| description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model', | ||
| defaultValue: defaultPretrainedModelPath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
single use, we can inline it IMO
Roadmap:
In order to be able to run the load_weights.spec.ts you need to change the format of the pretrained weights from .onnx to .jsonl by running the following python script:
Before running the script make sure you have downloaded the decoder_model.onnx from the following link: https://huggingface.co/Xenova/gpt2/tree/main/onnx and you have added it in the ONNX_MODEL_PATH
After generating this file, place it under the following path: disco/discojs/gpt2_weights.jsonl
The table below shows the accuracy along with the time of the evaluation on the whole HellaSwag for each model tested:
Note: The loaded_hellaswag.spec.ts reproduces this result for the loaded model. For the other 2 models this can be reproduced using the hellaswag_gpt.ts script in cli/