Introduction
In recent times, due to advancements in computing power, deep learning has made progress by leaps and bounds with the successful practical implementation of many old neural network architectures that were proposed years or decades ago. Graph Neural Network (GNN) is one such architecture that was proposed in 2005 by Scarselli, Franco, et all in their paper “Graph Neural Networks for Ranking Web Pages” but it has started gaining popularity recently only.
In this article, we will explain Graph Neural Networks (GNN) for beginners in a very simple-to-read language to build intuitive understanding behinds its importance and workings.
What is a Graph
A graph is a versatile data structure to represent complex relationships between data points. It helps to conceptualize data as a collection of nodes and edges that connects the nodes. Each data object can be represented as a node and the edge connecting two nodes represent the relationship between them. The edges of graphs may also have weights to denote the quantifiable importance of the relationship between two nodes.
Why Graph is useful
There are many real-world data that are so complex they cannot be represented by traditional linear or structured format. However, Graphs are nonlinear in nature and can help to unlock hidden insights and relationships by representing data as nodes and edges. A common use case of graphs is to represent social networks like Facebook where the nodes are users and the edges between them are friendships.
In fact, almost any data can be represented by graphs. For example, in an image, pixels can be denoted as nodes and edges can represent the relationship between pixels. Similarly, a sentence can be visualized as a graph where words are nodes and edges represent relationships between adjacent words.
Graphs are such a powerful tool that they are used in fields of Computer Science, Physics, Biology, Social Science, Finance, Engineering, etc.
What is Graph Learning
Graph Learning is an area of machine learning that deals with graphs and tries to deal with the following problems –
1. Node Classification
In this problem, we predict the class or category of a node based on characteristics. For example, it can be used to identify malicious systems in a network of computers.
2. Link Prediction
In this problem, we predict the link between two nodes that is possibly missing. A classic use case is a friendship recommendation for two people who have similar social media connections.
3. Graph Classification
This deals with classifying a graph into classes or categories based on characteristics. For example, the molecular structure of drugs can be represented as graphs and they can be classified based on their molecular relationship in their structure.
4. Graph Generation
This deals with the generation of new graphs with desired characteristics. It is useful in the discovery of new drugs by generating novel molecular structures.
What are Graph Neural Networks (GNNs)
Graph Neural Networks (GNNs) are types of neural networks that can learn the representation of nodes and edges of a graph and then use this representation to solve graph learning problems like node classification, link prediction, graph classification, graph generation, etc.
GNN (Graph Neural Network) is inspired and motivated by Convolutional Neural Network (CNN) that is used to work with images. CNN is quite capable to extract & learn the local spatial features around a pixel of the image. This gives it extraordinary power to understand the image much better than a simple feed-forward neural network.
On similar lines, in graphs, the fundamental goal is to learn the importance of a given node in the context of its surrounding nodes. However, CNN and other traditional neural networks can work only with Euclidean data such as texts and images. But graphs cannot be modeled in Euclidean space and hence cannot be effectively processed by the traditional neural networks. This is where GNN (Graph Neural Network) comes into the picture which is fundamentally capable to work with graphs by its design.
Graph Embeddings
Before we understand how a graph neural network (GNN) works, we should cover the important topic of graph embeddings. Graph embeddings are numerical vector representations of graph data. Graph embedding can be done at three levels –
- Node Embedding – To represent node data.
- Edge Embedding – To represent edge data.
- Graph Embedding – To represent the complete graph.
The whole possible values of embeddings are known as embedding space. Inside the embedding space, two nodes or edges that are closer to each other share some common characteristics or properties. For example, in the social media context, the embedding value of two people will be close to each other who works for the same company.
How Graph Neural Network (GNN) Works
Although there are many variations of GNNs, the fundamental concept of how the graph neural network works remains the same. In this section, we will understand this concept intuitively.
- We start with the initial representation of the graph as input.
- GNN consists of multiple layers k. In the first layer, the network learns the representation of the immediate neighbors of the nodes. In 2nd layer, the network learns the representation of neighbors of the neighbors, and so on. The number of layers K is a hyperparameter and can be any value.
- During the training process, in the forward pass, the layers calculate the representation of the nodes of the graphs. It then calculates the loss, i.e. the difference between the current representation and the actual representation, and backpropagates the error. In the next iteration, the network adjusts the node representation to minimize the error further, and so on till the error cannot be reduced further. (This step is similar to supervised learning training of any neural network)
- Once the GNN is trained and it has learned the embeddings /representation of the graph nodes/edges, it can be then used to solve graph learning problems like node classification, link prediction, graph classification, etc.
Types of Graph Neural Networks (GNNs)
Some of the popular architectures of graph neural networks are –
1. Message Passing Neural Network (MPNN)
In Message Passing Neural Network (MPNN), there are two steps involved – i) Message Passing & ii) Updating. During message passing, each node sends and receives messages from its connected neighbors. This helps nodes to learn the representation of other nodes in the graph. (Think of it as friends in a group chat, sharing their knowledge and experiences.)
After exchanging messages, nodes update their information based on what they’ve learned from their neighbors. This process is repeated in many iterations, allowing the nodes to gain a comprehensive understanding of the entire graph.
2. Graph Convolutional Network (GCNN)
In Graph Convolutional Network (GCNN), the node starts with an initial representation. It then aggregates the features of its neighboring nodes and updates its representation with the new information in one single step. This process can be thought of as blending the node’s features with those of its neighbors, effectively incorporating the neighbors’ information without explicit message passing which was the case in MPNN.
The feature aggregation is performed using convolutional operations, which is where the “convolutional” aspect of Graph Convolutional Networks comes from. During the training process, these operations are applied to every node in the graph, in multiple iterations to capture more complex relationships between nodes.
3. Gated Graph Neural Network (GGNN)
The Gated Graph Neural Network (GGNN) combines the Message Passing concept of MPNN (Message Passing Neural Network) and the Gating Mechanism of RNN (Recurrent Neural Networks).
Just like MPNN, in Gated Graph Neural Network the nodes perform Message Passing among themselves. But after receiving messages, each node uses Gating Mechanism to decide how much information it should incorporate from its neighbors into its own features. It then goes on to update its feature by combining existing features with new features from neighbors decided by Gating Mechanism.
The message passing, gating mechanism, and node update steps are repeated in several iterations, allowing the GGNN (Gated Graph Neural Network) to learn more complex relationships between nodes.
4. Graph Attention Networks (GAT)
As the name suggests, Graph Attention Networks (GAT) use the attention mechanism to better weigh the importance of neighboring nodes when aggregating their features.
As always, each node of the graphs starts with its own initial feature representation. Graph Attention Network then learns to assign importance weights to the features of each node and its neighbors. The attention mechanism takes into account both the node’s features and the relationship between nodes (edges) when determining these weights.
Using the calculated attention weights, the model aggregates the features of each node and its neighbors. Neighbors with higher attention weights contribute more to the new representation of the node, while those with lower weights contribute less. This process is repeated in multiple iterations, enabling the GAT to capture more complex relationships between nodes.
Applications of Graph Neural Networks (GNNs)
In recent times, there have been many applications of Graph Neural Networks (GNN) in the practical world. Some of the popular applications of GNNs are –
- Social Media Analysis
- Molecular Biology
- Recommendation System
- Natural Language Processing
- Computer Vision