Skip to content

Commit

Permalink
add test and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosholivan committed Nov 10, 2022
1 parent 58fb7b2 commit 84ccc30
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
34 changes: 16 additions & 18 deletions musicaiz/plotters/pianorolls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_subdivisions,
TimingConsts
)
from musicaiz.structure import Note
from musicaiz.structure import Note, Instrument


COLOR_EDGES = [
Expand Down Expand Up @@ -131,13 +131,12 @@ def _add_new_bar_subdiv(subdivisions):
"ticks": subdivisions[-1]["ticks"] + range_ticks,
})

def _notes_loop(self, notes: List[Note], show_pitch_labels: bool = True):
def _notes_loop(self, notes: List[Note]):
plt.ylabel("Pitch")
#highest_pitch = get_highest_pitch(track.instrument)
#lowest_pitch = get_lowest_pitch(track.instrument)
# plt.ylim((0, (highest_pitch - lowest_pitch) - 1))

pitches = []
for note in notes:
plt.vlines(x=note.start_ticks,
ymin=note.pitch,
Expand Down Expand Up @@ -226,14 +225,14 @@ def _notes_loop(self, notes: List[Note]):
x0=note.start_ticks,
y0=note.pitch,
x1=note.end_ticks,
y1=note.pitch+1,
y1=note.pitch + 1,
line=dict(
color=COLOR_EDGES[0],
width=2,
),
fillcolor=COLOR[0],
)

# this is to add a hover information on each note
self.fig.add_trace(
go.Scatter(
Expand All @@ -246,23 +245,22 @@ def _notes_loop(self, notes: List[Note]):
],
y=[
note.pitch,
note.pitch+1,
note.pitch+1,
note.pitch + 1,
note.pitch + 1,
note.pitch,
note.pitch
],
],
fill="toself",
mode="lines",
name=f"pitch={note.pitch}<br>\n"\
f"velocity={note.velocity}<br>\n"\
f"start_ticks={note.start_ticks}<br>\n"\
name=f"pitch={note.pitch}<br>\n"
f"velocity={note.velocity}<br>\n"
f"start_ticks={note.start_ticks}<br>\n"
f"end_ticks={note.end_ticks}<br>",
opacity=0,
showlegend=False,
)
)


def plot_grid(self, subdivisions):
# all the pitch values to be written in the axis
#self.fig.update_xaxes(range[0, len(subdivisions) - 1])
Expand Down Expand Up @@ -296,10 +294,9 @@ def plot_grid(self, subdivisions):
self.fig.add_vline(x=s["ticks"], line_width=0.2, line_color="blue")
prev_bar, prev_beat = s["bar"], s["bar_beat"]


def plot_instrument(
self,
track,
track: Instrument,
bar_start: int,
bar_end: int,
subdivision: str,
Expand All @@ -309,7 +306,6 @@ def plot_instrument(
time_sig: str = TimingConsts.DEFAULT_TIME_SIGNATURE.value,
bpm: int = TimingConsts.DEFAULT_BPM.value,
resolution: int = TimingConsts.RESOLUTION.value,
print_measure_data: bool = True,
show: bool = True
):

Expand All @@ -333,11 +329,11 @@ def plot_instrument(
total_bars = bar_end - bar_start
subdivisions = get_subdivisions(total_bars, subdivision, time_sig, bpm, resolution)
self.plot_grid(subdivisions)

# this is to add the yaxis labels# horizontal line for pitch grid
labels = [i for i in range(min(pitches) - 1, max(pitches) + 2)]
for pitch in labels:
self.fig.add_hline(y=pitch+1, line_width=0.1, line_color="white")
self.fig.add_hline(y=pitch + 1, line_width=0.1, line_color="white")
# if we do have too many pitches, we won't label all of them in the yaxis
if max(pitches) - min(pitches) > 24:
cleaned_labels = [label for label in labels if label % 2 == 0]
Expand All @@ -357,7 +353,9 @@ def plot_instrument(
),
)

self.fig.update_layout(legend={"xanchor":"center", "yanchor":"top"})
self.fig.update_layout(legend={
"xanchor": "center", "yanchor": "top"
})

if save_plot:
self.fig.write_html(Path(path, filename + ".html"))
Expand Down
17 changes: 16 additions & 1 deletion tests/unit/musicaiz/plotters/test_pianorolls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import matplotlib.pyplot as plt

from musicaiz.plotters import Pianoroll
from musicaiz.plotters import Pianoroll, PianorollHTML
from musicaiz.loaders import Musa


Expand All @@ -20,3 +21,17 @@ def test_Pianoroll_plot_instrument(midi_sample):
print_measure_data=False,
show_bar_labels=False
)
plt.close('all')


def test_PianorollHTML_plot_instrument(midi_sample):
plot = PianorollHTML()
musa_obj = Musa(midi_sample, structure="bars")
plot.plot_instrument(
track=musa_obj.instruments[0],
bar_start=1,
bar_end=2,
subdivision="quarter",
time_sig=musa_obj.time_sig.time_sig,
show=False
)

0 comments on commit 84ccc30

Please sign in to comment.