-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wave_plotter.py
62 lines (45 loc) · 1.67 KB
/
wave_plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Plots sound waves."""
import enum
from matplotlib import pyplot
import numpy as np
import wave
from typing import Text
class WaveGraphType(enum.Enum):
"""Type of graph to plot."""
PER_FRAME = 'frame'
PER_SECOND = 'seconds'
def _get_image_file_name(sound_wave_file_name: Text) -> Text:
"""Gets a PNG file name for given sound wave file name.
Args:
sound_wave_file_name: Sound wave file name.
Returns:
Image file name.
"""
return sound_wave_file_name.replace('.wav', '.png')
def create_sound_wave_graph(sound_wave_file_name: Text,
wave_graph_type: WaveGraphType):
"""Creates a wave graph for given sound wav file.
Args:
sound_wave_file_name: Wav file to plot.
wave_graph_type: Type of plot to create.
"""
sound_wave_file = wave.open(sound_wave_file_name, 'r')
pyplot.figure(1, figsize=(16, 12), dpi=72)
pyplot.title('Signal Wave')
signal = sound_wave_file.readframes(-1)
signal = np.fromstring(signal, 'Int16')
channels = [[] for channel in range(sound_wave_file.getnchannels())]
for index, datum in enumerate(signal):
channels[index % len(channels)].append(datum)
frame_rate = sound_wave_file.getframerate()
Time = np.linspace(0, frame_rate * len(signal) / len(channels),
num=len(signal) / len(channels))
for channel in channels:
if wave_graph_type == WaveGraphType.PER_SECOND:
pyplot.plot(Time, channel)
elif wave_graph_type == WaveGraphType.PER_FRAME:
pyplot.plot(channel)
else:
raise ValueError(f'Unsupported graph type: {wave_graph_type}.')
wave_graph_file_name = _get_image_file_name(sound_wave_file_name)
pyplot.savefig(wave_graph_file_name)