Skip to content

Commit

Permalink
fix: allow required=False fields to be logged in Argilla (argilla-i…
Browse files Browse the repository at this point in the history
…o#3846)

# Description

This PR addresses an issue with the Argilla API, since it's not handling
optional values passed as e.g. `{"required": "text", "optional_value":
None}` neither when those optional fields are not provided as part of
the payload e.g. `{"required": "text"}`.

So on, in the PR the API validation when creating new records has been
fixed to check that the optional fields are neither None nor str,
instead of applying the same check as for the required fields; plus
improving the `to_server_payload` method in the `FeedbackRecord` schema
not to include the fields with value None.

Closes argilla-io#3845

**Type of change**

- [X] Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**

- [x] Add unit tests for `FeedbackRecord.to_server_payload` with
`required=False` fields
- [X] Add unit tests for the validation on the API-side when creating
records via `validate_record_fields`

**Checklist**

- [ ] I added relevant documentation
- [X] follows the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [X] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <[email protected]>
Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
3 people committed Sep 28, 2023
1 parent 4fce163 commit 92392c4
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ These are the section headers that we use:
### Fixed

- Updated active learning for text classification notebooks to pass ids of type int to `TextClassificationRecord` ([#3831](https://github.com/argilla-io/argilla/pull/3831)).

- Fixed record fields validation that was preventing from logging records with optional fields (i.e. `required=True`) when the field value was `None` ([#3846](https://github.com/argilla-io/argilla/pull/3846)).

## [1.16.0](https://github.com/argilla-io/argilla/compare/v1.15.1...v1.16.0)

Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/schemas/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class FeedbackRecord(BaseModel):
"""

fields: Dict[str, str]
fields: Dict[str, Union[str, None]]
metadata: Dict[str, Any] = Field(default_factory=dict)
responses: List[ResponseSchema] = Field(default_factory=list)
suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field(
Expand Down Expand Up @@ -244,7 +244,7 @@ def to_server_payload(self, question_name_to_id: Optional[Dict[str, UUID]] = Non
to create a `FeedbackRecord` in the `FeedbackDataset`.
"""
payload = {}
payload["fields"] = self.fields
payload["fields"] = {key: value for key, value in self.fields.items() if value is not None}
if self.responses:
payload["responses"] = [response.to_server_payload() for response in self.responses]
if self.suggestions and question_name_to_id:
Expand Down
3 changes: 2 additions & 1 deletion src/argilla/client/sdk/v1/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def add_records(
if isinstance(response.get("user_id"), UUID):
response["user_id"] = str(response.get("user_id"))
cleaned_responses.append(response)
record["responses"] = cleaned_responses
if len(cleaned_responses) > 0:
record["responses"] = cleaned_responses

for suggestion in record.get("suggestions", []):
if isinstance(suggestion.get("question_id"), UUID):
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def validate_record_fields(dataset: Dataset, fields: Dict[str, Any]):
raise ValueError(f"Missing required value for field: {field.name!r}")

value = fields_copy.pop(field.name, None)
if not isinstance(value, str):
if value and not isinstance(value, str):
raise ValueError(
f"Wrong value found for field {field.name!r}. Expected {str.__name__!r}, found {type(value).__name__!r}"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/client/feedback/schemas/remote/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_remote_response_schema_from_api(payload: FeedbackResponseModel) -> None
(
{
"id": UUID("00000000-0000-0000-0000-000000000000"),
"fields": {"text": "This is the first record", "label": "positive"},
"fields": {"text": "This is the first record", "label": "positive", "optional": None},
"metadata": {"first": True, "nested": {"more": "stuff"}},
"responses": [
{
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2473,6 +2473,56 @@ async def test_create_dataset_records_with_extra_fields(
assert response.json() == {"detail": "Error: found fields values for non configured fields: ['output']"}
assert (await db.execute(select(func.count(Record.id)))).scalar() == 0

@pytest.mark.parametrize(
"record_json",
[
{"fields": {"input": "text-input", "output": "text-output"}},
{"fields": {"input": "text-input", "output": None}},
{"fields": {"input": "text-input"}},
],
)
async def test_create_dataset_records_with_optional_fields(
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict, record_json: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

await FieldFactory.create(name="input", dataset=dataset)
await FieldFactory.create(name="output", dataset=dataset, required=False)

records_json = {"items": [record_json]}

response = await async_client.post(
f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json
)

assert response.status_code == 204, response.json()
await db.refresh(dataset, attribute_names=["records"])
assert (await db.execute(select(func.count(Record.id)))).scalar() == 1

async def test_create_dataset_records_with_wrong_optional_fields(
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)
await FieldFactory.create(name="input", dataset=dataset)
await FieldFactory.create(name="output", dataset=dataset, required=False)
await TextQuestionFactory.create(name="input_ok", dataset=dataset)
await TextQuestionFactory.create(name="output_ok", dataset=dataset)

records_json = {
"items": [
{
"fields": {"input": "text-input", "output": 1},
},
]
}

response = await async_client.post(
f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json
)
assert response.status_code == 422
assert response.json() == {"detail": "Wrong value found for field 'output'. Expected 'str', found 'int'"}
assert (await db.execute(select(func.count(Record.id)))).scalar() == 0

async def test_create_dataset_records_with_index_error(
self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", owner_auth_header: dict
):
Expand Down

0 comments on commit 92392c4

Please sign in to comment.