Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/micro_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ TFLMRegistration Register_MIRROR_PAD();
TFLMRegistration Register_MUL();
TFLMRegistration Register_NEG();
TFLMRegistration Register_NOT_EQUAL();
TFLMRegistration* Register_ONE_HOT();
TFLMRegistration Register_PACK();
TFLMRegistration Register_PAD();
TFLMRegistration Register_PADV2();
Expand Down
243 changes: 243 additions & 0 deletions tensorflow/lite/micro/kernels/one_hot.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_common.h"

namespace tflite {
namespace {

constexpr int kIndicesTensor = 0;
constexpr int kDepthTensor = 1;
constexpr int kOnValueTensor = 2;
constexpr int kOffValueTensor = 3;
constexpr int kOutputTensor = 0;

namespace { // Local Util functions
inline int NumElements(const TfLiteEvalTensor* t) {
int count = 1;
for (int i = 0; i < t->dims->size; ++i) {
count *= t->dims->data[i];
}
return count;
}
} // namespace

// Retrieves the input tensors (indices, depth, on_value, off_value) and the
// output tensor (output) from the TfLiteNode.
// Reads params->axis to compute the actual position (axis) where the depth
// dimension will be inserted.
// These values are created temporarily within the Prepare and Eval functions
// and are destroyed afterward → efficient use of stack memory.
struct OneHotContext {
OneHotContext(TfLiteContext* context, TfLiteNode* node) {
indices = tflite::micro::GetEvalInput(context, node, kIndicesTensor);
depth = tflite::micro::GetEvalInput(context, node, kDepthTensor);
on_value = tflite::micro::GetEvalInput(context, node, kOnValueTensor);
off_value = tflite::micro::GetEvalInput(context, node, kOffValueTensor);
output = tflite::micro::GetEvalOutput(context, node, kOutputTensor);

const auto* params =
reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
const int indices_dims = indices->dims->size;
axis = (params->axis == -1) ? indices_dims : params->axis;
output_dims = indices_dims + 1;
dtype = on_value->type;
}

const TfLiteEvalTensor* indices;
const TfLiteEvalTensor* depth;
const TfLiteEvalTensor* on_value;
const TfLiteEvalTensor* off_value;
TfLiteEvalTensor* output;

int axis;
int output_dims;
TfLiteType dtype;
};

// Operation function
template <typename T, typename TI>
void OneHotComputeImpl(const OneHotContext& op_context) {
int prefix_dim_size = 1;
for (int i = 0; i < op_context.axis; ++i) {
prefix_dim_size *= op_context.indices->dims->data[i];
}
if (prefix_dim_size == 0) {
return;
}

const RuntimeShape indices_shape =
tflite::micro::GetTensorShape(op_context.indices);
const int suffix_dim_size = indices_shape.FlatSize() / prefix_dim_size;

const int32_t* depth_ptr =
tflite::micro::GetTensorData<int32_t>(op_context.depth);
if (depth_ptr == nullptr) return;
const int depth = *depth_ptr;

const T on_value = *tflite::micro::GetTensorData<T>(op_context.on_value);
const T off_value = *tflite::micro::GetTensorData<T>(op_context.off_value);

T* output_data = tflite::micro::GetTensorData<T>(op_context.output);
const TI* indices_data = tflite::micro::GetTensorData<TI>(op_context.indices);

for (int i = 0; i < prefix_dim_size; ++i) {
for (int j = 0; j < depth; ++j) {
for (int k = 0; k < suffix_dim_size; ++k, ++output_data) {
*output_data =
static_cast<int>(indices_data[i * suffix_dim_size + k]) == j
? on_value
: off_value;
}
}
}
}

template <typename T>
void OneHotCompute(const OneHotContext& op_context) {
if (op_context.indices->type == kTfLiteInt64) {
OneHotComputeImpl<T, int64_t>(op_context);
} else {
OneHotComputeImpl<T, int32_t>(op_context);
}
}

TfLiteStatus EnsureOutputDimsMatchExpected(TfLiteContext* context,
const OneHotContext& op_context) {
// read depth data
const int32_t* depth_ptr =
tflite::micro::GetTensorData<int32_t>(op_context.depth);
TF_LITE_ENSURE(context, depth_ptr != nullptr);

const int depth_val = *depth_ptr;
TF_LITE_ENSURE(context, depth_val >= 0);

// Output Tensor evaluation
TF_LITE_ENSURE(context, op_context.output != nullptr);

TF_LITE_ENSURE(context, op_context.output->dims != nullptr);

// TFLM assumes that the output tensor’s dims are already allocated
const int expected_dims_size = op_context.output_dims;
TF_LITE_ENSURE_EQ(context, op_context.output->dims->size, expected_dims_size);

for (int i = 0; i < expected_dims_size; ++i) {
int expected_dim_i;
if (i < op_context.axis) {
expected_dim_i = op_context.indices->dims->data[i];
} else if (i == op_context.axis) {
expected_dim_i = depth_val;
} else {
expected_dim_i = op_context.indices->dims->data[i - 1];
}

// If the size pre-allocated by the TFLM Memory Planner does not match the
// actual computed size, an error is raised.
TF_LITE_ENSURE_EQ(context, op_context.output->dims->data[i],
expected_dim_i);
}

return kTfLiteOk;
}

void* OneHotInit(TfLiteContext* context, const char* buffer, size_t length) {
(void)context;
(void)buffer;
(void)length;
// This kernel does not require persistent op data.
return nullptr;
}

TfLiteStatus OneHotPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);

OneHotContext op_context{context, node};
TF_LITE_ENSURE(context, op_context.output != nullptr);

switch (op_context.dtype) {
case kTfLiteFloat32:
case kTfLiteInt16:
case kTfLiteInt32:
case kTfLiteInt64:
case kTfLiteInt8:
case kTfLiteUInt8:
case kTfLiteBool:
op_context.output->type = op_context.dtype;
break;
default:
TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s",
TfLiteTypeGetName(op_context.dtype));
return kTfLiteError;
}

TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
op_context.indices->type == kTfLiteInt64);
TF_LITE_ENSURE(context, op_context.axis >= 0 &&
op_context.axis < op_context.output_dims);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type,
op_context.dtype);

// Even if the depth tensor is not a constant, the test predefines the output
// shape, so here we only perform validation.
return EnsureOutputDimsMatchExpected(context, op_context);
}

TfLiteStatus OneHotEval(TfLiteContext* context, TfLiteNode* node) {
OneHotContext op_context{context, node};

switch (op_context.output->type) {
case kTfLiteFloat32:
OneHotCompute<float>(op_context);
break;
case kTfLiteInt32:
OneHotCompute<int32_t>(op_context);
break;
case kTfLiteInt64:
OneHotCompute<int64_t>(op_context);
break;
case kTfLiteInt8:
OneHotCompute<int8_t>(op_context);
break;
case kTfLiteUInt8:
OneHotCompute<uint8_t>(op_context);
break;
case kTfLiteBool:
OneHotCompute<bool>(op_context);
break;
default:
return kTfLiteError;
}

return kTfLiteOk;
}

} // namespace

const TFLMRegistration* Register_ONE_HOT() {
static TFLMRegistration r =
tflite::micro::RegisterOp(OneHotInit, OneHotPrepare, OneHotEval);
return &r;
}

} // namespace tflite
95 changes: 95 additions & 0 deletions tensorflow/lite/micro/one_hot_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"

namespace tflite {
namespace testing {
namespace {

// Helper function for OneHot operation test
template <typename T>
void TestOneHot(const int* indices_dims, const int32_t* indices_data,
const int* depth_dims, const int32_t* depth_data,
const int* on_dims, const T* on_data, const int* off_dims,
const T* off_data, const int* output_dims,
const T* expected_output_data, T* output_data, int axis = -1) {
// 1. Tensor Setting
TfLiteIntArray* in_dims = IntArrayFromInts(indices_dims);
TfLiteIntArray* d_dims = IntArrayFromInts(depth_dims);
TfLiteIntArray* on_val_dims = IntArrayFromInts(on_dims);
TfLiteIntArray* off_val_dims = IntArrayFromInts(off_dims);
TfLiteIntArray* out_dims = IntArrayFromInts(output_dims);

const int output_dims_count = ElementCount(*out_dims);

// 2. Create Input Tensor
constexpr int inputs_size = 4;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateTensor(indices_data, in_dims), CreateTensor(depth_data, d_dims),
CreateTensor(on_data, on_val_dims), CreateTensor(off_data, off_val_dims),
CreateTensor(output_data, out_dims), // Output Tensor
};

// 3. Parameter setting
TfLiteOneHotParams builtin_data = {axis};

// 4. KernelRunner execution
int inputs_array_data[] = {4, 0, 1, 2, 3}; // indices, depth, on, off
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 4}; // output tensor index
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);

// tflite::Register_ONE_HOT)
const TFLMRegistration* registration = tflite::Register_ONE_HOT();
micro::KernelRunner runner(*registration, tensors, tensors_size, inputs_array,
outputs_array,
reinterpret_cast<void*>(&builtin_data));

TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());

// 5. Result evaluation
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
}
}

} // namespace
} // namespace testing
} // namespace tflite

// UNIT TEST
TF_LITE_MICRO_TESTS_BEGIN

TF_LITE_MICRO_TEST(OneHot_BasicInt32) {
// Indices: [0, 1, 2]
const int indices_dims[] = {1, 3};
const int32_t indices_data[] = {0, 1, 2};

// Depth: 3
const int depth_dims[] = {1, 1};
const int32_t depth_data[] = {3};

// On: 1, Off: 0
const int on_dims[] = {1, 1};
const int32_t on_data[] = {1};
const int off_dims[] = {1, 1};
const int32_t off_data[] = {0};

// Output: [3, 3] -> Identity Matrix
const int output_dims[] = {2, 3, 3};
const int32_t expected_output[] = {1, 0, 0, 0, 1, 0, 0, 0, 1};

int32_t output_data[9];

tflite::testing::TestOneHot(indices_dims, indices_data, depth_dims,
depth_data, on_dims, on_data, off_dims, off_data,
output_dims, expected_output, output_data);
}

TF_LITE_MICRO_TESTS_END
6 changes: 4 additions & 2 deletions tensorflow/lite/micro/tools/make/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/arena_allocator/single_arena_buffer_allo
$(TENSORFLOW_ROOT)tensorflow/lite/micro/testing_helpers_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/greedy_memory_planner_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/linear_memory_planner_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/non_persistent_buffer_planner_shim_test.cc
$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/non_persistent_buffer_planner_shim_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/one_hot_test.cc

MICROLITE_CC_KERNEL_SRCS := \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/activations.cc \
Expand Down Expand Up @@ -437,6 +438,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/mirror_pad.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/mul.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/mul_common.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/neg.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/one_hot.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/pack.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/pad.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/pad_common.cc \
Expand Down Expand Up @@ -480,7 +482,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.cc
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/var_handle.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like.cc
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like.cc \

MICROLITE_CC_SIGNAL_KERNEL_SRCS := \
$(TENSORFLOW_ROOT)signal/micro/kernels/delay.cc \
Expand Down
Loading