diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt index 5ea7b81cd1f..5da7b4373b1 100644 --- a/examples/models/parakeet/CMakeLists.txt +++ b/examples/models/parakeet/CMakeLists.txt @@ -80,7 +80,7 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() -add_executable(parakeet_runner main.cpp) +add_executable(parakeet_runner main.cpp timestamp_utils.cpp tokenizer_utils.cpp) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(parakeet_runner) if(NOT APPLE AND NOT MSVC) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b27bc1f8a91..bf25edbe5b0 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -71,3 +71,4 @@ From the executorch root directory: | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | | `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | +| `--timestamps` | Timestamp output mode: `none\|token\|word\|segment\|all` (default: `segment`) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 92e32ca30bf..7d459f54da2 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -351,6 +351,9 @@ def export_all(model): ) sample_rate = model.preprocessor._cfg.sample_rate + window_stride = float(model.preprocessor._cfg.window_stride) + encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8)) + metadata = { "num_rnn_layers": num_layers, "pred_hidden": pred_hidden, @@ -358,6 +361,8 @@ def export_all(model): "vocab_size": model.tokenizer.vocab_size, "blank_id": model.tokenizer.vocab_size, "sample_rate": sample_rate, + "window_stride": window_stride, + "encoder_subsampling_factor": encoder_subsampling_factor, } return programs, metadata diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 026f3911a3d..b90d25ac0f0 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -6,18 +6,28 @@ * LICENSE file in the root directory of this source tree. */ +#include #include +#include #include +#include #include #include #include +#include #include +#include #include #include +#include "timestamp_utils.h" +#include "tokenizer_utils.h" +#include "types.h" + #include #include +#include #include #include #include @@ -27,24 +37,75 @@ DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); DEFINE_string( tokenizer_path, - "tokenizer.json", + "tokenizer.model", "Path to SentencePiece tokenizer model file."); DEFINE_string( data_path, "", "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); +DEFINE_string( + timestamps, + "segment", + "Timestamp output mode: none|token|word|segment|all"); using ::executorch::extension::from_blob; using ::executorch::extension::Module; using ::executorch::runtime::Error; using ::executorch::runtime::EValue; -namespace { +using ::parakeet::TextWithOffsets; +using ::parakeet::Token; +using ::parakeet::TokenId; +using ::parakeet::TokenWithTextInfo; +namespace { // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; -std::vector greedy_decode_executorch( +struct TimestampOutputMode { + bool token = false; + bool word = false; + bool segment = false; + + bool enabled() const { + return token || word || segment; + } +}; + +std::string to_lower_ascii(std::string s) { + for (char& ch : s) { + ch = static_cast(std::tolower(static_cast(ch))); + } + return s; +} + +TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { + if (raw_arg.empty()) { + throw std::invalid_argument( + "Invalid --timestamps value (empty). Expected: token, word, segment, all."); + } + const std::string mode = to_lower_ascii(raw_arg); + if (mode == "none") { + return {false, false, false}; + } + if (mode == "token") { + return {true, false, false}; + } + if (mode == "word") { + return {false, true, false}; + } + if (mode == "segment") { + return {false, false, true}; + } + if (mode == "all") { + return {true, true, true}; + } + throw std::invalid_argument( + "Invalid --timestamps value '" + raw_arg + + "'. Expected: token, word, segment, all."); +} + +std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, int64_t encoder_len, @@ -53,7 +114,7 @@ std::vector greedy_decode_executorch( int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { - std::vector hypothesis; + std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] @@ -205,10 +266,10 @@ std::vector greedy_decode_executorch( int64_t dur = DURATIONS[dur_idx]; if (k == blank_id) { - t += std::max(dur, (int64_t)1); + t += std::max(dur, static_cast(1)); symbols_on_frame = 0; } else { - hypothesis.push_back(k); + hypothesis.push_back({static_cast(k), t, dur}); // Update decoder state std::vector token_data = {k}; @@ -268,29 +329,19 @@ std::vector greedy_decode_executorch( return hypothesis; } -std::string tokens_to_text( - const std::vector& tokens, - tokenizers::Tokenizer* tokenizer) { - // Decode tokens to text one by one - std::string result; - uint64_t prev_token = 0; - for (size_t i = 0; i < tokens.size(); i++) { - uint64_t token = static_cast(tokens[i]); - auto decode_result = tokenizer->decode(prev_token, token); - if (decode_result.ok()) { - result += decode_result.get(); - } - prev_token = token; - } - - return result; -} - } // namespace int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); + TimestampOutputMode timestamp_mode; + try { + timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps); + } catch (const std::invalid_argument& e) { + ET_LOG(Error, "%s", e.what()); + return 1; + } + if (FLAGS_audio_path.empty()) { ET_LOG(Error, "audio_path flag must be provided."); return 1; @@ -381,10 +432,14 @@ int main(int argc, char** argv) { auto vocab_size_result = model->execute("vocab_size", empty_inputs); auto blank_id_result = model->execute("blank_id", empty_inputs); auto sample_rate_result = model->execute("sample_rate", empty_inputs); + auto window_stride_result = model->execute("window_stride", empty_inputs); + auto encoder_subsampling_factor_result = + model->execute("encoder_subsampling_factor", empty_inputs); if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || !vocab_size_result.ok() || !blank_id_result.ok() || - !sample_rate_result.ok()) { + !sample_rate_result.ok() || !window_stride_result.ok() || + !encoder_subsampling_factor_result.ok()) { ET_LOG( Error, "Failed to query model metadata. Make sure the model was exported with constant_methods."); @@ -396,18 +451,23 @@ int main(int argc, char** argv) { int64_t num_rnn_layers = num_rnn_layers_result.get()[0].toInt(); int64_t pred_hidden = pred_hidden_result.get()[0].toInt(); int64_t sample_rate = sample_rate_result.get()[0].toInt(); + double window_stride = window_stride_result.get()[0].toDouble(); + int64_t encoder_subsampling_factor = + encoder_subsampling_factor_result.get()[0].toInt(); ET_LOG( Info, - "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld", + "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld, window_stride=%.6f, encoder_subsampling_factor=%lld", static_cast(vocab_size), static_cast(blank_id), static_cast(num_rnn_layers), static_cast(pred_hidden), - static_cast(sample_rate)); + static_cast(sample_rate), + window_stride, + encoder_subsampling_factor); ET_LOG(Info, "Running TDT greedy decode..."); - auto tokens = greedy_decode_executorch( + auto decoded_tokens = greedy_decode_executorch( *model, encoded, encoded_len, @@ -416,7 +476,7 @@ int main(int argc, char** argv) { num_rnn_layers, pred_hidden); - ET_LOG(Info, "Decoded %zu tokens", tokens.size()); + ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size()); // Load tokenizer ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); @@ -431,9 +491,68 @@ int main(int argc, char** argv) { } // Convert tokens to text - std::string text = tokens_to_text(tokens, tokenizer.get()); - std::cout << "Transcription tokens: " << text << std::endl; + std::string text = parakeet::tokenizer_utils::decode_token_sequence( + decoded_tokens, *tokenizer); + std::cout << "Transcribed text: " << text << std::endl; + + if (!timestamp_mode.enabled()) { + return 0; + } + + ET_LOG(Info, "Computing timestamps..."); + std::unordered_set supported_punctuation = + parakeet::tokenizer_utils::derive_supported_punctuation(*tokenizer); + ET_LOG( + Info, + "Derived supported_punctuation size=%zu", + supported_punctuation.size()); + + // for simplicity, compute all levels of timestamps regardless of mode + std::vector tokens_with_text_info; + try { + tokens_with_text_info = + parakeet::timestamp_utils::get_tokens_with_text_info( + decoded_tokens, *tokenizer, supported_punctuation); + } catch (const std::exception& e) { + ET_LOG(Error, "Failed to get tokens with text info: %s", e.what()); + return 1; + } + const auto word_offsets = parakeet::timestamp_utils::get_words_offsets( + tokens_with_text_info, *tokenizer, supported_punctuation); + const auto segment_offsets = + parakeet::timestamp_utils::get_segment_offsets(word_offsets); + + const double frame_to_seconds = + window_stride * static_cast(encoder_subsampling_factor); + + if (timestamp_mode.segment) { + std::cout << "\nSegment timestamps:" << std::endl; + for (const auto& segment : segment_offsets) { + const double start = segment.start_offset * frame_to_seconds; + const double end = segment.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << segment.text + << std::endl; + } + } + + if (timestamp_mode.word) { + std::cout << "\nWord timestamps:" << std::endl; + for (const auto& word : word_offsets) { + const double start = word.start_offset * frame_to_seconds; + const double end = word.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << word.text << std::endl; + } + } + + if (timestamp_mode.token) { + std::cout << "\nToken timestamps:" << std::endl; + for (const auto& token : tokens_with_text_info) { + const double start = token.start_offset * frame_to_seconds; + const double end = token.end_offset * frame_to_seconds; + std::cout << start << "s - " << end << "s : " << token.decoded_text + << std::endl; + } + } - ET_LOG(Info, "Done!"); return 0; } diff --git a/examples/models/parakeet/timestamp_utils.cpp b/examples/models/parakeet/timestamp_utils.cpp new file mode 100644 index 00000000000..ec93f089c85 --- /dev/null +++ b/examples/models/parakeet/timestamp_utils.cpp @@ -0,0 +1,255 @@ +#include "timestamp_utils.h" + +#include "tokenizer_utils.h" + +#include +#include + +#include + +namespace parakeet::timestamp_utils { + +std::vector get_tokens_with_text_info( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation) { + std::vector tokens_with_text; + tokens_with_text.reserve(tokens.size()); + + for (const auto& token : tokens) { + auto piece_result = tokenizer.id_to_piece(token.id); + if (!piece_result.ok()) { + throw std::runtime_error( + "id_to_piece failed for token=" + std::to_string(token.id)); + } + + auto text_result = tokenizer.decode(tokenizer.bos_tok(), token.id); + if (!text_result.ok()) { + throw std::runtime_error( + "decode failed for token=" + std::to_string(token.id)); + } + + tokens_with_text.push_back( + {token.id, + piece_result.get(), + text_result.get(), + token.start_offset, + token.start_offset + token.duration}); + } + + // NeMo TDT punctuation refinement pass: snap punctuation to the end of the + // previous token. + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 + for (size_t i = 1; i < tokens_with_text.size(); i++) { + if (supported_punctuation.count(tokens_with_text[i].decoded_text) > 0) { + tokens_with_text[i].start_offset = tokens_with_text[i - 1].end_offset; + tokens_with_text[i].end_offset = tokens_with_text[i].start_offset; + } + } + + return tokens_with_text; +} + +std::vector get_words_offsets( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation, + const std::string& word_delimiter_char) { + std::vector word_offsets; + if (tokens.empty()) { + return word_offsets; + } + + size_t previous_token_index = 0; + std::vector build_token_indices; + + auto is_curr_punctuation = [&](const std::string& token_text) { + return token_text != word_delimiter_char && + supported_punctuation.count(token_text) > 0; + }; + + auto is_word_start = [&](const std::string& token_piece, + const std::string& token_text, + const std::string& next_non_delim_token) { + const bool next_is_punctuation = + supported_punctuation.count(next_non_delim_token) > 0; + return token_piece != token_text || + (token_text == word_delimiter_char && !next_is_punctuation); + }; + + for (size_t i = 0; i < tokens.size(); i++) { + const auto& token = tokens[i]; + + const bool curr_punctuation = is_curr_punctuation(token.decoded_text); + + std::string next_non_delim_token; + for (size_t j = i + 1; j < tokens.size(); j++) { + if (tokens[j].decoded_text != word_delimiter_char) { + next_non_delim_token = tokens[j].decoded_text; + break; + } + } + + if (is_word_start( + token.raw_piece, token.decoded_text, next_non_delim_token) && + !curr_punctuation) { + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (size_t idx : build_token_indices) { + built_ids.push_back(tokens[idx].id); + } + word_offsets.push_back( + {tokenizer_utils::decode_token_sequence(built_ids, tokenizer), + tokens[previous_token_index].start_offset, + tokens[i - 1].end_offset}); + } + + build_token_indices.clear(); + + if (token.decoded_text != word_delimiter_char) { + build_token_indices.push_back(i); + previous_token_index = i; + } + } else if ( + curr_punctuation && build_token_indices.empty() && + !word_offsets.empty()) { + auto& last_built_word = word_offsets.back(); + last_built_word.end_offset = token.end_offset; + if (!last_built_word.text.empty() && last_built_word.text.back() == ' ') { + last_built_word.text.pop_back(); + } + last_built_word.text += token.decoded_text; + } else if (curr_punctuation && !build_token_indices.empty()) { + const auto& last = tokens[build_token_indices.back()].raw_piece; + if (last == " " || last == "_" || last == "▁") { + build_token_indices.pop_back(); + } + build_token_indices.push_back(i); + } else { + if (build_token_indices.empty()) { + previous_token_index = i; + } + build_token_indices.push_back(i); + } + } + + // Match NeMo behavior: inject first start_offset and append any remaining + // built tokens as the final word. + if (word_offsets.empty()) { + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (const size_t idx : build_token_indices) { + built_ids.push_back(tokens[idx].id); + } + word_offsets.push_back( + {tokenizer_utils::decode_token_sequence(built_ids, tokenizer), + tokens[0].start_offset, + tokens.back().end_offset}); + } + } else { + word_offsets[0].start_offset = tokens[0].start_offset; + + if (!build_token_indices.empty()) { + std::vector built_ids; + built_ids.reserve(build_token_indices.size()); + for (size_t idx : build_token_indices) { + built_ids.push_back(tokens[idx].id); + } + word_offsets.push_back( + {tokenizer_utils::decode_token_sequence(built_ids, tokenizer), + tokens[previous_token_index].start_offset, + tokens.back().end_offset}); + } + } + + return word_offsets; +} + +std::vector get_segment_offsets( + const std::vector& word_offsets, + const std::vector& segment_delimiters, + const std::optional& segment_gap_threshold) { + std::vector segment_offsets; + if (word_offsets.empty()) { + return segment_offsets; + } + + std::vector segment_words; + size_t previous_word_index = 0; + + for (size_t i = 0; i < word_offsets.size(); i++) { + const auto& offset = word_offsets[i]; + const auto& word = offset.text; + + if (segment_gap_threshold.has_value() && !segment_words.empty()) { + const int64_t gap_between_words = + offset.start_offset - word_offsets[i - 1].end_offset; + if (gap_between_words >= segment_gap_threshold.value()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + word_offsets[i - 1].end_offset}); + segment_words = {word}; + previous_word_index = i; + continue; + } + } + + const bool is_delimiter_word = + std::find(segment_delimiters.begin(), segment_delimiters.end(), word) != + segment_delimiters.end(); + + const bool ends_with_delimiter = !word.empty() && + std::find( + segment_delimiters.begin(), + segment_delimiters.end(), + std::string(1, word.back())) != segment_delimiters.end(); + + if (!word.empty() && (ends_with_delimiter || is_delimiter_word)) { + segment_words.push_back(word); + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + offset.end_offset}); + segment_words.clear(); + previous_word_index = i + 1; + continue; + } + + segment_words.push_back(word); + } + + if (!segment_words.empty()) { + std::string segment; + for (size_t j = 0; j < segment_words.size(); j++) { + if (j > 0) { + segment += " "; + } + segment += segment_words[j]; + } + segment_offsets.push_back( + {segment, + word_offsets[previous_word_index].start_offset, + word_offsets.back().end_offset}); + } + + return segment_offsets; +} + +} // namespace parakeet::timestamp_utils diff --git a/examples/models/parakeet/timestamp_utils.h b/examples/models/parakeet/timestamp_utils.h new file mode 100644 index 00000000000..b4b7913e63a --- /dev/null +++ b/examples/models/parakeet/timestamp_utils.h @@ -0,0 +1,36 @@ +#pragma once + +#include "types.h" + +#include +#include +#include +#include + +#include + +namespace parakeet::timestamp_utils { + +// throws if any tokenizer calls fail +std::vector get_tokens_with_text_info( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation); + +// ref: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54 +// assumes BPE tokenizer type +std::vector get_words_offsets( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const std::unordered_set& supported_punctuation, + const std::string& word_delimiter_char = " "); + +// ref +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227 +std::vector get_segment_offsets( + const std::vector& word_offsets, + const std::vector& segment_delimiters = {".", "?", "!"}, + const std::optional& segment_gap_threshold = std::nullopt); + +} // namespace parakeet::timestamp_utils diff --git a/examples/models/parakeet/tokenizer_utils.cpp b/examples/models/parakeet/tokenizer_utils.cpp new file mode 100644 index 00000000000..8cebebd8b19 --- /dev/null +++ b/examples/models/parakeet/tokenizer_utils.cpp @@ -0,0 +1,111 @@ +#include "tokenizer_utils.h" + +#include + +#include +#include +#include + +namespace { + +bool is_whitespace_only(const std::string& token) { + if (token.empty()) { + return true; + } + + try { + const auto codepoints = unicode_cpts_from_utf8(token); + for (const auto cp : codepoints) { + if (!unicode_cpt_flags(cp).is_whitespace) { + return false; + } + } + return true; + } catch (const std::exception&) { + return false; + } +} + +bool is_special_token(const std::string& token) { + if (token.size() >= 2 && token.front() == '[' && token.back() == ']') { + return true; + } + if (token.size() >= 2 && token.front() == '<' && token.back() == '>') { + return true; + } + if (token.rfind("##", 0) == 0) { + return true; + } + if (token.rfind(u8"▁", 0) == 0) { + return true; + } + if (is_whitespace_only(token)) { + return true; + } + return false; +} + +} // namespace + +namespace parakeet::tokenizer_utils { + +std::unordered_set derive_supported_punctuation( + const tokenizers::Tokenizer& tokenizer) { + std::unordered_set punctuation; + + const int32_t vocab_size = tokenizer.vocab_size(); + for (int32_t id = 0; id < vocab_size; id++) { + const auto piece_result = tokenizer.id_to_piece(static_cast(id)); + if (!piece_result.ok()) { + continue; + } + const std::string& piece = piece_result.get(); + if (is_special_token(piece)) { + continue; + } + + try { + const auto codepoints = unicode_cpts_from_utf8(piece); + for (const auto cp : codepoints) { + if (unicode_cpt_flags(cp).is_punctuation) { + punctuation.insert(unicode_cpt_to_utf8(cp)); + } + } + } catch (const std::exception&) { + ET_LOG( + Error, + "Failed to decode token piece '%s' to codepoints", + piece.c_str()); + } + } + + return punctuation; +} + +std::string decode_token_sequence( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer) { + std::string result; + TokenId prev_token = tokenizer.bos_tok(); + for (const TokenId token : tokens) { + auto decode_result = tokenizer.decode(prev_token, token); + if (decode_result.ok()) { + result += decode_result.get(); + } + prev_token = token; + } + return result; +} + +std::string decode_token_sequence( + const std::vector& decoded_tokens, + const tokenizers::Tokenizer& tokenizer) { + std::vector token_ids; + token_ids.reserve(decoded_tokens.size()); + for (const auto& tok : decoded_tokens) { + token_ids.push_back(tok.id); + } + return decode_token_sequence(token_ids, tokenizer); +} + +} // namespace parakeet::tokenizer_utils diff --git a/examples/models/parakeet/tokenizer_utils.h b/examples/models/parakeet/tokenizer_utils.h new file mode 100644 index 00000000000..1b1205cc07a --- /dev/null +++ b/examples/models/parakeet/tokenizer_utils.h @@ -0,0 +1,27 @@ +#pragma once + +#include "types.h" + +#include +#include +#include + +#include + +namespace parakeet::tokenizer_utils { + +// Matches NeMo extract_punctuation_from_vocab method +// https://github.com/NVIDIA-NeMo/NeMo/blob/b90a528/nemo/collections/asr/parts/utils/tokenizer_utils.py#L20 +std::unordered_set derive_supported_punctuation( + const tokenizers::Tokenizer& tokenizer); + +std::string decode_token_sequence( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer); + +// convenience overload +std::string decode_token_sequence( + const std::vector& decoded_tokens, + const tokenizers::Tokenizer& tokenizer); + +} // namespace parakeet::tokenizer_utils diff --git a/examples/models/parakeet/types.h b/examples/models/parakeet/types.h new file mode 100644 index 00000000000..eab1dd974df --- /dev/null +++ b/examples/models/parakeet/types.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace parakeet { + +// Matches output type of tokenizers::Tokenizer methods +using TokenId = uint64_t; + +struct Token { + TokenId id; + int64_t start_offset; + int64_t duration; +}; + +struct TokenWithTextInfo { + TokenId id; + // Raw vocabulary piece for the token_id (i.e., "##ing", "▁hello") + std::string raw_piece; + // Decoded text for the token_id (i.e., "ing", " hello") + std::string decoded_text; + int64_t start_offset; + int64_t end_offset; +}; + +struct TextWithOffsets { + std::string text; + int64_t start_offset; + int64_t end_offset; +}; + +} // namespace parakeet