Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial streaming support. #89

Merged
merged 128 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
9d33045
Add evaluation.
Mar 24, 2022
f7d708d
Small fixes.
Mar 24, 2022
74d0eb0
Add unit tests + eval mask fixes.
Mar 25, 2022
5589bf5
Update detection unit tests.
Mar 25, 2022
244100d
Fix typing + update pyproject.
Mar 25, 2022
5314ee4
Run autoflake.
Mar 25, 2022
4c0e6bb
Add ROI pruning.
Mar 26, 2022
8e59867
Remove arg.
Mar 26, 2022
343c219
Fix typo.
Mar 26, 2022
6a34f5b
Speed up argoverse maps.
Mar 28, 2022
23b2fec
Speed up evaluation.
Mar 28, 2022
ae42718
Small fixes.
Mar 28, 2022
a02e45c
Fix lint.
Mar 28, 2022
2f208d5
Small lint fixes.
Mar 28, 2022
c7797a5
Fix filtering.
benjaminrwilson Mar 29, 2022
0b15248
Small fixes.
benjaminrwilson Mar 29, 2022
c7a66de
Fix enums.
benjaminrwilson Mar 29, 2022
957e016
Remove unused lines.
benjaminrwilson Mar 29, 2022
b8ab038
Mypy fixes.
benjaminrwilson Mar 29, 2022
ceff8fa
Fix click mypy error.
benjaminrwilson Mar 29, 2022
c314775
Pytype fixes.
benjaminrwilson Mar 30, 2022
2e7af38
Fix pytype.
benjaminrwilson Mar 30, 2022
43f938a
Remove pytype.
benjaminrwilson Mar 30, 2022
90c791f
Small typing fixes.
benjaminrwilson Mar 30, 2022
515c6f2
Add unit tests.
benjaminrwilson Mar 31, 2022
8c5d1bc
Fix typing.
benjaminrwilson Mar 31, 2022
e7f2925
Remove click typing issue.
benjaminrwilson Mar 31, 2022
14f1141
Fix mypy.
benjaminrwilson Mar 31, 2022
7d4139c
Detection eval speed up.
benjaminrwilson Apr 2, 2022
c9de7cb
Rewrite detection eval for major speedup.
benjaminrwilson Apr 2, 2022
59a65f8
Typing fixes.
benjaminrwilson Apr 2, 2022
4327278
Typing fixes.
benjaminrwilson Apr 2, 2022
df479b9
Switch from record arrays to numpy arrays.
benjaminrwilson Apr 2, 2022
4c40a5f
Temp changes.
benjaminrwilson Apr 2, 2022
eccd4f0
Improve readability.
benjaminrwilson Apr 2, 2022
72dc702
Add comments.
benjaminrwilson Apr 2, 2022
fddd899
Modularize evaluate.
benjaminrwilson Apr 2, 2022
faa8317
Additional speedups.
benjaminrwilson Apr 3, 2022
f81ae17
Cleanup code.
benjaminrwilson Apr 3, 2022
d49105a
Additional speedup.
benjaminrwilson Apr 3, 2022
195b46a
Add roi pruning back.
benjaminrwilson Apr 3, 2022
dc2a954
Add multiprocessing.
benjaminrwilson Apr 3, 2022
06e2e96
Add verbosity.
benjaminrwilson Apr 3, 2022
54901b2
Mypy fixes.
benjaminrwilson Apr 3, 2022
856b4d4
Update cuboid fields.
benjaminrwilson Apr 3, 2022
be5a89c
Lint fixes.
benjaminrwilson Apr 3, 2022
5cb43ea
Fix map tutorial issues.
benjaminrwilson Apr 3, 2022
5214d9d
Merge branch 'map_notebook-fixes'
benjaminrwilson Apr 3, 2022
967cc39
Add test log.
benjaminrwilson Apr 3, 2022
ae8aeb5
Revert strings.
benjaminrwilson Apr 3, 2022
d8a0ae9
Remove outputs.
benjaminrwilson Apr 3, 2022
a343292
Merge branch 'map_notebook-fixes'
benjaminrwilson Apr 3, 2022
87239c3
Address missing detection edge cases.
benjaminrwilson Apr 4, 2022
9a2acc6
Address jhony comments.
benjaminrwilson Apr 5, 2022
ba6c337
Update docstring.
benjaminrwilson Apr 5, 2022
aa036aa
Clean docstrings.
benjaminrwilson Apr 5, 2022
a6b345d
Change roi method.
benjaminrwilson Apr 5, 2022
e725147
Clean up roi method.
benjaminrwilson Apr 5, 2022
ce896f0
Update roi returns.
benjaminrwilson Apr 5, 2022
aeea617
Autoflake.:
benjaminrwilson Apr 5, 2022
da897c1
Fix lint.
benjaminrwilson Apr 5, 2022
c459c26
Fix lint.
benjaminrwilson Apr 5, 2022
ca5d644
Update detection limiting logic.
Apr 7, 2022
375aaba
Fix indexing.
Apr 7, 2022
5113202
Fix tuple return.
Apr 7, 2022
97235a0
Merge https://github.com/argoai/av2-api into main
Apr 7, 2022
5a4c76d
Update CI.
Apr 7, 2022
07d3fbf
Add ROI unit tests.
Apr 9, 2022
0cd98f9
Remove val identity.
Apr 9, 2022
0ef221b
Fix import.
Apr 9, 2022
0b00475
Remove unused import.
Apr 9, 2022
82da97f
Update column names.
Apr 9, 2022
548ea93
Update eval.py
benjaminrwilson Apr 11, 2022
d83f8eb
Add README.md.
benjaminrwilson Apr 12, 2022
89163fa
Update README.
benjaminrwilson Apr 12, 2022
a7ea318
Update README.
benjaminrwilson Apr 12, 2022
8cbcb48
Update README.md
benjaminrwilson Apr 12, 2022
5753fc6
Update README.
benjaminrwilson Apr 12, 2022
db85849
Update README.
benjaminrwilson Apr 12, 2022
37d63ac
Update README.
benjaminrwilson Apr 12, 2022
94fb9de
Update README.
benjaminrwilson Apr 12, 2022
bc12db8
Update README.
benjaminrwilson Apr 12, 2022
f803216
Update README.md
benjaminrwilson Apr 12, 2022
2abf31c
Update README.md
benjaminrwilson Apr 12, 2022
f3d3f4d
Update README.md
benjaminrwilson Apr 12, 2022
b7c04f0
Updates.
Apr 15, 2022
c7c4026
Merge.
benjaminrwilson May 7, 2022
55fbff5
Merge branch 'argoai:main' into main
benjaminrwilson May 17, 2022
8eb4935
Merge branch 'argoai:main' into main
benjaminrwilson Jul 20, 2022
e9b141e
Streaming test.
benjaminrwilson Jul 20, 2022
f3bdf43
Add streaming support.
benjaminrwilson Jul 20, 2022
29ace45
Updates.
benjaminrwilson Jul 20, 2022
ab2d361
Updates.
benjaminrwilson Jul 20, 2022
2f37df2
Updates.
benjaminrwilson Jul 20, 2022
6249c87
Update upath.
Aug 8, 2022
9a06e2c
Remove file.
Aug 8, 2022
22eb6f9
Bump package.
Aug 8, 2022
6e295e7
Update upaths.
Aug 8, 2022
e6184b3
Add additional upath.
Aug 8, 2022
e2109bf
Add upaths.
Aug 8, 2022
40194b8
Updates.
Aug 8, 2022
21a3140
Autoflake.
Aug 8, 2022
7eed4a7
Small updates.
Aug 9, 2022
ae9ef75
Add range breakdown.
Aug 15, 2022
2eed914
Reorder indices.
Aug 15, 2022
38afd16
Update multiprocessing.
Aug 16, 2022
3897a79
Switch to threading
Aug 16, 2022
1d155f5
Small fixes.
Aug 17, 2022
cb33dbf
Merge.
Aug 26, 2022
072e596
Remove comments.
Aug 26, 2022
3fef470
Revert eval.
benjaminrwilson Aug 27, 2022
f915243
Update eval.py
benjaminrwilson Aug 28, 2022
2baa02f
Fix lint.
benjaminrwilson Aug 28, 2022
dc68a70
Fix lint.
benjaminrwilson Sep 3, 2022
af69cb1
Merge.
benjaminrwilson Sep 22, 2022
2e6d394
Fix lint.
benjaminrwilson Sep 22, 2022
0f7e248
Try numpy bound.
benjaminrwilson Sep 23, 2022
6765828
ci 3.10 pin to 3.10.6: https://github.com/python/mypy/issues/13627
benjaminrwilson Sep 25, 2022
d9e6245
Remove pin.
benjaminrwilson Sep 25, 2022
0f6f2ec
mypy fix attempt #2.
benjaminrwilson Sep 25, 2022
d1203f6
Update nox.
benjaminrwilson Sep 25, 2022
09ccefa
Updates.
benjaminrwilson Sep 25, 2022
b956b3e
Try 3.10.5
benjaminrwilson Sep 25, 2022
2e8f021
Pin numpy.
benjaminrwilson Sep 25, 2022
6eddb08
Upper bound numpy.
benjaminrwilson Sep 25, 2022
3516586
3.10.6
benjaminrwilson Sep 25, 2022
a131956
Update ci.yml
benjaminrwilson Sep 28, 2022
47aaa48
Update noxfile.py
benjaminrwilson Sep 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ dependencies:
- pyproj
- rich
- scipy
- universal_pathlib
5 changes: 2 additions & 3 deletions integration_tests/verify_tbv_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def verify_log_map(data_root: Path, log_id: str) -> None:
# every log should have one and only one raster height map. (Note: season is stripped from uuid here).
ground_height_raster_fpaths = list(log_map_dirpath.glob("*_ground_height_surface____*.npy"))
assert len(ground_height_raster_fpaths) == 1
ground_height_raster_fpath = ground_height_raster_fpaths[0]

