Skip to content

Commit

Permalink
Merge pull request #52 from SupremeLobster/main
Browse files Browse the repository at this point in the history
Release Quantized versions of the models
  • Loading branch information
TheSeriousProgrammer authored Sep 3, 2024
2 parents 08b0797 + a9c4ec2 commit ae5aa0a
Show file tree
Hide file tree
Showing 2 changed files with 344,279 additions and 7 deletions.
30 changes: 23 additions & 7 deletions eff_word_net/audio_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
LIB_FOLDER_LOCATION = os.path.dirname(os.path.realpath(__file__))

class ModelRawBackend :
def __init__(self) :
def __init__(self, use_quantized_model=False):
self.use_quantized_model = use_quantized_model
self.window_length = None
self.window_frames = None
pass
Expand Down Expand Up @@ -64,6 +65,12 @@ def audioToVector(self, inpAudio:np.array) -> np.array :

class First_Iteration_Siamese(ModelRawBackend) :
def __init__(self) :
super().__init__()

if self.use_quantized_model:
# TODO: Implement a quantized model
print("Quantized model not implemented yet, using default model instead")

self.window_length = 1.0 # 1 second
self.window_frames = int(self.window_length * 16000)
self.logmelcalc_interpreter = tflite.Interpreter(
Expand Down Expand Up @@ -144,16 +151,25 @@ def audioToVector(self, inpAudio:np.array) -> np.array :

class Resnet50_Arc_loss(ModelRawBackend):
def __init__(self):
super().__init__()

self.window_length = 1.5
self.window_frames = int(self.window_length * 16000)

self.onnx_sess = rt.InferenceSession(
os.path.join(
LIB_FOLDER_LOCATION, "models/resnet_50_arc/slim_93%_accuracy_72.7390%.onnx"),
sess_options= rt.SessionOptions(),
providers=["CPUExecutionProvider"]
)
if self.use_quantized_model:
self.onnx_sess = rt.InferenceSession(
os.path.join(
LIB_FOLDER_LOCATION, "models/resnet_50_arc/slim_93%_accuracy_72.7390%_qint8.onnx"),
sess_options= rt.SessionOptions(),
providers=["CPUExecutionProvider"]
)
else:
self.onnx_sess = rt.InferenceSession(
os.path.join(
LIB_FOLDER_LOCATION, "models/resnet_50_arc/slim_93%_accuracy_72.7390%.onnx"),
sess_options= rt.SessionOptions(),
providers=["CPUExecutionProvider"]
)

self.input_name:str = self.onnx_sess.get_inputs()[0].name
self.output_name:str = self.onnx_sess.get_outputs()[0].name
Expand Down
Loading

0 comments on commit ae5aa0a

Please sign in to comment.