-
Notifications
You must be signed in to change notification settings - Fork 222
Description
Hi, my understanding is that the sorting analyzer is aware of dependencies between extensions and handles lazyness of what gets computed. However, I'm having trouble using this in a straightforward way and the code to get there seems more and more convoluted and I believe I'm misusing the analyzer.
My goal
when given an empty analyzer (no extensions computed), I simply use (for an input dictionary d)
analyzer.compute_several_extensions(
{k:d.get(k, {}) for k in analzer.get_computable_extensions() if not d.get(k, True) == False}
)
However, I want to achieve the exact same result (up to random sampling) using an analyzer where some extensions may have already computed, possibly with the same parameters (in which case we aim to not recompute it) or with different parameters (in which case we need to recompute that extension and all those depending on it).
The problem
It seems that compute forces recomputing of the extension even if it has the exact same parameters. Thus one needs to check manually if that extension is already in the analyzer with the same parameters. This is a hard task because analyzer.get_default_extension_params(k) | d[k]
seems slightly different then analyzer.get_extension(k).params
. For example waveform has 'dtype': None
in the former and 'dtype': '<f4'
in the latter, templates has 'operators': None in the former and 'operators': ['average', 'std']
in the latter, ...
Is there an easy manner of achieving what I want ?
My code
The current code to achieve something close to what I want (but depending on how exactly dependencies are managed in spikeinterface, it may not be right) is the following. Is this a correct way of achieving my goal ?
analyzer = si.load("my_previous_analyzer_path").copy()
analyzer.set_temporary_recording(rec) #To handle the case where the recording changed path
analyzer.sorting = sorting #To handle the case where the sorting changed path
#For this issue I've put the config file as a string
config_str= """
random_spikes:
method: "uniform"
max_spikes_per_unit: 400
waveforms:
ms_before: 1
ms_after: 2
amplitude_scalings: False
spike_locations: False
"""
params=yaml.safe_load(config_str)
real_params = {k:(analyzer.get_default_extension_params(k) | params.get(k, {}))
for k in analyzer.get_computable_extensions() if params.get(k, {})!=False}
extensions_to_compute = {}
for k in analyzer.get_computable_extensions():
if not k in real_params:
if analyzer.has_extension(k):
analyzer.delete_extension(k)
elif not analyzer.has_extension(k):
extensions_to_compute[k] = real_params[k]
elif real_params[k] != analyzer.get_extension(k).params:
diff = {p:(real_params[k][p], analyzer.get_extension(k).params.get(p, None))
for p in real_params[k] if not real_params[k][p] is None}
diff = {k:(v1, v2) for k, (v1, v2) in diff.items() if v1!=v2 and not v2 is None}
if len(diff) > 0:
print(k, diff)
extensions_to_compute[k] = real_params[k]
analyzer.compute_several_extensions(extensions_to_compute)
Additionally, it would be really nice to have an option to analyzer.compute_several_extensions
in order to keep computing extensions when there is an error (for those that do not depend on the errored extension).