# every log should have a Sim(2) mapping from raster grid coordinates to city coordinates.
Sim2_fpaths = list(log_map_dirpath.glob("*___img_Sim2_city.json"))
Expand Down Expand Up @@ -181,10 +180,10 @@ def verify_log_map(data_root: Path, log_id: str) -> None:
assert left_lane_boundary.ndim == 2 and left_lane_boundary.shape[1] == 3

# load every pedestrian crossing
pcs = avm.get_scenario_ped_crossings()
avm.get_scenario_ped_crossings()

# load every drivable area
das = avm.get_scenario_vector_drivable_areas()
avm.get_scenario_vector_drivable_areas()


def verify_logs_using_dataloader(data_root: Path, log_ids: List[str]) -> None:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ install_requires =
pyproj
rich
scipy
universal_pathlib

package_dir=
=src
Expand Down
12 changes: 9 additions & 3 deletions src/av2/evaluation/detection/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@
in addition to the mean statistics average across all classes, and P refers to the number of included statistics,
e.g. AP, ATE, ASE, AOE, CDS by default.
"""
import itertools
import logging
from multiprocessing import get_context
import multiprocessing as mp
import warnings
from math import inf
from statistics import mean
from typing import Dict, Final, List, Optional, Tuple

import numpy as np
Expand All @@ -72,6 +76,8 @@
from av2.utils.io import TimestampedCitySE3EgoPoses
from av2.utils.typing import NDArrayBool, NDArrayFloat

warnings.filterwarnings("ignore", module="google")

TP_ERROR_COLUMNS: Final[Tuple[str, ...]] = tuple(x.value for x in TruePositiveErrorNames)
DTS_COLUMN_NAMES: Final[Tuple[str, ...]] = tuple(ORDERED_CUBOID_COL_NAMES) + ("score",)
GTS_COLUMN_NAMES: Final[Tuple[str, ...]] = tuple(ORDERED_CUBOID_COL_NAMES) + ("num_interior_pts",)
Expand Down Expand Up @@ -161,7 +167,7 @@ def evaluate(
args_list.append(args)

logger.info("Starting evaluation ...")
with get_context("spawn").Pool(processes=n_jobs) as p:
with mp.get_context("spawn").Pool(processes=n_jobs) as p:
outputs: Optional[List[Tuple[NDArrayFloat, NDArrayFloat]]] = p.starmap(accumulate, args_list)

if outputs is None:
Expand Down Expand Up @@ -192,7 +198,7 @@ def summarize_metrics(
dts: (N,14) Table of detections.
gts: (M,15) Table of ground truth annotations.
cfg: Detection configuration.

Returns:
The summary metrics.
"""
Expand Down
12 changes: 7 additions & 5 deletions src/av2/evaluation/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from joblib import Parallel, delayed
from scipy.spatial.distance import cdist
from upath import UPath

