forked from pygod-team/pygod
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
71 lines (58 loc) · 2.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import tqdm
import torch
import argparse
import warnings
from pygod.metric import *
from pygod.utils import load_data
from utils import init_model
def main(args):
auc, ap, rec = [], [], []
for _ in tqdm.tqdm(range(num_trial)):
model = init_model(args)
data = load_data(args.dataset)
if args.model == 'if' or args.model == 'lof':
model.fit(data.x)
score = model.decision_function(data.x)
else:
model.fit(data)
score = model.decision_score_
y = data.y.bool()
k = sum(y)
if torch.isnan(score).any():
warnings.warn('contains NaN, skip one trial.')
continue
auc.append(eval_roc_auc(y, score))
ap.append(eval_average_precision(y, score))
rec.append(eval_recall_at_k(y, score, k))
auc = torch.tensor(auc)
ap = torch.tensor(ap)
rec = torch.tensor(rec)
print(args.dataset + " " + model.__class__.__name__ + " " +
"AUC: {:.4f}±{:.4f} ({:.4f})\t"
"AP: {:.4f}±{:.4f} ({:.4f})\t"
"Recall: {:.4f}±{:.4f} ({:.4f})".format(torch.mean(auc),
torch.std(auc),
torch.max(auc),
torch.mean(ap),
torch.std(ap),
torch.max(ap),
torch.mean(rec),
torch.std(rec),
torch.max(rec)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="dominant",
help="supported model: [lof, if, mlpae, scan, radar, "
"anomalous, gcnae, dominant, done, adone, "
"anomalydae, gaan, guide, conad]. "
"Default: dominant")
parser.add_argument("--gpu", type=int, default=0,
help="GPU Index. Default: -1, using CPU.")
parser.add_argument("--dataset", type=str, default='inj_cora',
help="supported dataset: [inj_cora, inj_amazon, "
"inj_flickr, weibo, reddit, disney, books, "
"enron]. Default: inj_cora")
args = parser.parse_args()
# global setting
num_trial = 20
main(args)