Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
exs-whaddadin committed Dec 5, 2023
1 parent 080e698 commit c9acd7f
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 21 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
MolFlux
=======

**MolFlux** is a foundational package for molecular property prediction.
**MolFlux** is a foundational package for molecular predictive modelling.

```{note}
These docs are under active dev
Expand Down
4 changes: 3 additions & 1 deletion docs/source/pages/datasets/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ The ``datasets`` submodule is meant to address all of these issues. It is a libr
multiple sources. Whether you are looking for public datasets (such as PDBBind or QM9) or just easy access to saved data,
``datasets`` provides a standard and modular interface for accessing and manipulating these datasets!

It is built on top of [huggingface ``datasets``](https://huggingface.co/docs/datasets/index).
It is built on top of [huggingface ``datasets``](https://huggingface.co/docs/datasets/index). The huggingface ``datasets``
package is versatile, fast, and efficient. It can handle many data types and its built-in functionality allows for scalable
and fast data manipulation and handling.
14 changes: 10 additions & 4 deletions docs/source/pages/modelzoo/basic_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ This returns our catalogue of available model architectures (organised by the de
```{seealso}
[How to add your own model](how_to_add_models.md) if you would like to add your own model to the catalogue
```
For instance, `molflux.modelzoo.list_models()` returns as one item in the dictionary:
`'xgboost': ['xg_boost_classifier', 'xg_boost_regressor']`. In order to be able to use the two models `xg_boost_classifier`
For instance, `molflux.modelzoo.list_models()` returns as one item in the dictionary:
`'xgboost': ['xg_boost_classifier', 'xg_boost_regressor']`. In order to be able to use the two models `xg_boost_classifier`
and `xg_boost_regressor`, you would do: ``pip install molflux[xgboost]``.

## Loading a model architecture
Expand Down Expand Up @@ -287,8 +287,12 @@ predictions = model.predict(test_data)
print(predictions)
```

This returns a dictionary of your model's predictions!
This returns a dictionary of your model's predictions! Models can also support different inference methods. For example,
some classification models support the ``predict_proba`` method which returns the probabilities of the classes

```python
probabilities = model.predict_proba(test_data)
```

## Saving/Loading a model

Expand All @@ -307,7 +311,9 @@ save_to_store("path_to_my_model/", model)
The ``save_to_store`` function takes the path and the model to save. It can save to local disk or to an s3 location.

```{note}
Recommend using prod saving ..., link to core
For models intended for production level usage, we recommend that they are saved as described in the [productionising](../production/models.md)
section. Along with the model, this also save the featurisation metadata and a snapshot of the environment the model was
built in.
```

### Loading
Expand Down
8 changes: 4 additions & 4 deletions docs/source/pages/modelzoo/uncertainty.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ models implement additional functionalities:
while their spread indicates how uncertain the model is about this input.
4) `calibrate_uncertainty(data, **kwargs)` - calibrates the uncertainty of this model to an external/validation dataset

You can check whether a model implements any of these methods by using the appropriate `supports_x()` utility function:
You can check whether a model implements any of these methods by using the appropriate `supports_*` utility function:

```{code-cell} ipython3
from molflux.modelzoo import load_model, supports_prediction_interval
Expand All @@ -61,8 +61,8 @@ Similarly, `supports_std`, `supports_sampling`, and `supports_uncertainty_calibr

### Quick example - CatBoost Models

The typical example for a model with implemented uncertainty methods is a CatBoost Model.
Gaussian processes are distributions over functions such that predictions can be characterized by a mean and covariance.
A typical example for a model with implemented uncertainty methods is the CatBoost Model. This model architecture can
return both a mean and standard deviation prediction.

In the example below, we will train and predict using a CatBoost, and then use some of the
functions defined above to get a measure of the model uncertainty.
Expand Down Expand Up @@ -113,7 +113,7 @@ fitted `k` times.
2. In two steps - first, training an underlying model on training data, then
calibrating the uncertainty of it on a validation dataset