from av2.evaluation.detection.constants import (
MAX_NORMALIZED_ASE,
Expand Down Expand Up @@ -63,7 +64,7 @@ class DetectionCfg:
affinity_thresholds_m: Tuple[float, ...] = (0.5, 1.0, 2.0, 4.0)
affinity_type: AffinityType = AffinityType.CENTER
categories: Tuple[str, ...] = tuple(x.value for x in CompetitionCategories)
dataset_dir: Optional[Path] = None
dataset_dir: Optional[Union[Path, UPath]] = None
eval_only_roi_instances: bool = True
filter_metric: FilterMetricType = FilterMetricType.EUCLIDEAN
max_num_dts_per_category: int = 100
Expand Down Expand Up @@ -315,7 +316,7 @@ def compute_average_precision(
# Evaluate precision at different recalls.
precision_interpolated: NDArrayFloat = np.interp(recall_interpolated, recall, precision, right=0)

average_precision: float = np.mean(precision_interpolated)
average_precision: float = float(np.mean(precision_interpolated))
return average_precision, precision_interpolated


Expand Down Expand Up @@ -437,7 +438,7 @@ def compute_evaluated_gts_mask(


def load_mapped_avm_and_egoposes(
log_ids: List[str], dataset_dir: Path
log_ids: List[str], dataset_dir: Union[Path, UPath]
) -> Tuple[Dict[str, ArgoverseStaticMap], Dict[str, TimestampedCitySE3EgoPoses]]:
"""Load the maps and egoposes for each log in the dataset directory.

Expand All @@ -452,9 +453,10 @@ def load_mapped_avm_and_egoposes(
RuntimeError: If the process for loading maps and timestamped egoposes fails.
"""
log_id_to_timestamped_poses = {log_id: read_city_SE3_ego(dataset_dir / log_id) for log_id in log_ids}
avms: Optional[List[ArgoverseStaticMap]] = Parallel(n_jobs=-1, backend="threading", verbose=1)(
avms: Optional[List[ArgoverseStaticMap]] = Parallel(n_jobs=-1, backend="threading")(
delayed(ArgoverseStaticMap.from_map_dir)(dataset_dir / log_id / "map", build_raster=True) for log_id in log_ids
)

if avms is None:
raise RuntimeError("Map and egopose loading has failed!")
log_id_to_avm = {log_ids[i]: avm for i, avm in enumerate(avms)}
Expand Down
6 changes: 4 additions & 2 deletions src/av2/geometry/sim2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import numbers
from dataclasses import dataclass
from pathlib import Path
from typing import Union

import numpy as np
from upath import UPath

import av2.utils.io as io_utils
from av2.utils.helpers import assert_np_array_shape
Expand Down Expand Up @@ -187,9 +189,9 @@ def save_as_json(self, save_fpath: Path) -> None:
io_utils.save_json_dict(save_fpath, dict_for_serialization)

@classmethod
def from_json(cls, json_fpath: Path) -> Sim2:
def from_json(cls, json_fpath: Union[Path, UPath]) -> Sim2:
"""Generate class inst. from a JSON file containing Sim(2) parameters as flattened matrices (row-major)."""
with open(json_fpath, "r") as f:
with json_fpath.open("r") as f:
json_data = json.load(f)

R: NDArrayFloat = np.array(json_data["R"]).reshape(2, 2)
Expand Down
15 changes: 9 additions & 6 deletions src/av2/map/map_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
from pathlib import Path
from typing import Dict, Final, List, Optional, Tuple, Union

import fsspec
import numpy as np
from upath import UPath

import av2.geometry.interpolate as interp_utils
import av2.utils.dilation_utils as dilation_utils
import av2.utils.io as io_utils
import av2.utils.raster as raster_utils
from av2.geometry.sim2 import Sim2
from av2.map.drivable_area import DrivableArea
from av2.map.lane_segment import LaneSegment
from av2.map.pedestrian_crossing import PedestrianCrossing
from av2.utils import io
from av2.utils.typing import NDArrayBool, NDArrayByte, NDArrayFloat, NDArrayInt

# 1 meter resolution is insufficient for the online-generated drivable area and ROI raster grids
Expand Down Expand Up @@ -105,7 +107,7 @@ class GroundHeightLayer(RasterMapLayer):
"""

@classmethod
def from_file(cls, log_map_dirpath: Path) -> GroundHeightLayer:
def from_file(cls, log_map_dirpath: Union[Path, UPath]) -> GroundHeightLayer:
"""Load ground height values (w/ values at 30 cm resolution) from .npy file, and associated Sim(2) mapping.

Note: ground height values are stored on disk as a float16 2d-array, but cast to float32 once loaded for
Expand All @@ -130,7 +132,8 @@ def from_file(cls, log_map_dirpath: Path) -> GroundHeightLayer:
raise RuntimeError("Sim(2) mapping from city to image coordinates is missing")

# load the file with rasterized values
ground_height_array: NDArrayFloat = np.load(ground_height_npy_fpaths[0])
with ground_height_npy_fpaths[0].open("rb") as f:
ground_height_array: NDArrayFloat = np.load(f)

array_Sim2_city = Sim2.from_json(Sim2_json_fpaths[0])

Expand Down Expand Up @@ -311,7 +314,7 @@ class ArgoverseStaticMap:
raster_ground_height_layer: Optional[GroundHeightLayer]

@classmethod
def from_json(cls, static_map_path: Path) -> ArgoverseStaticMap:
def from_json(cls, static_map_path: Union[Path, UPath]) -> ArgoverseStaticMap:
"""Instantiate an Argoverse static map object (without raster data) from a JSON file containing map data.

Args:
Expand All @@ -322,7 +325,7 @@ def from_json(cls, static_map_path: Path) -> ArgoverseStaticMap:
An Argoverse HD map.
"""
log_id = static_map_path.stem.split("log_map_archive_")[1]
vector_data = io_utils.read_json_file(static_map_path)
vector_data = io.read_json_file(static_map_path)

vector_drivable_areas = {da["id"]: DrivableArea.from_dict(da) for da in vector_data["drivable_areas"].values()}
vector_lane_segments = {ls["id"]: LaneSegment.from_dict(ls) for ls in vector_data["lane_segments"].values()}
Expand All @@ -346,7 +349,7 @@ def from_json(cls, static_map_path: Path) -> ArgoverseStaticMap:
)

@classmethod
def from_map_dir(cls, log_map_dirpath: Path, build_raster: bool = False) -> ArgoverseStaticMap:
def from_map_dir(cls, log_map_dirpath: Union[Path, UPath], build_raster: bool = False) -> ArgoverseStaticMap:
"""Instantiate an Argoverse map object from data stored within a map data directory.

Note: the ground height surface file and associated coordinate mapping is not provided for the
Expand Down
2 changes: 1 addition & 1 deletion src/av2/structures/cuboid.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def project_to_cam(

# Sort by z-order to respect visibility in the scene.
# i.e, closer objects cuboids should be drawn on top of farther objects.
z_orders: NDArrayFloat = np.argsort(-z_buffer)
z_orders: NDArrayInt = np.argsort(-z_buffer)

cuboids_vertices_cam = cuboids_vertices_cam[z_orders]
front_face_indices = [0, 1, 2, 3, 0]
Expand Down
14 changes: 8 additions & 6 deletions src/av2/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pandas as pd
from pyarrow import feather
from upath import UPath

import av2.geometry.geometry as geometry_utils
from av2.geometry.se3 import SE3
Expand All @@ -22,7 +23,7 @@
SensorPosesMapping = Dict[str, SE3]


def read_feather(path: Path, columns: Optional[Tuple[str, ...]] = None) -> pd.DataFrame:
def read_feather(path: Union[Path, UPath], columns: Optional[Tuple[str, ...]] = None) -> pd.DataFrame:
"""Read Apache Feather data from a .feather file.

AV2 uses .feather to serialize much of its data. This function handles the deserialization
Expand All @@ -36,7 +37,8 @@ def read_feather(path: Path, columns: Optional[Tuple[str, ...]] = None) -> pd.Da
Returns:
(N,len(columns)) Apache Feather data represented as a `pandas` DataFrame.
"""
data: pd.DataFrame = feather.read_feather(path, columns=columns)
with path.open("rb") as f:
data: pd.DataFrame = feather.read_feather(f, columns=columns)
return data


Expand Down Expand Up @@ -116,7 +118,7 @@ def read_ego_SE3_sensor(log_dir: Path) -> SensorPosesMapping:
return sensor_name_to_pose


def read_city_SE3_ego(log_dir: Path) -> TimestampedCitySE3EgoPoses:
def read_city_SE3_ego(log_dir: Union[Path, UPath]) -> TimestampedCitySE3EgoPoses:
"""Read the egovehicle poses in the city reference frame.

The egovehicle city pose defines an SE3 transformation from the egovehicle reference frame to the city ref. frame.
Expand Down Expand Up @@ -148,7 +150,7 @@ def read_city_SE3_ego(log_dir: Path) -> TimestampedCitySE3EgoPoses:
Returns:
Mapping from egovehicle time (in nanoseconds) to egovehicle pose in the city reference frame.
"""
city_SE3_ego_path = Path(log_dir, "city_SE3_egovehicle.feather")
city_SE3_ego_path = log_dir / "city_SE3_egovehicle.feather"
city_SE3_ego = read_feather(city_SE3_ego_path)

quat_wxyz = city_SE3_ego.loc[:, ["qw", "qx", "qy", "qz"]].to_numpy()
Expand Down Expand Up @@ -206,7 +208,7 @@ def write_img(img_path: Path, img: NDArrayByte, channel_order: str = "RGB") -> N
cv2.imwrite(str(img_path), img)


def read_json_file(fpath: Path) -> Dict[str, Any]:
def read_json_file(fpath: Union[Path, UPath]) -> Dict[str, Any]:
"""Load dictionary from JSON file.

Args:
Expand All @@ -215,7 +217,7 @@ def read_json_file(fpath: Path) -> Dict[str, Any]:
Returns:
Deserialized Python dictionary.
"""
with open(fpath, "rb") as f:
with fpath.open("rb") as f:
data: Dict[str, Any] = json.load(f)
return data

Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/motion_forecasting/eval/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_compute_brier_ade(
],
)
def test_compute_brier_ade_data_validation(
forecast_probabilities: NDArrayFloat, normalize: bool, expectation: AbstractContextManager # type: ignore
forecast_probabilities: NDArrayFloat, normalize: bool, expectation: AbstractContextManager # type: ignore
) -> None:
"""Test that test_compute_brier_ade raises the correct errors when inputs are invalid.

Expand Down Expand Up @@ -247,7 +247,7 @@ def test_compute_brier_fde(
],
)
def test_compute_brier_fde_data_validation(
forecast_probabilities: NDArrayFloat, normalize: bool, expectation: AbstractContextManager # type: ignore
forecast_probabilities: NDArrayFloat, normalize: bool, expectation: AbstractContextManager # type: ignore
) -> None:
"""Test that test_compute_brier_fde raises the correct errors when inputs are invalid.

Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/motion_forecasting/eval/test_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
ids=["valid", "wrong_shape_trajectory", "mismatched_trajectory_probability_shape"],
)
def test_challenge_submission_data_validation(
test_submission_dict: Dict[str, ScenarioPredictions], expectation: AbstractContextManager # type: ignore
test_submission_dict: Dict[str, ScenarioPredictions], expectation: AbstractContextManager # type: ignore
) -> None:
"""Test that validation of submitted trajectories works as expected during challenge submission initialization.

Expand Down