Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: follow-up fixes for detached pydantic.BaseModel schemas #3829

Merged
merged 20 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
e297605
fix: add missing `abc.ABC` in `RemoteSchema`
alvarobartt Sep 26, 2023
519b800
fix: add missing `abc.ABC` in `FieldSchema` and `QuestionSchema`
alvarobartt Sep 26, 2023
160c3d4
fix: pop `id` in `DatasetConfig.from_yaml`
alvarobartt Sep 26, 2023
5a46584
feat: extend `generate_pydantic_schema` type-hints
alvarobartt Sep 26, 2023
871b2da
refactor: `__init__` validation to `FeedbackDataset`
alvarobartt Sep 26, 2023
2cce820
feat: define `ResponseStatusFilter`
alvarobartt Sep 28, 2023
b45a425
fix: use `property` + `abstractmethod` instead of `abstractproperty`
alvarobartt Oct 10, 2023
f881cee
refactor: move `ResponseStatus` and `ResponseStatusFilter` to `enums.py`
alvarobartt Oct 10, 2023
ba1268e
revert: `Enum` cannot be extended
alvarobartt Oct 10, 2023
8b74e16
fix: `AllowedQuestionTypes` import
alvarobartt Oct 10, 2023
0270ac6
Merge branch 'develop' of github.com:argilla-io/argilla into fix/afte…
alvarobartt Oct 10, 2023
af57f3a
fix: `push_to_argilla` to parse `Remote{Field,Question}`
alvarobartt Oct 10, 2023
8d10963
fix: `from_huggingface` fully backwards compatible
alvarobartt Oct 10, 2023
d1d4098
test(unit): ensure back comp with `argilla.cfg`
alvarobartt Oct 11, 2023
3cca5cc
fix(style): `DatasetConfig.from_yaml` naming
alvarobartt Oct 11, 2023
f5bd67e
test(unit): ensure back comp with `argilla.yaml`
alvarobartt Oct 11, 2023
9534c81
docs: update `CHANGELOG.md`
alvarobartt Oct 11, 2023
bfaabeb
fix: `DeprecatedDatasetConfig.from_json` check `settings`
alvarobartt Oct 11, 2023
d337e38
revert: move `__init__` validation to `FeedbackDatasetBase`
alvarobartt Oct 11, 2023
5b43da3
fix(test): `ResponseStatusFilter` outdated import
alvarobartt Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ These are the section headers that we use:

### Fixed

