From 7deddbf8db702e3537f02f406ca5b67e824202ec Mon Sep 17 00:00:00 2001 From: mjwen Date: Fri, 30 Jun 2023 17:13:07 -0500 Subject: [PATCH] Add future annotation to use `list` for typing --- src/atomate2/common/schemas/elastic.py | 33 ++++++++++++++------------ 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/src/atomate2/common/schemas/elastic.py b/src/atomate2/common/schemas/elastic.py index 7095aa6015..d54fb58f4a 100644 --- a/src/atomate2/common/schemas/elastic.py +++ b/src/atomate2/common/schemas/elastic.py @@ -1,9 +1,9 @@ """Schemas for elastic tensor fitting and related properties.""" +from __future__ import annotations from copy import deepcopy -from typing import List, Optional +from typing import TYPE_CHECKING -from emmet.core.math import Matrix3D, MatrixVoigt from pydantic import BaseModel, Field from pymatgen.analysis.elasticity import ( Deformation, @@ -12,12 +12,15 @@ Strain, Stress, ) -from pymatgen.core import Structure from pymatgen.core.tensors import TensorMapping from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from atomate2 import SETTINGS +if TYPE_CHECKING: + from emmet.core.math import Matrix3D, MatrixVoigt + from pymatgen.core import Structure + __all__ = [ "DerivedProperties", "FittingData", @@ -83,20 +86,20 @@ class DerivedProperties(BaseModel): class FittingData(BaseModel): """Data used to fit elastic tensors.""" - cauchy_stresses: List[Matrix3D] = Field( + cauchy_stresses: list[Matrix3D] = Field( None, description="The Cauchy stresses used to fit the elastic tensor." ) - strains: List[Matrix3D] = Field( + strains: list[Matrix3D] = Field( None, description="The strains used to fit the elastic tensor." ) - pk_stresses: List[Matrix3D] = Field( + pk_stresses: list[Matrix3D] = Field( None, description="The Piola-Kirchoff stresses used to fit the elastic tensor." ) - deformations: List[Matrix3D] = Field( + deformations: list[Matrix3D] = Field( None, description="The deformations corresponding to each strain state." ) - uuids: List[str] = Field(None, description="The uuids of the deformation jobs.") - job_dirs: List[str] = Field( + uuids: list[str] = Field(None, description="The uuids of the deformation jobs.") + job_dirs: list[str] = Field( None, description="The directories where the deformation jobs were run." ) @@ -141,13 +144,13 @@ class ElasticDocument(BaseModel): def from_stresses( cls, structure: Structure, - stresses: List[Stress], - deformations: List[Deformation], - uuids: List[str], - job_dirs: List[str], + stresses: list[Stress], + deformations: list[Deformation], + uuids: list[str], + job_dirs: list[str], fitting_method: str = SETTINGS.ELASTIC_FITTING_METHOD, - order: Optional[int] = None, - equilibrium_stress: Optional[Matrix3D] = None, + order: int | None = None, + equilibrium_stress: Matrix3D | None = None, symprec: float = SETTINGS.SYMPREC, allow_elastically_unstable_structs: bool = True, ):