|
1 | 1 | ''' |
2 | 2 | Module implementing the CUDA "standalone" device. |
3 | 3 | ''' |
4 | | -import os |
5 | 4 | import inspect |
6 | | -from collections import defaultdict, Counter |
| 5 | +import os |
| 6 | +import re |
7 | 7 | import tempfile |
| 8 | +from collections import Counter, defaultdict |
8 | 9 | from distutils import ccompiler |
9 | | -import re |
10 | 10 | from itertools import chain |
11 | 11 |
|
12 | 12 | import numpy as np |
13 | | - |
14 | 13 | from brian2.codegen.cpp_prefs import get_compiler_and_args |
| 14 | +from brian2.codegen.generators.cpp_generator import c_data_type |
15 | 15 | from brian2.codegen.translation import make_statements |
16 | | -from brian2.core.clocks import Clock, defaultclock, EventClock |
| 16 | +from brian2.core.clocks import Clock, EventClock, defaultclock |
17 | 17 | from brian2.core.namespace import get_local_namespace |
18 | | -from brian2.core.preferences import prefs, PreferenceError |
19 | | -from brian2.core.variables import ArrayVariable, DynamicArrayVariable, Constant |
20 | | -from brian2.parsing.rendering import CPPNodeRenderer |
| 18 | +from brian2.core.preferences import PreferenceError, prefs |
| 19 | +from brian2.core.variables import ArrayVariable, Constant, DynamicArrayVariable |
| 20 | +from brian2.devices.cpp_standalone.device import CPPStandaloneDevice, CPPWriter |
21 | 21 | from brian2.devices.device import all_devices |
| 22 | +from brian2.groups import Subgroup |
| 23 | +from brian2.input.spikegeneratorgroup import SpikeGeneratorGroup |
| 24 | +from brian2.monitors import EventMonitor, SpikeMonitor, StateMonitor |
| 25 | +from brian2.parsing.rendering import CPPNodeRenderer |
22 | 26 | from brian2.synapses.synapses import Synapses, SynapticPathway |
| 27 | +from brian2.units import second |
23 | 28 | from brian2.utils.filetools import copy_directory, ensure_directory |
24 | | -from brian2.utils.stringtools import get_identifiers, stripped_deindented_lines |
25 | | -from brian2.codegen.generators.cpp_generator import c_data_type |
26 | 29 | from brian2.utils.logger import get_logger |
27 | | -from brian2.units import second |
28 | | -from brian2.monitors import SpikeMonitor, StateMonitor, EventMonitor |
29 | | -from brian2.groups import Subgroup |
30 | | - |
31 | | -from brian2.devices.cpp_standalone.device import CPPWriter, CPPStandaloneDevice |
32 | | -from brian2.input.spikegeneratorgroup import SpikeGeneratorGroup |
| 30 | +from brian2.utils.stringtools import get_identifiers, stripped_deindented_lines |
33 | 31 |
|
34 | | -from brian2cuda.utils.stringtools import replace_floating_point_literals |
35 | | -from brian2cuda.utils.gputools import select_gpu, get_nvcc_path |
| 32 | +from brian2cuda.utils.gputools import get_nvcc_path, select_gpu |
36 | 33 | from brian2cuda.utils.logger import report_issue_message |
| 34 | +from brian2cuda.utils.stringtools import replace_floating_point_literals |
37 | 35 |
|
38 | | -from .codeobject import CUDAStandaloneCodeObject, CUDAStandaloneAtomicsCodeObject |
39 | | - |
| 36 | +from .codeobject import CUDAStandaloneAtomicsCodeObject, CUDAStandaloneCodeObject |
40 | 37 |
|
41 | 38 | __all__ = [] |
42 | 39 |
|
@@ -445,8 +442,7 @@ def generate_objects_source( |
445 | 442 | # if hasattr(var, 'owner') and isinstance(v.owner, Clock): |
446 | 443 | if isinstance(var.owner, SpikeGeneratorGroup): |
447 | 444 | self.spikegenerator_eventspaces.append(varname) |
448 | | - for var in self.eventspace_arrays.keys(): |
449 | | - del self.arrays[var] |
| 445 | + |
450 | 446 | subgroups_with_spikemonitor = set() |
451 | 447 | for codeobj in self.code_objects.values(): |
452 | 448 | if isinstance(codeobj.owner, SpikeMonitor): |
@@ -488,9 +484,9 @@ def generate_objects_source( |
488 | 484 | profile_statemonitor_vars=profile_statemonitor_vars, |
489 | 485 | subgroups_with_spikemonitor=sorted(subgroups_with_spikemonitor), |
490 | 486 | timed_arrays=timed_arrays, |
491 | | - variables_on_host_only=self.variables_on_host_only) |
492 | | - # Reinsert deleted entries, in case we use self.arrays later? maybe unnecassary... |
493 | | - self.arrays.update(self.eventspace_arrays) |
| 487 | + variables_on_host_only=self.variables_on_host_only, |
| 488 | + ) |
| 489 | + |
494 | 490 | writer.write('objects.*', arr_tmp) |
495 | 491 |
|
496 | 492 | def generate_main_source(self, writer): |
@@ -1422,6 +1418,9 @@ def build(self, directory='output', results_directory="results", |
1422 | 1418 | if var.name in ('t', 'dt', 'timestep'): |
1423 | 1419 | # We manage time variables on host and pass them by value to kernels |
1424 | 1420 | self.variables_on_host_only.append(varname) |
| 1421 | + if var.name.endswith("space") and var.name.startswith("_"): |
| 1422 | + # eventspace variables are on host |
| 1423 | + self.variables_on_host_only.append(varname) |
1425 | 1424 | for var, varname in self.dynamic_arrays.items(): |
1426 | 1425 | varnames = ['_synaptic_pre', '_synaptic_post'] |
1427 | 1426 | try: |
|
0 commit comments