Graph Neural Networks (GNNs)

Posted on May 29, 2024

Neville Dubash headshot

Neural networks are a type of machine learning model that can be used to process many types of data. One common example is images, where convolutional neural networks take into account the relationships between nearby pixels. Another type of model is a graph neural network (GNN). In a mathematical context, a graph is a data structure which is made up of a set of nodes which are connected by edges. Social networks are classic examples of graphs that are studied using machine learning. Individuals are represented by nodes and the connections between them represent follower or friend relationships. Graph neural networks can be used to predict the preferences of an individual, based on data associated with the individual themselves combined with data from their neighbours in the social network graph. Another example of graph data is found in the field of chemistry, where the connections between atoms in molecules can be studied as a graph. GNNs can be used to predict properties of molecules and help search for an optimal chemical structure.

Like other neural network models, GNNs are typically made up of a sequence of layers. At each layer the node and edge data is updated, often by passing it through another type of neural network. The structure of the graph is used to pool information and message passing can exchange data between nearby nodes and edges in each GNN layer. For the final layer, predictions can be made for individual nodes or edges, or for the graph as a whole.

The GNN structure can be extended to more complex models such as spatio-temporal GNNs where the graph data evolves over time. This type of algorithm has been used to predict vehicle traffic based on time series data from roadway sensors, with physical locations represented as nodes and the streets connecting them as edges. GNNs provide a way of applying deep learning to many types of structured data, identifying complex relationships in an efficient way.