Both of these are possible with our Mapie implementation.
Both of these are possible with our [Mapie](https://github.com/scikit-learn-contrib/MAPIE) implementation.

```{note}
This functionality is still work in progress.
Expand Down
12 changes: 6 additions & 6 deletions docs/source/pages/splits/gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The original `sklearn` [notebook](https://scikit-learn.org/stable/auto_examples/

```{code-cell} ipython3
---
tags: [remove-input]
tags: [hide-input]
---
import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -59,7 +59,7 @@ To begin, we'll visualize our data:

```{code-cell} ipython3
---
tags: [remove-input]
tags: [hide-input]
---
# Generate the class/group data
n_points = 100
Expand Down Expand Up @@ -111,7 +111,7 @@ the validation set (in grey), and the test set (in red).

```{code-cell} ipython3
---
tags: [remove-input]
tags: [hide-input]
---
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""Create a sample plot for indices of a cross-validation object."""
Expand Down Expand Up @@ -170,7 +170,7 @@ Let's see how it looks for the `k_fold` cross-validation object:

```{code-cell} ipython3
---
tags: [remove-input]
tags: [hide-input]
---
fig, ax = plt.subplots(figsize=figsize)
strategy = load_splitting_strategy("k_fold")
Expand All @@ -186,7 +186,7 @@ consideration. We can change this by using either:

```{code-cell} ipython3
---
tags: [remove-input]
tags: [hide-input]
---
strategies = ["stratified_k_fold", "group_k_fold"]
Expand All @@ -207,7 +207,7 @@ Note how some use the group/class information while others do not:

```{code-cell} ipython3
---
tags: [remove-input]
tags: [hide-input]
---
strategies = ["group_k_fold", "group_shuffle_split", "k_fold", "linear_split", "shuffle_split", "stratified_k_fold", "stratified_shuffle_split", "time_series_split"]
Expand Down
3 changes: 3 additions & 0 deletions docs/source/pages/tutorials/esol_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ print(dataset)
print(dataset[0])
```

The loaded dataset is an instance of a HuggingFace ``Dataset`` (for more info, checkout the [docs](https://huggingface.co/docs/datasets/index)).
You can see that there are two columns: ``smiles`` and ``log_solubility``.


Expand Down Expand Up @@ -127,6 +128,7 @@ plt.scatter(
split_featurised_dataset["test"]["log_solubility"],
preds["random_forest_regressor::log_solubility"],
)
plt.plot([-10, 0], [-10, 0], c='r')
plt.xlabel("True values")
plt.ylabel("Predicted values")
plt.show()
Expand Down Expand Up @@ -234,6 +236,7 @@ plt.scatter(
split_featurised_dataset["test"]["log_solubility"],
preds["random_forest_regressor::log_solubility"],
)
plt.plot([-10, 0], [-10, 0], c='r')
plt.xlabel("True values")
plt.ylabel("Predicted values")
plt.show()
Expand Down
4 changes: 2 additions & 2 deletions src/molflux/datasets/builders/ani2x/ani2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import datasets
from molflux.datasets.typing import ExamplesGenerator

from .ani2x_configs import FEATURES, LEVEL_OF_THEORY, URL_DICT
from .ani2x_configs import FEATURES, LEVELS_OF_THEORY, URL_DICT

_BASE_URL = "https://zenodo.org/records/10108942"

Expand All @@ -26,7 +26,7 @@
@dataclass
class ANI2XConfig(datasets.BuilderConfig):
backend: Literal["openeye", "rdkit"] = "rdkit"
level_of_theory: LEVEL_OF_THEORY = "wB97X/631Gd"
level_of_theory: LEVELS_OF_THEORY = "wB97X/631Gd"


class ANI2X(datasets.GeneratorBasedBuilder):
Expand Down
2 changes: 1 addition & 1 deletion src/molflux/datasets/builders/ani2x/ani2x_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
}

LEVEL_OF_THEORY = Literal[
LEVELS_OF_THEORY = Literal[
"wB97X/631Gd",
"wB97X/def2TZVPP",
"wB97MD3BJ/def2TZVPP",
Expand Down
6 changes: 4 additions & 2 deletions tests/datasets/plugins/openeye/ani2x/test_ani2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datasets
from molflux.datasets import load_dataset
from molflux.datasets.builders.ani2x.ani2x import ANI2X
from molflux.datasets.builders.ani2x.ani2x_configs import FEATURES
from molflux.datasets.catalogue import list_datasets

dataset_name = "ani2x"
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_dataset_has_correct_num_rows(level_of_theory):

dataset = load_dataset(dataset_name, backend_name, level_of_theory=level_of_theory)

assert set(dataset.column_names) == set(FEATURES[level_of_theory].keys())
assert len(dataset) == 10


Expand All @@ -95,8 +97,8 @@ def test_dataset_has_correct_num_rows(level_of_theory):
levels_of_theory,
)
@pytest.mark.usefixtures("_fixture_mocked_dataset_asset")
def test_dataset_is_readable_with_rdkit(level_of_theory):
"""That rdkit can read the mol bytes"""
def test_dataset_is_readable_with_openeye(level_of_theory):
"""That openeye can read the mol bytes"""

dataset = load_dataset(dataset_name, backend_name, level_of_theory=level_of_theory)

Expand Down
2 changes: 2 additions & 0 deletions tests/datasets/plugins/rdkit/ani2x/test_ani2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datasets
from molflux.datasets import load_dataset
from molflux.datasets.builders.ani2x.ani2x import ANI2X
from molflux.datasets.builders.ani2x.ani2x_configs import FEATURES
from molflux.datasets.catalogue import list_datasets

dataset_name = "ani2x"
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_dataset_has_correct_num_rows(level_of_theory):
"""That the built dataset has correct num rows."""

dataset = load_dataset(dataset_name, backend_name, level_of_theory=level_of_theory)
assert set(dataset.column_names) == set(FEATURES[level_of_theory].keys())

assert len(dataset) == 10

Expand Down

0 comments on commit c9acd7f

Please sign in to comment.