diff --git a/train.py b/train.py index 23a52b2..75a9278 100644 --- a/train.py +++ b/train.py @@ -17,6 +17,78 @@ import cv2 import os +from PIL import Image +from sklearn.metrics import roc_auc_score +from torch import nn +import pytorch_lightning as pl +from sklearn.metrics import confusion_matrix +import pickle +from sampling_methods.kcenter_greedy import kCenterGreedy +from sklearn.random_projection import SparseRandomProjection +from sklearn.neighbors import NearestNeighbors +from scipy.ndimage import gaussian_filter + + +def distance_matrix(x, y=None, p=2): # pairwise distance of vectors + + y = x if type(y) == type(None) else y + + n = x.size(0) + m = y.size(0) + d = x.size(1) + + x = x.unsqueeze(1).expand(n, m, d) + y = y.unsqueeze(0).expand(n, m, d) + + dist = torch.pow(x - y, p).sum(2) + + return dist + + +class NN(): + + def __init__(self, X=None, Y=None, p=2): + self.p = p + self.train(X, Y) + + def train(self, X, Y): + self.train_pts = X + self.train_label = Y + + def __call__(self, x): + return self.predict(x) + + def predict(self, x): + if type(self.train_pts) == type(None) or type(self.train_label) == type(None): + name = self.__class__.__name__ + raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first") + + dist = distance_matrix(x, self.train_pts, self.p) ** (1 / self.p) + labels = torch.argmin(dist, dim=1) + return self.train_label[labels] + +class KNN(NN): + + def __init__(self, X=None, Y=None, k=3, p=2): + self.k = k + super().__init__(X, Y, p) + + def train(self, X, Y): + super().train(X, Y) + if type(Y) != type(None): + self.unique_labels = self.train_label.unique() + + def predict(self, x): + + + # dist = distance_matrix(x, self.train_pts, self.p) ** (1 / self.p) + dist = torch.cdist(x, self.train_pts, self.p) + + knn = dist.topk(self.k, largest=False) + + + return knn + def copy_files(src, dst, ignores=[]): src_files = os.listdir(src) for file_name in src_files: