briann.network package

Submodules

briann.network.area_transformations module

class briann.network.area_transformations.SimpleDenseTransformation(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x)[source]

briann.network.connection_transformations module

class briann.network.connection_transformations.IndexBasedSplitter(input_flatten_axes: Tuple[int, int], input_indices: List[int], output_flatten_axes: Tuple[int, int], output_shape: List[int])[source]

Bases: Module

The IndexBasedSplitter maps its input to its output by using indices in the forward(). To configure this mapping, the following arguments are used.

Parameters:
  • input_flatten_axes (Tuple[int,int]) – Sets the input_flatten_axes() of this object.

  • input_indices (List[int]) – Sets the input_indices() of this object.

  • output_flatten_axes (Tuple[int,int]) – Sets the output_flatten_axes() of this object.

  • output_shape (List[int]) – Sets the output_shape() of this object.

forward(x: Tensor) Tensor[source]

This method first flattens the input tensor x along input_flatten_axes(), then extracts the corresponding entries using the indices from output_indices() and then unflattens the output along the axes specified in output_flatten_axes() to arrive at the output shape specified in output_shape().

property input_flatten_axes: Tuple[int, int]
Returns:

The axes along which the input will be flattened inside forward() before selecting entries from it. This is a Tuple of two ints, where the first int is the axis at which flattening starts and the second int is the axis (inclusive) at which it will end.

Return type:

Tuple[int, int]

property input_indices: Dict[int, List[int]]
Returns:

A list of indices that specify which entries of the flattened input torch.Tensor shall be moved to in the flattened output tensor. The length of this list must factor into output_shape() along the output_flatten_axes().

Return type:

List[int]

property output_flatten_axes: Tuple[int, int]
Returns:

A Tuple of two axes, namely the start_axis and end_axis (inclusive). Axes spanned by the start_axis and end_axis are the axes along which the output tensor will be flattened before its entries are filled with the corresponing ones from the input tensor. The calculation of these axes DOES include the initial batch axis and is thus assumed to lead to axes greater than 0.

Return type:

Tuple[int,int]

property output_shape: List[int]
Returns:

The desired shape for the torch.Tensor that the forward() method shall output.

Return type:

List[int]

class briann.network.connection_transformations.Splitter[source]

Bases: Adapter

This class is to be used in the forward() method. It takes an input x that is the current state() from the sending output_time_frame_accumulator and then splits off the part that is relevant to the calling Connection object.

forward(x: Tensor) Tensor[source]

briann.network.core module

This module collects all necessary components to build a BrIANN model.

class briann.network.core.AdditiveMerger[source]

Bases: Merger

This merger adds all inputs.

forward(x: Dict[int, torch.Tensor]) torch.Tensor[source]

Adds all tensors stored in the values of input dictionary x.

Parameters:

x (Dict[int, torch.Tensor]) – A dictionary of [int, tensor] where a key is an index of a connection and a value is a tensor to be processed. This input is assumed to be non-empty and an exception will be raised if the assumption is violated.

class briann.network.core.Area(*args: Any, **kwargs: Any)[source]

Bases: Module

An area corresponds to a small population of neurons that jointly hold a representation in the area’s output_time_frame_accumulator(). Given a time-point t and a set S of areas that should be updated at t, the caller should update the areas’ states in two consecutive loops over S. The first loop should call the collect_inputs() method on each area to make it collect, sum and buffer its inputs from the overall network. Then, in the second loop, the forward() method should be called on each area of S to sum the buffered inputs and apply the area’s transformation(). This splitting of input collection and forward transformation allows for parallelization of areas.

Parameters:
Raises:

ValueError – If the index is not a non-negative integer.

collect_inputs(current_time: float) None[source]

Calls the forward() method of all incoming connections to get the current inputs, sums them up and buffers the result for later use by the forward() method. Since the inputs are summed, it is necessary that they are all of the same shape.

Parameters:

current_time (float) – The current time of the simulation used to time-discount the states of the input areas.

Return type:

None

forward() None[source]

Assuming collect_inputs() has been run on all areas of the simulation just beforehand, this method passes the buffered inputs through the :py:meth:`~briann.network.core.Area.transformation of self (if exists) and passes the result to the accumulate() of self.

property index: int
Returns:

The index used to identify this area in the overall model.

Return type:

int

property input_connections: Set[Connection]
Returns:

A set of Connection objects projecting to this area.

Return type:

Set[Connection]

property input_shape: int
Returns:

The shape of the input to this area for a single instance (i.e. excluding the batch-dimension that is assumed to be at index 0 of the actual input).

Return type:

int

property merger: Merger
Returns:

This merger is used to merge the input signals in the collect_inputs() method.

Return type:

Merger

property output_connections: Set[Connection]
Returns:

A set of Connection objects projecting from this area.

Return type:

Set[Connection]

property output_shape: int
Returns:

The shape of the output of this area for a single instance (i.e. excluding the batch-dimension that is assumed to be at index 0 of the actual output). The output is the state held in the output_time_frame_accumulator() and hence has same shape.

Return type:

int

property output_time_frame_accumulator: TimeFrameAccumulator
Returns:

The time-frame accumulator of this area. This holds the output state of the area which will be made available to other areas via Connection.

Return type:

TimeFrameAccumulator

reset() None[source]

Resets the area to its initial state. This should be done everytime a new trial is simulated.

property update_count: int
Returns:

Counts how many times this area was updated during the simulation.

Return type:

int

property update_rate: float
Returns:

The update-rate of this area.

Return type:

float

class briann.network.core.BrIANN(*args: Any, **kwargs: Any)[source]

Bases: Module

This class functions as the network that holds together all its Area’s and Connection’s. Its name abbreviates Brain Inspired Artificial Neural Networks. To use it, one should provide a configuration dictionary from which all components can be loaded. Then, for each batch, one should call load_next_stimulus_batch(). Once a batch is loaded, the processing can be simulated for as long as the caller intends (ideally at least for as long as the Source areas provide TimeFrame’s) using the step() method. In order to get a simplified networkx representation which contains information about the large-scale network topology (Area’s and Connection’s), one can use get_topology().

Parameters:

configuration (Dict[str, Any]) – A configuration in the form of a dictionary.

_current_simulation_time
Returns:

The time that has passed since the start of the simulation. It is updated after each step of the simulation.

Return type:

float

property areas: torch.nn.ModuleList
Returns:

The set of areas held by self.

Return type:

torch.nn.ModuleList

property connections: torch.nn.ModuleList
Returns:

The set of internally stored Connection.

Return type:

torch.nn.ModuleList

property current_simulation_time: float
Returns:

The time that has passed since the start of the simulation. It is updated after each step of the simulation.

Return type:

float

extract_current_output_batch(final_states_only: bool = True) torch.Tensor | Dict[int, torch.Tensor][source]

Extracts the current output batch(es) held by the target area(s) of this model. This is done for each target area by either collecting the target area’s current states (if final_states_only is True) or the entire sequence of output frames (if final_states_only is False). After extraction, the target area’s internal buffer of time-frames is cleared.

Parameters:

final_states_only (bool, optional, defaults to True) – Indicates whether only the final, i.e. current output states shall be returned or the entire sequence.

Returns:

If there is only a single target area, the output will be a torch.Tensor. Otherwise, it will be a dictionary for which a key is a target area’s index and the corresponding value a tensor. Each such tensor is of shape [batch_size, time_frame_count, …], where … is the shape of the state of a single instance’s time-frame of the corresponding target area.

Return type:

torch.Tensor | Dict[int, torch.Tensor]

get_area_at_index(index: int) Area[source]
Returns:

The area with given index.

Return type:

Area

Raises:

ValueError – If self does not store an area of given index

get_area_indices() Set[int][source]
Returns:

The set of indices of the areas stored internally.

Return type:

Set[int]

get_connections_from(area_index: int) Set[Connection][source]
Returns:

A set of Connection objects that are the output connections of the area with the given index.

Return type:

Set[Connection]

get_connections_to(area_index: int) Set[Connection][source]
Returns:

A set of Connection objects that are the input connections to the area with the given index.

Return type:

Set[Connection]

get_topology() networkx.DiGraph[source]

Converts the BrIANN network to a NetworkX DiGraph where each node is simply the index() of a corresponding Area and each edge is simply the triplet (u,*v*) where u is the from_index(), v the to_index() of the corresponding Connection.

Returns:

A NetworkX DiGraph representing the BrIANN network.

Return type:

nx.DiGraph

load_next_stimulus_batch(X: torch.Tensor | Dict[int, torch.Tensor]) None[source]

This method resets the current_simulation_time() and all areas. It also makes the Source areas load their corresponding next batch of stimuli. It thus assumes that all source areas have a valid data_loader() set and that the data loaders are in sync with each other and non-empty.

Parameters:

X (torch.Tensor | Dict[int, torch.Tensor]) – A tensor of shape [batch_size, time_frame_count, …] or a Dict[int, torch.Tensor] where the tensor’s first axis corresponds to instances in the batch and the second axis to time-frames. If a dictionary is provided, then each key is an index of a source area and the value is the corresponding input tensor.

Return type:

None

step() Set[Area][source]

Performs one step of the simulation by finding the set of areas due to be updated next and calling their collect_inputs() and forward() method to make them process their inputs. This method needs to be called repeatedly to step through the simulation. The simulation does not have an internally checked stopping condition, meaning this step method can be called indefinitely, even if the sources already ran out of stimuli. The caller of this method thus needs to determine when to stop the simulation.

Returns:

The set of areas that were updated within this step.

Return type:

Set[Area]

class briann.network.core.Connection(*args: Any, **kwargs: Any)[source]

Bases: Module

A connection between two coreArea objects. This is analogous to a neural tract between areas of a biological neural network that not only sends information but also converts it between the reference frames of the input and output area. It thus has a transformation() that is applied to the input before it is sent to the target area. For biological plausibility, the transformation should be a simple linear transformation, for instance a torch.nn.Linear layer.

Parameters:
forward(current_time: float) TimeFrame[source]

Reads the current state of the time_frame_accumulator() and applies the transformation() to it.

Parameters:

current_time (float) – The current time in the simulation.

Returns:

The produced time frame.

Return type:

TimeFrame

property from_area_index: int
Returns:

The index of the area that is the source of this connection.

Return type:

int

property index: int
Returns:

The index used to identify this connection in the overall model.

Return type:

int

property input_time_frame_accumulator: TimeFrameAccumulator
Returns:

The time frame accumulator that stores the input of the connection.

Return type:

TimeFrameAccumulator

property to_area_index: int
Returns:

The index of the area that is the target of this connection.

Return type:

int

class briann.network.core.ConnectionTransformation(*args: Any, **kwargs: Any)[source]

Bases: Module

Superclass for a transformation to be placed on a Connection. This default implementation only perform the identity transformation on its input.

Parameters:
Returns:

An instance of this class.

Return type:

ConnectionTransformation.

forward(x)[source]
property input_shape: List[int]
Returns:

The shape of the input to this transformation, disregarding the initial batch axis.

Return type:

List[int]

property output_shape: List[int]
Returns:

The shape of the output of this transformation, disregarding the initial batch axis.

Return type:

List[int]

class briann.network.core.IndexBasedMerger(connection_index_to_input_flatten_axes: Dict[int, Tuple[int, int]], connection_index_to_output_indices: Dict[int, List[int]], output_flatten_axes: Tuple[int, int], final_output_shape: List[int])[source]

Bases: Merger

Maps the dimensions of the input tensors given to the forward() method to the output tensor.

Parameters:
property connection_index_to_input_flatten_axes: Dict[int, Tuple[int, int]]

:return:A dictionary mapping the index() to a Tuple of two axes, namely the start_axis and end_axis (inclusive). Axes spanned by the start_axis and end_axis are the axes along which the given input tensor will be flattened before its dimensions are mapped onto the output tensor. The calculation of these axes DOES include the initial batch axis and is thus assumed to be greater than 0. :rtype: Dict[int, Tuple[int,int]]

property connection_index_to_output_indices: Dict[int, List[int]]

:return:A dictionary mapping the index() to a list of indices. These latter indices specify where the dimensions of the connection’s flattened input torch.Tensor shall be moved to in the flattened output tensor. Important, the tensor_indices must all be unique and, when joined and sorted, give a contiguous list starting at 0. The length of this list must factor into the final_output_shape() along the output_flatten_axes(). :rtype: Dict[int, List[int]]

property final_output_shape: List[int]
Returns:

The final output shape after unflattening the output tensor along the specified output_flatten_axes. This shape does NOT include the initial batch axis, acknowledgeing that the batch-size is not necssarily known during configuration of this object.

Return type:

List[int]

forward(x: Dict[int, torch.Tensor]) torch.Tensor[source]

For each input tensor in x, this method first flattens the input along the axes specified during initialization and then moves the dimensions in their existing order to the indices of the output tensor that were specified during initialization. At this point, the output tensor is flat along the axes specified during initialization. Therafter, the output tensor is reshaped along those axes to reach its final output shape specified during initialization. It is thus assumed that the input tensors and the output tensor all have the same shape along the remaining axes.

For example, assume the inputs have shapes

  • first input: [batch_size, 2, 4, 5] which will be flattened to [batch_size, 2, 20]

  • second input: [batch_size, 2, 10] which will be “flattened” to [batch_size, 2, 10].

Then, assume the final output shape is [batch_size] + [2,10,3] which will be flattened to [batch_size] + [2,30].

The mapping is performed on the flattened tensors. For simplicity, say the 20 dimensions of the first input’s axis 2 are mapped to the first 20 dimensions of the flat output’s axis 2 and the 10 dimensions of the second input’s axis 2 are mapped to the last 10 dimensions of the flat output’s axis 2. This is only possible when the remaining axes (here, the leading batch_size axis and axis 1) have the same dimensionalities for all inputs and y. Finally, y is reshaped to its final output shape [batch_size] + [2,10,3] and returned.

Parameters:

x (Dict[int, torch.Tensor]) – A dictionary of [int, tensor] where a key is an index of a connection and a value is a tensor to be processed. This input is assumed to be non-empty and an exception will be raised if the assumption is violated.

property output_flatten_axes: Tuple[int, int]
Returns:

The first axis and last axis (inclusive) along which the output tensor shall initially be flattened while mapping the inputs onto it. The calculation of these axes DOES include the initial batch axis and is thus assumed to be greater than 0.

Return type:

Tuple[int,int]

class briann.network.core.LinearConnectionTransformation(*args: Any, **kwargs: Any)[source]

Bases: ConnectionTransformation

A simple linear transformation to be placed on a Connection. The input is first flattened along all axes except the first (batch axis) and then passed through a regular torch.nn.Linear layer before being reshaped to fit the output shape. Construction of this object allow for keyword arguments that further configure the linear transformation taken from torch, i.e. bias, device, dtype.

Parameters:
_on_update_shape(name: str, value: List[int]) None[source]

This method is a callback that adjusts the model parameters of the transformation of self whenever the input_shape() or output_shape() is updated.

Parameters:
  • obj (ConnectionTransformation) – The object on which the input_shape is updated.

  • name (str) – The name of the attribute (i.e. ‘input_shape’ or ‘output_shape) to be updated.

  • value (List[int]) – The new shape of the input or output, disregarding the batch_size as axis 0.

Return type:

None

forward(x: torch.Tensor) torch.Tensor[source]
class briann.network.core.Merger[source]

Bases: Adapter

This class (and in particular its forward method) is to be used inside the collect_inputs() method to merge all the collected inputs into one torch.Tensor.

abstractmethod forward(x: Dict[int, torch.Tensor]) torch.Tensor[source]
Parameters:

x (Dict[int, torch.Tensor]) – A dictionary of [int, tensor] where a key is an index of a connection and a value is a tensor to be processed. This input is assumed to be non-empty and an exception will be raised if the assumption is violated.

Returns:

A single tensor that is the combination of the values of x.

Return type:

torch.Tensor

class briann.network.core.Source(*args: Any, **kwargs: Any)[source]

Bases: Area

The source Area is a special area because it streams the input to the other areas. In order to set it up for the simulation of a trial, load stimuli via the collect_inputs() method, one TimeFrame will be taken from the stimuli and held in a bffer. Upon calling the forward() method, that time-frame will be placed in the TimeFrameAccumulator(), so that it can be read by other areas. Once the time frames are all streamed, the source area will no longer add new time-frames to the accumulator and hence its representation will simply decay over time.

Parameters:
  • index (int) – Sets the index of this area.

  • output_time_frame_accumulator (TimeFrameAccumulator) – Sets the time_frame_accumulator() of this area.

  • output_shape (List[int]) – Sets the output_shape() of this area.

  • output_connections (Dict[int, Connection]) – Sets the output_connections() of this area.

  • update_rate (float) – Sets the update_rate() of this area.

collect_inputs(current_time: float) None[source]

Pops the next TimeFrame from stimulus_batch() or generates an array of zeros if the stimulus stream is over. Either way, the result is buffered internally to be made available to other areas upon calling forward().

Parameters:

current_time (float) – The current time of the simulation.

Raises:

ValueError – if the current_time is not equal to the time of the popped TimeFrame.

Return type:

None

load_next_stimulus_batch(X: torch.Tensor) None[source]

This method loads the next batch of stimuli that will be streamed to the other model areas during the simulation.

Parameters:

X (torch.Tensor) – A tensor of shape [batch_size, time_frame_count, …] where the first axis corresponds to instances in the batch and the second axis to time-frames.

Raises:
  • Exception – if self.data_loader is None.

  • StopIteration – if the data_loader is empty.

property stimulus_batch: Deque[TimeFrame]

The stimuli that are currently loaded in the source area. This is a deque of TimeFrame objects that are to be processed by the model.

Returns:

The stimuli.

Return type:

Deque[TimeFrame]

class briann.network.core.Target(*args: Any, **kwargs: Any)[source]

Bases: Area

This class is a subclass of Area and has the same functionality as a regular area except that it has no output connections.

Parameters:
extract_current_output_batch(final_states_only: bool = True) torch.Tensor[source]

Extracts the current output batch held by this target area. This is done by collecting all time-frames obtained by forward() and held in buffer and stacking them into a tensor that is returned. After extraction, the internal buffer of time-frames is cleared.

Parameters:

final_states_only (bool, optional, defaults to True) – Indicates whether only the final, i.e. current output states shall be returned or the entire sequence.

Returns:

The output batch held by this target area. This is a tensor of shape [batch_size, time_frame_count, …], where … is the shape of the state of a single instance’s time-frame of this area

Return type:

torch.Tensor

forward() None[source]

Assuming collect_inputs() has been run on all areas of the simulation just beforehand, this method passes the buffered inputs through the :py:meth:`~briann.network.core.Area.transformation of self (if exists) and passes the result to the accumulate() of self.

