Skip to content

GNN

GNN

GNN Class

This class implements Graph Neural Networks (GNNs) for recommendation systems. GNNs are a type of neural network designed to operate on graph-structured data, capturing complex relationships between nodes through message passing and aggregation.

Attributes:

Name Type Description
num_layers int

Number of layers in the GNN.

hidden_dim int

Dimensionality of hidden layers.

learning_rate float

Learning rate for training the GNN.

epochs int

Number of training epochs.

graph Graph

The graph structure representing users and items.

Methods:

Name Description
build_model

Constructs the GNN model architecture, defining layers and operations for message passing and node aggregation.

train

Trains the GNN model on the provided data, optimizing node embeddings for recommendation tasks.

recommend

Generates top-N recommendations for a given user by leveraging learned node embeddings and graph structure.

evaluate

Evaluates the performance of the GNN model on test data, providing metrics such as accuracy and precision.

Source code in engines/contentFilterEngine/graph_based_algorithms/gnn.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class GNN:
    """
    GNN Class

    This class implements Graph Neural Networks (GNNs) for recommendation systems.
    GNNs are a type of neural network designed to operate on graph-structured data,
    capturing complex relationships between nodes through message passing and aggregation.

    Attributes:
        num_layers (int): Number of layers in the GNN.
        hidden_dim (int): Dimensionality of hidden layers.
        learning_rate (float): Learning rate for training the GNN.
        epochs (int): Number of training epochs.
        graph (Graph): The graph structure representing users and items.

    Methods:
        build_model():
            Constructs the GNN model architecture, defining layers and operations for
            message passing and node aggregation.

        train(data):
            Trains the GNN model on the provided data, optimizing node embeddings for
            recommendation tasks.

        recommend(user_id, top_n=10):
            Generates top-N recommendations for a given user by leveraging learned node
            embeddings and graph structure.

        evaluate(test_data):
            Evaluates the performance of the GNN model on test data, providing metrics
            such as accuracy and precision.
    """
    def __init__(self):
        self.graph = None

    def load_graph(self, file_path):
        """
        Load a graph from a file.
        """
        self.graph = nx.read_adjlist(file_path, delimiter=',', nodetype=int)
        return self.graph

    def visualize_graph(self, recommended_nodes=None, top_nodes=None, node_labels=None):
        """
        Visualize the graph using the draw_graph function.
        """
        if self.graph is not None:
            # Use a simpler layout for visualization
            pos = nx.circular_layout(self.graph)
            draw_graph(self.graph, pos=pos, 
                       top_nodes=top_nodes, 
                       recommended_nodes=recommended_nodes, 
                       node_labels=node_labels)
        else:
            print("Graph not loaded. Please load a graph first.")

load_graph(file_path)

Load a graph from a file.

Source code in engines/contentFilterEngine/graph_based_algorithms/gnn.py
41
42
43
44
45
46
def load_graph(self, file_path):
    """
    Load a graph from a file.
    """
    self.graph = nx.read_adjlist(file_path, delimiter=',', nodetype=int)
    return self.graph

visualize_graph(recommended_nodes=None, top_nodes=None, node_labels=None)

Visualize the graph using the draw_graph function.

Source code in engines/contentFilterEngine/graph_based_algorithms/gnn.py
48
49
50
51
52
53
54
55
56
57
58
59
60
def visualize_graph(self, recommended_nodes=None, top_nodes=None, node_labels=None):
    """
    Visualize the graph using the draw_graph function.
    """
    if self.graph is not None:
        # Use a simpler layout for visualization
        pos = nx.circular_layout(self.graph)
        draw_graph(self.graph, pos=pos, 
                   top_nodes=top_nodes, 
                   recommended_nodes=recommended_nodes, 
                   node_labels=node_labels)
    else:
        print("Graph not loaded. Please load a graph first.")