diff --git a/cnsproject/network/neural_populations.py b/cnsproject/network/neural_populations.py index 7356a4b..9ec48e5 100644 --- a/cnsproject/network/neural_populations.py +++ b/cnsproject/network/neural_populations.py @@ -57,7 +57,7 @@ def __init__( self.additive_spike_trace = additive_spike_trace if self.spike_trace: - self.register_buffer("traces", torch.Tensor()) + self.register_buffer("traces", torch.zeros(*self.shape)) self.register_buffer("tau_s", torch.tensor(tau_s)) if self.additive_spike_trace: @@ -68,7 +68,7 @@ def __init__( self.is_inhibitory = is_inhibitory self.learning = learning - self.register_buffer("s", torch.ByteTensor()) + self.register_buffer("s", torch.zeros(*self.shape, dtype=torch.bool)) self.dt = None @abstractmethod