Skip to content

Zero Shot Learning

ZeroShotLearner

Source code in engines/contentFilterEngine/learning_paradigms/zero_shot.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class ZeroShotLearner:
    def __init__(self, model):
        """
        Initializes the ZeroShotLearner with the given model.

        Parameters:
            model (torch.nn.Module): The model to be used for predictions.
        """
        self.model = model

    def predict(self, graph, node_index, top_k=5, threshold=0.5):
        """
        Predicts the top-k items for a given node in a graph using the model.

        Parameters:
            graph (torch.Tensor): The graph data structure containing nodes and edges.
            node_index (int): The index of the node for which predictions are to be made.
            top_k (int, optional): The number of top items to return. Defaults to 5.
            threshold (float, optional): The threshold for prediction confidence. Defaults to 0.5.

        Returns:
            List[int]: A list of indices representing the top-k predicted items.
        """
        print("Predicting with Zero-Shot Learning...")
        return predict(self.model, graph, node_index, top_k, threshold)

__init__(model)

Initializes the ZeroShotLearner with the given model.

Parameters:

Name Type Description Default
model Module

The model to be used for predictions.

required
Source code in engines/contentFilterEngine/learning_paradigms/zero_shot.py
 5
 6
 7
 8
 9
10
11
12
def __init__(self, model):
    """
    Initializes the ZeroShotLearner with the given model.

    Parameters:
        model (torch.nn.Module): The model to be used for predictions.
    """
    self.model = model

predict(graph, node_index, top_k=5, threshold=0.5)

Predicts the top-k items for a given node in a graph using the model.

Parameters:

Name Type Description Default
graph Tensor

The graph data structure containing nodes and edges.

required
node_index int

The index of the node for which predictions are to be made.

required
top_k int

The number of top items to return. Defaults to 5.

5
threshold float

The threshold for prediction confidence. Defaults to 0.5.

0.5

Returns:

Type Description

List[int]: A list of indices representing the top-k predicted items.

Source code in engines/contentFilterEngine/learning_paradigms/zero_shot.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def predict(self, graph, node_index, top_k=5, threshold=0.5):
    """
    Predicts the top-k items for a given node in a graph using the model.

    Parameters:
        graph (torch.Tensor): The graph data structure containing nodes and edges.
        node_index (int): The index of the node for which predictions are to be made.
        top_k (int, optional): The number of top items to return. Defaults to 5.
        threshold (float, optional): The threshold for prediction confidence. Defaults to 0.5.

    Returns:
        List[int]: A list of indices representing the top-k predicted items.
    """
    print("Predicting with Zero-Shot Learning...")
    return predict(self.model, graph, node_index, top_k, threshold)