Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU test runner #242

Merged
merged 27 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c0d07aa
Add `rstcheck` and `doc8`
michalk8 Feb 2, 2023
2b3021d
Pass `CUDA_VISIBLE_DEVICES` to `tox`
michalk8 Feb 2, 2023
6fb3e9d
Add GPU CI runner
michalk8 Feb 3, 2023
d63441d
Merge branch 'main' into feature/improve-pre-commits
michalk8 Feb 3, 2023
af3cc26
Fix Python version for GPU tests
michalk8 Feb 3, 2023
4572dd4
Try running without `tox`
michalk8 Feb 3, 2023
f6c9ea1
Fix not installing jax[cuda]
michalk8 Feb 3, 2023
6010ebd
Use different Docker image
michalk8 Feb 6, 2023
5244413
Fix escpape
michalk8 Feb 6, 2023
7aef30d
Use apt-get
michalk8 Feb 6, 2023
270b86a
Do not use `{}`
michalk8 Feb 6, 2023
68650e5
Fix not installing `git`
michalk8 Feb 6, 2023
7cb3d10
Use personal Docker image
michalk8 Feb 7, 2023
6b34346
Pin `jax[cuda]` version
michalk8 Feb 7, 2023
ab18f39
Mark grad(sqrtm) as CPU only test
michalk8 Feb 7, 2023
b682b64
Fix ICNN hessian test on GPU
michalk8 Feb 7, 2023
1236aa6
Use `eigvalsh` to check for positive-semidefinite
michalk8 Feb 7, 2023
c60cf50
Adjust tolerance in a test
michalk8 Feb 7, 2023
9991e15
Mark Sinkhorn online as CPU
michalk8 Feb 7, 2023
7c2aa8d
Run all tests on GPU
michalk8 Feb 7, 2023
2634a58
Skip more tests on GPU
michalk8 Feb 8, 2023
1fe92af
Update tolerances on k-means test
michalk8 Feb 8, 2023
669dec5
Always jit in online Sinkhorn test
michalk8 Feb 8, 2023
7dcb6b3
Use simple comparison
michalk8 Feb 8, 2023
2863be1
Only run fast GPU tests, try other GPU
michalk8 Feb 8, 2023
d6cfb78
Use previous GPU
michalk8 Feb 8, 2023
d7210f7
[ci skip] Fix test
michalk8 Feb 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Mark grad(sqrtm) as CPU only test
  • Loading branch information
michalk8 committed Feb 7, 2023
commit ab18f393838c88cac60e10f5d434729d1a9650a6
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:

- name: Run tests
run: |
python3 -m pytest -m fast --memray -vv
python3 -m pytest -m "fast and not cpu" --memray -vv

tests:
name: Python ${{ matrix.python-version }} on ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ testpaths = [
"tests",
]
markers = [
"cpu: Mark tests as CPU only.",
"fast: Mark tests as fast.",
"notebook: Mark tests as notebook related.",
]

[tool.coverage.run]
Expand Down
2 changes: 1 addition & 1 deletion src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def update_linearization(
transport_matrix = transport.matrix * rescale_factor

if not self.is_balanced:
# Rescales transport for Unbalanced GW according to Sejourne et al (2021).
# Rescales transport for Unbalanced GW according to Sejourne et al. (2021)
transport_mass = jax.lax.stop_gradient(marginal_1.sum())
epsilon = update_epsilon_unbalanced(epsilon, transport_mass)
unbalanced_correction = self.cost_unbalanced_correction(
Expand Down
2 changes: 2 additions & 0 deletions tests/math/matrix_square_root_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def test_solve_bartels_stewart_batch(self):
)
np.testing.assert_allclose(self.x, x[0, 0], atol=1.e-5)

# gradient needs Schur decomposition, currently not implemented in jax for GPU
@pytest.mark.cpu
@pytest.mark.fast.with_args(
"fn,n_tests,dim,epsilon,atol,rtol",
[(lambda x: matrix_square_root.sqrtm(x)[0], 3, 3, 1e-6, 1e-6, 1e-6),
Expand Down