Skip to content

Commit 24f6181

Browse files
committed
store eventspaces as results
1 parent 1eb3d88 commit 24f6181

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

brian2cuda/device.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,39 @@
11
'''
22
Module implementing the CUDA "standalone" device.
33
'''
4-
import os
54
import inspect
6-
from collections import defaultdict, Counter
5+
import os
6+
import re
77
import tempfile
8+
from collections import Counter, defaultdict
89
from distutils import ccompiler
9-
import re
1010
from itertools import chain
1111

1212
import numpy as np
13-
1413
from brian2.codegen.cpp_prefs import get_compiler_and_args
14+
from brian2.codegen.generators.cpp_generator import c_data_type
1515
from brian2.codegen.translation import make_statements
16-
from brian2.core.clocks import Clock, defaultclock, EventClock
16+
from brian2.core.clocks import Clock, EventClock, defaultclock
1717
from brian2.core.namespace import get_local_namespace
18-
from brian2.core.preferences import prefs, PreferenceError
19-
from brian2.core.variables import ArrayVariable, DynamicArrayVariable, Constant
20-
from brian2.parsing.rendering import CPPNodeRenderer
18+
from brian2.core.preferences import PreferenceError, prefs
19+
from brian2.core.variables import ArrayVariable, Constant, DynamicArrayVariable
20+
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice, CPPWriter
2121
from brian2.devices.device import all_devices
22+
from brian2.groups import Subgroup
23+
from brian2.input.spikegeneratorgroup import SpikeGeneratorGroup
24+
from brian2.monitors import EventMonitor, SpikeMonitor, StateMonitor
25+
from brian2.parsing.rendering import CPPNodeRenderer
2226
from brian2.synapses.synapses import Synapses, SynapticPathway
27+
from brian2.units import second
2328
from brian2.utils.filetools import copy_directory, ensure_directory
24-
from brian2.utils.stringtools import get_identifiers, stripped_deindented_lines
25-
from brian2.codegen.generators.cpp_generator import c_data_type
2629
from brian2.utils.logger import get_logger
27-
from brian2.units import second
28-
from brian2.monitors import SpikeMonitor, StateMonitor, EventMonitor
29-
from brian2.groups import Subgroup
30-
31-
from brian2.devices.cpp_standalone.device import CPPWriter, CPPStandaloneDevice
32-
from brian2.input.spikegeneratorgroup import SpikeGeneratorGroup
30+
from brian2.utils.stringtools import get_identifiers, stripped_deindented_lines
3331

34-
from brian2cuda.utils.stringtools import replace_floating_point_literals
35-
from brian2cuda.utils.gputools import select_gpu, get_nvcc_path
32+
from brian2cuda.utils.gputools import get_nvcc_path, select_gpu
3633
from brian2cuda.utils.logger import report_issue_message
34+
from brian2cuda.utils.stringtools import replace_floating_point_literals
3735

38-
from .codeobject import CUDAStandaloneCodeObject, CUDAStandaloneAtomicsCodeObject
39-
36+
from .codeobject import CUDAStandaloneAtomicsCodeObject, CUDAStandaloneCodeObject
4037

4138
__all__ = []
4239

@@ -445,8 +442,7 @@ def generate_objects_source(
445442
# if hasattr(var, 'owner') and isinstance(v.owner, Clock):
446443
if isinstance(var.owner, SpikeGeneratorGroup):
447444
self.spikegenerator_eventspaces.append(varname)
448-
for var in self.eventspace_arrays.keys():
449-
del self.arrays[var]
445+
450446
subgroups_with_spikemonitor = set()
451447
for codeobj in self.code_objects.values():
452448
if isinstance(codeobj.owner, SpikeMonitor):
@@ -488,9 +484,9 @@ def generate_objects_source(
488484
profile_statemonitor_vars=profile_statemonitor_vars,
489485
subgroups_with_spikemonitor=sorted(subgroups_with_spikemonitor),
490486
timed_arrays=timed_arrays,
491-
variables_on_host_only=self.variables_on_host_only)
492-
# Reinsert deleted entries, in case we use self.arrays later? maybe unnecassary...
493-
self.arrays.update(self.eventspace_arrays)
487+
variables_on_host_only=self.variables_on_host_only,
488+
)
489+
494490
writer.write('objects.*', arr_tmp)
495491

496492
def generate_main_source(self, writer):
@@ -1422,6 +1418,9 @@ def build(self, directory='output', results_directory="results",
14221418
if var.name in ('t', 'dt', 'timestep'):
14231419
# We manage time variables on host and pass them by value to kernels
14241420
self.variables_on_host_only.append(varname)
1421+
if var.name.endswith("space") and var.name.startswith("_"):
1422+
# eventspace variables are on host
1423+
self.variables_on_host_only.append(varname)
14251424
for var, varname in self.dynamic_arrays.items():
14261425
varnames = ['_synaptic_pre', '_synaptic_post']
14271426
try:

brian2cuda/templates/objects.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ void brian::set_variable_by_name(std::string name, std::string s_value) {
9999
s_value = "0";
100100
// non-dynamic arrays
101101
{% for var, varname in array_specs | dictsort(by='value') %}
102-
{% if not var in dynamic_array_specs and not var.read_only %}
102+
{% if var not in dynamic_array_specs and var not in eventspace_arrays and not var.read_only %}
103103
if (name == "{{var.owner.name}}.{{var.name}}") {
104104
var_size = {{var.size}};
105105
data_size = {{var.size}}*sizeof({{c_data_type(var.dtype)}});
@@ -184,7 +184,7 @@ void brian::set_variable_by_name(std::string name, std::string s_value) {
184184
}
185185
//////////////// arrays ///////////////////
186186
{% for var, varname in array_specs | dictsort(by='value') %}
187-
{% if not var in dynamic_array_specs %}
187+
{% if var not in dynamic_array_specs and var not in eventspace_arrays %}
188188
{{c_data_type(var.dtype)}} * brian::{{varname}};
189189
{{c_data_type(var.dtype)}} * brian::dev{{varname}};
190190
__device__ {{c_data_type(var.dtype)}} * brian::d{{varname}};
@@ -682,7 +682,7 @@ void _dealloc_arrays()
682682
{% endfor %}
683683

684684
{% for var, varname in array_specs | dictsort(by='value') %}
685-
{% if not var in dynamic_array_specs %}
685+
{% if var not in dynamic_array_specs and var not in eventspace_arrays%}
686686
if({{varname}}!=0)
687687
{
688688
delete [] {{varname}};
@@ -783,7 +783,7 @@ extern thrust::device_vector<int32_t> _dev_{{varname}}_eventspace;
783783

784784
//////////////// arrays ///////////////////
785785
{% for var, varname in array_specs | dictsort(by='value') %}
786-
{% if not var in dynamic_array_specs %}
786+
{% if var not in dynamic_array_specs and var not in eventspace_arrays %}
787787
extern {{c_data_type(var.dtype)}} * {{varname}};
788788
extern {{c_data_type(var.dtype)}} * dev{{varname}};
789789
extern __device__ {{c_data_type(var.dtype)}} *d{{varname}};

0 commit comments

Comments
 (0)