Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/parakeet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ if(EXECUTORCH_BUILD_METAL)
executorch_target_link_options_shared_lib(metal_backend)
endif()

add_executable(parakeet_runner main.cpp)
add_executable(parakeet_runner main.cpp timestamp_utils.cpp tokenizer_utils.cpp)
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
target_link_options_gc_sections(parakeet_runner)
if(NOT APPLE AND NOT MSVC)
Expand Down
1 change: 1 addition & 0 deletions examples/models/parakeet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@ From the executorch root directory:
| `--audio_path` | Path to input audio file (.wav) |
| `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) |
| `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) |
| `--timestamps` | Timestamp output mode: `none\|token\|word\|segment\|all` (default: `segment`) |
5 changes: 5 additions & 0 deletions examples/models/parakeet/export_parakeet_tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,18 @@ def export_all(model):
)

sample_rate = model.preprocessor._cfg.sample_rate
window_stride = float(model.preprocessor._cfg.window_stride)
encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8))

metadata = {
"num_rnn_layers": num_layers,
"pred_hidden": pred_hidden,
"joint_hidden": joint_hidden,
"vocab_size": model.tokenizer.vocab_size,
"blank_id": model.tokenizer.vocab_size,
"sample_rate": sample_rate,
"window_stride": window_stride,
"encoder_subsampling_factor": encoder_subsampling_factor,
}

return programs, metadata
Expand Down
183 changes: 151 additions & 32 deletions examples/models/parakeet/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,28 @@
* LICENSE file in the root directory of this source tree.
*/

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <exception>
#include <fstream>
#include <iostream>
#include <memory>
#include <optional>
#include <string>
#include <unordered_set>
#include <vector>

#include <gflags/gflags.h>

#include "timestamp_utils.h"
#include "tokenizer_utils.h"
#include "types.h"

#include <executorch/extension/llm/runner/llm_runner_helper.h>
#include <executorch/extension/llm/runner/wav_loader.h>
#include <executorch/extension/llm/tokenizers/third-party/llama.cpp-unicode/include/unicode.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/core/evalue.h>
Expand All @@ -27,24 +37,75 @@ DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte).");
DEFINE_string(audio_path, "", "Path to input audio file (.wav).");
DEFINE_string(
tokenizer_path,
"tokenizer.json",
"tokenizer.model",
"Path to SentencePiece tokenizer model file.");
DEFINE_string(
data_path,
"",
"Path to data file (.ptd) for delegate data (optional, required for CUDA).");
DEFINE_string(
timestamps,
"segment",
"Timestamp output mode: none|token|word|segment|all");

using ::executorch::extension::from_blob;
using ::executorch::extension::Module;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;

