Few Shot Learning
FewShotLearner
Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
__init__(model, data_loader, criterion, optimizer, num_epochs)
Initializes the FewShotLearner with the given model, data loader, criterion, optimizer, and number of epochs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model to be trained and used for predictions. |
required |
data_loader
|
DataLoader
|
DataLoader providing the training data. |
required |
criterion
|
Module
|
Loss function used to evaluate the model's performance. |
required |
optimizer
|
Optimizer
|
Optimization algorithm used to update model weights. |
required |
num_epochs
|
int
|
Number of epochs to train the model. |
required |
Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
|
predict(graph, node_index, top_k=5, threshold=0.5)
Predicts the top-k items for a given node in a graph using the trained model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph
|
Tensor
|
The graph data structure containing nodes and edges. |
required |
node_index
|
int
|
The index of the node for which predictions are to be made. |
required |
top_k
|
int
|
The number of top items to return. Defaults to 5. |
5
|
threshold
|
float
|
The threshold for prediction confidence. Defaults to 0.5. |
0.5
|
Returns:
Type | Description |
---|---|
List[int]: A list of indices representing the top-k predicted items. |
Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
train()
Trains the model using the few-shot learning approach.
This method iterates over the data provided by the data_loader for a specified number of epochs, updating the model's weights using the optimizer and evaluating its performance using the criterion.
Source code in engines/contentFilterEngine/learning_paradigms/few_shot.py
43 44 45 46 47 48 49 50 51 |
|