Skip to content

SparseConnection support #703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
7 changes: 6 additions & 1 deletion bindsnet/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def assign_labels(
indices = torch.nonzero(labels == i).view(-1)

# Compute average firing rates for this label.
selected_spikes = torch.index_select(spikes, dim=0, index=torch.tensor(indices))
rates[:, i] = alpha * rates[:, i] + (
torch.sum(spikes[indices], 0) / n_labeled
torch.sum(selected_spikes, 0) / n_labeled
)

# Compute proportions of spike activity per class.
Expand Down Expand Up @@ -111,6 +112,8 @@ def all_activity(

# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)
if spikes.is_sparse:
spikes = spikes.to_dense()

rates = torch.zeros((n_samples, n_labels), device=spikes.device)
for i in range(n_labels):
Expand Down Expand Up @@ -152,6 +155,8 @@ def proportion_weighting(

# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)
if spikes.is_sparse:
spikes = spikes.to_dense()

rates = torch.zeros((n_samples, n_labels), device=spikes.device)
for i in range(n_labels):
Expand Down
65 changes: 42 additions & 23 deletions bindsnet/learning/MCC_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def update(self, **kwargs) -> None:
if ((self.min is not None) or (self.max is not None)) and not isinstance(
self, NoOp
):
self.feature_value.clamp_(self.min, self.max)
if self.feature_value.is_sparse:
self.feature_value = self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse()
else:
self.feature_value.clamp_(self.min, self.max)

@abstractmethod
def reset_state_variables(self) -> None:
Expand Down Expand Up @@ -247,10 +250,16 @@ def _connection_update(self, **kwargs) -> None:
torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt
)
else:
self.feature_value -= (
self.reduction(torch.bmm(source_s, target_x), dim=0)
* self.connection.dt
)
if self.feature_value.is_sparse:
self.feature_value -= (
torch.bmm(source_s, target_x)
* self.connection.dt
).to_sparse()
else:
self.feature_value -= (
self.reduction(torch.bmm(source_s, target_x), dim=0)
* self.connection.dt
)
del source_s, target_x

# Post-synaptic update.
Expand Down Expand Up @@ -278,10 +287,16 @@ def _connection_update(self, **kwargs) -> None:
torch.mean(self.average_buffer_post, dim=0) * self.connection.dt
)
else:
self.feature_value += (
self.reduction(torch.bmm(source_x, target_s), dim=0)
* self.connection.dt
)
if self.feature_value.is_sparse:
self.feature_value += (
torch.bmm(source_x, target_s)
* self.connection.dt
).to_sparse()
else:
self.feature_value += (
self.reduction(torch.bmm(source_x, target_s), dim=0)
* self.connection.dt
)
del source_x, target_s

super().update()
Expand Down Expand Up @@ -508,16 +523,18 @@ def _connection_update(self, **kwargs) -> None:
self.average_buffer_index + 1
) % self.average_update

if self.continues_update:
self.feature_value += self.nu[0] * torch.mean(
self.average_buffer, dim=0
)
elif self.average_buffer_index == 0:
self.feature_value += self.nu[0] * torch.mean(
if self.continues_update or self.average_buffer_index == 0:
update = self.nu[0] * torch.mean(
self.average_buffer, dim=0
)
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update
else:
self.feature_value += self.nu[0] * self.reduction(update, dim=0)
update = self.nu[0] * self.reduction(update, dim=0)
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -686,14 +703,16 @@ def _connection_update(self, **kwargs) -> None:
self.average_buffer_index + 1
) % self.average_update

if self.continues_update:
self.feature_value += torch.mean(self.average_buffer, dim=0)
elif self.average_buffer_index == 0:
self.feature_value += torch.mean(self.average_buffer, dim=0)
if self.continues_update or self.average_buffer_index == 0:
update = torch.mean(self.average_buffer, dim=0)
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update
else:
self.feature_value += (
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
)
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay
Expand Down
37 changes: 28 additions & 9 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def update(self) -> None:
(self.connection.wmin != -np.inf).any()
or (self.connection.wmax != np.inf).any()
) and not isinstance(self, NoOp):
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
if self.connection.w.is_sparse:
raise Exception("SparseConnection isn't supported for wmin\\wmax")
else:
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)


class NoOp(LearningRule):
Expand Down Expand Up @@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None:
if self.nu[0].any():
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w -= update
del source_s, target_x

