Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add list classmethod to FeedbackDataset via ArgillaMixin #3619

Merged
merged 11 commits into from
Aug 23, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ These are the section headers that we use:
- Added `login` function in `argilla.client.login` to login into an Argilla server and store the credentials locally ([#3582](https://github.com/argilla-io/argilla/pull/3582)).
- Added `login` command to login into an Argilla server ([#3600](https://github.com/argilla-io/argilla/pull/3600)).
- Added `response_status` param to `GET /api/v1/datasets/{dataset_id}/records` to be able to filter by `response_status` as previously included for `GET /api/v1/me/datasets/{dataset_id}/records` ([#3613](https://github.com/argilla-io/argilla/pull/3613)).
- Added `list` classmethod to `ArgillaMixin` to be used as `FeedbackDataset.list()`, also including the `workspace` to list from as arg ([#3619](https://github.com/argilla-io/argilla/pull/3619)).

### Changed

- Updated `RemoteFeedbackDataset.delete_records` to use batch delete records endpoint ([#3580](https://github.com/argilla-io/argilla/pull/3580)).
- Included `allowed_for_roles` for some `RemoteFeedbackDataset`, `RemoteFeedbackRecords`, and `RemoteFeedbackRecord` methods that are only allowed for users with roles `owner` and `admin` ([#3601](https://github.com/argilla-io/argilla/pull/3601)).
- Renamed `ArgillaToFromMixin` to `ArgillaMixin` ([#3619](https://github.com/argilla-io/argilla/pull/3619)).

### Changed

Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/dataset/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from argilla.client.feedback.constants import FETCHING_BATCH_SIZE
from argilla.client.feedback.dataset.base import FeedbackDatasetBase
from argilla.client.feedback.dataset.mixins import ArgillaToFromMixin
from argilla.client.feedback.dataset.mixins import ArgillaMixin
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes

if TYPE_CHECKING:
Expand All @@ -27,7 +27,7 @@
warnings.simplefilter("always", DeprecationWarning)


class FeedbackDataset(FeedbackDatasetBase, ArgillaToFromMixin):
class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin):
def __init__(
self,
*,
Expand Down
46 changes: 45 additions & 1 deletion src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
warnings.simplefilter("always", DeprecationWarning)


class ArgillaToFromMixin:
class ArgillaMixin:
# TODO(alvarobartt): remove when `delete` is implemented
def __delete_dataset(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> None:
try:
Expand Down Expand Up @@ -332,3 +332,47 @@
questions=questions,
guidelines=existing_dataset.guidelines or None,
)

@classmethod
def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[RemoteFeedbackDataset]:
"""Lists the `FeedbackDataset`s pushed to Argilla.

Note that you may need to `rg.init(...)` with your Argilla credentials before
calling this function, otherwise, the default http://localhost:6900 will be used,
which will fail if Argilla is not deployed locally.

Args:
workspace: the workspace where to list the datasets from. If not provided,
then the workspace filtering won't be applied. Defaults to `None`.

Returns:
A list of `RemoteFeedbackDataset` datasets, which are `FeedbackDataset`
datasets previously pushed to Argilla via `push_to_argilla`.
"""
client: "ArgillaClient" = ArgillaSingleton.get()
httpx_client: "httpx.Client" = client.http_client.httpx

if workspace is not None:
workspace = Workspace.from_name(workspace)

# TODO(alvarobartt or gabrielmbmb): add `workspace_id` in `GET /api/v1/datasets`
# and in `GET /api/v1/me/datasets` to filter by workspace
try:
datasets = datasets_api_v1.list_datasets(client=httpx_client).parsed
except Exception as e:
raise RuntimeError(

Check warning on line 363 in src/argilla/client/feedback/dataset/mixins.py

View check run for this annotation

Codecov / codecov/patch

src/argilla/client/feedback/dataset/mixins.py#L362-L363

Added lines #L362 - L363 were not covered by tests
f"Failed while listing the `FeedbackDataset` datasets in Argilla with exception: {e}"
) from e
return [
RemoteFeedbackDataset(
client=httpx_client,
id=dataset.id,
name=dataset.name,
workspace=workspace if workspace is not None else Workspace.from_id(dataset.workspace_id),
fields=cls.__get_fields(client=httpx_client, id=dataset.id),
questions=cls.__get_questions(client=httpx_client, id=dataset.id),
guidelines=dataset.guidelines or None,
)
for dataset in datasets
alvarobartt marked this conversation as resolved.
Show resolved Hide resolved
if workspace is None or dataset.workspace_id == workspace.id
]
58 changes: 45 additions & 13 deletions tests/integration/client/feedback/dataset/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest
from argilla.client import api
from argilla.client.feedback.dataset import FeedbackDataset
from argilla.client.feedback.dataset.remote import RemoteFeedbackDataset
from argilla.client.sdk.users.models import UserRole

from tests.factories import DatasetFactory, RecordFactory, TextFieldFactory, TextQuestionFactory, UserFactory
Expand All @@ -23,13 +24,10 @@
@pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin])
@pytest.mark.asyncio
async def test_delete_records(role: UserRole) -> None:
text_field = await TextFieldFactory.create(required=True)
rating_question = await TextQuestionFactory.create(required=True)
dataset = await DatasetFactory.create(
fields=[text_field],
questions=[rating_question],
records=await RecordFactory.create_batch(size=100),
)
dataset = await DatasetFactory.create()
await TextFieldFactory.create(dataset=dataset, required=True)
await TextQuestionFactory.create(dataset=dataset, required=True)
await RecordFactory.create_batch(dataset=dataset, size=10)
user = await UserFactory.create(role=role, workspaces=[dataset.workspace])

api.init(api_key=user.api_key)
Expand All @@ -47,9 +45,10 @@ async def test_delete_records(role: UserRole) -> None:
@pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin])
@pytest.mark.asyncio
async def test_delete(role: UserRole) -> None:
text_field = await TextFieldFactory.create(required=True)
rating_question = await TextQuestionFactory.create(required=True)
dataset = await DatasetFactory.create(fields=[text_field], questions=[rating_question])
dataset = await DatasetFactory.create()
await TextFieldFactory.create(dataset=dataset, required=True)
await TextQuestionFactory.create(dataset=dataset, required=True)
await RecordFactory.create_batch(dataset=dataset, size=10)
user = await UserFactory.create(role=role, workspaces=[dataset.workspace])

api.init(api_key=user.api_key)
Expand All @@ -63,13 +62,46 @@ async def test_delete(role: UserRole) -> None:
@pytest.mark.parametrize("role", [UserRole.annotator])
@pytest.mark.asyncio
async def test_delete_not_allowed_role(role: UserRole) -> None:
text_field = await TextFieldFactory.create(required=True)
rating_question = await TextQuestionFactory.create(required=True)
dataset = await DatasetFactory.create(fields=[text_field], questions=[rating_question])
dataset = await DatasetFactory.create()
await TextFieldFactory.create(dataset=dataset, required=True)
await TextQuestionFactory.create(dataset=dataset, required=True)
await RecordFactory.create_batch(dataset=dataset, size=10)
user = await UserFactory.create(role=role, workspaces=[dataset.workspace])

api.init(api_key=user.api_key)
remote_dataset = FeedbackDataset.from_argilla(id=dataset.id)

with pytest.raises(PermissionError, match=f"User with role={role} is not allowed to call `delete`"):
remote_dataset.delete()


@pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin, UserRole.annotator])
@pytest.mark.asyncio
async def test_list(role: UserRole) -> None:
dataset = await DatasetFactory.create()
await TextFieldFactory.create(dataset=dataset, required=True)
await TextQuestionFactory.create(dataset=dataset, required=True)
await RecordFactory.create_batch(dataset=dataset, size=10)
user = await UserFactory.create(role=role, workspaces=[dataset.workspace])

api.init(api_key=user.api_key)
remote_datasets = FeedbackDataset.list()
assert len(remote_datasets) == 1
assert all(isinstance(remote_dataset, RemoteFeedbackDataset) for remote_dataset in remote_datasets)
assert all(remote_dataset.workspace.id == dataset.workspace.id for remote_dataset in remote_datasets)


@pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin, UserRole.annotator])
@pytest.mark.asyncio
async def test_list_with_workspace_name(role: UserRole) -> None:
dataset = await DatasetFactory.create()
await TextFieldFactory.create(dataset=dataset, required=True)
await TextQuestionFactory.create(dataset=dataset, required=True)
await RecordFactory.create_batch(dataset=dataset, size=10)
user = await UserFactory.create(role=role, workspaces=[dataset.workspace])

api.init(api_key=user.api_key)
remote_datasets = FeedbackDataset.list(workspace=dataset.workspace.name)
assert len(remote_datasets) == 1
assert all(isinstance(remote_dataset, RemoteFeedbackDataset) for remote_dataset in remote_datasets)
assert all(remote_dataset.workspace.id == dataset.workspace.id for remote_dataset in remote_datasets)
Loading