Skip to content

Commit

Permalink
Address TODO re missing asserts in test_elastic_wf_with_mace() (#679)
Browse files Browse the repository at this point in the history
* address TODO about missing asserts in test_elastic_wf_with_mace()

* remove unused # noqa: A003

* fix all ruff ANN001 errors (missing types)
  • Loading branch information
janosh committed Jan 16, 2024
1 parent 7a9df71 commit 4869a35
Show file tree
Hide file tree
Showing 18 changed files with 101 additions and 95 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default_language_version:
python: python3
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.11
rev: v0.1.13
hooks:
- id: ruff
args: [--fix]
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ ignore = [
"PT006", # pytest-parametrize-names-wrong-type
"RUF013", # implicit-optional
# TODO remove PT011, pytest.raises() should always check err msg
"ANN001", # TODO remove this ignore
"ANN002",
"ANN003",
"ANN101", # missing self type annotation
Expand Down
2 changes: 1 addition & 1 deletion src/atomate2/amset/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class AmsetTaskDocument(StructureMetadata):
completed_at: str = Field(
None, description="Timestamp for when this task was completed"
)
input: dict = Field(None, description="The input settings") # noqa: A003
input: dict = Field(None, description="The input settings")
transport: TransportData = Field(None, description="The transport results")
usage_stats: UsageStats = Field(None, description="Timing and memory usage")
mesh: MeshData = Field(None, description="Full AMSET mesh data")
Expand Down
12 changes: 10 additions & 2 deletions src/atomate2/cli/dev.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
"""Module containing command line scripts for developers."""


from __future__ import annotations

from typing import TYPE_CHECKING

import click

if TYPE_CHECKING:
from pathlib import Path


@click.group(context_settings={"help_option_names": ["-h", "--help"]})
def dev() -> None:
Expand All @@ -10,7 +18,7 @@ def dev() -> None:

@dev.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.argument("test_dir")
def vasp_test_data(test_dir) -> None:
def vasp_test_data(test_dir: Path) -> None:
"""Generate test data for VASP unit tests.
This script expects there is an outputs.json file and job folders in the current
Expand Down Expand Up @@ -156,7 +164,7 @@ def test_my_flow(mock_vasp, clean_dir, si_structure):
print(test_function_str) # noqa: T201


def _potcar_to_potcar_spec(potcar_filename, output_filename) -> None:
def _potcar_to_potcar_spec(potcar_filename: str | Path, output_filename: Path) -> None:
"""Convert a POTCAR file to a POTCAR.spec file."""
from pymatgen.io.vasp import Potcar

Expand Down
3 changes: 2 additions & 1 deletion src/atomate2/common/flows/defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pathlib import Path

import numpy.typing as npt
from emmet.core.tasks import TaskDoc
from pymatgen.analysis.defects.core import Defect
from pymatgen.core.structure import Structure
from pymatgen.entries.computed_entries import ComputedStructureEntry
Expand Down Expand Up @@ -361,7 +362,7 @@ def sc_entry_and_locpot_from_prv(
"""

@abstractmethod
def get_planar_locpot(self, task_doc) -> dict:
def get_planar_locpot(self, task_doc: TaskDoc) -> dict:
"""Get the Planar Locpot from the TaskDoc.
This is needed just in case the planar average locpot is stored in different
Expand Down
4 changes: 2 additions & 2 deletions src/atomate2/common/jobs/defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def bulk_supercell_calculation(
"""
if get_planar_locpot is None:

def get_planar_locpot(tdoc) -> NDArray:
return tdoc.calcs_reversed[0].output.locpot
def get_planar_locpot(task_doc: TaskDoc) -> NDArray:
return task_doc.calcs_reversed[0].output.locpot

logger.info("Running bulk supercell calculation. Running...")
sc_mat = get_sc_fromstruct(uc_structure) if sc_mat is None else sc_mat
Expand Down
12 changes: 6 additions & 6 deletions src/atomate2/common/jobs/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def generate_frequencies_eigenvectors(

@job(data=["forces", "displaced_structures"])
def run_phonon_displacements(
displacements,
displacements: list[Structure],
structure: Structure,
supercell_matrix,
supercell_matrix: Matrix3D,
phonon_maker: BaseVaspMaker | ForceFieldStaticMaker = None,
prev_dir: str | Path = None,
) -> Flow:
Expand Down Expand Up @@ -281,16 +281,16 @@ def run_phonon_displacements(
"dirs": [],
}

for i, displacement in enumerate(displacements):
for idx, displacement in enumerate(displacements):
if prev_dir is not None:
phonon_job = phonon_maker.make(displacement, prev_dir=prev_dir)
else:
phonon_job = phonon_maker.make(displacement)
phonon_job.append_name(f" {i + 1}/{len(displacements)}")
phonon_job.append_name(f" {idx + 1}/{len(displacements)}")

# we will add some meta data
info = {
"displacement_number": i,
"displacement_number": idx,
"original_structure": structure,
"supercell_matrix": supercell_matrix,
"displaced_structure": displacement,
Expand All @@ -302,7 +302,7 @@ def run_phonon_displacements(
)

phonon_jobs.append(phonon_job)
outputs["displacement_number"].append(i)
outputs["displacement_number"].append(idx)
outputs["uuids"].append(phonon_job.output.uuid)
outputs["dirs"].append(phonon_job.output.dir_name)
outputs["forces"].append(phonon_job.output.output.forces)
Expand Down
4 changes: 3 additions & 1 deletion src/atomate2/common/powerups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from jobflow import Flow, Maker


def add_metadata_to_flow(flow, additional_fields: dict, class_filter: Maker) -> Flow:
def add_metadata_to_flow(
flow: Flow, additional_fields: dict, class_filter: Maker
) -> Flow:
"""
Return the flow with additional field(metadata) to the task doc.
Expand Down
38 changes: 19 additions & 19 deletions src/atomate2/common/schemas/cclib.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def from_logfile(

@requires(cclib, "cclib_calculate requires cclib to be installed.")
def cclib_calculate(
cclib_obj,
cclib_obj: Any,
method: str,
cube_file: Union[Path, str],
proatom_dir: Union[Path, str],
Expand Down Expand Up @@ -310,28 +310,28 @@ def cclib_calculate(
vol = volume.read_from_cube(str(cube_file))

if method == "bader":
m = Bader(cclib_obj, vol)
_method = Bader(cclib_obj, vol)
elif method == "bickelhaupt":
m = Bickelhaupt(cclib_obj)
_method = Bickelhaupt(cclib_obj)
elif method == "cpsa":
m = CSPA(cclib_obj)
_method = CSPA(cclib_obj)
elif method == "ddec6":
m = DDEC6(cclib_obj, vol, str(proatom_dir))
_method = DDEC6(cclib_obj, vol, str(proatom_dir))
elif method == "density":
m = Density(cclib_obj)
_method = Density(cclib_obj)
elif method == "hirshfeld":
m = Hirshfeld(cclib_obj, vol, str(proatom_dir))
_method = Hirshfeld(cclib_obj, vol, str(proatom_dir))
elif method == "lpa":
m = LPA(cclib_obj)
_method = LPA(cclib_obj)
elif method == "mbo":
m = MBO(cclib_obj)
_method = MBO(cclib_obj)
elif method == "mpa":
m = MPA(cclib_obj)
_method = MPA(cclib_obj)
else:
raise ValueError(f"{method} is not supported.")
raise ValueError(f"{method=} is not supported.")

try:
m.calculate()
_method.calculate()
except AttributeError:
return None

Expand All @@ -351,20 +351,20 @@ def cclib_calculate(
]
calc_attributes = {}
for attribute in avail_attributes:
if hasattr(m, attribute):
calc_attributes[attribute] = getattr(m, attribute)
if hasattr(_method, attribute):
calc_attributes[attribute] = getattr(_method, attribute)
return calc_attributes


def _get_homos_lumos(
moenergies: list[list[float]], homo_indices: list[int]
mo_energies: list[list[float]], homo_indices: list[int]
) -> tuple[list[float], Optional[list[float]], Optional[list[float]]]:
"""
Calculate the HOMO, LUMO, and HOMO-LUMO gap energies in eV.
Parameters
----------
moenergies
mo_energies
List of MO energies. For restricted calculations, List[List[float]] is
length one. For unrestricted, it is length two.
homo_indices
Expand All @@ -380,13 +380,13 @@ def _get_homos_lumos(
The HOMO-LUMO gaps (eV), calculated as LUMO_alpha-HOMO_alpha and
LUMO_beta-HOMO_beta
"""
homo_energies = [moenergies[i][h] for i, h in enumerate(homo_indices)]
homo_energies = [mo_energies[i][h] for i, h in enumerate(homo_indices)]
# Make sure that the HOMO+1 (i.e. LUMO) is in moenergies (sometimes virtual
# orbitals aren't printed in the output)
for i, h in enumerate(homo_indices):
if len(moenergies[i]) < h + 2:
if len(mo_energies[i]) < h + 2:
return homo_energies, None, None
lumo_energies = [moenergies[i][h + 1] for i, h in enumerate(homo_indices)]
lumo_energies = [mo_energies[i][h + 1] for i, h in enumerate(homo_indices)]
homo_lumo_gaps = [
lumo_energies[i] - homo_energies[i] for i in range(len(homo_energies))
]
Expand Down
37 changes: 20 additions & 17 deletions src/atomate2/common/schemas/defects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""General schemas for defect workflow outputs."""

import logging
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union

import numpy as np
Expand Down Expand Up @@ -239,8 +240,8 @@ def from_task_outputs(
def get_ent(
struct: Structure,
energy: float,
dir_name,
uuid,
dir_name: str,
uuid: str,
) -> ComputedStructureEntry:
return ComputedStructureEntry(
structure=struct,
Expand Down Expand Up @@ -283,14 +284,16 @@ def from_entries(
"""

def find_entry(entries, uuid) -> tuple[int, ComputedStructureEntry]:
def find_entry(
entries: Sequence[ComputedStructureEntry], uuid: str
) -> tuple[int, ComputedStructureEntry]:
"""Find the entry with the given UUID."""
for itr, entry in enumerate(entries):
for idx, entry in enumerate(entries):
if entry.data["uuid"] == uuid:
return itr, entry
return idx, entry
raise ValueError(f"Could not find entry with UUID: {uuid}")

def dQ_entries(e1, e2) -> float: # noqa: N802
def dQ_entries(e1: ComputedStructureEntry, e2: ComputedStructureEntry) -> float: # noqa: N802
"""Get the displacement between two entries."""
return get_dQ(e1.structure, e2.structure)

Expand Down Expand Up @@ -338,23 +341,23 @@ def dQ_entries(e1, e2) -> float: # noqa: N802
relaxed_index2=idx2,
)

def get_taskdocs(self) -> list[list[TaskDoc]]:
def get_taskdocs(self) -> tuple[list[TaskDoc], list[TaskDoc]]:
"""Get the distorted task documents."""

def remove_host_name(dir_name) -> str:
def remove_host_name(dir_name: str) -> str:
return dir_name.split(":")[-1]

return [
[
TaskDoc.from_directory(remove_host_name(dir_name))
for dir_name in self.static_dirs1
],
[
TaskDoc.from_directory(remove_host_name(dir_name))
for dir_name in self.static_dirs2
],
static1_task_docs = [
TaskDoc.from_directory(remove_host_name(dir_name))
for dir_name in self.static_dirs1
]
static2_task_docs = [
TaskDoc.from_directory(remove_host_name(dir_name))
for dir_name in self.static_dirs2
]

return static1_task_docs, static2_task_docs


def sort_pos_dist(
list_in: list[Any],
Expand Down
12 changes: 7 additions & 5 deletions src/atomate2/cp2k/powerups.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def update_user_input_settings(
A copy of the input flow/job/maker modified to use the updated input settings.
"""

# Convert nested dictionary updates for cp2k inpt settings
# Convert nested dictionary updates for cp2k input settings
# into dict_mod update format
def nested_to_dictmod(d, kk="input_set_generator->user_input_settings") -> dict:
def nested_to_dictmod(
dct: dict, kk: str = "input_set_generator->user_input_settings"
) -> dict:
d2 = {}
for k, v in d.items():
k2 = kk + f"->{k}"
for k, v in dct.items():
k2 = f"{kk}->{k}"
if isinstance(v, dict):
d2.update(nested_to_dictmod(v, kk=k2))
else:
Expand Down Expand Up @@ -143,7 +145,7 @@ def update_user_kpoints_settings(


def add_metadata_to_flow(
flow, additional_fields: dict, class_filter: Maker = BaseCp2kMaker
flow: Flow, additional_fields: dict, class_filter: Maker = BaseCp2kMaker
) -> Flow:
"""
Return the Cp2k flow with additional field(metadata) to the task doc.
Expand Down
6 changes: 3 additions & 3 deletions src/atomate2/cp2k/schemas/calc_types/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from itertools import product
from pathlib import Path
from typing import Any

from monty.serialization import loadfn

Expand All @@ -24,14 +25,13 @@
_RUN_TYPES.append(f"{rt}{vdw}{u}") # noqa: PERF401


def get_enum_source(enum_name, doc, items) -> str:
def get_enum_source(enum_name: str, doc: str, items: dict[str, Any]) -> str:
header = f"""
class {enum_name}(ValueEnum):
\"\"\" {doc} \"\"\"\n
"""
items = [f' {const} = "{val}"' for const, val in items.items()]

return header + "\n".join(items)
return header + "\n".join(f' {key} = "{val}"' for key, val in items.items())


run_type_enum = get_enum_source(
Expand Down
4 changes: 2 additions & 2 deletions src/atomate2/cp2k/schemas/calc_types/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module to define various calculation types as Enums for CP2K."""

from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from pathlib import Path

from monty.serialization import loadfn
Expand All @@ -22,7 +22,7 @@ def run_type(inputs: dict) -> RunType:
"""
dft = inputs.get("dft")

def _variant_equal(v1, v2) -> bool:
def _variant_equal(v1: Sequence, v2: Sequence) -> bool:
"""Determine if two run_types are equal."""
if isinstance(v1, str) and isinstance(v2, str):
return v1.strip().upper() == v2.strip().upper()
Expand Down
2 changes: 1 addition & 1 deletion src/atomate2/cp2k/schemas/calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class Calculation(BaseModel):
has_cp2k_completed: Status = Field(
None, description="Whether CP2K completed the calculation successfully"
)
input: CalculationInput = Field( # noqa: A003
input: CalculationInput = Field(
None, description="CP2K input settings for the calculation"
)
output: CalculationOutput = Field(None, description="The CP2K calculation output")
Expand Down
2 changes: 1 addition & 1 deletion src/atomate2/cp2k/schemas/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class TaskDocument(StructureMetadata, MoleculeMetadata):
completed_at: Optional[str] = Field(
None, description="Timestamp for when this task was completed"
)
input: Optional[InputSummary] = Field( # noqa: A003
input: Optional[InputSummary] = Field(
None, description="The input to the first calculation"
)
output: Optional[OutputSummary] = Field(
Expand Down
Loading

0 comments on commit 4869a35

Please sign in to comment.