Skip to content

Conversation

@christinakopi
Copy link
Collaborator

@christinakopi christinakopi commented Jun 26, 2025

Roadmap:

  • ensure bijection between onnx and tfjs weights
  • add progress indicator to hellaswag benchmark
  • move filesystem logic to discojs-node and cli
  • automate downloading onnx model from hub
  • convert python script to nodejs and web-compatible logic
  • experiment with webgpu for hellaswag
  • experiment with converting tfjs back to onnx

In order to be able to run the load_weights.spec.ts you need to change the format of the pretrained weights from .onnx to .jsonl by running the following python script:

Before running the script make sure you have downloaded the decoder_model.onnx from the following link: https://huggingface.co/Xenova/gpt2/tree/main/onnx and you have added it in the ONNX_MODEL_PATH

!pip install onnx
!pip install onnxruntime

import onnx
import numpy as np
from onnx import numpy_helper
import json
import sys

ONNX_MODEL_PATH = "<path/to/decoder_model.onnx>"
OUTPUT_FILENAME = "<where the output file (gpt2_weights.jsonl) will be saved>"

try:
    model = onnx.load(ONNX_MODEL_PATH)
except FileNotFoundError:
    print(f"ERROR: The model file was not found at '{ONNX_MODEL_PATH}'")
    sys.exit(1)

print(f"Extracting weights from {ONNX_MODEL_PATH}...")

with open(OUTPUT_FILENAME, "w") as f:
    for initializer in model.graph.initializer:
        name = initializer.name
        array = numpy_helper.to_array(initializer)

        weight_object = {
            "key": name,
            "value": array.tolist()
        }

        f.write(json.dumps(weight_object) + "\n")

print(f"Successfully saved weights to {OUTPUT_FILENAME}")

After generating this file, place it under the following path: disco/discojs/gpt2_weights.jsonl

The table below shows the accuracy along with the time of the evaluation on the whole HellaSwag for each model tested:

Model Accuracy Eval Time (s)
TFJS GPT (gpt-nano) 24.67% 1390.25
Xenova GPT-2 (ONNX) 29.03% 22767.03
Loaded TFJS GPT (from ONNX) 28.41% 12523.59

Note: The loaded_hellaswag.spec.ts reproduces this result for the loaded model. For the other 2 models this can be reproduced using the hellaswag_gpt.ts script in cli/

Copy link
Member

@martinjaggi martinjaggi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amazing work, thanks a lot!
later i'd be very curious to check if also the conversion in the reverse direction (to onnx) works fine.

@JulienVig JulienVig force-pushed the NAN-init_from_ONNX-christinakopi branch from ae33116 to 99e2db4 Compare November 26, 2025 01:27
@JulienVig JulienVig changed the title First try of loading the weights of the pretrained ONNX GPT2 model into our GPT2-tfjs implementation ONNX to Tensorflow.js conversion of GPT-2 Nov 26, 2025
@JulienVig JulienVig force-pushed the NAN-init_from_ONNX-christinakopi branch 3 times, most recently from 6565027 to d08b99d Compare November 26, 2025 03:04
@JulienVig JulienVig force-pushed the NAN-init_from_ONNX-christinakopi branch from 0f7831d to 772832d Compare November 26, 2025 03:49
@JulienVig JulienVig force-pushed the NAN-init_from_ONNX-christinakopi branch from 772832d to 464ff8b Compare November 26, 2025 04:43
@JulienVig JulienVig marked this pull request as ready for review November 26, 2025 05:01
@JulienVig JulienVig requested a review from tharvik November 26, 2025 05:01
Copy link
Collaborator

@tharvik tharvik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haa, that's great work, well done in hacking around the onnx protobuf!
nothing blocking (except the tsc --build thing), only a few nitpicks and questions here and there

@@ -1,50 +1,102 @@
// import fs from 'fs';
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented

import { Tokenizer, models } from '@epfml/discojs';
import { models, serialization, Tokenizer } from '@epfml/discojs';
import { loadHellaSwag } from '@epfml/discojs-node';
// import { AutoTokenizer } from '@xenova/transformers';
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented

Comment on lines +17 to 21
const logLines: string[] = [];
function log(message: string) {
console.log(message);
logLines.push(message);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need a log system for the CLI, we can simply output to the console, no?

Comment on lines +86 to +90
case 'gpt-tfjs-random':
log("Using GPT-TFJS with random initialization")
model = new models.GPT({ seed: 42 });
break;
case 'gpt-tfjs-pretrained':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that was confusing

Suggested change
case 'gpt-tfjs-random':
log("Using GPT-TFJS with random initialization")
model = new models.GPT({ seed: 42 });
break;
case 'gpt-tfjs-pretrained':
case 'gpt-tfjs-random':
log("Using GPT-TFJS with random initialization")
model = new models.GPT({ seed: 42 });
break;
case 'gpt-tfjs-pretrained':

const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json")
const args = parse<HellaSwagArgs>({
model: {
type: (raw: string) => raw as ModelType,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting isn't nice, especially for user interfaces, better to implement so value checking with a switch

Comment on lines +108 to +109
console.log("WARNING: protobuf raw data is empty, falling back on specific data fields.")
if (tensor.floatData && tensor.floatData.length > 0) return new Float32Array(tensor.floatData);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that look to me as more relevant field than using rawData, why not using it first (and drop the warning)?

if (gptLayersModel.weights.length !== onnxTfjsMapping.size)
throw new Error(`Mismatch between TFJS and ONNX weight mapping weights.`);

const finalWeights = gptLayersModel.weights.map((weight, _i) => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const finalWeights = gptLayersModel.weights.map((weight, _i) => {
const finalWeights = gptLayersModel.weights.map((weight) => {

Comment on lines +12 to +13
const ASSET_FOLDER = path.join(__dirname, "..", "assets");
const OUTPUT_FILENAME = path.join(ASSET_FOLDER, "model.json");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we more simply store it in the current directory and so letting choose the user where to have it?

"extends": "../tsconfig.base.json",
"compilerOptions": { "outDir": "dist" },
"include": ["src"],
"exclude": ["**/*.spec.ts"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have tests anyway

Suggested change
"exclude": ["**/*.spec.ts"]

pretrainedModelPath: {
type: String,
description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model',
defaultValue: defaultPretrainedModelPath
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

single use, we can inline it IMO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants