briann.network package
Submodules
briann.network.area_transformations module
briann.network.connection_transformations module
- class briann.network.connection_transformations.IndexBasedSplitter(input_flatten_axes: Tuple[int, int], input_indices: List[int], output_flatten_axes: Tuple[int, int], output_shape: List[int])[source]
Bases:
ModuleThe IndexBasedSplitter maps its input to its output by using indices in the
forward(). To configure this mapping, the following arguments are used.- Parameters:
input_flatten_axes (Tuple[int,int]) – Sets the
input_flatten_axes()of this object.input_indices (List[int]) – Sets the
input_indices()of this object.output_flatten_axes (Tuple[int,int]) – Sets the
output_flatten_axes()of this object.output_shape (List[int]) – Sets the
output_shape()of this object.
- forward(x: Tensor) Tensor[source]
This method first flattens the input tensor x along
input_flatten_axes(), then extracts the corresponding entries using the indices fromoutput_indices()and then unflattens the output along the axes specified inoutput_flatten_axes()to arrive at the output shape specified inoutput_shape().
- property input_flatten_axes: Tuple[int, int]
- Returns:
The axes along which the input will be flattened inside
forward()before selecting entries from it. This is a Tuple of two ints, where the first int is the axis at which flattening starts and the second int is the axis (inclusive) at which it will end.- Return type:
Tuple[int, int]
- property input_indices: Dict[int, List[int]]
- Returns:
A list of indices that specify which entries of the flattened input
torch.Tensorshall be moved to in the flattened output tensor. The length of this list must factor intooutput_shape()along theoutput_flatten_axes().- Return type:
List[int]
- property output_flatten_axes: Tuple[int, int]
- Returns:
A Tuple of two axes, namely the start_axis and end_axis (inclusive). Axes spanned by the start_axis and end_axis are the axes along which the output tensor will be flattened before its entries are filled with the corresponing ones from the input tensor. The calculation of these axes DOES include the initial batch axis and is thus assumed to lead to axes greater than 0.
- Return type:
Tuple[int,int]
- property output_shape: List[int]
- Returns:
The desired shape for the
torch.Tensorthat theforward()method shall output.- Return type:
List[int]
- class briann.network.connection_transformations.Splitter[source]
Bases:
AdapterThis class is to be used in the
forward()method. It takes an input x that is the currentstate()from the sendingoutput_time_frame_accumulatorand then splits off the part that is relevant to the callingConnectionobject.
briann.network.core module
This module collects all necessary components to build a BrIANN model.
- class briann.network.core.AdditiveMerger[source]
Bases:
MergerThis merger adds all inputs.
- forward(x: Dict[int, torch.Tensor]) torch.Tensor[source]
Adds all tensors stored in the values of input dictionary x.
- Parameters:
x (Dict[int, torch.Tensor]) – A dictionary of [int, tensor] where a key is an index of a connection and a value is a tensor to be processed. This input is assumed to be non-empty and an exception will be raised if the assumption is violated.
- class briann.network.core.Area(*args: Any, **kwargs: Any)[source]
Bases:
ModuleAn area corresponds to a small population of neurons that jointly hold a representation in the area’s
output_time_frame_accumulator(). Given a time-point t and a set S of areas that should be updated at t, the caller should update the areas’ states in two consecutive loops over S. The first loop should call thecollect_inputs()method on each area to make it collect, sum and buffer its inputs from the overall network. Then, in the second loop, theforward()method should be called on each area of S to sum the buffered inputs and apply the area’stransformation(). This splitting of input collection and forward transformation allows for parallelization of areas.- Parameters:
index (int) – Sets the
indexof this area.output_time_frame_accumulator (
TimeFrameAccumulator) – Sets theoutput_time_frame_accumulator()of self.input_connections (List[
Connection]) – Sets theinput_connections()of this area.input_shape (List[int]) – Sets the
input_shape()of this area.output_shape (List[int]) – Sets the
output_shape()of this area.output_connections (List[
Connection]) – Sets theoutput_connections()of this area.transformation (torch.nn.Module) – Sets the
transformation()of this area.update_rate (float) – Sets the
update_rate()of this area.
- Raises:
ValueError – If the index is not a non-negative integer.
- collect_inputs(current_time: float) None[source]
Calls the
forward()method of all incoming connections to get the current inputs, sums them up and buffers the result for later use by theforward()method. Since the inputs are summed, it is necessary that they are all of the same shape.- Parameters:
current_time (float) – The current time of the simulation used to time-discount the states of the input areas.
- Return type:
None
- forward() None[source]
Assuming
collect_inputs()has been run on all areas of the simulation just beforehand, this method passes the buffered inputs through the :py:meth:`~briann.network.core.Area.transformation of self (if exists) and passes the result to theaccumulate()of self.
- property index: int
- Returns:
The index used to identify this area in the overall model.
- Return type:
int
- property input_connections: Set[Connection]
- Returns:
A set of
Connectionobjects projecting to this area.- Return type:
Set[Connection]
- property input_shape: int
- Returns:
The shape of the input to this area for a single instance (i.e. excluding the batch-dimension that is assumed to be at index 0 of the actual input).
- Return type:
int
- property merger: Merger
- Returns:
This merger is used to merge the input signals in the
collect_inputs()method.- Return type:
- property output_connections: Set[Connection]
- Returns:
A set of
Connectionobjects projecting from this area.- Return type:
Set[Connection]
- property output_shape: int
- Returns:
The shape of the output of this area for a single instance (i.e. excluding the batch-dimension that is assumed to be at index 0 of the actual output). The output is the state held in the
output_time_frame_accumulator()and hence has same shape.- Return type:
int
- property output_time_frame_accumulator: TimeFrameAccumulator
- Returns:
The time-frame accumulator of this area. This holds the output state of the area which will be made available to other areas via
Connection.- Return type:
- reset() None[source]
Resets the area to its initial state. This should be done everytime a new trial is simulated.
- property update_count: int
- Returns:
Counts how many times this area was updated during the simulation.
- Return type:
int
- property update_rate: float
- Returns:
The update-rate of this area.
- Return type:
float
- class briann.network.core.BrIANN(*args: Any, **kwargs: Any)[source]
Bases:
ModuleThis class functions as the network that holds together all its
Area’s andConnection’s. Its name abbreviates Brain Inspired Artificial Neural Networks. To use it, one should provide a configuration dictionary from which all components can be loaded. Then, for each batch, one should callload_next_stimulus_batch(). Once a batch is loaded, the processing can be simulated for as long as the caller intends (ideally at least for as long as theSourceareas provideTimeFrame’s) using thestep()method. In order to get a simplified networkx representation which contains information about the large-scale network topology (Area’s andConnection’s), one can useget_topology().- Parameters:
configuration (Dict[str, Any]) – A configuration in the form of a dictionary.
- _current_simulation_time
- Returns:
The time that has passed since the start of the simulation. It is updated after each step of the simulation.
- Return type:
float
- property areas: torch.nn.ModuleList
- Returns:
The set of areas held by self.
- Return type:
torch.nn.ModuleList
- property connections: torch.nn.ModuleList
- Returns:
The set of internally stored
Connection.- Return type:
torch.nn.ModuleList
- property current_simulation_time: float
- Returns:
The time that has passed since the start of the simulation. It is updated after each step of the simulation.
- Return type:
float
- extract_current_output_batch(final_states_only: bool = True) torch.Tensor | Dict[int, torch.Tensor][source]
Extracts the current output batch(es) held by the target area(s) of this model. This is done for each target area by either collecting the target area’s current states (if final_states_only is True) or the entire sequence of output frames (if final_states_only is False). After extraction, the target area’s internal buffer of time-frames is cleared.
- Parameters:
final_states_only (bool, optional, defaults to True) – Indicates whether only the final, i.e. current output states shall be returned or the entire sequence.
- Returns:
If there is only a single target area, the output will be a torch.Tensor. Otherwise, it will be a dictionary for which a key is a target area’s index and the corresponding value a tensor. Each such tensor is of shape [batch_size, time_frame_count, …], where … is the shape of the state of a single instance’s time-frame of the corresponding target area.
- Return type:
torch.Tensor| Dict[int,torch.Tensor]
- get_area_at_index(index: int) Area[source]
- Returns:
The area with given index.
- Return type:
- Raises:
ValueError – If self does not store an area of given index
- get_area_indices() Set[int][source]
- Returns:
The set of indices of the areas stored internally.
- Return type:
Set[int]
- get_connections_from(area_index: int) Set[Connection][source]
- Returns:
A set of
Connectionobjects that are the output connections of the area with the given index.- Return type:
Set[
Connection]
- get_connections_to(area_index: int) Set[Connection][source]
- Returns:
A set of
Connectionobjects that are the input connections to the area with the given index.- Return type:
Set[
Connection]
- get_topology() networkx.DiGraph[source]
Converts the BrIANN network to a NetworkX DiGraph where each node is simply the
index()of a correspondingAreaand each edge is simply the triplet (u,*v*) where u is thefrom_index(), v theto_index()of the correspondingConnection.- Returns:
A NetworkX DiGraph representing the BrIANN network.
- Return type:
nx.DiGraph
- load_next_stimulus_batch(X: torch.Tensor | Dict[int, torch.Tensor]) None[source]
This method resets the
current_simulation_time()and all areas. It also makes theSourceareas load their corresponding next batch of stimuli. It thus assumes that all source areas have a validdata_loader()set and that the data loaders are in sync with each other and non-empty.- Parameters:
X (
torch.Tensor| Dict[int,torch.Tensor]) – A tensor of shape [batch_size, time_frame_count, …] or a Dict[int,torch.Tensor] where the tensor’s first axis corresponds to instances in the batch and the second axis to time-frames. If a dictionary is provided, then each key is an index of a source area and the value is the corresponding input tensor.- Return type:
None
- step() Set[Area][source]
Performs one step of the simulation by finding the set of areas due to be updated next and calling their
collect_inputs()andforward()method to make them process their inputs. This method needs to be called repeatedly to step through the simulation. The simulation does not have an internally checked stopping condition, meaning this step method can be called indefinitely, even if the sources already ran out of stimuli. The caller of this method thus needs to determine when to stop the simulation.- Returns:
The set of areas that were updated within this step.
- Return type:
Set[
Area]
- class briann.network.core.Connection(*args: Any, **kwargs: Any)[source]
Bases:
ModuleA connection between two
coreAreaobjects. This is analogous to a neural tract between areas of a biological neural network that not only sends information but also converts it between the reference frames of the input and output area. It thus has atransformation()that is applied to the input before it is sent to the target area. For biological plausibility, the transformation should be a simple linear transformation, for instance atorch.nn.Linearlayer.- Parameters:
index (int) – Sets the
indexof this area.from_area_index (int) – Sets the
from_area_index()of this connection.to_area_index (int) – Sets the
to_area_index()of this connection.input_time_frame_accumulator (
TimeFrameAccumulator) – Used to setinput_time_frame_accumulator()of self.transformation (
ConnectionTransformation) – Sets thetransformation()of the connection.
- forward(current_time: float) TimeFrame[source]
Reads the current state of the
time_frame_accumulator()and applies thetransformation()to it.- Parameters:
current_time (float) – The current time in the simulation.
- Returns:
The produced time frame.
- Return type:
- property from_area_index: int
- Returns:
The index of the area that is the source of this connection.
- Return type:
int
- property index: int
- Returns:
The index used to identify this connection in the overall model.
- Return type:
int
- property input_time_frame_accumulator: TimeFrameAccumulator
- Returns:
The time frame accumulator that stores the input of the connection.
- Return type:
- property to_area_index: int
- Returns:
The index of the area that is the target of this connection.
- Return type:
int
- class briann.network.core.ConnectionTransformation(*args: Any, **kwargs: Any)[source]
Bases:
ModuleSuperclass for a transformation to be placed on a
Connection. This default implementation only perform the identity transformation on its input.- Parameters:
input_shape (List[int]) – Sets the
input_shape()property.output_shape (List[int]) – Sets the
output_shape()property.
- Returns:
An instance of this class.
- Return type:
ConnectionTransformation.
- property input_shape: List[int]
- Returns:
The shape of the input to this transformation, disregarding the initial batch axis.
- Return type:
List[int]
- property output_shape: List[int]
- Returns:
The shape of the output of this transformation, disregarding the initial batch axis.
- Return type:
List[int]
- class briann.network.core.IndexBasedMerger(connection_index_to_input_flatten_axes: Dict[int, Tuple[int, int]], connection_index_to_output_indices: Dict[int, List[int]], output_flatten_axes: Tuple[int, int], final_output_shape: List[int])[source]
Bases:
MergerMaps the dimensions of the input tensors given to the
forward()method to the output tensor.- Parameters:
connection_index_to_input_flatten_axes (Dict[int, Tuple[int,int]]) – Sets the
connection_index_to_input_flatten_axes()property of this instance.connection_index_to_output_indices (Dict[int, Tuple[int,int]]) – Sets the
connection_index_to_output_indices()property of this instance.output_flatten_axes (Tuple[int,int]) – Sets the
output_flatten_axes()property of this instance.final_output_shape – Sets the
final_output_shape()property of this instance.
- property connection_index_to_input_flatten_axes: Dict[int, Tuple[int, int]]
:return:A dictionary mapping the
index()to a Tuple of two axes, namely the start_axis and end_axis (inclusive). Axes spanned by the start_axis and end_axis are the axes along which the given input tensor will be flattened before its dimensions are mapped onto the output tensor. The calculation of these axes DOES include the initial batch axis and is thus assumed to be greater than 0. :rtype: Dict[int, Tuple[int,int]]
- property connection_index_to_output_indices: Dict[int, List[int]]
:return:A dictionary mapping the
index()to a list of indices. These latter indices specify where the dimensions of the connection’s flattened inputtorch.Tensorshall be moved to in the flattened output tensor. Important, the tensor_indices must all be unique and, when joined and sorted, give a contiguous list starting at 0. The length of this list must factor into thefinal_output_shape()along theoutput_flatten_axes(). :rtype: Dict[int, List[int]]
- property final_output_shape: List[int]
- Returns:
The final output shape after unflattening the output tensor along the specified output_flatten_axes. This shape does NOT include the initial batch axis, acknowledgeing that the batch-size is not necssarily known during configuration of this object.
- Return type:
List[int]
- forward(x: Dict[int, torch.Tensor]) torch.Tensor[source]
For each input tensor in x, this method first flattens the input along the axes specified during initialization and then moves the dimensions in their existing order to the indices of the output tensor that were specified during initialization. At this point, the output tensor is flat along the axes specified during initialization. Therafter, the output tensor is reshaped along those axes to reach its final output shape specified during initialization. It is thus assumed that the input tensors and the output tensor all have the same shape along the remaining axes.
For example, assume the inputs have shapes
first input: [batch_size, 2, 4, 5] which will be flattened to [batch_size, 2, 20]
second input: [batch_size, 2, 10] which will be “flattened” to [batch_size, 2, 10].
Then, assume the final output shape is [batch_size] + [2,10,3] which will be flattened to [batch_size] + [2,30].
The mapping is performed on the flattened tensors. For simplicity, say the 20 dimensions of the first input’s axis 2 are mapped to the first 20 dimensions of the flat output’s axis 2 and the 10 dimensions of the second input’s axis 2 are mapped to the last 10 dimensions of the flat output’s axis 2. This is only possible when the remaining axes (here, the leading batch_size axis and axis 1) have the same dimensionalities for all inputs and y. Finally, y is reshaped to its final output shape [batch_size] + [2,10,3] and returned.
- Parameters:
x (Dict[int, torch.Tensor]) – A dictionary of [int, tensor] where a key is an index of a connection and a value is a tensor to be processed. This input is assumed to be non-empty and an exception will be raised if the assumption is violated.
- property output_flatten_axes: Tuple[int, int]
- Returns:
The first axis and last axis (inclusive) along which the output tensor shall initially be flattened while mapping the inputs onto it. The calculation of these axes DOES include the initial batch axis and is thus assumed to be greater than 0.
- Return type:
Tuple[int,int]
- class briann.network.core.LinearConnectionTransformation(*args: Any, **kwargs: Any)[source]
Bases:
ConnectionTransformationA simple linear transformation to be placed on a
Connection. The input is first flattened along all axes except the first (batch axis) and then passed through a regulartorch.nn.Linearlayer before being reshaped to fit the output shape. Construction of this object allow for keyword arguments that further configure the linear transformation taken from torch, i.e. bias, device, dtype.- Parameters:
input_shape (List[int]) – Sets the
input_shape()property.output_shape (List[int]) – Sets the
output_shape()property.
- _on_update_shape(name: str, value: List[int]) None[source]
This method is a callback that adjusts the model parameters of the transformation of self whenever the
input_shape()oroutput_shape()is updated.- Parameters:
obj (
ConnectionTransformation) – The object on which the input_shape is updated.name (str) – The name of the attribute (i.e. ‘input_shape’ or ‘output_shape) to be updated.
value (List[int]) – The new shape of the input or output, disregarding the batch_size as axis 0.
- Return type:
None
- class briann.network.core.Merger[source]
Bases:
AdapterThis class (and in particular its forward method) is to be used inside the
collect_inputs()method to merge all the collected inputs into onetorch.Tensor.- abstractmethod forward(x: Dict[int, torch.Tensor]) torch.Tensor[source]
- Parameters:
x (Dict[int, torch.Tensor]) – A dictionary of [int, tensor] where a key is an index of a connection and a value is a tensor to be processed. This input is assumed to be non-empty and an exception will be raised if the assumption is violated.
- Returns:
A single tensor that is the combination of the values of x.
- Return type:
torch.Tensor
- class briann.network.core.Source(*args: Any, **kwargs: Any)[source]
Bases:
AreaThe source
Areais a special area because it streams the input to the other areas. In order to set it up for the simulation of a trial, load stimuli via thecollect_inputs()method, oneTimeFramewill be taken from the stimuli and held in a bffer. Upon calling theforward()method, that time-frame will be placed in theTimeFrameAccumulator(), so that it can be read by other areas. Once the time frames are all streamed, the source area will no longer add new time-frames to the accumulator and hence its representation will simply decay over time.- Parameters:
index (int) – Sets the
indexof this area.output_time_frame_accumulator (
TimeFrameAccumulator) – Sets thetime_frame_accumulator()of this area.output_shape (List[int]) – Sets the
output_shape()of this area.output_connections (Dict[int,
Connection]) – Sets theoutput_connections()of this area.update_rate (float) – Sets the
update_rate()of this area.
- collect_inputs(current_time: float) None[source]
Pops the next
TimeFramefromstimulus_batch()or generates an array of zeros if the stimulus stream is over. Either way, the result is buffered internally to be made available to other areas upon callingforward().- Parameters:
current_time (float) – The current time of the simulation.
- Raises:
ValueError – if the current_time is not equal to the time of the popped
TimeFrame.- Return type:
None
- load_next_stimulus_batch(X: torch.Tensor) None[source]
This method loads the next batch of stimuli that will be streamed to the other model areas during the simulation.
- Parameters:
X (
torch.Tensor) – A tensor of shape [batch_size, time_frame_count, …] where the first axis corresponds to instances in the batch and the second axis to time-frames.- Raises:
Exception – if self.data_loader is None.
StopIteration – if the data_loader is empty.
- class briann.network.core.Target(*args: Any, **kwargs: Any)[source]
Bases:
AreaThis class is a subclass of
Areaand has the same functionality as a regular area except that it has no output connections.- Parameters:
index (int) – Sets the
indexof this area.output_time_frame_accumulator (
TimeFrameAccumulator) – Sets theoutput_time_frame_accumulator()of self.input_connections (List[
Connection]) – Sets theinput_connections()of this area.input_shape (List[int]) – Sets the
input_shape()of this area.output_shape (List[int]) – Sets the
output_shape()of this area.transformation (torch.nn.Module) – Sets the
transformation()of this area.update_rate (float) – Sets the
update_rate()of this area.
- extract_current_output_batch(final_states_only: bool = True) torch.Tensor[source]
Extracts the current output batch held by this target area. This is done by collecting all time-frames obtained by
forward()and held in buffer and stacking them into a tensor that is returned. After extraction, the internal buffer of time-frames is cleared.- Parameters:
final_states_only (bool, optional, defaults to True) – Indicates whether only the final, i.e. current output states shall be returned or the entire sequence.
- Returns:
The output batch held by this target area. This is a tensor of shape [batch_size, time_frame_count, …], where … is the shape of the state of a single instance’s time-frame of this area
- Return type:
torch.Tensor
- forward() None[source]
Assuming
collect_inputs()has been run on all areas of the simulation just beforehand, this method passes the buffered inputs through the :py:meth:`~briann.network.core.Area.transformation of self (if exists) and passes the result to theaccumulate()of self.
- property output_connections: Set[Connection]
- Returns:
A set of
Connectionobjects projecting from this area.- Return type:
Set[Connection]
- class briann.network.core.TimeFrame(state: torch.Tensor, time_point: float)[source]
Bases:
objectA time-frame in the simulation that holds a temporary state of an
Area.- Parameters:
state (
torch.Tensor) – Sets thestateof this time frame.time_point (float) – Sets the
time_pointof this time frame.
- property state: torch.Tensor
- Returns:
The state of the time frame. This is a
torch.tensor, for instance of shape [instance count, dimensionality].- Return type:
torch.Tensor
- class briann.network.core.TimeFrameAccumulator(initial_time_frame: TimeFrame, decay_rate: float)[source]
Bases:
objectThis class is used to accumulate
TimeFrameobjects. Accumulation happens by adding new time-frames into the accumulator’s own time-frame using theaccumulate()function. An important feature of the accumulator is that during every update, the currently stored information decays according to the provided decay_rate and the time since the last update. This is done to ensure that older information has less influence on the current state of the accumulator than new information.- Parameters:
initial_time_frame (
TimeFrame) – Sets theinitial_time_frameandtime_frameof this time frame accumulator.decay_rate (float) – Sets the
decay_rate()property of self.
- accumulate(time_frame: TimeFrame) None[source]
Sets the
state()of thetime_frame()of self equal to the weighted sum of the state of the new time_frame and the state the current time frame of self. The weight for the old state is w =decay_rate`^dt, where dt is the time of the provided `time_frame()minus the time-frame currently held by self. The weight for the new time_frame is simply equal to 1. This method also sets thetime_point()of the time-frame of self equal to that of the new time_frame.- Parameters:
time_frame (
TimeFrame) – The new time-frame to be added to thetime_frame()of self.- Raises:
ValueError – If the state of time_frame does not have the same shape as that of the current time-frame of self.
ValueError – If the time-point of time_frame is earlier than that of the current time-frame of self.
- Returns:
None
- property decay_rate: float
- Returns:
The rate taken from the interval [0,1] at which the energy of the
state()oftime_frame()decays as time passes. This rate is recommended to be in the range (0,1), in order to have true exponential decay. If set to 1, there is no decay, if set to 0, there is no memory. See py:meth:~.TimeFrameAccumulator.accumulate for details.- Return type:
float
- reset(initial_time_frame: TimeFrame = None) None[source]
Resets the
time_frame()of self. If initial_time_frame is provided, then this one will be used for reset and saved ininitial_time_frame(). Otherwise, the one provided during construction will be used.- Parameters:
initial_time_frame (TimeFrame, optional, defaults to None.) – The time-frame to be used to set
time_frame()andinitial_time_frame()of self.
- time_frame(current_time: float) TimeFrame[source]
Provides a
TimeFramethat holds the time-discounted sum of allTimeFrameobjects added via theaccumulate()method.- Parameters:
current_time (float) – The current time, used to discount the state of self.
- Raises:
ValueError – If current_time is earlier than the time-point of the current time-frame of self.
- Returns:
The time-discounted time-frame of this accumulator.
- Return type: