Skip to content

Few Shot Learning

FewShotLearner

Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class FewShotLearner:
    def __init__(self, model, data_loader, criterion, optimizer, num_epochs):
        """
        Initializes the FewShotLearner with the given model, data loader, criterion, optimizer, and number of epochs.

        Parameters:
            model (torch.nn.Module): The model to be trained and used for predictions.
            data_loader (torch.utils.data.DataLoader): DataLoader providing the training data.
            criterion (torch.nn.Module): Loss function used to evaluate the model's performance.
            optimizer (torch.optim.Optimizer): Optimization algorithm used to update model weights.
            num_epochs (int): Number of epochs to train the model.
        """
        self.model = model
        self.data_loader = data_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.num_epochs = num_epochs

    def train(self):
        """
        Trains the model using the few-shot learning approach.

        This method iterates over the data provided by the data_loader for a specified number of epochs,
        updating the model's weights using the optimizer and evaluating its performance using the criterion.
        """
        print("Training with Few-Shot Learning...")
        train_model(self.model, self.data_loader, self.criterion, self.optimizer, self.num_epochs)

    def predict(self, graph, node_index, top_k=5, threshold=0.5):
        """
        Predicts the top-k items for a given node in a graph using the trained model.

        Parameters:
            graph (torch.Tensor): The graph data structure containing nodes and edges.
            node_index (int): The index of the node for which predictions are to be made.
            top_k (int, optional): The number of top items to return. Defaults to 5.
            threshold (float, optional): The threshold for prediction confidence. Defaults to 0.5.

        Returns:
            List[int]: A list of indices representing the top-k predicted items.
        """
        print("Predicting with Few-Shot Learning...")
        return predict(self.model, graph, node_index, top_k, threshold)

__init__(model, data_loader, criterion, optimizer, num_epochs)

Initializes the FewShotLearner with the given model, data loader, criterion, optimizer, and number of epochs.

Parameters:

Name Type Description Default
model Module

The model to be trained and used for predictions.

required
data_loader DataLoader

DataLoader providing the training data.

required
criterion Module

Loss function used to evaluate the model's performance.

required
optimizer Optimizer

Optimization algorithm used to update model weights.

required
num_epochs int

Number of epochs to train the model.

required
Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(self, model, data_loader, criterion, optimizer, num_epochs):
    """
    Initializes the FewShotLearner with the given model, data loader, criterion, optimizer, and number of epochs.

    Parameters:
        model (torch.nn.Module): The model to be trained and used for predictions.
        data_loader (torch.utils.data.DataLoader): DataLoader providing the training data.
        criterion (torch.nn.Module): Loss function used to evaluate the model's performance.
        optimizer (torch.optim.Optimizer): Optimization algorithm used to update model weights.
        num_epochs (int): Number of epochs to train the model.
    """
    self.model = model
    self.data_loader = data_loader
    self.criterion = criterion
    self.optimizer = optimizer
    self.num_epochs = num_epochs

predict(graph, node_index, top_k=5, threshold=0.5)

Predicts the top-k items for a given node in a graph using the trained model.

Parameters:

Name Type Description Default
graph Tensor

The graph data structure containing nodes and edges.

required
node_index int

The index of the node for which predictions are to be made.

required
top_k int

The number of top items to return. Defaults to 5.

5
threshold float

The threshold for prediction confidence. Defaults to 0.5.

0.5

Returns:

Type Description

List[int]: A list of indices representing the top-k predicted items.

Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def predict(self, graph, node_index, top_k=5, threshold=0.5):
    """
    Predicts the top-k items for a given node in a graph using the trained model.

    Parameters:
        graph (torch.Tensor): The graph data structure containing nodes and edges.
        node_index (int): The index of the node for which predictions are to be made.
        top_k (int, optional): The number of top items to return. Defaults to 5.
        threshold (float, optional): The threshold for prediction confidence. Defaults to 0.5.

    Returns:
        List[int]: A list of indices representing the top-k predicted items.
    """
    print("Predicting with Few-Shot Learning...")
    return predict(self.model, graph, node_index, top_k, threshold)

train()

Trains the model using the few-shot learning approach.

This method iterates over the data provided by the data_loader for a specified number of epochs, updating the model's weights using the optimizer and evaluating its performance using the criterion.

Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
43
44
45
46
47
48
49
50
51
def train(self):
    """
    Trains the model using the few-shot learning approach.

    This method iterates over the data provided by the data_loader for a specified number of epochs,
    updating the model's weights using the optimizer and evaluating its performance using the criterion.
    """
    print("Training with Few-Shot Learning...")
    train_model(self.model, self.data_loader, self.criterion, self.optimizer, self.num_epochs)