diff --git a/CHANGELOG.md b/CHANGELOG.md index 83066779c8..3e3f4e16aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ These are the section headers that we use: ### Fixed - Updated active learning for text classification notebooks to pass ids of type int to `TextClassificationRecord` ([#3831](https://github.com/argilla-io/argilla/pull/3831)). - +- Fixed record fields validation that was preventing from logging records with optional fields (i.e. `required=True`) when the field value was `None` ([#3846](https://github.com/argilla-io/argilla/pull/3846)). ## [1.16.0](https://github.com/argilla-io/argilla/compare/v1.15.1...v1.16.0) diff --git a/src/argilla/client/feedback/schemas/records.py b/src/argilla/client/feedback/schemas/records.py index 2da5dada4f..ff544fb40d 100644 --- a/src/argilla/client/feedback/schemas/records.py +++ b/src/argilla/client/feedback/schemas/records.py @@ -196,7 +196,7 @@ class FeedbackRecord(BaseModel): """ - fields: Dict[str, str] + fields: Dict[str, Union[str, None]] metadata: Dict[str, Any] = Field(default_factory=dict) responses: List[ResponseSchema] = Field(default_factory=list) suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field( @@ -244,7 +244,7 @@ def to_server_payload(self, question_name_to_id: Optional[Dict[str, UUID]] = Non to create a `FeedbackRecord` in the `FeedbackDataset`. """ payload = {} - payload["fields"] = self.fields + payload["fields"] = {key: value for key, value in self.fields.items() if value is not None} if self.responses: payload["responses"] = [response.to_server_payload() for response in self.responses] if self.suggestions and question_name_to_id: diff --git a/src/argilla/client/sdk/v1/datasets/api.py b/src/argilla/client/sdk/v1/datasets/api.py index e17a482ec3..4d79e21566 100644 --- a/src/argilla/client/sdk/v1/datasets/api.py +++ b/src/argilla/client/sdk/v1/datasets/api.py @@ -236,7 +236,8 @@ def add_records( if isinstance(response.get("user_id"), UUID): response["user_id"] = str(response.get("user_id")) cleaned_responses.append(response) - record["responses"] = cleaned_responses + if len(cleaned_responses) > 0: + record["responses"] = cleaned_responses for suggestion in record.get("suggestions", []): if isinstance(suggestion.get("question_id"), UUID): diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 3547568e5d..814ab0aa9c 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -494,7 +494,7 @@ def validate_record_fields(dataset: Dataset, fields: Dict[str, Any]): raise ValueError(f"Missing required value for field: {field.name!r}") value = fields_copy.pop(field.name, None) - if not isinstance(value, str): + if value and not isinstance(value, str): raise ValueError( f"Wrong value found for field {field.name!r}. Expected {str.__name__!r}, found {type(value).__name__!r}" ) diff --git a/tests/unit/client/feedback/schemas/remote/test_records.py b/tests/unit/client/feedback/schemas/remote/test_records.py index 55c3260cd9..731748faec 100644 --- a/tests/unit/client/feedback/schemas/remote/test_records.py +++ b/tests/unit/client/feedback/schemas/remote/test_records.py @@ -207,7 +207,7 @@ def test_remote_response_schema_from_api(payload: FeedbackResponseModel) -> None ( { "id": UUID("00000000-0000-0000-0000-000000000000"), - "fields": {"text": "This is the first record", "label": "positive"}, + "fields": {"text": "This is the first record", "label": "positive", "optional": None}, "metadata": {"first": True, "nested": {"more": "stuff"}}, "responses": [ { diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index 1ed0695ae0..6746344385 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -2473,6 +2473,56 @@ async def test_create_dataset_records_with_extra_fields( assert response.json() == {"detail": "Error: found fields values for non configured fields: ['output']"} assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 + @pytest.mark.parametrize( + "record_json", + [ + {"fields": {"input": "text-input", "output": "text-output"}}, + {"fields": {"input": "text-input", "output": None}}, + {"fields": {"input": "text-input"}}, + ], + ) + async def test_create_dataset_records_with_optional_fields( + self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict, record_json: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + await FieldFactory.create(name="input", dataset=dataset) + await FieldFactory.create(name="output", dataset=dataset, required=False) + + records_json = {"items": [record_json]} + + response = await async_client.post( + f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + ) + + assert response.status_code == 204, response.json() + await db.refresh(dataset, attribute_names=["records"]) + assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 + + async def test_create_dataset_records_with_wrong_optional_fields( + self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + await FieldFactory.create(name="input", dataset=dataset) + await FieldFactory.create(name="output", dataset=dataset, required=False) + await TextQuestionFactory.create(name="input_ok", dataset=dataset) + await TextQuestionFactory.create(name="output_ok", dataset=dataset) + + records_json = { + "items": [ + { + "fields": {"input": "text-input", "output": 1}, + }, + ] + } + + response = await async_client.post( + f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + ) + assert response.status_code == 422 + assert response.json() == {"detail": "Wrong value found for field 'output'. Expected 'str', found 'int'"} + assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 + async def test_create_dataset_records_with_index_error( self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", owner_auth_header: dict ):