# Post-synaptic update.
Expand All @@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None:
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
)
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += update
del source_x, target_s

super().update()
Expand Down Expand Up @@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None:

# Pre-synaptic update.
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[0] * update

# Post-synaptic update.
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[1] * update

super().update()
Expand Down Expand Up @@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None:
a_minus = torch.tensor(a_minus, device=self.connection.w.device)

# Compute weight update based on the eligibility value of the past timestep.
update = reward * self.eligibility
self.connection.w += self.nu[0] * self.reduction(update, dim=0)
update = self.reduction(reward * self.eligibility, dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[0] * update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None:
self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
self.eligibility_trace += self.eligibility / self.tc_e_trace

update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
if self.connection.w.is_sparse:
update = update.to_sparse()
# Compute weight update.
self.connection.w += (
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
)
self.connection.w += update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None:
) * source_x[:, None]

# Compute weight update.
self.connection.w += self.nu[0] * reward * self.eligibility_trace
update = self.nu[0] * reward * self.eligibility_trace
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += update

super().update()
63 changes: 50 additions & 13 deletions bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import torch
from scipy.spatial.distance import euclidean
from torch.nn.modules.utils import _pair
from torch import device

from bindsnet.learning import PostPre
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
from bindsnet.network import Network
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
from bindsnet.network.topology import Connection, LocalConnection
from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection
from bindsnet.network.topology_features import Weight


class TwoLayerNetwork(Network):
Expand Down Expand Up @@ -94,6 +97,9 @@ class DiehlAndCook2015(Network):
def __init__(
self,
n_inpt: int,
device: str = "cpu",
batch_size: int = None,
sparse: bool = False,
n_neurons: int = 100,
exc: float = 22.5,
inh: float = 17.5,
Expand Down Expand Up @@ -170,27 +176,58 @@ def __init__(

# Connections
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
input_exc_conn = Connection(
input_exc_conn = MulticompartmentConnection(
source=input_layer,
target=exc_layer,
w=w,
update_rule=PostPre,
nu=nu,
reduction=reduction,
wmin=wmin,
wmax=wmax,
norm=norm,
device=device,
pipeline=[
Weight(
'weight',
w,
range=[wmin, wmax],
norm=norm,
reduction=reduction,
nu=nu,
learning_rule=MMCPostPre,
sparse=sparse,
batch_size=batch_size
)
]
)
w = self.exc * torch.diag(torch.ones(self.n_neurons))
exc_inh_conn = Connection(
source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
if sparse:
w = w.unsqueeze(0).expand(batch_size, -1, -1)
exc_inh_conn = MulticompartmentConnection(
source=exc_layer,
target=inh_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
range=[0, self.exc],
sparse=sparse
)
]
)
w = -self.inh * (
torch.ones(self.n_neurons, self.n_neurons)
- torch.diag(torch.ones(self.n_neurons))
)
inh_exc_conn = Connection(
source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
if sparse:
w = w.unsqueeze(0).expand(batch_size, -1, -1)
inh_exc_conn = MulticompartmentConnection(
source=inh_layer,
target=exc_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
range=[-self.inh, 0],
sparse=sparse
)
]
)

# Add to network
Expand Down
11 changes: 7 additions & 4 deletions bindsnet/network/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
time: Optional[int] = None,
batch_size: int = 1,
device: str = "cpu",
sparse: Optional[bool] = False
):
# language=rst
"""
Expand All @@ -62,6 +63,7 @@ def __init__(
self.time = time
self.batch_size = batch_size
self.device = device
self.sparse = sparse

# if time is not specified the monitor variable accumulate the logs
if self.time is None:
Expand Down Expand Up @@ -98,11 +100,12 @@ def record(self) -> None:
for v in self.state_vars:
data = getattr(self.obj, v).unsqueeze(0)
# self.recording[v].append(data.detach().clone().to(self.device))
self.recording[v].append(
torch.empty_like(data, device=self.device, requires_grad=False).copy_(
data, non_blocking=True
)
record = torch.empty_like(data, device=self.device, requires_grad=False).copy_(
data, non_blocking=True
)
if self.sparse:
record = record.to_sparse()
self.recording[v].append(record)
# remove the oldest element (first in the list)
if self.time is not None:
self.recording[v].pop(0)
Expand Down
Loading
Loading