From 3cac417ad8e87905fd7222f838a0011b308e3522 Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Fri, 2 Dec 2022 20:14:17 +0100 Subject: [PATCH 1/9] big refactor to load notes, bars, beats and subbeats as objects when loading a midi file with tempo and time signature changes. Addition of REMI tokenizer, multi-track pianorolls and chord detection algorithm --- musicaiz/algorithms/__init__.py | 15 + musicaiz/algorithms/chord_prediction.py | 203 +++ musicaiz/algorithms/harmonic_shift.py | 4 +- musicaiz/converters/__init__.py | 19 +- musicaiz/converters/musa_json.py | 81 +- musicaiz/converters/musa_protobuf.py | 263 ++-- musicaiz/converters/pretty_midi_musa.py | 44 +- musicaiz/converters/protobuf/musicaiz.proto | 172 +-- musicaiz/converters/protobuf/musicaiz_pb2.py | 52 +- musicaiz/converters/protobuf/musicaiz_pb2.pyi | 268 ++-- musicaiz/datasets/jsbchorales.py | 27 + musicaiz/datasets/lmd.py | 27 + musicaiz/datasets/maestro.py | 30 +- musicaiz/datasets/utils.py | 12 +- musicaiz/features/harmony.py | 7 +- musicaiz/harmony/chords.py | 14 +- musicaiz/harmony/keys.py | 4 +- musicaiz/loaders.py | 1124 +++++++++++++++-- musicaiz/plotters/pianorolls.py | 252 ++-- musicaiz/rhythm/__init__.py | 8 +- musicaiz/rhythm/quantizer.py | 13 +- musicaiz/rhythm/timing.py | 186 ++- musicaiz/structure/bars.py | 31 +- musicaiz/structure/instruments.py | 116 +- musicaiz/structure/notes.py | 70 +- musicaiz/tokenizers/__init__.py | 10 + musicaiz/tokenizers/encoder.py | 122 +- musicaiz/tokenizers/mmm.py | 206 +-- musicaiz/tokenizers/remi.py | 403 ++++++ musicaiz/utils.py | 5 +- musicaiz/version.py | 4 +- tests/fixtures/midis/midi_changes.mid | Bin 0 -> 13476 bytes tests/fixtures/tokenizers/remi_tokens.txt | 1 + .../algorithms/test_chord_prediction.py | 19 + .../musicaiz/converters/test_musa_json.py | 31 + .../converters/test_musa_to_protobuf.py | 78 +- .../converters/test_pretty_midi_musa.py | 38 + tests/unit/musicaiz/datasets/test_lmd.py | 2 +- tests/unit/musicaiz/datasets/test_maestro.py | 2 +- tests/unit/musicaiz/features/test_harmony.py | 2 - tests/unit/musicaiz/loaders/test_loaders.py | 204 ++- .../unit/musicaiz/plotters/test_pianorolls.py | 69 +- tests/unit/musicaiz/rhythm/test_quantizer.py | 4 +- tests/unit/musicaiz/structure/test_notes.py | 6 +- tests/unit/musicaiz/tokenizers/__init__.py | 0 tests/unit/musicaiz/tokenizers/test_mmm.py | 122 +- tests/unit/musicaiz/tokenizers/test_remi.py | 134 ++ 47 files changed, 3440 insertions(+), 1064 deletions(-) create mode 100644 musicaiz/algorithms/chord_prediction.py create mode 100644 musicaiz/tokenizers/remi.py create mode 100644 tests/fixtures/midis/midi_changes.mid create mode 100644 tests/fixtures/tokenizers/remi_tokens.txt create mode 100644 tests/unit/musicaiz/algorithms/test_chord_prediction.py create mode 100644 tests/unit/musicaiz/converters/test_musa_json.py create mode 100644 tests/unit/musicaiz/converters/test_pretty_midi_musa.py create mode 100644 tests/unit/musicaiz/tokenizers/__init__.py create mode 100644 tests/unit/musicaiz/tokenizers/test_remi.py diff --git a/musicaiz/algorithms/__init__.py b/musicaiz/algorithms/__init__.py index 3ea1dbd..23d9a00 100644 --- a/musicaiz/algorithms/__init__.py +++ b/musicaiz/algorithms/__init__.py @@ -70,6 +70,15 @@ key_detection ) +from .chord_prediction import ( + predict_chords, + get_chords, + get_chords_candidates, + compute_chord_notes_dist, + _notes_to_onehot, + _group_notes_in_beats, +) + __all__ = [ "harmonic_shifting", @@ -86,4 +95,10 @@ "signature_fifths_profiles", "_eights_per_pitch_class", "key_detection", + "predict_chords", + "get_chords", + "get_chords_candidates", + "compute_chord_notes_dist", + "_notes_to_onehot", + "_group_notes_in_beats", ] diff --git a/musicaiz/algorithms/chord_prediction.py b/musicaiz/algorithms/chord_prediction.py new file mode 100644 index 0000000..b2fe9de --- /dev/null +++ b/musicaiz/algorithms/chord_prediction.py @@ -0,0 +1,203 @@ +import pretty_midi as pm +import numpy as np +from typing import List, Dict + + +from musicaiz.structure import Note +from musicaiz.rhythm import NoteLengths +from musicaiz.harmony import Chord + + +def predict_chords(musa_obj): + notes_beats = [] + for i in range(len(musa_obj.beats)): + nts = musa_obj.get_notes_in_beat(i) + nts = [n for n in nts if not n.is_drum] + if nts is not None or len(nts) != 0: + notes_beats.append(nts) + notes_pitches_segments = [_notes_to_onehot(note) for note in notes_beats] + # Convert chord labels to onehot + chords_onehot = Chord.chords_to_onehot() + # step 1: Compute the distance between all the chord vectors and the notes vectors + all_dists = [compute_chord_notes_dist(chords_onehot, segment) for segment in notes_pitches_segments] + # step 2: get chord candidates per step which distance is the lowest + chord_segments = get_chords_candidates(all_dists) + # step 3: clean chord candidates + chords = get_chords(chord_segments, chords_onehot) + return chords + + +def get_chords( + chord_segments: List[List[str]], + chords_onehot: Dict[str, List[int]], +) -> List[List[str]]: + """ + Clean the predicted chords that are extracted with get_chords_candidates method + by comparing each chord in a step with the chords in the previous and next steps. + The ouput chords are the ones wich distances are the lowest. + + Parameters + ---------- + + chord_segments: List[List[str]] + The chord candidates extracted with get_chords_candidates method. + + Returns + ------- + + chords: List[List[str]] + """ + chords = [] + for i, _ in enumerate(chord_segments): + cross_dists = {} + for j, _ in enumerate(chord_segments[i]): + if i == 0: + for item in range(len(chord_segments[i + 1])): + dist = np.linalg.norm(np.array(chords_onehot[chord_segments[i][j]]) - np.array(chords_onehot[chord_segments[i+1][item]])) + cross_dists.update( + { + chord_segments[i][j] + " " + chord_segments[i+1][item]: dist + } + ) + if i != 0: + for item in range(len(chord_segments[i - 1])): + dist = np.linalg.norm(np.array(chords_onehot[chord_segments[i][j]]) - np.array(chords_onehot[chord_segments[i-1][item]])) + cross_dists.update( + { + chord_segments[i][j] + " " + chord_segments[i-1][item]: dist + } + ) + #print("--------") + #print(cross_dists) + chords_list = [(i.split(" ")[0], cross_dists[i]) for i in cross_dists if cross_dists[i]==min(cross_dists.values())] + chords_dict = {} + chords_dict.update(chords_list) + #print(chords_dict) + # Diminish distances if in previous step there's one or more chords equal to the chords in the current step + for chord, dist in chords_dict.items(): + if i != 0: + prev_chords = [c for c in chords[i - 1]] + tonics = [c.split("-")[0] for c in prev_chords] + tonic = chord.split("-")[0] + if chord not in prev_chords or tonic not in tonics: + chords_dict[chord] = dist + 0.5 + #print(chords_dict) + new_chords_list = [i for i in chords_dict if chords_dict[i]==min(chords_dict.values())] + #print(new_chords_list) + chords.append(new_chords_list) + # If a 7th chord is predicted at a time step and the same chord triad is at + # the prev at next steps, we'll substitute the triad chord for the 7th chord + #for step in chords: + # chord_names = "/".join(step) + # if "SEVENTH" in chord_names: + return chords + + +def get_chords_candidates(dists: List[Dict[str, float]]) -> List[List[str]]: + """ + Gets the chords with the minimum distance in a list of dictionaries + where each element of the list is a step (beat) corresponding to the note + vectors and the items are dicts with the chord names (key) and dists (val.) + + Parameters + ---------- + + dists: List[Dict[str, float]] + The list of distances between chord and note vectors as dictionaries per step. + + Returns + ------- + + chord_segments: List[List[str]] + A list with all the chords predicted per step. + """ + chord_segments = [] + for dists_dict in dists: + chord_segments.append([i for i in dists_dict if dists_dict[i]==min(dists_dict.values())]) + return chord_segments + + +def compute_chord_notes_dist( + chords_onehot: Dict[str, List[int]], + notes_onehot: Dict[str, List[int]], +) -> Dict[str, float]: + """ + Compute the distance between each chord and a single notes vector. + The outpput is given as a dictionary with the chord name (key) and the distance (val.). + + Parameters + ---------- + + chords_onehot: Dict[str, List[int]] + + notes_onehot: Dict[str, List[int]] + + Returns + ------- + + dists: Dict[str, float] + """ + dists = {} + for chord, chord_vec in chords_onehot.items(): + dist = np.linalg.norm(np.array(notes_onehot)-np.array(chord_vec)) + dists.update({chord: dist}) + return dists + + +def _notes_to_onehot(notes: List[Note]) -> List[int]: + """ + Converts a list of notes into a list of 0s and 1s. + The output list will have 12 elements corresponding to + the notes in the chromatic scale from C to B. + If the note C is in the input list, the index corresponding + to that note in the output list will be 1, otherwise it'll be 0. + + Parameters + ---------- + notes: List[Note]) + + Returns + ------- + pitches_onehot: List[int] + """ + pitches = [pm.note_name_to_number(note.note_name + "-1") for note in notes] + pitches = list(dict.fromkeys(pitches)) + pitches_onehot = [1 if i in pitches else 0 for i in range(0, 12)] + return pitches_onehot + + +def _group_notes_in_beats(midi) -> List[List[Note]]: + """ + Group notes in beats (a quarter note in 4/4 time sig.). + + Parameters + ---------- + + midi: A musa object + + Returns + ------- + """ + # split iois if difference in quarters + if midi.time_sig.denom == 4: + iois_split = NoteLengths.QUARTER.ticks() + + notes_segments = [] + it = 0 + for i, note in enumerate(midi.notes): + if i < it: + continue + for next_note in midi.notes[i:]: + diff = abs(note.start_ticks - next_note.start_ticks) + if diff < iois_split: + it += 1 + if diff >= iois_split: + notes_segments.append(midi.notes[i:it]) + break + # if only one note in the 1st step, group the notes in the next beat into it + for n, notes in enumerate(notes_segments): + if i == 0: + continue + if len(notes) <= 1: + notes.extend(notes_segments[n+1]) + return notes_segments diff --git a/musicaiz/algorithms/harmonic_shift.py b/musicaiz/algorithms/harmonic_shift.py index d3b76ac..a986127 100644 --- a/musicaiz/algorithms/harmonic_shift.py +++ b/musicaiz/algorithms/harmonic_shift.py @@ -4,7 +4,7 @@ from musicaiz.harmony import Tonality from musicaiz.structure import NoteClassBase, Note -from musicaiz.converters import pretty_midi_note_to_musanalysis +from musicaiz.converters import prettymidi_note_to_musicaiz def harmonic_shifting( @@ -347,7 +347,7 @@ def _map_passing_note( # Check if target pitch corresponds to a note in the scale, if not, # convert the note to the closest note in the scale target_name = pm.note_number_to_name(target_pitch) - target_name, target_octave = pretty_midi_note_to_musanalysis(target_name) + target_name, target_octave = prettymidi_note_to_musicaiz(target_name) # Get the notes in the scale all_degs = ["I", "II", "III", "IV", "V", "VI", "VII"] diff --git a/musicaiz/converters/__init__.py b/musicaiz/converters/__init__.py index ead1ccd..721f380 100644 --- a/musicaiz/converters/__init__.py +++ b/musicaiz/converters/__init__.py @@ -20,17 +20,22 @@ .. autosummary:: :toctree: generated/ - pretty_midi_note_to_musanalysis + prettymidi_note_to_musicaiz + musicaiz_note_to_prettymidi """ from .musa_json import ( - MusaJSON + MusaJSON, + BarJSON, + InstrumentJSON, + NoteJSON, ) from .pretty_midi_musa import ( - pretty_midi_note_to_musanalysis, - musa_to_prettymidi + prettymidi_note_to_musicaiz, + musicaiz_note_to_prettymidi, + musa_to_prettymidi, ) from .musa_protobuf import ( @@ -42,7 +47,11 @@ __all__ = [ "MusaJSON", - "pretty_midi_note_to_musanalysis", + "BarJSON", + "InstrumentJSON", + "NoteJSON", + "prettymidi_note_to_musicaiz", + "musicaiz_note_to_prettymidi", "musa_to_prettymidi", "protobuf", "musa_to_proto", diff --git a/musicaiz/converters/musa_json.py b/musicaiz/converters/musa_json.py index 5b98a3d..cf296d4 100644 --- a/musicaiz/converters/musa_json.py +++ b/musicaiz/converters/musa_json.py @@ -10,14 +10,19 @@ class NoteJSON: end: int # ticks pitch: int velocity: int + bar_idx: int + beat_idx: int + subbeat_idx: int + instrument_idx: int + instrument_prog: int @dataclass class BarJSON: time_sig: str + bpm: int start: int # ticks end: int # ticks - notes: List[NoteJSON] @dataclass @@ -25,15 +30,12 @@ class InstrumentJSON: is_drum: bool name: str n_prog: int - bars: List[BarJSON] - @dataclass class JSON: tonality: str time_sig: str - instruments: List[InstrumentJSON] class MusaJSON: @@ -41,7 +43,11 @@ class MusaJSON: """ This class converst a `musicaiz` :func:`~musicaiz.loaders.Musa` object into a JSON format. - + Note that this conversion is different that the .json method of Musa class, + since that is intended for encoding musicaiz objects and this class here + can be encoded and decoded with other softwares since it does not encode + musicaiz objects. + Examples -------- @@ -56,7 +62,7 @@ class MusaJSON: field="hello", value=2 ) - + Save the json to disk: >>> musa_json.save("filename") @@ -68,23 +74,25 @@ def __init__( ): self.midi = musa_obj self.json = self.to_json(musa_obj=self.midi) - + def save(self, filename: str, path: Union[str, Path] = ""): """Saves the JSON into disk.""" with open(Path(path, filename + ".json"), "w") as write_file: json.dump(self.json, write_file) - + @staticmethod def to_json(musa_obj): composition = {} # headers composition["tonality"] = musa_obj.tonality - composition["time_sig"] = musa_obj.time_sig.time_sig - composition["instruments"] = [{}] * len(musa_obj.instruments) + composition["resolution"] = musa_obj.resolution + composition["instruments"] = [] + composition["bars"] = [] + composition["notes"] = [] composition["instruments"] = [] - for i, instr in enumerate(musa_obj.instruments): + for _, instr in enumerate(musa_obj.instruments): composition["instruments"].append( { "is_drum": instr.is_drum, @@ -92,31 +100,30 @@ def to_json(musa_obj): "n_prog": int(instr.program), } ) - if instr.bars is None: - continue - if len(instr.bars) == 0: - continue - composition["instruments"][i]["bars"] = [] - for b, bar in enumerate(instr.bars): - composition["instruments"][i]["bars"].append( - { - "time_sig": bar.time_sig, - "start": bar.start_ticks, - "end": bar.end_ticks - } - ) - composition["instruments"][i]["bars"][b]["notes"] = [] - if len(bar.notes) == 0: - continue - for n, note in enumerate(bar.notes): - composition["instruments"][i]["bars"][b]["notes"].append( - { - "start": note.start_ticks, - "end": note.end_ticks, - "pitch": note.pitch, - "velocity": note.velocity, - } - ) + for _, bar in enumerate(musa_obj.bars): + composition["bars"].append( + { + "time_sig": bar.time_sig.time_sig, + "start": bar.start_ticks, + "end": bar.end_ticks, + "bpm": bar.bpm, + } + ) + for _, note in enumerate(musa_obj.notes): + composition["notes"].append( + { + "start": note.start_ticks, + "end": note.end_ticks, + "pitch": note.pitch, + "velocity": note.velocity, + "bar_idx": note.bar_idx, + "beat_idx": note.beat_idx, + "subbeat_idx": note.subbeat_idx, + "instrument_idx": note.instrument_idx, + "instrument_prog": note.instrument_prog, + + } + ) return composition def add_instrument_field(self, n_program: int, field: str, value: Union[str, int, float]): @@ -183,4 +190,4 @@ def delete_header_field(): class JSONMusa: - NotImplementedError \ No newline at end of file + NotImplementedError diff --git a/musicaiz/converters/musa_protobuf.py b/musicaiz/converters/musa_protobuf.py index 8693da5..deb66c0 100644 --- a/musicaiz/converters/musa_protobuf.py +++ b/musicaiz/converters/musa_protobuf.py @@ -1,5 +1,13 @@ from musicaiz.converters.protobuf import musicaiz_pb2 -from musicaiz.structure import Note, Instrument, Bar +from musicaiz.structure import ( + Note, + Instrument, + Bar, +) +from musicaiz.rhythm import ( + Beat, + Subdivision, +) from musicaiz import loaders @@ -13,16 +21,48 @@ def musa_to_proto(musa_obj): Returns ------- - + proto: The output protobuf. """ proto = musicaiz_pb2.Musa() # Time signature data - time_sig = proto.time_signatures.add() - time_sig.num = musa_obj.time_sig.num - time_sig.denom = musa_obj.time_sig.denom + proto_time_signature_changes = proto.time_signature_changes.add() + proto_time_signature_changes = musa_obj.time_signature_changes + + proto_subdivision_note = proto.subdivision_note.add() + proto_subdivision_note = proto.subdivision_note.add() + + proto_file = proto.file.add() + proto_file = musa_obj.file + + proto_total_bars = proto.total_bars.add() + proto_total_bars = musa_obj.total_bars + + proto_tonality = proto.tonality.add() + proto_tonality = musa_obj.tonality + + proto_resolution = proto.resolution.add() + proto_resolution = musa_obj.resolution + + proto_is_quantized = proto.is_quantized.add() + proto_is_quantized = musa_obj.is_quantized + + proto_quantize_note = proto.quantize_note.add() + proto_quantize_note = musa_obj.quantize_note + + proto_absolute_timing = proto.absolute_timing.add() + proto_absolute_timing = musa_obj.absolute_timing + + proto_cut_notes = proto.cut_notes.add() + proto_cut_notes = musa_obj.cut_notes + + proto_tempo_changes = proto.tempo_changes.add() + proto_tempo_changes = musa_obj.tempo_changes + + proto_instruments_progs = proto.instruments_progs.add() + proto_instruments_progs = musa_obj.instruments_progs # Other parameters (quantization, PPQ...) for instr in musa_obj.instruments: @@ -32,52 +72,61 @@ def musa_to_proto(musa_obj): proto_instruments.family = instr.family if instr.family is not None else "" proto_instruments.is_drum = instr.is_drum - if instr.bars is not None: - if len(instr.bars) != 0: - # loop in bars to add them to the protobuf - for bar in instr.bars: - proto_bars = proto_instruments.bars.add() - proto_bars.bpm = bar.bpm - proto_bars.time_sig = bar.time_sig - proto_bars.resolution = bar.resolution - proto_bars.absolute_timing = bar.absolute_timing - proto_bars.note_density = bar.note_density - proto_bars.harmonic_density = bar.harmonic_density - proto_bars.start_ticks = bar.start_ticks - proto_bars.end_ticks = bar.end_ticks - proto_bars.start_sec = bar.start_sec - proto_bars.end_sec = bar.end_sec - - # loop in notes to add them to the protobuf - for note in bar.notes: - proto_note = proto_bars.notes.add() - proto_note.pitch = note.pitch - proto_note.pitch_name = note.pitch_name - proto_note.note_name = note.note_name - proto_note.octave = note.octave - proto_note.ligated = note.ligated - proto_note.start_ticks = note.start_ticks - proto_note.end_ticks = note.end_ticks - proto_note.start_sec = note.start_sec - proto_note.end_sec = note.end_sec - proto_note.symbolic = note.symbolic - proto_note.velocity = note.velocity - else: - if len(instr.notes) != 0: - # loop in notes to add them to the protobuf - for note in instr.notes: - proto_note = proto_instruments.notes.add() - proto_note.pitch = note.pitch - proto_note.pitch_name = note.pitch_name - proto_note.note_name = note.note_name - proto_note.octave = note.octave - proto_note.ligated = note.ligated - proto_note.start_ticks = note.start_ticks - proto_note.end_ticks = note.end_ticks - proto_note.start_sec = note.start_sec - proto_note.end_sec = note.end_sec - proto_note.symbolic = note.symbolic - proto_note.velocity = note.velocity + # loop in bars to add them to the protobuf + for bar in musa_obj.bars: + proto_bars = proto.bars.add() + proto_bars.bpm = bar.bpm + proto_bars.time_sig = bar.time_sig.time_sig + proto_bars.resolution = bar.resolution + proto_bars.absolute_timing = bar.absolute_timing + proto_bars.note_density = bar.note_density + proto_bars.harmonic_density = bar.harmonic_density + proto_bars.start_ticks = bar.start_ticks + proto_bars.end_ticks = bar.end_ticks + proto_bars.start_sec = bar.start_sec + proto_bars.end_sec = bar.end_sec + + # loop in notes to add them to the protobuf + for note in musa_obj.notes: + proto_note = proto.notes.add() + proto_note.pitch = note.pitch + proto_note.pitch_name = note.pitch_name + proto_note.note_name = note.note_name + proto_note.octave = note.octave + proto_note.ligated = note.ligated + proto_note.start_ticks = note.start_ticks + proto_note.end_ticks = note.end_ticks + proto_note.start_sec = note.start_sec + proto_note.end_sec = note.end_sec + proto_note.symbolic = note.symbolic + proto_note.velocity = note.velocity + proto_note.instrument_prog = note.instrument_prog + proto_note.instrument_idx = note.instrument_idx + proto_note.beat_idx = note.beat_idx + proto_note.bar_idx = note.bar_idx + proto_note.subbeat_idx = note.subbeat_idx + + for beat in musa_obj.beats: + proto_beat = proto.beats.add() + proto_beat.time_sig = beat.time_sig.time_sig + proto_beat.bpm = beat.bpm + proto_beat.global_idx = beat.global_idx + proto_beat.bar_idx = beat.bar_idx + proto_beat.start_sec = beat.start_sec + proto_beat.end_sec = beat.end_sec + proto_beat.start_ticks = beat.start_ticks + proto_beat.end_ticks = beat.end_ticks + + for subbeat in musa_obj.subbeats: + proto_subbeat = proto.subbeats.add() + proto_subbeat.global_idx = subbeat.global_idx + proto_subbeat.bar_idx = subbeat.bar_idx + proto_subbeat.beat_idx = subbeat.beat_idx + proto_subbeat.start_sec = subbeat.start_sec + proto_subbeat.end_sec = subbeat.end_sec + proto_subbeat.start_ticks = subbeat.start_ticks + proto_subbeat.end_ticks = subbeat.end_ticks + return proto @@ -93,14 +142,14 @@ def proto_to_musa(protobuf): #-> loaders.Musa: Returns ------- - + midi: Musa The output Musa object. """ - - midi = loaders.Musa() - for i, instr in enumerate(protobuf.instruments): + midi = loaders.Musa(file=None) + + for _, instr in enumerate(protobuf.instruments): midi.instruments.append( Instrument( program=instr.program, @@ -109,47 +158,65 @@ def proto_to_musa(protobuf): #-> loaders.Musa: general_midi=False, ) ) - if len(instr.notes) != 0: - # the notes are stored inside instruments, but not in bars - for note in instr.notes: - # Initialize the note with musicaiz `Note` object - midi.instruments[i].notes.append( - Note( - pitch=note.pitch, - start=note.start_ticks, - end=note.end_ticks, - velocity=note.velocity, - bpm=bar.bpm, - resolution=bar.resolution, - ) - ) - elif len(instr.notes) == 0: - # the notes are stored inside bars that are stored inside instruments - for j, bar in enumerate(instr.bars): - # Initialize the bar with musicaiz `Bar` object - midi.instruments[i].bars.append( - Bar( - time_sig=bar.time_sig, - bpm=bar.bpm, - resolution=bar.resolution, - absolute_timing=bar.absolute_timing, - ) - ) - for note in bar.notes: - midi.instruments[i].bars[j].notes.append( - Note( - pitch=note.pitch, - start=note.start_ticks, - end=note.end_ticks, - velocity=note.velocity, - bpm=bar.bpm, - resolution=bar.resolution, - ) - ) - midi.instruments[i].bars[j].note_density = bar.note_density - midi.instruments[i].bars[j].harmonic_density = bar.harmonic_density - midi.instruments[i].bars[j].start_ticks = bar.start_ticks - midi.instruments[i].bars[j].end_ticks = bar.end_ticks - midi.instruments[i].bars[j].start_sec = bar.start_sec - midi.instruments[i].bars[j].end_sec = bar.end_sec + + for note in midi.notes: + # Initialize the note with musicaiz `Note` object + midi.notes.append( + Note( + pitch=note.pitch, + start=note.start_ticks, + end=note.end_ticks, + velocity=note.velocity, + bpm=note.bpm, + resolution=note.resolution, + instrument_prog=note.instrument_prog, + instrument_idx=note.instrument_idx + ) + ) + + for j, bar in enumerate(midi.bars): + # Initialize the bar with musicaiz `Bar` object + midi.bars.append( + Bar( + time_sig=bar.time_sig, + bpm=bar.bpm, + resolution=bar.resolution, + absolute_timing=bar.absolute_timing, + ) + ) + midi.bars[j].note_density = bar.note_density + midi.bars[j].harmonic_density = bar.harmonic_density + midi.bars[j].start_ticks = bar.start_ticks + midi.bars[j].end_ticks = bar.end_ticks + midi.bars[j].start_sec = bar.start_sec + midi.bars[j].end_sec = bar.end_sec + + for beat in midi.beats: + # Initialize the note with musicaiz `Note` object + midi.beats.append( + Beat( + time_sig=beat.time_sig, + bpm=beat.bpm, + resolution=beat.resolution, + global_idx=beat.global_idx, + bar_idx=beat.bar_idx, + start=beat.start_sec, + end=beat.end_sec + ) + ) + + for subbeat in midi.subbeats: + # Initialize the note with musicaiz `Note` object + midi.notes.append( + Subdivision( + time_sig=subbeat.time_sig, + bpm=subbeat.bpm, + resolution=subbeat.resolution, + global_idx=subbeat.global_idx, + bar_idx=subbeat.bar_idx, + start=subbeat.start_sec, + end=subbeat.end_sec, + beat_idx=subbeat.beat_idx, + ) + ) return midi \ No newline at end of file diff --git a/musicaiz/converters/pretty_midi_musa.py b/musicaiz/converters/pretty_midi_musa.py index 8f8096d..85a068a 100644 --- a/musicaiz/converters/pretty_midi_musa.py +++ b/musicaiz/converters/pretty_midi_musa.py @@ -5,7 +5,7 @@ from musicaiz.structure import NoteClassBase -def pretty_midi_note_to_musanalysis(note: str) -> Tuple[str, int]: +def prettymidi_note_to_musicaiz(note: str) -> Tuple[str, int]: octave = int("".join(filter(str.isdigit, note))) # Get the note name without the octave note_name = note.replace(str(octave), "") @@ -13,6 +13,23 @@ def pretty_midi_note_to_musanalysis(note: str) -> Tuple[str, int]: return musa_note_name.name, octave +def musicaiz_note_to_prettymidi( + note: str, + octave: int +) -> str: + """ + >>> note = "F_SHARP" + >>> octave = 3 + >>> pm_note = musicaiz_note_to_prettymidi(note, octave) + >>> "F#3" + """ + note_name = note.replace("SHARP", "#") + note_name = note_name.replace("FLAT", "b") + note_name = note_name.replace("_", "") + pm_note = note_name + str(octave) + return pm_note + + def musa_to_prettymidi(musa_obj): """ Converts a Musa object into a PrettMIDI object. @@ -23,11 +40,24 @@ def musa_to_prettymidi(musa_obj): midi: PrettyMIDI The pretty_midi object. """ - # TODO: Write also metadata in PrettyMIDI object: pitch bends... + # TODO: Write also metadata in PrettyMIDI object: pitch bends.. midi = pm.PrettyMIDI( resolution=musa_obj.resolution, - initial_tempo=musa_obj.bpm + initial_tempo=musa_obj.tempo_changes[0]["tempo"] ) + midi.time_signature_changes = [] + for ts in musa_obj.time_signature_changes: + midi.time_signature_changes.append( + pm.TimeSignature( + numerator=ts["time_sig"].num, + denominator=ts["time_sig"].denom, + time=ts["ms"] / 1000 + ) + ) + # TODO: Get ticks for each event (see Mido) + midi._tick_scales = [ + (0, 60.0 / (musa_obj.tempo_changes[0]["tempo"] * midi.resolution)) + ] for i, inst in enumerate(musa_obj.instruments): midi.instruments.append( @@ -37,7 +67,11 @@ def musa_to_prettymidi(musa_obj): name=inst.name ) ) - for note in inst.notes: + notes = musa_obj.get_notes_in_bars( + bar_start=0, bar_end=musa_obj.total_bars, + program=int(inst.program), instrument_idx=i + ) + for note in notes: midi.instruments[i].notes.append( pm.Note( velocity=note.velocity, @@ -46,4 +80,4 @@ def musa_to_prettymidi(musa_obj): end=note.end_sec ) ) - return midi \ No newline at end of file + return midi diff --git a/musicaiz/converters/protobuf/musicaiz.proto b/musicaiz/converters/protobuf/musicaiz.proto index 6ab850f..abe637b 100644 --- a/musicaiz/converters/protobuf/musicaiz.proto +++ b/musicaiz/converters/protobuf/musicaiz.proto @@ -7,23 +7,36 @@ package musicaiz; message Musa { - repeated TimeSignature time_signatures = 5; - repeated Instrument instruments = 6; - - message Note { - - } - - // time attributes: musicaiz.rhythm - message TimeSignature { - int32 num = 2; - int32 denom = 3; - } - - // harmony attributes: musicaiz.harmony - message Tonality { - - } + repeated TimeSignatureChanges time_signature_changes = 5; + repeated SubdivisionNote subdivision_note = 6; + repeated File file = 7; + repeated TotalBars total_bars = 8; + repeated Tonality tonality = 9; + repeated Resolution resolution = 10; + repeated IsQuantized is_quantized = 11; + repeated QuantizeNote quantize_note = 12; + repeated AbsoluteTiming absolute_timing = 13; + repeated CutNotes cut_notes = 14; + repeated TempoChanges tempo_changes = 15; + repeated InstrumentsProgs instruments_progs = 16; + repeated Instrument instruments = 17; + repeated Bar bars = 18; + repeated Note notes = 19; + repeated Beat beats = 20; + repeated Subbeat subbeats = 21; + + message TimeSignatureChanges {} + message SubdivisionNote {} + message File {} + message TotalBars {} + message Tonality {} + message Resolution {} + message IsQuantized {} + message QuantizeNote {} + message AbsoluteTiming {} + message CutNotes {} + message TempoChanges {} + message InstrumentsProgs {} message Instrument { // Instrument index. @@ -34,63 +47,80 @@ message Musa { string name = 3; // The instrument's family. string family = 4; - bool is_drum = 5; - repeated Bar bars = 6; - repeated Note notes = 7; - - message Note { - int32 pitch = 1; - string pitch_name = 2; - string note_name = 3; - string octave = 4; - bool ligated = 5; - - // Timing inf of the Note - int32 start_ticks = 6; - int32 end_ticks = 7; - float start_sec = 8; - float end_sec = 9; - string symbolic = 10; - - int32 velocity = 11; - } - - message Bar { - int32 bpm = 1; - string time_sig = 2; - int32 resolution = 3; - bool absolute_timing = 4; - - // Timing inf of the Bar - int32 note_density = 5; - int32 harmonic_density = 6; - int32 start_ticks = 7; - int32 end_ticks = 8; - float start_sec = 9; - float end_sec = 10; - - repeated Note notes = 11; - - message Note { - int32 pitch = 1; - string pitch_name = 2; - string note_name = 3; - string octave = 4; - bool ligated = 5; - - // Timing inf of the Note - int32 start_ticks = 6; - int32 end_ticks = 7; - float start_sec = 8; - float end_sec = 9; - string symbolic = 10; - - int32 velocity = 11; - } - } + bool is_drum = 5; + } + + message Note { + int32 pitch = 1; + string pitch_name = 2; + string note_name = 3; + string octave = 4; + bool ligated = 5; + + // Timing inf of the Note + int32 start_ticks = 6; + int32 end_ticks = 7; + float start_sec = 8; + float end_sec = 9; + string symbolic = 10; + + int32 velocity = 11; + + int32 bar_idx = 12; + int32 beat_idx = 13; + int32 subbeat_idx = 14; + + int32 instrument_idx = 15; + int32 instrument_prog = 16; + } + message Bar { + float bpm = 1; + string time_sig = 2; + int32 resolution = 3; + bool absolute_timing = 4; + + // Timing inf of the Bar + int32 note_density = 5; + int32 harmonic_density = 6; + int32 start_ticks = 7; + int32 end_ticks = 8; + float start_sec = 9; + float end_sec = 10; } + message Beat { + float bpm = 1; + string time_sig = 2; + int32 resolution = 3; + bool absolute_timing = 4; + + // Timing + int32 start_ticks = 7; + int32 end_ticks = 8; + float start_sec = 9; + float end_sec = 10; + int32 global_idx = 11; + int32 bar_idx = 12; + } + message Subbeat { + float bpm = 1; + string time_sig = 2; + int32 resolution = 3; + bool absolute_timing = 4; + + // Timing inf of the Bar + int32 note_density = 5; + int32 harmonic_density = 6; + int32 start_ticks = 7; + int32 end_ticks = 8; + float start_sec = 9; + float end_sec = 10; + + int32 global_idx = 11; + int32 bar_idx = 12; + int32 beat_idx = 13; + } } \ No newline at end of file diff --git a/musicaiz/converters/protobuf/musicaiz_pb2.py b/musicaiz/converters/protobuf/musicaiz_pb2.py index ea35047..241f4ca 100644 --- a/musicaiz/converters/protobuf/musicaiz_pb2.py +++ b/musicaiz/converters/protobuf/musicaiz_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n+musicaiz/converters/protobuf/musicaiz.proto\x12\x08musicaiz\"\x90\x08\n\x04Musa\x12\x35\n\x0ftime_signatures\x18\x05 \x03(\x0b\x32\x1c.musicaiz.Musa.TimeSignature\x12.\n\x0binstruments\x18\x06 \x03(\x0b\x32\x19.musicaiz.Musa.Instrument\x1a\x06\n\x04Note\x1a+\n\rTimeSignature\x12\x0b\n\x03num\x18\x02 \x01(\x05\x12\r\n\x05\x64\x65nom\x18\x03 \x01(\x05\x1a\n\n\x08Tonality\x1a\xdf\x06\n\nInstrument\x12\x12\n\ninstrument\x18\x01 \x01(\x05\x12\x0f\n\x07program\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06\x66\x61mily\x18\x04 \x01(\t\x12\x0f\n\x07is_drum\x18\x05 \x01(\x08\x12+\n\x04\x62\x61rs\x18\x06 \x03(\x0b\x32\x1d.musicaiz.Musa.Instrument.Bar\x12-\n\x05notes\x18\x07 \x03(\x0b\x32\x1e.musicaiz.Musa.Instrument.Note\x1a\xcd\x01\n\x04Note\x12\r\n\x05pitch\x18\x01 \x01(\x05\x12\x12\n\npitch_name\x18\x02 \x01(\t\x12\x11\n\tnote_name\x18\x03 \x01(\t\x12\x0e\n\x06octave\x18\x04 \x01(\t\x12\x0f\n\x07ligated\x18\x05 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x06 \x01(\x05\x12\x11\n\tend_ticks\x18\x07 \x01(\x05\x12\x11\n\tstart_sec\x18\x08 \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\t \x01(\x02\x12\x10\n\x08symbolic\x18\n \x01(\t\x12\x10\n\x08velocity\x18\x0b \x01(\x05\x1a\xd0\x03\n\x03\x42\x61r\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x05\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x31\n\x05notes\x18\x0b \x03(\x0b\x32\".musicaiz.Musa.Instrument.Bar.Note\x1a\xcd\x01\n\x04Note\x12\r\n\x05pitch\x18\x01 \x01(\x05\x12\x12\n\npitch_name\x18\x02 \x01(\t\x12\x11\n\tnote_name\x18\x03 \x01(\t\x12\x0e\n\x06octave\x18\x04 \x01(\t\x12\x0f\n\x07ligated\x18\x05 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x06 \x01(\x05\x12\x11\n\tend_ticks\x18\x07 \x01(\x05\x12\x11\n\tstart_sec\x18\x08 \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\t \x01(\x02\x12\x10\n\x08symbolic\x18\n \x01(\t\x12\x10\n\x08velocity\x18\x0b \x01(\x05\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n+musicaiz/converters/protobuf/musicaiz.proto\x12\x08musicaiz\"\xa5\x10\n\x04Musa\x12\x43\n\x16time_signature_changes\x18\x05 \x03(\x0b\x32#.musicaiz.Musa.TimeSignatureChanges\x12\x38\n\x10subdivision_note\x18\x06 \x03(\x0b\x32\x1e.musicaiz.Musa.SubdivisionNote\x12!\n\x04\x66ile\x18\x07 \x03(\x0b\x32\x13.musicaiz.Musa.File\x12,\n\ntotal_bars\x18\x08 \x03(\x0b\x32\x18.musicaiz.Musa.TotalBars\x12)\n\x08tonality\x18\t \x03(\x0b\x32\x17.musicaiz.Musa.Tonality\x12-\n\nresolution\x18\n \x03(\x0b\x32\x19.musicaiz.Musa.Resolution\x12\x30\n\x0cis_quantized\x18\x0b \x03(\x0b\x32\x1a.musicaiz.Musa.IsQuantized\x12\x32\n\rquantize_note\x18\x0c \x03(\x0b\x32\x1b.musicaiz.Musa.QuantizeNote\x12\x36\n\x0f\x61\x62solute_timing\x18\r \x03(\x0b\x32\x1d.musicaiz.Musa.AbsoluteTiming\x12*\n\tcut_notes\x18\x0e \x03(\x0b\x32\x17.musicaiz.Musa.CutNotes\x12\x32\n\rtempo_changes\x18\x0f \x03(\x0b\x32\x1b.musicaiz.Musa.TempoChanges\x12:\n\x11instruments_progs\x18\x10 \x03(\x0b\x32\x1f.musicaiz.Musa.InstrumentsProgs\x12.\n\x0binstruments\x18\x11 \x03(\x0b\x32\x19.musicaiz.Musa.Instrument\x12 \n\x04\x62\x61rs\x18\x12 \x03(\x0b\x32\x12.musicaiz.Musa.Bar\x12\"\n\x05notes\x18\x13 \x03(\x0b\x32\x13.musicaiz.Musa.Note\x12\"\n\x05\x62\x65\x61ts\x18\x14 \x03(\x0b\x32\x13.musicaiz.Musa.Beat\x12(\n\x08subbeats\x18\x15 \x03(\x0b\x32\x16.musicaiz.Musa.Subbeat\x1a\x16\n\x14TimeSignatureChanges\x1a\x11\n\x0fSubdivisionNote\x1a\x06\n\x04\x46ile\x1a\x0b\n\tTotalBars\x1a\n\n\x08Tonality\x1a\x0c\n\nResolution\x1a\r\n\x0bIsQuantized\x1a\x0e\n\x0cQuantizeNote\x1a\x10\n\x0e\x41\x62soluteTiming\x1a\n\n\x08\x43utNotes\x1a\x0e\n\x0cTempoChanges\x1a\x12\n\x10InstrumentsProgs\x1a`\n\nInstrument\x12\x12\n\ninstrument\x18\x01 \x01(\x05\x12\x0f\n\x07program\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06\x66\x61mily\x18\x04 \x01(\t\x12\x0f\n\x07is_drum\x18\x05 \x01(\x08\x1a\xb6\x02\n\x04Note\x12\r\n\x05pitch\x18\x01 \x01(\x05\x12\x12\n\npitch_name\x18\x02 \x01(\t\x12\x11\n\tnote_name\x18\x03 \x01(\t\x12\x0e\n\x06octave\x18\x04 \x01(\t\x12\x0f\n\x07ligated\x18\x05 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x06 \x01(\x05\x12\x11\n\tend_ticks\x18\x07 \x01(\x05\x12\x11\n\tstart_sec\x18\x08 \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\t \x01(\x02\x12\x10\n\x08symbolic\x18\n \x01(\t\x12\x10\n\x08velocity\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x12\x13\n\x0bsubbeat_idx\x18\x0e \x01(\x05\x12\x16\n\x0einstrument_idx\x18\x0f \x01(\x05\x12\x17\n\x0finstrument_prog\x18\x10 \x01(\x05\x1a\xcd\x01\n\x03\x42\x61r\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x1a\xc3\x01\n\x04\x42\x65\x61t\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x1a\x88\x02\n\x07Subbeat\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'musicaiz.converters.protobuf.musicaiz_pb2', globals()) @@ -21,19 +21,39 @@ DESCRIPTOR._options = None _MUSA._serialized_start=58 - _MUSA._serialized_end=1098 - _MUSA_NOTE._serialized_start=169 - _MUSA_NOTE._serialized_end=175 - _MUSA_TIMESIGNATURE._serialized_start=177 - _MUSA_TIMESIGNATURE._serialized_end=220 - _MUSA_TONALITY._serialized_start=222 - _MUSA_TONALITY._serialized_end=232 - _MUSA_INSTRUMENT._serialized_start=235 - _MUSA_INSTRUMENT._serialized_end=1098 - _MUSA_INSTRUMENT_NOTE._serialized_start=426 - _MUSA_INSTRUMENT_NOTE._serialized_end=631 - _MUSA_INSTRUMENT_BAR._serialized_start=634 - _MUSA_INSTRUMENT_BAR._serialized_end=1098 - _MUSA_INSTRUMENT_BAR_NOTE._serialized_start=426 - _MUSA_INSTRUMENT_BAR_NOTE._serialized_end=631 + _MUSA._serialized_end=2143 + _MUSA_TIMESIGNATURECHANGES._serialized_start=874 + _MUSA_TIMESIGNATURECHANGES._serialized_end=896 + _MUSA_SUBDIVISIONNOTE._serialized_start=898 + _MUSA_SUBDIVISIONNOTE._serialized_end=915 + _MUSA_FILE._serialized_start=917 + _MUSA_FILE._serialized_end=923 + _MUSA_TOTALBARS._serialized_start=925 + _MUSA_TOTALBARS._serialized_end=936 + _MUSA_TONALITY._serialized_start=938 + _MUSA_TONALITY._serialized_end=948 + _MUSA_RESOLUTION._serialized_start=950 + _MUSA_RESOLUTION._serialized_end=962 + _MUSA_ISQUANTIZED._serialized_start=964 + _MUSA_ISQUANTIZED._serialized_end=977 + _MUSA_QUANTIZENOTE._serialized_start=979 + _MUSA_QUANTIZENOTE._serialized_end=993 + _MUSA_ABSOLUTETIMING._serialized_start=995 + _MUSA_ABSOLUTETIMING._serialized_end=1011 + _MUSA_CUTNOTES._serialized_start=1013 + _MUSA_CUTNOTES._serialized_end=1023 + _MUSA_TEMPOCHANGES._serialized_start=1025 + _MUSA_TEMPOCHANGES._serialized_end=1039 + _MUSA_INSTRUMENTSPROGS._serialized_start=1041 + _MUSA_INSTRUMENTSPROGS._serialized_end=1059 + _MUSA_INSTRUMENT._serialized_start=1061 + _MUSA_INSTRUMENT._serialized_end=1157 + _MUSA_NOTE._serialized_start=1160 + _MUSA_NOTE._serialized_end=1470 + _MUSA_BAR._serialized_start=1473 + _MUSA_BAR._serialized_end=1678 + _MUSA_BEAT._serialized_start=1681 + _MUSA_BEAT._serialized_end=1876 + _MUSA_SUBBEAT._serialized_start=1879 + _MUSA_SUBBEAT._serialized_end=2143 # @@protoc_insertion_point(module_scope) diff --git a/musicaiz/converters/protobuf/musicaiz_pb2.pyi b/musicaiz/converters/protobuf/musicaiz_pb2.pyi index e84f25f..75f8657 100644 --- a/musicaiz/converters/protobuf/musicaiz_pb2.pyi +++ b/musicaiz/converters/protobuf/musicaiz_pb2.pyi @@ -6,114 +6,198 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class Musa(_message.Message): - __slots__ = ["instruments", "time_signatures"] + __slots__ = ["absolute_timing", "bars", "beats", "cut_notes", "file", "instruments", "instruments_progs", "is_quantized", "notes", "quantize_note", "resolution", "subbeats", "subdivision_note", "tempo_changes", "time_signature_changes", "tonality", "total_bars"] + class AbsoluteTiming(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class Bar(_message.Message): + __slots__ = ["absolute_timing", "bpm", "end_sec", "end_ticks", "harmonic_density", "note_density", "resolution", "start_sec", "start_ticks", "time_sig"] + ABSOLUTE_TIMING_FIELD_NUMBER: _ClassVar[int] + BPM_FIELD_NUMBER: _ClassVar[int] + END_SEC_FIELD_NUMBER: _ClassVar[int] + END_TICKS_FIELD_NUMBER: _ClassVar[int] + HARMONIC_DENSITY_FIELD_NUMBER: _ClassVar[int] + NOTE_DENSITY_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] + START_SEC_FIELD_NUMBER: _ClassVar[int] + START_TICKS_FIELD_NUMBER: _ClassVar[int] + TIME_SIG_FIELD_NUMBER: _ClassVar[int] + absolute_timing: bool + bpm: float + end_sec: float + end_ticks: int + harmonic_density: int + note_density: int + resolution: int + start_sec: float + start_ticks: int + time_sig: str + def __init__(self, bpm: _Optional[float] = ..., time_sig: _Optional[str] = ..., resolution: _Optional[int] = ..., absolute_timing: bool = ..., note_density: _Optional[int] = ..., harmonic_density: _Optional[int] = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ...) -> None: ... + class Beat(_message.Message): + __slots__ = ["absolute_timing", "bar_idx", "bpm", "end_sec", "end_ticks", "global_idx", "resolution", "start_sec", "start_ticks", "time_sig"] + ABSOLUTE_TIMING_FIELD_NUMBER: _ClassVar[int] + BAR_IDX_FIELD_NUMBER: _ClassVar[int] + BPM_FIELD_NUMBER: _ClassVar[int] + END_SEC_FIELD_NUMBER: _ClassVar[int] + END_TICKS_FIELD_NUMBER: _ClassVar[int] + GLOBAL_IDX_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] + START_SEC_FIELD_NUMBER: _ClassVar[int] + START_TICKS_FIELD_NUMBER: _ClassVar[int] + TIME_SIG_FIELD_NUMBER: _ClassVar[int] + absolute_timing: bool + bar_idx: int + bpm: float + end_sec: float + end_ticks: int + global_idx: int + resolution: int + start_sec: float + start_ticks: int + time_sig: str + def __init__(self, bpm: _Optional[float] = ..., time_sig: _Optional[str] = ..., resolution: _Optional[int] = ..., absolute_timing: bool = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ..., global_idx: _Optional[int] = ..., bar_idx: _Optional[int] = ...) -> None: ... + class CutNotes(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class File(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... class Instrument(_message.Message): - __slots__ = ["bars", "family", "instrument", "is_drum", "name", "notes", "program"] - class Bar(_message.Message): - __slots__ = ["absolute_timing", "bpm", "end_sec", "end_ticks", "harmonic_density", "note_density", "notes", "resolution", "start_sec", "start_ticks", "time_sig"] - class Note(_message.Message): - __slots__ = ["end_sec", "end_ticks", "ligated", "note_name", "octave", "pitch", "pitch_name", "start_sec", "start_ticks", "symbolic", "velocity"] - END_SEC_FIELD_NUMBER: _ClassVar[int] - END_TICKS_FIELD_NUMBER: _ClassVar[int] - LIGATED_FIELD_NUMBER: _ClassVar[int] - NOTE_NAME_FIELD_NUMBER: _ClassVar[int] - OCTAVE_FIELD_NUMBER: _ClassVar[int] - PITCH_FIELD_NUMBER: _ClassVar[int] - PITCH_NAME_FIELD_NUMBER: _ClassVar[int] - START_SEC_FIELD_NUMBER: _ClassVar[int] - START_TICKS_FIELD_NUMBER: _ClassVar[int] - SYMBOLIC_FIELD_NUMBER: _ClassVar[int] - VELOCITY_FIELD_NUMBER: _ClassVar[int] - end_sec: float - end_ticks: int - ligated: bool - note_name: str - octave: str - pitch: int - pitch_name: str - start_sec: float - start_ticks: int - symbolic: str - velocity: int - def __init__(self, pitch: _Optional[int] = ..., pitch_name: _Optional[str] = ..., note_name: _Optional[str] = ..., octave: _Optional[str] = ..., ligated: bool = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ..., symbolic: _Optional[str] = ..., velocity: _Optional[int] = ...) -> None: ... - ABSOLUTE_TIMING_FIELD_NUMBER: _ClassVar[int] - BPM_FIELD_NUMBER: _ClassVar[int] - END_SEC_FIELD_NUMBER: _ClassVar[int] - END_TICKS_FIELD_NUMBER: _ClassVar[int] - HARMONIC_DENSITY_FIELD_NUMBER: _ClassVar[int] - NOTES_FIELD_NUMBER: _ClassVar[int] - NOTE_DENSITY_FIELD_NUMBER: _ClassVar[int] - RESOLUTION_FIELD_NUMBER: _ClassVar[int] - START_SEC_FIELD_NUMBER: _ClassVar[int] - START_TICKS_FIELD_NUMBER: _ClassVar[int] - TIME_SIG_FIELD_NUMBER: _ClassVar[int] - absolute_timing: bool - bpm: int - end_sec: float - end_ticks: int - harmonic_density: int - note_density: int - notes: _containers.RepeatedCompositeFieldContainer[Musa.Instrument.Bar.Note] - resolution: int - start_sec: float - start_ticks: int - time_sig: str - def __init__(self, bpm: _Optional[int] = ..., time_sig: _Optional[str] = ..., resolution: _Optional[int] = ..., absolute_timing: bool = ..., note_density: _Optional[int] = ..., harmonic_density: _Optional[int] = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ..., notes: _Optional[_Iterable[_Union[Musa.Instrument.Bar.Note, _Mapping]]] = ...) -> None: ... - class Note(_message.Message): - __slots__ = ["end_sec", "end_ticks", "ligated", "note_name", "octave", "pitch", "pitch_name", "start_sec", "start_ticks", "symbolic", "velocity"] - END_SEC_FIELD_NUMBER: _ClassVar[int] - END_TICKS_FIELD_NUMBER: _ClassVar[int] - LIGATED_FIELD_NUMBER: _ClassVar[int] - NOTE_NAME_FIELD_NUMBER: _ClassVar[int] - OCTAVE_FIELD_NUMBER: _ClassVar[int] - PITCH_FIELD_NUMBER: _ClassVar[int] - PITCH_NAME_FIELD_NUMBER: _ClassVar[int] - START_SEC_FIELD_NUMBER: _ClassVar[int] - START_TICKS_FIELD_NUMBER: _ClassVar[int] - SYMBOLIC_FIELD_NUMBER: _ClassVar[int] - VELOCITY_FIELD_NUMBER: _ClassVar[int] - end_sec: float - end_ticks: int - ligated: bool - note_name: str - octave: str - pitch: int - pitch_name: str - start_sec: float - start_ticks: int - symbolic: str - velocity: int - def __init__(self, pitch: _Optional[int] = ..., pitch_name: _Optional[str] = ..., note_name: _Optional[str] = ..., octave: _Optional[str] = ..., ligated: bool = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ..., symbolic: _Optional[str] = ..., velocity: _Optional[int] = ...) -> None: ... - BARS_FIELD_NUMBER: _ClassVar[int] + __slots__ = ["family", "instrument", "is_drum", "name", "program"] FAMILY_FIELD_NUMBER: _ClassVar[int] INSTRUMENT_FIELD_NUMBER: _ClassVar[int] IS_DRUM_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] - NOTES_FIELD_NUMBER: _ClassVar[int] PROGRAM_FIELD_NUMBER: _ClassVar[int] - bars: _containers.RepeatedCompositeFieldContainer[Musa.Instrument.Bar] family: str instrument: int is_drum: bool name: str - notes: _containers.RepeatedCompositeFieldContainer[Musa.Instrument.Note] program: int - def __init__(self, instrument: _Optional[int] = ..., program: _Optional[int] = ..., name: _Optional[str] = ..., family: _Optional[str] = ..., is_drum: bool = ..., bars: _Optional[_Iterable[_Union[Musa.Instrument.Bar, _Mapping]]] = ..., notes: _Optional[_Iterable[_Union[Musa.Instrument.Note, _Mapping]]] = ...) -> None: ... + def __init__(self, instrument: _Optional[int] = ..., program: _Optional[int] = ..., name: _Optional[str] = ..., family: _Optional[str] = ..., is_drum: bool = ...) -> None: ... + class InstrumentsProgs(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class IsQuantized(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... class Note(_message.Message): + __slots__ = ["bar_idx", "beat_idx", "end_sec", "end_ticks", "instrument_idx", "instrument_prog", "ligated", "note_name", "octave", "pitch", "pitch_name", "start_sec", "start_ticks", "subbeat_idx", "symbolic", "velocity"] + BAR_IDX_FIELD_NUMBER: _ClassVar[int] + BEAT_IDX_FIELD_NUMBER: _ClassVar[int] + END_SEC_FIELD_NUMBER: _ClassVar[int] + END_TICKS_FIELD_NUMBER: _ClassVar[int] + INSTRUMENT_IDX_FIELD_NUMBER: _ClassVar[int] + INSTRUMENT_PROG_FIELD_NUMBER: _ClassVar[int] + LIGATED_FIELD_NUMBER: _ClassVar[int] + NOTE_NAME_FIELD_NUMBER: _ClassVar[int] + OCTAVE_FIELD_NUMBER: _ClassVar[int] + PITCH_FIELD_NUMBER: _ClassVar[int] + PITCH_NAME_FIELD_NUMBER: _ClassVar[int] + START_SEC_FIELD_NUMBER: _ClassVar[int] + START_TICKS_FIELD_NUMBER: _ClassVar[int] + SUBBEAT_IDX_FIELD_NUMBER: _ClassVar[int] + SYMBOLIC_FIELD_NUMBER: _ClassVar[int] + VELOCITY_FIELD_NUMBER: _ClassVar[int] + bar_idx: int + beat_idx: int + end_sec: float + end_ticks: int + instrument_idx: int + instrument_prog: int + ligated: bool + note_name: str + octave: str + pitch: int + pitch_name: str + start_sec: float + start_ticks: int + subbeat_idx: int + symbolic: str + velocity: int + def __init__(self, pitch: _Optional[int] = ..., pitch_name: _Optional[str] = ..., note_name: _Optional[str] = ..., octave: _Optional[str] = ..., ligated: bool = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ..., symbolic: _Optional[str] = ..., velocity: _Optional[int] = ..., bar_idx: _Optional[int] = ..., beat_idx: _Optional[int] = ..., subbeat_idx: _Optional[int] = ..., instrument_idx: _Optional[int] = ..., instrument_prog: _Optional[int] = ...) -> None: ... + class QuantizeNote(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class Resolution(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class Subbeat(_message.Message): + __slots__ = ["absolute_timing", "bar_idx", "beat_idx", "bpm", "end_sec", "end_ticks", "global_idx", "harmonic_density", "note_density", "resolution", "start_sec", "start_ticks", "time_sig"] + ABSOLUTE_TIMING_FIELD_NUMBER: _ClassVar[int] + BAR_IDX_FIELD_NUMBER: _ClassVar[int] + BEAT_IDX_FIELD_NUMBER: _ClassVar[int] + BPM_FIELD_NUMBER: _ClassVar[int] + END_SEC_FIELD_NUMBER: _ClassVar[int] + END_TICKS_FIELD_NUMBER: _ClassVar[int] + GLOBAL_IDX_FIELD_NUMBER: _ClassVar[int] + HARMONIC_DENSITY_FIELD_NUMBER: _ClassVar[int] + NOTE_DENSITY_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] + START_SEC_FIELD_NUMBER: _ClassVar[int] + START_TICKS_FIELD_NUMBER: _ClassVar[int] + TIME_SIG_FIELD_NUMBER: _ClassVar[int] + absolute_timing: bool + bar_idx: int + beat_idx: int + bpm: float + end_sec: float + end_ticks: int + global_idx: int + harmonic_density: int + note_density: int + resolution: int + start_sec: float + start_ticks: int + time_sig: str + def __init__(self, bpm: _Optional[float] = ..., time_sig: _Optional[str] = ..., resolution: _Optional[int] = ..., absolute_timing: bool = ..., note_density: _Optional[int] = ..., harmonic_density: _Optional[int] = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ..., global_idx: _Optional[int] = ..., bar_idx: _Optional[int] = ..., beat_idx: _Optional[int] = ...) -> None: ... + class SubdivisionNote(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class TempoChanges(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class TimeSignatureChanges(_message.Message): __slots__ = [] def __init__(self) -> None: ... - class TimeSignature(_message.Message): - __slots__ = ["denom", "num"] - DENOM_FIELD_NUMBER: _ClassVar[int] - NUM_FIELD_NUMBER: _ClassVar[int] - denom: int - num: int - def __init__(self, num: _Optional[int] = ..., denom: _Optional[int] = ...) -> None: ... class Tonality(_message.Message): __slots__ = [] def __init__(self) -> None: ... + class TotalBars(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + ABSOLUTE_TIMING_FIELD_NUMBER: _ClassVar[int] + BARS_FIELD_NUMBER: _ClassVar[int] + BEATS_FIELD_NUMBER: _ClassVar[int] + CUT_NOTES_FIELD_NUMBER: _ClassVar[int] + FILE_FIELD_NUMBER: _ClassVar[int] INSTRUMENTS_FIELD_NUMBER: _ClassVar[int] - TIME_SIGNATURES_FIELD_NUMBER: _ClassVar[int] + INSTRUMENTS_PROGS_FIELD_NUMBER: _ClassVar[int] + IS_QUANTIZED_FIELD_NUMBER: _ClassVar[int] + NOTES_FIELD_NUMBER: _ClassVar[int] + QUANTIZE_NOTE_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] + SUBBEATS_FIELD_NUMBER: _ClassVar[int] + SUBDIVISION_NOTE_FIELD_NUMBER: _ClassVar[int] + TEMPO_CHANGES_FIELD_NUMBER: _ClassVar[int] + TIME_SIGNATURE_CHANGES_FIELD_NUMBER: _ClassVar[int] + TONALITY_FIELD_NUMBER: _ClassVar[int] + TOTAL_BARS_FIELD_NUMBER: _ClassVar[int] + absolute_timing: _containers.RepeatedCompositeFieldContainer[Musa.AbsoluteTiming] + bars: _containers.RepeatedCompositeFieldContainer[Musa.Bar] + beats: _containers.RepeatedCompositeFieldContainer[Musa.Beat] + cut_notes: _containers.RepeatedCompositeFieldContainer[Musa.CutNotes] + file: _containers.RepeatedCompositeFieldContainer[Musa.File] instruments: _containers.RepeatedCompositeFieldContainer[Musa.Instrument] - time_signatures: _containers.RepeatedCompositeFieldContainer[Musa.TimeSignature] - def __init__(self, time_signatures: _Optional[_Iterable[_Union[Musa.TimeSignature, _Mapping]]] = ..., instruments: _Optional[_Iterable[_Union[Musa.Instrument, _Mapping]]] = ...) -> None: ... + instruments_progs: _containers.RepeatedCompositeFieldContainer[Musa.InstrumentsProgs] + is_quantized: _containers.RepeatedCompositeFieldContainer[Musa.IsQuantized] + notes: _containers.RepeatedCompositeFieldContainer[Musa.Note] + quantize_note: _containers.RepeatedCompositeFieldContainer[Musa.QuantizeNote] + resolution: _containers.RepeatedCompositeFieldContainer[Musa.Resolution] + subbeats: _containers.RepeatedCompositeFieldContainer[Musa.Subbeat] + subdivision_note: _containers.RepeatedCompositeFieldContainer[Musa.SubdivisionNote] + tempo_changes: _containers.RepeatedCompositeFieldContainer[Musa.TempoChanges] + time_signature_changes: _containers.RepeatedCompositeFieldContainer[Musa.TimeSignatureChanges] + tonality: _containers.RepeatedCompositeFieldContainer[Musa.Tonality] + total_bars: _containers.RepeatedCompositeFieldContainer[Musa.TotalBars] + def __init__(self, time_signature_changes: _Optional[_Iterable[_Union[Musa.TimeSignatureChanges, _Mapping]]] = ..., subdivision_note: _Optional[_Iterable[_Union[Musa.SubdivisionNote, _Mapping]]] = ..., file: _Optional[_Iterable[_Union[Musa.File, _Mapping]]] = ..., total_bars: _Optional[_Iterable[_Union[Musa.TotalBars, _Mapping]]] = ..., tonality: _Optional[_Iterable[_Union[Musa.Tonality, _Mapping]]] = ..., resolution: _Optional[_Iterable[_Union[Musa.Resolution, _Mapping]]] = ..., is_quantized: _Optional[_Iterable[_Union[Musa.IsQuantized, _Mapping]]] = ..., quantize_note: _Optional[_Iterable[_Union[Musa.QuantizeNote, _Mapping]]] = ..., absolute_timing: _Optional[_Iterable[_Union[Musa.AbsoluteTiming, _Mapping]]] = ..., cut_notes: _Optional[_Iterable[_Union[Musa.CutNotes, _Mapping]]] = ..., tempo_changes: _Optional[_Iterable[_Union[Musa.TempoChanges, _Mapping]]] = ..., instruments_progs: _Optional[_Iterable[_Union[Musa.InstrumentsProgs, _Mapping]]] = ..., instruments: _Optional[_Iterable[_Union[Musa.Instrument, _Mapping]]] = ..., bars: _Optional[_Iterable[_Union[Musa.Bar, _Mapping]]] = ..., notes: _Optional[_Iterable[_Union[Musa.Note, _Mapping]]] = ..., beats: _Optional[_Iterable[_Union[Musa.Beat, _Mapping]]] = ..., subbeats: _Optional[_Iterable[_Union[Musa.Subbeat, _Mapping]]] = ...) -> None: ... diff --git a/musicaiz/datasets/jsbchorales.py b/musicaiz/datasets/jsbchorales.py index 1a83f96..d727dfc 100644 --- a/musicaiz/datasets/jsbchorales.py +++ b/musicaiz/datasets/jsbchorales.py @@ -80,3 +80,30 @@ def tokenize( args, True ) + + +# TODO: args parsing here +if __name__ == "__main__": + args = MMMTokenizerArguments( + prev_tokens="", + windowing=True, + time_unit="HUNDRED_TWENTY_EIGHT", + num_programs=None, + shuffle_tracks=True, + track_density=False, + window_size=32, + hop_length=16, + time_sig=True, + velocity=True, + ) + dataset = JSBChorales() + dataset.tokenize( + dataset_path="H:/INVESTIGACION/Datasets/JSB Chorales/", + output_path="H:/GitHub/musanalysis-datasets/jsbchorales/mmm/32_bars_166", + output_file="token-sequences", + args=args, + tokenize_split="validation" + ) + vocab = MMMTokenizer.get_vocabulary( + dataset_path="H:/GitHub/musanalysis-datasets/jsbchorales/mmm/32_bars_166" + ) diff --git a/musicaiz/datasets/lmd.py b/musicaiz/datasets/lmd.py index e91ba42..8919842 100644 --- a/musicaiz/datasets/lmd.py +++ b/musicaiz/datasets/lmd.py @@ -163,3 +163,30 @@ def get_metadata( } ) return composers_json + + +# TODO: args parsing here +if __name__ == "__main__": + args = MMMTokenizerArguments( + prev_tokens="", + windowing=True, + time_unit="HUNDRED_TWENTY_EIGHT", + num_programs=None, + shuffle_tracks=True, + track_density=False, + window_size=32, + hop_length=16, + time_sig=True, + velocity=True, + ) + dataset = LakhMIDI() + dataset.tokenize( + dataset_path="H:/INVESTIGACION/Datasets/LMD/clean_midi", + output_path="H:/GitHub/musanalysis-datasets/lmd/mmm/32_bars_166", + output_file="token-sequences", + args=args, + tokenize_split="validation" + ) + vocab = MMMTokenizer.get_vocabulary( + dataset_path="H:/GitHub/musanalysis-datasets/lmd/mmm/32_bars_166" + ) diff --git a/musicaiz/datasets/maestro.py b/musicaiz/datasets/maestro.py index 47c1476..0645fc6 100644 --- a/musicaiz/datasets/maestro.py +++ b/musicaiz/datasets/maestro.py @@ -73,7 +73,7 @@ class ComposerPeriods(Enum): LEOS_JANACEK = "ROMANTICISM" LUDWIG_VAN_BEETHOVEN = "CLASSICISM" MIKHAIL_GLINKA = "ROMANTICISM" - MILY_BALARIKEV = "ROMANTICISM" + MILY_BALARIKEV = "ROMANTICISM" # TODO: MILY_BALAKIREV MODEST_MUSSORGSKY = "ROMANTICISM" MUZIO_CLEMENTI = "CLASSICISM" NICCOLO_PAGANINI = "CLASSICISM" @@ -199,3 +199,31 @@ def get_metadata(dataset_path: Union[str, Path]) -> Dict[str, str]: } ) return composers_json + + +# TODO: args parsing here +if __name__ == "__main__": + args = MMMTokenizerArguments( + prev_tokens="", + windowing=True, + time_unit="HUNDRED_TWENTY_EIGHT", + num_programs=None, + shuffle_tracks=True, + track_density=False, + window_size=32, + hop_length=16, + time_sig=True, + velocity=True, + tempo=True, + ) + dataset = Maestro() + dataset.tokenize( + dataset_path="H:/INVESTIGACION/Datasets/MAESTRO/", + output_path="H:/GitHub/musanalysis-datasets/maestro/mmm/32_bars_16", + output_file="token-sequences", + args=args, + tokenize_split="all" + ) + vocab = MMMTokenizer.get_vocabulary( + dataset_path="H:/GitHub/musanalysis-datasets/maestro/mmm/32_bars_16" + ) diff --git a/musicaiz/datasets/utils.py b/musicaiz/datasets/utils.py index 4466bb8..3c49e8e 100644 --- a/musicaiz/datasets/utils.py +++ b/musicaiz/datasets/utils.py @@ -8,6 +8,8 @@ from musicaiz.tokenizers import ( MMMTokenizer, MMMTokenizerArguments, + REMITokenizer, + REMITokenizerArguments, TOKENIZER_ARGUMENTS, TokenizerArguments, ) @@ -33,7 +35,7 @@ def tokenize_path( else: elements = dataset_path.rglob("*.mid") elements = [f.name for f in dataset_path.rglob("*.mid")] - total = len(list(dataset_path.glob("*.mid"))) + total = len(elements) for el in elements: # Some files in LMD hace errors (OSError: data byte must be in range 0..127), @@ -111,13 +113,17 @@ def _processer( if type(args) not in TOKENIZER_ARGUMENTS: raise ValueError("Non valid tokenizer args object.") if isinstance(args, MMMTokenizerArguments): - args.prev_tokens=prev_tokens + args.prev_tokens = prev_tokens tokenizer = MMMTokenizer(file, args) piece_tokens = tokenizer.tokenize_file() + elif isinstance(args, REMITokenizerArguments): + args.prev_tokens = prev_tokens + tokenizer = REMITokenizer(file, args) + piece_tokens = tokenizer.tokenize_file() else: raise ValueError("Non valid tokenizer.") piece_tokens += "\n" return piece_tokens except: - pass \ No newline at end of file + pass diff --git a/musicaiz/features/harmony.py b/musicaiz/features/harmony.py index 645f4b3..3d55725 100644 --- a/musicaiz/features/harmony.py +++ b/musicaiz/features/harmony.py @@ -18,7 +18,6 @@ Note, NoteClassBase, ) -from musicaiz import loaders def _extract_note_positions(note_seq: List[Note]) -> List[int]: @@ -324,15 +323,15 @@ def get_harmonic_density(note_seq: List[Note]) -> int: if len(note_seq) == 0: return 0 # Go tick per tick - latest_note = loaders.Musa._last_note(note_seq) + latest_note = note_seq[-1] counts = [] - step_ticks = 1 + step_ticks = 10 # We'll compute by steps of 10 ticks which is a low value for i in range(0, latest_note.end_ticks, step_ticks): count = 0 for note_idx, note in enumerate(note_seq): # if note ends aftre the next step start, count it - if note.start_ticks < step_ticks * i and note.end_ticks >= step_ticks * i: + if note.start_ticks < i and note.end_ticks >= i: count += 1 counts.append(count) return max(counts) diff --git a/musicaiz/harmony/chords.py b/musicaiz/harmony/chords.py index 49ae9b4..fb5f578 100644 --- a/musicaiz/harmony/chords.py +++ b/musicaiz/harmony/chords.py @@ -251,7 +251,7 @@ def get_all_qualities(cls) -> List[str]: for n in note.value: all_notes.append(n) return all_notes - + # TODO: Finish this @classmethod def get_chord_from_name(cls, chord_name: str) -> AllChords: @@ -337,7 +337,7 @@ def _check_inversion_with_quality(self, inversion: int): - Only higher than 13th chords do have a 6th inversion""" if self.type.value - 1 < inversion: raise ValueError(f"Chord quality {self.quality_name} does not have a {inversion} inversion.") - + @classmethod def get_all_chords(cls) -> Dict[str, Tuple[NoteClassBase, AllChords]]: """ @@ -362,7 +362,7 @@ def get_all_chords(cls) -> Dict[str, Tuple[NoteClassBase, AllChords]]: if not any(e in chord.name for e in exclude): all_chords.update({tonic.name + "-" + chord.name: (tonic, chord)}) return all_chords - + @classmethod def get_notes_from_chord( cls, @@ -374,7 +374,7 @@ def get_notes_from_chord( Parameters ---------- - + tonic: NoteClassBase quality: AllChords @@ -382,7 +382,7 @@ def get_notes_from_chord( Returns ------- - + List[NoteClassBase] """ notes = [tonic] @@ -394,7 +394,7 @@ def get_notes_from_chord( note_dest_obj = interval_inst.transpose_note(note_obj) notes.append(note_dest_obj.note) return notes - + @classmethod def get_pitches_from_chord( cls, @@ -417,7 +417,7 @@ def get_pitches_from_chord( Returns ------- - + List[int] """ notes = cls.get_notes_from_chord(tonic, quality) diff --git a/musicaiz/harmony/keys.py b/musicaiz/harmony/keys.py index d965f27..fc82740 100644 --- a/musicaiz/harmony/keys.py +++ b/musicaiz/harmony/keys.py @@ -1358,7 +1358,7 @@ def scale_notes(self, scale: str) -> List[NoteClassBase]: a submode. This is only used in the case of minor scales (harmonic or melodic) and greek scales. The values that support the scales arg are: :func:`~musicaiz.harmony.Scales`. - + Examples -------- Major tonalities: @@ -1371,7 +1371,7 @@ def scale_notes(self, scale: str) -> List[NoteClassBase]: >>> tonality.scale_notes("IONIAN") Minor tonalities: - + >>> tonality = Tonality.C_MINOR >>> tonality.scale_notes("NATURAL") >>> tonality.scale_notes("HARMONIC") diff --git a/musicaiz/loaders.py b/musicaiz/loaders.py index fd6d73a..1ace13b 100644 --- a/musicaiz/loaders.py +++ b/musicaiz/loaders.py @@ -20,14 +20,12 @@ import mido import functools import numpy as np +from traitlets import Callable # Our modules -from musicaiz.structure import ( - Note, - Instrument, - Bar -) +from musicaiz.structure import Note, Instrument, Bar +from musicaiz.errors import BarIdxErrorMessage from musicaiz.rhythm import ( TimingConsts, get_subdivisions, @@ -36,13 +34,15 @@ advanced_quantizer, get_ticks_from_subdivision, ms_per_tick, + ms_per_bar, TimeSignature, + Beat, + Subdivision, ) -from musicaiz.features import harmony -from musicaiz.algorithms import ( - key_detection, - KeyDetectionAlgorithms -) +from musicaiz.converters import musa_to_prettymidi +from musicaiz.features import get_harmonic_density +from musicaiz.algorithms import key_detection, KeyDetectionAlgorithms +from tests.unit.musicaiz import converters class ValidFiles(Enum): @@ -58,10 +58,7 @@ def all_extensions(cls) -> List[str]: return all -STRUCTURES = ["bars", "instruments"] - - -class Musa: +class MusaII: """Musanalisys main object. This object loads a file and maps it to the musicaiz' objects defined in the submodules `harmony` and `structure`. @@ -76,7 +73,7 @@ class Musa: Organices the attributes at different structure levels which are bar, instrument or piece level. Defaults to "piece". - + quantize: bool Default is True. Quantizes the notes at bar or instrument level with the `rhythm.advanced_quantizer` method that uses a strength of 100%. @@ -100,7 +97,7 @@ class Musa: resolution: int the pulses o ticks per quarter note (PPQ or TPQN). If this parameter is not initialized we suppose a resolution (sequencer ticks) of 960 ticks. - + absolute_timing: bool default is True. This allows to initialize note time arguments in absolute (True) or relative time units (False). Relative units means that each bar will start at 0 seconds @@ -179,13 +176,15 @@ def __init__( if msg.type == "set_tempo": self.bpm = int(mido.tempo2bpm(msg.tempo)) elif msg.type == "time_signature": - self.time_sig = TimeSignature(str(msg.numerator) + "/" + str(msg.denominator)) - + self.time_sig = TimeSignature( + str(msg.numerator) + "/" + str(msg.denominator) + ) + # initialize midi object with pretty_midi pm_inst = pm.PrettyMIDI( midi_file=self.file, resolution=self.resolution, - initial_tempo=self.bpm + initial_tempo=self.bpm, ) # The MIDI file might not have the time signature not tempo information, # in that case, we initialize them as defalut (120bpm 4/4) @@ -194,7 +193,21 @@ def __init__( # Divide notes into instrument and bars or just into instruments # depending on th evalue of the input argument `structure` - if self.structure == "bars": + if self.structure == "instrument_bars": + # Map instruments to load them as musicaiz instrument class + self._load_instruments(pm_inst) + self.notes = [] + self._load_inst_bars() + for instrument in self.instruments: + self._load_bars_notes( + instrument, + absolute_timing=self.absolute_timing, + cut_notes=self.cut_notes, + ) + # TODO: All instr must have the same total_bars, we should get the track with more bars and + # append empty bars to the rest of the tracks + self.total_bars = len(self.instruments[0].bars) + elif self.structure == "bars": # Map instruments to load them as musicaiz instrument class self._load_instruments(pm_inst) self.notes = [] @@ -208,19 +221,29 @@ def __init__( self._load_bars_notes( instrument, absolute_timing=self.absolute_timing, - cut_notes=self.cut_notes + cut_notes=self.cut_notes, ) # TODO: All instr must have the same total_bars, we should get the track with more bars and # append empty bars to the rest of the tracks self.total_bars = len(self.instruments[0].bars) + elif self.structure == "notes": + # Concatenate all the notes of different instruments + # this is for getting the latest note of the piece + # and get the total number of bars of the piece + for instrument in pm_inst.instruments: + self.notes.extend(instrument.notes) + self.instruments = [] + self.bars = [] elif self.structure == "instruments": self._load_instruments(pm_inst) for instrument in self.instruments: instrument.bars = None self.notes.extend(instrument.notes) - self.total_bars = self.get_total_bars(self.notes) + self.total_bars = self.get_total_bars(self.notes) else: - raise ValueError(f"Structure argument value {structure} is not valid.") + raise ValueError( + f"Structure argument value {structure} is not valid." + ) elif self.is_musicxml(file): # initialize musicxml object with ?? @@ -235,15 +258,15 @@ def __init__( time_sig=self.time_sig.time_sig, bpm=self.bpm, resolution=self.resolution, - absolute_timing=self.absolute_timing + absolute_timing=self.absolute_timing, ) v_grid = get_ticks_from_subdivision(grid) for instrument in self.instruments: if self.structure == "bars": - for bar in instrument.bars: + for bar in instrument.bars: advanced_quantizer(bar.notes, v_grid) advanced_quantizer(instrument.notes, v_grid) - + # sort the notes in all the midi file self.notes.sort(key=lambda x: x.start_ticks, reverse=False) @@ -267,36 +290,22 @@ def is_musicxml(cls, file: Union[str, TextIO]): extension = cls.get_file_extension(file) return True if extension in ValidFiles.MUSIC_XML.value else False - def _load_instruments(self, pm_inst) -> List[Note]: - """Populates `instruments` attribute mapping pretty_midi instruments - to musicaiz instrument class.""" - # Load pretty midi instrument - for i, instrument in enumerate(pm_inst.instruments): - self.instruments.append( - Instrument( - program=instrument.program, - name=instrument.name, - is_drum=instrument.is_drum, - general_midi=False, - ) - ) - for note in instrument.notes: - # Initialize the note with musicaiz `Note` object - self.instruments[i].notes.append( - Note( - pitch=note.pitch, - start=note.start, - end=note.end, - velocity=note.velocity, + def _load_inst_bars(self): + """Load the bars for an instrument.""" + total_bars = self.get_total_bars(self.notes) + for instrument in self.instruments: + for _ in range(total_bars): + instrument.bars.append( + Bar( + time_sig=self.time_sig.time_sig, bpm=self.bpm, - resolution=self.resolution, ) ) def _load_bars(self): """Load the bars for an instrument.""" total_bars = self.get_total_bars(self.notes) - for instrument in self.instruments: + for instrument in pm_inst.instruments: for _ in range(total_bars): instrument.bars.append( Bar( @@ -309,7 +318,7 @@ def _load_bars_notes( self, instrument: Instrument, cut_notes: bool = False, - absolute_timing: bool = True + absolute_timing: bool = True, ): start_bar_ticks = 0 _, bar_ticks = ticks_per_bar(self.time_sig.time_sig, self.resolution) @@ -319,11 +328,13 @@ def _load_bars_notes( bar.notes.append(n) notes_next_bar = [] next_start_bar_ticks = start_bar_ticks + bar_ticks - # bar obj attributes + # bar obj attributes if self.absolute_timing: bar.start_ticks = start_bar_ticks bar.end_ticks = next_start_bar_ticks - bar.start_sec = bar.start_ticks * ms_per_tick(self.bpm, self.resolution) / 1000 + bar.start_sec = ( + bar.start_ticks * ms_per_tick(self.bpm, self.resolution) / 1000 + ) else: bar.start_ticks, bar.start_sec = 0, 0.0 bar.end_ticks = bar.start_ticks + bar_ticks @@ -331,21 +342,35 @@ def _load_bars_notes( for i, note in enumerate(instrument.notes): # TODO: If note ends after the next bar start? Fix this, like this we'll loose it - if note.start_ticks >= start_bar_ticks and note.end_ticks <= next_start_bar_ticks: + if ( + note.start_ticks >= start_bar_ticks + and note.end_ticks <= next_start_bar_ticks + ): bar.notes.append(note) # note starts in current bar but ends in the next (or nexts bars) -> cut note - elif start_bar_ticks <= note.start_ticks <= next_start_bar_ticks and note.end_ticks >= next_start_bar_ticks: + elif ( + start_bar_ticks <= note.start_ticks <= next_start_bar_ticks + and note.end_ticks >= next_start_bar_ticks + ): if cut_notes: # cut note by creating a new note that starts when the next bar starts note_next_bar = Note( - start=next_start_bar_ticks, end=note.end_ticks, - pitch=note.pitch, velocity=note.velocity, ligated=True + start=next_start_bar_ticks, + end=note.end_ticks, + pitch=note.pitch, + velocity=note.velocity, + ligated=True, ) notes_next_bar.append(note_next_bar) # cut note by assigning end note to the current end bar note.end_ticks = next_start_bar_ticks - note.end_secs = next_start_bar_ticks * ms_per_tick(self.bpm, self.resolution) / 1000 + note.end_secs = ( + next_start_bar_ticks + * ms_per_tick(self.bpm, self.resolution) + / 1000 + ) note.ligated = True + note.instrument_prog = instrument.program bar.notes.append(note) elif note.start_ticks > next_start_bar_ticks: break @@ -404,9 +429,11 @@ def predict_key(self, method: str) -> str: key = key_detection(notes, method) elif self.structure == "instruments": raise ValueError("Initialize the Musa with `structure=bars`") - elif method in KeyDetectionAlgorithms.KRUMHANSL_KESSLER.value or \ - KeyDetectionAlgorithms.TEMPERLEY.value or \ - KeyDetectionAlgorithms.ALBRETCH_SHANAHAN.value: + elif ( + method in KeyDetectionAlgorithms.KRUMHANSL_KESSLER.value + or KeyDetectionAlgorithms.TEMPERLEY.value + or KeyDetectionAlgorithms.ALBRETCH_SHANAHAN.value + ): key = key_detection(self.notes, method) return key @@ -476,28 +503,29 @@ def _event_compare(event1, event2): # The spacing for these scores is 256, which is larger than the # largest value a MIDI value can take. secondary_sort = { - 'set_tempo': lambda e: (1 * 256 * 256), - 'time_signature': lambda e: (2 * 256 * 256), - 'key_signature': lambda e: (3 * 256 * 256), - 'lyrics': lambda e: (4 * 256 * 256), - 'text_events' :lambda e: (5 * 256 * 256), - 'program_change': lambda e: (6 * 256 * 256), - 'pitchwheel': lambda e: ((7 * 256 * 256) + e.pitch), - 'control_change': lambda e: ( - (8 * 256 * 256) + (e.control * 256) + e.value), - 'note_off': lambda e: ((9 * 256 * 256) + (e.note * 256)), - 'note_on': lambda e: ( - (10 * 256 * 256) + (e.note * 256) + e.velocity), - 'end_of_track': lambda e: (11 * 256 * 256) + "set_tempo": lambda e: (1 * 256 * 256), + "time_signature": lambda e: (2 * 256 * 256), + "key_signature": lambda e: (3 * 256 * 256), + "lyrics": lambda e: (4 * 256 * 256), + "text_events": lambda e: (5 * 256 * 256), + "program_change": lambda e: (6 * 256 * 256), + "pitchwheel": lambda e: ((7 * 256 * 256) + e.pitch), + "control_change": lambda e: ((8 * 256 * 256) + (e.control * 256) + e.value), + "note_off": lambda e: ((9 * 256 * 256) + (e.note * 256)), + "note_on": lambda e: ((10 * 256 * 256) + (e.note * 256) + e.velocity), + "end_of_track": lambda e: (11 * 256 * 256), } # If the events have the same tick, and both events have types # which appear in the secondary_sort dictionary, use the dictionary # to determine their ordering. - if (event1.time == event2.time and - event1.type in secondary_sort and - event2.type in secondary_sort): - return (secondary_sort[event1.type](event1) - - secondary_sort[event2.type](event2)) + if ( + event1.time == event2.time + and event1.type in secondary_sort + and event2.type in secondary_sort + ): + return secondary_sort[event1.type](event1) - secondary_sort[event2.type]( + event2 + ) # Otherwise, just return the difference of their ticks. return event1.time - event2.time @@ -520,14 +548,16 @@ def write_midi(self, filename: str): # Write BPM timing_track.append( mido.MetaMessage( - "set_tempo", time=0, + "set_tempo", + time=0, # Convert from microseconds per quarter note to BPM - tempo=self.bpm) + tempo=self.bpm, + ) ) - #Write key TODO - #timing_track.append( - #mido.MetaMessage("key_signature", time=self.time_to_tick(ks.time), - #key=key_number_to_mido_key_name[ks.key_number])) + # Write key TODO + # timing_track.append( + # mido.MetaMessage("key_signature", time=self.time_to_tick(ks.time), + # key=key_number_to_mido_key_name[ks.key_number])) for n, instrument in enumerate(self.instruments): # Perharps notes are grouped in bars, concatenate them @@ -539,8 +569,9 @@ def write_midi(self, filename: str): track = mido.MidiTrack() # Add track name event if instrument has a name if instrument.name: - track.append(mido.MetaMessage( - 'track_name', time=0, name=instrument.name)) + track.append( + mido.MetaMessage("track_name", time=0, name=instrument.name) + ) # If it's a drum event, we need to set channel to 9 if instrument.is_drum: channel = 9 @@ -548,15 +579,20 @@ def write_midi(self, filename: str): else: channel = 8 # channels[n % len(channels)] # Set the program number - track.append(mido.Message( - 'program_change', time=0, program=instrument.program, - channel=channel)) + track.append( + mido.Message( + "program_change", + time=0, + program=instrument.program, + channel=channel, + ) + ) # Add all note events ligated_notes = [] for idx, note in enumerate(instrument.notes): if note.ligated: ligated_notes.append(note) - for next_note in instrument.notes[idx+1:]: + for next_note in instrument.notes[idx + 1 :]: if not next_note.ligated: continue ligated_notes.append(next_note) @@ -567,13 +603,25 @@ def write_midi(self, filename: str): note.end_sec = ligated_notes[-1].end_sec ligated_notes = [] # Construct the note-on event - track.append(mido.Message( - 'note_on', time=note.start_ticks, - channel=channel, note=note.pitch, velocity=note.velocity)) + track.append( + mido.Message( + "note_on", + time=note.start_ticks, + channel=channel, + note=note.pitch, + velocity=note.velocity, + ) + ) # Also need a note-off event (note on with velocity 0) - track.append(mido.Message( - 'note_on', time=note.end_ticks, - channel=channel, note=note.pitch, velocity=0)) + track.append( + mido.Message( + "note_on", + time=note.end_ticks, + channel=channel, + note=note.pitch, + velocity=0, + ) + ) # Sort all the events using the event_compare comparator. track = sorted(track, key=functools.cmp_to_key(self._event_compare)) @@ -581,17 +629,18 @@ def write_midi(self, filename: str): # If there's a note off event and a note on event with the same # tick and pitch, put the note off event first for n, (event1, event2) in enumerate(zip(track[:-1], track[1:])): - if (event1.time == event2.time and - event1.type == 'note_on' and - event2.type == 'note_on' and - event1.note == event2.note and - event1.velocity != 0 and - event2.velocity == 0): + if ( + event1.time == event2.time + and event1.type == "note_on" + and event2.type == "note_on" + and event1.note == event2.note + and event1.velocity != 0 + and event2.velocity == 0 + ): track[n] = event2 track[n + 1] = event1 # Finally, add in an end of track event - track.append(mido.MetaMessage( - 'end_of_track', time=track[-1].time + 1)) + track.append(mido.MetaMessage("end_of_track", time=track[-1].time + 1)) # Add to the list of output tracks mid.tracks.append(track) # Turn ticks to relative time from absolute @@ -601,7 +650,7 @@ def write_midi(self, filename: str): event.time -= tick tick += event.time mid.save(filename + ".mid") - + def fluidsynth(self, fs=44100, sf2_path=None): """Synthesize using fluidsynth. Parameters @@ -619,16 +668,871 @@ def fluidsynth(self, fs=44100, sf2_path=None): """ # If there are no instruments, or all instruments have no notes, return # an empty array - if len(self.instruments) == 0 or all(len(i.notes) == 0 for i in self.instruments): + if len(self.instruments) == 0 or all( + len(i.notes) == 0 for i in self.instruments + ): return np.array([]) # Get synthesized waveform for each instrument - waveforms = [i.fluidsynth(fs=fs, - sf2_path=sf2_path) for i in self.instruments] + waveforms = [i.fluidsynth(fs=fs, sf2_path=sf2_path) for i in self.instruments] # Allocate output waveform, with #sample = max length of all waveforms synthesized = np.zeros(np.max([w.shape[0] for w in waveforms])) # Sum all waveforms in for waveform in waveforms: - synthesized[:waveform.shape[0]] += waveform + synthesized[: waveform.shape[0]] += waveform # Normalize synthesized /= np.abs(synthesized).max() - return synthesized \ No newline at end of file + return synthesized + + + +class Musa: + + __slots__ = [ + "file", + "tonality", + "time_signature_changes", + "resolution", + "instruments", + "is_quantized", + "total_bars", + "absolute_timing", + "cut_notes", + "notes", + "bars", + "tempo_changes", + "instruments_progs", + "quantize_note", + "general_midi", + "subdivision_note", + "subbeats", + "beats", + ] + + # subdivision_note and quantize_note + # lower than a quarter note which is a beat in X/4 bars + VALID_SUBDIVISIONS = [ + "eight", + "sixteenth", + "thirty_two", + "sixty_four", + "hundred_twenty_eight", + ] + + def __init__( + self, + file: Optional[Union[str, TextIO, Path]], + quantize: bool = False, + quantize_note: Optional[str] = "sixteenth", + cut_notes: bool = False, + tonality: Optional[str] = None, + resolution: Optional[int] = None, + absolute_timing: bool = True, + general_midi: bool = False, + subdivision_note: str = "sixteenth" + ): + + """ + Structure: attributes that contains lists of Note and Instrument objects. + Time: attributes that contains lists of Bar, Beat and Subdivision + objects. + + A MIDI file can contain time signature changes, so each Beat objects are equivalent + to the Bar they belong to. Ex.: a 2/4 time signature will contain 2 beats = 2 quarter + notes whereas a 3/8 bar will contain 3 beats = 3 eight notes. + """ + + # TODO: quantize + # TODO: relative times in notes? + # TODO: cut_notes + # TODO: assign notes their name when key is known + # TODO: key signature changes, + # TODO: write_midi -> with pretty_midi + # TODO: synthesize -> with pretty_midi + + self.instruments = [] + self.notes = [] + self.total_bars = 0 + self.general_midi = general_midi + self.absolute_timing = absolute_timing + self.is_quantized = quantize + self.quantize_note = quantize_note + self.subdivision_note = subdivision_note + self.subbeats = [] + self.beats = [] + self.cut_notes = cut_notes + self.time_signature_changes = [] + self.bars = [] + self.tempo_changes = [] + + # TODO unify quantize_note as subdivision_note? + + if subdivision_note not in self.VALID_SUBDIVISIONS: + raise ValueError( + "{subdivision_note} is not valid subdivision_note. " \ + "Valid values are: {self.VALID_SUBDIVISIONS}" + ) + + # File provided + if file is not None: + if isinstance(file, str): + file = Path(file) + if self.is_valid(file): + self.file = file + else: + raise ValueError("Input file extension is not valid.") + if self.is_midi(file): + self._load_midifile(resolution, tonality) + + + # group subdivisions in beats + + + # group subdivisions in bars + + def json(self): + return {key : getattr(self, key, None) for key in self.__slots__} + + def _load_midifile( + self, + resolution: int, + tonality: str, + ): + # initialize midi object with pretty_midi + pm_inst = pm.PrettyMIDI( + midi_file=str(self.file), + ) + if resolution is None: + self.resolution = pm_inst.resolution + else: + self.resolution = resolution + prev_time = -1 + for time_sig in pm_inst.time_signature_changes: + # there cannot be 2 different time sigs at the same time + if time_sig.time == prev_time: + continue + self.time_signature_changes.append( + { + "time_sig": TimeSignature( + (time_sig.numerator, time_sig.denominator) + ), + "ms": time_sig.time * 1000, + } + ) + prev_time = time_sig.time + + # if the time signature is not defined, we'll initialize it as + # 4/4 by default + if len(self.time_signature_changes) == 0: + self.time_signature_changes = [ + { + "time_sig": TimeSignature( + TimingConsts.DEFAULT_TIME_SIGNATURE.value + ), + "ms": 0.0 + } + ] + + # For whatever reason, get_tempo_changes() returns a tuple + # of arrays with 2 elements per array, the 1st is the time and + # the 2nd is the tempo. In other cases, only 2 arrays are returned + # each one with one element, the 1st array is the time and the 2nd, the tempo + tempo_changes = pm_inst.get_tempo_changes() + if tempo_changes[0].size == 2: + for tempo_changes in tempo_changes: + if tempo_changes.size == 2: + self.tempo_changes.append( + {"tempo": tempo_changes[1], "ms": tempo_changes[0] * 1000} + ) + elif tempo_changes[0].size == 1 and len(tempo_changes) == 2: + self.tempo_changes.append( + {"tempo": tempo_changes[1][0], "ms": tempo_changes[0][0] * 1000} + ) + + # Initialize tonality. If it's given, ignore the KeySignatureChanges + if tonality is None: + self.tonality = tonality + else: + self.tonality = pm_inst.key_signature_changes + + last_note_end = self._get_last_note_end(pm_inst) + + # Add last note end to the time signature changes + # (for easier bars laoding) + self.tempo_changes.append( + {"tempo": self.tempo_changes[-1]["tempo"], "ms": last_note_end * 1000} + ) + + # Load Bars + #self._load_bars(last_note_end) + + # Load beats + self._load_beats(last_note_end) + + # Load subdivisions + self._load_subdivisions(last_note_end) + + # group beats in bars and create Bar objects + self._load_bars_and_group_beats_in_bars() + + # last bar is complete even if the last note does not + # end when the bar ends, so we'll add the empty beats + # to the last bar to complete the bar (as it's done in DAWs) + self.total_bars = len(self.bars) + + # Map instruments to load them as musicaiz instrument class + self._load_instruments_and_notes(pm_inst) + self.instruments_progs = [inst.program for inst in self.instruments] + + # Fill bar attributes related to notes information + self._fill_bar_notes_attributes() + + # assign bar indexes to subbeats + j, k = 0, 0 + for _, subbeat in enumerate(self.subbeats): + bar = self.bars[j] + beat = self.beats[k] + # bar_idx + if subbeat.start_sec >= bar.start_sec and subbeat.end_sec <= bar.end_sec: + subbeat.bar_idx = j + else: + j += 1 + subbeat.bar_idx = j + # beat_idx + if subbeat.start_sec >= beat.start_sec and subbeat.end_sec <= beat.end_sec: + subbeat.beat_idx = k + else: + k += 1 + subbeat.beat_idx = k + + assert len([sub for sub in self.subbeats if sub.bar_idx is None]) == 0 + assert len([sub for sub in self.subbeats if sub.beat_idx is None]) == 0 + """ + # if quantize + if quantize: + grid = get_subdivisions( + total_bars=self.total_bars, + subdivision=quantize_note, + time_sig=self.time_sig.time_sig, + bpm=self.bpm, + resolution=self.resolution, + absolute_timing=self.absolute_timing, + ) + v_grid = get_ticks_from_subdivision(grid) + advanced_quantizer(self.notes, v_grid) + """ + + @classmethod + def is_valid(cls, file: Union[str, Path]): + extension = cls.get_file_extension(file) + return True if extension in ValidFiles.all_extensions() else False + + @staticmethod + def get_file_extension(file: Union[str, Path]): + return Path(file).suffix + + @classmethod + def is_midi(cls, file: Union[str, Path]): + extension = cls.get_file_extension(file) + return True if extension in ValidFiles.MIDI.value else False + + @classmethod + def is_musicxml(cls, file: Union[str, TextIO]): + extension = cls.get_file_extension(file) + return True if extension in ValidFiles.MUSIC_XML.value else False + + def bar_beats_subdivs_analysis(self): + for i, time_sig in enumerate(self.time_signature_changes): + if i + 1 == len(self.time_signature_changes): + break + sb_len = len([sb for sb in self.subbeats if sb.time_sig.time_sig == time_sig["time_sig"].time_sig]) + print(f"{sb_len} subdivisions in {time_sig['time_sig'].time_sig}") + beat_len = len([beat for beat in self.beats if beat.time_sig.time_sig == time_sig["time_sig"].time_sig]) + print(f"{beat_len} beats in {time_sig['time_sig'].time_sig}") + bar_len = len([bar for bar in self.bars if bar.time_sig.time_sig == time_sig["time_sig"].time_sig]) + print(f"{bar_len} bars in {time_sig['time_sig'].time_sig}") + + # subbeat + def get_notes_in_subbeat( + self, + subbeat_idx: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + notes = self._filter_by_instruments(program, instrument_idx, self.notes) + return self._get_objs_in_subbeat(subbeat_idx, notes) + + def get_notes_in_subbeat_bar( + self, + subbeat_idx: int, + bar_idx: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + first_idx = len(self.get_subbeats_in_bars(0, bar_idx)) + global_idx = subbeat_idx + first_idx + all_notes = self.get_notes_in_bar(bar_idx, program, instrument_idx) + return self._get_objs_in_subbeat(global_idx, all_notes) + + def _get_objs_in_subbeat(self, subbeat_idx: int, objs): + if subbeat_idx >= len(self.subbeats): + raise ValueError( + f"Not subbeat index {subbeat_idx} found in bars. The file has {len(self.subbeats)} subbeats." + ) + return list(filter(lambda obj: obj.subbeat_idx == subbeat_idx, objs)) + + # subbeats + def get_notes_in_subbeats( + self, + subbeat_start: int, + subbeat_end: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + notes = self._filter_by_instruments(program, instrument_idx, self.notes) + return self._get_objs_in_subbeats( + subbeat_start, subbeat_end, notes + ) + + def _get_objs_in_subbeats( + self, + subbeat_start: int, + subbeat_end: int, + objs: List[Note] + ): + if subbeat_start > subbeat_end: + raise ValueError("subbeat_start must be minor than subbeat_end.") + return list( + filter( + lambda obj: obj.subbeat_idx >= subbeat_start and obj.subbeat_idx < subbeat_end, objs + ) + ) + + # beat + def get_notes_in_beat( + self, + beat_idx: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + """beat_idx is the global index of the beat in the file.""" + notes = self._filter_by_instruments(program, instrument_idx, self.notes) + return self._get_objs_in_beat(beat_idx, notes) if notes is not None else [] + + def get_notes_in_beat_bar( + self, + beat_idx: int, + bar_idx: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + """beat_idx is the local index of the beat in the file.""" + first_idx = self.get_subbeats_in_bar(bar_idx)[0].beat_idx + global_idx = beat_idx + first_idx + all_notes = self.get_notes_in_bar(bar_idx, program, instrument_idx) + return self._get_objs_in_beat(global_idx, all_notes) if all_notes is not None else [] + + def get_subbeats_in_beat(self, beat_idx: int) -> List[Subdivision]: + return self._get_objs_in_beat(beat_idx, self.subbeats) + + def get_subbeat_in_beat(self, subbeat_idx: int, beat_idx: int) -> Subdivision: + all_subbeats = self._get_objs_in_beat(beat_idx, self.subbeats) + # TODO: Error message if subbeat_idx > len(all_beats) + return all_subbeats[subbeat_idx] + + def _get_objs_in_beat(self, beat_idx: int, objs): + if beat_idx >= len(self.beats): + raise ValueError( + f"Not subbeat index {beat_idx} found in bars. The file has {len(self.beats)} beats." + ) + return list(filter(lambda obj: obj.beat_idx == beat_idx, objs)) + + # beats + def get_notes_in_beats( + self, + beat_start: int, + beat_end: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + notes = self._filter_by_instruments(program, instrument_idx, self.notes) + return self._get_objs_in_beats( + beat_start, beat_end, notes + ) if notes is not None else [] + + def get_subbeats_in_beats( + self, + beat_start: int, + beat_end: int, + ) -> List[Subdivision]: + return self._get_objs_in_beats( + beat_start, beat_end, self.subbeats + ) + + def _get_objs_in_beats( + self, + beat_start: int, + beat_end: int, + objs, + ): + if beat_start > beat_end: + raise ValueError("beat_start must be minor than beat_end.") + return list( + filter( + lambda obj: obj.beat_idx >= beat_start and obj.beat_idx < beat_end, objs + ) + ) + + # Bar + def get_notes_in_bar( + self, + bar_idx: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + notes = self._filter_by_instruments(program, instrument_idx, self.notes) + return self._get_objs_in_bar(bar_idx, notes) if notes is not None else [] + + def get_beats_in_bar( + self, + bar_idx: int, + ) -> List[Beat]: + return self._get_objs_in_bar(bar_idx, self.beats) + + def get_beat_in_bar(self, beat_idx: int, bar_idx: int) -> Beat: + all_beats = self._get_objs_in_bar(bar_idx, self.beats) + # TODO: Error message if beat_idx > len(all_beats) + return all_beats[beat_idx] + + def get_subbeats_in_bar(self, bar_idx: int) -> List[Subdivision]: + return self._get_objs_in_bar(bar_idx, self.subbeats) + + def get_subbeat_in_bar(self, subbeat_idx: int, bar_idx: int) -> List[Subdivision]: + all_subbeats = self._get_objs_in_bar(bar_idx, self.subbeats) + # TODO: Error message if subbeat_idx > len(all_beats) + return all_subbeats[subbeat_idx] + + def _get_objs_in_bar( + self, + bar_idx: int, + objs: List[Note] + ): + if bar_idx >= len(self.bars): + raise ValueError( + f"Not bar index {bar_idx} found in bars. The file has {len(self.bars)} bars." + ) + return list(filter(lambda obj: obj.bar_idx == bar_idx, objs)) + + # Bars + def get_notes_in_bars( + self, + bar_start: int, + bar_end: int, + program: Optional[Union[List[int], int]] = None, + instrument_idx: Optional[Union[List[int], int]] = None, + ) -> List[Note]: + notes = self._filter_by_instruments(program, instrument_idx, self.notes) + return self._get_objs_in_bars(bar_start, bar_end, notes) if notes is not None else [] + + def get_beats_in_bars(self, bar_start: int, bar_end: int) -> List[Beat]: + return self._get_objs_in_bars(bar_start, bar_end, self.beats) + + def get_subbeats_in_bars( + self, + bar_start: int, + bar_end: int, + ) -> List[Subdivision]: + return self._get_objs_in_bars(bar_start, bar_end, self.subbeats) + + def _get_objs_in_bars( + self, + bar_start: int, + bar_end: int, + obj + ): + if bar_start > bar_end: + raise ValueError("subbeat_start must be minor than subbeat_end.") + return list( + filter( + lambda obj: obj.bar_idx >= bar_start and obj.bar_idx < bar_end, obj + ) + ) + + # Instruments + def _filter_by_instruments( + self, + program: Optional[Union[List[int], int]], + instrument_idx: Optional[Union[List[int], int]], + objs, + ): + if program is not None: + if isinstance(program, list): + return self._filter_instruments(program, instrument_idx) + elif isinstance(program, int): + return self._filter_instrument(program, instrument_idx, objs) + else: + return objs + + def _filter_instruments( + self, + program: Optional[int], + instrument_idx: Optional[List[int]], + ): + if instrument_idx is not None and len(program) != len(instrument_idx): + raise ValueError("programs and instrument_idxs must have the same length.") + diff_progs = list(set(program).difference(set(self.instruments_progs))) + # if there's one or more programs not found, error + if len(diff_progs) != 0: + raise ValueError( + f"Programs {diff_progs} not found. Instruments programs are {self.instruments_progs}." + ) + if instrument_idx is None: + instrument_idx = [None for _ in range(len(program))] + objs = [] + for p, i in zip(program, instrument_idx): + objs.extend(self.get_notes_in_bars(0, len(self.bars), p, i)) + objs.sort(key=lambda x: x.start_sec, reverse=False) + return objs + + def _filter_instrument( + self, + program: Optional[int], + instrument_idx: Optional[List[int]], + objs, + ): + if program not in self.instruments_progs: + raise ValueError( + f"Not program {program} found in instruments. The file has the following programs: {self.instruments_progs}." + ) + filtered = list(filter(lambda obj: obj.instrument_prog == program, objs)) + if instrument_idx is None: + return filtered + else: + idxs = [note.instrument_idx for note in filtered] + idxs = list(dict.fromkeys(idxs)) + if instrument_idx not in idxs: + raise ValueError( + f"program {program} does not match instrument with index {instrument_idx}. " + f"Instrument indexes for program {program} are {idxs}." + ) + return list( + filter(lambda obj: obj.instrument_idx == instrument_idx, filtered) + ) + + @staticmethod + def _get_last_note_end(pm_inst): + # last note + last_notes = [inst.notes[-1] for inst in pm_inst.instruments] + last_notes.sort(key=lambda x: x.end, reverse=False) + return last_notes[-1].end # last note end in secs + + def writemidi(self, filename): + midi = musa_to_prettymidi(self) + midi.write(filename) + + def predict_key(self, method: str) -> str: + """ + Predict the key with the key profiles algorithms. + Note that signature fifths algorithm requires to initialize + the Musa class with the argument `structure="bars"` instead + of "instruments". The other algorithms work for both initializations. + + Parameters + ---------- + + method: str + The algorithm we want to use to predict the key. The list of + algorithms can be found here: :func:`~musicaiz.algorithms.KeyDetectionAlgorithms`. + + Raises + ------ + + ValueError + + ValueError + + Returns + ------- + key: str + The predicted key as a string separating tonic, alteration + (if proceeds) and mode with "_". + """ + if method not in KeyDetectionAlgorithms.all_values(): + raise ValueError("Not method found.") + elif method in KeyDetectionAlgorithms.SIGNATURE_FIFTHS.value: + # get notes in 2 1st bars (excluding drums) + # TODO: Exclude drums + all_notes = self.get_notes_in_bars(0, 2) + notes = [note for note in all_notes if not note.is_drum] + key = key_detection(notes, method) + elif ( + method in KeyDetectionAlgorithms.KRUMHANSL_KESSLER.value + or KeyDetectionAlgorithms.TEMPERLEY.value + or KeyDetectionAlgorithms.ALBRETCH_SHANAHAN.value + ): + notes = [note for note in self.notes if not note.is_drum] + key = key_detection(notes, method) + return key + + def _load_beats(self, last_note_end: float): + # Populate bars considering time signature changes + start_beat_ms = 0 + tempo_idx = 0 + for i, time_sig in enumerate(self.time_signature_changes): + # latest note end will be the end of the time_sig_changes + if i + 1 == len(self.time_signature_changes): + ms_next_change = last_note_end * 1000 + else: + # next time sig + ms_next_change = self.time_signature_changes[i + 1]["ms"] + + while True: + # we need to calculate the number of bars of time_sig[i] that we + # have before the next change (sec_next_change) + if tempo_idx + 1 >= len(self.tempo_changes): + break + if start_beat_ms <= self.tempo_changes[tempo_idx + 1][ + "ms" + ] and tempo_idx < len(self.tempo_changes): + bpm = self.tempo_changes[tempo_idx]["tempo"] + else: + tempo_idx += 1 + bpm = self.tempo_changes[tempo_idx]["tempo"] + + # Get the duration in ms of one bar + _, bar_ms = ms_per_bar( + time_sig["time_sig"].time_sig, bpm=bpm, resolution=self.resolution + ) + beat_ms = bar_ms / time_sig["time_sig"].num + beat_end = start_beat_ms + beat_ms + if beat_end > ms_next_change: + beat_end = ms_next_change + # If there's a tempo_change inside a bar, we'll + # calculate the % of bar that is in each tempo to + # calculate where the bar ends + if tempo_idx + 1 < len(self.tempo_changes): + if beat_end > self.tempo_changes[tempo_idx + 1]["ms"]: + ms_in_prev_tempo = ( + self.tempo_changes[tempo_idx + 1]["ms"] - start_beat_ms + ) + perc = ms_in_prev_tempo / beat_ms + _, bar_ms = ms_per_bar( + time_sig["time_sig"].time_sig, + bpm=self.tempo_changes[tempo_idx + 1]["tempo"], + resolution=self.resolution, + ) + beat_ms = bar_ms / time_sig["time_sig"].num + beat_end = start_beat_ms + ms_in_prev_tempo + beat_ms * perc + beat = Beat( + time_sig=time_sig["time_sig"], + start=start_beat_ms / 1000, + end=beat_end / 1000, + bpm=self.tempo_changes[tempo_idx]["tempo"], + resolution=self.resolution, + ) + self.beats.append(beat) + start_beat_ms = beat.end_sec * 1000 + if start_beat_ms >= ms_next_change: + break + + # TODO: This is almost the same as _load_beats, refactor + def _load_subdivisions(self, last_note_end: float): + # Populate bars considering time signature changes + start_subdiv_ms = 0 + tempo_idx = 0 + for i, time_sig in enumerate(self.time_signature_changes): + # latest note end will be the end of the time_sig_changes + if i + 1 == len(self.time_signature_changes): + ms_next_change = last_note_end * 1000 + else: + # next time sig + ms_next_change = self.time_signature_changes[i + 1]["ms"] + + while True: + # we need to calculate the number of bars of time_sig[i] that we + # have before the next change (sec_next_change) + if tempo_idx + 1 >= len(self.tempo_changes): + break + if start_subdiv_ms <= self.tempo_changes[tempo_idx + 1][ + "ms" + ] and tempo_idx < len(self.tempo_changes): + bpm = self.tempo_changes[tempo_idx]["tempo"] + else: + tempo_idx += 1 + bpm = self.tempo_changes[tempo_idx]["tempo"] + + # Get the duration in ms of one bar + _, bar_ms = ms_per_bar( + time_sig["time_sig"].time_sig, bpm=bpm, resolution=self.resolution + ) + subdivs_in_bar = time_sig["time_sig"]._notes_per_bar( + self.subdivision_note.upper() + ) + if subdivs_in_bar > 1: + ValueError("Subdivision note value must be lower than a bar.") + subdiv_ms = bar_ms / subdivs_in_bar + subdiv_end = start_subdiv_ms + subdiv_ms + if subdiv_end > ms_next_change: + subdiv_end = ms_next_change + # If there's a tempo_change inside a bar, we'll + # calculate the % of bar that is in each tempo to + # calculate where the bar ends + if tempo_idx + 1 < len(self.tempo_changes): + if subdiv_end > self.tempo_changes[tempo_idx + 1]["ms"]: + ms_in_prev_tempo = ( + self.tempo_changes[tempo_idx + 1]["ms"] - start_subdiv_ms + ) + perc = ms_in_prev_tempo / subdiv_ms + _, bar_ms = ms_per_bar( + time_sig["time_sig"].time_sig, + bpm=self.tempo_changes[tempo_idx + 1]["tempo"], + resolution=self.resolution, + ) + subdivs_in_bar = time_sig["time_sig"]._notes_per_bar( + self.subdivision_note.upper() + ) + subdiv_ms = bar_ms / subdivs_in_bar + subdiv_end = start_subdiv_ms + ms_in_prev_tempo + subdiv_ms * perc + + subdiv = Subdivision( + time_sig=time_sig["time_sig"], + start=start_subdiv_ms / 1000, + end=subdiv_end / 1000, + bpm=self.tempo_changes[tempo_idx]["tempo"], + resolution=self.resolution, + ) + self.subbeats.append(subdiv) + self.subbeats[-1].global_idx = len(self.subbeats) - 1 + start_subdiv_ms = subdiv.end_sec * 1000 + if start_subdiv_ms >= ms_next_change: + break + + def _load_instruments_and_notes(self, pm_inst) -> List[Note]: + """Populates `instruments` attribute mapping pretty_midi instruments + to musicaiz instrument class.""" + # Load pretty midi instruments and notes + notes = [] + for i, instrument in enumerate(pm_inst.instruments): + self.instruments.append( + Instrument( + program=instrument.program, + name=instrument.name, + is_drum=instrument.is_drum, + general_midi=self.general_midi, + ) + ) + + # convert pretty_midi Note objects to our Note objects + t = 0 + for pm_note in instrument.notes: + ms_next_change = self.tempo_changes[t + 1]["ms"] + if pm_note.start * 1000 >= ms_next_change: + bpm = self.tempo_changes[t + 1]["tempo"] + else: + bpm = self.tempo_changes[t]["tempo"] + note = Note( + start=pm_note.start, + end=pm_note.end, + pitch=pm_note.pitch, + velocity=pm_note.velocity, + instrument_prog=instrument.program, + bpm=bpm, + resolution=self.resolution, + instrument_idx=i, + is_drum=instrument.is_drum, + ) + notes.append(note) + + # sort notes by start time + notes.sort(key=lambda x: x.start_sec, reverse=False) + + for note in notes: + bar = list( + filter( + lambda obj: obj[1].start_sec <= note.start_sec, enumerate(self.bars) + ) + ) + note.bar_idx = bar[-1][0] + beat = list( + filter( + lambda obj: obj[1].start_sec <= note.start_sec, enumerate(self.beats) + ) + ) + note.beat_idx = beat[-1][0] + subbeat = list( + filter( + lambda obj: obj[1].start_sec <= note.start_sec, enumerate(self.subbeats) + ) + ) + note.subbeat_idx = subbeat[-1][0] + self.notes = notes + + def _fill_bar_notes_attributes(self): + for i, bar in enumerate(self.bars): + notes = self.get_notes_in_bar(i) + self.bars[i].note_density = len(notes) + self.bars[i].harmonic_density = get_harmonic_density(notes) + + def _load_bars_and_group_beats_in_bars(self): + bar_idx = 0 + beats = 0 + for i, beat in enumerate(self.beats): + time_sig = beat.time_sig.time_sig + beats_bar = beat.time_sig.num + if i == 0: + prev_beats_bar = beats_bar + prev_time_sig = time_sig + start_bar_sec = 0.0 + if beats_bar == prev_beats_bar and \ + beats + 1 <= beats_bar and \ + time_sig == prev_time_sig: + beat.bar_idx = bar_idx + beats += 1 + else: + end_bar_sec = beat.start_sec + bar_idx += 1 + beats = 0 + beat.bar_idx = bar_idx + beats += 1 + bar = Bar( + time_sig=self.beats[i - 1].time_sig, + start=start_bar_sec, + end=end_bar_sec, + bpm=self.beats[i - 1].bpm, + resolution=self.beats[i - 1].resolution + ) + self.bars.append(bar) + start_bar_sec = end_bar_sec + beat.global_idx = i + prev_beats_bar = beats_bar + prev_time_sig = time_sig + + if self.beats[-1].end_ticks > self.bars[-1].end_ticks: + # last bar + beats_last_bar = [beat for beat in self.beats if beat.bar_idx == len(self.bars)] + bar_sec = ms_per_bar( + beats_last_bar[0].time_sig.time_sig, + beats_last_bar[0].bpm, + beats_last_bar[0].resolution + )[1] / 1000 + bar = Bar( + time_sig=beats_last_bar[0].time_sig, + start=start_bar_sec, + end=start_bar_sec + bar_sec, + bpm=beat.bpm, + resolution=beat.resolution + ) + self.bars.append(bar) + # Now add as musch beats as needed to complete the last bar + if len(beats_last_bar) < bar.time_sig.num: + for _ in range(bar.time_sig.num - len(beats_last_bar)): + dur = self.beats[-1].end_sec - self.beats[-1].start_sec + beat = Beat( + time_sig=self.beats[-1].time_sig, + start=self.beats[-1].end_sec, + end=self.beats[-1].end_sec + dur, + bpm=self.beats[-1].bpm, + resolution=self.beats[-1].resolution, + ) + beat.global_idx = len(self.beats) + beat.bar_idx = len(self.bars) - 1 + self.beats.append(beat) diff --git a/musicaiz/plotters/pianorolls.py b/musicaiz/plotters/pianorolls.py index e392020..efb3638 100644 --- a/musicaiz/plotters/pianorolls.py +++ b/musicaiz/plotters/pianorolls.py @@ -1,7 +1,8 @@ +from re import sub import matplotlib.pyplot as plt from matplotlib.ticker import MultipleLocator import plotly.graph_objects as go -from typing import List, Union +from typing import List, Union, Optional from pathlib import Path import warnings @@ -12,10 +13,10 @@ TimingConsts ) from musicaiz.structure import Note, Instrument +from musicaiz.loaders import Musa - +# TODO: Add more colors. This only handles 10 instruments per plot COLOR_EDGES = [ - '#C232FF', '#C232FF', '#89FFAE', '#FFFF8B', @@ -25,11 +26,20 @@ '#FDB0F8', '#FFDC9C', '#F3A3C4', - '#E7E7E7' + '#E7E7E7', + "#AA7DBB", + "#7DBB90", + "#83C0B9", + "#83AFC0", + "#8AA7C3", + "#8A94C3", + "#C793CB", + "#CB93B9", + "#CA97A0", + "#CAB697", ] COLOR = [ - '#D676FF', '#D676FF', '#0AFE57', '#FEFF00', @@ -39,15 +49,22 @@ '#FF4CF4', '#FFB225', '#C25581', - '#737D73' + '#737D73', + "#E3A4FB", + "#A7F6BF", + "#A7F6ED", + "#ADE4F9", + "#ADD4F9", + "#B5C1F9", + "#F5B5F9", + "#FEBDE9", + "#FEBDC9", + "#F3DFC0", ] # TODO: subdivisions in plotly -# TODO: Plot all instruments in one plot plt and plotly # TODO: Hide note by clicking pitch in legend plotly -# TODO: plot bars plotly -# TODO: rethink bar plotting in plt # TODO: method to save plot in png or html @@ -75,7 +92,11 @@ class Pianoroll: ) """ - def __init__(self, dark: bool = False): + def __init__( + self, + musa: Optional[Musa] = None, + dark: bool = False, + ): if dark: background_color = "#282828" @@ -86,7 +107,9 @@ def __init__(self, dark: bool = False): self.ax.yaxis.set_major_locator(MultipleLocator(12)) self.ax.set_facecolor(background_color) - plt.xlabel("Time (bar.beat.subdivision)") + plt.xlabel("Time (bar)") + + self.musa = musa def plot_grid(self, subdivisions): # TODO: If we have lots of subdivisions (subdivision arg is small), then @@ -95,43 +118,31 @@ def plot_grid(self, subdivisions): # TODO: The same happens with pitch (y ticks). We should make smth to avoid # all the pitch values to be written in the axis plt.xlim((0, len(subdivisions) - 1)) - # Add 1st subdivision of new bar after last bar for better plotting the last bar - self._add_new_bar_subdiv(subdivisions) # Each subdivision has a vertical lines (grid) - self.ax.set_xticks([s["ticks"] for s in subdivisions]) - ##labels = [str(s["ticks"]) for s in subdivisions] - labels = [str(s["bar"]) + "." + str(s["bar_beat"]) + "." + str(s["bar_subdivision"]) for s in subdivisions] + prev_bar_idx = 0 + labels = [] + for s in subdivisions: + if s.bar_idx != prev_bar_idx: + labels.append(str(s.bar_idx)) + else: + labels.append("") + prev_bar_idx = s.bar_idx + self.ax.set_xticks([s.start_ticks for s in subdivisions]) self.ax.set_xticklabels(labels) self.ax.xaxis.grid(which="major", linewidth=0.1, color="gray") # Get labels for bar and beats prev_bar, prev_beat = 0, 0 bars_labels, beats_labels = [], [] for s in subdivisions: - if s["bar"] != prev_bar and not prev_bar == 0: + if s.bar_idx != prev_bar and not prev_bar == 0: bars_labels.append(s) - self.ax.axvline(x=s["ticks"], linestyle="--", linewidth=0.4, color="red") - if s["bar_beat"] != prev_beat and not prev_beat == 0: + self.ax.axvline(x=s.start_ticks, linestyle="-", linewidth=0.6, color="grey") + if s.beat_idx != prev_beat and not prev_beat == 0: beats_labels.append(s) - self.ax.axvline(x=s["ticks"], linestyle="--", linewidth=0.2, color="blue") - prev_bar, prev_beat = s["bar"], s["bar_beat"] - - @staticmethod - def _add_new_bar_subdiv(subdivisions): - """In `rhythm.get_subdivisions` method, we get all the subdivisions (starting times). - We can add the 1st subdivision of a new bar after the subdivisions dict for plotting.""" - range_sec = subdivisions[-1]["sec"] - subdivisions[-2]["sec"] - range_ticks = subdivisions[-1]["ticks"] - subdivisions[-2]["ticks"] - subdivisions.append({ - "bar": subdivisions[-1]["bar"] + 1, - "piece_beat": 1, - "piece_subdivision": 1, - "bar_beat": 1, - "bar_subdivision": 1, - "sec": subdivisions[-1]["sec"] + range_sec, - "ticks": subdivisions[-1]["ticks"] + range_ticks, - }) + self.ax.axvline(x=s.start_ticks, linestyle="-", linewidth=0.2, color="grey") + prev_bar, prev_beat = s.bar_idx, s.beat_idx - def _notes_loop(self, notes: List[Note]): + def _notes_loop(self, notes: List[Note], idx: int): plt.ylabel("Pitch") #highest_pitch = get_highest_pitch(track.instrument) #lowest_pitch = get_lowest_pitch(track.instrument) @@ -141,7 +152,7 @@ def _notes_loop(self, notes: List[Note]): plt.vlines(x=note.start_ticks, ymin=note.pitch, ymax=note.pitch + 1, - color=COLOR_EDGES[0], + color=COLOR_EDGES[idx], linewidth=0.01) self.ax.add_patch( @@ -149,76 +160,86 @@ def _notes_loop(self, notes: List[Note]): width=note.end_ticks - note.start_ticks, height=1, alpha=note.velocity / 127, - edgecolor=COLOR_EDGES[0], - facecolor=COLOR[0])) + edgecolor=COLOR_EDGES[idx], + facecolor=COLOR[idx])) - def plot_instrument( + def plot_instruments( self, - track, - total_bars: int, - subdivision: str, - time_sig: str = TimingConsts.DEFAULT_TIME_SIGNATURE.value, - bpm: int = TimingConsts.DEFAULT_BPM.value, - resolution: int = TimingConsts.RESOLUTION.value, - quantized: bool = False, + program: Union[int, List[int]], + bar_start: int, + bar_end: int, print_measure_data: bool = True, - show_bar_labels: bool = True + show_bar_labels: bool = True, + show_grid: bool = True, + show: bool = False, ): if print_measure_data: plt.text( - x=0, y=1.3, s=f"Measure: {time_sig}", transform=self.ax.transAxes, + x=0, y=1.3, s=f"Measure: {self.musa.time_signature_changes[0]}", transform=self.ax.transAxes, horizontalalignment='left', verticalalignment='top', fontsize=12) plt.text( - 0, 1.2, f"Displayed bars: {total_bars}", transform=self.ax.transAxes, + 0, 1.2, f"Displayed bars: {self.musa.total_bars}", transform=self.ax.transAxes, horizontalalignment='left', verticalalignment='top', fontsize=12) plt.text( - 0, 1.1, f"Quantized: {quantized}", transform=self.ax.transAxes, + 0, 1.1, f"Quantized: {self.musa.is_quantized}", transform=self.ax.transAxes, horizontalalignment='left', verticalalignment='top', fontsize=12) plt.text( - 1, 1.3, f"Tempo: {bpm}bpm", transform=self.ax.transAxes, + 1, 1.3, f"Tempo: {self.musa.tempo_changes[0]}bpm", transform=self.ax.transAxes, horizontalalignment='right', verticalalignment='top', fontsize=12) plt.text( - 1, 1.2, f"Subdivision: {subdivision}", transform=self.ax.transAxes, + 1, 1.2, f"Subdivision: {self.musa.subdivision_note}", transform=self.ax.transAxes, horizontalalignment='right', verticalalignment='top', fontsize=12) - #plt.text( - #1, 1.1, f"Instrument: {track.name}", transform=self.ax.transAxes, - #horizontalalignment='right', verticalalignment='top', fontsize=12) - - subdivisions = get_subdivisions(total_bars, subdivision, time_sig, bpm, resolution) - self.plot_grid(subdivisions) - self._notes_loop(track) + subdivisions = self.musa.get_subbeats_in_bars(bar_start, bar_end) + if show_grid: + self.plot_grid(subdivisions) + notes = self.musa.get_notes_in_bars(bar_start, bar_end, program) + if isinstance(program, list): + for i, p in enumerate(program): + notes_i = self.musa._filter_by_instruments(p, None, notes) + if notes_i is not None: + self._notes_loop(notes_i, i) + else: + self._notes_loop(notes, 0) if not show_bar_labels: self.ax.get_yaxis().set_visible(False) self.ax.get_xaxis().set_visible(False) + if show: + plt.show() class PianorollHTML: - """The Musa object need to be initialized with the argument - structure = `bars` for a good visualization of the pianoroll.""" - - #ax.yaxis.set_major_locator(MultipleLocator(1)) - #ax.grid(linewidth=0.25) - #ax.set_facecolor('#282828') - fig = go.Figure( - data=go.Scatter(), - layout=go.Layout( - { - "title": "", - #"template": "plotly_dark", - "xaxis": {'title': "subdivisions (bar.beat.subdivision)"}, - "yaxis": {'title': 'pitch'}, - } + def __init__( + self, + musa: Optional[Musa] = None, + ): + + """The Musa object need to be initialized with the argument + structure = `bars` for a good visualization of the pianoroll.""" + + self.musa = musa + #ax.yaxis.set_major_locator(MultipleLocator(1)) + #ax.grid(linewidth=0.25) + #ax.set_facecolor('#282828') + self.fig = go.Figure( + data=go.Scatter(), + layout=go.Layout( + { + "title": "", + #"template": "plotly_dark", + "xaxis": {'title': "subdivisions (bar)"}, + "yaxis": {'title': 'pitch'}, + } + ) ) - ) - def _notes_loop(self, notes: List[Note]): + def _notes_loop(self, notes: List[Note], idx: int): for note in notes: self.fig.add_shape( type="rect", @@ -227,10 +248,10 @@ def _notes_loop(self, notes: List[Note]): x1=note.end_ticks, y1=note.pitch + 1, line=dict( - color=COLOR_EDGES[0], + color=COLOR_EDGES[idx], width=2, ), - fillcolor=COLOR[0], + fillcolor=COLOR[idx], ) # this is to add a hover information on each note @@ -265,16 +286,21 @@ def plot_grid(self, subdivisions): # all the pitch values to be written in the axis #self.fig.update_xaxes(range[0, len(subdivisions) - 1]) #plt.xlim((0, len(subdivisions) - 1)) - # Add 1st subdivision of new bar after last bar for better plotting the last bar - Pianoroll._add_new_bar_subdiv(subdivisions) # Each subdivision has a vertical lines (grid) #self.fig.set_xticks([s["ticks"] for s in subdivisions]) ##labels = [str(s["ticks"]) for s in subdivisions] - labels = [str(s["bar"]) + "." + str(s["bar_beat"]) + "." + str(s["bar_subdivision"]) for s in subdivisions] + prev_bar_idx = 0 + labels = [] + for s in subdivisions: + if s.bar_idx != prev_bar_idx: + labels.append(str(s.bar_idx)) + else: + labels.append("") + prev_bar_idx = s.bar_idx self.fig.update_layout( xaxis=dict( tickmode="array", - tickvals=[s["ticks"] for s in subdivisions], + tickvals=[s.start_ticks for s in subdivisions], ticktext=labels, tickfont=dict( size=10, @@ -286,49 +312,39 @@ def plot_grid(self, subdivisions): prev_bar, prev_beat = 0, 0 bars_labels, beats_labels = [], [] for s in subdivisions: - if s["bar"] != prev_bar and not prev_bar == 0: + if s.bar_idx != prev_bar and not prev_bar == 0: bars_labels.append(s) - self.fig.add_vline(x=s["ticks"], line_width=0.4, line_color="red") - if s["bar_beat"] != prev_beat and not prev_beat == 0: + self.fig.add_vline(x=s.start_ticks, line_width=0.6, line_color="grey") + if s.beat_idx != prev_beat and not prev_beat == 0: beats_labels.append(s) - self.fig.add_vline(x=s["ticks"], line_width=0.2, line_color="blue") - prev_bar, prev_beat = s["bar"], s["bar_beat"] + self.fig.add_vline(x=s.start_ticks, line_width=0.2, line_color="grey") + prev_bar, prev_beat = s.bar_idx, s.beat_idx - def plot_instrument( + def plot_instruments( self, - track: Instrument, + program: Union[int, List[int]], bar_start: int, bar_end: int, - subdivision: str, path: Union[Path, str] = Path("."), filename: str = "title", save_plot: bool = True, - time_sig: str = TimingConsts.DEFAULT_TIME_SIGNATURE.value, - bpm: int = TimingConsts.DEFAULT_BPM.value, - resolution: int = TimingConsts.RESOLUTION.value, + show_grid: bool = True, show: bool = True ): - pitches = [] - if track.bars is None: - warnings.warn("Track has no bars. You probably initialized the Musa object with structure=`instruments`. \n \ - You can use `structure=`bars` to get the track bars. \n \ - The plotter is going to ignore the bars and it'll plot all the notes in the track.") - self._notes_loop(track.notes) - for note in track.notes: - if note.pitch not in pitches: - pitches.append(note.pitch) - else: - # TODO - for bar in track.bars[bar_start:bar_end]: - self._notes_loop(bar.notes) - for note in bar.notes: - if note.pitch not in pitches: - pitches.append(note.pitch) - - total_bars = bar_end - bar_start - subdivisions = get_subdivisions(total_bars, subdivision, time_sig, bpm, resolution) + subdivisions = self.musa.get_subbeats_in_bars(bar_start, bar_end) + if show_grid: self.plot_grid(subdivisions) + notes = self.musa.get_notes_in_bars(bar_start, bar_end, program) + if isinstance(program, list): + for i, p in enumerate(program): + notes_i = self.musa._filter_by_instruments(p, None, notes) + if notes_i is not None: + self._notes_loop(notes_i, i) + else: + self._notes_loop(notes, 0) + + pitches = [note.pitch for note in notes] # this is to add the yaxis labels# horizontal line for pitch grid labels = [i for i in range(min(pitches) - 1, max(pitches) + 2)] @@ -341,12 +357,12 @@ def plot_instrument( cleaned_labels = labels # Adjust y labels (pitch) - # TODO: label in the middle of the pitch + # TODO: label between pitches self.fig.update_layout( yaxis=dict( tickmode="array", - tickvals=labels, - ticktext=labels, + tickvals=cleaned_labels, + ticktext=cleaned_labels, tickfont=dict( size=12, ), diff --git a/musicaiz/rhythm/__init__.py b/musicaiz/rhythm/__init__.py index d4b5a9d..fb0166f 100644 --- a/musicaiz/rhythm/__init__.py +++ b/musicaiz/rhythm/__init__.py @@ -61,7 +61,10 @@ ms_per_note, ms_per_bar, get_subdivisions, - get_symbolic_duration, + get_symbolic_duration, + Timing, + Beat, + Subdivision, ) from .quantizer import ( @@ -87,4 +90,7 @@ "SymbolicNoteLengths", "get_symbolic_duration", "TimeSignature", + "Timing", + "Beat", + "Subdivision", ] diff --git a/musicaiz/rhythm/quantizer.py b/musicaiz/rhythm/quantizer.py index 0adb369..030563f 100644 --- a/musicaiz/rhythm/quantizer.py +++ b/musicaiz/rhythm/quantizer.py @@ -2,8 +2,7 @@ import numpy as np from enum import Enum -from musicaiz.structure import notes -from musicaiz.rhythm.timing import TimingConsts, ms_per_tick +from musicaiz.rhythm import TimingConsts, ms_per_tick class QuantizerConfig(Enum): @@ -33,7 +32,7 @@ def get_ticks_from_subdivision( def basic_quantizer( - input_notes: notes, + input_notes, grid: List[int], bpm: int = TimingConsts.DEFAULT_BPM.value ): @@ -53,11 +52,11 @@ def basic_quantizer( input_notes[i].end_ticks = end_tick + abs(delta_tick) input_notes[i].start_sec = input_notes[i].start_ticks * ms_per_tick(bpm) - input_notes[i].end_sec =input_notes[i].end_ticks * ms_per_tick(bpm) + input_notes[i].end_sec = input_notes[i].end_ticks * ms_per_tick(bpm) def advanced_quantizer( - input_notes: notes, + input_notes, grid: List[int], strength: float = QuantizerConfig.STRENGTH.value, delta_Qr: int = QuantizerConfig.DELTA_QR.value, @@ -129,5 +128,5 @@ def advanced_quantizer( input_notes[i].start_ticks = input_notes[i].start_ticks + abs(delta_tick_q) input_notes[i].end_ticks = input_notes[i].end_ticks + abs(delta_tick_q) - input_notes[i].start_sec = input_notes[i].start_ticks * ms_per_tick(bpm) // 1000 - input_notes[i].end_sec = input_notes[i].end_ticks * ms_per_tick(bpm) // 1000 + input_notes[i].start_sec = input_notes[i].start_ticks * ms_per_tick(bpm) / 1000 + input_notes[i].end_sec = input_notes[i].end_ticks * ms_per_tick(bpm) / 1000 diff --git a/musicaiz/rhythm/timing.py b/musicaiz/rhythm/timing.py index 1535be5..738f093 100644 --- a/musicaiz/rhythm/timing.py +++ b/musicaiz/rhythm/timing.py @@ -1,6 +1,7 @@ from __future__ import annotations +from abc import ABCMeta from enum import Enum -from typing import Tuple, List, Dict, Union +from typing import Tuple, List, Dict, Union, Optional import numpy as np import math @@ -70,18 +71,22 @@ def ms( resolution: int = TimingConsts.RESOLUTION.value, ) -> float: return ms_per_tick(bpm, resolution) * self.ticks(resolution) - + @classmethod - def get_note_ticks_mapping(cls, triplets: bool = False) -> Dict[str, int]: + def get_note_ticks_mapping( + cls, + triplets: bool = False, + resolution: int = TimingConsts.RESOLUTION.value, + ) -> Dict[str, int]: dict_notes = {} for note_dur in list(cls.__members__.keys()): if not triplets: # remove triplet durations (optional) if "TRIPLET" in note_dur: continue - dict_notes.update({note_dur: cls[note_dur].ticks()}) + dict_notes.update({note_dur: cls[note_dur].ticks(resolution)}) return dict_notes - + @classmethod def get_note_with_fraction(cls, fraction: float) -> NoteLengths: for note in cls.__members__: @@ -167,27 +172,27 @@ def __init__(self, time_sig: Union[Tuple[int, int], str]): @property def beats_per_bar(self) -> int: return self.num - + @property def beat_type(self) -> str: return TimeSigDenominators.get_note_length(self.denom).name - + def _notes_per_bar(self, note_name: str) -> int: return (1 / NoteLengths[note_name].value) * self.num * (1 / self.denom) - + @property def quarters(self) -> int: # Get the name of the denominator note: NoteLengths(1 / self.denom).name return self._notes_per_bar("QUARTER") - + @property def eights(self) -> int: return self._notes_per_bar("EIGHT") - + @property def sixteenths(self) -> int: return self._notes_per_bar("SIXTEENTH") - + def __repr__(self): return "TimeSig(num={}, den={})".format( self.num, @@ -195,9 +200,150 @@ def __repr__(self): ) +class Timing(metaclass=ABCMeta): + + def __init__( + self, + bpm: float, + resolution: int, + start: Union[int, float], + end: Union[int, float], + ): + self.ms_tick = ms_per_tick(bpm, resolution) + + timings = self._initialize_timing_attributes( + start, end, self.ms_tick + ) + + self.start_ticks = timings["start_ticks"] + self.end_ticks = timings["end_ticks"] + self.start_sec = timings["start_sec"] + self.end_sec = timings["end_sec"] + self.bpm = bpm + self.resolution = resolution + + @staticmethod + def _initialize_timing_attributes( + start: Union[int, float], + end: Union[int, float], + ms_tick: Union[int, float], + ) -> Dict[str, Union[int, float]]: + # inital checks + if start < 0 or end <= 0: + raise ValueError("Start and end must be positive.") + elif start >= end: + raise ValueError("Start time must be lower than the end time.") + + # ticks must be int, secs must be float + if isinstance(start, int) and isinstance(end, int): + start_ticks = start + end_ticks = end + start_sec = start_ticks * ms_tick / 1000 + end_sec = end_ticks * ms_tick / 1000 + elif isinstance(start, float) and isinstance(end, float): + start_sec = start + end_sec = end + start_ticks = int(start_sec * (1 / (ms_tick / 1000))) + end_ticks = int(end_sec * (1 / (ms_tick / 1000))) + + timings = { + "start_ticks": start_ticks, + "end_ticks": end_ticks, + "start_sec": start_sec, + "end_sec": end_sec, + } + return timings + + +class Beat(Timing): + + def __init__( + self, + bpm: float, + resolution: int, + start: Union[int, float], + end: Union[int, float], + time_sig: Optional[TimeSignature] = None, + global_idx: Optional[int] = None, + bar_idx: Optional[int] = None, + ): + + super().__init__(bpm, resolution, start, end) + + self.time_sig = time_sig + self.global_idx = global_idx + self.bar_idx = bar_idx + self.symbolic = TimeSigDenominators.get_note_length( + time_sig.denom + ).name.lower() + + def __repr__(self): + + return "Beat(time_signature={}, " \ + "bpm={}, " \ + "start_ticks={} " \ + "end_ticks={} " \ + "start_sec={} " \ + "end_sec={} " \ + "global_idx={} " \ + "bar_idx={} " \ + "symbolic={})".format( + self.time_sig, + self.bpm, + self.start_ticks, + self.end_ticks, + self.start_sec, + self.end_sec, + self.global_idx, + self.bar_idx, + self.symbolic, + ) + + +class Subdivision(Timing): + + def __init__( + self, + bpm: float, + resolution: int, + start: Union[int, float], + end: Union[int, float], + time_sig: Optional[TimeSignature] = None, + global_idx: Optional[int] = None, + bar_idx: Optional[int] = None, + beat_idx: Optional[int] = None, + ): + + super().__init__(bpm, resolution, start, end) + + self.time_sig = time_sig + self.global_idx = global_idx + self.bar_idx = bar_idx + self.beat_idx = beat_idx + + def __repr__(self): + + return "Subdivision(time_signature={}, " \ + "bpm={}, " \ + "start_ticks={} " \ + "end_ticks={} " \ + "start_sec={} " \ + "end_sec={} " \ + "global_idx={} " \ + "bar_idx={} " \ + "beat_idx={})".format( + self.time_sig, + self.bpm, + self.start_ticks, + self.end_ticks, + self.start_sec, + self.end_sec, + self.global_idx, + self.bar_idx, + self.beat_idx, + ) + -# TODO: Refactor all these functions to a class with bpm and resolution as attributes? -# We are repeating so many times the bpm and resolution input args def ms_per_tick( bpm: int = TimingConsts.DEFAULT_BPM.value, resolution: int = TimingConsts.RESOLUTION.value, @@ -363,7 +509,7 @@ def get_subdivisions( resolution: int the pulses o ticks per quarter note (PPQ or TPQN). - + absolute_timing: bool default is True. This allows to initialize note time arguments in absolute (True) or relative time units (False). Relative units means that each bar will start at 0 seconds @@ -466,10 +612,14 @@ def get_subdivisions( return beat_subdivs -def get_symbolic_duration(duration: int, triplets: bool = False) -> str: +def get_symbolic_duration( + duration: int, + triplets: bool = False, + resolution: int = TimingConsts.RESOLUTION.value, +) -> str: """Given a note duration in ticks it calculates its symbolic duration: half, quarter, dotted_half... - + Parameters ----------- @@ -478,7 +628,7 @@ def get_symbolic_duration(duration: int, triplets: bool = False) -> str: triplets durations. """ - all_notes_ticks = NoteLengths.get_note_ticks_mapping(triplets) + all_notes_ticks = NoteLengths.get_note_ticks_mapping(triplets, resolution) notes_ticks = all_notes_ticks # look for the closest note in the notes ticks dict @@ -487,4 +637,4 @@ def get_symbolic_duration(duration: int, triplets: bool = False) -> str: arr = np.asarray(list(notes_ticks.values())) i = (np.abs(arr - duration)).argmin() symbolic_duration = list(notes_ticks.keys())[i] - return symbolic_duration \ No newline at end of file + return symbolic_duration diff --git a/musicaiz/structure/bars.py b/musicaiz/structure/bars.py index 83933fd..fe5845f 100644 --- a/musicaiz/structure/bars.py +++ b/musicaiz/structure/bars.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Optional import numpy as np @@ -6,6 +6,7 @@ TimingConsts, ms_per_bar, ms_per_tick, + Timing, ) from musicaiz.structure import Note @@ -38,10 +39,12 @@ class Bar: def __init__( self, + start: Optional[Union[int, float]] = None, + end: Optional[Union[int, float]] = None, time_sig: str = TimingConsts.DEFAULT_TIME_SIGNATURE.value, bpm: int = TimingConsts.DEFAULT_BPM.value, resolution: int = TimingConsts.RESOLUTION.value, - absolute_timing: bool = True + absolute_timing: bool = True, ): self.bpm = bpm self.time_sig = time_sig @@ -50,13 +53,23 @@ def __init__( # The following attributes are set when loading a MIDI file # with Musa class - self.notes = [] self.note_density = None self.harmonic_density = None - self.start_ticks = None - self.end_ticks = None - self.start_sec = None - self.end_sec = None + + self.ms_tick = ms_per_tick(bpm, resolution) + + if start is not None and end is not None: + timings = Timing._initialize_timing_attributes(start, end, self.ms_tick) + + self.start_ticks = timings["start_ticks"] + self.end_ticks = timings["end_ticks"] + self.start_sec = timings["start_sec"] + self.end_sec = timings["end_sec"] + else: + self.start_ticks = None + self.end_ticks = None + self.start_sec = None + self.end_sec = None def relative_notes_timing(self, bar_start: float): """The bar start is the value in ticks where the bar starts""" @@ -118,12 +131,14 @@ def __repr__(self): else: end_sec = self.end_sec - return "Bar(note_density={}, " \ + return "Bar(time_signature={}, " \ + "note_density={}, " \ "harmonic_density={} " \ "start_ticks={} " \ "end_ticks={} " \ "start_sec={} " \ "end_sec={})".format( + self.time_sig, self.note_density, self.harmonic_density, self.start_ticks, diff --git a/musicaiz/structure/instruments.py b/musicaiz/structure/instruments.py index 0279530..87fe612 100644 --- a/musicaiz/structure/instruments.py +++ b/musicaiz/structure/instruments.py @@ -12,6 +12,7 @@ except ImportError: _HAS_FLUIDSYNTH = False + class InstrumentMidiPrograms(Enum): # Value 1: List of Midi instrument program number ACOUSTIC_GRAND_PIANO = 0 @@ -442,7 +443,7 @@ def __init__( else: if program is not None: self.program = program - if name is None: + if name is None or name == "": name_obj = InstrumentMidiPrograms.get_name_from_program(program) else: if general_midi: @@ -480,119 +481,6 @@ def __init__( # List of bars in the instrument self.bars = [] - # TODO - def concatenate_instruments( - instrument_1: Instrument, - instrument_2: Instrument, - new_program: int = None, - new_name: str = None, - ): - """Creates a new instrument by concatenating the notes - of 2 instruments and setting a new name and program. - If no name nor program is provided, they will be set - as the `instrument_1`.""" - pass - - def fluidsynth( - self, - fs=44100, - sf2_path=None - ): - """Synthesize using fluidsynth. - Parameters - ---------- - fs : int - Sampling rate to synthesize. - sf2_path : str - Path to a .sf2 file. - Default ``None``, which uses the TimGM6mb.sf2 file included with - ``pretty_midi``. - Returns - ------- - synthesized : np.ndarray - Waveform of the MIDI data, synthesized at ``fs``. - """ - # If sf2_path is None, use the included TimGM6mb.sf2 path - if sf2_path is None: - raise ValueError(f"No sf2_path provided.") - - if not _HAS_FLUIDSYNTH: - raise ImportError("fluidsynth() was called but pyfluidsynth " - "is not installed.") - - if not Path.exists(sf2_path): - raise ValueError("No soundfont file found at the supplied path " - "{}".format(sf2_path)) - - # If the instrument has no notes, return an empty array - if len(self.notes) == 0: - return np.array([]) - - # Create fluidsynth instance - fl = fluidsynth.Synth(samplerate=fs) - # Load in the soundfont - sfid = fl.sfload(sf2_path) - # If this is a drum instrument, use channel 9 and bank 128 - if self.is_drum: - channel = 9 - # Try to use the supplied program number - res = fl.program_select(channel, sfid, 128, self.program) - # If the result is -1, there's no preset with this program number - if res == -1: - # So use preset 0 - fl.program_select(channel, sfid, 128, 0) - # Otherwise just use channel 0 - else: - channel = 0 - fl.program_select(channel, sfid, 0, self.program) - # Collect all notes in one list - event_list = [] - for note in self.notes: - event_list += [[note.start, 'note on', note.pitch, note.velocity]] - event_list += [[note.end, 'note off', note.pitch]] - for bend in self.pitch_bends: - event_list += [[bend.time, 'pitch bend', bend.pitch]] - for control_change in self.control_changes: - event_list += [[control_change.time, 'control change', - control_change.number, control_change.value]] - # Sort the event list by time, and secondarily by whether the event - # is a note off - event_list.sort(key=lambda x: (x[0], x[1] != 'note off')) - # Add some silence at the beginning according to the time of the first - # event - current_time = event_list[0][0] - # Convert absolute seconds to relative samples - next_event_times = [e[0] for e in event_list[1:]] - for event, end in zip(event_list[:-1], next_event_times): - event[0] = end - event[0] - # Include 1 second of silence at the end - event_list[-1][0] = 1. - # Pre-allocate output array - total_time = current_time + np.sum([e[0] for e in event_list]) - synthesized = np.zeros(int(np.ceil(fs*total_time))) - # Iterate over all events - for event in event_list: - # Process events based on type - if event[1] == 'note on': - fl.noteon(channel, event[2], event[3]) - elif event[1] == 'note off': - fl.noteoff(channel, event[2]) - elif event[1] == 'pitch bend': - fl.pitch_bend(channel, event[2]) - elif event[1] == 'control change': - fl.cc(channel, event[2], event[3]) - # Add in these samples - current_sample = int(fs*current_time) - end = int(fs*(current_time + event[0])) - samples = fl.get_samples(end - current_sample)[::2] - synthesized[current_sample:end] += samples - # Increment the current sample - current_time += event[0] - # Close fluidsynth - fl.delete() - - return synthesized - def __repr__(self): if self.family is None: family = "unknown" diff --git a/musicaiz/structure/notes.py b/musicaiz/structure/notes.py index 6a135c3..1d1ff83 100644 --- a/musicaiz/structure/notes.py +++ b/musicaiz/structure/notes.py @@ -1,6 +1,6 @@ from __future__ import annotations import pretty_midi as pm -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional, Dict import re from enum import Enum @@ -9,7 +9,8 @@ TimingConsts, ms_per_tick, get_symbolic_duration, - SymbolicNoteLengths + SymbolicNoteLengths, + Timing ) @@ -216,7 +217,7 @@ def get_natural_scale_notes(cls) -> List[NoteClassBase]: if cls[str(note)].natural_scale_index is not None: notes.append(cls[str(note)]) return notes - + @classmethod def get_all_chromatic_scale_notes(cls) -> List[NoteClassBase]: """Returns all the notes in chromatic scale (flats AND sharps).""" @@ -393,27 +394,17 @@ def __init__( self.ms_tick = ms_per_tick(bpm, resolution) - # inital checks - if start < 0 or end <= 0: - raise ValueError("Start and end must be positive.") - elif start >= end: - raise ValueError("Start time must be lower than the end time.") - - # ticks must be int, secs must be float - if isinstance(start, int) and isinstance(end, int): - self.start_ticks = start - self.end_ticks = end - self.start_sec = self.start_ticks * self.ms_tick / 1000 - self.end_sec = self.end_ticks * self.ms_tick / 1000 - elif isinstance(start, float) and isinstance(end, float): - self.start_sec = start - self.end_sec = end - self.start_ticks = int(self.start_sec * (1 / (self.ms_tick / 1000))) - self.end_ticks = int(self.end_sec * (1 / (self.ms_tick / 1000))) + timings = Timing._initialize_timing_attributes(start, end, self.ms_tick) + + self.start_ticks = timings["start_ticks"] + self.end_ticks = timings["end_ticks"] + self.start_sec = timings["start_sec"] + self.end_sec = timings["end_sec"] self.symbolic = get_symbolic_duration( self.end_ticks - self.start_ticks, - True + True, + resolution ) def __repr__(self): @@ -477,7 +468,11 @@ class Note(NoteTiming): "velocity", "bpm", "resolution", - "ligated" + "ligated", + "instrument_prog", + "bar_idx", + "instrument_idx", + "is_drum" ] def __init__( @@ -486,18 +481,37 @@ def __init__( start: Union[int, float], end: Union[int, float], velocity: int, + instrument_prog: Optional[int] = None, + instrument_idx: Optional[int] = None, + bar_idx: Optional[int] = None, + beat_idx: Optional[int] = None, + subbeat_idx: Optional[int] = None, ligated: bool = False, bpm: int = TimingConsts.DEFAULT_BPM.value, resolution: int = TimingConsts.RESOLUTION.value, + is_drum: bool = False ): super().__init__(pitch, start, end, bpm, resolution) self.velocity = velocity + self.instrument_prog = instrument_prog + self.bar_idx = bar_idx + self.beat_idx = beat_idx + self.subbeat_idx = subbeat_idx + + # We can have 2 instruments (or tracks) with the same program number, + # so this will store the index of the instrument to distinguish equal + # program number instruments in MIDI files + self.instrument_idx = instrument_idx # if a note belongs to 2 bars and we split the tracks by bars self.ligated = ligated + self.resolution = resolution + self.bpm = bpm + self.is_drum = is_drum + def __repr__(self): return "Note(pitch={}, " \ "name={}, " \ @@ -507,7 +521,11 @@ def __repr__(self): "end_ticks={}, " \ "symbolic={}, " \ "velocity={}, " \ - "ligated={})".format( + "ligated={}, " \ + "instrument_prog={}, " \ + "bar_idx={}, " \ + "beat_idx={}, " \ + "subbeat_idx={})".format( self.pitch, self.note_name, self.start_sec, @@ -516,5 +534,9 @@ def __repr__(self): self.end_ticks, SymbolicNoteLengths[self.symbolic].value, self.velocity, - self.ligated + self.ligated, + self.instrument_prog, + self.bar_idx, + self.beat_idx, + self.subbeat_idx, ) diff --git a/musicaiz/tokenizers/__init__.py b/musicaiz/tokenizers/__init__.py index c937933..393de22 100644 --- a/musicaiz/tokenizers/__init__.py +++ b/musicaiz/tokenizers/__init__.py @@ -38,6 +38,8 @@ TokenizerArguments MMMTokenizerArguments MMMTokenizer + REMITokenizerArguments + REMITokenizer """ @@ -51,6 +53,10 @@ MMMTokenizer, MMMTokenizerArguments, ) +from .remi import ( + REMITokenizer, + REMITokenizerArguments, +) from .one_hot import ( OneHot, ) @@ -58,11 +64,13 @@ TOKENIZER_ARGUMENTS = [ MMMTokenizerArguments, + REMITokenizerArguments, ] class Tokenizers(Enum): MULTI_TRACK_MUSIC_MACHINE = ("MMM", MMMTokenizerArguments) + REMI = ("REMI", REMITokenizerArguments) @property def name(self): @@ -88,5 +96,7 @@ def args(): "Tokenizers", "MMMTokenizerArguments", "MMMTokenizer", + "REMITokenizerArguments", + "REMITokenizer", "OneHot" ] diff --git a/musicaiz/tokenizers/encoder.py b/musicaiz/tokenizers/encoder.py index 133409d..5f14e5b 100644 --- a/musicaiz/tokenizers/encoder.py +++ b/musicaiz/tokenizers/encoder.py @@ -1,28 +1,116 @@ -from abc import ABCMeta, abstractmethod -from typing import List +from abc import ABCMeta +from typing import List, Dict from enum import Enum from pathlib import Path -_POPULATE_NOTE_ON = [f"NOTE_ON={i} " for i in range(0, 128)] -_POPULATE_NOTE_OFF = [f"NOTE_OFF={i} " for i in range(0, 128)] -_POPULATE_NOTE_INST = [f"NOTE_INST={i} " for i in range(0, 128)] +class TokenizerArguments: + pass -class Tokens(Enum): - GENRE = [] - SUBGENRE = [] - STRUCTURE = ["piece", "instrument", "bar"] - NOTE_DENSITY = ["piece", "instrument", "bar"] - HARMONIC_DENSITY = ["piece", "instrument", "bar"] - CHORDS = ["piece", "instrument", "bar", "time_step"] +class EncodeBase(metaclass=ABCMeta): + @classmethod + def _get_tokens_analytics( + cls, + tokens: str, + note_token: str, + piece_start_token: str, + ) -> Dict[str, int]: + """ + Extracts features to aanlyze the given token sequence. -class TokenizerArguments: - pass + Parameters + ---------- + tokens: str + A token sequence. -class EncodeBase(metaclass=ABCMeta): + Returns + ------- + + analytics: Dict[str, int] + The ``analytics`` dict keys are: + - ``total_tokens`` + - ``unique_tokens`` + - ``total_notes`` + - ``unique_notes`` + - ``total_bars``: non empty bars + - ``total_instruments`` + - ``unique_instruments`` + """ + # Convert str in list of pieces that contain tokens + # We suppose that the piece starts with a BAR=0 token (that is, any instr has notes in the 1st bar) + dataset_tokens = cls._get_pieces_tokens(tokens, piece_start_token) + # Start the analysis + note_counts, bar_counts, instr_counts = 0, 0, 0 # total notes and bars (also repeated note values) + total_toks = 0 + unique_tokens, unique_notes, unique_instr = [], [], [] # total non-repeated tokens + unique_genres, unique_composers, unique_periods = [], [], [] + for piece, toks in enumerate(dataset_tokens): + for tok in toks: + total_toks += 1 + if tok not in unique_tokens: + unique_tokens.append(tok) + if note_token in tok: + note_counts += 1 + if "BAR" in tok: + bar_counts += 1 + if "INST" in tok: + instr_counts += 1 + if note_token in tok and tok not in unique_notes: + unique_notes.append(tok) + if "INST" in tok and tok not in unique_instr: + unique_instr.append(tok) + if "GENRE" in tok and tok not in unique_genres: + unique_genres.append(tok) + if "PERIOD" in tok and tok not in unique_periods: + unique_periods.append(tok) + if "COMPOSER" in tok and tok not in unique_composers: + unique_composers.append(tok) + if piece_start_token == "BAR=0": + bar_counts += 1 + analytics = { + "total_pieces": piece + 1, + "total_tokens": len(tokens.split(" ")), + "unique_tokens": len(unique_tokens), + "total_notes": note_counts, + "unique_notes": len(unique_notes), + "total_bars": bar_counts, + "total_instruments": instr_counts, + } + if len(unique_genres) != 0: + analytics.update({"unique_genres": len(unique_genres)}) + if len(unique_periods) != 0: + analytics.update({"unique_periods": len(unique_periods)}) + if len(unique_composers) != 0: + analytics.update({"unique_composers": len(unique_composers)}) + + return analytics + + @staticmethod + def _get_pieces_tokens(tokens: str, token: str) -> List[List[str]]: + """Converts the tokens str that can contain one or more + pieces into a list of pieces that are also lists which contain + one item per token. + + Example (MMMTokenizer) + ---------------------- + >>> tokens = "PIECE_START INST=0 ... PIECE_START ..." + >>> dataset_tokens = _get_pieces_tokens(tokens, "PIECE_START") + >>> [ + ["PIECE_START INST=0 ...], + ["PIECE_START ...], + ] + """ + tokens = tokens.split(token) + if "" in tokens: tokens.remove("") + dataset_tokens = [] + for piece in tokens: + piece_tokens = piece.split(" ") + if "" in piece_tokens: piece_tokens.remove("") + dataset_tokens.append(piece_tokens) + return dataset_tokens def get_vocabulary( dataset_path: str, @@ -55,9 +143,9 @@ def get_vocabulary( with open(Path(dataset_path, vocab_filename + ".txt"), "w") as vocab_file: vocab_file.write(" ".join(vocabulary)) - + return vocabulary - + @staticmethod def add_token_to_vocabulary(): pass diff --git a/musicaiz/tokenizers/mmm.py b/musicaiz/tokenizers/mmm.py index 34c1ee4..4166a6b 100644 --- a/musicaiz/tokenizers/mmm.py +++ b/musicaiz/tokenizers/mmm.py @@ -24,7 +24,7 @@ logger = logging.getLogger("mmm-tokenizer") -logging.basicConfig(level = logging.INFO) +logging.basicConfig(level=logging.INFO) @dataclass @@ -36,38 +36,38 @@ class MMMTokenizerArguments(TokenizerArguments): windowing: bool if True, the method tokenizes each file by applying bars windowing. - + time_unit: str the note length in `VALID_TIME_UNITS` that one `TIME_DELTA` unit will be equal to. This allows to tokenize in a wide variety of note lengths for diverse purposes. Be careful when choosing this value because if there are notes which duration is lower than the chosen time_unit value, they won't be tokenized. - + num_programs: List[int] the number of programs to tokenize. If None, the method tokenizes all the tracks. - + shuffle_tracks: bool shuffles the order of tracks in each window (PIECE). - + track_density: bool if True a token DENSITY is added at the beggining of each track. - + window_size: int the number of bars per track to tokenize. - + hop_length: int the number of bars to slice when tokenizing. If a MIDI file contains 5 bars and the window size is 4 and the hop length is 1, it'll be splitted in 2 PIECE tokens, one from bar 1 to 4 and the other on from bar 2 to 5 (somehow like audio FFT). - + time_sig: bool if we want to include the time signature in the samples. Note that the time signature will be added to the piece-level, that is, before the first track starts. - + velocity: bool if we want to add the velocity token. Velocities ranges between 1 and 128 (ints). - + quantize: bool if we want to quantize the symbolic music data for tokenizing. """ @@ -102,19 +102,20 @@ def __init__( ): if args is None: - raise ValueError(f"No `MMMTokenizerArguments` passed.") + raise ValueError("No `MMMTokenizerArguments` passed.") self.args = args # Convert file into a Musa object to be processed if file is not None: self.midi_object = Musa( file=file, - structure="bars", absolute_timing=False, quantize=self.args.quantize, cut_notes=False ) - + else: + self.midi_object = Musa(file=None) + def tokenize_file( self, ) -> str: @@ -140,7 +141,8 @@ def tokenize_file( if not self.args.windowing: if self.args.time_sig: - time_sig_tok = f"TIME_SIG={self.midi_object.time_sig.time_sig} " + time_sig = self.midi_object.time_signature_changes[0]['time_sig'].time_sig + time_sig_tok = f"TIME_SIG={time_sig} " else: time_sig_tok = "" if self.args.tempo: @@ -150,7 +152,7 @@ def tokenize_file( tokens = self.tokenize_tracks( instruments=tokenized_instruments, bar_start=0, - tokens="PIECE_START " + self.args.prev_tokens + " " + time_sig_tok + tempo_tok, + tokens="PIECE_START " + self.args.prev_tokens + time_sig_tok + tempo_tok, ) tokens += "\n" else: @@ -162,14 +164,15 @@ def tokenize_file( if i + self.args.window_size == self.midi_object.total_bars: break if self.args.time_sig: - time_sig_tok = f"TIME_SIG={self.midi_object.time_sig.time_sig} " + time_sig = self.midi_object.time_signature_changes[0]['time_sig'].time_sig + time_sig_tok = f"TIME_SIG={time_sig} " else: time_sig_tok = "" tokens += self.tokenize_tracks( tokenized_instruments, bar_start=i, - bar_end=i+self.args.window_size, - tokens="PIECE_START " + self.args.prev_tokens + " " + time_sig_tok, + bar_end=i + self.args.window_size, + tokens="PIECE_START " + self.args.prev_tokens + time_sig_tok, ) tokens += "\n" return tokens @@ -190,12 +193,6 @@ def tokenize_tracks( instruments: List[Instrument] the list of instruments to tokenize. - track_density: bool - if True a token DENSITY is added at the beggining of each track. - The token DENSITY is the total notes of the track or instrument. - - velocity: bool = False - Returns ------- @@ -214,9 +211,9 @@ def tokenize_tracks( # loop in bars if bar_end is None: bar_end = len(inst.bars) - bars = inst.bars[bar_start:bar_end] + bars = self.midi_object.bars[bar_start:bar_end] tokens = self.tokenize_track_bars( - bars, tokens + bars, inst.program, tokens ) if inst_idx + 1 == len(instruments): tokens += "TRACK_END" @@ -227,6 +224,7 @@ def tokenize_tracks( def tokenize_track_bars( self, bars: List[Bar], + program: int, tokens: Optional[str] = None, ) -> str: """ @@ -253,16 +251,14 @@ def tokenize_track_bars( if self.args.time_unit not in VALID_TIME_UNITS: raise ValueError(f"Invalid time unit: {self.args.time_unit}") - for bar in bars: + for b, bar in enumerate(bars): bar_start = bar.start_ticks bar_end = bar.end_ticks # sort notes by start_ticks - bar.notes = sort_notes(bar.notes) - all_note_starts = [note.start_ticks for note in bar.notes] - all_note_ends = [note.end_ticks for note in bar.notes] + notes = self.midi_object.get_notes_in_bar(b, program) tokens += "BAR_START " - if len(bar.notes) == 0: + if len(notes) == 0: delta_symb = get_symbolic_duration( bar_end - bar_start, True ) @@ -274,16 +270,18 @@ def tokenize_track_bars( tokens += "BAR_END " continue else: - if bar.notes[0].start_ticks - bar_start != 0: + all_note_starts = [note.start_ticks for note in notes] + all_note_ends = [note.end_ticks for note in notes] + if notes[0].start_ticks - bar_start != 0: delta_symb = get_symbolic_duration( - bar.notes[0].start_ticks, True + notes[0].start_ticks, True ) delta_val = int( NoteLengths[delta_symb].value / NoteLengths[self.args.time_unit].value ) - #tokens += f"TIME_DELTA={delta_val - bar_start} " if delta_val - bar_start != 0 else "TIME_DELTA=1 " - if delta_val - bar_start != 0: tokens += f"TIME_DELTA={delta_val - bar_start} " - + if delta_val - bar_start != 0: + tokens += f"TIME_DELTA={delta_val - bar_start} " + all_time_events = all_note_starts + all_note_ends num_notes = len(all_note_starts) i = 0 @@ -293,12 +291,12 @@ def tokenize_track_bars( # The 1st note event will always be the 1st note on note_idx = event_idx % num_notes if event_idx < num_notes: - tokens += f"NOTE_ON={bar.notes[note_idx].pitch} " + tokens += f"NOTE_ON={notes[note_idx].pitch} " if self.args.velocity: - tokens += f"VELOCITY={bar.notes[note_idx].velocity} " + tokens += f"VELOCITY={notes[note_idx].velocity} " else: - tokens += f"NOTE_OFF={bar.notes[note_idx].pitch} " - + tokens += f"NOTE_OFF={notes[note_idx].pitch} " + if len(event_idxs) == len(all_time_events): break @@ -314,11 +312,11 @@ def tokenize_track_bars( delta_val = int( NoteLengths[delta_symb].value / NoteLengths[self.args.time_unit].value ) - #tokens += f"TIME_DELTA={delta_val} " if delta_val != 0 else "TIME_DELTA=1 " - if delta_val != 0: tokens += f"TIME_DELTA={delta_val} " + if delta_val != 0: + tokens += f"TIME_DELTA={delta_val} " list_indexes = [i for i, diff in enumerate(diffs) if diff == time_delta and i not in event_idxs] - + els_on = [el for el in list_indexes if el < num_notes] els_off = [el for el in list_indexes if el >= num_notes] if len(els_on) != 0 and len(els_off) != 0: @@ -329,15 +327,15 @@ def tokenize_track_bars( event_idx = min(els_on) i += 1 event_idxs.append(event_idx) - if bar.notes[-1].end_ticks < bar_end: + if notes[-1].end_ticks < bar_end: delta_symb = get_symbolic_duration( - bar_end - bar.notes[-1].end_ticks, True + bar_end - notes[-1].end_ticks, True ) delta_val = int( NoteLengths[delta_symb].value / NoteLengths[self.args.time_unit].value ) - #tokens += f"TIME_DELTA={delta_val} " if delta_val != 0 else "TIME_DELTA=1 " - if delta_val != 0: tokens += f"TIME_DELTA={delta_val} " + if delta_val != 0: + tokens += f"TIME_DELTA={delta_val} " tokens += "BAR_END " return tokens @@ -390,10 +388,15 @@ def tokens_to_musa( """Converts a str valid tokens sequence in Musa objects.""" # Initialize midi file to write - midi = Musa() - midi.time_sig = TimeSignature(time_sig) - - instruments_tokens = cls.split_tokens_by_track(tokens) + midi = Musa(file=None) + midi.time_signature_changes = [ + { + "time_sig": TimeSignature(time_sig), + "ms": 0.0 + } + ] + tokens_list = tokens.split(" ") + instruments_tokens = cls.split_tokens_by_track(tokens_list) for inst_idx, instr_tokens in enumerate(instruments_tokens): # First index in instr_tokens is the instr program # We just want the INST token in this loop @@ -407,7 +410,7 @@ def tokens_to_musa( global_time_delta = 0 for bar_idx, bar in enumerate(bar_tokens): bar_obj = Bar() - midi.instruments[inst_idx].bars.append(bar_obj) + midi.bars.append(bar_obj) if absolute_timing: global_time_delta_ticks = bar_idx * ticks_bar else: @@ -445,104 +448,19 @@ def tokens_to_musa( start=start_time, end=end_time, velocity=int(vel), + instrument_prog=midi.instruments[inst_idx].program, + bar_idx=idx ) - midi.instruments[inst_idx].bars[bar_idx].notes.append(note) + midi.notes.append(note) break else: continue return midi - @staticmethod - def _get_pieces_tokens(tokens: str) -> List[List[str]]: - """Converts the tokens str that can contain one or more - pieces into a list of pieces that are also lists which contain - one item per token. - - Example - ------- - >>> tokens = "PIECE_START INST=0 ... PIECE_START ..." - >>> dataset_tokens = _get_pieces_tokens(tokens) - >>> [ - ["PIECE_START INST=0 ...], - ["PIECE_START ...], - ] - """ - tokens = tokens.split("PIECE_START") - tokens.remove("") - dataset_tokens = [] - for piece in tokens: - piece_tokens = piece.split(" ") - piece_tokens.remove("") - dataset_tokens.append(piece_tokens) - return dataset_tokens + @classmethod + def get_pieces_tokens(cls, tokens: str): + return cls._get_pieces_tokens(tokens, "PIECE_START") @classmethod def get_tokens_analytics(cls, tokens: str) -> Dict[str, int]: - """ - Extracts features to aanlyze the given token sequence. - - Parameters - ---------- - - tokens: str - A token sequence. - - Returns - ------- - - analytics: Dict[str, int] - The ``analytics`` dict keys are: - - ``total_tokens`` - - ``unique_tokens`` - - ``total_notes`` - - ``unique_notes`` - - ``total_bars`` - - ``total_instruments`` - - ``unique_instruments`` - """ - # Convert str in list of pieces that contain tokens - dataset_tokens = cls._get_pieces_tokens(tokens) - # Start the analysis - note_counts, bar_counts, instr_counts = 0, 0, 0 # total notes and bars (also repeated note values) - total_toks = 0 - unique_tokens, unique_notes, unique_instr = [], [], [] # total non-repeated tokens - unique_genres, unique_composers, unique_periods = [], [], [] - for piece, toks in enumerate(dataset_tokens): - for tok in toks: - total_toks += 1 - if tok not in unique_tokens: - unique_tokens.append(tok) - if "NOTE_ON" in tok: - note_counts += 1 - if "BAR_START" in tok: - bar_counts += 1 - if "INST" in tok: - instr_counts += 1 - if "NOTE_ON" in tok and tok not in unique_notes: - unique_notes.append(tok) - if "INST" in tok and tok not in unique_instr: - unique_instr.append(tok) - if "GENRE" in tok and tok not in unique_genres: - unique_genres.append(tok) - if "PERIOD" in tok and tok not in unique_periods: - unique_periods.append(tok) - if "COMPOSER" in tok and tok not in unique_composers: - unique_composers.append(tok) - - analytics = { - "total_pieces": piece + 1, - "total_tokens": total_toks, - "unique_tokens": len(unique_tokens), - "total_notes": note_counts, - "unique_notes": len(unique_notes), - "total_bars": bar_counts, - "total_instruments": instr_counts, - } - if len(unique_genres) != 0: - analytics.update({"unique_genres": len(unique_genres)}) - if len(unique_periods) != 0: - analytics.update({"unique_periods": len(unique_periods)}) - if len(unique_composers) != 0: - analytics.update({"unique_composers": len(unique_composers)}) - - return analytics + return cls._get_tokens_analytics(tokens, "NOTE_ON", "PIECE_START") diff --git a/musicaiz/tokenizers/remi.py b/musicaiz/tokenizers/remi.py new file mode 100644 index 0000000..62cc084 --- /dev/null +++ b/musicaiz/tokenizers/remi.py @@ -0,0 +1,403 @@ + +from typing import Optional, List, Dict, Union, TextIO +from pathlib import Path +import logging +from dataclasses import dataclass + + +from musicaiz.loaders import Musa +from musicaiz.rhythm.timing import Subdivision, TimeSignature +from musicaiz.structure import Note, Instrument, Bar +from musicaiz.tokenizers import EncodeBase, TokenizerArguments +from musicaiz.rhythm import ( + TimingConsts, + NoteLengths, + ms_per_note, +) + + +# time units available to tokenize +VALID_TIME_UNITS = ["SIXTEENTH", "THIRTY_SECOND", "SIXTY_FOUR", "HUNDRED_TWENTY_EIGHT"] + + +logger = logging.getLogger("remi-tokenizer") +logging.basicConfig(level=logging.INFO) + + +@dataclass +class REMITokenizerArguments(TokenizerArguments): + """ + This is the REMI arguments class. + The default parameters are selected by following the original REMI representation. + + prev_tokens: str + if we want to add tokens after the `PIECE_START` token and before + the 1st TRACK_START token (for conditioning...). + + sub_beat: str + the note length in `VALID_TIME_UNITS` that one bar is divided in. Note that + this refers to the subdivisions of a 4/4 bar, so if we have different time + signatures, the ratio of `sub_beat / bar.denominator` will be maintained + to prevent wrong subdivisions when using bars with different denominators. + + velocity: bool + if we want to add the velocity token. Velocities ranges between 1 and 128 (ints). + + quantize: bool + if we want to quantize the symbolic music data for tokenizing. + """ + + prev_tokens: str = "" + sub_beat: str = "SIXTEENTH" # 16 in a 4/4 bar + num_programs: Optional[List[int]] = None + velocity: bool = False + quantize: bool = True + + +class REMITokenizer(EncodeBase): + """ + This class presents methods to compute the REMI Encoding. + The REMI encoding for piano pieces (mono-track) was introduced in: + *Huang, Y. S., & Yang, Y. H. (2020, October). + Pop music transformer: Beat-based modeling and generation of expressive pop piano compositions. + In Proceedings of the 28th ACM International Conference on Multimedia (pp. 1180-1188).* + + For multi-track pieces, the REMI encoding was adapted by: + *Zeng, M., Tan, X., Wang, R., Ju, Z., Qin, T., & Liu, T. Y. (2021). + Musicbert: Symbolic music understanding with large-scale pre-training. + arXiv preprint arXiv:2106.05630.* + + In this implementation, both mono-track and multi-track are handled. + + This encoding works divides a X/4 bar in 16 sub-beats which means that + each quarter or crotchet is divided in 4 sub-beats (16th notes). In spite + of that and for allowing developers having more control over the beats + division, we can change that value to other divisions as a function of the + selected note length. + The music is quantized but, as happens with the sub-beats tokens, we can + specify if we want to quantize or not with the `quantize` argument. + The note's duration are ex`ressed in its symbolic length, e.g., a duration + equal to 1 is a whole note and a duration of 16 is a 16th note. + + This hiherarchical tokenization is organized as follows: + - Bar -> [BAR] Position from 1/16 to 16/16 + - Position -> [POS=1/16] [TEMPO=X] [INST=X] [PITCH=X] [DUR=1] [VEL=X] ... + + Note that if a position or sub-beat does not contain notes, it'll not be present + in the tokenization. This allows preventing having usueful or "empty" tokens. + + Attributes + ---------- + file: Optional[Union[str, TextIO, Path]] = None + """ + + def __init__( + self, + file: Union[str, TextIO, Path], + args: REMITokenizerArguments = None + ): + + if args is None: + raise ValueError("No `REMITokenizerArguments` passed.") + self.args = args + + # Convert file into a Musa object to be processed + self.midi_object = Musa( + file=file, + absolute_timing=False, + quantize=self.args.quantize, + cut_notes=False + ) + + def tokenize_file( + self, + ) -> str: + """ + This method tokenizes a Musa (MIDI) object. + + Returns + ------- + + all_tokens: List[str] + the list of tokens corresponding to all the windows. + """ + # Do not tokenize the tracks that are not in num_programs + # but if num_programs is None then tokenize all instruments + + tokenized_instruments = [] + if self.args.num_programs is not None: + for inst in self.midi_object: + if inst.program in self.args.num_programs: + tokenized_instruments.append(inst) + else: + tokenized_instruments = self.midi_object.instruments + + tokens = self.tokenize_bars( + tokens=self.args.prev_tokens, + ) + tokens += "\n" + return tokens + + def tokenize_bars( + self, + tokens: Optional[str] = None, + ) -> str: + """ + This method tokenizes a given list of musicaiz bar objects. + + Parameters + ---------- + + bars: List[Bar] + + tokens: str + the number of bars per track to tokenize. + + Returns + ------- + + tokens: str + the tokens corresponding to the bars. + """ + if tokens is None: + tokens = "" + + # check valid time unit + if self.args.sub_beat not in VALID_TIME_UNITS: + raise ValueError(f"Invalid time unit: {self.args.sub_beat}") + + prev_bpm = 0 + for b_idx, bar in enumerate(self.midi_object.bars): + all_notes = self.midi_object.get_notes_in_bar(b_idx) + if len(all_notes) == 0: + continue + tokens += f"BAR={b_idx} " + tokens += f"TIME_SIG={bar.time_sig.time_sig} " + # Get subdivisions in bar with bar index + subdivs = self.midi_object.get_subbeats_in_bar(bar_idx=b_idx) + subdivs_idxs = [i for i in range(len(subdivs))] + for sub in subdivs_idxs: + notes = self.midi_object.get_notes_in_subbeat_bar( + sub, b_idx + ) + if len(notes) == 0: + continue + bpm = subdivs[sub].bpm + tokens += f"SUB_BEAT={sub} " + if bpm != prev_bpm: + tokens += f"TEMPO={int(bpm)} " + prev_bpm = bpm + # Get notes in subdivision with subdivision min and max indexes + prev_prog = notes[0].instrument_prog + for note in notes: + prog = note.instrument_prog + if prog != prev_prog or "PITCH" not in tokens: + tokens += f"INST={note.instrument_prog} " + tokens += f"PITCH={note.pitch} " + note_dur = int( + NoteLengths[note.symbolic].value / NoteLengths[self.args.sub_beat].value + ) + tokens += f"DUR={note_dur} " + tokens += f"VELOCITY={note.velocity} " + prev_prog = prog + prev_bpm = bpm + return tokens.rstrip() + + @staticmethod + def _split_tokens( + piece_tokens: List[str], + token: str, + ) -> List[List[str]]: + """Split tokens list by token""" + indices = [i for i, x in enumerate(piece_tokens) if x.split("=")[0] == token] + lst = [piece_tokens[start:end] for start, end in zip([0, *indices], [*indices, len(piece_tokens)])] + return [el for el in lst if el != []] + + @classmethod + def split_tokens_by_bar( + cls, + piece_tokens: List[str], + ) -> List[List[str]]: + """Split tokens list by bar""" + return cls._split_tokens(piece_tokens, "BAR") + + @classmethod + def split_tokens_by_subbeat( + cls, + piece_tokens: List[str], + ) -> List[List[str]]: + """Split tokens list by subbeat""" + bars = cls.split_tokens_by_bar(piece_tokens) + subbeats = [] + for bar in bars: + subbeats.extend(cls._split_tokens(bar, "SUB_BEAT")) + return subbeats + + # TODO + @classmethod + def tokens_to_musa( + cls, + tokens: Union[str, List[str]], + sub_beat: str = "SIXTY_FOUR", + resolution: int = TimingConsts.RESOLUTION.value, + ) -> Musa: + + """Converts a str valid tokens sequence in Musa objects. + + This representation does not store beats in the Musa object, but + it stores bars, notes, instruments and subbeats.""" + # Initialize midi file to write + midi = Musa(file=None, resolution=resolution) + midi.resolution = resolution + + if isinstance(tokens, str): + tokens = tokens.split(" ") + + midi.subbeats = [] + midi.bars = [] + midi.instruments_progs = [] + + sb_tokens = cls.split_tokens_by_subbeat(tokens) + global_subbeats, pos, total_subbeats = 0, 0, 0 + prev_pos, prev_bar_pos = 0, 0 + for sb_tokens in sb_tokens: + if "BAR" in sb_tokens[0]: + time_sig = sb_tokens[1].split("=")[1] + bar_pos = int(sb_tokens[0].split("=")[1]) + # if bar has incomplete subbeats, we fill them + prev_bar_subbeats = len([s for s in midi.subbeats if s.bar_idx == prev_bar_pos]) + if prev_bar_subbeats != 0: + for ii in range(prev_bar_subbeats, int(TimeSignature(time_sig)._notes_per_bar(sub_beat))): + midi.subbeats.append( + Subdivision( + bpm=tempo_token, + resolution=midi.resolution, + start=(midi.subbeats[-1].global_idx + 1) * sec_subbeat, + end=((midi.subbeats[-1].global_idx + 1) + 1) * sec_subbeat, + time_sig=TimeSignature(time_sig=time_sig), + global_idx=global_subbeats + ii, + bar_idx=prev_bar_pos, + beat_idx=None, + ) + ) + # empty bars + if bar_pos != prev_bar_pos + 1 and bar_pos != 0: + # Generate all subdivisions of the empty bar + for b in range(prev_bar_pos + 1, bar_pos): + total_subbeats += (b - prev_bar_pos - 1) * int(TimeSignature(time_sig)._notes_per_bar(sub_beat)) + for ii in range(0, int(TimeSignature(time_sig)._notes_per_bar(sub_beat))): + midi.subbeats.append( + Subdivision( + bpm=tempo_token, + resolution=midi.resolution, + start=(total_subbeats + ii) * sec_subbeat, + end=((total_subbeats + ii) + 1) * sec_subbeat, + time_sig=TimeSignature(time_sig=time_sig), + global_idx=total_subbeats + ii, + bar_idx=b, + beat_idx=None, + ) + ) + total_subbeats += int(TimeSignature(time_sig)._notes_per_bar(sub_beat)) + global_subbeats = total_subbeats + total_subbeats += int(TimeSignature(time_sig)._notes_per_bar(sub_beat)) + prev_bar_pos = bar_pos + continue + if "SUB_BEAT" in sb_tokens[0]: + if sb_tokens[1].split("=")[0] == "TEMPO": + tempo_token = int(sb_tokens[1].split("=")[1]) + if sb_tokens[2].split("=")[0] == "INST": + inst_token = int(sb_tokens[2].split("=")[1]) + if inst_token not in midi.instruments_progs: + midi.instruments_progs.append(inst_token) + midi.instruments.append( + Instrument( + program=inst_token + ) + ) + pos = int(sb_tokens[0].split("=")[1]) + # Last subbeat + sec_subbeat = ms_per_note( + sub_beat.lower(), + tempo_token, + midi.resolution + ) / 1000 + # Generate beats that have no notes + empty_subbeats = pos - prev_pos - 1 + if empty_subbeats >= 1: + for ii in range(prev_pos, pos): + midi.subbeats.append( + Subdivision( + bpm=tempo_token, + resolution=midi.resolution, + start=(global_subbeats + ii) * sec_subbeat, + end=((global_subbeats + ii) + 1) * sec_subbeat, + time_sig=TimeSignature(time_sig=time_sig), + global_idx=global_subbeats + ii, + bar_idx=bar_pos, + beat_idx=None, + ) + ) + sb = Subdivision( + bpm=tempo_token, + resolution=midi.resolution, + start=(midi.subbeats[-1].global_idx + 1) * sec_subbeat, + end=((midi.subbeats[-1].global_idx + 1) + 1) * sec_subbeat, + time_sig=TimeSignature(time_sig=time_sig), + global_idx=midi.subbeats[-1].global_idx + 1, + bar_idx=bar_pos, + beat_idx=None, + ) + midi.subbeats.append(sb) + prev_pos = pos + 1 + + for j, tok in enumerate(sb_tokens): + if tok.split("=")[0] == "PITCH": + pitch = int(tok.split("=")[1]) + dur = int(sb_tokens[j + 1].split("=")[1]) + vel = int(sb_tokens[j + 2].split("=")[1]) + note = Note( + start=sb.start_sec, + end=sb.start_sec + dur * sec_subbeat, + instrument_prog=int(inst_token), + pitch=pitch, + velocity=vel, + subbeat_idx=sb.global_idx, + ) + midi.notes.append(note) + # complete subbeats last bar if they're incomplete + prev_bar_subbeats = len([s for s in midi.subbeats if s.bar_idx == bar_pos]) + if prev_bar_subbeats != 0: + for ii in range(prev_bar_subbeats, int(TimeSignature(time_sig)._notes_per_bar(sub_beat))): + midi.subbeats.append( + Subdivision( + bpm=tempo_token, + resolution=midi.resolution, + start=(midi.subbeats[-1].global_idx + 1) * sec_subbeat, + end=((midi.subbeats[-1].global_idx + 1) + 1) * sec_subbeat, + time_sig=TimeSignature(time_sig=time_sig), + global_idx=midi.subbeats[-1].global_idx + 1, + bar_idx=bar_pos, + beat_idx=None, + ) + ) + # Generate Bars + bar_pos = 0 + last_bar = midi.subbeats[-1].bar_idx + for i in range(0, last_bar + 1): + sub = [s for s in midi.subbeats if s.bar_idx == i][0] + end_sec = [s.end_sec for s in midi.subbeats if s.bar_idx == i][-1] + midi.bars.append( + Bar( + start=sub.start_sec, + end=end_sec, + time_sig=sub.time_sig, + resolution=resolution, + bpm=sub.bpm + ) + ) + return midi + + @classmethod + def get_tokens_analytics(cls, tokens: str) -> Dict[str, int]: + return cls._get_tokens_analytics(tokens, "PITCH", "BAR=0") diff --git a/musicaiz/utils.py b/musicaiz/utils.py index 5aff292..84cd80f 100644 --- a/musicaiz/utils.py +++ b/musicaiz/utils.py @@ -29,7 +29,7 @@ def group_notes_in_subdivisions_bars(musa_obj: Musa, subdiv: str) -> List[List[L """This function groups notes in the selected subdivision. The result is a list which elements are lists that represent the bars, and inside them, lists that represent the notes in each subdivision. - + Parameters ---------- musa_obj: Musa @@ -82,13 +82,12 @@ def get_highest_subdivision_bars_notes( all_subdiv_notes: List[List[List[Note]]] ) -> List[List[Note]]: """Extracts the highest note in each subdivision. - + Parameters ---------- all_subdiv_notes: List[List[List[Note]]] A list of bars in which each element is a subdivision which is a List of notes that are in the subdivision. - Returns ------- diff --git a/musicaiz/version.py b/musicaiz/version.py index 2799ff6..e63e266 100644 --- a/musicaiz/version.py +++ b/musicaiz/version.py @@ -1,5 +1,5 @@ MAJOR = 0 -MINOR = 0 -PATCH = 2 +MINOR = 1 +PATCH = 0 __version__ = "%d.%d.%d" % (MAJOR, MINOR, PATCH) \ No newline at end of file diff --git a/tests/fixtures/midis/midi_changes.mid b/tests/fixtures/midis/midi_changes.mid new file mode 100644 index 0000000000000000000000000000000000000000..51526467e6cfc33901c23070214f435840dbb11d GIT binary patch literal 13476 zcmd^``&(4k)yKCJbgYRw&BZt7G7ce7qhUCq1|$fMprC*-h!A5YBQ%gmVoq!%(WXrs zNHsS80_Y%7-yh!ppzUc(df|E6zSpOhzbXB!wa;xBP-|b4_wvK{*=Nq#m$lbzuf6x0 zqsLE=8)G(rb>{a+kH7b}F?-%HUnUOjeEqrZ$um>uC(lh~PESppeW{`0{Q2|Q(_`;U zPPh|O+1KBFr(x0^pSX~n{9tVVWLs0+Sljr-)Y#;ib9IyBZH>;Jy?e};gAXMCC1buk zQB{?Ad}HwQ_x0wd$qVMEn{wt4J?3Z0kIc__{Npb3^W>!YIgkJPg7K2$#^dqdUNFDl z@fSS)yJLPyvtRP~ACCD|@}&6{kN^3YS)#%ckN>sZe8JlGh={D#hc!{e`$jpo;s ze?Dw}n>=TJ%i~|Rn%^ad&F^^p>+@!r=F2=n?8Yyijfj0Ekud)!e$CdoW}(h?>dbU% zw$7M(e%8BAz5J{aH)w0pcSxqz>~quFNt#yn*{5B`zs>A~m=@P* zF-lLLyG?sJJlkS~)P*M3Y2uTT;#_Svro%7|z8TYEoLP$9Os~~7%c_=GCghnKi^TYBZ93 z*pZ06Z>89s3UQAjZgitu&`uy>v9US^RuPB^sbTlHjDMf}%P36=T*^qF*A9h7HH z=AkoohRrGvVV-p_*KM)bR?a-tuQKK#H*>iyK1?MOk)c0~Wn|cQ1tYcpaz@fOihPT໖$dxgSvGZ{S7V(S_TcZA8b zidLX(c41-ZTus^Rvd|%-DOa|r`$}HfuHjSIb%nwO{g&}*M#vHk!mc2FO@67olzdgX zQKs+A!Un5xsl|4ao`(aNnVH=I*O=SF4%gXXoW*80gXOhfZZ?S$zR?nR-4+b#l<#=mvybNk8s9L<`V<%22_93r5$F=4B-KD zc}UhL=(Zlb*?j~}uS2(M)tzF5x-Nx=hz0U>eAC6i3ZKF%9>pEdF8}n$JB;~Mt?(OI zRuV6hmK9gHt(M3lfhGESM?X+fZRi0CWC0d`*b<{aoQvYFTS~(esEmYi^}jDFH*DA=PC&oA)zEOKx;%o;3xfs6ZeDL%Hm;^ zExg57jf0bE@rtlOTJyUrbCE?APm{Jis26nToiLkKn1P0Lplpe?G;)llH0)-Fso{Hm z@r|XCIK@!i`{)`J=1LIT66gmD zbFS0}?f5W}F?y6OA>2qSq$hTm>dpdSv49PZhixS3Ih*3j`e?j6U08%>8D?uq37@+OwN= z;d3#FAETfc!%!Qw@ujFF)fnt{HGIDs_SWu>@=E-h4HsaWiYV`ugZGZUA+ioVpu zEwr^bcj-0iMMkHm%sfp>L6i1{U9f@2^EN}=GXYvsrV4)Gu1503*&lZk%l%a%2a!}Rc zUdj90Wg;Xw^Wnh9uflS2jSYBfRHd6L?6g~{60VqX9#m@WozAqHOglxp$#hb5noLLU zbn0rbj>^GKs?=dJSHsV2b9(p+fHS z{kTC|+#m)nwG>Qy!E_X4>t{aaG(ncmxLMjJO{+BFG|bivdM*aN`h68KZ2Hs?Sh^8v zbTfXdQC){nz&qR>6Jd95spV9PGnlNr_bHb5LElIR=`T*r-Woox~Kz_XoYskMUF8VX@` zS#g23Sq_G!{bq@=yhb5! zbox}IbT>MEs+$YO5-uYvuLG6TepS9vTZd!J>|(6F$TpIfFXbpeL+-biwJ+gSrE-h2 zP7VOerSf%XzHqZlLo1|WRZ{AG6&A2kf6HhKrBPuDmDBioO~OtiQs_YGp%d2?nv*b9 zR%q+CRGHUkQz4e9-v1b52~6Zph7pP}aDqnI5O_E7gTg*j?na%}OlCQ`&dkP~*-hf( zTm1fKxTA~Ry5(ddv7CGWRDtzi14x36;2U5Qcn~}UHiIo-D@cKF(*9w}N5G@tF|ZB% z6L=hKPh7%LUEv6CIr#*?p9D{Vr$HL*02xpXo}sTA%4b0>r~}^uJHc~c7uXG+2lXHe z8h`^{Sof3thW3r%+p>EJ+%(_zfEPg%Xa;*h3wQ~%f_DGz}z z&V2Q*AApa*4>$O(V;7L~?rzlX0bTaTqSYhDV;>&o2&2nDw)!FNn zA!Ed5@3XzaBWtS?P9Z;jf4z2Ub?M%1RqRx3BCu1!DbWFJWmOyH{o3}4z3EZAA8HDJ z(6Kw(DSejPXh0Kl3+$AB?r5Xb`F~@dtgPndkeeAY?f#ZVm(|Ev&COvqGi>x4^)~hz z3%zf0YIlp@zRh-uKf2FrEWz@Xk+05Xxu}u#^?$w1a?zygYhbgoj)>jf99+|GH(VLK z!`g}uzmOKSt;}UUceTF$x2~=6@t#Kayku}4ho5fgjF#2zuRM@;My6MMwO9x<^; zOzcTa%#p*D6NZ$wC;0s&cnUlX(qIS3fNJmzebrDt3u-|f_!ig+o&&qUZty&)2U*Yn z9AIN&k9F3Qm^j*|@@)@z5j25juotv|mq07n2iib8*bj2x0O$an^mmZ*5a zdO$De14ltW7yx-N2#$dv+8?JJ1|#5Qa00vnPJ&lK0lWr^;1n1IV}O{LZO~^NF3EcB z5fe*Vpns5%^jK>>;$e??*dreHh=;x8Ti|VQW}`=6rl#KIo2uqUxFn<=%|h(8Z_ zthHY99GC*{)9wSx^WZ|_6Y3u_>penaFZm((KKKFn2oM%WeCaLEQjlf z{z8`UQ@6SUgg?>HWxIK4?@x_oDmW%tx$ToQDbBAL$v^!{l?K8iJsm?ffPRJ=XT+tR zLU6XF*Cd>6ao}H&`w=CmU*oZ0|5Uawzs(*y$K6J7wq^Sn)aj0jt+H0+=sWlx9Bk#d zxR6ULHQrMBxc06@NE<9?l&*I8h7&8^xiw7JOFAp9Sa?u^ zfY``d$Lt-n93^)qBBa65aXVFmK2I8l#zJ-jgcA6dbYA^aCm5FS}s!ZGx0=cwAKwpV&y(^OLnwPwpam(qi7 zniq13C_RVwrtAF-uHb+^N+Ag+g}8G@ClzFENt%zJAQnz8%GRd&K8xOABbo4kaiwH}9eJ5#{B){t0``I$;=^uKIzij4 z4#ly*P5=)@j)ZzAxkL3Oq-jyt@fNs0k^o|9QT*oWzdS9fsL2@69iG@nSy7hx!NGlM zhCb0M-op4ps!=I&Gg3A;L$|1+V3M51{a8Fp%D;Ga7OLgCtc>+#2a=JLQuEwl%B_r0 zi?TS;m3vk)v_#kR4jE@A>E`oH*0;5cWN22to|cQ>B3GY-ZV?^S%n`|su67-$Vu;+l zWGU!F8F(O0cE4P4GMR*mb_s8g0CinT;+Z*}gjhRxcMjEWi! zRudxNDCbRt8wnQI9Ob!bx(V6Aq+-a;i}EAl#hoW90@2g}{>kuQ^3VLU!X;?R^9PEP zBu*D9((-Kzyu<>y=jnEl0|kQcfAtkK$G(u}i+%1Xd{RWK`wE&iUzo1bHkFC;D>LMa zni@~%k_=yvWccC2jFEKr!s_YntUR$l5pWaK;?9s*PDXp7BBNcRt3)pi?5ag3yUk)R zOrMsN^}-q{>;AmsJdnk1F!m%TTZIX^u_@s+h0|m_@ z|GOD7W5^8;hvJghI+ZZ>-zM=_^1`Nk7iRXZNbn};jJPn9Up0qYG~pWd*y!hPt1OvX zN`hCtC78~zxa%EE=B{2%kKU%*k27R$Yq+XCGPM4!;mZcIue!cIKlbKCrhCF2dv9!N cVmveHW}K|^;)~yI$_!1s{;oTIuD-tgZx($s+yDRo literal 0 HcmV?d00001 diff --git a/tests/fixtures/tokenizers/remi_tokens.txt b/tests/fixtures/tokenizers/remi_tokens.txt new file mode 100644 index 0000000..f7efc51 --- /dev/null +++ b/tests/fixtures/tokenizers/remi_tokens.txt @@ -0,0 +1 @@ +BAR=0 TIME_SIG=4/4 SUB_BEAT=4 TEMPO=120 INST=30 PITCH=69 DUR=4 VELOCITY=127 PITCH=64 DUR=8 VELOCITY=127 SUB_BEAT=8 PITCH=67 DUR=4 VELOCITY=127 SUB_BEAT=12 PITCH=64 DUR=4 VELOCITY=127 BAR=2 TIME_SIG=4/4 SUB_BEAT=0 PITCH=72 DUR=4 VELOCITY=127 SUB_BEAT=4 PITCH=69 DUR=4 VELOCITY=127 SUB_BEAT=12 PITCH=67 DUR=2 VELOCITY=127 \ No newline at end of file diff --git a/tests/unit/musicaiz/algorithms/test_chord_prediction.py b/tests/unit/musicaiz/algorithms/test_chord_prediction.py new file mode 100644 index 0000000..2bd3e24 --- /dev/null +++ b/tests/unit/musicaiz/algorithms/test_chord_prediction.py @@ -0,0 +1,19 @@ +import pytest + +from musicaiz.loaders import Musa +from musicaiz.algorithms import predict_chords + + +@pytest.fixture +def midi_sample(fixture_dir): + return fixture_dir / "midis" / "midi_data.mid" + + +def test_predict_chords(midi_sample): + # Import MIDI file + midi = Musa(midi_sample) + + got = predict_chords(midi) + assert len(got) == len(midi.beats) + for i in got: + assert len(i) != 0 diff --git a/tests/unit/musicaiz/converters/test_musa_json.py b/tests/unit/musicaiz/converters/test_musa_json.py new file mode 100644 index 0000000..80a4bc4 --- /dev/null +++ b/tests/unit/musicaiz/converters/test_musa_json.py @@ -0,0 +1,31 @@ +from musicaiz.converters import ( + MusaJSON, + BarJSON, + InstrumentJSON, + NoteJSON +) +from musicaiz.loaders import Musa +from .test_musa_to_protobuf import midi_sample + + +def test_MusaJSON(midi_sample): + midi = Musa( + midi_sample, + ) + got = MusaJSON.to_json(midi) + + assert got["tonality"] is None + assert got["resolution"] == 480 + assert len(got["instruments"]) == 2 + assert len(got["bars"]) == 3 + assert len(got["notes"]) == 37 + + for inst in got["instruments"]: + assert set(inst.keys()) == set(InstrumentJSON.__dataclass_fields__.keys()) + for bar in got["bars"]: + assert set(bar.keys()) == set(BarJSON.__dataclass_fields__.keys()) + for note in got["notes"]: + assert set(note.keys()) == set(NoteJSON.__dataclass_fields__.keys()) + +# TODO +#def test_JSONMusa(midi_sample, midi_data): diff --git a/tests/unit/musicaiz/converters/test_musa_to_protobuf.py b/tests/unit/musicaiz/converters/test_musa_to_protobuf.py index f74034b..561b532 100644 --- a/tests/unit/musicaiz/converters/test_musa_to_protobuf.py +++ b/tests/unit/musicaiz/converters/test_musa_to_protobuf.py @@ -41,74 +41,46 @@ def _assert_valid_note_obj(note): assert note.end_sec >= 0.0 -def test_musa_to_proto_instruments(midi_sample, midi_data): - midi = Musa( - midi_sample, - structure="instruments", - ) - got = musa_to_proto(midi) - - _assert_midi_valid_instr_obj(midi_data, got.instruments) - - # check notes - assert len(got.instruments[0].notes) != 0 - assert len(got.instruments[1].notes) != 0 - # check every note attributes are not empty - for instr in got.instruments: - for note in instr.notes: - _assert_valid_note_obj(note) - - -def test_musa_to_proto_bars(midi_sample, midi_data): - midi = Musa( - midi_sample, - structure="bars", - ) +def test_musa_to_proto(midi_sample, midi_data): + midi = Musa(midi_sample) got = musa_to_proto(midi) _assert_midi_valid_instr_obj(midi_data, got.instruments) # check bars - assert len(got.instruments[0].bars) != 0 - assert len(got.instruments[1].bars) != 0 + assert len(got.instruments) != 0 + assert len(got.bars) != 0 # check every bar attributes are not empty - for instr in got.instruments: - for i, bar in enumerate(instr.bars): - # check only the first 5 bars since the midi file is large - if i < 5: - assert bar.start_ticks >= 0 - assert bar.end_ticks >= 0 - assert bar.start_sec >= 0.0 - assert bar.end_sec >= 0.0 - # check every note for each bar - for note in bar.notes: - _assert_valid_note_obj(note) + for i, bar in enumerate(got.bars): + # check only the first 5 bars since the midi file is large + if i < 5: + assert bar.start_ticks >= 0 + assert bar.end_ticks >= 0 + assert bar.start_sec >= 0.0 + assert bar.end_sec >= 0.0 + for note in got.notes: + _assert_valid_note_obj(note) def test_proto_to_musa(midi_sample, midi_data): - midi = Musa( - midi_sample, - structure="bars", - ) + midi = Musa(midi_sample) proto = musa_to_proto(midi) got = proto_to_musa(proto) _assert_midi_valid_instr_obj(midi_data, got.instruments) # check bars - assert len(got.instruments[0].bars) != 0 + assert len(got.instruments) != 0 # check every bar attributes are not empty - for instr in got.instruments: - for i, bar in enumerate(instr.bars): - # check only the first 5 bars since the midi file is large - if i < 5: - assert bar.start_ticks >= 0 - assert bar.end_ticks >= 0 - assert bar.start_sec >= 0.0 - assert bar.end_sec >= 0.0 - # check every note for each bar - for note in bar.notes: - _assert_valid_note_obj(note) - + for i, bar in enumerate(got.bars): + # check only the first 5 bars since the midi file is large + if i < 5: + assert bar.start_ticks >= 0 + assert bar.end_ticks >= 0 + assert bar.start_sec >= 0.0 + assert bar.end_sec >= 0.0 + # check every note + for note in got.notes: + _assert_valid_note_obj(note) diff --git a/tests/unit/musicaiz/converters/test_pretty_midi_musa.py b/tests/unit/musicaiz/converters/test_pretty_midi_musa.py new file mode 100644 index 0000000..cc6381f --- /dev/null +++ b/tests/unit/musicaiz/converters/test_pretty_midi_musa.py @@ -0,0 +1,38 @@ +from musicaiz.converters import ( + prettymidi_note_to_musicaiz, + musicaiz_note_to_prettymidi, + musa_to_prettymidi, +) +from musicaiz.loaders import Musa +from .test_musa_to_protobuf import midi_sample + + +def test_prettymidi_note_to_musicaiz(): + note = "G#4" + expected_name = "G_SHARP" + expected_octave = 4 + + got_name, got_octave = prettymidi_note_to_musicaiz(note) + + assert got_name == expected_name + assert got_octave == expected_octave + + +def test_musicaiz_note_to_prettymidi(): + note = "G_SHARP" + octave = 4 + expected = "G#4" + + got = musicaiz_note_to_prettymidi(note, octave) + + assert got == expected + + +def test_musa_to_prettymidi(midi_sample): + midi = Musa(midi_sample) + got = musa_to_prettymidi(midi) + + assert len(got.instruments) == 2 + + for inst in got.instruments: + assert len(inst.notes) != 0 diff --git a/tests/unit/musicaiz/datasets/test_lmd.py b/tests/unit/musicaiz/datasets/test_lmd.py index c9ecdf7..8e4c686 100644 --- a/tests/unit/musicaiz/datasets/test_lmd.py +++ b/tests/unit/musicaiz/datasets/test_lmd.py @@ -20,7 +20,7 @@ def test_LakhMIDI_get_metadata(dataset_path): "split": "train", } } - + dataset = LakhMIDI() got = dataset.get_metadata(dataset_path) diff --git a/tests/unit/musicaiz/datasets/test_maestro.py b/tests/unit/musicaiz/datasets/test_maestro.py index 17c4192..f40576e 100644 --- a/tests/unit/musicaiz/datasets/test_maestro.py +++ b/tests/unit/musicaiz/datasets/test_maestro.py @@ -20,7 +20,7 @@ def test_Maestro_get_metadata(dataset_path): "split": "train", } } - + dataset = Maestro() got = dataset.get_metadata(dataset_path) diff --git a/tests/unit/musicaiz/features/test_harmony.py b/tests/unit/musicaiz/features/test_harmony.py index ca45291..a675084 100644 --- a/tests/unit/musicaiz/features/test_harmony.py +++ b/tests/unit/musicaiz/features/test_harmony.py @@ -8,7 +8,6 @@ predict_progression, _all_note_seq_permutations, _delete_repeated_note_names, - _extract_note_positions, _order_note_seq_by_chromatic_idx, get_harmonic_density, ) @@ -21,7 +20,6 @@ AllChords, IntervalSemitones, DegreesRoman, - Scales, Tonality, ModeConstructors, ) diff --git a/tests/unit/musicaiz/loaders/test_loaders.py b/tests/unit/musicaiz/loaders/test_loaders.py index b3835db..2febfd6 100644 --- a/tests/unit/musicaiz/loaders/test_loaders.py +++ b/tests/unit/musicaiz/loaders/test_loaders.py @@ -6,57 +6,215 @@ @pytest.fixture def midi_sample(fixture_dir): + return fixture_dir / "midis" / "midi_changes.mid" + + +@pytest.fixture +def midi_sample_2(fixture_dir): return fixture_dir / "midis" / "midi_data.mid" -def _assert_key_profiles(midi_sample, methods, expected): +def test_Musa(midi_sample): + + # args + quantize = False + quantize_note = "sixteenth" + cut_notes = False + absolute_timing = False + general_midi = True + subdivision_note = "sixteenth" + midi = Musa( + file=midi_sample, + quantize=quantize, + quantize_note=quantize_note, + cut_notes=cut_notes, + absolute_timing=absolute_timing, + general_midi=general_midi, + subdivision_note=subdivision_note, + ) + + # check attributes + assert midi.file.stem == "midi_changes" + assert midi.total_bars != 0 + assert midi.tonality is None + assert midi.subdivision_note == subdivision_note + assert len(midi.time_signature_changes) != 0 + assert midi.resolution != 0 + assert len(midi.instruments) != 0 + assert midi.is_quantized == quantize + assert midi.quantize_note == quantize_note + assert midi.absolute_timing == absolute_timing + assert midi.cut_notes == cut_notes + assert len(midi.notes) != 0 + assert len(midi.bars) != 0 + assert len(midi.subbeats) != 0 + assert len(midi.tempo_changes) != 0 + assert len(midi.instruments_progs) != 0 + + midi.bar_beats_subdivs_analysis() + + # Test methods + notes = midi.get_notes_in_subbeat( + subbeat_idx=0, program=48, instrument_idx=None + ) + assert len([n.subbeat_idx for n in notes if n.subbeat_idx != 0]) == 0 + + notes = midi.get_notes_in_subbeat_bar( + subbeat_idx=0, bar_idx=40, program=48, instrument_idx=None + ) + assert len([n.bar_idx for n in notes if n.bar_idx != 40]) == 0 + + notes = midi.get_notes_in_subbeats( + subbeat_start=0, subbeat_end=4, program=48, instrument_idx=None + ) + assert len([n.subbeat_idx for n in notes if n.subbeat_idx >= 4]) == 0 + + notes = midi.get_notes_in_beat( + beat_idx=0, program=48, instrument_idx=None + ) + assert len([n.beat_idx for n in notes if n.beat_idx != 0]) == 0 + + notes = midi.get_notes_in_subbeat_bar( + subbeat_idx=0, bar_idx=40, program=48, instrument_idx=None + ) + assert len([n.bar_idx for n in notes if n.bar_idx != 40]) == 0 + + subbeats = midi.get_subbeats_in_beat(beat_idx=50) + assert len([n.beat_idx for n in subbeats if n.beat_idx != 50]) == 0 + + subbeat = midi.get_subbeat_in_beat(subbeat_idx=2, beat_idx=50) + assert subbeat.beat_idx == 50 + + notes = midi.get_notes_in_beats( + beat_start=10, beat_end=20, program=48, instrument_idx=None + ) + assert len([n.beat_idx for n in notes if n.beat_idx >= 20 and n.beat_idx < 10]) == 0 + + subbeats = midi.get_subbeats_in_beats( + beat_start=10, beat_end=20 + ) + assert len([n.beat_idx for n in subbeats if n.beat_idx >= 20 and n.beat_idx < 10]) == 0 + + notes = midi.get_notes_in_bar( + bar_idx=12, program=48, instrument_idx=None + ) + assert len([n.bar_idx for n in notes if n.bar_idx != 12]) == 0 + + beats = midi.get_beats_in_bar( + bar_idx=30 + ) + assert len([n.bar_idx for n in beats if n.bar_idx != 30]) == 0 + + beat = midi.get_beat_in_bar( + beat_idx=0, bar_idx=80 + ) + assert beat.bar_idx == 80 + + subbeats = midi.get_subbeats_in_bar(bar_idx=1) + assert len([n.bar_idx for n in subbeats if n.bar_idx != 1]) == 0 + + subbeat = midi.get_subbeat_in_bar( + subbeat_idx=14, bar_idx=1 + ) + assert subbeat.bar_idx == 1 + + notes = midi.get_notes_in_bars( + bar_start=30, bar_end=32 + ) + assert len([n.bar_idx for n in notes if n.bar_idx < 30 and n.bar_idx >= 32]) == 0 + + beats = midi.get_beats_in_bars( + bar_start=30, bar_end=32 + ) + assert len([n.bar_idx for n in beats if n.bar_idx < 30 and n.bar_idx >= 32]) == 0 + + subbeats = midi.get_subbeats_in_bars( + bar_start=30, bar_end=32 + ) + assert len([n.bar_idx for n in subbeats if n.bar_idx < 30 and n.bar_idx >= 32]) == 0 + + # Test when passing more than one instrument program number + notes_is2 = midi.get_notes_in_bar( + bar_idx=0, + program=[48, 45], + instrument_idx=[0, 1] + ) + assert len(notes_is2) != 0 + for n in notes_is2: + assert n.instrument_prog in [48, 45] + assert n.instrument_idx in [0, 1] + + # Test errors + # error when bar_idx does not exist + with pytest.raises(ValueError): + midi.get_notes_in_bar(bar_idx=10000) + + # error when bar_start > bar_end + with pytest.raises(ValueError): + midi.get_notes_in_bars(10, 1) + + # error when no program number found + with pytest.raises(ValueError): + midi.get_notes_in_bar( + bar_idx=0, program=100, instrument_idx=None + ) + + # error when program does not match instrument_idx + # instrument_idx=4 corresponds with program 49, error + with pytest.raises(ValueError): + midi.get_notes_in_bar( + bar_idx=0, program=0, instrument_idx=4 + ) + + # error when programs and instruments_idxs have diff len + with pytest.raises(ValueError): + midi.get_notes_in_bar( + bar_idx=0, program=[100, 47], instrument_idx=[0] + ) + + # error when programs do not match instrument_idxs + with pytest.raises(ValueError): + midi.get_notes_in_bar( + bar_idx=2, program=[48, 0], instrument_idx=[1, 2] + ) + + +# Predict key tests +def _assert_key_profiles(midi_sample_2, methods, expected): # try both instruments and bars initializations in Musa - midi_instr = Musa(midi_sample, structure="instruments") - midi_bars = Musa(midi_sample, structure="bars") + midi_instr = Musa(midi_sample_2) for method in methods: got = midi_instr.predict_key(method) assert got == expected - got = midi_bars.predict_key(method) - assert got == expected - -def test_predict_key_kk(midi_sample): +def test_predict_key_kk(midi_sample_2): # Test case: K-K methods = KeyDetectionAlgorithms.KRUMHANSL_KESSLER.value expected = "F_MAJOR" - _assert_key_profiles(midi_sample, methods, expected) + _assert_key_profiles(midi_sample_2, methods, expected) -def test_predict_key_temperley(midi_sample): +def test_predict_key_temperley(midi_sample_2): # Test case: K-K methods = KeyDetectionAlgorithms.TEMPERLEY.value expected = "F_MAJOR" - _assert_key_profiles(midi_sample, methods, expected) + _assert_key_profiles(midi_sample_2, methods, expected) -def test_predict_key_albretch(midi_sample): +def test_predict_key_albretch(midi_sample_2): # Test case: K-K methods = KeyDetectionAlgorithms.ALBRETCH_SHANAHAN.value expected = "F_MAJOR" - _assert_key_profiles(midi_sample, methods, expected) + _assert_key_profiles(midi_sample_2, methods, expected) -def test_predict_key_5ths(midi_sample): +def test_predict_key_5ths(midi_sample_2): # Test case: K-K methods = KeyDetectionAlgorithms.SIGNATURE_FIFTHS.value expected = "A_SHARP_MAJOR" - midi_instr = Musa(midi_sample, structure="instruments") - midi_bars = Musa(midi_sample, structure="bars") - for method in methods: - # Signature 5ths does not work for structure="instruments" - with pytest.raises(ValueError): - midi_instr.predict_key(method) - - got = midi_bars.predict_key(method) - assert got == expected - + _assert_key_profiles(midi_sample_2, methods, expected) diff --git a/tests/unit/musicaiz/plotters/test_pianorolls.py b/tests/unit/musicaiz/plotters/test_pianorolls.py index a128e85..646d694 100644 --- a/tests/unit/musicaiz/plotters/test_pianorolls.py +++ b/tests/unit/musicaiz/plotters/test_pianorolls.py @@ -9,29 +9,64 @@ def midi_sample(fixture_dir): return fixture_dir / "tokenizers" / "mmm_tokens.mid" +@pytest.fixture +def midi_multiinstr(fixture_dir): + return fixture_dir / "midis" / "midi_changes.mid" + def test_Pianoroll_plot_instrument(midi_sample): - plot = Pianoroll() - musa_obj = Musa(midi_sample, structure="bars") - plot.plot_instrument( - track=musa_obj.instruments[0].notes, - total_bars=2, - subdivision="quarter", - time_sig=musa_obj.time_sig.time_sig, - print_measure_data=False, - show_bar_labels=False + # Test case: plot one instrument + musa_obj = Musa(midi_sample) + plot = Pianoroll(musa_obj) + plot.plot_instruments( + program=30, + bar_start=0, + bar_end=4, + print_measure_data=True, + show_bar_labels=False, + show_grid=False, + show=False, + ) + plt.close("all") + + +def test_Pianoroll_plot_instruments(midi_multiinstr): + # Test case: plot multiple instruments + musa_obj = Musa(midi_multiinstr) + plot = Pianoroll(musa_obj) + plot.plot_instruments( + program=[48, 45, 74, 49, 49, 42, 25, 48, 21, 46, 0, 15, 72, 44], + bar_start=0, + bar_end=4, + print_measure_data=True, + show_bar_labels=False, + show_grid=True, + show=False, ) - plt.close('all') + plt.close("all") def test_PianorollHTML_plot_instrument(midi_sample): - plot = PianorollHTML() - musa_obj = Musa(midi_sample, structure="bars") - plot.plot_instrument( - track=musa_obj.instruments[0], - bar_start=1, + musa_obj = Musa(midi_sample) + plot = PianorollHTML(musa_obj) + plot.plot_instruments( + program=30, + bar_start=0, bar_end=2, - subdivision="quarter", - time_sig=musa_obj.time_sig.time_sig, + show_grid=False, + show=False + ) + plt.close("all") + + +def test_PianorollHTML_plot_instruments(midi_multiinstr): + musa_obj = Musa(midi_multiinstr) + plot = PianorollHTML(musa_obj) + plot.plot_instruments( + program=[48, 45, 74, 49, 49, 42, 25, 48, 21, 46, 0, 15, 72, 44], + bar_start=0, + bar_end=4, + show_grid=False, show=False ) + plt.close("all") diff --git a/tests/unit/musicaiz/rhythm/test_quantizer.py b/tests/unit/musicaiz/rhythm/test_quantizer.py index 28bc9f5..083b648 100644 --- a/tests/unit/musicaiz/rhythm/test_quantizer.py +++ b/tests/unit/musicaiz/rhythm/test_quantizer.py @@ -101,7 +101,7 @@ def test_basic_quantizer_2(grid_8): def test_advanced_quantizer_1(grid_16): strength = 1 delta_Qr = 12 - type_q= "positive" + type_q = "positive" notes_bar1 = [ Note(pitch=69, start=1, end=24, velocity=127), @@ -152,7 +152,7 @@ def test_advanced_quantizer_2(grid_16): def test_advanced_quantizer_3(grid_16): - + strength = 1 delta_Qr = 12 type_q = None diff --git a/tests/unit/musicaiz/structure/test_notes.py b/tests/unit/musicaiz/structure/test_notes.py index 842d575..5dcbfed 100644 --- a/tests/unit/musicaiz/structure/test_notes.py +++ b/tests/unit/musicaiz/structure/test_notes.py @@ -291,7 +291,7 @@ def test_Note_a(): note_off = 1.0 ligated = True - got_note = Note(pitch, note_on, note_off, velocity, ligated, bpm, resolution) + got_note = Note(pitch, note_on, note_off, velocity, ligated, bpm, resolution=resolution) assert got_note.start_ticks == 0 assert got_note.end_ticks == 1920 @@ -311,7 +311,7 @@ def test_Note_b(): note_off = 20.0 ligated = True - got_note = Note(pitch_name, note_on, note_off, velocity, ligated, bpm, resolution) + got_note = Note(pitch_name, note_on, note_off, velocity, ligated, bpm, resolution=resolution) assert got_note.start_ticks == 19200 assert got_note.end_ticks == 38400 @@ -330,7 +330,7 @@ def test_Note_c(): end_ticks = 20 ligated = True - got_note = Note(pitch_name, start_ticks, end_ticks, velocity, ligated, bpm, resolution) + got_note = Note(pitch_name, start_ticks, end_ticks, velocity, ligated, bpm, resolution=resolution) assert math.isclose(round(got_note.start_sec, 3), 0.005) assert math.isclose(round(got_note.end_sec, 3), 0.01) diff --git a/tests/unit/musicaiz/tokenizers/__init__.py b/tests/unit/musicaiz/tokenizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/musicaiz/tokenizers/test_mmm.py b/tests/unit/musicaiz/tokenizers/test_mmm.py index b722e81..759f3c4 100644 --- a/tests/unit/musicaiz/tokenizers/test_mmm.py +++ b/tests/unit/musicaiz/tokenizers/test_mmm.py @@ -30,7 +30,7 @@ def midi_sample(fixture_dir): @pytest.fixture def musa_obj_tokens(): # Initialize Musa obj fot the mmm_tokens.txt sequence - musa_obj = Musa() + musa_obj = Musa(file=None) musa_obj.instruments.append( Instrument( program=30, @@ -38,7 +38,7 @@ def musa_obj_tokens(): ) ) for _ in range(0, 3, 1): - musa_obj.instruments[0].bars.append(Bar()) + musa_obj.bars.append(Bar()) return musa_obj @@ -46,25 +46,25 @@ def musa_obj_tokens(): def musa_obj_abs(musa_obj_tokens): # Add notes to the Musa obj with absolute timings notes_bar1 = [ - Note(pitch=69, start=0.5, end=1.0, velocity=127), - Note(pitch=64, start=0.5, end=1.5, velocity=127), - Note(pitch=67, start=1.0, end=1.5, velocity=127), - Note(pitch=64, start=1.5, end=2.0, velocity=127) + Note(pitch=69, start=0.5, end=1.0, velocity=127, bar_idx=0, instrument_prog=30), + Note(pitch=64, start=0.5, end=1.5, velocity=127, bar_idx=0, instrument_prog=30), + Note(pitch=67, start=1.0, end=1.5, velocity=127, bar_idx=0, instrument_prog=30), + Note(pitch=64, start=1.5, end=2.0, velocity=127, bar_idx=0, instrument_prog=30) ] - musa_obj_tokens.instruments[0].bars[0].notes = notes_bar1 - musa_obj_tokens.instruments[0].bars[0].start_ticks = 0 - musa_obj_tokens.instruments[0].bars[0].end_ticks = 96 * 4 + musa_obj_tokens.notes.extend(notes_bar1) + musa_obj_tokens.bars[0].start_ticks = 0 + musa_obj_tokens.bars[0].end_ticks = 96 * 4 # bar2 is empty - musa_obj_tokens.instruments[0].bars[1].start_ticks = 96 * 4 - musa_obj_tokens.instruments[0].bars[1].end_ticks = 96 * 8 + musa_obj_tokens.bars[1].start_ticks = 96 * 4 + musa_obj_tokens.bars[1].end_ticks = 96 * 8 notes_bar3 = [ - Note(pitch=72, start=4.0, end=4.5, velocity=127), - Note(pitch=69, start=4.5, end=5.0, velocity=127), - Note(pitch=67, start=5.5, end=5.75, velocity=127), + Note(pitch=72, start=4.0, end=4.5, velocity=127, bar_idx=2, instrument_prog=30), + Note(pitch=69, start=4.5, end=5.0, velocity=127, bar_idx=2, instrument_prog=30), + Note(pitch=67, start=5.5, end=5.75, velocity=127, bar_idx=2, instrument_prog=30), ] - musa_obj_tokens.instruments[0].bars[2].notes = notes_bar3 - musa_obj_tokens.instruments[0].bars[2].start_ticks = 96 * 8 - musa_obj_tokens.instruments[0].bars[2].end_ticks = 96 * 12 + musa_obj_tokens.notes.extend(notes_bar3) + musa_obj_tokens.bars[2].start_ticks = 96 * 8 + musa_obj_tokens.bars[2].end_ticks = 96 * 12 return musa_obj_tokens @@ -72,40 +72,39 @@ def musa_obj_abs(musa_obj_tokens): def musa_obj_rel(musa_obj_tokens): # Add notes to the Musa obj with relative timings notes_bar1 = [ - Note(pitch=69, start=0.5, end=1.0, velocity=127), - Note(pitch=64, start=0.5, end=1.5, velocity=127), - Note(pitch=67, start=1.0, end=1.5, velocity=127), - Note(pitch=64, start=1.5, end=2.0, velocity=127) + Note(pitch=69, start=0.5, end=1.0, velocity=127, bar_idx=0, instrument_prog=30), + Note(pitch=64, start=0.5, end=1.5, velocity=127, bar_idx=0, instrument_prog=30), + Note(pitch=67, start=1.0, end=1.5, velocity=127, bar_idx=0, instrument_prog=30), + Note(pitch=64, start=1.5, end=2.0, velocity=127, bar_idx=0, instrument_prog=30) ] - musa_obj_tokens.instruments[0].bars[0].notes = notes_bar1 - musa_obj_tokens.instruments[0].bars[0].start_ticks = 0 - musa_obj_tokens.instruments[0].bars[0].end_ticks = 96 * 4 + musa_obj_tokens.notes.extend(notes_bar1) + musa_obj_tokens.bars[0].start_ticks = 0 + musa_obj_tokens.bars[0].end_ticks = 96 * 4 # bar2 is empty - musa_obj_tokens.instruments[0].bars[1].start_ticks = 0 - musa_obj_tokens.instruments[0].bars[1].end_ticks = 96 * 4 + musa_obj_tokens.bars[1].start_ticks = 0 + musa_obj_tokens.bars[1].end_ticks = 96 * 4 notes_bar3 = [ - Note(pitch=72, start=0.0, end=0.5, velocity=127), - Note(pitch=69, start=0.5, end=1.0, velocity=127), - Note(pitch=67, start=1.5, end=1.75, velocity=127), + Note(pitch=72, start=0.0, end=0.5, velocity=127, bar_idx=2, instrument_prog=30), + Note(pitch=69, start=0.5, end=1.0, velocity=127, bar_idx=2, instrument_prog=30), + Note(pitch=67, start=1.5, end=1.75, velocity=127, bar_idx=2, instrument_prog=30), ] - musa_obj_tokens.instruments[0].bars[2].notes = notes_bar3 - musa_obj_tokens.instruments[0].bars[2].start_ticks = 0 - musa_obj_tokens.instruments[0].bars[2].end_ticks = 96 * 4 + musa_obj_tokens.notes.extend(notes_bar3) + musa_obj_tokens.bars[2].start_ticks = 0 + musa_obj_tokens.bars[2].end_ticks = 96 * 4 return musa_obj_tokens def _assert_valid_musa_obj(got_musa_obj, expected_musa_obj): assert len(got_musa_obj.instruments) == len(expected_musa_obj.instruments) assert got_musa_obj.instruments[0].program == expected_musa_obj.instruments[0].program - assert len(got_musa_obj.instruments[0].bars) == len(expected_musa_obj.instruments[0].bars) - for b in range(len(expected_musa_obj.instruments[0].bars)): - for n in range(len(expected_musa_obj.instruments[0].bars[b].notes)): - got_note = got_musa_obj.instruments[0].bars[b].notes[n] - expected_note = expected_musa_obj.instruments[0].bars[b].notes[n] - assert expected_note.pitch == got_note.pitch - assert expected_note.start_ticks == got_note.start_ticks - assert expected_note.end_ticks == got_note.end_ticks - assert expected_note.velocity == got_note.velocity + assert len(got_musa_obj.bars) == len(expected_musa_obj.bars) + for n in range(len(expected_musa_obj.notes)): + got_note = got_musa_obj.notes[n] + expected_note = expected_musa_obj.notes[n] + assert expected_note.pitch == got_note.pitch + assert expected_note.start_ticks == got_note.start_ticks + assert expected_note.end_ticks == got_note.end_ticks + assert expected_note.velocity == got_note.velocity def test_MMMTokenizer_split_tokens_by_track(): @@ -216,39 +215,42 @@ def test_MMMTokenizer_split_tokens_by_bar(): assert set(expected[i]) == set(got[i]) -@pytest.mark.skip("Fix this when it's implemented") def test_MMMTokenizer_tokens_to_musa_a(musa_obj_abs, mmm_tokens): # Test case: 1 polyphonic instrument, absolute timings - absolute_timing = True - got = MMMTokenizer.tokens_to_musa(mmm_tokens, absolute_timing) + got = MMMTokenizer.tokens_to_musa( + tokens=mmm_tokens, + absolute_timing=True, + time_unit="SIXTEENTH" + ) expected = musa_obj_abs - _assert_valid_musa_obj(got, expected) -@pytest.mark.skip("Fix this when it's implemented") def test_MMMTokenizer_tokens_to_musa_b(musa_obj_rel, mmm_tokens): # Test case: 1 polyphonic instrument, relative timings - absolute_timing = False - got = MMMTokenizer.tokens_to_musa(mmm_tokens, absolute_timing) + got = MMMTokenizer.tokens_to_musa( + tokens=mmm_tokens, + absolute_timing=False, + time_unit="SIXTEENTH" + ) expected = musa_obj_rel _assert_valid_musa_obj(got, expected) def test_MMMTokenizer_get_pieces_tokens(mmm_multiple_tokens): - got = MMMTokenizer._get_pieces_tokens(mmm_multiple_tokens) + got = MMMTokenizer.get_pieces_tokens(mmm_multiple_tokens) expected_len = 4 assert expected_len == len(got) def test_MMMTokenizer_get_tokens_analytics(mmm_multiple_tokens): got = MMMTokenizer.get_tokens_analytics(mmm_multiple_tokens) - expected_total_tokens = 724 + expected_total_tokens = 725 expected_unique_tokens = 59 expected_total_notes = 188 expected_unique_notes = 23 - expected_total_bars = 56 + expected_total_bars = 112 expected_total_instruments = 16 expected_total_pieces = 4 @@ -268,11 +270,17 @@ def test_MMMTokenizer_tokenize_track_bars(musa_obj_abs, mmm_tokens): start_bar = mmm_tokens.index("BAR_START") end_bar = mmm_tokens.index("TRACK_END") expected = mmm_tokens[start_bar:end_bar] - bars = musa_obj_abs.instruments[0].bars args = MMMTokenizerArguments(time_unit="SIXTEENTH") tokenizer = MMMTokenizer(args=args) - got = tokenizer.tokenize_track_bars(bars) + + tokenizer.midi_object.notes = musa_obj_abs.notes + tokenizer.midi_object.bars = musa_obj_abs.bars + tokenizer.midi_object.instruments_progs = [musa_obj_abs.instruments[0].program] + got = tokenizer.tokenize_track_bars( + bars=tokenizer.midi_object.bars, + program=30 + ) assert got == expected @@ -283,6 +291,14 @@ def test_MMMTokenizer_tokenize_tracks(musa_obj_abs, mmm_tokens): args = MMMTokenizerArguments(time_unit="SIXTEENTH") tokenizer = MMMTokenizer(args=args) + + # since we don't pass the file as an argument, we need to + # create the Musa attributes with the musa_obj_abs + # When a file is provided this is done automatically + tokenizer.midi_object.notes = musa_obj_abs.notes + tokenizer.midi_object.bars = musa_obj_abs.bars + tokenizer.midi_object.instruments_progs = [musa_obj_abs.instruments[0].program] + got = tokenizer.tokenize_tracks( instruments=musa_obj_abs.instruments, bar_start=0, diff --git a/tests/unit/musicaiz/tokenizers/test_remi.py b/tests/unit/musicaiz/tokenizers/test_remi.py new file mode 100644 index 0000000..70f682c --- /dev/null +++ b/tests/unit/musicaiz/tokenizers/test_remi.py @@ -0,0 +1,134 @@ +import pytest + + +from musicaiz.tokenizers import ( + REMITokenizer, + REMITokenizerArguments, +) +from .test_mmm import ( + _assert_valid_musa_obj, + musa_obj_tokens, + musa_obj_abs, + midi_sample, +) + + +@pytest.fixture +def remi_tokens(fixture_dir): + tokens_path = fixture_dir / "tokenizers" / "remi_tokens.txt" + text_file = open(tokens_path, "r") + # read whole file to a string + yield text_file.read() + + +def test_REMITokenizer_split_tokens_by_bar(remi_tokens): + tokens = remi_tokens.split(" ") + expected_bar_1 = [ + [ + "BAR=0", + "TIME_SIG=4/4", + "SUB_BEAT=4", + "TEMPO=120", + "INST=30", + "PITCH=69", + "DUR=4", + "VELOCITY=127", + "PITCH=64", + "DUR=8", + "VELOCITY=127", + "SUB_BEAT=8", + "PITCH=67", + "DUR=4", + "VELOCITY=127", + "SUB_BEAT=12", + "PITCH=64", + "DUR=4", + "VELOCITY=127", + ] + ] + got = REMITokenizer.split_tokens_by_bar(tokens) + assert set(expected_bar_1[0]) == set(got[0]) + + +def test_REMITokenizer_split_tokens_by_subbeat(remi_tokens): + tokens = remi_tokens.split(" ") + expected_subbeats_bar_1 = [ + [ + "BAR=0", + "TIME_SIG=4/4", + ], + [ + "SUB_BEAT=4", + "TEMPO=120", + "INST=30", + "PITCH=69", + "DUR=4", + "VELOCITY=127", + "PITCH=64", + "DUR=8", + "VELOCITY=127", + ], + [ + "SUB_BEAT=8", + "PITCH=67", + "DUR=4", + "VELOCITY=127", + ], + [ + "SUB_BEAT=12", + "PITCH=64", + "DUR=4", + "VELOCITY=127", + ] + ] + got = REMITokenizer.split_tokens_by_subbeat(tokens) + for i in range(len(expected_subbeats_bar_1)): + assert set(expected_subbeats_bar_1[i]) == set(got[i]) + + +def test_REMITokenizer_tokens_to_musa_a(remi_tokens, musa_obj_abs): + # Test case: 1 polyphonic instrument, absolute timings + got = REMITokenizer.tokens_to_musa( + tokens=remi_tokens, + sub_beat="SIXTEENTH" + ) + _assert_valid_musa_obj(got, musa_obj_abs) + + +def test_REMITokenizer_get_tokens_analytics(remi_tokens): + got = REMITokenizer.get_tokens_analytics(remi_tokens) + expected_total_tokens = 33 + expected_unique_tokens = 16 + expected_total_notes = 7 + expected_unique_notes = 4 + expected_total_bars = 2 + expected_total_instruments = 1 + expected_total_pieces = 1 + + assert expected_total_pieces == got["total_pieces"] + assert expected_total_tokens == got["total_tokens"] + assert expected_unique_tokens == got["unique_tokens"] + assert expected_total_notes == got["total_notes"] + assert expected_unique_notes == got["unique_notes"] + assert expected_total_bars == got["total_bars"] + assert expected_total_instruments == got["total_instruments"] + + +def test_REMITokenizer_tokenize_bars(midi_sample, remi_tokens): + + expected = remi_tokens + + args = REMITokenizerArguments(sub_beat="SIXTEENTH") + tokenizer = REMITokenizer( + midi_sample, + args=args + ) + got = tokenizer.tokenize_bars() + assert got == expected + + +def test_REMITokenizer_tokenize_file(midi_sample): + args = REMITokenizerArguments(sub_beat="SIXTEENTH") + tokenizer = REMITokenizer(midi_sample, args) + got = tokenizer.tokenize_file() + assert got != "" From 04657dca6ad42ba617281ac6f821db45b5cac888 Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Sun, 4 Dec 2022 11:31:51 +0100 Subject: [PATCH 2/9] update todos --- TODOs.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/TODOs.md b/TODOs.md index b81b731..1686997 100644 --- a/TODOs.md +++ b/TODOs.md @@ -1,27 +1,26 @@ ## Improvements ### Converters -- [ ] Add protobufs. - [ ] Add MusicXML - [ ] Add ABC notation. - [ ] JSON to musicaiz objects. -### Rhythm -- [ ] Support time signature changes in middle of a piece when loading with ``loaders.Musa`` object. -- [ ] Support tempo or bpm changes in middle of a piece when loading with ``loaders.Musa`` object. - ### Plotters - [ ] Adjust plotters. Plot in secs or ticks and be careful with tick labels in plots that have too much data, numbers can overlap and the plot won't be clean. -- [ ] Plot all tracks in the same pianoroll. ### Harmony -- [ ] Measure just the correct interval (and not all the possible intervals based on the pitch) if note name is known (now it measures all the possible intervals given pitch, but if we do know the note name the interval is just one) +- [ ] Measure just the correct interval (and not all the possible intervals based on the pitch) if note name is known (now it measures all the possible intervals given pitch, but if we do know the note name the interval is just one). +- [ ] Support key changes in middle of a piece when loading with ``loaders.Musa`` object. +- [ ] Initialize note names correctly if key or tonality is known (know the note name initialization is arbitrary, can be the enharmonic or not) ### Features -- [ ] Initialize note names correctly if tonality is known (know the note name initialization is arbitrary, can be the enharmonic or not) +- [ ] Function to compute: Polyphonic rate +- [ ] Function to compute: Polyphony ### Tokenizers -- [ ] Add more encodings (MusicBERT...) +- [ ] MusicTransformer +- [ ] Octuple +- [ ] Compound Word ### Synthesis - [ ] Add function to synthesize a ``loaders.Musa`` object (can be inherited from ``pretty_midi``). From b3a86acb005a82b524c36ce329492419c86f4caa Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Sun, 4 Dec 2022 11:32:27 +0100 Subject: [PATCH 3/9] remove unusued imports and fix errors --- musicaiz/loaders.py | 654 +------------------------------------------- 1 file changed, 2 insertions(+), 652 deletions(-) diff --git a/musicaiz/loaders.py b/musicaiz/loaders.py index 1ace13b..653163b 100644 --- a/musicaiz/loaders.py +++ b/musicaiz/loaders.py @@ -17,23 +17,12 @@ import pretty_midi as pm from pathlib import Path from enum import Enum -import mido -import functools -import numpy as np -from traitlets import Callable # Our modules from musicaiz.structure import Note, Instrument, Bar -from musicaiz.errors import BarIdxErrorMessage from musicaiz.rhythm import ( TimingConsts, - get_subdivisions, - ticks_per_bar, - _bar_str_to_tuple, - advanced_quantizer, - get_ticks_from_subdivision, - ms_per_tick, ms_per_bar, TimeSignature, Beat, @@ -42,7 +31,6 @@ from musicaiz.converters import musa_to_prettymidi from musicaiz.features import get_harmonic_density from musicaiz.algorithms import key_detection, KeyDetectionAlgorithms -from tests.unit.musicaiz import converters class ValidFiles(Enum): @@ -58,633 +46,6 @@ def all_extensions(cls) -> List[str]: return all -class MusaII: - - """Musanalisys main object. This object loads a file and maps it to the - musicaiz' objects defined in the submodules `harmony` and `structure`. - - Attributes - ---------- - - file: Union[str, TextIO] - The input file. It can be a MIDI file, a MusicXML file (TODO) or and ABC file (TODO). - - structure: str - Organices the attributes at different structure levels which are bar, - instrument or piece level. - Defaults to "piece". - - quantize: bool - Default is True. Quantizes the notes at bar or instrument level with the - `rhythm.advanced_quantizer` method that uses a strength of 100%. - - - tonality: Optional[str] - Initializes the MIDI file and adds the tonality attribute. Knowing the tonality - in advance means that notes are initialized by knowing their name - (ex.: pitch 24 can be C or B#) so this reduces the complexity of the chord and - key prediction algorithms. - - time_sig: str - If we do know the time signature in advance, we can initialize Musa object with it. - This will assume that all the MIDI has the same time signature. - - bpm: int - The tempo or bpm of the MIDI file. If this parameter is not initialized we suppose - 120bpm with a resolution (sequencer ticks) of 960 ticks, which means that we have - 500 ticks per quarter note. - - resolution: int - the pulses o ticks per quarter note (PPQ or TPQN). If this parameter is not initialized - we suppose a resolution (sequencer ticks) of 960 ticks. - - absolute_timing: bool - default is True. This allows to initialize note time arguments in absolute (True) or - relative time units (False). Relative units means that each bar will start at 0 seconds - and ticks, so the note timing attributes will be relative to the bar start equals to 0. - - Raises - ------ - - ValueError: [description] - """ - - __slots__ = [ - "file", - "structure", - "tonality", - "time_sig", - "bpm", - "resolution", - "data", - "instruments", - "bars", - "is_quantized", - "notes", - "total_bars", - "absolute_timing", - "cut_notes", - ] - - def __init__( - self, - file: Optional[Union[str, TextIO, Path]] = None, - structure: str = "instruments", - quantize: bool = False, - quantize_note: Optional[str] = "sixteenth", - cut_notes: bool = False, - tonality: Optional[str] = None, - time_sig: str = TimingConsts.DEFAULT_TIME_SIGNATURE.value, - bpm: int = TimingConsts.DEFAULT_BPM.value, - resolution: int = TimingConsts.RESOLUTION.value, - absolute_timing: bool = True, - ): - - self.file = None - self.data = None - self.notes = [] - self.tonality = tonality - self.time_sig = None - self.bpm = bpm - self.resolution = resolution - self.instruments = [] - self.bars = [] - self.structure = structure - self.total_bars = None - self.absolute_timing = absolute_timing - self.is_quantized = quantize - self.cut_notes = cut_notes - - if quantize_note != "eight" and quantize_note != "sixteenth": - raise ValueError("quantize_note must be sixteenth or eight") - - # TODO: What if time_sig, tonality and bpm change inside the piece? - - # File provided - if file is not None: - if isinstance(file, Path): - file = str(file) - if self.is_valid(file): - self.file = file - else: - raise ValueError("Input file extension is not valid.") - - if self.is_midi(file): - # Read bpm from MIDI file with mido - m = mido.MidiFile(self.file) - for msg in m: - if msg.type == "set_tempo": - self.bpm = int(mido.tempo2bpm(msg.tempo)) - elif msg.type == "time_signature": - self.time_sig = TimeSignature( - str(msg.numerator) + "/" + str(msg.denominator) - ) - - # initialize midi object with pretty_midi - pm_inst = pm.PrettyMIDI( - midi_file=self.file, - resolution=self.resolution, - initial_tempo=self.bpm, - ) - # The MIDI file might not have the time signature not tempo information, - # in that case, we initialize them as defalut (120bpm 4/4) - if self.time_sig is None: - self.time_sig = TimeSignature(time_sig) - - # Divide notes into instrument and bars or just into instruments - # depending on th evalue of the input argument `structure` - if self.structure == "instrument_bars": - # Map instruments to load them as musicaiz instrument class - self._load_instruments(pm_inst) - self.notes = [] - self._load_inst_bars() - for instrument in self.instruments: - self._load_bars_notes( - instrument, - absolute_timing=self.absolute_timing, - cut_notes=self.cut_notes, - ) - # TODO: All instr must have the same total_bars, we should get the track with more bars and - # append empty bars to the rest of the tracks - self.total_bars = len(self.instruments[0].bars) - elif self.structure == "bars": - # Map instruments to load them as musicaiz instrument class - self._load_instruments(pm_inst) - self.notes = [] - # Concatenate all the notes of different instruments - # this is for getting the latest note of the piece - # and get the total number of bars of the piece - for instrument in self.instruments: - self.notes.extend(instrument.notes) - self._load_bars() - for instrument in self.instruments: - self._load_bars_notes( - instrument, - absolute_timing=self.absolute_timing, - cut_notes=self.cut_notes, - ) - # TODO: All instr must have the same total_bars, we should get the track with more bars and - # append empty bars to the rest of the tracks - self.total_bars = len(self.instruments[0].bars) - elif self.structure == "notes": - # Concatenate all the notes of different instruments - # this is for getting the latest note of the piece - # and get the total number of bars of the piece - for instrument in pm_inst.instruments: - self.notes.extend(instrument.notes) - self.instruments = [] - self.bars = [] - elif self.structure == "instruments": - self._load_instruments(pm_inst) - for instrument in self.instruments: - instrument.bars = None - self.notes.extend(instrument.notes) - self.total_bars = self.get_total_bars(self.notes) - else: - raise ValueError( - f"Structure argument value {structure} is not valid." - ) - - elif self.is_musicxml(file): - # initialize musicxml object with ?? - # TODO: implement musicxml parser - self.data = None - - # Now quantize if is_quantized - if quantize: - grid = get_subdivisions( - total_bars=self.total_bars, - subdivision=quantize_note, - time_sig=self.time_sig.time_sig, - bpm=self.bpm, - resolution=self.resolution, - absolute_timing=self.absolute_timing, - ) - v_grid = get_ticks_from_subdivision(grid) - for instrument in self.instruments: - if self.structure == "bars": - for bar in instrument.bars: - advanced_quantizer(bar.notes, v_grid) - advanced_quantizer(instrument.notes, v_grid) - - # sort the notes in all the midi file - self.notes.sort(key=lambda x: x.start_ticks, reverse=False) - - @classmethod - def is_valid(cls, file: Union[str, TextIO]): - extension = cls.get_file_extension(file) - return True if extension in ValidFiles.all_extensions() else False - - # TODO: How to split if arg is a filepointer? - @staticmethod - def get_file_extension(file: Union[str, TextIO]): - return Path(file).suffix - - @classmethod - def is_midi(cls, file: Union[str, TextIO]): - extension = cls.get_file_extension(file) - return True if extension in ValidFiles.MIDI.value else False - - @classmethod - def is_musicxml(cls, file: Union[str, TextIO]): - extension = cls.get_file_extension(file) - return True if extension in ValidFiles.MUSIC_XML.value else False - - def _load_inst_bars(self): - """Load the bars for an instrument.""" - total_bars = self.get_total_bars(self.notes) - for instrument in self.instruments: - for _ in range(total_bars): - instrument.bars.append( - Bar( - time_sig=self.time_sig.time_sig, - bpm=self.bpm, - ) - ) - - def _load_bars(self): - """Load the bars for an instrument.""" - total_bars = self.get_total_bars(self.notes) - for instrument in pm_inst.instruments: - for _ in range(total_bars): - instrument.bars.append( - Bar( - time_sig=self.time_sig.time_sig, - bpm=self.bpm, - ) - ) - - def _load_bars_notes( - self, - instrument: Instrument, - cut_notes: bool = False, - absolute_timing: bool = True, - ): - start_bar_ticks = 0 - _, bar_ticks = ticks_per_bar(self.time_sig.time_sig, self.resolution) - notes_next_bar = [] - for bar_idx, bar in enumerate(instrument.bars): - for n in notes_next_bar: - bar.notes.append(n) - notes_next_bar = [] - next_start_bar_ticks = start_bar_ticks + bar_ticks - # bar obj attributes - if self.absolute_timing: - bar.start_ticks = start_bar_ticks - bar.end_ticks = next_start_bar_ticks - bar.start_sec = ( - bar.start_ticks * ms_per_tick(self.bpm, self.resolution) / 1000 - ) - else: - bar.start_ticks, bar.start_sec = 0, 0.0 - bar.end_ticks = bar.start_ticks + bar_ticks - bar.end_sec = bar.end_ticks * ms_per_tick(self.bpm, self.resolution) / 1000 - - for i, note in enumerate(instrument.notes): - # TODO: If note ends after the next bar start? Fix this, like this we'll loose it - if ( - note.start_ticks >= start_bar_ticks - and note.end_ticks <= next_start_bar_ticks - ): - bar.notes.append(note) - # note starts in current bar but ends in the next (or nexts bars) -> cut note - elif ( - start_bar_ticks <= note.start_ticks <= next_start_bar_ticks - and note.end_ticks >= next_start_bar_ticks - ): - if cut_notes: - # cut note by creating a new note that starts when the next bar starts - note_next_bar = Note( - start=next_start_bar_ticks, - end=note.end_ticks, - pitch=note.pitch, - velocity=note.velocity, - ligated=True, - ) - notes_next_bar.append(note_next_bar) - # cut note by assigning end note to the current end bar - note.end_ticks = next_start_bar_ticks - note.end_secs = ( - next_start_bar_ticks - * ms_per_tick(self.bpm, self.resolution) - / 1000 - ) - note.ligated = True - note.instrument_prog = instrument.program - bar.notes.append(note) - elif note.start_ticks > next_start_bar_ticks: - break - - # sort notes in the bar by their onset - bar.notes.sort(key=lambda x: x.start_ticks, reverse=False) - - # update bar attributes now that we know its notes - bar.note_density = len(bar.notes) - bar.harmonic_density = harmony.get_harmonic_density(bar.notes) - # if absolute_timing is False, we'll write the note time attributes relative - # to their corresponding bar - if not absolute_timing: - bar.relative_notes_timing(bar_start=start_bar_ticks) - start_bar_ticks = next_start_bar_ticks - - def predict_key(self, method: str) -> str: - """ - Predict the key with the key profiles algorithms. - Note that signature fifths algorithm requires to initialize - the Musa class with the argument `structure="bars"` instead - of "instruments". The other algorithms work for both initializations. - - Parameters - ---------- - - method: str - The algorithm we want to use to predict the key. The list of - algorithms can be found here: :func:`~musicaiz.algorithms.KeyDetectionAlgorithms`. - - Raises - ------ - - ValueError - - ValueError - - Returns - ------- - key: str - The predicted key as a string separating tonic, alteration - (if proceeds) and mode with "_". - """ - if method not in KeyDetectionAlgorithms.all_values(): - raise ValueError("Not method found.") - elif method in KeyDetectionAlgorithms.SIGNATURE_FIFTHS.value: - if self.structure == "bars": - notes = [] - for inst in self.instruments: - for b, bar in enumerate(inst.bars): - # Signature fifths only takes the 2 1st bars of the piece - if b == 2: - break - notes.extend(bar.notes) - notes.sort(key=lambda x: x.start_ticks, reverse=False) - key = key_detection(notes, method) - elif self.structure == "instruments": - raise ValueError("Initialize the Musa with `structure=bars`") - elif ( - method in KeyDetectionAlgorithms.KRUMHANSL_KESSLER.value - or KeyDetectionAlgorithms.TEMPERLEY.value - or KeyDetectionAlgorithms.ALBRETCH_SHANAHAN.value - ): - key = key_detection(self.notes, method) - return key - - @staticmethod - def group_instrument_bar_notes(musa_object: Musa) -> List[Bar]: - """Instead of having the structure Instrument -> Bar -> Note, this - method group groups the infrmation as: Bar -> Note.""" - bars = [] - for inst_idx, instrument in enumerate(musa_object.instruments): - for bar_idx, bar in enumerate(instrument.bars): - if inst_idx == 0: - bars.append([]) - bars[bar_idx].extend(bar.notes) - return bars - - # TODO: Move this to `utils.py`? - @staticmethod - def _last_note(note_seq: List[Note]) -> Union[Note, None]: - """Get the last note of a sequence.""" - # If there's only 1 note in the sequence, that'll be the - # latest note, so we initialize t to -1 so at least - # 1 note in the note seq. raises the if condition - if len(note_seq) == 0: - return None - t = -1 - for n in note_seq: - if n.end_ticks > t: - last_note = n - t = n.end_ticks - return last_note - - def get_total_bars(self, note_seq: List[Note]): - """ - Calculates the number of bars of a sequence. - """ - last_note = self._last_note(note_seq) - # TODO: Detect number of bars if time signatura changes - subdivisions = get_subdivisions( - total_bars=500, # initialize with a big number of bars - subdivision=self.time_sig.beat_type.upper(), - time_sig=self.time_sig.time_sig, - bpm=self.bpm, - resolution=self.resolution, - ) - # We calculate the number of bars of the piece (supposing same time_sig) - for s in subdivisions: - if last_note.end_ticks > s["ticks"]: - total_bars = s["bar"] - return total_bars - - @staticmethod - def _event_compare(event1, event2): - """Compares two events for sorting. - Events are sorted by tick time ascending. Events with the same tick - time ares sorted by event type. Some events are sorted by - additional values. For example, Note On events are sorted by pitch - then velocity, ensuring that a Note Off (Note On with velocity 0) - will never follow a Note On with the same pitch. - Parameters - ---------- - event1, event2 : mido.Message - Two events to be compared. - """ - # Construct a dictionary which will map event names to numeric - # values which produce the correct sorting. Each dictionary value - # is a function which accepts an event and returns a score. - # The spacing for these scores is 256, which is larger than the - # largest value a MIDI value can take. - secondary_sort = { - "set_tempo": lambda e: (1 * 256 * 256), - "time_signature": lambda e: (2 * 256 * 256), - "key_signature": lambda e: (3 * 256 * 256), - "lyrics": lambda e: (4 * 256 * 256), - "text_events": lambda e: (5 * 256 * 256), - "program_change": lambda e: (6 * 256 * 256), - "pitchwheel": lambda e: ((7 * 256 * 256) + e.pitch), - "control_change": lambda e: ((8 * 256 * 256) + (e.control * 256) + e.value), - "note_off": lambda e: ((9 * 256 * 256) + (e.note * 256)), - "note_on": lambda e: ((10 * 256 * 256) + (e.note * 256) + e.velocity), - "end_of_track": lambda e: (11 * 256 * 256), - } - # If the events have the same tick, and both events have types - # which appear in the secondary_sort dictionary, use the dictionary - # to determine their ordering. - if ( - event1.time == event2.time - and event1.type in secondary_sort - and event2.type in secondary_sort - ): - return secondary_sort[event1.type](event1) - secondary_sort[event2.type]( - event2 - ) - # Otherwise, just return the difference of their ticks. - return event1.time - event2.time - - def write_midi(self, filename: str): - """Writes a Musa object to a MIDI file. - This is adapted from `pretty_midi` library.""" - pass - # TODO: Support tempo, time sig. and key changes - # Initialize output MIDI object - mid = mido.MidiFile(ticks_per_beat=self.resolution) - - # Create track 0 with timing information - timing_track = mido.MidiTrack() - - # Write time sig. - num, den = _bar_str_to_tuple(self.time_sig.time_sig) - timing_track.append( - mido.MetaMessage("time_signature", time=0, numerator=num, denominator=den) - ) - # Write BPM - timing_track.append( - mido.MetaMessage( - "set_tempo", - time=0, - # Convert from microseconds per quarter note to BPM - tempo=self.bpm, - ) - ) - # Write key TODO - # timing_track.append( - # mido.MetaMessage("key_signature", time=self.time_to_tick(ks.time), - # key=key_number_to_mido_key_name[ks.key_number])) - - for n, instrument in enumerate(self.instruments): - # Perharps notes are grouped in bars, concatenate them - if len(instrument.notes) == 0: - # TODO: check if note have absolute durations, relative ones won't work - for bar in instrument.bars: - instrument.notes.extend(bar.notes) - # Initialize track for this instrument - track = mido.MidiTrack() - # Add track name event if instrument has a name - if instrument.name: - track.append( - mido.MetaMessage("track_name", time=0, name=instrument.name) - ) - # If it's a drum event, we need to set channel to 9 - if instrument.is_drum: - channel = 9 - # Otherwise, choose a channel from the possible channel list - else: - channel = 8 # channels[n % len(channels)] - # Set the program number - track.append( - mido.Message( - "program_change", - time=0, - program=instrument.program, - channel=channel, - ) - ) - # Add all note events - ligated_notes = [] - for idx, note in enumerate(instrument.notes): - if note.ligated: - ligated_notes.append(note) - for next_note in instrument.notes[idx + 1 :]: - if not next_note.ligated: - continue - ligated_notes.append(next_note) - # Concat all ligated notes into one note by rewriting the 1st ligated note args - note.start_ticks = ligated_notes[0].start_ticks - note.start_sec = ligated_notes[0].start_sec - note.end_ticks = ligated_notes[-1].end_ticks - note.end_sec = ligated_notes[-1].end_sec - ligated_notes = [] - # Construct the note-on event - track.append( - mido.Message( - "note_on", - time=note.start_ticks, - channel=channel, - note=note.pitch, - velocity=note.velocity, - ) - ) - # Also need a note-off event (note on with velocity 0) - track.append( - mido.Message( - "note_on", - time=note.end_ticks, - channel=channel, - note=note.pitch, - velocity=0, - ) - ) - - # Sort all the events using the event_compare comparator. - track = sorted(track, key=functools.cmp_to_key(self._event_compare)) - - # If there's a note off event and a note on event with the same - # tick and pitch, put the note off event first - for n, (event1, event2) in enumerate(zip(track[:-1], track[1:])): - if ( - event1.time == event2.time - and event1.type == "note_on" - and event2.type == "note_on" - and event1.note == event2.note - and event1.velocity != 0 - and event2.velocity == 0 - ): - track[n] = event2 - track[n + 1] = event1 - # Finally, add in an end of track event - track.append(mido.MetaMessage("end_of_track", time=track[-1].time + 1)) - # Add to the list of output tracks - mid.tracks.append(track) - # Turn ticks to relative time from absolute - for track in mid.tracks: - tick = 0 - for event in track: - event.time -= tick - tick += event.time - mid.save(filename + ".mid") - - def fluidsynth(self, fs=44100, sf2_path=None): - """Synthesize using fluidsynth. - Parameters - ---------- - fs : int - Sampling rate to synthesize at. - sf2_path : str - Path to a .sf2 file. - Default ``None``, which uses the TimGM6mb.sf2 file included with - ``pretty_midi``. - Returns - ------- - synthesized : np.ndarray - Waveform of the MIDI data, synthesized at ``fs``. - """ - # If there are no instruments, or all instruments have no notes, return - # an empty array - if len(self.instruments) == 0 or all( - len(i.notes) == 0 for i in self.instruments - ): - return np.array([]) - # Get synthesized waveform for each instrument - waveforms = [i.fluidsynth(fs=fs, sf2_path=sf2_path) for i in self.instruments] - # Allocate output waveform, with #sample = max length of all waveforms - synthesized = np.zeros(np.max([w.shape[0] for w in waveforms])) - # Sum all waveforms in - for waveform in waveforms: - synthesized[: waveform.shape[0]] += waveform - # Normalize - synthesized /= np.abs(synthesized).max() - return synthesized - - - class Musa: __slots__ = [ @@ -745,7 +106,7 @@ def __init__( # TODO: relative times in notes? # TODO: cut_notes # TODO: assign notes their name when key is known - # TODO: key signature changes, + # TODO: key signature changes # TODO: write_midi -> with pretty_midi # TODO: synthesize -> with pretty_midi @@ -764,8 +125,6 @@ def __init__( self.bars = [] self.tempo_changes = [] - # TODO unify quantize_note as subdivision_note? - if subdivision_note not in self.VALID_SUBDIVISIONS: raise ValueError( "{subdivision_note} is not valid subdivision_note. " \ @@ -783,14 +142,8 @@ def __init__( if self.is_midi(file): self._load_midifile(resolution, tonality) - - # group subdivisions in beats - - - # group subdivisions in bars - def json(self): - return {key : getattr(self, key, None) for key in self.__slots__} + return {key: getattr(self, key, None) for key in self.__slots__} def _load_midifile( self, @@ -862,9 +215,6 @@ def _load_midifile( {"tempo": self.tempo_changes[-1]["tempo"], "ms": last_note_end * 1000} ) - # Load Bars - #self._load_bars(last_note_end) - # Load beats self._load_beats(last_note_end) From 7dcbcc8a35d819a78eea87cb038a1c3ca0e0f1f9 Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Mon, 5 Dec 2022 10:18:24 +0100 Subject: [PATCH 4/9] change master to main branch --- .github/workflows/pypi_release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pypi_release.yml b/.github/workflows/pypi_release.yml index 264ff15..4af2d05 100644 --- a/.github/workflows/pypi_release.yml +++ b/.github/workflows/pypi_release.yml @@ -24,7 +24,7 @@ jobs: python -m build twine check --strict dist/* - name: Publish distribution to PyPI - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@main with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} From 64359c978837d07e025340d52f2964e1286c1526 Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Mon, 5 Dec 2022 10:20:13 +0100 Subject: [PATCH 5/9] refactor quantizer configs as dataclass --- musicaiz/converters/musa_protobuf.py | 4 +- musicaiz/converters/protobuf/musicaiz.proto | 4 +- musicaiz/converters/protobuf/musicaiz_pb2.py | 72 +++++++++---------- musicaiz/converters/protobuf/musicaiz_pb2.pyi | 10 +-- musicaiz/loaders.py | 70 ++++++++++++------ musicaiz/rhythm/__init__.py | 3 + musicaiz/rhythm/quantizer.py | 69 ++++++++++-------- musicaiz/tokenizers/mmm.py | 2 - .../converters/test_musa_to_protobuf.py | 1 + tests/unit/musicaiz/loaders/test_loaders.py | 5 +- tests/unit/musicaiz/rhythm/test_quantizer.py | 44 +++++++----- 11 files changed, 164 insertions(+), 120 deletions(-) diff --git a/musicaiz/converters/musa_protobuf.py b/musicaiz/converters/musa_protobuf.py index deb66c0..9495109 100644 --- a/musicaiz/converters/musa_protobuf.py +++ b/musicaiz/converters/musa_protobuf.py @@ -49,8 +49,8 @@ def musa_to_proto(musa_obj): proto_is_quantized = proto.is_quantized.add() proto_is_quantized = musa_obj.is_quantized - proto_quantize_note = proto.quantize_note.add() - proto_quantize_note = musa_obj.quantize_note + proto_quantizer_args = proto.quantizer_args.add() + proto_quantizer_args = musa_obj.quantizer_args proto_absolute_timing = proto.absolute_timing.add() proto_absolute_timing = musa_obj.absolute_timing diff --git a/musicaiz/converters/protobuf/musicaiz.proto b/musicaiz/converters/protobuf/musicaiz.proto index abe637b..6d9ba29 100644 --- a/musicaiz/converters/protobuf/musicaiz.proto +++ b/musicaiz/converters/protobuf/musicaiz.proto @@ -14,7 +14,7 @@ message Musa { repeated Tonality tonality = 9; repeated Resolution resolution = 10; repeated IsQuantized is_quantized = 11; - repeated QuantizeNote quantize_note = 12; + repeated QuantizerArgs quantizer_args = 12; repeated AbsoluteTiming absolute_timing = 13; repeated CutNotes cut_notes = 14; repeated TempoChanges tempo_changes = 15; @@ -32,7 +32,7 @@ message Musa { message Tonality {} message Resolution {} message IsQuantized {} - message QuantizeNote {} + message QuantizerArgs {} message AbsoluteTiming {} message CutNotes {} message TempoChanges {} diff --git a/musicaiz/converters/protobuf/musicaiz_pb2.py b/musicaiz/converters/protobuf/musicaiz_pb2.py index 241f4ca..3155b88 100644 --- a/musicaiz/converters/protobuf/musicaiz_pb2.py +++ b/musicaiz/converters/protobuf/musicaiz_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n+musicaiz/converters/protobuf/musicaiz.proto\x12\x08musicaiz\"\xa5\x10\n\x04Musa\x12\x43\n\x16time_signature_changes\x18\x05 \x03(\x0b\x32#.musicaiz.Musa.TimeSignatureChanges\x12\x38\n\x10subdivision_note\x18\x06 \x03(\x0b\x32\x1e.musicaiz.Musa.SubdivisionNote\x12!\n\x04\x66ile\x18\x07 \x03(\x0b\x32\x13.musicaiz.Musa.File\x12,\n\ntotal_bars\x18\x08 \x03(\x0b\x32\x18.musicaiz.Musa.TotalBars\x12)\n\x08tonality\x18\t \x03(\x0b\x32\x17.musicaiz.Musa.Tonality\x12-\n\nresolution\x18\n \x03(\x0b\x32\x19.musicaiz.Musa.Resolution\x12\x30\n\x0cis_quantized\x18\x0b \x03(\x0b\x32\x1a.musicaiz.Musa.IsQuantized\x12\x32\n\rquantize_note\x18\x0c \x03(\x0b\x32\x1b.musicaiz.Musa.QuantizeNote\x12\x36\n\x0f\x61\x62solute_timing\x18\r \x03(\x0b\x32\x1d.musicaiz.Musa.AbsoluteTiming\x12*\n\tcut_notes\x18\x0e \x03(\x0b\x32\x17.musicaiz.Musa.CutNotes\x12\x32\n\rtempo_changes\x18\x0f \x03(\x0b\x32\x1b.musicaiz.Musa.TempoChanges\x12:\n\x11instruments_progs\x18\x10 \x03(\x0b\x32\x1f.musicaiz.Musa.InstrumentsProgs\x12.\n\x0binstruments\x18\x11 \x03(\x0b\x32\x19.musicaiz.Musa.Instrument\x12 \n\x04\x62\x61rs\x18\x12 \x03(\x0b\x32\x12.musicaiz.Musa.Bar\x12\"\n\x05notes\x18\x13 \x03(\x0b\x32\x13.musicaiz.Musa.Note\x12\"\n\x05\x62\x65\x61ts\x18\x14 \x03(\x0b\x32\x13.musicaiz.Musa.Beat\x12(\n\x08subbeats\x18\x15 \x03(\x0b\x32\x16.musicaiz.Musa.Subbeat\x1a\x16\n\x14TimeSignatureChanges\x1a\x11\n\x0fSubdivisionNote\x1a\x06\n\x04\x46ile\x1a\x0b\n\tTotalBars\x1a\n\n\x08Tonality\x1a\x0c\n\nResolution\x1a\r\n\x0bIsQuantized\x1a\x0e\n\x0cQuantizeNote\x1a\x10\n\x0e\x41\x62soluteTiming\x1a\n\n\x08\x43utNotes\x1a\x0e\n\x0cTempoChanges\x1a\x12\n\x10InstrumentsProgs\x1a`\n\nInstrument\x12\x12\n\ninstrument\x18\x01 \x01(\x05\x12\x0f\n\x07program\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06\x66\x61mily\x18\x04 \x01(\t\x12\x0f\n\x07is_drum\x18\x05 \x01(\x08\x1a\xb6\x02\n\x04Note\x12\r\n\x05pitch\x18\x01 \x01(\x05\x12\x12\n\npitch_name\x18\x02 \x01(\t\x12\x11\n\tnote_name\x18\x03 \x01(\t\x12\x0e\n\x06octave\x18\x04 \x01(\t\x12\x0f\n\x07ligated\x18\x05 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x06 \x01(\x05\x12\x11\n\tend_ticks\x18\x07 \x01(\x05\x12\x11\n\tstart_sec\x18\x08 \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\t \x01(\x02\x12\x10\n\x08symbolic\x18\n \x01(\t\x12\x10\n\x08velocity\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x12\x13\n\x0bsubbeat_idx\x18\x0e \x01(\x05\x12\x16\n\x0einstrument_idx\x18\x0f \x01(\x05\x12\x17\n\x0finstrument_prog\x18\x10 \x01(\x05\x1a\xcd\x01\n\x03\x42\x61r\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x1a\xc3\x01\n\x04\x42\x65\x61t\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x1a\x88\x02\n\x07Subbeat\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n+musicaiz/converters/protobuf/musicaiz.proto\x12\x08musicaiz\"\xa8\x10\n\x04Musa\x12\x43\n\x16time_signature_changes\x18\x05 \x03(\x0b\x32#.musicaiz.Musa.TimeSignatureChanges\x12\x38\n\x10subdivision_note\x18\x06 \x03(\x0b\x32\x1e.musicaiz.Musa.SubdivisionNote\x12!\n\x04\x66ile\x18\x07 \x03(\x0b\x32\x13.musicaiz.Musa.File\x12,\n\ntotal_bars\x18\x08 \x03(\x0b\x32\x18.musicaiz.Musa.TotalBars\x12)\n\x08tonality\x18\t \x03(\x0b\x32\x17.musicaiz.Musa.Tonality\x12-\n\nresolution\x18\n \x03(\x0b\x32\x19.musicaiz.Musa.Resolution\x12\x30\n\x0cis_quantized\x18\x0b \x03(\x0b\x32\x1a.musicaiz.Musa.IsQuantized\x12\x34\n\x0equantizer_args\x18\x0c \x03(\x0b\x32\x1c.musicaiz.Musa.QuantizerArgs\x12\x36\n\x0f\x61\x62solute_timing\x18\r \x03(\x0b\x32\x1d.musicaiz.Musa.AbsoluteTiming\x12*\n\tcut_notes\x18\x0e \x03(\x0b\x32\x17.musicaiz.Musa.CutNotes\x12\x32\n\rtempo_changes\x18\x0f \x03(\x0b\x32\x1b.musicaiz.Musa.TempoChanges\x12:\n\x11instruments_progs\x18\x10 \x03(\x0b\x32\x1f.musicaiz.Musa.InstrumentsProgs\x12.\n\x0binstruments\x18\x11 \x03(\x0b\x32\x19.musicaiz.Musa.Instrument\x12 \n\x04\x62\x61rs\x18\x12 \x03(\x0b\x32\x12.musicaiz.Musa.Bar\x12\"\n\x05notes\x18\x13 \x03(\x0b\x32\x13.musicaiz.Musa.Note\x12\"\n\x05\x62\x65\x61ts\x18\x14 \x03(\x0b\x32\x13.musicaiz.Musa.Beat\x12(\n\x08subbeats\x18\x15 \x03(\x0b\x32\x16.musicaiz.Musa.Subbeat\x1a\x16\n\x14TimeSignatureChanges\x1a\x11\n\x0fSubdivisionNote\x1a\x06\n\x04\x46ile\x1a\x0b\n\tTotalBars\x1a\n\n\x08Tonality\x1a\x0c\n\nResolution\x1a\r\n\x0bIsQuantized\x1a\x0f\n\rQuantizerArgs\x1a\x10\n\x0e\x41\x62soluteTiming\x1a\n\n\x08\x43utNotes\x1a\x0e\n\x0cTempoChanges\x1a\x12\n\x10InstrumentsProgs\x1a`\n\nInstrument\x12\x12\n\ninstrument\x18\x01 \x01(\x05\x12\x0f\n\x07program\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06\x66\x61mily\x18\x04 \x01(\t\x12\x0f\n\x07is_drum\x18\x05 \x01(\x08\x1a\xb6\x02\n\x04Note\x12\r\n\x05pitch\x18\x01 \x01(\x05\x12\x12\n\npitch_name\x18\x02 \x01(\t\x12\x11\n\tnote_name\x18\x03 \x01(\t\x12\x0e\n\x06octave\x18\x04 \x01(\t\x12\x0f\n\x07ligated\x18\x05 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x06 \x01(\x05\x12\x11\n\tend_ticks\x18\x07 \x01(\x05\x12\x11\n\tstart_sec\x18\x08 \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\t \x01(\x02\x12\x10\n\x08symbolic\x18\n \x01(\t\x12\x10\n\x08velocity\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x12\x13\n\x0bsubbeat_idx\x18\x0e \x01(\x05\x12\x16\n\x0einstrument_idx\x18\x0f \x01(\x05\x12\x17\n\x0finstrument_prog\x18\x10 \x01(\x05\x1a\xcd\x01\n\x03\x42\x61r\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x1a\xc3\x01\n\x04\x42\x65\x61t\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x1a\x88\x02\n\x07Subbeat\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'musicaiz.converters.protobuf.musicaiz_pb2', globals()) @@ -21,39 +21,39 @@ DESCRIPTOR._options = None _MUSA._serialized_start=58 - _MUSA._serialized_end=2143 - _MUSA_TIMESIGNATURECHANGES._serialized_start=874 - _MUSA_TIMESIGNATURECHANGES._serialized_end=896 - _MUSA_SUBDIVISIONNOTE._serialized_start=898 - _MUSA_SUBDIVISIONNOTE._serialized_end=915 - _MUSA_FILE._serialized_start=917 - _MUSA_FILE._serialized_end=923 - _MUSA_TOTALBARS._serialized_start=925 - _MUSA_TOTALBARS._serialized_end=936 - _MUSA_TONALITY._serialized_start=938 - _MUSA_TONALITY._serialized_end=948 - _MUSA_RESOLUTION._serialized_start=950 - _MUSA_RESOLUTION._serialized_end=962 - _MUSA_ISQUANTIZED._serialized_start=964 - _MUSA_ISQUANTIZED._serialized_end=977 - _MUSA_QUANTIZENOTE._serialized_start=979 - _MUSA_QUANTIZENOTE._serialized_end=993 - _MUSA_ABSOLUTETIMING._serialized_start=995 - _MUSA_ABSOLUTETIMING._serialized_end=1011 - _MUSA_CUTNOTES._serialized_start=1013 - _MUSA_CUTNOTES._serialized_end=1023 - _MUSA_TEMPOCHANGES._serialized_start=1025 - _MUSA_TEMPOCHANGES._serialized_end=1039 - _MUSA_INSTRUMENTSPROGS._serialized_start=1041 - _MUSA_INSTRUMENTSPROGS._serialized_end=1059 - _MUSA_INSTRUMENT._serialized_start=1061 - _MUSA_INSTRUMENT._serialized_end=1157 - _MUSA_NOTE._serialized_start=1160 - _MUSA_NOTE._serialized_end=1470 - _MUSA_BAR._serialized_start=1473 - _MUSA_BAR._serialized_end=1678 - _MUSA_BEAT._serialized_start=1681 - _MUSA_BEAT._serialized_end=1876 - _MUSA_SUBBEAT._serialized_start=1879 - _MUSA_SUBBEAT._serialized_end=2143 + _MUSA._serialized_end=2146 + _MUSA_TIMESIGNATURECHANGES._serialized_start=876 + _MUSA_TIMESIGNATURECHANGES._serialized_end=898 + _MUSA_SUBDIVISIONNOTE._serialized_start=900 + _MUSA_SUBDIVISIONNOTE._serialized_end=917 + _MUSA_FILE._serialized_start=919 + _MUSA_FILE._serialized_end=925 + _MUSA_TOTALBARS._serialized_start=927 + _MUSA_TOTALBARS._serialized_end=938 + _MUSA_TONALITY._serialized_start=940 + _MUSA_TONALITY._serialized_end=950 + _MUSA_RESOLUTION._serialized_start=952 + _MUSA_RESOLUTION._serialized_end=964 + _MUSA_ISQUANTIZED._serialized_start=966 + _MUSA_ISQUANTIZED._serialized_end=979 + _MUSA_QUANTIZERARGS._serialized_start=981 + _MUSA_QUANTIZERARGS._serialized_end=996 + _MUSA_ABSOLUTETIMING._serialized_start=998 + _MUSA_ABSOLUTETIMING._serialized_end=1014 + _MUSA_CUTNOTES._serialized_start=1016 + _MUSA_CUTNOTES._serialized_end=1026 + _MUSA_TEMPOCHANGES._serialized_start=1028 + _MUSA_TEMPOCHANGES._serialized_end=1042 + _MUSA_INSTRUMENTSPROGS._serialized_start=1044 + _MUSA_INSTRUMENTSPROGS._serialized_end=1062 + _MUSA_INSTRUMENT._serialized_start=1064 + _MUSA_INSTRUMENT._serialized_end=1160 + _MUSA_NOTE._serialized_start=1163 + _MUSA_NOTE._serialized_end=1473 + _MUSA_BAR._serialized_start=1476 + _MUSA_BAR._serialized_end=1681 + _MUSA_BEAT._serialized_start=1684 + _MUSA_BEAT._serialized_end=1879 + _MUSA_SUBBEAT._serialized_start=1882 + _MUSA_SUBBEAT._serialized_end=2146 # @@protoc_insertion_point(module_scope) diff --git a/musicaiz/converters/protobuf/musicaiz_pb2.pyi b/musicaiz/converters/protobuf/musicaiz_pb2.pyi index 75f8657..f03972c 100644 --- a/musicaiz/converters/protobuf/musicaiz_pb2.pyi +++ b/musicaiz/converters/protobuf/musicaiz_pb2.pyi @@ -6,7 +6,7 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class Musa(_message.Message): - __slots__ = ["absolute_timing", "bars", "beats", "cut_notes", "file", "instruments", "instruments_progs", "is_quantized", "notes", "quantize_note", "resolution", "subbeats", "subdivision_note", "tempo_changes", "time_signature_changes", "tonality", "total_bars"] + __slots__ = ["absolute_timing", "bars", "beats", "cut_notes", "file", "instruments", "instruments_progs", "is_quantized", "notes", "quantizer_args", "resolution", "subbeats", "subdivision_note", "tempo_changes", "time_signature_changes", "tonality", "total_bars"] class AbsoluteTiming(_message.Message): __slots__ = [] def __init__(self) -> None: ... @@ -116,7 +116,7 @@ class Musa(_message.Message): symbolic: str velocity: int def __init__(self, pitch: _Optional[int] = ..., pitch_name: _Optional[str] = ..., note_name: _Optional[str] = ..., octave: _Optional[str] = ..., ligated: bool = ..., start_ticks: _Optional[int] = ..., end_ticks: _Optional[int] = ..., start_sec: _Optional[float] = ..., end_sec: _Optional[float] = ..., symbolic: _Optional[str] = ..., velocity: _Optional[int] = ..., bar_idx: _Optional[int] = ..., beat_idx: _Optional[int] = ..., subbeat_idx: _Optional[int] = ..., instrument_idx: _Optional[int] = ..., instrument_prog: _Optional[int] = ...) -> None: ... - class QuantizeNote(_message.Message): + class QuantizerArgs(_message.Message): __slots__ = [] def __init__(self) -> None: ... class Resolution(_message.Message): @@ -175,7 +175,7 @@ class Musa(_message.Message): INSTRUMENTS_PROGS_FIELD_NUMBER: _ClassVar[int] IS_QUANTIZED_FIELD_NUMBER: _ClassVar[int] NOTES_FIELD_NUMBER: _ClassVar[int] - QUANTIZE_NOTE_FIELD_NUMBER: _ClassVar[int] + QUANTIZER_ARGS_FIELD_NUMBER: _ClassVar[int] RESOLUTION_FIELD_NUMBER: _ClassVar[int] SUBBEATS_FIELD_NUMBER: _ClassVar[int] SUBDIVISION_NOTE_FIELD_NUMBER: _ClassVar[int] @@ -192,7 +192,7 @@ class Musa(_message.Message): instruments_progs: _containers.RepeatedCompositeFieldContainer[Musa.InstrumentsProgs] is_quantized: _containers.RepeatedCompositeFieldContainer[Musa.IsQuantized] notes: _containers.RepeatedCompositeFieldContainer[Musa.Note] - quantize_note: _containers.RepeatedCompositeFieldContainer[Musa.QuantizeNote] + quantizer_args: _containers.RepeatedCompositeFieldContainer[Musa.QuantizerArgs] resolution: _containers.RepeatedCompositeFieldContainer[Musa.Resolution] subbeats: _containers.RepeatedCompositeFieldContainer[Musa.Subbeat] subdivision_note: _containers.RepeatedCompositeFieldContainer[Musa.SubdivisionNote] @@ -200,4 +200,4 @@ class Musa(_message.Message): time_signature_changes: _containers.RepeatedCompositeFieldContainer[Musa.TimeSignatureChanges] tonality: _containers.RepeatedCompositeFieldContainer[Musa.Tonality] total_bars: _containers.RepeatedCompositeFieldContainer[Musa.TotalBars] - def __init__(self, time_signature_changes: _Optional[_Iterable[_Union[Musa.TimeSignatureChanges, _Mapping]]] = ..., subdivision_note: _Optional[_Iterable[_Union[Musa.SubdivisionNote, _Mapping]]] = ..., file: _Optional[_Iterable[_Union[Musa.File, _Mapping]]] = ..., total_bars: _Optional[_Iterable[_Union[Musa.TotalBars, _Mapping]]] = ..., tonality: _Optional[_Iterable[_Union[Musa.Tonality, _Mapping]]] = ..., resolution: _Optional[_Iterable[_Union[Musa.Resolution, _Mapping]]] = ..., is_quantized: _Optional[_Iterable[_Union[Musa.IsQuantized, _Mapping]]] = ..., quantize_note: _Optional[_Iterable[_Union[Musa.QuantizeNote, _Mapping]]] = ..., absolute_timing: _Optional[_Iterable[_Union[Musa.AbsoluteTiming, _Mapping]]] = ..., cut_notes: _Optional[_Iterable[_Union[Musa.CutNotes, _Mapping]]] = ..., tempo_changes: _Optional[_Iterable[_Union[Musa.TempoChanges, _Mapping]]] = ..., instruments_progs: _Optional[_Iterable[_Union[Musa.InstrumentsProgs, _Mapping]]] = ..., instruments: _Optional[_Iterable[_Union[Musa.Instrument, _Mapping]]] = ..., bars: _Optional[_Iterable[_Union[Musa.Bar, _Mapping]]] = ..., notes: _Optional[_Iterable[_Union[Musa.Note, _Mapping]]] = ..., beats: _Optional[_Iterable[_Union[Musa.Beat, _Mapping]]] = ..., subbeats: _Optional[_Iterable[_Union[Musa.Subbeat, _Mapping]]] = ...) -> None: ... + def __init__(self, time_signature_changes: _Optional[_Iterable[_Union[Musa.TimeSignatureChanges, _Mapping]]] = ..., subdivision_note: _Optional[_Iterable[_Union[Musa.SubdivisionNote, _Mapping]]] = ..., file: _Optional[_Iterable[_Union[Musa.File, _Mapping]]] = ..., total_bars: _Optional[_Iterable[_Union[Musa.TotalBars, _Mapping]]] = ..., tonality: _Optional[_Iterable[_Union[Musa.Tonality, _Mapping]]] = ..., resolution: _Optional[_Iterable[_Union[Musa.Resolution, _Mapping]]] = ..., is_quantized: _Optional[_Iterable[_Union[Musa.IsQuantized, _Mapping]]] = ..., quantizer_args: _Optional[_Iterable[_Union[Musa.QuantizerArgs, _Mapping]]] = ..., absolute_timing: _Optional[_Iterable[_Union[Musa.AbsoluteTiming, _Mapping]]] = ..., cut_notes: _Optional[_Iterable[_Union[Musa.CutNotes, _Mapping]]] = ..., tempo_changes: _Optional[_Iterable[_Union[Musa.TempoChanges, _Mapping]]] = ..., instruments_progs: _Optional[_Iterable[_Union[Musa.InstrumentsProgs, _Mapping]]] = ..., instruments: _Optional[_Iterable[_Union[Musa.Instrument, _Mapping]]] = ..., bars: _Optional[_Iterable[_Union[Musa.Bar, _Mapping]]] = ..., notes: _Optional[_Iterable[_Union[Musa.Note, _Mapping]]] = ..., beats: _Optional[_Iterable[_Union[Musa.Beat, _Mapping]]] = ..., subbeats: _Optional[_Iterable[_Union[Musa.Subbeat, _Mapping]]] = ...) -> None: ... diff --git a/musicaiz/loaders.py b/musicaiz/loaders.py index 653163b..f12e076 100644 --- a/musicaiz/loaders.py +++ b/musicaiz/loaders.py @@ -13,7 +13,7 @@ """ from __future__ import annotations -from typing import TextIO, Union, List, Optional +from typing import TextIO, Union, List, Optional, Type import pretty_midi as pm from pathlib import Path from enum import Enum @@ -27,6 +27,8 @@ TimeSignature, Beat, Subdivision, + advanced_quantizer, + QuantizerConfig, ) from musicaiz.converters import musa_to_prettymidi from musicaiz.features import get_harmonic_density @@ -62,11 +64,11 @@ class Musa: "bars", "tempo_changes", "instruments_progs", - "quantize_note", "general_midi", "subdivision_note", "subbeats", "beats", + "quantizer_args", ] # subdivision_note and quantize_note @@ -83,13 +85,13 @@ def __init__( self, file: Optional[Union[str, TextIO, Path]], quantize: bool = False, - quantize_note: Optional[str] = "sixteenth", cut_notes: bool = False, tonality: Optional[str] = None, resolution: Optional[int] = None, absolute_timing: bool = True, general_midi: bool = False, - subdivision_note: str = "sixteenth" + subdivision_note: str = "sixteenth", + quantizer_args: Type[QuantizerConfig] = QuantizerConfig, ): """ @@ -116,7 +118,6 @@ def __init__( self.general_midi = general_midi self.absolute_timing = absolute_timing self.is_quantized = quantize - self.quantize_note = quantize_note self.subdivision_note = subdivision_note self.subbeats = [] self.beats = [] @@ -124,6 +125,7 @@ def __init__( self.time_signature_changes = [] self.bars = [] self.tempo_changes = [] + self.quantizer_args = quantizer_args if subdivision_note not in self.VALID_SUBDIVISIONS: raise ValueError( @@ -256,20 +258,45 @@ def _load_midifile( assert len([sub for sub in self.subbeats if sub.bar_idx is None]) == 0 assert len([sub for sub in self.subbeats if sub.beat_idx is None]) == 0 - """ - # if quantize - if quantize: - grid = get_subdivisions( - total_bars=self.total_bars, - subdivision=quantize_note, - time_sig=self.time_sig.time_sig, - bpm=self.bpm, - resolution=self.resolution, - absolute_timing=self.absolute_timing, + + # Add subbeats to last bar if it's incomplete + if self.subbeats[-1].end_ticks < self.bars[-1].end_ticks: + subbeats_last_bar = self.get_subbeats_in_bar(len(self.bars) - 1) + # subbeats in a complete bar + subbeats_total = int( + self.bars[-1].time_sig._notes_per_bar(self.subdivision_note.upper()) ) - v_grid = get_ticks_from_subdivision(grid) - advanced_quantizer(self.notes, v_grid) - """ + if len(subbeats_last_bar) < subbeats_total: + for _ in range(subbeats_total - len(subbeats_last_bar)): + dur = subbeats_last_bar[0].end_sec - subbeats_last_bar[0].start_sec + subbeat = Subdivision( + time_sig=subbeats_last_bar[0].time_sig, + start=self.subbeats[-1].end_sec, + end=self.subbeats[-1].end_sec + dur, + bpm=subbeats_last_bar[0].bpm, + resolution=self.subbeats[-1].resolution, + ) + subbeat.global_idx = len(self.subbeats) + subbeat.bar_idx = len(self.bars) - 1 + # TODO: Group last subbeats in the correct beats + # (this is not that important since these beats do not contain notes) + subbeat.beat_idx = len(self.beats) - 1 + subbeat.bar_idx = len(self.bars) - 1 + self.subbeats.append(subbeat) + + # if quantize + if self.is_quantized: + quantized_notes = [] + for i, bar in enumerate(self.bars): + v_grid = [sb.start_ticks for sb in self.get_subbeats_in_bar(i)] + # TODO: Recalcualte subbeat_idx, beat_idx and bar_idx of the notes + notes = self.get_notes_in_bar(i) + advanced_quantizer( + notes, v_grid, config=self.quantizer_args, + bpm=bar.bpm, resolution=self.resolution + ) + quantized_notes.extend(notes) + self.notes = quantized_notes @classmethod def is_valid(cls, file: Union[str, Path]): @@ -874,13 +901,14 @@ def _load_bars_and_group_beats_in_bars(self): self.bars.append(bar) # Now add as musch beats as needed to complete the last bar if len(beats_last_bar) < bar.time_sig.num: + beats = self.get_beats_in_bar(len(self.bars) - 1) for _ in range(bar.time_sig.num - len(beats_last_bar)): - dur = self.beats[-1].end_sec - self.beats[-1].start_sec + dur = beats[0].end_sec - beats[0].start_sec beat = Beat( - time_sig=self.beats[-1].time_sig, + time_sig=beats[0].time_sig, start=self.beats[-1].end_sec, end=self.beats[-1].end_sec + dur, - bpm=self.beats[-1].bpm, + bpm=beats[0].bpm, resolution=self.beats[-1].resolution, ) beat.global_idx = len(self.beats) diff --git a/musicaiz/rhythm/__init__.py b/musicaiz/rhythm/__init__.py index fb0166f..1ce886a 100644 --- a/musicaiz/rhythm/__init__.py +++ b/musicaiz/rhythm/__init__.py @@ -43,6 +43,7 @@ .. autosummary:: :toctree: generated/ + QuantizerConfig basic_quantizer advanced_quantizer get_ticks_from_subdivision @@ -68,6 +69,7 @@ ) from .quantizer import ( + QuantizerConfig, basic_quantizer, advanced_quantizer, get_ticks_from_subdivision, @@ -83,6 +85,7 @@ "ms_per_note", "ms_per_bar", "get_subdivisions", + "QuantizerConfig", "basic_quantizer", "advanced_quantizer", "get_ticks_from_subdivision", diff --git a/musicaiz/rhythm/quantizer.py b/musicaiz/rhythm/quantizer.py index 030563f..fa9665d 100644 --- a/musicaiz/rhythm/quantizer.py +++ b/musicaiz/rhythm/quantizer.py @@ -1,13 +1,38 @@ from typing import List, Dict, Union, Optional import numpy as np -from enum import Enum +from dataclasses import dataclass from musicaiz.rhythm import TimingConsts, ms_per_tick -class QuantizerConfig(Enum): - DELTA_QR = 12 - STRENGTH = 1 # 100% +@dataclass +class QuantizerConfig: + """ + Basic quantizer arguments. + + Parameters + ---------- + + note: str + The note length of the grid. + + strength: parameter between 0 and 1. + Example GRID = [0 24 48], STAR_TICKS = [3 ,21, 40] and Aq + START_NEW_TICS = [(3-0)*strength, (21-24)*strength, (40-48)*strength] + END_NEW_TICKS = [] + + delta_qr: Q_range in ticks + + type_q: type of quantization + if negative: only differences between start_tick and grid > Q_r is + taking into account for the quantization. If positive only differences + between start_tick and grid < Q_r is taking into accounto for the quantization. + If none all start_tick is quantized based on the strength (it works similar to basic + quantization but adding the strength parameter) + """ + delta_qr: int = 12 + strength: int = 1 # 100% + type_q: Optional[str] = None def _find_nearest( @@ -58,10 +83,9 @@ def basic_quantizer( def advanced_quantizer( input_notes, grid: List[int], - strength: float = QuantizerConfig.STRENGTH.value, - delta_Qr: int = QuantizerConfig.DELTA_QR.value, - type_q: Optional[str] = None, - bpm: int = TimingConsts.DEFAULT_BPM.value + config: QuantizerConfig, + bpm: int = TimingConsts.DEFAULT_BPM.value, + resolution: int = TimingConsts.RESOLUTION.value, ): """ This function quantizes a musa object given a grid. @@ -72,26 +96,9 @@ def advanced_quantizer( file: musa object grid: array of ints in ticks - - strength: parameter between 0 and 1. - Example GRID = [0 24 48], STAR_TICKS = [3 ,21, 40] and Aq - START_NEW_TICS = [(3-0)*strength, (21-24)*strength, (40-48)*strength] - END_NEW_TICKS = [] - - delta_Qr: Q_range in ticks - - type_q: type of quantization - if negative: only differences between start_tick and grid > Q_r is - taking into account for the quantization. If positive only differences - between start_tick and grid < Q_r is taking into accounto for the quantization. - If none all start_tick is quantized based on the strength (it works similar to basic - quantization but adding the strength parameter) - - Returns - ------- """ - Aq = strength + Aq = config.strength for i in range(len(input_notes)): @@ -101,7 +108,7 @@ def advanced_quantizer( delta_tick = start_tick - start_tick_quantized delta_tick_q = int(delta_tick * Aq) - if type_q == "negative" and (abs(delta_tick) > delta_Qr): + if config.type_q == "negative" and (abs(delta_tick) > config.delta_qr): if delta_tick > 0: input_notes[i].start_ticks = start_tick - delta_tick_q input_notes[i].end_ticks = end_tick - delta_tick_q @@ -110,7 +117,7 @@ def advanced_quantizer( input_notes[i].start_ticks = start_tick + abs(delta_tick_q) input_notes[i].end_ticks = end_tick + abs(delta_tick_q) - elif type_q == "positive" and (abs(delta_tick) < delta_Qr): + elif config.type_q == "positive" and (abs(delta_tick) < config.delta_qr): if delta_tick > 0: input_notes[i].start_ticks = input_notes[i].start_ticks - delta_tick_q input_notes[i].end_ticks = input_notes[i].end_ticks - delta_tick_q @@ -119,7 +126,7 @@ def advanced_quantizer( input_notes[i].start_ticks = input_notes[i].start_ticks + abs(delta_tick_q) input_notes[i].end_ticks = input_notes[i].end_ticks + abs(delta_tick_q) - elif type_q is None: + elif config.type_q is None: if delta_tick > 0: input_notes[i].start_ticks = input_notes[i].start_ticks - delta_tick_q input_notes[i].end_ticks = input_notes[i].end_ticks - delta_tick_q @@ -128,5 +135,5 @@ def advanced_quantizer( input_notes[i].start_ticks = input_notes[i].start_ticks + abs(delta_tick_q) input_notes[i].end_ticks = input_notes[i].end_ticks + abs(delta_tick_q) - input_notes[i].start_sec = input_notes[i].start_ticks * ms_per_tick(bpm) / 1000 - input_notes[i].end_sec = input_notes[i].end_ticks * ms_per_tick(bpm) / 1000 + input_notes[i].start_sec = input_notes[i].start_ticks * ms_per_tick(bpm, resolution) / 1000 + input_notes[i].end_sec = input_notes[i].end_ticks * ms_per_tick(bpm, resolution) / 1000 diff --git a/musicaiz/tokenizers/mmm.py b/musicaiz/tokenizers/mmm.py index 4166a6b..f1397a0 100644 --- a/musicaiz/tokenizers/mmm.py +++ b/musicaiz/tokenizers/mmm.py @@ -1,11 +1,9 @@ from typing import Optional, List, Dict, Union, TextIO from pathlib import Path -import argparse import logging from dataclasses import dataclass - from musicaiz.loaders import Musa from musicaiz.structure import Note, Instrument, Bar from musicaiz.tokenizers import EncodeBase, TokenizerArguments diff --git a/tests/unit/musicaiz/converters/test_musa_to_protobuf.py b/tests/unit/musicaiz/converters/test_musa_to_protobuf.py index 561b532..0ff946f 100644 --- a/tests/unit/musicaiz/converters/test_musa_to_protobuf.py +++ b/tests/unit/musicaiz/converters/test_musa_to_protobuf.py @@ -17,6 +17,7 @@ def midi_data(): "expected_instrument_name_2": "Piano left", } + def _assert_midi_valid_instr_obj(midi_data, instruments): # check instrs assert midi_data["expected_instruments"] == len(instruments) diff --git a/tests/unit/musicaiz/loaders/test_loaders.py b/tests/unit/musicaiz/loaders/test_loaders.py index 2febfd6..d6264db 100644 --- a/tests/unit/musicaiz/loaders/test_loaders.py +++ b/tests/unit/musicaiz/loaders/test_loaders.py @@ -17,8 +17,7 @@ def midi_sample_2(fixture_dir): def test_Musa(midi_sample): # args - quantize = False - quantize_note = "sixteenth" + quantize = True cut_notes = False absolute_timing = False general_midi = True @@ -26,7 +25,6 @@ def test_Musa(midi_sample): midi = Musa( file=midi_sample, quantize=quantize, - quantize_note=quantize_note, cut_notes=cut_notes, absolute_timing=absolute_timing, general_midi=general_midi, @@ -42,7 +40,6 @@ def test_Musa(midi_sample): assert midi.resolution != 0 assert len(midi.instruments) != 0 assert midi.is_quantized == quantize - assert midi.quantize_note == quantize_note assert midi.absolute_timing == absolute_timing assert midi.cut_notes == cut_notes assert len(midi.notes) != 0 diff --git a/tests/unit/musicaiz/rhythm/test_quantizer.py b/tests/unit/musicaiz/rhythm/test_quantizer.py index 083b648..19cfc69 100644 --- a/tests/unit/musicaiz/rhythm/test_quantizer.py +++ b/tests/unit/musicaiz/rhythm/test_quantizer.py @@ -4,6 +4,7 @@ from musicaiz import rhythm from musicaiz.structure import Note from musicaiz.rhythm.quantizer import ( + QuantizerConfig, basic_quantizer, get_ticks_from_subdivision, advanced_quantizer, @@ -99,9 +100,12 @@ def test_basic_quantizer_2(grid_8): def test_advanced_quantizer_1(grid_16): - strength = 1 - delta_Qr = 12 - type_q = "positive" + + config = QuantizerConfig( + strength=1, + delta_qr=12, + type_q="positive", + ) notes_bar1 = [ Note(pitch=69, start=1, end=24, velocity=127), @@ -110,7 +114,7 @@ def test_advanced_quantizer_1(grid_16): Note(pitch=64, start=13, end=18, velocity=127), ] - advanced_quantizer(notes_bar1, grid_16, strength, delta_Qr, type_q) + advanced_quantizer(notes_bar1, grid_16, config, 120, 96) expected = [ Note(pitch=69, start=0, end=23, velocity=127), @@ -126,9 +130,11 @@ def test_advanced_quantizer_1(grid_16): def test_advanced_quantizer_2(grid_16): - strength = 1 - delta_Qr = 12 - type_q = None + config = QuantizerConfig( + strength=1, + delta_qr=12, + type_q=None, + ) notes_bar1 = [ Note(pitch=69, start=1, end=24, velocity=127), @@ -137,7 +143,7 @@ def test_advanced_quantizer_2(grid_16): Note(pitch=64, start=13, end=18, velocity=127), ] - advanced_quantizer(notes_bar1, grid_16, strength, delta_Qr, type_q) + advanced_quantizer(notes_bar1, grid_16, config, 120, 96) expected = [ Note(pitch=69, start=0, end=23, velocity=127), @@ -153,9 +159,11 @@ def test_advanced_quantizer_2(grid_16): def test_advanced_quantizer_3(grid_16): - strength = 1 - delta_Qr = 12 - type_q = None + config = QuantizerConfig( + strength=1, + delta_qr=12, + type_q=None, + ) notes_bar1 = [ # i dont know why but it changes when asing to a object Note(pitch=69, start=1, end=24, velocity=127), @@ -164,7 +172,7 @@ def test_advanced_quantizer_3(grid_16): Note(pitch=64, start=13, end=18, velocity=127), ] - advanced_quantizer(notes_bar1, grid_16, strength, delta_Qr, type_q) + advanced_quantizer(notes_bar1, grid_16, config, 120, 96) expected = [ Note(pitch=69, start=0, end=23, velocity=127), @@ -178,11 +186,13 @@ def test_advanced_quantizer_3(grid_16): assert notes_bar1[i].end_ticks == expected[i].end_ticks -def test_advanced_quantizer_3(grid_16): +def test_advanced_quantizer_4(grid_16): - strength = 0.75 - delta_Qr = 12 - type_q= None + config = QuantizerConfig( + strength=0.75, + delta_qr=12, + type_q=None, + ) notes_bar1 = [ Note(pitch=69, start=1, end=24, velocity=127), @@ -191,7 +201,7 @@ def test_advanced_quantizer_3(grid_16): Note(pitch=64, start=30, end=50, velocity=127), ] - advanced_quantizer(notes_bar1, grid_16, strength, delta_Qr, type_q) + advanced_quantizer(notes_bar1, grid_16, config, 120, 96) expected = [ Note(pitch=69, start=1, end=24, velocity=127), From 9a78619c114ec89afde898b766aa868435e2ee60 Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Mon, 5 Dec 2022 10:23:19 +0100 Subject: [PATCH 6/9] update readme --- README.md | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 25bb7b3..dad21e1 100644 --- a/README.md +++ b/README.md @@ -157,26 +157,26 @@ from musicaiz.structure import Chords, Tonality from musicaiz.plotters import Pianoroll, PianorollHTML # Matplotlib - plot = Pianoroll() - musa_obj = Musa(midi_sample, structure="bars") - plot.plot_instrument( - track=musa_obj.instruments[0].notes, - total_bars=2, - subdivision="quarter", - time_sig=musa_obj.time_sig.time_sig, - print_measure_data=False, - show_bar_labels=False + musa_obj = Musa(midi_sample) + plot = Pianoroll(musa_obj) + plot.plot_instruments( + program=[48, 45], + bar_start=0, + bar_end=4, + print_measure_data=True, + show_bar_labels=False, + show_grid=False, + show=True, ) # Pyplot HTML - plot = PianorollHTML() - musa_obj = Musa(midi_sample, structure="bars") - plot.plot_instrument( - track=musa_obj.instruments[0], - bar_start=1, - bar_end=2, - subdivision="quarter", - time_sig=musa_obj.time_sig.time_sig, + musa_obj = Musa(midi_sample) + plot = PianorollHTML(musa_obj) + plot.plot_instruments( + program=[48, 45], + bar_start=0, + bar_end=4, + show_grid=False, show=False ) ```` From 9c4bb9efc36603fef95a31c5909af8f2808b4b4d Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Mon, 5 Dec 2022 10:27:21 +0100 Subject: [PATCH 7/9] remove unusued methods --- musicaiz/features/__init__.py | 4 -- musicaiz/features/predict_midi.py | 98 +------------------------------ 2 files changed, 1 insertion(+), 101 deletions(-) diff --git a/musicaiz/features/__init__.py b/musicaiz/features/__init__.py index 95668f2..ab31095 100644 --- a/musicaiz/features/__init__.py +++ b/musicaiz/features/__init__.py @@ -152,9 +152,7 @@ _order_note_seq_by_chromatic_idx, ) from .predict_midi import ( - predict_midi_chords, predic_time_sig_numerator, - predict_midi_all_keys_degrees, ) from .rhythm import ( get_start_sec, @@ -212,9 +210,7 @@ "_extract_note_positions", "_order_note_seq_by_chromatic_idx", "get_harmonic_density", - "predict_midi_chords", "predic_time_sig_numerator", - "predict_midi_all_keys_degrees", "get_start_sec", "get_ioi", "_delete_duplicates", diff --git a/musicaiz/features/predict_midi.py b/musicaiz/features/predict_midi.py index 6d08f9d..205d149 100644 --- a/musicaiz/features/predict_midi.py +++ b/musicaiz/features/predict_midi.py @@ -3,22 +3,10 @@ # from a midi files. -from typing import Union, TextIO, List, Tuple, Dict -import os -import multiprocessing +from typing import Union, TextIO, List -from musicaiz.harmony import ( - DegreesRoman, - Tonality, - ModeConstructors -) from musicaiz.loaders import Musa -from .harmony import ( - predict_scales_degrees, - predict_chords, - predict_possible_progressions -) from musicaiz.rhythm import get_subdivisions from .rhythm import ( get_start_sec, @@ -50,90 +38,6 @@ def _concatenate_notes_from_different_files( return all_notes, subdivisions -def predict_midi_chords( - files: Union[List[Union[str, TextIO]], str, TextIO], -) -> List[List[Tuple[DegreesRoman, Tonality, ModeConstructors]]]: - """This funciton uses the `predict_chords` function which - predicts the possible chords o a notes list but in this case - for a whole midi file. - - The argument of this function is one or more midi files (it might - be the case that we have the instruments in different midi files and - we want to take all of them into account for the prediction).""" - - # load midi files and get all notes - if not isinstance(files, list): - files = [files] - all_notes, subdivisions = _concatenate_notes_from_different_files(files) - - # loop in time steps of 20 ticks - scales_degrees = [] - all_chords = [] - all_notes_steps = [] - - # ticks step corresponding to the subdivision note - ticks_step = subdivisions[-1]["ticks"] - subdivisions[-2]["ticks"] - - pool = multiprocessing.Pool(processes=os.cpu_count()) - for i in range(0, subdivisions[-1]["ticks"], ticks_step): - notes_time_step = [] - for note in all_notes: - # if note has already finished it's not in the current subdivision - if note.end_ticks < i or note.start_ticks > i + ticks_step: - continue - # calculate note_one inside the current subdivision - if note.start_ticks <= i: - note_start = i - else: - note_start = note.start_ticks - # calculate note_off inside the current subdivision - if note.end_ticks < i + ticks_step: - note_end = note.end_ticks - else: - note_end = i + ticks_step - duration = note_end - note_start - if duration >= ticks_step * 0.2: # IF WE QUANTIZE, 0.5 - notes_time_step.append(note) - all_notes_steps.append(notes_time_step) - # Start the thread - scales = pool.apply_async( - predict_scales_degrees, - args=(notes_time_step,) - ) - # scales = features.predict_scales_degrees(notes_time_step) - chords = predict_chords(notes_time_step) - scales_degrees.append(scales) - all_chords.append(chords) - scales_degrees = [p.get() for p in scales_degrees] - return scales_degrees - - -def predict_midi_all_keys_degrees( - files: List[Union[str, TextIO]] -) -> Dict[str, List[Union[None, DegreesRoman]]]: - scales_degrees = predict_midi_chords(files) - all_scales = predict_possible_progressions(scales_degrees) - return all_scales - - -def get_str_progression_from_scale( - scales: Dict[str, List[Union[None, DegreesRoman]]], - scale: str -) -> str: - str_degrees = "" - for i, degree in enumerate(scales[scale]): - if i != 0: - if i % 8 == 0: # there are 8th notes per bar 4/4 - str_degrees += "|" - else: - str_degrees += "-" - if degree is None: - str_degrees += "None" - else: - str_degrees += degree.major - return str_degrees - - def predic_time_sig_numerator(files: List[Union[str, TextIO]]): """Uses `features.rhythm` functions.""" # load midi files and get all notes From a791380dc535c14c0645581b84e560e0e1a86c9b Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Mon, 5 Dec 2022 11:39:57 +0100 Subject: [PATCH 8/9] add grpah representation --- musicaiz/features/__init__.py | 29 ++++++++- musicaiz/features/graphs.py | 72 +++++++++++++++++++++ tests/unit/musicaiz/features/test_graphs.py | 37 +++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 musicaiz/features/graphs.py create mode 100644 tests/unit/musicaiz/features/test_graphs.py diff --git a/musicaiz/features/__init__.py b/musicaiz/features/__init__.py index ab31095..5b4ba58 100644 --- a/musicaiz/features/__init__.py +++ b/musicaiz/features/__init__.py @@ -122,8 +122,30 @@ get_segment_boundaries plot_novelty_from_ssm + +Graphs +------ + +This submodule presents different implementations of self-similarity matrices. + +The papers that are implemented in this sumbodule are the following: + +.. panels:: + + [1] Jeong, D., Kwon, T., Kim, Y., & Nam, J. (2019) + Graph neural network for music score data and modeling expressive piano performance. + In International Conference on Machine Learning, 3060-3070 + https://proceedings.mlr.press/v97/jeong19a.html + + +.. autosummary:: + :toctree: generated/ + + musa_to_graph + plot_graph """ + from .pitch import ( get_highest_lowest_pitches, get_pitch_range, @@ -168,7 +190,6 @@ note_length_transition_matrix, plot_note_length_transition_matrix, ) - from .self_similarity import ( compute_ssm, self_similarity_louie, @@ -182,6 +203,10 @@ get_segment_boundaries, plot_novelty_from_ssm, ) +from .graphs import ( + musa_to_graph, + plot_graph, +) __all__ = [ "get_highest_lowest_pitches", @@ -230,4 +255,6 @@ "get_novelty_func", "get_segment_boundaries", "plot_novelty_from_ssm", + "musa_to_graph", + "plot_graph", ] diff --git a/musicaiz/features/graphs.py b/musicaiz/features/graphs.py new file mode 100644 index 0000000..aca2cee --- /dev/null +++ b/musicaiz/features/graphs.py @@ -0,0 +1,72 @@ +import matplotlib.pyplot as plt +import networkx as nx + + +def musa_to_graph(musa_object) -> nx.graph: + """Converts a Musa object into a Graph where nodes are + the notes and edges are connections between notes. + + A similar symbolic music graph representation was introduced in: + + Jeong, D., Kwon, T., Kim, Y., & Nam, J. (2019, May). + Graph neural network for music score data and modeling expressive piano performance. + In International Conference on Machine Learning (pp. 3060-3070). PMLR. + + Parameters + ---------- + musa_object + + Returns + ------- + _type_: _description_ + """ + g = nx.Graph() + for i, note in enumerate(musa_object.notes): + g.add_node(i, pitch=note.pitch, velocity=note.velocity, start=note.start_ticks, end=note.end_ticks) + nodes = list(g.nodes(data=True)) + + # Add edges + for i, node in enumerate(nodes): + for j, next_node in enumerate(nodes): + # if note has already finished it's not in the current subdivision + # TODO: Check this conditions + if i >= j: + continue + if node[1]["start"] >= next_node[1]["start"] and next_node[1]["end"] <= node[1]["end"]: + g.add_edge(i, j, weight=5, color="violet") + elif node[1]["start"] <= next_node[1]["start"] and next_node[1]["end"] <= node[1]["end"]: + g.add_edge(i, j, weight=5, color="violet") + if (j - i == 1) and (not g.has_edge(i, j)): + g.add_edge(i, j, weight=5, color="red") + if g.has_edge(i, i): + g.remove_edge(i, i) + return g + + +def plot_graph(graph: nx.graph, show: bool = False): + """Plots a graph with matplotlib. + + Args: + graph: nx.graph + """ + plt.figure(figsize=(50, 10), dpi=100) + "Plots a networkx graph." + pos = {i: (data["start"], data["pitch"]) for i, data in list(graph.nodes(data=True))} + if nx.get_edge_attributes(graph, 'color') == {}: + colors = ["violet" for _ in range(len(graph.edges()))] + else: + colors = nx.get_edge_attributes(graph, 'color').values() + if nx.get_edge_attributes(graph, 'weight') == {}: + weights = [1 for _ in range(len(graph.edges()))] + else: + weights = nx.get_edge_attributes(graph, 'weight').values() + nx.draw( + graph, + pos, + with_labels=True, + edge_color=colors, + width=list(weights), + node_color='lightblue' + ) + if show: + plt.show() diff --git a/tests/unit/musicaiz/features/test_graphs.py b/tests/unit/musicaiz/features/test_graphs.py new file mode 100644 index 0000000..94245eb --- /dev/null +++ b/tests/unit/musicaiz/features/test_graphs.py @@ -0,0 +1,37 @@ +import pytest + +import matplotlib.pyplot as plt +import networkx as nx + +from musicaiz.loaders import Musa +from musicaiz.features import ( + musa_to_graph, + plot_graph, +) + + +@pytest.fixture +def midi_sample_2(fixture_dir): + return fixture_dir / "midis" / "midi_data.mid" + + +def test_musa_to_graph(midi_sample_2): + musa_obj = Musa(midi_sample_2) + graph = musa_to_graph(musa_obj) + + # n notes must be equal to n nodes + assert len(musa_obj.notes) == len(graph.nodes) + + # adjacency matrix + mat = nx.attr_matrix(graph)[0] + + # n notes must be equal to n nodes + assert len(musa_obj.notes) == mat.shape[0] + + +def test_plot_graph(midi_sample_2): + musa_obj = Musa(midi_sample_2) + graph = musa_to_graph(musa_obj) + + plot_graph(graph, show=False) + plt.close("all") From 00af1febcf643b91b3161aaf734b977b3ec83073 Mon Sep 17 00:00:00 2001 From: carlosholivan Date: Mon, 5 Dec 2022 11:59:12 +0100 Subject: [PATCH 9/9] update networkx version --- musicaiz/features/graphs.py | 2 +- requirements.txt | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/musicaiz/features/graphs.py b/musicaiz/features/graphs.py index aca2cee..f590b63 100644 --- a/musicaiz/features/graphs.py +++ b/musicaiz/features/graphs.py @@ -29,7 +29,7 @@ def musa_to_graph(musa_object) -> nx.graph: for i, node in enumerate(nodes): for j, next_node in enumerate(nodes): # if note has already finished it's not in the current subdivision - # TODO: Check this conditions + # TODO: Check these conditions if i >= j: continue if node[1]["start"] >= next_node[1]["start"] and next_node[1]["end"] <= node[1]["end"]: diff --git a/requirements.txt b/requirements.txt index b1aa951..a4445e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ protobuf==4.21.3 rich==12.6.0 # For models submodule -networkx==2.8.3 +networkx==2.8.6 sklearn==0.0 gradio==3.0.15 torchsummary==1.5.1 diff --git a/setup.cfg b/setup.cfg index 06c82d8..c83bb4f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = seaborn==0.11.2 pre-commit==2.19.0 tqdm==4.64.0 - networkx==2.8.3 + networkx==2.8.6 sklearn==0.0 gradio==3.0.15 torchsummary==1.5.1