property output_connections: Set[Connection]
Returns:

A set of Connection objects projecting from this area.

Return type:

Set[Connection]

class briann.network.core.TimeFrame(state: torch.Tensor, time_point: float)[source]

Bases: object

A time-frame in the simulation that holds a temporary state of an Area.

Parameters:
  • state (torch.Tensor) – Sets the state of this time frame.

  • time_point (float) – Sets the time_point of this time frame.

property state: torch.Tensor
Returns:

The state of the time frame. This is a torch.tensor, for instance of shape [instance count, dimensionality].

Return type:

torch.Tensor

property time_point: float
Returns:

The time point at which this time frame’s state() occured.

Return type:

float

class briann.network.core.TimeFrameAccumulator(initial_time_frame: TimeFrame, decay_rate: float)[source]

Bases: object

This class is used to accumulate TimeFrame objects. Accumulation happens by adding new time-frames into the accumulator’s own time-frame using the accumulate() function. An important feature of the accumulator is that during every update, the currently stored information decays according to the provided decay_rate and the time since the last update. This is done to ensure that older information has less influence on the current state of the accumulator than new information.

Parameters:
  • initial_time_frame (TimeFrame) – Sets the initial_time_frame and time_frame of this time frame accumulator.

  • decay_rate (float) – Sets the decay_rate() property of self.

