From d402b49d2747c4d17333e7ea9f031145d1961f88 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Sun, 11 Jan 2026 16:45:19 -0500 Subject: [PATCH 01/17] Update tokenizers for id_to_piece --- extension/llm/runner/test/test_text_llm_runner.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index c9b57fb7391..69798f224cb 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -38,6 +38,11 @@ class MockTokenizer : public ::tokenizers::Tokenizer { encode, (const std::string&, int8_t, int8_t), (const)); + MOCK_METHOD( + ::tokenizers::Result, + id_to_piece, + (uint64_t), + (const)); MOCK_METHOD( ::tokenizers::Result, decode, From 0396189cb8dfd18910133d2a578c845362c6eed2 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Sun, 11 Jan 2026 18:59:23 -0500 Subject: [PATCH 02/17] Working with supported_punctuation export (don't want) --- .../models/parakeet/export_parakeet_tdt.py | 37 ++ examples/models/parakeet/main.cpp | 378 +++++++++++++++++- 2 files changed, 397 insertions(+), 18 deletions(-) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 92e32ca30bf..2605365d1fa 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -2,9 +2,11 @@ import argparse import os +import re import shutil import tarfile import tempfile +import unicodedata import torch import torchaudio @@ -17,6 +19,29 @@ from torch.export import Dim, export +_SPECIAL_TOKEN_PATTERNS = [ + re.compile(r"^\[.*\]$"), + re.compile(r"^<.*>$"), + re.compile(r"^##"), + re.compile(r"^▁"), + re.compile(r"^\s*$"), +] + + +def _extract_punctuation_from_vocab(vocab: list[str]) -> list[str]: + def is_special_token(token: str) -> bool: + return any(pattern.match(token) for pattern in _SPECIAL_TOKEN_PATTERNS) + + punctuation: set[str] = set() + for token in vocab: + if is_special_token(token): + continue + for char in token: + if unicodedata.category(char).startswith("P"): + punctuation.add(char) + return sorted(punctuation) + + def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: """Load audio file and resample to target sample rate.""" @@ -351,6 +376,15 @@ def export_all(model): ) sample_rate = model.preprocessor._cfg.sample_rate + window_stride = float(model.preprocessor._cfg.window_stride) + encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8)) + + tokenizer_vocab = getattr(model.tokenizer, "vocab", None) + supported_punctuation = ( + _extract_punctuation_from_vocab(tokenizer_vocab) + if isinstance(tokenizer_vocab, list) + else [] + ) metadata = { "num_rnn_layers": num_layers, "pred_hidden": pred_hidden, @@ -358,6 +392,9 @@ def export_all(model): "vocab_size": model.tokenizer.vocab_size, "blank_id": model.tokenizer.vocab_size, "sample_rate": sample_rate, + "window_stride": window_stride, + "encoder_subsampling_factor": encoder_subsampling_factor, + "supported_punctuation": supported_punctuation, } return programs, metadata diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 026f3911a3d..6430096842d 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -6,11 +6,14 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include #include #include +#include +#include #include #include @@ -27,7 +30,7 @@ DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); DEFINE_string( tokenizer_path, - "tokenizer.json", + "tokenizer.model", "Path to SentencePiece tokenizer model file."); DEFINE_string( data_path, @@ -44,6 +47,242 @@ namespace { // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; +struct TokenTimestamp { + uint64_t token_id; + std::string token_piece; + std::string token_text; + int64_t start_offset; + int64_t end_offset; +}; + +struct WordTimestamp { + std::string word; + int64_t start_offset; + int64_t end_offset; +}; + +struct SegmentTimestamp { + std::string segment; + int64_t start_offset; + int64_t end_offset; +}; + +std::string decode_token_sequence( + const std::vector& tokens, + tokenizers::Tokenizer* tokenizer) { + std::string result; + uint64_t prev_token = tokenizer->bos_tok(); + for (uint64_t token : tokens) { + auto decode_result = tokenizer->decode(prev_token, token); + if (decode_result.ok()) { + result += decode_result.get(); + } + prev_token = token; + } + return result; +} + +std::vector get_words_offsets( + const std::vector& tokens, + tokenizers::Tokenizer* tokenizer, + const std::unordered_set& supported_punctuation, + const std::string& word_delimiter_char = " ") { + std::vector word_offsets; + if (tokens.empty()) { + return word_offsets; + } + + size_t previous_token_index = 0; + std::vector built_tokens; + + auto is_curr_punctuation = [&](const std::string& token_text) { + return token_text != word_delimiter_char && + supported_punctuation.count(token_text) > 0; + }; + + auto is_word_start = [&](const std::string& token_piece, + const std::string& token_text, + const std::string& next_non_delim_token) { + const bool next_is_punctuation = + supported_punctuation.count(next_non_delim_token) > 0; + return token_piece != token_text || + (token_text == word_delimiter_char && !next_is_punctuation); + }; + + for (size_t i = 0; i < tokens.size(); i++) { + const auto& token = tokens[i]; + + const bool curr_punctuation = is_curr_punctuation(token.token_text); + + std::string next_non_delim_token; + for (size_t j = i + 1; j < tokens.size(); j++) { + if (tokens[j].token_text != word_delimiter_char) { + next_non_delim_token = tokens[j].token_text; + break; + } + } + + if (is_word_start(token.token_piece, token.token_text, next_non_delim_token) && + !curr_punctuation) { + if (!built_tokens.empty()) { + std::vector built_ids; + built_ids.reserve(built_tokens.size()); + for (size_t idx : built_tokens) { + built_ids.push_back(tokens[idx].token_id); + } + word_offsets.push_back( + {decode_token_sequence(built_ids, tokenizer), + tokens[previous_token_index].start_offset, + tokens[i - 1].end_offset}); + } + + built_tokens.clear(); + + if (token.token_text != word_delimiter_char) { + built_tokens.push_back(i); + previous_token_index = i; + } + } else if (curr_punctuation && built_tokens.empty() && !word_offsets.empty()) { + auto& last_built_word = word_offsets.back(); + last_built_word.end_offset = token.end_offset; + if (!last_built_word.word.empty() && last_built_word.word.back() == ' ') { + last_built_word.word.pop_back(); + } + last_built_word.word += token.token_text; + } else if (curr_punctuation && !built_tokens.empty()) { + const auto& last = tokens[built_tokens.back()].token_piece; + if (last == " " || last == "_" || last == "▁") { + built_tokens.pop_back(); + } + built_tokens.push_back(i); + } else { + if (built_tokens.empty()) { + previous_token_index = i; + } + built_tokens.push_back(i); + } + } + + // Match NeMo behavior: inject first start_offset and append any remaining + // built tokens as the final word. + if (word_offsets.empty()) { + if (!built_tokens.empty()) { + std::vector built_ids; + built_ids.reserve(built_tokens.size()); + for (size_t idx : built_tokens) { + built_ids.push_back(tokens[idx].token_id); + } + word_offsets.push_back( + {decode_token_sequence(built_ids, tokenizer), + tokens[0].start_offset, + tokens.back().end_offset}); + } + } else { + word_offsets[0].start_offset = tokens[0].start_offset; + + if (!built_tokens.empty()) { + std::vector built_ids; + built_ids.reserve(built_tokens.size()); + for (size_t idx : built_tokens) { + built_ids.push_back(tokens[idx].token_id); + } + word_offsets.push_back( + {decode_token_sequence(built_ids, tokenizer), + tokens[previous_token_index].start_offset, + tokens.back().end_offset}); + } + } + + return word_offsets; +} + +std::vector get_segment_offsets( + const std::vector& word_offsets, + const std::vector& segment_delimiters = {".", "?", "!"}, + const std::optional& segment_gap_threshold = std::nullopt) { + std::vector segment_offsets; + if (word_offsets.empty()) { + return segment_offsets; + } + + std::vector segment_words; + size_t previous_word_index = 0; + + for (size_t i = 0; i < word_offsets.size(); i++) { + const auto& offset = word_offsets[i]; + const auto& word = offset.word; + + if (segment_gap_threshold.has_value() && !segment_words.empty()) { + const int64_t gap_between_words = + offset.start_offset - word_offsets[i - 1].end_offset; + if (gap_between_words >= segment_gap_threshold.value()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + word_offsets[i - 1].end_offset}); + segment_words = {word}; + previous_word_index = i; + continue; + } + } + + const bool is_delimiter_word = std::find( + segment_delimiters.begin(), + segment_delimiters.end(), + word) != segment_delimiters.end(); + + const bool ends_with_delimiter = !word.empty() && + std::find( + segment_delimiters.begin(), + segment_delimiters.end(), + std::string(1, word.back())) != segment_delimiters.end(); + + if (!word.empty() && (ends_with_delimiter || is_delimiter_word)) { + segment_words.push_back(word); + if (!segment_words.empty()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back({segment, + word_offsets[previous_word_index].start_offset, + offset.end_offset}); + } + segment_words.clear(); + previous_word_index = i + 1; + continue; + } + + segment_words.push_back(word); + } + + if (!segment_words.empty()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + word_offsets.back().end_offset}); + } + + return segment_offsets; +} + std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, @@ -52,7 +291,9 @@ std::vector greedy_decode_executorch( int64_t vocab_size, int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, - int64_t max_symbols_per_step = 10) { + int64_t max_symbols_per_step = 10, + std::vector* token_start_offsets = nullptr, + std::vector* token_durations = nullptr) { std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; @@ -209,6 +450,12 @@ std::vector greedy_decode_executorch( symbols_on_frame = 0; } else { hypothesis.push_back(k); + if (token_start_offsets != nullptr) { + token_start_offsets->push_back(t); + } + if (token_durations != nullptr) { + token_durations->push_back(dur); + } // Update decoder state std::vector token_data = {k}; @@ -271,19 +518,12 @@ std::vector greedy_decode_executorch( std::string tokens_to_text( const std::vector& tokens, tokenizers::Tokenizer* tokenizer) { - // Decode tokens to text one by one - std::string result; - uint64_t prev_token = 0; - for (size_t i = 0; i < tokens.size(); i++) { - uint64_t token = static_cast(tokens[i]); - auto decode_result = tokenizer->decode(prev_token, token); - if (decode_result.ok()) { - result += decode_result.get(); - } - prev_token = token; + std::vector ids; + ids.reserve(tokens.size()); + for (int64_t t : tokens) { + ids.push_back(static_cast(t)); } - - return result; + return decode_token_sequence(ids, tokenizer); } } // namespace @@ -381,10 +621,17 @@ int main(int argc, char** argv) { auto vocab_size_result = model->execute("vocab_size", empty_inputs); auto blank_id_result = model->execute("blank_id", empty_inputs); auto sample_rate_result = model->execute("sample_rate", empty_inputs); + auto window_stride_result = model->execute("window_stride", empty_inputs); + auto encoder_subsampling_factor_result = + model->execute("encoder_subsampling_factor", empty_inputs); + auto supported_punctuation_result = + model->execute("supported_punctuation", empty_inputs); if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || !vocab_size_result.ok() || !blank_id_result.ok() || - !sample_rate_result.ok()) { + !sample_rate_result.ok() || !window_stride_result.ok() || + !encoder_subsampling_factor_result.ok() || + !supported_punctuation_result.ok()) { ET_LOG( Error, "Failed to query model metadata. Make sure the model was exported with constant_methods."); @@ -396,17 +643,32 @@ int main(int argc, char** argv) { int64_t num_rnn_layers = num_rnn_layers_result.get()[0].toInt(); int64_t pred_hidden = pred_hidden_result.get()[0].toInt(); int64_t sample_rate = sample_rate_result.get()[0].toInt(); + double window_stride = window_stride_result.get()[0].toDouble(); + int64_t encoder_subsampling_factor = + encoder_subsampling_factor_result.get()[0].toInt(); + + std::unordered_set supported_punctuation; + for (const auto& ev : supported_punctuation_result.get()) { + if (!ev.isString()) { + continue; + } + supported_punctuation.insert(std::string(ev.toString())); + } ET_LOG( Info, - "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld", + "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld, window_stride=%.6f, encoder_subsampling_factor=%lld", static_cast(vocab_size), static_cast(blank_id), static_cast(num_rnn_layers), static_cast(pred_hidden), - static_cast(sample_rate)); + static_cast(sample_rate), + window_stride, + static_cast(encoder_subsampling_factor)); ET_LOG(Info, "Running TDT greedy decode..."); + std::vector token_start_offsets; + std::vector token_durations; auto tokens = greedy_decode_executorch( *model, encoded, @@ -414,7 +676,10 @@ int main(int argc, char** argv) { blank_id, vocab_size, num_rnn_layers, - pred_hidden); + pred_hidden, + /*max_symbols_per_step=*/10, + &token_start_offsets, + &token_durations); ET_LOG(Info, "Decoded %zu tokens", tokens.size()); @@ -434,6 +699,83 @@ int main(int argc, char** argv) { std::string text = tokens_to_text(tokens, tokenizer.get()); std::cout << "Transcription tokens: " << text << std::endl; + // Compute timestamps matching NeMo's TDT timestamp behavior. + if (tokens.size() != token_start_offsets.size() || + tokens.size() != token_durations.size()) { + ET_LOG(Error, "Token/timestamp length mismatch"); + return 1; + } + + std::vector char_timestamps; + char_timestamps.reserve(tokens.size()); + + for (size_t i = 0; i < tokens.size(); i++) { + const uint64_t token_id = static_cast(tokens[i]); + + auto piece_result = tokenizer->id_to_piece(token_id); + if (!piece_result.ok()) { + ET_LOG( + Error, + "id_to_piece failed for token=%llu", + (unsigned long long)token_id); + return 1; + } + + auto text_result = tokenizer->decode(tokenizer->bos_tok(), token_id); + if (!text_result.ok()) { + ET_LOG(Error, "decode failed for token=%llu", (unsigned long long)token_id); + return 1; + } + + const int64_t start_offset = token_start_offsets[i]; + const int64_t end_offset = start_offset + token_durations[i]; + + char_timestamps.push_back( + {token_id, + piece_result.get(), + text_result.get(), + start_offset, + end_offset}); + } + + // NeMo TDT punctuation refinement: snap punctuation to the end of the + // previous token. + for (size_t i = 1; i < char_timestamps.size(); i++) { + if (supported_punctuation.count(char_timestamps[i].token_text) > 0) { + char_timestamps[i].start_offset = char_timestamps[i - 1].end_offset; + char_timestamps[i].end_offset = char_timestamps[i].start_offset; + } + } + + auto word_timestamps = + get_words_offsets(char_timestamps, tokenizer.get(), supported_punctuation); + auto segment_timestamps = get_segment_offsets(word_timestamps); + + const double frame_to_seconds = + window_stride * static_cast(encoder_subsampling_factor); + + std::cout << "\nSegment timestamps:" << std::endl; + for (const auto& stamp : segment_timestamps) { + const double start = stamp.start_offset * frame_to_seconds; + const double end = stamp.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << stamp.segment << std::endl; + } + + std::cout << "\nWord timestamps:" << std::endl; + for (const auto& stamp : word_timestamps) { + const double start = stamp.start_offset * frame_to_seconds; + const double end = stamp.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << stamp.word << std::endl; + } + + std::cout << "\nChar timestamps:" << std::endl; + for (const auto& stamp : char_timestamps) { + const double start = stamp.start_offset * frame_to_seconds; + const double end = stamp.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << stamp.token_text + << std::endl; + } + ET_LOG(Info, "Done!"); return 0; } From fa3e4043695caaa877d7f54dc7420c3a70bcd8e7 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Sun, 11 Jan 2026 19:29:27 -0500 Subject: [PATCH 03/17] Derive supported punctuation --- .../models/parakeet/export_parakeet_tdt.py | 32 ----- examples/models/parakeet/main.cpp | 121 ++++++++++++++---- 2 files changed, 96 insertions(+), 57 deletions(-) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 2605365d1fa..7d459f54da2 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -2,11 +2,9 @@ import argparse import os -import re import shutil import tarfile import tempfile -import unicodedata import torch import torchaudio @@ -19,29 +17,6 @@ from torch.export import Dim, export -_SPECIAL_TOKEN_PATTERNS = [ - re.compile(r"^\[.*\]$"), - re.compile(r"^<.*>$"), - re.compile(r"^##"), - re.compile(r"^▁"), - re.compile(r"^\s*$"), -] - - -def _extract_punctuation_from_vocab(vocab: list[str]) -> list[str]: - def is_special_token(token: str) -> bool: - return any(pattern.match(token) for pattern in _SPECIAL_TOKEN_PATTERNS) - - punctuation: set[str] = set() - for token in vocab: - if is_special_token(token): - continue - for char in token: - if unicodedata.category(char).startswith("P"): - punctuation.add(char) - return sorted(punctuation) - - def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: """Load audio file and resample to target sample rate.""" @@ -379,12 +354,6 @@ def export_all(model): window_stride = float(model.preprocessor._cfg.window_stride) encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8)) - tokenizer_vocab = getattr(model.tokenizer, "vocab", None) - supported_punctuation = ( - _extract_punctuation_from_vocab(tokenizer_vocab) - if isinstance(tokenizer_vocab, list) - else [] - ) metadata = { "num_rnn_layers": num_layers, "pred_hidden": pred_hidden, @@ -394,7 +363,6 @@ def export_all(model): "sample_rate": sample_rate, "window_stride": window_stride, "encoder_subsampling_factor": encoder_subsampling_factor, - "supported_punctuation": supported_punctuation, } return programs, metadata diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 6430096842d..13281412969 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -9,18 +9,20 @@ #include #include #include +#include #include #include #include #include -#include #include +#include #include #include #include #include +#include #include #include #include @@ -67,6 +69,76 @@ struct SegmentTimestamp { int64_t end_offset; }; +bool is_whitespace_only(const std::string& token) { + if (token.empty()) { + return true; + } + + try { + const auto codepoints = unicode_cpts_from_utf8(token); + for (const auto cp : codepoints) { + if (!unicode_cpt_flags(cp).is_whitespace) { + return false; + } + } + return true; + } catch (const std::exception&) { + return false; + } +} + +bool is_special_token(const std::string& token) { + if (token.size() >= 2 && token.front() == '[' && token.back() == ']') { + return true; + } + if (token.size() >= 2 && token.front() == '<' && token.back() == '>') { + return true; + } + if (token.rfind("##", 0) == 0) { + return true; + } + if (token.rfind(u8"▁", 0) == 0) { + return true; + } + if (is_whitespace_only(token)) { + return true; + } + return false; +} + +std::unordered_set derive_supported_punctuation( + tokenizers::Tokenizer* tokenizer) { + std::unordered_set punctuation; + + const int32_t vocab_size = tokenizer->vocab_size(); + for (int32_t id = 0; id < vocab_size; id++) { + const auto piece_result = tokenizer->id_to_piece(static_cast(id)); + if (!piece_result.ok()) { + continue; + } + const std::string& piece = piece_result.get(); + if (is_special_token(piece)) { + continue; + } + + try { + const auto codepoints = unicode_cpts_from_utf8(piece); + for (const auto cp : codepoints) { + if (unicode_cpt_flags(cp).is_punctuation) { + punctuation.insert(unicode_cpt_to_utf8(cp)); + } + } + } catch (const std::exception&) { + ET_LOG( + Error, + "Failed to decode token piece '%s' to codepoints", + piece.c_str()); + } + } + + return punctuation; +} + std::string decode_token_sequence( const std::vector& tokens, tokenizers::Tokenizer* tokenizer) { @@ -122,7 +194,8 @@ std::vector get_words_offsets( } } - if (is_word_start(token.token_piece, token.token_text, next_non_delim_token) && + if (is_word_start( + token.token_piece, token.token_text, next_non_delim_token) && !curr_punctuation) { if (!built_tokens.empty()) { std::vector built_ids; @@ -142,7 +215,8 @@ std::vector get_words_offsets( built_tokens.push_back(i); previous_token_index = i; } - } else if (curr_punctuation && built_tokens.empty() && !word_offsets.empty()) { + } else if ( + curr_punctuation && built_tokens.empty() && !word_offsets.empty()) { auto& last_built_word = word_offsets.back(); last_built_word.end_offset = token.end_offset; if (!last_built_word.word.empty() && last_built_word.word.back() == ' ') { @@ -233,10 +307,9 @@ std::vector get_segment_offsets( } } - const bool is_delimiter_word = std::find( - segment_delimiters.begin(), - segment_delimiters.end(), - word) != segment_delimiters.end(); + const bool is_delimiter_word = + std::find(segment_delimiters.begin(), segment_delimiters.end(), word) != + segment_delimiters.end(); const bool ends_with_delimiter = !word.empty() && std::find( @@ -254,9 +327,10 @@ std::vector get_segment_offsets( } segment += segment_words[j]; } - segment_offsets.push_back({segment, - word_offsets[previous_word_index].start_offset, - offset.end_offset}); + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + offset.end_offset}); } segment_words.clear(); previous_word_index = i + 1; @@ -624,14 +698,11 @@ int main(int argc, char** argv) { auto window_stride_result = model->execute("window_stride", empty_inputs); auto encoder_subsampling_factor_result = model->execute("encoder_subsampling_factor", empty_inputs); - auto supported_punctuation_result = - model->execute("supported_punctuation", empty_inputs); if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || !vocab_size_result.ok() || !blank_id_result.ok() || !sample_rate_result.ok() || !window_stride_result.ok() || - !encoder_subsampling_factor_result.ok() || - !supported_punctuation_result.ok()) { + !encoder_subsampling_factor_result.ok()) { ET_LOG( Error, "Failed to query model metadata. Make sure the model was exported with constant_methods."); @@ -647,14 +718,6 @@ int main(int argc, char** argv) { int64_t encoder_subsampling_factor = encoder_subsampling_factor_result.get()[0].toInt(); - std::unordered_set supported_punctuation; - for (const auto& ev : supported_punctuation_result.get()) { - if (!ev.isString()) { - continue; - } - supported_punctuation.insert(std::string(ev.toString())); - } - ET_LOG( Info, "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld, window_stride=%.6f, encoder_subsampling_factor=%lld", @@ -695,6 +758,13 @@ int main(int argc, char** argv) { return 1; } + std::unordered_set supported_punctuation = + derive_supported_punctuation(tokenizer.get()); + ET_LOG( + Info, + "Derived supported_punctuation size=%zu", + supported_punctuation.size()); + // Convert tokens to text std::string text = tokens_to_text(tokens, tokenizer.get()); std::cout << "Transcription tokens: " << text << std::endl; @@ -723,7 +793,8 @@ int main(int argc, char** argv) { auto text_result = tokenizer->decode(tokenizer->bos_tok(), token_id); if (!text_result.ok()) { - ET_LOG(Error, "decode failed for token=%llu", (unsigned long long)token_id); + ET_LOG( + Error, "decode failed for token=%llu", (unsigned long long)token_id); return 1; } @@ -747,8 +818,8 @@ int main(int argc, char** argv) { } } - auto word_timestamps = - get_words_offsets(char_timestamps, tokenizer.get(), supported_punctuation); + auto word_timestamps = get_words_offsets( + char_timestamps, tokenizer.get(), supported_punctuation); auto segment_timestamps = get_segment_offsets(word_timestamps); const double frame_to_seconds = From 62774fa45857a5db6ca512f2dd36d1789212715a Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Sun, 11 Jan 2026 19:35:53 -0500 Subject: [PATCH 04/17] Some cleanups --- examples/models/parakeet/main.cpp | 44 ++++++++++++++----------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 13281412969..6158d9acc42 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -107,12 +107,12 @@ bool is_special_token(const std::string& token) { } std::unordered_set derive_supported_punctuation( - tokenizers::Tokenizer* tokenizer) { + const tokenizers::Tokenizer& tokenizer) { std::unordered_set punctuation; - const int32_t vocab_size = tokenizer->vocab_size(); + const int32_t vocab_size = tokenizer.vocab_size(); for (int32_t id = 0; id < vocab_size; id++) { - const auto piece_result = tokenizer->id_to_piece(static_cast(id)); + const auto piece_result = tokenizer.id_to_piece(static_cast(id)); if (!piece_result.ok()) { continue; } @@ -141,11 +141,11 @@ std::unordered_set derive_supported_punctuation( std::string decode_token_sequence( const std::vector& tokens, - tokenizers::Tokenizer* tokenizer) { + const tokenizers::Tokenizer& tokenizer) { std::string result; - uint64_t prev_token = tokenizer->bos_tok(); + uint64_t prev_token = tokenizer.bos_tok(); for (uint64_t token : tokens) { - auto decode_result = tokenizer->decode(prev_token, token); + auto decode_result = tokenizer.decode(prev_token, token); if (decode_result.ok()) { result += decode_result.get(); } @@ -156,7 +156,7 @@ std::string decode_token_sequence( std::vector get_words_offsets( const std::vector& tokens, - tokenizers::Tokenizer* tokenizer, + const tokenizers::Tokenizer& tokenizer, const std::unordered_set& supported_punctuation, const std::string& word_delimiter_char = " ") { std::vector word_offsets; @@ -363,11 +363,11 @@ std::vector greedy_decode_executorch( int64_t encoder_len, int64_t blank_id, int64_t vocab_size, + std::vector& token_start_offsets, + std::vector& token_durations, int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, - int64_t max_symbols_per_step = 10, - std::vector* token_start_offsets = nullptr, - std::vector* token_durations = nullptr) { + int64_t max_symbols_per_step = 10) { std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; @@ -524,12 +524,8 @@ std::vector greedy_decode_executorch( symbols_on_frame = 0; } else { hypothesis.push_back(k); - if (token_start_offsets != nullptr) { - token_start_offsets->push_back(t); - } - if (token_durations != nullptr) { - token_durations->push_back(dur); - } + token_start_offsets.push_back(t); + token_durations.push_back(dur); // Update decoder state std::vector token_data = {k}; @@ -591,7 +587,7 @@ std::vector greedy_decode_executorch( std::string tokens_to_text( const std::vector& tokens, - tokenizers::Tokenizer* tokenizer) { + const tokenizers::Tokenizer& tokenizer) { std::vector ids; ids.reserve(tokens.size()); for (int64_t t : tokens) { @@ -738,11 +734,11 @@ int main(int argc, char** argv) { encoded_len, blank_id, vocab_size, + token_start_offsets, + token_durations, num_rnn_layers, pred_hidden, - /*max_symbols_per_step=*/10, - &token_start_offsets, - &token_durations); + /*max_symbols_per_step=*/10); ET_LOG(Info, "Decoded %zu tokens", tokens.size()); @@ -759,14 +755,14 @@ int main(int argc, char** argv) { } std::unordered_set supported_punctuation = - derive_supported_punctuation(tokenizer.get()); + derive_supported_punctuation(*tokenizer); ET_LOG( Info, "Derived supported_punctuation size=%zu", supported_punctuation.size()); // Convert tokens to text - std::string text = tokens_to_text(tokens, tokenizer.get()); + std::string text = tokens_to_text(tokens, *tokenizer); std::cout << "Transcription tokens: " << text << std::endl; // Compute timestamps matching NeMo's TDT timestamp behavior. @@ -818,8 +814,8 @@ int main(int argc, char** argv) { } } - auto word_timestamps = get_words_offsets( - char_timestamps, tokenizer.get(), supported_punctuation); + auto word_timestamps = + get_words_offsets(char_timestamps, *tokenizer, supported_punctuation); auto segment_timestamps = get_segment_offsets(word_timestamps); const double frame_to_seconds = From afc34270a02376a0dbea2482c264bb3574274f02 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Sun, 11 Jan 2026 20:11:33 -0500 Subject: [PATCH 05/17] Add DecodedToken --- examples/models/parakeet/main.cpp | 47 +++++++++++++++---------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 6158d9acc42..76e4c4d4db6 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -49,6 +49,12 @@ namespace { // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; +struct DecodedToken { + int64_t token_id; + int64_t start_offset; + int64_t duration; +}; + struct TokenTimestamp { uint64_t token_id; std::string token_piece; @@ -357,18 +363,16 @@ std::vector get_segment_offsets( return segment_offsets; } -std::vector greedy_decode_executorch( +std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, int64_t encoder_len, int64_t blank_id, int64_t vocab_size, - std::vector& token_start_offsets, - std::vector& token_durations, int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { - std::vector hypothesis; + std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] @@ -523,9 +527,7 @@ std::vector greedy_decode_executorch( t += std::max(dur, (int64_t)1); symbols_on_frame = 0; } else { - hypothesis.push_back(k); - token_start_offsets.push_back(t); - token_durations.push_back(dur); + hypothesis.push_back({k, t, dur}); // Update decoder state std::vector token_data = {k}; @@ -726,21 +728,23 @@ int main(int argc, char** argv) { static_cast(encoder_subsampling_factor)); ET_LOG(Info, "Running TDT greedy decode..."); - std::vector token_start_offsets; - std::vector token_durations; - auto tokens = greedy_decode_executorch( + auto decoded_tokens = greedy_decode_executorch( *model, encoded, encoded_len, blank_id, vocab_size, - token_start_offsets, - token_durations, num_rnn_layers, pred_hidden, /*max_symbols_per_step=*/10); - ET_LOG(Info, "Decoded %zu tokens", tokens.size()); + ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size()); + + std::vector tokens; + tokens.reserve(decoded_tokens.size()); + for (const auto& tok : decoded_tokens) { + tokens.push_back(tok.token_id); + } // Load tokenizer ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); @@ -766,17 +770,12 @@ int main(int argc, char** argv) { std::cout << "Transcription tokens: " << text << std::endl; // Compute timestamps matching NeMo's TDT timestamp behavior. - if (tokens.size() != token_start_offsets.size() || - tokens.size() != token_durations.size()) { - ET_LOG(Error, "Token/timestamp length mismatch"); - return 1; - } - std::vector char_timestamps; - char_timestamps.reserve(tokens.size()); + char_timestamps.reserve(decoded_tokens.size()); - for (size_t i = 0; i < tokens.size(); i++) { - const uint64_t token_id = static_cast(tokens[i]); + for (size_t i = 0; i < decoded_tokens.size(); i++) { + const auto& decoded_token = decoded_tokens[i]; + const uint64_t token_id = static_cast(decoded_token.token_id); auto piece_result = tokenizer->id_to_piece(token_id); if (!piece_result.ok()) { @@ -794,8 +793,8 @@ int main(int argc, char** argv) { return 1; } - const int64_t start_offset = token_start_offsets[i]; - const int64_t end_offset = start_offset + token_durations[i]; + const int64_t start_offset = decoded_token.start_offset; + const int64_t end_offset = start_offset + decoded_token.duration; char_timestamps.push_back( {token_id, From 6313a490d40a08911e4dba29af0f76c866104896 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Sun, 11 Jan 2026 21:00:34 -0500 Subject: [PATCH 06/17] Same token id types, some small type cleanups --- examples/models/parakeet/main.cpp | 140 +++++++++++++++--------------- 1 file changed, 68 insertions(+), 72 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 76e4c4d4db6..05c4ac84880 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -46,31 +47,30 @@ using ::executorch::runtime::EValue; namespace { +// Matches output type of tokenizers::Tokenizer methods +using TokenId = uint64_t; + // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; struct DecodedToken { - int64_t token_id; + TokenId token_id; int64_t start_offset; int64_t duration; }; struct TokenTimestamp { - uint64_t token_id; + TokenId token_id; + // Raw vocabulary piece for the token_id (i.e., "##ing", "▁hello") std::string token_piece; + // Decoded text for the token_id (i.e., "ing", " hello") std::string token_text; int64_t start_offset; int64_t end_offset; }; -struct WordTimestamp { - std::string word; - int64_t start_offset; - int64_t end_offset; -}; - -struct SegmentTimestamp { - std::string segment; +struct TextWithOffsets { + std::string text; int64_t start_offset; int64_t end_offset; }; @@ -112,13 +112,15 @@ bool is_special_token(const std::string& token) { return false; } +// Matches NeMo extract_punctuation_from_vocab method +// https://github.com/NVIDIA-NeMo/NeMo/blob/b90a528/nemo/collections/asr/parts/utils/tokenizer_utils.py#L20 std::unordered_set derive_supported_punctuation( const tokenizers::Tokenizer& tokenizer) { std::unordered_set punctuation; const int32_t vocab_size = tokenizer.vocab_size(); for (int32_t id = 0; id < vocab_size; id++) { - const auto piece_result = tokenizer.id_to_piece(static_cast(id)); + const auto piece_result = tokenizer.id_to_piece(static_cast(id)); if (!piece_result.ok()) { continue; } @@ -146,11 +148,11 @@ std::unordered_set derive_supported_punctuation( } std::string decode_token_sequence( - const std::vector& tokens, + const std::vector& tokens, const tokenizers::Tokenizer& tokenizer) { std::string result; - uint64_t prev_token = tokenizer.bos_tok(); - for (uint64_t token : tokens) { + TokenId prev_token = tokenizer.bos_tok(); + for (const TokenId token : tokens) { auto decode_result = tokenizer.decode(prev_token, token); if (decode_result.ok()) { result += decode_result.get(); @@ -160,18 +162,18 @@ std::string decode_token_sequence( return result; } -std::vector get_words_offsets( +std::vector get_words_offsets( const std::vector& tokens, const tokenizers::Tokenizer& tokenizer, const std::unordered_set& supported_punctuation, const std::string& word_delimiter_char = " ") { - std::vector word_offsets; + std::vector word_offsets; if (tokens.empty()) { return word_offsets; } size_t previous_token_index = 0; - std::vector built_tokens; + std::vector build_token_indices; auto is_curr_punctuation = [&](const std::string& token_text) { return token_text != word_delimiter_char && @@ -203,10 +205,10 @@ std::vector get_words_offsets( if (is_word_start( token.token_piece, token.token_text, next_non_delim_token) && !curr_punctuation) { - if (!built_tokens.empty()) { - std::vector built_ids; - built_ids.reserve(built_tokens.size()); - for (size_t idx : built_tokens) { + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (size_t idx : build_token_indices) { built_ids.push_back(tokens[idx].token_id); } word_offsets.push_back( @@ -215,41 +217,41 @@ std::vector get_words_offsets( tokens[i - 1].end_offset}); } - built_tokens.clear(); + build_token_indices.clear(); if (token.token_text != word_delimiter_char) { - built_tokens.push_back(i); + build_token_indices.push_back(i); previous_token_index = i; } } else if ( - curr_punctuation && built_tokens.empty() && !word_offsets.empty()) { + curr_punctuation && build_token_indices.empty() && !word_offsets.empty()) { auto& last_built_word = word_offsets.back(); last_built_word.end_offset = token.end_offset; - if (!last_built_word.word.empty() && last_built_word.word.back() == ' ') { - last_built_word.word.pop_back(); + if (!last_built_word.text.empty() && last_built_word.text.back() == ' ') { + last_built_word.text.pop_back(); } - last_built_word.word += token.token_text; - } else if (curr_punctuation && !built_tokens.empty()) { - const auto& last = tokens[built_tokens.back()].token_piece; + last_built_word.text += token.token_text; + } else if (curr_punctuation && !build_token_indices.empty()) { + const auto& last = tokens[build_token_indices.back()].token_piece; if (last == " " || last == "_" || last == "▁") { - built_tokens.pop_back(); + build_token_indices.pop_back(); } - built_tokens.push_back(i); + build_token_indices.push_back(i); } else { - if (built_tokens.empty()) { + if (build_token_indices.empty()) { previous_token_index = i; } - built_tokens.push_back(i); + build_token_indices.push_back(i); } } // Match NeMo behavior: inject first start_offset and append any remaining // built tokens as the final word. if (word_offsets.empty()) { - if (!built_tokens.empty()) { - std::vector built_ids; - built_ids.reserve(built_tokens.size()); - for (size_t idx : built_tokens) { + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (const size_t idx : build_token_indices) { built_ids.push_back(tokens[idx].token_id); } word_offsets.push_back( @@ -260,10 +262,10 @@ std::vector get_words_offsets( } else { word_offsets[0].start_offset = tokens[0].start_offset; - if (!built_tokens.empty()) { - std::vector built_ids; - built_ids.reserve(built_tokens.size()); - for (size_t idx : built_tokens) { + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (size_t idx : build_token_indices) { built_ids.push_back(tokens[idx].token_id); } word_offsets.push_back( @@ -276,11 +278,11 @@ std::vector get_words_offsets( return word_offsets; } -std::vector get_segment_offsets( - const std::vector& word_offsets, +std::vector get_segment_offsets( + const std::vector& word_offsets, const std::vector& segment_delimiters = {".", "?", "!"}, const std::optional& segment_gap_threshold = std::nullopt) { - std::vector segment_offsets; + std::vector segment_offsets; if (word_offsets.empty()) { return segment_offsets; } @@ -290,7 +292,7 @@ std::vector get_segment_offsets( for (size_t i = 0; i < word_offsets.size(); i++) { const auto& offset = word_offsets[i]; - const auto& word = offset.word; + const auto& word = offset.text; if (segment_gap_threshold.has_value() && !segment_words.empty()) { const int64_t gap_between_words = @@ -524,10 +526,10 @@ std::vector greedy_decode_executorch( int64_t dur = DURATIONS[dur_idx]; if (k == blank_id) { - t += std::max(dur, (int64_t)1); + t += std::max(dur, static_cast(1)); symbols_on_frame = 0; } else { - hypothesis.push_back({k, t, dur}); + hypothesis.push_back({static_cast(k), t, dur}); // Update decoder state std::vector token_data = {k}; @@ -588,14 +590,18 @@ std::vector greedy_decode_executorch( } std::string tokens_to_text( - const std::vector& tokens, + const std::vector& decoded_tokens, const tokenizers::Tokenizer& tokenizer) { - std::vector ids; - ids.reserve(tokens.size()); - for (int64_t t : tokens) { - ids.push_back(static_cast(t)); + std::string result; + TokenId prev_token = tokenizer.bos_tok(); + for (const auto& tok : decoded_tokens) { + auto decode_result = tokenizer.decode(prev_token, tok.token_id); + if (decode_result.ok()) { + result += decode_result.get(); + } + prev_token = tok.token_id; } - return decode_token_sequence(ids, tokenizer); + return result; } } // namespace @@ -740,12 +746,6 @@ int main(int argc, char** argv) { ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size()); - std::vector tokens; - tokens.reserve(decoded_tokens.size()); - for (const auto& tok : decoded_tokens) { - tokens.push_back(tok.token_id); - } - // Load tokenizer ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); auto tokenizer = @@ -758,6 +758,10 @@ int main(int argc, char** argv) { return 1; } + // Convert tokens to text + std::string text = tokens_to_text(decoded_tokens, *tokenizer); + std::cout << "Transcription tokens: " << text << std::endl; + std::unordered_set supported_punctuation = derive_supported_punctuation(*tokenizer); ET_LOG( @@ -765,24 +769,16 @@ int main(int argc, char** argv) { "Derived supported_punctuation size=%zu", supported_punctuation.size()); - // Convert tokens to text - std::string text = tokens_to_text(tokens, *tokenizer); - std::cout << "Transcription tokens: " << text << std::endl; - // Compute timestamps matching NeMo's TDT timestamp behavior. std::vector char_timestamps; char_timestamps.reserve(decoded_tokens.size()); - for (size_t i = 0; i < decoded_tokens.size(); i++) { - const auto& decoded_token = decoded_tokens[i]; - const uint64_t token_id = static_cast(decoded_token.token_id); + for (const auto& decoded_token : decoded_tokens) { + const TokenId token_id = decoded_token.token_id; auto piece_result = tokenizer->id_to_piece(token_id); if (!piece_result.ok()) { - ET_LOG( - Error, - "id_to_piece failed for token=%llu", - (unsigned long long)token_id); + ET_LOG(Error, "id_to_piece failed for token=%llu", token_id); return 1; } @@ -824,14 +820,14 @@ int main(int argc, char** argv) { for (const auto& stamp : segment_timestamps) { const double start = stamp.start_offset * frame_to_seconds; const double end = stamp.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << stamp.segment << std::endl; + std::cout << start << "s - " << end << "s : " << stamp.text << std::endl; } std::cout << "\nWord timestamps:" << std::endl; for (const auto& stamp : word_timestamps) { const double start = stamp.start_offset * frame_to_seconds; const double end = stamp.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << stamp.word << std::endl; + std::cout << start << "s - " << end << "s : " << stamp.text << std::endl; } std::cout << "\nChar timestamps:" << std::endl; From df2e8e8da4d077790ae4ec8b2e42c6cb590b627f Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Mon, 12 Jan 2026 10:58:47 -0500 Subject: [PATCH 07/17] Refs and decode token string overload --- examples/models/parakeet/main.cpp | 38 ++++++++++++++++--------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 05c4ac84880..c0827e8c426 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -46,7 +46,6 @@ using ::executorch::runtime::Error; using ::executorch::runtime::EValue; namespace { - // Matches output type of tokenizers::Tokenizer methods using TokenId = uint64_t; @@ -162,6 +161,21 @@ std::string decode_token_sequence( return result; } +// convenience overload +std::string decode_token_sequence( + const std::vector& decoded_tokens, + const tokenizers::Tokenizer& tokenizer) { + std::vector token_ids; + token_ids.reserve(decoded_tokens.size()); + for (const auto& tok : decoded_tokens) { + token_ids.push_back(tok.token_id); + } + return decode_token_sequence(token_ids, tokenizer); +} + +// ref: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54 +// assumes BPE tokenizer type std::vector get_words_offsets( const std::vector& tokens, const tokenizers::Tokenizer& tokenizer, @@ -224,7 +238,8 @@ std::vector get_words_offsets( previous_token_index = i; } } else if ( - curr_punctuation && build_token_indices.empty() && !word_offsets.empty()) { + curr_punctuation && build_token_indices.empty() && + !word_offsets.empty()) { auto& last_built_word = word_offsets.back(); last_built_word.end_offset = token.end_offset; if (!last_built_word.text.empty() && last_built_word.text.back() == ' ') { @@ -278,6 +293,8 @@ std::vector get_words_offsets( return word_offsets; } +// ref +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227 std::vector get_segment_offsets( const std::vector& word_offsets, const std::vector& segment_delimiters = {".", "?", "!"}, @@ -589,21 +606,6 @@ std::vector greedy_decode_executorch( return hypothesis; } -std::string tokens_to_text( - const std::vector& decoded_tokens, - const tokenizers::Tokenizer& tokenizer) { - std::string result; - TokenId prev_token = tokenizer.bos_tok(); - for (const auto& tok : decoded_tokens) { - auto decode_result = tokenizer.decode(prev_token, tok.token_id); - if (decode_result.ok()) { - result += decode_result.get(); - } - prev_token = tok.token_id; - } - return result; -} - } // namespace int main(int argc, char** argv) { @@ -759,7 +761,7 @@ int main(int argc, char** argv) { } // Convert tokens to text - std::string text = tokens_to_text(decoded_tokens, *tokenizer); + std::string text = decode_token_sequence(decoded_tokens, *tokenizer); std::cout << "Transcription tokens: " << text << std::endl; std::unordered_set supported_punctuation = From 3a4a2f1622ff1a4fceee1dbe3b95fd5b733ed015 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Mon, 12 Jan 2026 11:11:53 -0500 Subject: [PATCH 08/17] Rename to FrameAlignedToken --- examples/models/parakeet/main.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index c0827e8c426..9ae3422027b 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -58,7 +58,7 @@ struct DecodedToken { int64_t duration; }; -struct TokenTimestamp { +struct FrameAlignedToken { TokenId token_id; // Raw vocabulary piece for the token_id (i.e., "##ing", "▁hello") std::string token_piece; @@ -177,7 +177,7 @@ std::string decode_token_sequence( // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54 // assumes BPE tokenizer type std::vector get_words_offsets( - const std::vector& tokens, + const std::vector& tokens, const tokenizers::Tokenizer& tokenizer, const std::unordered_set& supported_punctuation, const std::string& word_delimiter_char = " ") { @@ -733,7 +733,7 @@ int main(int argc, char** argv) { static_cast(pred_hidden), static_cast(sample_rate), window_stride, - static_cast(encoder_subsampling_factor)); + encoder_subsampling_factor); ET_LOG(Info, "Running TDT greedy decode..."); auto decoded_tokens = greedy_decode_executorch( @@ -743,8 +743,7 @@ int main(int argc, char** argv) { blank_id, vocab_size, num_rnn_layers, - pred_hidden, - /*max_symbols_per_step=*/10); + pred_hidden); ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size()); @@ -772,7 +771,7 @@ int main(int argc, char** argv) { supported_punctuation.size()); // Compute timestamps matching NeMo's TDT timestamp behavior. - std::vector char_timestamps; + std::vector char_timestamps; char_timestamps.reserve(decoded_tokens.size()); for (const auto& decoded_token : decoded_tokens) { @@ -787,7 +786,9 @@ int main(int argc, char** argv) { auto text_result = tokenizer->decode(tokenizer->bos_tok(), token_id); if (!text_result.ok()) { ET_LOG( - Error, "decode failed for token=%llu", (unsigned long long)token_id); + Error, + "decode failed for token=%llu", + static_cast(token_id)); return 1; } From 1a23c140e872b2ed4c476a15efab5a8c019e94a2 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Mon, 12 Jan 2026 12:02:33 -0500 Subject: [PATCH 09/17] Helper for tokens with text info --- examples/models/parakeet/README.md | 1 + examples/models/parakeet/main.cpp | 237 +++++++++++++++++++---------- 2 files changed, 154 insertions(+), 84 deletions(-) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b27bc1f8a91..98711f45a0f 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -71,3 +71,4 @@ From the executorch root directory: | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | | `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | +| `--timestamps` | Timestamp output mode: `none\|token\|word\|segment\|all` | diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 9ae3422027b..53bbd6f0ce4 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -39,6 +39,10 @@ DEFINE_string( data_path, "", "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); +DEFINE_string( + timestamps, + "none", + "Timestamp output mode: none|token|word|segment|all"); using ::executorch::extension::from_blob; using ::executorch::extension::Module; @@ -52,18 +56,28 @@ using TokenId = uint64_t; // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; -struct DecodedToken { - TokenId token_id; +struct TimestampOutputMode { + bool token = false; + bool word = false; + bool segment = false; + + bool enabled() const { + return token || word || segment; + } +}; + +struct Token { + TokenId id; int64_t start_offset; int64_t duration; }; -struct FrameAlignedToken { - TokenId token_id; +struct TokenWithTextInfo { + TokenId id; // Raw vocabulary piece for the token_id (i.e., "##ing", "▁hello") - std::string token_piece; + std::string raw_piece; // Decoded text for the token_id (i.e., "ing", " hello") - std::string token_text; + std::string decoded_text; int64_t start_offset; int64_t end_offset; }; @@ -74,6 +88,39 @@ struct TextWithOffsets { int64_t end_offset; }; +std::string to_lower_ascii(std::string s) { + for (char& ch : s) { + ch = static_cast(std::tolower(static_cast(ch))); + } + return s; +} + +TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { + if (raw_arg.empty()) { + throw std::invalid_argument( + "Invalid --timestamps value (empty). Expected: token, word, segment, all."); + } + const std::string mode = to_lower_ascii(raw_arg); + if (mode == "none") { + return {false, false, false}; + } + if (mode == "token") { + return {true, false, false}; + } + if (mode == "word") { + return {false, true, false}; + } + if (mode == "segment") { + return {false, false, true}; + } + if (mode == "all") { + return {true, true, true}; + } + throw std::invalid_argument( + "Invalid --timestamps value '" + raw_arg + + "'. Expected: token, word, segment, all."); +} + bool is_whitespace_only(const std::string& token) { if (token.empty()) { return true; @@ -163,21 +210,63 @@ std::string decode_token_sequence( // convenience overload std::string decode_token_sequence( - const std::vector& decoded_tokens, + const std::vector& decoded_tokens, const tokenizers::Tokenizer& tokenizer) { std::vector token_ids; token_ids.reserve(decoded_tokens.size()); for (const auto& tok : decoded_tokens) { - token_ids.push_back(tok.token_id); + token_ids.push_back(tok.id); } return decode_token_sequence(token_ids, tokenizer); } +// throws if any tokenizer calls fail +std::vector get_tokens_with_text_info( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation) { + std::vector tokens_with_text; + tokens_with_text.reserve(tokens.size()); + + for (const auto& token : tokens) { + auto piece_result = tokenizer.id_to_piece(token.id); + if (!piece_result.ok()) { + throw std::runtime_error( + "id_to_piece failed for token=" + std::to_string(token.id)); + } + + auto text_result = tokenizer.decode(tokenizer.bos_tok(), token.id); + if (!text_result.ok()) { + throw std::runtime_error( + "decode failed for token=" + std::to_string(token.id)); + } + + tokens_with_text.push_back( + {token.id, + piece_result.get(), + text_result.get(), + token.start_offset, + token.start_offset + token.duration}); + } + + // NeMo TDT punctuation refinement pass: snap punctuation to the end of the + // previous token. + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 + for (size_t i = 1; i < tokens_with_text.size(); i++) { + if (supported_punctuation.count(tokens_with_text[i].decoded_text) > 0) { + tokens_with_text[i].start_offset = tokens_with_text[i - 1].end_offset; + tokens_with_text[i].end_offset = tokens_with_text[i].start_offset; + } + } + + return tokens_with_text; +} + // ref: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54 // assumes BPE tokenizer type std::vector get_words_offsets( - const std::vector& tokens, + const std::vector& tokens, const tokenizers::Tokenizer& tokenizer, const std::unordered_set& supported_punctuation, const std::string& word_delimiter_char = " ") { @@ -206,24 +295,24 @@ std::vector get_words_offsets( for (size_t i = 0; i < tokens.size(); i++) { const auto& token = tokens[i]; - const bool curr_punctuation = is_curr_punctuation(token.token_text); + const bool curr_punctuation = is_curr_punctuation(token.decoded_text); std::string next_non_delim_token; for (size_t j = i + 1; j < tokens.size(); j++) { - if (tokens[j].token_text != word_delimiter_char) { - next_non_delim_token = tokens[j].token_text; + if (tokens[j].decoded_text != word_delimiter_char) { + next_non_delim_token = tokens[j].decoded_text; break; } } if (is_word_start( - token.token_piece, token.token_text, next_non_delim_token) && + token.raw_piece, token.decoded_text, next_non_delim_token) && !curr_punctuation) { if (!build_token_indices.empty()) { std::vector built_ids; built_ids.reserve(build_token_indices.size()); for (size_t idx : build_token_indices) { - built_ids.push_back(tokens[idx].token_id); + built_ids.push_back(tokens[idx].id); } word_offsets.push_back( {decode_token_sequence(built_ids, tokenizer), @@ -233,7 +322,7 @@ std::vector get_words_offsets( build_token_indices.clear(); - if (token.token_text != word_delimiter_char) { + if (token.decoded_text != word_delimiter_char) { build_token_indices.push_back(i); previous_token_index = i; } @@ -245,9 +334,9 @@ std::vector get_words_offsets( if (!last_built_word.text.empty() && last_built_word.text.back() == ' ') { last_built_word.text.pop_back(); } - last_built_word.text += token.token_text; + last_built_word.text += token.decoded_text; } else if (curr_punctuation && !build_token_indices.empty()) { - const auto& last = tokens[build_token_indices.back()].token_piece; + const auto& last = tokens[build_token_indices.back()].raw_piece; if (last == " " || last == "_" || last == "▁") { build_token_indices.pop_back(); } @@ -267,7 +356,7 @@ std::vector get_words_offsets( std::vector built_ids; built_ids.reserve(build_token_indices.size()); for (const size_t idx : build_token_indices) { - built_ids.push_back(tokens[idx].token_id); + built_ids.push_back(tokens[idx].id); } word_offsets.push_back( {decode_token_sequence(built_ids, tokenizer), @@ -281,7 +370,7 @@ std::vector get_words_offsets( std::vector built_ids; built_ids.reserve(build_token_indices.size()); for (size_t idx : build_token_indices) { - built_ids.push_back(tokens[idx].token_id); + built_ids.push_back(tokens[idx].id); } word_offsets.push_back( {decode_token_sequence(built_ids, tokenizer), @@ -382,7 +471,7 @@ std::vector get_segment_offsets( return segment_offsets; } -std::vector greedy_decode_executorch( +std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, int64_t encoder_len, @@ -391,7 +480,7 @@ std::vector greedy_decode_executorch( int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { - std::vector hypothesis; + std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] @@ -611,6 +700,14 @@ std::vector greedy_decode_executorch( int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); + TimestampOutputMode timestamp_mode; + try { + timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps); + } catch (const std::invalid_argument& e) { + ET_LOG(Error, "%s", e.what()); + return 1; + } + if (FLAGS_audio_path.empty()) { ET_LOG(Error, "audio_path flag must be provided."); return 1; @@ -761,8 +858,13 @@ int main(int argc, char** argv) { // Convert tokens to text std::string text = decode_token_sequence(decoded_tokens, *tokenizer); - std::cout << "Transcription tokens: " << text << std::endl; + std::cout << "Transcribed text: " << text << std::endl; + + if (!timestamp_mode.enabled()) { + return 0; + } + ET_LOG(Info, "Computing timestamps..."); std::unordered_set supported_punctuation = derive_supported_punctuation(*tokenizer); ET_LOG( @@ -770,77 +872,44 @@ int main(int argc, char** argv) { "Derived supported_punctuation size=%zu", supported_punctuation.size()); - // Compute timestamps matching NeMo's TDT timestamp behavior. - std::vector char_timestamps; - char_timestamps.reserve(decoded_tokens.size()); - - for (const auto& decoded_token : decoded_tokens) { - const TokenId token_id = decoded_token.token_id; - - auto piece_result = tokenizer->id_to_piece(token_id); - if (!piece_result.ok()) { - ET_LOG(Error, "id_to_piece failed for token=%llu", token_id); - return 1; - } - - auto text_result = tokenizer->decode(tokenizer->bos_tok(), token_id); - if (!text_result.ok()) { - ET_LOG( - Error, - "decode failed for token=%llu", - static_cast(token_id)); - return 1; - } - - const int64_t start_offset = decoded_token.start_offset; - const int64_t end_offset = start_offset + decoded_token.duration; - - char_timestamps.push_back( - {token_id, - piece_result.get(), - text_result.get(), - start_offset, - end_offset}); - } - - // NeMo TDT punctuation refinement: snap punctuation to the end of the - // previous token. - for (size_t i = 1; i < char_timestamps.size(); i++) { - if (supported_punctuation.count(char_timestamps[i].token_text) > 0) { - char_timestamps[i].start_offset = char_timestamps[i - 1].end_offset; - char_timestamps[i].end_offset = char_timestamps[i].start_offset; - } - } - - auto word_timestamps = - get_words_offsets(char_timestamps, *tokenizer, supported_punctuation); - auto segment_timestamps = get_segment_offsets(word_timestamps); + // for simplicity, compute all levels of timestamps regardless of mode + const auto tokens_with_text_info = get_tokens_with_text_info( + decoded_tokens, *tokenizer, supported_punctuation); + const auto word_offsets = get_words_offsets( + tokens_with_text_info, *tokenizer, supported_punctuation); + const auto segment_offsets = get_segment_offsets(word_offsets); const double frame_to_seconds = window_stride * static_cast(encoder_subsampling_factor); - std::cout << "\nSegment timestamps:" << std::endl; - for (const auto& stamp : segment_timestamps) { - const double start = stamp.start_offset * frame_to_seconds; - const double end = stamp.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << stamp.text << std::endl; + if (timestamp_mode.segment) { + std::cout << "\nSegment timestamps:" << std::endl; + for (const auto& segment : segment_offsets) { + const double start = segment.start_offset * frame_to_seconds; + const double end = segment.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << segment.text + << std::endl; + } } - std::cout << "\nWord timestamps:" << std::endl; - for (const auto& stamp : word_timestamps) { - const double start = stamp.start_offset * frame_to_seconds; - const double end = stamp.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << stamp.text << std::endl; + if (timestamp_mode.word) { + std::cout << "\nWord timestamps:" << std::endl; + for (const auto& word : word_offsets) { + const double start = word.start_offset * frame_to_seconds; + const double end = word.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << word.text << std::endl; + } } - std::cout << "\nChar timestamps:" << std::endl; - for (const auto& stamp : char_timestamps) { - const double start = stamp.start_offset * frame_to_seconds; - const double end = stamp.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << stamp.token_text - << std::endl; + if (timestamp_mode.token) { + std::cout << "\nToken timestamps:" << std::endl; + for (const auto& token : tokens_with_text_info) { + const double start = token.start_offset * frame_to_seconds; + const double end = token.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << token.decoded_text + << std::endl; + } } - ET_LOG(Info, "Done!"); return 0; } From 0c9768d224662abdc0bf890f0053babc169e0943 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Mon, 12 Jan 2026 14:10:13 -0500 Subject: [PATCH 10/17] try-catch get_tokens_with_text_info --- examples/models/parakeet/main.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 53bbd6f0ce4..412b1717b6a 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -873,8 +873,14 @@ int main(int argc, char** argv) { supported_punctuation.size()); // for simplicity, compute all levels of timestamps regardless of mode - const auto tokens_with_text_info = get_tokens_with_text_info( + std::vector tokens_with_text_info; + try { + tokens_with_text_info = get_tokens_with_text_info( decoded_tokens, *tokenizer, supported_punctuation); + } catch (const std::exception& e) { + ET_LOG(Error, "Failed to get tokens with text info: %s", e.what()); + return 1; + } const auto word_offsets = get_words_offsets( tokens_with_text_info, *tokenizer, supported_punctuation); const auto segment_offsets = get_segment_offsets(word_offsets); From 365896d152f7ea8d8b72d8871de311fbebeab56b Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Tue, 13 Jan 2026 10:00:29 -0500 Subject: [PATCH 11/17] Remove duplicated mock --- extension/llm/runner/test/test_text_llm_runner.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 69798f224cb..c9b57fb7391 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -38,11 +38,6 @@ class MockTokenizer : public ::tokenizers::Tokenizer { encode, (const std::string&, int8_t, int8_t), (const)); - MOCK_METHOD( - ::tokenizers::Result, - id_to_piece, - (uint64_t), - (const)); MOCK_METHOD( ::tokenizers::Result, decode, From 08b82fdb6b0a6775ec608ff16ff82a71cc3a3a47 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Tue, 13 Jan 2026 14:30:09 -0500 Subject: [PATCH 12/17] timestamp_utils/tokenizer_utils/types re-organization --- examples/models/parakeet/CMakeLists.txt | 7 +- examples/models/parakeet/main.cpp | 399 +------------------ examples/models/parakeet/timestamp_utils.cpp | 257 ++++++++++++ examples/models/parakeet/timestamp_utils.h | 36 ++ examples/models/parakeet/tokenizer_utils.cpp | 111 ++++++ examples/models/parakeet/tokenizer_utils.h | 27 ++ examples/models/parakeet/types.h | 33 ++ 7 files changed, 488 insertions(+), 382 deletions(-) create mode 100644 examples/models/parakeet/timestamp_utils.cpp create mode 100644 examples/models/parakeet/timestamp_utils.h create mode 100644 examples/models/parakeet/tokenizer_utils.cpp create mode 100644 examples/models/parakeet/tokenizer_utils.h create mode 100644 examples/models/parakeet/types.h diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt index 5ea7b81cd1f..7632d1e92ea 100644 --- a/examples/models/parakeet/CMakeLists.txt +++ b/examples/models/parakeet/CMakeLists.txt @@ -80,7 +80,12 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() -add_executable(parakeet_runner main.cpp) +add_executable( + parakeet_runner + main.cpp + timestamp_utils.cpp + tokenizer_utils.cpp +) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(parakeet_runner) if(NOT APPLE AND NOT MSVC) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 412b1717b6a..291282aecc8 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -21,6 +21,10 @@ #include +#include "timestamp_utils.h" +#include "tokenizer_utils.h" +#include "types.h" + #include #include #include @@ -49,10 +53,12 @@ using ::executorch::extension::Module; using ::executorch::runtime::Error; using ::executorch::runtime::EValue; -namespace { -// Matches output type of tokenizers::Tokenizer methods -using TokenId = uint64_t; +using ::parakeet::TextWithOffsets; +using ::parakeet::Token; +using ::parakeet::TokenId; +using ::parakeet::TokenWithTextInfo; +namespace { // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; @@ -66,28 +72,6 @@ struct TimestampOutputMode { } }; -struct Token { - TokenId id; - int64_t start_offset; - int64_t duration; -}; - -struct TokenWithTextInfo { - TokenId id; - // Raw vocabulary piece for the token_id (i.e., "##ing", "▁hello") - std::string raw_piece; - // Decoded text for the token_id (i.e., "ing", " hello") - std::string decoded_text; - int64_t start_offset; - int64_t end_offset; -}; - -struct TextWithOffsets { - std::string text; - int64_t start_offset; - int64_t end_offset; -}; - std::string to_lower_ascii(std::string s) { for (char& ch : s) { ch = static_cast(std::tolower(static_cast(ch))); @@ -121,356 +105,6 @@ TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { "'. Expected: token, word, segment, all."); } -bool is_whitespace_only(const std::string& token) { - if (token.empty()) { - return true; - } - - try { - const auto codepoints = unicode_cpts_from_utf8(token); - for (const auto cp : codepoints) { - if (!unicode_cpt_flags(cp).is_whitespace) { - return false; - } - } - return true; - } catch (const std::exception&) { - return false; - } -} - -bool is_special_token(const std::string& token) { - if (token.size() >= 2 && token.front() == '[' && token.back() == ']') { - return true; - } - if (token.size() >= 2 && token.front() == '<' && token.back() == '>') { - return true; - } - if (token.rfind("##", 0) == 0) { - return true; - } - if (token.rfind(u8"▁", 0) == 0) { - return true; - } - if (is_whitespace_only(token)) { - return true; - } - return false; -} - -// Matches NeMo extract_punctuation_from_vocab method -// https://github.com/NVIDIA-NeMo/NeMo/blob/b90a528/nemo/collections/asr/parts/utils/tokenizer_utils.py#L20 -std::unordered_set derive_supported_punctuation( - const tokenizers::Tokenizer& tokenizer) { - std::unordered_set punctuation; - - const int32_t vocab_size = tokenizer.vocab_size(); - for (int32_t id = 0; id < vocab_size; id++) { - const auto piece_result = tokenizer.id_to_piece(static_cast(id)); - if (!piece_result.ok()) { - continue; - } - const std::string& piece = piece_result.get(); - if (is_special_token(piece)) { - continue; - } - - try { - const auto codepoints = unicode_cpts_from_utf8(piece); - for (const auto cp : codepoints) { - if (unicode_cpt_flags(cp).is_punctuation) { - punctuation.insert(unicode_cpt_to_utf8(cp)); - } - } - } catch (const std::exception&) { - ET_LOG( - Error, - "Failed to decode token piece '%s' to codepoints", - piece.c_str()); - } - } - - return punctuation; -} - -std::string decode_token_sequence( - const std::vector& tokens, - const tokenizers::Tokenizer& tokenizer) { - std::string result; - TokenId prev_token = tokenizer.bos_tok(); - for (const TokenId token : tokens) { - auto decode_result = tokenizer.decode(prev_token, token); - if (decode_result.ok()) { - result += decode_result.get(); - } - prev_token = token; - } - return result; -} - -// convenience overload -std::string decode_token_sequence( - const std::vector& decoded_tokens, - const tokenizers::Tokenizer& tokenizer) { - std::vector token_ids; - token_ids.reserve(decoded_tokens.size()); - for (const auto& tok : decoded_tokens) { - token_ids.push_back(tok.id); - } - return decode_token_sequence(token_ids, tokenizer); -} - -// throws if any tokenizer calls fail -std::vector get_tokens_with_text_info( - const std::vector& tokens, - const tokenizers::Tokenizer& tokenizer, - const std::unordered_set& supported_punctuation) { - std::vector tokens_with_text; - tokens_with_text.reserve(tokens.size()); - - for (const auto& token : tokens) { - auto piece_result = tokenizer.id_to_piece(token.id); - if (!piece_result.ok()) { - throw std::runtime_error( - "id_to_piece failed for token=" + std::to_string(token.id)); - } - - auto text_result = tokenizer.decode(tokenizer.bos_tok(), token.id); - if (!text_result.ok()) { - throw std::runtime_error( - "decode failed for token=" + std::to_string(token.id)); - } - - tokens_with_text.push_back( - {token.id, - piece_result.get(), - text_result.get(), - token.start_offset, - token.start_offset + token.duration}); - } - - // NeMo TDT punctuation refinement pass: snap punctuation to the end of the - // previous token. - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 - for (size_t i = 1; i < tokens_with_text.size(); i++) { - if (supported_punctuation.count(tokens_with_text[i].decoded_text) > 0) { - tokens_with_text[i].start_offset = tokens_with_text[i - 1].end_offset; - tokens_with_text[i].end_offset = tokens_with_text[i].start_offset; - } - } - - return tokens_with_text; -} - -// ref: -// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54 -// assumes BPE tokenizer type -std::vector get_words_offsets( - const std::vector& tokens, - const tokenizers::Tokenizer& tokenizer, - const std::unordered_set& supported_punctuation, - const std::string& word_delimiter_char = " ") { - std::vector word_offsets; - if (tokens.empty()) { - return word_offsets; - } - - size_t previous_token_index = 0; - std::vector build_token_indices; - - auto is_curr_punctuation = [&](const std::string& token_text) { - return token_text != word_delimiter_char && - supported_punctuation.count(token_text) > 0; - }; - - auto is_word_start = [&](const std::string& token_piece, - const std::string& token_text, - const std::string& next_non_delim_token) { - const bool next_is_punctuation = - supported_punctuation.count(next_non_delim_token) > 0; - return token_piece != token_text || - (token_text == word_delimiter_char && !next_is_punctuation); - }; - - for (size_t i = 0; i < tokens.size(); i++) { - const auto& token = tokens[i]; - - const bool curr_punctuation = is_curr_punctuation(token.decoded_text); - - std::string next_non_delim_token; - for (size_t j = i + 1; j < tokens.size(); j++) { - if (tokens[j].decoded_text != word_delimiter_char) { - next_non_delim_token = tokens[j].decoded_text; - break; - } - } - - if (is_word_start( - token.raw_piece, token.decoded_text, next_non_delim_token) && - !curr_punctuation) { - if (!build_token_indices.empty()) { - std::vector built_ids; - built_ids.reserve(build_token_indices.size()); - for (size_t idx : build_token_indices) { - built_ids.push_back(tokens[idx].id); - } - word_offsets.push_back( - {decode_token_sequence(built_ids, tokenizer), - tokens[previous_token_index].start_offset, - tokens[i - 1].end_offset}); - } - - build_token_indices.clear(); - - if (token.decoded_text != word_delimiter_char) { - build_token_indices.push_back(i); - previous_token_index = i; - } - } else if ( - curr_punctuation && build_token_indices.empty() && - !word_offsets.empty()) { - auto& last_built_word = word_offsets.back(); - last_built_word.end_offset = token.end_offset; - if (!last_built_word.text.empty() && last_built_word.text.back() == ' ') { - last_built_word.text.pop_back(); - } - last_built_word.text += token.decoded_text; - } else if (curr_punctuation && !build_token_indices.empty()) { - const auto& last = tokens[build_token_indices.back()].raw_piece; - if (last == " " || last == "_" || last == "▁") { - build_token_indices.pop_back(); - } - build_token_indices.push_back(i); - } else { - if (build_token_indices.empty()) { - previous_token_index = i; - } - build_token_indices.push_back(i); - } - } - - // Match NeMo behavior: inject first start_offset and append any remaining - // built tokens as the final word. - if (word_offsets.empty()) { - if (!build_token_indices.empty()) { - std::vector built_ids; - built_ids.reserve(build_token_indices.size()); - for (const size_t idx : build_token_indices) { - built_ids.push_back(tokens[idx].id); - } - word_offsets.push_back( - {decode_token_sequence(built_ids, tokenizer), - tokens[0].start_offset, - tokens.back().end_offset}); - } - } else { - word_offsets[0].start_offset = tokens[0].start_offset; - - if (!build_token_indices.empty()) { - std::vector built_ids; - built_ids.reserve(build_token_indices.size()); - for (size_t idx : build_token_indices) { - built_ids.push_back(tokens[idx].id); - } - word_offsets.push_back( - {decode_token_sequence(built_ids, tokenizer), - tokens[previous_token_index].start_offset, - tokens.back().end_offset}); - } - } - - return word_offsets; -} - -// ref -// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227 -std::vector get_segment_offsets( - const std::vector& word_offsets, - const std::vector& segment_delimiters = {".", "?", "!"}, - const std::optional& segment_gap_threshold = std::nullopt) { - std::vector segment_offsets; - if (word_offsets.empty()) { - return segment_offsets; - } - - std::vector segment_words; - size_t previous_word_index = 0; - - for (size_t i = 0; i < word_offsets.size(); i++) { - const auto& offset = word_offsets[i]; - const auto& word = offset.text; - - if (segment_gap_threshold.has_value() && !segment_words.empty()) { - const int64_t gap_between_words = - offset.start_offset - word_offsets[i - 1].end_offset; - if (gap_between_words >= segment_gap_threshold.value()) { - std::string segment; - for (size_t j = 0; j < segment_words.size(); j++) { - if (j > 0) { - segment += " "; - } - segment += segment_words[j]; - } - segment_offsets.push_back( - {segment, - word_offsets[previous_word_index].start_offset, - word_offsets[i - 1].end_offset}); - segment_words = {word}; - previous_word_index = i; - continue; - } - } - - const bool is_delimiter_word = - std::find(segment_delimiters.begin(), segment_delimiters.end(), word) != - segment_delimiters.end(); - - const bool ends_with_delimiter = !word.empty() && - std::find( - segment_delimiters.begin(), - segment_delimiters.end(), - std::string(1, word.back())) != segment_delimiters.end(); - - if (!word.empty() && (ends_with_delimiter || is_delimiter_word)) { - segment_words.push_back(word); - if (!segment_words.empty()) { - std::string segment; - for (size_t j = 0; j < segment_words.size(); j++) { - if (j > 0) { - segment += " "; - } - segment += segment_words[j]; - } - segment_offsets.push_back( - {segment, - word_offsets[previous_word_index].start_offset, - offset.end_offset}); - } - segment_words.clear(); - previous_word_index = i + 1; - continue; - } - - segment_words.push_back(word); - } - - if (!segment_words.empty()) { - std::string segment; - for (size_t j = 0; j < segment_words.size(); j++) { - if (j > 0) { - segment += " "; - } - segment += segment_words[j]; - } - segment_offsets.push_back( - {segment, - word_offsets[previous_word_index].start_offset, - word_offsets.back().end_offset}); - } - - return segment_offsets; -} - std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, @@ -857,7 +491,8 @@ int main(int argc, char** argv) { } // Convert tokens to text - std::string text = decode_token_sequence(decoded_tokens, *tokenizer); + std::string text = parakeet::tokenizer_utils::decode_token_sequence( + decoded_tokens, *tokenizer); std::cout << "Transcribed text: " << text << std::endl; if (!timestamp_mode.enabled()) { @@ -866,7 +501,7 @@ int main(int argc, char** argv) { ET_LOG(Info, "Computing timestamps..."); std::unordered_set supported_punctuation = - derive_supported_punctuation(*tokenizer); + parakeet::tokenizer_utils::derive_supported_punctuation(*tokenizer); ET_LOG( Info, "Derived supported_punctuation size=%zu", @@ -875,15 +510,17 @@ int main(int argc, char** argv) { // for simplicity, compute all levels of timestamps regardless of mode std::vector tokens_with_text_info; try { - tokens_with_text_info = get_tokens_with_text_info( - decoded_tokens, *tokenizer, supported_punctuation); + tokens_with_text_info = + parakeet::timestamp_utils::get_tokens_with_text_info( + decoded_tokens, *tokenizer, supported_punctuation); } catch (const std::exception& e) { ET_LOG(Error, "Failed to get tokens with text info: %s", e.what()); return 1; } - const auto word_offsets = get_words_offsets( + const auto word_offsets = parakeet::timestamp_utils::get_words_offsets( tokens_with_text_info, *tokenizer, supported_punctuation); - const auto segment_offsets = get_segment_offsets(word_offsets); + const auto segment_offsets = + parakeet::timestamp_utils::get_segment_offsets(word_offsets); const double frame_to_seconds = window_stride * static_cast(encoder_subsampling_factor); diff --git a/examples/models/parakeet/timestamp_utils.cpp b/examples/models/parakeet/timestamp_utils.cpp new file mode 100644 index 00000000000..c15693566f9 --- /dev/null +++ b/examples/models/parakeet/timestamp_utils.cpp @@ -0,0 +1,257 @@ +#include "timestamp_utils.h" + +#include "tokenizer_utils.h" + +#include +#include + +#include + +namespace parakeet::timestamp_utils { + +std::vector get_tokens_with_text_info( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation) { + std::vector tokens_with_text; + tokens_with_text.reserve(tokens.size()); + + for (const auto& token : tokens) { + auto piece_result = tokenizer.id_to_piece(token.id); + if (!piece_result.ok()) { + throw std::runtime_error( + "id_to_piece failed for token=" + std::to_string(token.id)); + } + + auto text_result = tokenizer.decode(tokenizer.bos_tok(), token.id); + if (!text_result.ok()) { + throw std::runtime_error( + "decode failed for token=" + std::to_string(token.id)); + } + + tokens_with_text.push_back( + {token.id, + piece_result.get(), + text_result.get(), + token.start_offset, + token.start_offset + token.duration}); + } + + // NeMo TDT punctuation refinement pass: snap punctuation to the end of the + // previous token. + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 + for (size_t i = 1; i < tokens_with_text.size(); i++) { + if (supported_punctuation.count(tokens_with_text[i].decoded_text) > 0) { + tokens_with_text[i].start_offset = tokens_with_text[i - 1].end_offset; + tokens_with_text[i].end_offset = tokens_with_text[i].start_offset; + } + } + + return tokens_with_text; +} + +std::vector get_words_offsets( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation, + const std::string& word_delimiter_char) { + std::vector word_offsets; + if (tokens.empty()) { + return word_offsets; + } + + size_t previous_token_index = 0; + std::vector build_token_indices; + + auto is_curr_punctuation = [&](const std::string& token_text) { + return token_text != word_delimiter_char && + supported_punctuation.count(token_text) > 0; + }; + + auto is_word_start = [&](const std::string& token_piece, + const std::string& token_text, + const std::string& next_non_delim_token) { + const bool next_is_punctuation = + supported_punctuation.count(next_non_delim_token) > 0; + return token_piece != token_text || + (token_text == word_delimiter_char && !next_is_punctuation); + }; + + for (size_t i = 0; i < tokens.size(); i++) { + const auto& token = tokens[i]; + + const bool curr_punctuation = is_curr_punctuation(token.decoded_text); + + std::string next_non_delim_token; + for (size_t j = i + 1; j < tokens.size(); j++) { + if (tokens[j].decoded_text != word_delimiter_char) { + next_non_delim_token = tokens[j].decoded_text; + break; + } + } + + if (is_word_start( + token.raw_piece, token.decoded_text, next_non_delim_token) && + !curr_punctuation) { + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (size_t idx : build_token_indices) { + built_ids.push_back(tokens[idx].id); + } + word_offsets.push_back( + {tokenizer_utils::decode_token_sequence(built_ids, tokenizer), + tokens[previous_token_index].start_offset, + tokens[i - 1].end_offset}); + } + + build_token_indices.clear(); + + if (token.decoded_text != word_delimiter_char) { + build_token_indices.push_back(i); + previous_token_index = i; + } + } else if ( + curr_punctuation && build_token_indices.empty() && + !word_offsets.empty()) { + auto& last_built_word = word_offsets.back(); + last_built_word.end_offset = token.end_offset; + if (!last_built_word.text.empty() && last_built_word.text.back() == ' ') { + last_built_word.text.pop_back(); + } + last_built_word.text += token.decoded_text; + } else if (curr_punctuation && !build_token_indices.empty()) { + const auto& last = tokens[build_token_indices.back()].raw_piece; + if (last == " " || last == "_" || last == "▁") { + build_token_indices.pop_back(); + } + build_token_indices.push_back(i); + } else { + if (build_token_indices.empty()) { + previous_token_index = i; + } + build_token_indices.push_back(i); + } + } + + // Match NeMo behavior: inject first start_offset and append any remaining + // built tokens as the final word. + if (word_offsets.empty()) { + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (const size_t idx : build_token_indices) { + built_ids.push_back(tokens[idx].id); + } + word_offsets.push_back( + {tokenizer_utils::decode_token_sequence(built_ids, tokenizer), + tokens[0].start_offset, + tokens.back().end_offset}); + } + } else { + word_offsets[0].start_offset = tokens[0].start_offset; + + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (size_t idx : build_token_indices) { + built_ids.push_back(tokens[idx].id); + } + word_offsets.push_back( + {tokenizer_utils::decode_token_sequence(built_ids, tokenizer), + tokens[previous_token_index].start_offset, + tokens.back().end_offset}); + } + } + + return word_offsets; +} + +std::vector get_segment_offsets( + const std::vector& word_offsets, + const std::vector& segment_delimiters, + const std::optional& segment_gap_threshold) { + std::vector segment_offsets; + if (word_offsets.empty()) { + return segment_offsets; + } + + std::vector segment_words; + size_t previous_word_index = 0; + + for (size_t i = 0; i < word_offsets.size(); i++) { + const auto& offset = word_offsets[i]; + const auto& word = offset.text; + + if (segment_gap_threshold.has_value() && !segment_words.empty()) { + const int64_t gap_between_words = + offset.start_offset - word_offsets[i - 1].end_offset; + if (gap_between_words >= segment_gap_threshold.value()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + word_offsets[i - 1].end_offset}); + segment_words = {word}; + previous_word_index = i; + continue; + } + } + + const bool is_delimiter_word = + std::find(segment_delimiters.begin(), segment_delimiters.end(), word) != + segment_delimiters.end(); + + const bool ends_with_delimiter = !word.empty() && + std::find( + segment_delimiters.begin(), + segment_delimiters.end(), + std::string(1, word.back())) != segment_delimiters.end(); + + if (!word.empty() && (ends_with_delimiter || is_delimiter_word)) { + segment_words.push_back(word); + if (!segment_words.empty()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + offset.end_offset}); + } + segment_words.clear(); + previous_word_index = i + 1; + continue; + } + + segment_words.push_back(word); + } + + if (!segment_words.empty()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + word_offsets.back().end_offset}); + } + + return segment_offsets; +} + +} // namespace parakeet::timestamp_utils diff --git a/examples/models/parakeet/timestamp_utils.h b/examples/models/parakeet/timestamp_utils.h new file mode 100644 index 00000000000..5787b11d54e --- /dev/null +++ b/examples/models/parakeet/timestamp_utils.h @@ -0,0 +1,36 @@ +#pragma once + +#include "types.h" + +#include +#include +#include +#include + +#include + +namespace parakeet::timestamp_utils { + +// throws if any tokenizer calls fail +std::vector get_tokens_with_text_info( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation); + +// ref: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54 +// assumes BPE tokenizer type +std::vector get_words_offsets( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation, + const std::string& word_delimiter_char = " "); + +// ref +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227 +std::vector get_segment_offsets( + const std::vector& word_offsets, + const std::vector& segment_delimiters = {".", "?", "!"}, + const std::optional& segment_gap_threshold = std::nullopt); + +} // namespace parakeet::timestamp_utils diff --git a/examples/models/parakeet/tokenizer_utils.cpp b/examples/models/parakeet/tokenizer_utils.cpp new file mode 100644 index 00000000000..8cebebd8b19 --- /dev/null +++ b/examples/models/parakeet/tokenizer_utils.cpp @@ -0,0 +1,111 @@ +#include "tokenizer_utils.h" + +#include + +#include +#include +#include + +namespace { + +bool is_whitespace_only(const std::string& token) { + if (token.empty()) { + return true; + } + + try { + const auto codepoints = unicode_cpts_from_utf8(token); + for (const auto cp : codepoints) { + if (!unicode_cpt_flags(cp).is_whitespace) { + return false; + } + } + return true; + } catch (const std::exception&) { + return false; + } +} + +bool is_special_token(const std::string& token) { + if (token.size() >= 2 && token.front() == '[' && token.back() == ']') { + return true; + } + if (token.size() >= 2 && token.front() == '<' && token.back() == '>') { + return true; + } + if (token.rfind("##", 0) == 0) { + return true; + } + if (token.rfind(u8"▁", 0) == 0) { + return true; + } + if (is_whitespace_only(token)) { + return true; + } + return false; +} + +} // namespace + +namespace parakeet::tokenizer_utils { + +std::unordered_set derive_supported_punctuation( + const tokenizers::Tokenizer& tokenizer) { + std::unordered_set punctuation; + + const int32_t vocab_size = tokenizer.vocab_size(); + for (int32_t id = 0; id < vocab_size; id++) { + const auto piece_result = tokenizer.id_to_piece(static_cast(id)); + if (!piece_result.ok()) { + continue; + } + const std::string& piece = piece_result.get(); + if (is_special_token(piece)) { + continue; + } + + try { + const auto codepoints = unicode_cpts_from_utf8(piece); + for (const auto cp : codepoints) { + if (unicode_cpt_flags(cp).is_punctuation) { + punctuation.insert(unicode_cpt_to_utf8(cp)); + } + } + } catch (const std::exception&) { + ET_LOG( + Error, + "Failed to decode token piece '%s' to codepoints", + piece.c_str()); + } + } + + return punctuation; +} + +std::string decode_token_sequence( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer) { + std::string result; + TokenId prev_token = tokenizer.bos_tok(); + for (const TokenId token : tokens) { + auto decode_result = tokenizer.decode(prev_token, token); + if (decode_result.ok()) { + result += decode_result.get(); + } + prev_token = token; + } + return result; +} + +std::string decode_token_sequence( + const std::vector& decoded_tokens, + const tokenizers::Tokenizer& tokenizer) { + std::vector token_ids; + token_ids.reserve(decoded_tokens.size()); + for (const auto& tok : decoded_tokens) { + token_ids.push_back(tok.id); + } + return decode_token_sequence(token_ids, tokenizer); +} + +} // namespace parakeet::tokenizer_utils diff --git a/examples/models/parakeet/tokenizer_utils.h b/examples/models/parakeet/tokenizer_utils.h new file mode 100644 index 00000000000..588edb54207 --- /dev/null +++ b/examples/models/parakeet/tokenizer_utils.h @@ -0,0 +1,27 @@ +#pragma once + +#include "types.h" + +#include +#include +#include + +#include + +namespace parakeet::tokenizer_utils { + +// Matches NeMo extract_punctuation_from_vocab method +// https://github.com/NVIDIA-NeMo/NeMo/blob/b90a528/nemo/collections/asr/parts/utils/tokenizer_utils.py#L20 +std::unordered_set derive_supported_punctuation( + const tokenizers::Tokenizer& tokenizer); + +std::string decode_token_sequence( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer); + +// convenience overload +std::string decode_token_sequence( + const std::vector& decoded_tokens, + const tokenizers::Tokenizer& tokenizer); + +} // namespace parakeet::tokenizer_utils diff --git a/examples/models/parakeet/types.h b/examples/models/parakeet/types.h new file mode 100644 index 00000000000..eab1dd974df --- /dev/null +++ b/examples/models/parakeet/types.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace parakeet { + +// Matches output type of tokenizers::Tokenizer methods +using TokenId = uint64_t; + +struct Token { + TokenId id; + int64_t start_offset; + int64_t duration; +}; + +struct TokenWithTextInfo { + TokenId id; + // Raw vocabulary piece for the token_id (i.e., "##ing", "▁hello") + std::string raw_piece; + // Decoded text for the token_id (i.e., "ing", " hello") + std::string decoded_text; + int64_t start_offset; + int64_t end_offset; +}; + +struct TextWithOffsets { + std::string text; + int64_t start_offset; + int64_t end_offset; +}; + +} // namespace parakeet From 5c27d9d4842428fe0a0d0358cccde86c2fb3bd09 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Tue, 13 Jan 2026 14:38:32 -0500 Subject: [PATCH 13/17] Remove redundant check --- examples/models/parakeet/timestamp_utils.cpp | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/models/parakeet/timestamp_utils.cpp b/examples/models/parakeet/timestamp_utils.cpp index c15693566f9..ec93f089c85 100644 --- a/examples/models/parakeet/timestamp_utils.cpp +++ b/examples/models/parakeet/timestamp_utils.cpp @@ -216,19 +216,17 @@ std::vector get_segment_offsets( if (!word.empty() && (ends_with_delimiter || is_delimiter_word)) { segment_words.push_back(word); - if (!segment_words.empty()) { - std::string segment; - for (size_t j = 0; j < segment_words.size(); j++) { - if (j > 0) { - segment += " "; - } - segment += segment_words[j]; + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; } - segment_offsets.push_back( - {segment, - word_offsets[previous_word_index].start_offset, - offset.end_offset}); + segment += segment_words[j]; } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + offset.end_offset}); segment_words.clear(); previous_word_index = i + 1; continue; From 9504e37b3285b5630e95a18243265108034293e8 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Tue, 13 Jan 2026 15:00:13 -0500 Subject: [PATCH 14/17] CMake lint fix --- examples/models/parakeet/CMakeLists.txt | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt index 7632d1e92ea..5da7b4373b1 100644 --- a/examples/models/parakeet/CMakeLists.txt +++ b/examples/models/parakeet/CMakeLists.txt @@ -80,12 +80,7 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() -add_executable( - parakeet_runner - main.cpp - timestamp_utils.cpp - tokenizer_utils.cpp -) +add_executable(parakeet_runner main.cpp timestamp_utils.cpp tokenizer_utils.cpp) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(parakeet_runner) if(NOT APPLE AND NOT MSVC) From 349d0b69da341a7b5804748ef893c000dac1c0d0 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Tue, 13 Jan 2026 15:07:06 -0500 Subject: [PATCH 15/17] Explicit pytorch tokenizers include --- examples/models/parakeet/timestamp_utils.h | 2 +- examples/models/parakeet/tokenizer_utils.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/parakeet/timestamp_utils.h b/examples/models/parakeet/timestamp_utils.h index 5787b11d54e..b4b7913e63a 100644 --- a/examples/models/parakeet/timestamp_utils.h +++ b/examples/models/parakeet/timestamp_utils.h @@ -7,7 +7,7 @@ #include #include -#include +#include namespace parakeet::timestamp_utils { diff --git a/examples/models/parakeet/tokenizer_utils.h b/examples/models/parakeet/tokenizer_utils.h index 588edb54207..1b1205cc07a 100644 --- a/examples/models/parakeet/tokenizer_utils.h +++ b/examples/models/parakeet/tokenizer_utils.h @@ -6,7 +6,7 @@ #include #include -#include +#include namespace parakeet::tokenizer_utils { From 4b5b15a13cd09d7a9d58215becaedb07db6dd4c0 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Tue, 13 Jan 2026 16:38:52 -0500 Subject: [PATCH 16/17] default to segement timestamps --- examples/models/parakeet/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 291282aecc8..b90d25ac0f0 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -45,7 +45,7 @@ DEFINE_string( "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); DEFINE_string( timestamps, - "none", + "segment", "Timestamp output mode: none|token|word|segment|all"); using ::executorch::extension::from_blob; From 7ada1eb74e5837dd492b9d6606dfdcb622174e3a Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Tue, 13 Jan 2026 16:46:39 -0500 Subject: [PATCH 17/17] Default segment in readme --- examples/models/parakeet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 98711f45a0f..bf25edbe5b0 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -71,4 +71,4 @@ From the executorch root directory: | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | | `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | -| `--timestamps` | Timestamp output mode: `none\|token\|word\|segment\|all` | +| `--timestamps` | Timestamp output mode: `none\|token\|word\|segment\|all` (default: `segment`) |