From 1c80defb059b599e9babab2674b26e454f2db609 Mon Sep 17 00:00:00 2001 From: Sara Han <127759186+sdiazlor@users.noreply.github.com> Date: Tue, 19 Dec 2023 21:45:28 +0100 Subject: [PATCH] feat: create default `text_descriptives` as metadata via `utils.modeling` (#4400) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Local PR from https://github.com/argilla-io/argilla/pull/4083 New addition of documentation Closes #4017 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I added relevant documentation - [ ] I followed the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --------- Co-authored-by: m-newhauser <35735816+m-newhauser@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David Berenstein --- CHANGELOG.md | 3 +- .../annotation_workflows.md | 2 +- .../create_update_dataset/create_dataset.md | 2 +- .../create_update_dataset/metadata.md | 52 +++ environment_dev.yml | 1 + pyproject.toml | 1 + .../feedback/integrations/textdescriptives.py | 328 ++++++++++++++++++ .../integrations/test_textdescriptives_.py | 285 +++++++++++++++ 8 files changed, 671 insertions(+), 3 deletions(-) create mode 100644 src/argilla/client/feedback/integrations/textdescriptives.py create mode 100644 tests/unit/client/feedback/integrations/test_textdescriptives_.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 06bc5d5c9b..a8d37dc8b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,8 @@ These are the section headers that we use: ### Added -- Added strategy to handle and translate errors from server for `401 http status code`. ([#4362](https://github.com/argilla-io/argilla/pull/4362)) +- Added strategy to handle and translate errors from the server for `401` HTTP status code` ([#4362](https://github.com/argilla-io/argilla/pull/4362)) +- Added integration for `textdescriptives` using `TextDescriptivesExtractor` to configure `metadata_properties` in `FeedbackDataset` and `FeedbackRecord`. ([#4400](https://github.com/argilla-io/argilla/pull/4400)). Contributed by @m-newhauser - Added `POST /api/v1/me/responses/bulk` endpoint to create responses in bulk for current user. ([#4380](https://github.com/argilla-io/argilla/pull/4380)) - Added new CLI task to reindex datasets and records into the search engine. ([#4404](https://github.com/argilla-io/argilla/pull/4404)) diff --git a/docs/_source/practical_guides/annotation_workflows/annotation_workflows.md b/docs/_source/practical_guides/annotation_workflows/annotation_workflows.md index 8517287151..6cf834aee9 100644 --- a/docs/_source/practical_guides/annotation_workflows/annotation_workflows.md +++ b/docs/_source/practical_guides/annotation_workflows/annotation_workflows.md @@ -33,4 +33,4 @@ active_learning weak_supervision semantic_search job_scheduling -``` \ No newline at end of file +``` diff --git a/docs/_source/practical_guides/create_update_dataset/create_dataset.md b/docs/_source/practical_guides/create_update_dataset/create_dataset.md index e31029e511..cd73334fd5 100644 --- a/docs/_source/practical_guides/create_update_dataset/create_dataset.md +++ b/docs/_source/practical_guides/create_update_dataset/create_dataset.md @@ -122,7 +122,7 @@ The following arguments apply to specific metadata types: ``` ```{note} -You can also define metadata properties after the dataset has been configured or add them to an existing dataset in Argilla. To do that use the `add_metadata_property` method as explained [here](/practical_guides/create_update_dataset/metadata.md). +You can also define metadata properties after the dataset has been configured or add them to an existing dataset in Argilla using the `add_metadata_property` method. In addition, you can now add text descriptives of your fields as metadata automatically with the `TextDescriptivesExtractor`. For more info, take a look [here](/practical_guides/create_update_dataset/metadata.md). ``` ##### Define `vectors` diff --git a/docs/_source/practical_guides/create_update_dataset/metadata.md b/docs/_source/practical_guides/create_update_dataset/metadata.md index 695b18d98e..e322db4f52 100644 --- a/docs/_source/practical_guides/create_update_dataset/metadata.md +++ b/docs/_source/practical_guides/create_update_dataset/metadata.md @@ -114,6 +114,58 @@ dataset.update_records(modified_records) You can also follow the same strategy to modify existing metadata. ``` +### Add Text Descriptives + +You can easily add text descriptives to your records or datasets using the `TextDescriptivesExtractor` based on the [TextDescriptives](https://github.com/HLasse/TextDescriptives) library, which will add the corresponding metadata properties and metadata automatically. The `TextDescriptivesExtractor` can be used on a `FeedbackDataset` or a `RemoteFeedbackDataset` and accepts the following arguments: + +- `model` (optional): The language of the spacy model that will be used. Defaults to `en`. Check [here](https://spacy.io/usage/models) the available languages and models. +- `metrics` (optional): A list of metrics to extract. The default extracted metrics are: `n_tokens`, `n_unique_tokens`, `n_sentences`, `perplexity`, `entropy`, and `flesch_reading_ease`. You can select your metrics according to the following groups `descriptive_stats`, `readability`, `dependency_distance`, `pos_proportions`, `coherence`, `quality`, and `information_theory`. For more information about each group, check this documentation [page](https://hlasse.github.io/TextDescriptives/descriptivestats.html). +- `fields` (optional): A list of field names to extract metrics from. All fields will be used by default. +- `visible_for_annotators` (optional): Whether the extracted metrics should be visible to annotators. Defaults to `True`. +- `show_progress` (optional): Whether to show a progress bar when extracting metrics. Defaults to `True`. + +For a practical example, check our [tutorial on adding text descriptives as metadata](/tutorials_and_integrations/integrations/add_text_descriptives_as_metadata.ipynb). + +::::{tab-set} + +:::{tab-item} Records +```python +from argilla.client.feedback.integrations.textdescriptives import TextDescriptivesExtractor + +records = [...] # FeedbackRecords or RemoteFeedbackRecords + +tde = TextDescriptivesExtractor( + model="en", + metrics=None, + fields=None, + visible_for_annotators=True, + show_progress=True, +) + +tde.update_records(records) +``` +::: + +:::{tab-item} Dataset +```python +from argilla.client.feedback.integrations.textdescriptives import TextDescriptivesExtractor + +dataset = dataset # FeedbackDataset or RemoteFeedbackDataset + +tde = TextDescriptivesExtractor( + model="en", + metrics=None, + fields=None, + visible_for_annotators=True, + show_progress=True, +) + +tde.update_dataset(dataset) +``` +::: + +:::: + ## Other datasets diff --git a/environment_dev.yml b/environment_dev.yml index 4a9eb076f1..f7913f6782 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -58,6 +58,7 @@ dependencies: - trl>=0.5.0 - sentence-transformers - rich!=13.1.0 + - textdescriptives>=2.7.0,<3.0.0 - ipynbname>=2023.2.0.0 # install Argilla in editable mode - -e .[server,listeners] diff --git a/pyproject.toml b/pyproject.toml index 61ca3d91b8..472cb9da6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ integrations = [ "sentence-transformers", "setfit>=0.7.0", "span_marker", + "textdescriptives>=2.7.0,<3.0.0", "openai>=0.27.10,<1.0.0", "peft", "trl>=0.5.0", diff --git a/src/argilla/client/feedback/integrations/textdescriptives.py b/src/argilla/client/feedback/integrations/textdescriptives.py new file mode 100644 index 0000000000..b0bc61ecfd --- /dev/null +++ b/src/argilla/client/feedback/integrations/textdescriptives.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re +from typing import List, Optional, Union + +import pandas as pd +import textdescriptives as td +from rich.progress import Progress + +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset +from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from argilla.client.feedback.schemas.metadata import ( + FloatMetadataProperty, + IntegerMetadataProperty, + TermsMetadataProperty, +) +from argilla.client.feedback.schemas.records import FeedbackRecord +from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord + +_LOGGER = logging.getLogger(__name__) +_LOGGER.setLevel(logging.INFO) + + +class TextDescriptivesExtractor: + """This class extracts a number of basic text descriptives from FeedbackDataset + records using the TextDescriptives library and adds them as record metadata.""" + + def __init__( + self, + model: str = "en", + metrics: Optional[List[str]] = None, + fields: Optional[List[str]] = None, + visible_for_annotators: bool = True, + show_progress: bool = True, + ): + """ + Initialize a new TextDescriptivesExtractor object. + + Args: + model (str): The language of the model to use for text descriptives. + metrics (Optional[List[str]]): A list of metrics to extract + [“descriptive_stats”, “readability”, “dependency_distance”, “pos_proportions”, “coherence”, “quality”, “information_theory”]. + If None, all metrics will be extracted. + fields (Optional[List[str]]): A list of field names to extract metrics from. If None, all fields will be used. + visible_for_annotators (bool): Whether the extracted metrics should be visible to annotators. + show_progress (bool): Whether to show a progress bar when extracting metrics. + """ + self.model = model + self.metrics = metrics + self.fields = fields + self.visible_for_annotators = visible_for_annotators + self.show_progress = show_progress + self.__basic_metrics = [ + "n_tokens", + "n_unique_tokens", + "n_sentences", + "perplexity", + "entropy", + "flesch_reading_ease", + ] + + def _extract_metrics_for_single_field( + self, + records: List[Union[FeedbackRecord, RemoteFeedbackRecord]], + field: str, + basic_metrics: Optional[List[str]] = None, + ) -> Optional[pd.DataFrame]: + """ + Extract text descriptives metrics for a single field from a list of feedback records + using the TextDescriptives library. + + Args: + records (List[Union[FeedbackRecord, RemoteFeedbackRecord]]): A list of FeedbackDataset or RemoteFeedbackDataset records. + field (str): The name of the field to extract metrics for. + basic_metrics (Optional[List[str]]): A list of basic metrics to extract. If None, all metrics will be extracted. + + Returns: + Optional[pd.DataFrame]: A dataframe containing the text descriptives metrics for the field, or None if the field is empty. + """ + # If the field is empty, skip it + field_text = [record.fields[field] for record in records if record.fields[field]] + if not field_text: + return None + # If language is english, the default spacy model is used (to avoid warning message) + if self.model == "en": + field_metrics = td.extract_metrics(text=field_text, spacy_model="en_core_web_sm", metrics=self.metrics) + else: + field_metrics = td.extract_metrics(text=field_text, lang=self.model, metrics=self.metrics) + # Drop text column + field_metrics = field_metrics.drop("text", axis=1) + # If basic metrics is None, use all basic metrics + if basic_metrics is None and self.metrics is None: + basic_metrics = self.__basic_metrics + field_metrics = field_metrics.loc[:, basic_metrics] + # Select all column names that contain ONLY NaNs + nan_columns = field_metrics.columns[field_metrics.isnull().all()].tolist() + if nan_columns: + _LOGGER.warning(f"The following columns for {field} contain only NaN values: {nan_columns}") + # Concatenate field name with the metric name + field_metrics.columns = [f"{field}_{metric}" for metric in field_metrics.columns] + return field_metrics + + def _extract_metrics_for_all_fields( + self, records: List[Union[FeedbackRecord, RemoteFeedbackRecord]], fields: List[str] = None + ) -> pd.DataFrame: + """ + Extract text descriptives metrics for all named fields from a list of feedback records + using the TextDescriptives library. + Args: + records (List[Union[FeedbackRecord, RemoteFeedbackRecord]]): A list of FeedbackDataset or RemoteFeedbackDataset records. + fields (List[str]): A list of fields to extract metrics for. If None, extract metrics for all fields. + Returns: + pd.DataFrame: A dataframe containing the text descriptives metrics for each record and field. + """ + # If fields is None, use all fields + if self.fields: + fields = self.fields + else: + fields = list({key for record in records for key in record.fields.keys()}) + # Extract all metrics for each field + field_metrics = { + field: self._extract_metrics_for_single_field(records=records, field=field) for field in fields + } + field_metrics = {field: metrics for field, metrics in field_metrics.items() if metrics is not None} + # If there is only one field, return the metrics for that field directly + if len(field_metrics) == 1: + return list(field_metrics.values())[0] + else: + # If there are multiple fields, combine metrics for each field into a single dataframe + final_metrics = pd.concat(field_metrics, axis=1, keys=field_metrics.keys()) + final_metrics.columns = final_metrics.columns.droplevel(0) + return final_metrics + + def _cast_to_python_types(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Convert integer, boolean and floats columns in a dataframe + to Python native types. + + Args: + df (pd.DataFrame): The text descriptives dataframe. + + Returns: + pd.DataFrame: The text descriptives dataframe with integer and boolean columns cast to Python native types. + """ + # Select columns by data type + int_cols = df.select_dtypes(include=["int64"]).columns + bool_cols = df.select_dtypes(include=["boolean"]).columns + float_cols = df.select_dtypes(include=["float64"]).columns + # Cast integer columns to Python's native int type + df[int_cols] = df[int_cols].astype(int) + # Cast boolean columns to Python's native str type + df[bool_cols] = df[bool_cols].astype(str) + # Cast float columns to Python's native float type and round to 2 decimal places + df[float_cols] = df[float_cols].astype(float).round(2) + return df + + def _clean_column_name(self, col_name: str) -> str: + """ + Clean the column name of a dataframe to fit a specific regex pattern. + Args: + col_name (str): A column name. + Returns: + str: A column name that fits the regex pattern. + """ + col_name = col_name.lower() # Convert to lowercase + col_name = re.sub(r"[^a-z0-9_]", "_", col_name) # Replace non-alphanumeric characters with underscores + return col_name + + def _create_metadata_properties(self, df: pd.DataFrame) -> List: + """ + Generate metadata properties based on dataframe columns and data types. + + Args: + df (pd.DataFrame): The text descriptives dataframe. + + Returns: + List: A list of metadata properties. + """ + properties = [] + for col, dtype in df.dtypes.items(): + name = col + title = name.replace("_", " ").title() + if dtype in ["object", "bool"]: + prop = TermsMetadataProperty( + name=name, + title=title, + visible_for_annotators=self.visible_for_annotators, + values=df[col].unique().tolist(), + ) + elif dtype == "int32": + prop = IntegerMetadataProperty( + name=name, title=title, visible_for_annotators=self.visible_for_annotators + ) + elif dtype == "float64": + prop = FloatMetadataProperty(name=name, title=title, visible_for_annotators=self.visible_for_annotators) + else: + _LOGGER.warning(f"Unhandled data type for column {col}: {dtype}") + prop = None + if prop is not None: + properties.append(prop) + return properties + + def _add_text_descriptives_to_metadata( + self, records: List[Union[FeedbackRecord, RemoteFeedbackRecord]], df: pd.DataFrame + ) -> List[Union[FeedbackRecord, RemoteFeedbackRecord]]: + """ + Add the text descriptives metrics extracted previously as metadata + to a list of FeedbackDataset records. + + Args: + records (List[Union[FeedbackRecord, RemoteFeedbackRecord]]): A list of FeedbackDataset or RemoteFeedbackDataset records. + df (pd.DataFrame): The text descriptives dataframe. + + Returns: + List[Union[FeedbackRecord, RemoteFeedbackRecord]]: A list of FeedbackDataset or RemoteFeedbackDataset records with extracted metrics added as metadata. + """ + modified_records = [] + with Progress() as progress_bar: + task = progress_bar.add_task( + "Adding text descriptives to metadata...", total=len(records), visible=self.show_progress + ) + for record, metrics in zip(records, df.to_dict("records")): + filtered_metrics = {key: value for key, value in metrics.items() if not pd.isna(value)} + record.metadata.update(filtered_metrics) + modified_records.append(record) + progress_bar.update(task, advance=1) + return modified_records + + def update_records( + self, records: List[Union[FeedbackRecord, RemoteFeedbackRecord]] + ) -> List[Union[FeedbackRecord, RemoteFeedbackRecord]]: + """ + Extract text descriptives metrics from a list of FeedbackDataset or RemoteFeedbackDataset records, + add them as metadata to the records and return the updated records. + + Args: + records (List[Union[FeedbackRecord, RemoteFeedbackRecord]]): A list of FeedbackDataset or RemoteFeedbackDataset records. + + Returns: + List[Union[FeedbackRecord, RemoteFeedbackRecord]]: A list of FeedbackDataset or RemoteFeedbackDataset records with text descriptives metrics added as metadata. + + >>> from argilla.client.feedback.integrations.textdescriptives import TextDescriptivesExtractor + >>> records = [rg.FeedbackRecord(fields={"text": "This is a test."})] + >>> tde = TextDescriptivesExtractor() + >>> updated_records = tde.update_records(records) + """ + # Extract text descriptives metrics from records + extracted_metrics = self._extract_metrics_for_all_fields(records) + # If the dataframe doesn't contain any columns, return the original records and log a warning + if extracted_metrics.shape[1] == 0: + _LOGGER.warning( + "No text descriptives metrics were extracted. This could be because the metrics contained NaNs." + ) + return records + else: + # Cast integer and boolean columns to Python native types + extracted_metrics = self._cast_to_python_types(extracted_metrics) + # Clean column names + extracted_metrics.columns = [self._clean_column_name(col) for col in extracted_metrics.columns] + # Add the metrics to the metadata of the records + modified_records = self._add_text_descriptives_to_metadata(records, extracted_metrics) + return modified_records + + def update_dataset( + self, dataset: Union[FeedbackDataset, RemoteFeedbackDataset] + ) -> Union[FeedbackDataset, RemoteFeedbackDataset]: + """ + Extract text descriptives metrics from records in a FeedbackDataset + or RemoteFeedbackDataset, add them as metadata to the records and + return the updated dataset. + + Args: + dataset (Union[FeedbackDataset, RemoteFeedbackDataset]): A FeedbackDataset or RemoteFeedbackDataset. + + Returns: + Union[FeedbackDataset, RemoteFeedbackDataset]: A FeedbackDataset or RemoteFeedbackDataset with text descriptives metrics added as metadata. + + >>> import argilla as rg + >>> from argilla.client.feedback.integrations.textdescriptives import TextDescriptivesExtractor + >>> rg.init(...) + >>> dataset = rg.FeedbackDataset.from_argilla(name="my-dataset") + >>> tde = TextDescriptivesExtractor() + >>> updated_dataset = tde.update_dataset(dataset) + + """ + if isinstance(dataset, (FeedbackDataset, RemoteFeedbackDataset)): + records = dataset.records + else: + raise ValueError( + f"Provided object is of `type={type(dataset)}` while only `type=FeedbackDataset` or `type=RemoteFeedbackDataset` are allowed." + ) + # Extract text descriptives metrics from records + extracted_metrics = self._extract_metrics_for_all_fields(records) + # Cast integer and boolean columns to Python native types + extracted_metrics = self._cast_to_python_types(extracted_metrics) + # Clean column names + extracted_metrics.columns = [self._clean_column_name(col) for col in extracted_metrics.columns] + # Create metadata properties based on dataframe columns and data types + metadata_properties = self._create_metadata_properties(extracted_metrics) + # Add each metadata property iteratively to the dataset + [dataset.add_metadata_property(prop) for prop in metadata_properties] + # Add the metrics to the metadata + if isinstance(dataset, FeedbackDataset): + with Progress() as progress_bar: + task = progress_bar.add_task( + "Adding text descriptives to metadata...", total=len(records), visible=self.show_progress + ) + for record, metrics in zip(records, extracted_metrics.to_dict("records")): + filtered_metrics = {key: value for key, value in metrics.items() if not pd.isna(value)} + record.metadata.update(filtered_metrics) + progress_bar.update(task, advance=1) + elif isinstance(dataset, RemoteFeedbackDataset): + modified_records = self._add_text_descriptives_to_metadata(records, extracted_metrics) + dataset = dataset.update_records(modified_records) + return dataset diff --git a/tests/unit/client/feedback/integrations/test_textdescriptives_.py b/tests/unit/client/feedback/integrations/test_textdescriptives_.py new file mode 100644 index 0000000000..61d0b49845 --- /dev/null +++ b/tests/unit/client/feedback/integrations/test_textdescriptives_.py @@ -0,0 +1,285 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pandas as pd +import pytest +from argilla.client.feedback.dataset import FeedbackDataset +from argilla.client.feedback.integrations.textdescriptives import TextDescriptivesExtractor +from argilla.client.feedback.schemas.fields import TextField +from argilla.client.feedback.schemas.metadata import ( + FloatMetadataProperty, + IntegerMetadataProperty, + TermsMetadataProperty, +) +from argilla.client.feedback.schemas.questions import TextQuestion +from argilla.client.feedback.schemas.records import FeedbackRecord + + +@pytest.fixture +def records(): + return [ + FeedbackRecord(fields={"text": "This is a test."}), + FeedbackRecord(fields={"text": "This is another test."}), + ] + + +@pytest.mark.parametrize( + "records", + [ + [ + FeedbackRecord( + fields={"required-field": "This is a test.", "optional-field": None}, + ), + FeedbackRecord( + fields={"required-field": "This is another test.", "optional-field": None}, + ), + ], + [ + FeedbackRecord( + fields={"required-field": "This is a test.", "optional-field": "This is also a test."}, + ), + FeedbackRecord( + fields={"required-field": "This is another test.", "optional-field": "This is also another test."}, + ), + ], + [ + FeedbackRecord( + fields={"required-field": "This is a test."}, + metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, + ), + FeedbackRecord( + fields={"required-field": "This is another test."}, + metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, + ), + ], + ], +) +def test_extract_metrics_for_single_field(records) -> None: + tde = TextDescriptivesExtractor() + field_metrics = tde._extract_metrics_for_single_field(records, "required-field") + assert field_metrics["required-field_n_tokens"].values[0] == 4 + assert len(field_metrics) == len(records) # Assert the number of rows in the DataFrame + assert isinstance(field_metrics, pd.DataFrame) # Assert the data type of the DataFrame + assert "required-field_n_tokens" in field_metrics.columns # Assert the presence of the column + assert field_metrics["required-field_n_tokens"].values[0] == 4 # Assert the value of the column + assert "required-field" not in field_metrics.columns # Assert that text column has been dropped + assert not field_metrics.isnull().values.any() # Assert no columns with NaN values + assert "optional-field" not in field_metrics.columns + + +def test_extract_metrics_for_single_field_empty_field() -> None: + tde = TextDescriptivesExtractor() + records = [ + FeedbackRecord( + fields={"required-field": "This is a test.", "optional-field": None}, + ), + FeedbackRecord( + fields={"required-field": "This is another test.", "optional-field": None}, + ), + ] + field_metrics = tde._extract_metrics_for_single_field(records, "optional-field") + assert field_metrics is None + + +@pytest.mark.parametrize( + "records", + [ + [ + FeedbackRecord( + fields={"required-field": "This is a test.", "optional-field": None}, + ), + FeedbackRecord( + fields={"required-field": "This is another test.", "optional-field": None}, + ), + ], + [ + FeedbackRecord( + fields={"required-field": "This is a test.", "optional-field": "This is also a test."}, + ), + FeedbackRecord( + fields={"required-field": "This is another test.", "optional-field": "This is also another test."}, + ), + ], + [ + FeedbackRecord( + fields={"required-field": "This is a test."}, + metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, + ), + FeedbackRecord( + fields={"required-field": "This is another test."}, + metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, + ), + ], + ], +) +def test_extract_metrics_for_all_fields(records) -> None: + tde = TextDescriptivesExtractor() + field_metrics = tde._extract_metrics_for_all_fields(records) + expected_fields = [key for record in records for key, value in record.fields.items() if value is not None] + assert field_metrics["required-field_n_tokens"].values[0] == 4 + assert all(any(field == col or field + "_" in col for col in field_metrics.columns) for field in expected_fields) + + +def test_cast_to_python_types() -> None: + tde = TextDescriptivesExtractor() + df = pd.DataFrame( + { + "col_int": [1, 2, 3], + "col_bool": [True, False, True], + "col_float": [1.234, 2.345, 3.456], + } + ) + df_result = tde._cast_to_python_types(df) + assert df_result["col_int"].dtype == "int32" or df_result["col_int"].dtype == "int64" + assert df_result["col_bool"].dtype == "object" + assert df_result["col_float"].dtype == "float64" or df_result["col_float"].dtype == "float32" + assert df_result["col_float"].values[0] == 1.23 + assert isinstance(df_result, pd.DataFrame) + + +def test_clean_column_name() -> None: + tde = TextDescriptivesExtractor() + assert tde._clean_column_name("Test_Col") == "test_col" + assert tde._clean_column_name("test col") == "test_col" + assert tde._clean_column_name("Test-Col") == "test_col" + assert tde._clean_column_name("Test.Col") == "test_col" + + +@pytest.mark.parametrize( + "column_name, expected_prop_type, expected_title, expected_visible, expected_type, expected_values", + [ + ("col_int", IntegerMetadataProperty, "Col Int", True, "integer", None), + ("col_bool", TermsMetadataProperty, "Col Bool", True, "terms", ["True", "False"]), + ("col_float", FloatMetadataProperty, "Col Float", True, "float", None), + ("col_obj", TermsMetadataProperty, "Col Obj", True, "terms", ["value_1", "value_2", "value_3"]), + ], +) +def test_create_metadata_properties( + column_name, expected_prop_type, expected_title, expected_visible, expected_type, expected_values +) -> None: + tde = TextDescriptivesExtractor() + df = pd.DataFrame( + { + "col_int": pd.Series([1, 2, 3], dtype="int32"), + "col_bool": pd.Series([True, False, True], dtype="bool"), + "col_float": pd.Series([1.234, 2.345, 3.456], dtype="float64"), + "col_obj": pd.Series(["value_1", "value_2", "value_3"], dtype="object"), + } + ) + properties = tde._create_metadata_properties(df) + prop = next((prop for prop in properties if prop.name == column_name), None) + assert isinstance(prop, expected_prop_type) + assert prop.name == column_name + assert prop.title == expected_title + assert prop.visible_for_annotators == expected_visible + assert prop.type == expected_type + if isinstance(prop, TermsMetadataProperty): + assert prop.values == expected_values + + +@pytest.mark.parametrize( + "records", + [ + [ + FeedbackRecord( + fields={"required-field": "This is a test.", "optional-field": None}, + ), + FeedbackRecord( + fields={"required-field": "This is another test.", "optional-field": None}, + ), + ], + [ + FeedbackRecord( + fields={"required-field": "This is a test.", "optional-field": "This is also a test."}, + ), + FeedbackRecord( + fields={"required-field": "This is another test.", "optional-field": "This is also another test."}, + ), + ], + [ + FeedbackRecord( + fields={"required-field": "This is a test."}, + metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, + ), + FeedbackRecord( + fields={"required-field": "This is another test."}, + metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, + ), + ], + ], +) +def test_update_records_metrics_extracted(records) -> None: + tde = TextDescriptivesExtractor() + extracted_metrics = pd.DataFrame({"text_n_tokens": [4, 5]}) + tde._extract_metrics_for_all_fields = MagicMock(return_value=extracted_metrics) + tde._cast_to_python_types = MagicMock(return_value=extracted_metrics) + tde._clean_column_name = MagicMock(side_effect=lambda col: col) + tde._add_text_descriptives_to_metadata = MagicMock(return_value=records) + updated_records = tde.update_records(records) + tde._extract_metrics_for_all_fields.assert_called_once_with(records) + tde._cast_to_python_types.assert_called_once_with(extracted_metrics) + tde._clean_column_name.assert_called_with("text_n_tokens") + tde._add_text_descriptives_to_metadata.assert_called_once_with(records, extracted_metrics) + assert updated_records == records + + +def test_update_records_no_metrics_extracted(records): + tde = TextDescriptivesExtractor() + tde._extract_metrics_for_all_fields = MagicMock(return_value=pd.DataFrame()) + updated_records = tde.update_records(records) + assert updated_records == records + + +def test_update_feedback_dataset(): + dataset = FeedbackDataset( + fields=[TextField(name="text")], + questions=[TextQuestion(name="question")], + ) + records = [ + FeedbackRecord(fields={"text": "This is a test."}), + FeedbackRecord(fields={"text": "This is another test."}), + ] + dataset.add_records(records) + + tde = TextDescriptivesExtractor() + + extracted_metrics = pd.DataFrame({"text_n_tokens": [4, 5]}) + tde._extract_metrics_for_all_fields = MagicMock(return_value=extracted_metrics) + tde._cast_to_python_types = MagicMock(return_value=extracted_metrics) + tde._clean_column_name = MagicMock(side_effect=lambda col: col) + tde._create_metadata_properties = MagicMock(return_value=[IntegerMetadataProperty(name="text_n_tokens")]) + + updated_dataset = tde.update_dataset(dataset) + + tde._extract_metrics_for_all_fields.assert_called_once_with(records) + tde._cast_to_python_types.assert_called_once_with(extracted_metrics) + tde._clean_column_name.assert_called_with("text_n_tokens") + tde._create_metadata_properties.assert_called_once_with(extracted_metrics) + assert updated_dataset == dataset + assert isinstance(updated_dataset, FeedbackDataset) + assert updated_dataset.metadata_properties == [ + IntegerMetadataProperty( + name="text_n_tokens", title="text_n_tokens", visible_for_annotators=True, type="integer", min=None, max=None + ) + ] + + +def test_update_dataset_with_invalid_dataset(): + tde = TextDescriptivesExtractor() + dataset = "invalid_dataset" + + with pytest.raises(ValueError): + tde.update_dataset(dataset)