Skip to content

Commit

Permalink
update heatmap module
Browse files Browse the repository at this point in the history
collect heatmaps from a directory
  • Loading branch information
hazemfahmyy committed Sep 21, 2022
1 parent 044d377 commit 0c14d51
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

DeepJanus/core/.DS_Store
.DS_Store
*.pyc
*.xml
*.iml
18 changes: 18 additions & 0 deletions HeatmapModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,24 @@ def collectHeatmaps(outPutPath, layerX):
print("Collected " + str(layerX) + " heatmaps")
return allHM, imgList

def collectHeatmaps_Dir(HMDir):
allHM = {}
imgList = []
index2 = 0
for file in os.listdir(HMDir):
if file.endswith(".pt"):
imgList.append(file)
for file in imgList:
if torch.cuda.is_available():
heatMap = torch.load(join(HMDir, file))
heatMap.cuda()
else:
heatMap = torch.load(join(HMDir, file), map_location='cpu')
allHM[file.split(".")[0]] = heatMap
index2 = index2 + 1
if index2 % 1000 == 0:
print("Heatmap is collected for " + str(index2) + " images")
return allHM, imgList

def calcAndSaveHeatmapDistances(layerX, outPutPath: str, outputFile: str, metric):
start = time.time()
Expand Down

0 comments on commit 0c14d51

Please sign in to comment.