Skip to content

Commit

Permalink
tflite model plugin running
Browse files Browse the repository at this point in the history
  • Loading branch information
williamzhang0306 committed Aug 13, 2024
1 parent 071052a commit 5782194
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 14 deletions.
17 changes: 17 additions & 0 deletions onair/config/tflite_test_config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[FILES]
TelemetryFilePath = onair/data/raw_telemetry_data/data_physics_generation/Errors
TelemetryFile = 700_crash_to_earth_1.csv
MetaFilePath = onair/data/telemetry_configs/
MetaFile = data_physics_generation_CONFIG.json

[DATA_HANDLING]
DataSourceFile = onair/data_handling/csv_parser.py

[PLUGINS]
KnowledgeRepPluginDict = {'tflite_model_1':'plugins/tflite_model/__init__.py'}
LearnersPluginDict = {'reporter':'plugins/reporter'}
PlannersPluginDict = {}
ComplexPluginDict = {}

[OPTIONS]
IO_Enabled = true
53 changes: 39 additions & 14 deletions plugins/tflite_model/tflite_model_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,63 @@
from onair.src.ai_components.ai_plugin_abstract.ai_plugin import AIPlugin
import numpy as np

# tflite_runtime is the bare minimum required to run tflite models
try:
import tflite_runtime.interpreter as tflite
except ModuleNotFoundError:
import tensorflow.lite as tflite
else:
raise ModuleNotFoundError("tflite_runtime or tensorflow modules not found")

class Plugin(AIPlugin):
def __init__(self, _name, _headers):
super().__init__(_name, _headers)
# your model goes here
model_path = r"tflite_models\yolo-v5-tflite-tflite-tflite-model-v1\1.tflite"

# Your model here
model_path = r"tflite_models\mobilebert-tflite-default-v1\1.tflite"

# Load and initialize the model
self.interpreter = tflite.Interpreter(model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()

self.output_details = self.interpreter.get_output_details()
self.input_tensors = None
self.output_tensors = None

def update(self,low_level_data=[], high_level_data={}):
"""
Given streamed data point, system should update internally
"""
pass
self.input_tensors = self.generate_random_input()

# set inputs - this should handle models with multiple inputs
for model_input, tensor in zip(self.input_details, self.input_tensors):
self.interpreter.set_tensor(model_input['index'], tensor)

self.interpreter.invoke()

outputs = []
for model_output in self.output_details:
tensor = self.interpreter.get_tensor(model_output['index'])
outputs.append(tensor)

self.output_tensors = tuple(outputs)

def render_reasoning(self):
"""
System should return its diagnosis
"""
pass

def generate_random_input(self):
pass


return self.output_tensors


def generate_random_input(self) -> tuple:
"""
Generates random tensors to be used as the input of a tflite model.
Returns:
tuple of np.ndarray
"""
tensors = []
for model_input in self.input_details:
rand_tensor = np.random.rand(*model_input['shape'])
rand_tensor = rand_tensor.astype(model_input['dtype'])
tensors.append(rand_tensor)
return tuple(tensors)

0 comments on commit 5782194

Please sign in to comment.