Skip to content

Commit

Permalink
Merge branch 'main' into faiss
Browse files Browse the repository at this point in the history
  • Loading branch information
hcw-00 committed Dec 6, 2021
2 parents adde64a + b61c492 commit 144148b
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,78 @@
import cv2
import os

from PIL import Image
from sklearn.metrics import roc_auc_score
from torch import nn
import pytorch_lightning as pl
from sklearn.metrics import confusion_matrix
import pickle
from sampling_methods.kcenter_greedy import kCenterGreedy
from sklearn.random_projection import SparseRandomProjection
from sklearn.neighbors import NearestNeighbors
from scipy.ndimage import gaussian_filter


def distance_matrix(x, y=None, p=2): # pairwise distance of vectors

y = x if type(y) == type(None) else y

n = x.size(0)
m = y.size(0)
d = x.size(1)

x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)

dist = torch.pow(x - y, p).sum(2)

return dist


class NN():

def __init__(self, X=None, Y=None, p=2):
self.p = p
self.train(X, Y)

def train(self, X, Y):
self.train_pts = X
self.train_label = Y

def __call__(self, x):
return self.predict(x)

def predict(self, x):
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
name = self.__class__.__name__
raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first")

dist = distance_matrix(x, self.train_pts, self.p) ** (1 / self.p)

This comment has been minimized.

Copy link
@HoseinHashemi

HoseinHashemi Jan 6, 2022

Any particular reason you don't use torch.cdist here in NN class? It has been used in KNN class, though.

labels = torch.argmin(dist, dim=1)
return self.train_label[labels]

class KNN(NN):

def __init__(self, X=None, Y=None, k=3, p=2):
self.k = k
super().__init__(X, Y, p)

def train(self, X, Y):
super().train(X, Y)
if type(Y) != type(None):
self.unique_labels = self.train_label.unique()

def predict(self, x):


# dist = distance_matrix(x, self.train_pts, self.p) ** (1 / self.p)
dist = torch.cdist(x, self.train_pts, self.p)

knn = dist.topk(self.k, largest=False)


return knn

def copy_files(src, dst, ignores=[]):
src_files = os.listdir(src)
for file_name in src_files:
Expand Down

0 comments on commit 144148b

Please sign in to comment.