namespace {
using ::parakeet::TextWithOffsets;
using ::parakeet::Token;
using ::parakeet::TokenId;
using ::parakeet::TokenWithTextInfo;

namespace {
// TDT duration values
const std::vector<int> DURATIONS = {0, 1, 2, 3, 4};

std::vector<int64_t> greedy_decode_executorch(
struct TimestampOutputMode {
bool token = false;
bool word = false;
bool segment = false;

bool enabled() const {
return token || word || segment;
}
};

std::string to_lower_ascii(std::string s) {
for (char& ch : s) {
ch = static_cast<char>(std::tolower(static_cast<unsigned char>(ch)));
}
return s;
}

TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) {
if (raw_arg.empty()) {
throw std::invalid_argument(
"Invalid --timestamps value (empty). Expected: token, word, segment, all.");
}
const std::string mode = to_lower_ascii(raw_arg);
if (mode == "none") {
return {false, false, false};
}
if (mode == "token") {
return {true, false, false};
}
if (mode == "word") {
return {false, true, false};
}
if (mode == "segment") {
return {false, false, true};
}
if (mode == "all") {
return {true, true, true};
}
throw std::invalid_argument(
"Invalid --timestamps value '" + raw_arg +
"'. Expected: token, word, segment, all.");
}

std::vector<Token> greedy_decode_executorch(
Module& model,
const ::executorch::aten::Tensor& encoder_output,
int64_t encoder_len,
Expand All @@ -53,7 +114,7 @@ std::vector<int64_t> greedy_decode_executorch(
int64_t num_rnn_layers = 2,
int64_t pred_hidden = 640,
int64_t max_symbols_per_step = 10) {
std::vector<int64_t> hypothesis;
std::vector<Token> hypothesis;
int64_t num_token_classes = vocab_size + 1;

// Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim]
Expand Down Expand Up @@ -205,10 +266,10 @@ std::vector<int64_t> greedy_decode_executorch(
int64_t dur = DURATIONS[dur_idx];

if (k == blank_id) {
t += std::max(dur, (int64_t)1);
t += std::max(dur, static_cast<int64_t>(1));
symbols_on_frame = 0;
} else {
hypothesis.push_back(k);
hypothesis.push_back({static_cast<TokenId>(k), t, dur});

// Update decoder state
std::vector<int64_t> token_data = {k};
Expand Down Expand Up @@ -268,29 +329,19 @@ std::vector<int64_t> greedy_decode_executorch(
return hypothesis;
}

std::string tokens_to_text(
const std::vector<int64_t>& tokens,
tokenizers::Tokenizer* tokenizer) {
// Decode tokens to text one by one
std::string result;
uint64_t prev_token = 0;
for (size_t i = 0; i < tokens.size(); i++) {
uint64_t token = static_cast<uint64_t>(tokens[i]);
auto decode_result = tokenizer->decode(prev_token, token);
if (decode_result.ok()) {
result += decode_result.get();
}
prev_token = token;
}

return result;
}

} // namespace

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

TimestampOutputMode timestamp_mode;
try {
timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps);
} catch (const std::invalid_argument& e) {
ET_LOG(Error, "%s", e.what());
return 1;
}

if (FLAGS_audio_path.empty()) {
ET_LOG(Error, "audio_path flag must be provided.");
return 1;
Expand Down Expand Up @@ -381,10 +432,14 @@ int main(int argc, char** argv) {
auto vocab_size_result = model->execute("vocab_size", empty_inputs);
auto blank_id_result = model->execute("blank_id", empty_inputs);
auto sample_rate_result = model->execute("sample_rate", empty_inputs);
auto window_stride_result = model->execute("window_stride", empty_inputs);
auto encoder_subsampling_factor_result =
model->execute("encoder_subsampling_factor", empty_inputs);

if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() ||
!vocab_size_result.ok() || !blank_id_result.ok() ||
!sample_rate_result.ok()) {
!sample_rate_result.ok() || !window_stride_result.ok() ||
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will break compat with previously exported parakeet models. I chose to do this b/c early in development and to avoid having to make a separate path that allows everything but timestamps if the new metadata isn't present.

Open to doing such a thing if reviewers feel strongly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, i don't mind bc breaking at this early stage.

And it's in examples anyway not in core directory (like /extensions)

!encoder_subsampling_factor_result.ok()) {
ET_LOG(
Error,
"Failed to query model metadata. Make sure the model was exported with constant_methods.");
Expand All @@ -396,18 +451,23 @@ int main(int argc, char** argv) {
int64_t num_rnn_layers = num_rnn_layers_result.get()[0].toInt();
int64_t pred_hidden = pred_hidden_result.get()[0].toInt();
int64_t sample_rate = sample_rate_result.get()[0].toInt();
double window_stride = window_stride_result.get()[0].toDouble();
int64_t encoder_subsampling_factor =
encoder_subsampling_factor_result.get()[0].toInt();

ET_LOG(
Info,
"Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld",
"Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld, window_stride=%.6f, encoder_subsampling_factor=%lld",
static_cast<long long>(vocab_size),
static_cast<long long>(blank_id),
static_cast<long long>(num_rnn_layers),
static_cast<long long>(pred_hidden),
static_cast<long long>(sample_rate));
static_cast<long long>(sample_rate),
window_stride,
encoder_subsampling_factor);

ET_LOG(Info, "Running TDT greedy decode...");
auto tokens = greedy_decode_executorch(
auto decoded_tokens = greedy_decode_executorch(
*model,
encoded,
encoded_len,
Expand All @@ -416,7 +476,7 @@ int main(int argc, char** argv) {
num_rnn_layers,
pred_hidden);

ET_LOG(Info, "Decoded %zu tokens", tokens.size());
ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size());

// Load tokenizer
ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str());
Expand All @@ -431,9 +491,68 @@ int main(int argc, char** argv) {
}

// Convert tokens to text
std::string text = tokens_to_text(tokens, tokenizer.get());
std::cout << "Transcription tokens: " << text << std::endl;
std::string text = parakeet::tokenizer_utils::decode_token_sequence(
decoded_tokens, *tokenizer);
std::cout << "Transcribed text: " << text << std::endl;

if (!timestamp_mode.enabled()) {
return 0;
}

ET_LOG(Info, "Computing timestamps...");
std::unordered_set<std::string> supported_punctuation =
parakeet::tokenizer_utils::derive_supported_punctuation(*tokenizer);
ET_LOG(
Info,
"Derived supported_punctuation size=%zu",
supported_punctuation.size());

// for simplicity, compute all levels of timestamps regardless of mode
std::vector<TokenWithTextInfo> tokens_with_text_info;
try {
tokens_with_text_info =
parakeet::timestamp_utils::get_tokens_with_text_info(
decoded_tokens, *tokenizer, supported_punctuation);
} catch (const std::exception& e) {
ET_LOG(Error, "Failed to get tokens with text info: %s", e.what());
return 1;
}
const auto word_offsets = parakeet::timestamp_utils::get_words_offsets(
tokens_with_text_info, *tokenizer, supported_punctuation);
const auto segment_offsets =
parakeet::timestamp_utils::get_segment_offsets(word_offsets);

const double frame_to_seconds =
window_stride * static_cast<double>(encoder_subsampling_factor);

if (timestamp_mode.segment) {
std::cout << "\nSegment timestamps:" << std::endl;
for (const auto& segment : segment_offsets) {
const double start = segment.start_offset * frame_to_seconds;
const double end = segment.end_offset * frame_to_seconds;
std::cout << start << "s - " << end << "s : " << segment.text
<< std::endl;
}
}

if (timestamp_mode.word) {
std::cout << "\nWord timestamps:" << std::endl;
for (const auto& word : word_offsets) {
const double start = word.start_offset * frame_to_seconds;
const double end = word.end_offset * frame_to_seconds;
std::cout << start << "s - " << end << "s : " << word.text << std::endl;
}
}

if (timestamp_mode.token) {
std::cout << "\nToken timestamps:" << std::endl;
for (const auto& token : tokens_with_text_info) {
const double start = token.start_offset * frame_to_seconds;
const double end = token.end_offset * frame_to_seconds;
std::cout << start << "s - " << end << "s : " << token.decoded_text
<< std::endl;
}
}

ET_LOG(Info, "Done!");
return 0;
}
Loading
Loading