accumulate(time_frame: TimeFrame) None[source]

Sets the state() of the time_frame() of self equal to the weighted sum of the state of the new time_frame and the state the current time frame of self. The weight for the old state is w = decay_rate`^dt, where dt is the time of the provided `time_frame() minus the time-frame currently held by self. The weight for the new time_frame is simply equal to 1. This method also sets the time_point() of the time-frame of self equal to that of the new time_frame.

Parameters:

time_frame (TimeFrame) – The new time-frame to be added to the time_frame() of self.

Raises:
  • ValueError – If the state of time_frame does not have the same shape as that of the current time-frame of self.

  • ValueError – If the time-point of time_frame is earlier than that of the current time-frame of self.

Returns:

None

property decay_rate: float
Returns:

The rate taken from the interval [0,1] at which the energy of the state() of time_frame() decays as time passes. This rate is recommended to be in the range (0,1), in order to have true exponential decay. If set to 1, there is no decay, if set to 0, there is no memory. See py:meth:~.TimeFrameAccumulator.accumulate for details.

Return type:

float

reset(initial_time_frame: TimeFrame = None) None[source]

Resets the time_frame() of self. If initial_time_frame is provided, then this one will be used for reset and saved in initial_time_frame(). Otherwise, the one provided during construction will be used.

Parameters:

initial_time_frame (TimeFrame, optional, defaults to None.) – The time-frame to be used to set time_frame() and initial_time_frame() of self.

time_frame(current_time: float) TimeFrame[source]

Provides a TimeFrame that holds the time-discounted sum of all TimeFrame objects added via the accumulate() method.

Parameters:

current_time (float) – The current time, used to discount the state of self.

Raises:

ValueError – If current_time is earlier than the time-point of the current time-frame of self.

Returns:

The time-discounted time-frame of this accumulator.

Return type:

TimeFrame

Module contents