- Fixed saving of models trained with `ArgillaTrainer` with a `peft_config` parameter. ([#3795](https://github.com/argilla-io/argilla/pull/3795))
- Fixed saving of models trained with `ArgillaTrainer` with a `peft_config` parameter ([#3795](https://github.com/argilla-io/argilla/pull/3795)).
- Fixed backwards compatibility on `from_huggingface` when loading a `FeedbackDataset` from the Hugging Face Hub that was previously dumped using another version of Argilla, starting at 1.8.0, when it was first introduced ([#3829](https://github.com/argilla-io/argilla/pull/3829)).

## [1.16.0](https://github.com/argilla-io/argilla/compare/v1.15.1...v1.16.0)

Expand Down
52 changes: 44 additions & 8 deletions src/argilla/client/feedback/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
import warnings
from typing import List, Optional
Expand Down Expand Up @@ -43,15 +44,17 @@ def to_yaml(self) -> str:
return dump(self.dict())

@classmethod
def from_yaml(cls, yaml: str) -> "DatasetConfig":
yaml = re.sub(r"(\n\s*|)id: !!python/object:uuid\.UUID\s+int: \d+", "", yaml)
yaml = load(yaml, Loader=SafeLoader)
def from_yaml(cls, yaml_str: str) -> "DatasetConfig":
yaml_str = re.sub(r"(\n\s*|)id: !!python/object:uuid\.UUID\s+int: \d+", "", yaml_str)
yaml_dict = load(yaml_str, Loader=SafeLoader)
# Here for backwards compatibility
for field in yaml["fields"]:
for field in yaml_dict["fields"]:
field.pop("id", None)
field.pop("settings", None)
for question in yaml["questions"]:
for question in yaml_dict["questions"]:
question.pop("id", None)
question.pop("settings", None)
return cls(**yaml)
return cls(**yaml_dict)


# TODO(alvarobartt): here for backwards compatibility, remove in 1.14.0
Expand All @@ -70,11 +73,44 @@ def to_json(self) -> str:
return self.json()

@classmethod
def from_json(cls, json: str) -> "DeprecatedDatasetConfig":
def from_json(cls, json_str: str) -> "DeprecatedDatasetConfig":
warnings.warn(
"`DatasetConfig` can just be loaded from YAML, so make sure that you are"
" loading a YAML file instead of a JSON file. `DatasetConfig` will be dumped"
" as YAML from now on, instead of JSON.",
DeprecationWarning,
)
return cls.parse_raw(json)
parsed_json = json.loads(json_str)
# Here for backwards compatibility
for field in parsed_json["fields"]:
# for 1.10.0, 1.9.0, and 1.8.0
field.pop("id", None)
field.pop("inserted_at", None)
field.pop("updated_at", None)
if "settings" not in field:
continue
field["type"] = field["settings"]["type"]
if "use_markdown" in field["settings"]:
field["use_markdown"] = field["settings"]["use_markdown"]
# for 1.12.0 and 1.11.0
field.pop("settings", None)
for question in parsed_json["questions"]:
# for 1.10.0, 1.9.0, and 1.8.0
question.pop("id", None)
question.pop("inserted_at", None)
question.pop("updated_at", None)
if "settings" not in question:
continue
question.update({"type": question["settings"]["type"]})
if question["type"] in ["rating", "ranking"]:
question["values"] = [option["value"] for option in question["settings"]["options"]]
elif question["type"] in ["label_selection", "multi_label_selection"]:
if all(option["value"] == option["text"] for option in question["settings"]["options"]):
question["labels"] = [option["value"] for option in question["settings"]["options"]]
else:
question["labels"] = {option["value"]: option["text"] for option in question["settings"]["options"]}
if "visible_labels" in question["settings"]:
question["visible_labels"] = question["settings"]["visible_labels"]
# for 1.12.0 and 1.11.0
question.pop("settings", None)
return cls(**parsed_json)
26 changes: 16 additions & 10 deletions src/argilla/client/feedback/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
FeedbackRecord,
FieldSchema,
)
from argilla.client.feedback.schemas.types import AllowedQuestionTypes
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes
from argilla.client.feedback.training.schemas import (
TrainingTaskForChatCompletion,
TrainingTaskForDPO,
Expand All @@ -43,7 +43,6 @@
from datasets import Dataset

from argilla.client.feedback.schemas.types import (
AllowedFieldTypes,
AllowedRemoteFieldTypes,
AllowedRemoteQuestionTypes,
)
Expand All @@ -58,8 +57,8 @@ class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin):
def __init__(
self,
*,
fields: Union[List["AllowedFieldTypes"], List["AllowedRemoteFieldTypes"]],
questions: Union[List["AllowedQuestionTypes"], List["AllowedRemoteQuestionTypes"]],
fields: Union[List[AllowedFieldTypes], List["AllowedRemoteFieldTypes"]],
questions: Union[List[AllowedQuestionTypes], List["AllowedRemoteQuestionTypes"]],
guidelines: Optional[str] = None,
) -> None:
"""Initializes a `FeedbackDatasetBase` instance locally.
Expand All @@ -84,17 +83,21 @@ def __init__(
any_required = False
unique_names = set()
for field in fields:
if not isinstance(field, FieldSchema):
raise TypeError(f"Expected `fields` to be a list of `FieldSchema`, got {type(field)} instead.")
if not isinstance(field, AllowedFieldTypes):
raise TypeError(
f"Expected `fields` to be a list of `{AllowedFieldTypes.__name__}`, got {type(field)} instead."
)
if field.name in unique_names:
raise ValueError(f"Expected `fields` to have unique names, got {field.name} twice instead.")
unique_names.add(field.name)
if not any_required and field.required:
any_required = True

if not any_required:
raise ValueError("At least one `FieldSchema` in `fields` must be required (`required=True`).")
raise ValueError("At least one field in `fields` must be required (`required=True`).")

self._fields = fields
self._fields_schema = None
self._fields_schema = generate_pydantic_schema(self.fields)

if not isinstance(questions, list):
raise TypeError(f"Expected `questions` to be a list, got {type(questions)} instead.")
Expand All @@ -113,8 +116,10 @@ def __init__(
unique_names.add(question.name)
if not any_required and question.required:
any_required = True

if not any_required:
raise ValueError("At least one question in `questions` must be required (`required=True`).")

self._questions = questions

if guidelines is not None:
Expand All @@ -126,6 +131,7 @@ def __init__(
raise ValueError(
"Expected `guidelines` to be either None (default) or a non-empty string, minimum length is 1."
)

self._guidelines = guidelines

@property
Expand All @@ -140,11 +146,11 @@ def guidelines(self) -> str:
return self._guidelines

@property
def fields(self) -> Union[List["AllowedFieldTypes"], List["AllowedRemoteFieldTypes"]]:
def fields(self) -> Union[List[AllowedFieldTypes], List["AllowedRemoteFieldTypes"]]:
"""Returns the fields that define the schema of the records in the dataset."""
return self._fields

def field_by_name(self, name: str) -> Union["AllowedFieldTypes", "AllowedRemoteFieldTypes"]:
def field_by_name(self, name: str) -> Union[AllowedFieldTypes, "AllowedRemoteFieldTypes"]:
"""Returns the field by name if it exists. Othewise a `ValueError` is raised.

Args:
Expand Down
6 changes: 4 additions & 2 deletions src/argilla/client/feedback/dataset/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
from argilla.client.feedback.constants import FETCHING_BATCH_SIZE
from argilla.client.feedback.dataset.base import FeedbackDatasetBase
from argilla.client.feedback.dataset.mixins import ArgillaMixin, UnificationMixin
from argilla.client.feedback.schemas.fields import TextField
from argilla.client.feedback.schemas.types import AllowedQuestionTypes

if TYPE_CHECKING:
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes
from argilla.client.feedback.schemas.types import AllowedFieldTypes


class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin, UnificationMixin):
def __init__(
self,
*,
fields: List["AllowedFieldTypes"],
questions: List["AllowedQuestionTypes"],
questions: List[AllowedQuestionTypes],
guidelines: Optional[str] = None,
) -> None:
"""Initializes a `FeedbackDataset` instance locally.
Expand Down
8 changes: 5 additions & 3 deletions src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,15 @@ def push_to_argilla(
except Exception as e:
raise Exception(f"Failed while creating the `FeedbackDataset` in Argilla with exception: {e}") from e

fields = self.__add_fields(client=httpx_client, id=argilla_id)
# TODO(alvarobartt): re-use ArgillaMixin components when applicable
self.__add_fields(client=httpx_client, id=argilla_id)
fields = self.__get_fields(client=httpx_client, id=argilla_id)

questions = self.__add_questions(client=httpx_client, id=argilla_id)
self.__add_questions(client=httpx_client, id=argilla_id)
questions = self.__get_questions(client=httpx_client, id=argilla_id)
question_name_to_id = {question.name: question.id for question in questions}

self.__publish_dataset(client=httpx_client, id=argilla_id)

self.__push_records(
client=httpx_client, id=argilla_id, show_progress=show_progress, question_name_to_id=question_name_to_id
)
Expand Down
4 changes: 1 addition & 3 deletions src/argilla/client/feedback/dataset/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ def __init__(
TypeError: if `guidelines` is not None and not a string.
ValueError: if `guidelines` is an empty string.
"""
self._fields = fields
self._questions = questions
self._guidelines = guidelines
super().__init__(fields=fields, questions=questions, guidelines=guidelines)

self._client = client # Required to be able to use `allowed_for_roles` decorator
self._id = id
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/dataset/remote/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord
from argilla.client.sdk.users.models import UserRole
from argilla.client.sdk.v1.datasets import api as datasets_api_v1
from argilla.client.sdk.v1.datasets.models import FeedbackResponseStatusFilter
from argilla.client.utils import allowed_for_roles

if TYPE_CHECKING:
from uuid import UUID

import httpx

from argilla.client.feedback.schemas.enums import ResponseStatusFilter
from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes
from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel
from argilla.client.workspaces import Workspace
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
)

def filter_by(
self, response_status: Union[FeedbackResponseStatusFilter, List[FeedbackResponseStatusFilter]]
self, response_status: Union["ResponseStatusFilter", List["ResponseStatusFilter"]]
) -> FilteredRemoteFeedbackDataset:
"""Filters the current `RemoteFeedbackDataset` based on the `response_status` of
the responses of the records in Argilla. This method creates a new class instance
Expand Down
13 changes: 13 additions & 0 deletions src/argilla/client/feedback/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,16 @@ class QuestionTypes(str, Enum):
label_selection = "label_selection"
multi_label_selection = "multi_label_selection"
ranking = "ranking"


class ResponseStatus(str, Enum):
draft = "draft"
submitted = "submitted"
discarded = "discarded"


class ResponseStatusFilter(str, Enum):
draft = "draft"
submitted = "submitted"
discarded = "discarded"
missing = "missing"
9 changes: 5 additions & 4 deletions src/argilla/client/feedback/schemas/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractproperty
from abc import ABC, abstractmethod
from typing import Any, Dict, Literal, Optional

from pydantic import BaseModel, Extra, Field, validator
Expand All @@ -21,7 +21,7 @@
from argilla.client.feedback.schemas.validators import title_must_have_value


class FieldSchema(BaseModel):
class FieldSchema(BaseModel, ABC):
"""Base schema for the `FeedbackDataset` fields.

Args:
Expand Down Expand Up @@ -52,12 +52,13 @@ class Config:
extra = Extra.forbid
exclude = {"type"}

@abstractproperty
@property
@abstractmethod
def server_settings(self) -> Dict[str, Any]:
"""Abstract property that should be implemented by the classes that inherit from
this one, and that will be used to create the `FeedbackDataset` in Argilla.
"""
raise NotImplementedError
...

def to_server_payload(self) -> Dict[str, Any]:
"""Method that will be used to create the payload that will be sent to Argilla
Expand Down
9 changes: 5 additions & 4 deletions src/argilla/client/feedback/schemas/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import warnings
from abc import abstractproperty
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Extra, Field, conint, conlist, root_validator, validator
Expand All @@ -23,7 +23,7 @@
from argilla.client.feedback.schemas.validators import title_must_have_value


class QuestionSchema(BaseModel):
class QuestionSchema(BaseModel, ABC):
"""Base schema for the `FeedbackDataset` questions. Which means that all the questions
in the dataset will have at least these fields.

Expand Down Expand Up @@ -58,12 +58,13 @@ class Config:
extra = Extra.forbid
exclude = {"type"}

@abstractproperty
@property
@abstractmethod
def server_settings(self) -> Dict[str, Any]:
"""Abstract property that should be implemented by the classes that inherit from
this one, and that will be used to create the `FeedbackDataset` in Argilla.
"""
raise NotImplementedError
...

def to_server_payload(self) -> Dict[str, Any]:
"""Method that will be used to create the payload that will be sent to Argilla
Expand Down
13 changes: 3 additions & 10 deletions src/argilla/client/feedback/schemas/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

import warnings
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
from uuid import UUID

from pydantic import BaseModel, Extra, Field, PrivateAttr, StrictInt, StrictStr, conint, validator

from argilla.client.feedback.schemas.enums import ResponseStatus

if TYPE_CHECKING:
from argilla.client.feedback.unification import UnifiedValueSchema

Expand Down Expand Up @@ -46,12 +47,6 @@ class ValueSchema(BaseModel):
value: Union[StrictStr, StrictInt, List[str], List[RankingValueSchema]]


class ResponseStatus(str, Enum):
draft = "draft"
submitted = "submitted"
discarded = "discarded"


class ResponseSchema(BaseModel):
"""Schema for the `FeedbackRecord` response.

Expand Down Expand Up @@ -103,9 +98,7 @@ class SuggestionSchema(BaseModel):
"""Schema for the suggestions for the questions related to the record.

Args:
question_id: ID of the question in Argilla. Defaults to None, and is automatically
fulfilled internally once the question is pushed to Argilla.
question_name: name of the question.
question_name: name of the question in the `FeedbackDataset`.
type: type of the question. Defaults to None. Possible values are `model` or `human`.
score: score of the suggestion. Defaults to None.
value: value of the suggestion, which should match the type of the question.
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/schemas/remote/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Optional, Type
from uuid import UUID

import httpx
from pydantic import BaseModel


class RemoteSchema(BaseModel):
class RemoteSchema(BaseModel, ABC):
id: Optional[UUID] = None
client: Optional[httpx.Client] = None

Expand Down
Loading
Loading