Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: delete suggestion from record on search engine #4336

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/argilla/server/apis/v1/handlers/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from argilla.server.models import Suggestion, User
from argilla.server.policies import SuggestionPolicyV1, authorize
from argilla.server.schemas.v1.suggestions import Suggestion as SuggestionSchema
from argilla.server.search_engine import SearchEngine, get_search_engine
from argilla.server.security import auth

router = APIRouter(tags=["suggestions"])
Expand All @@ -41,6 +42,7 @@ async def _get_suggestion(db: "AsyncSession", suggestion_id: UUID) -> Suggestion
async def delete_suggestion(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
suggestion_id: UUID,
current_user: User = Security(auth.get_current_user),
):
Expand All @@ -49,6 +51,6 @@ async def delete_suggestion(
await authorize(current_user, SuggestionPolicyV1.delete(suggestion))

try:
return await datasets.delete_suggestion(db, suggestion)
return await datasets.delete_suggestion(db, search_engine, suggestion)
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))
18 changes: 15 additions & 3 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,20 +1092,32 @@ async def upsert_suggestion(

async def delete_suggestions(db: "AsyncSession", record: Record, suggestions_ids: List[UUID]) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code may not be affected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's clearly affected yes. I will take a look.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed it and we will take a first approach reindexing suggestions individually.

params = [Suggestion.id.in_(suggestions_ids), Suggestion.record_id == record.id]

await Suggestion.delete_many(db=db, params=params)


async def get_suggestion_by_id(db: "AsyncSession", suggestion_id: "UUID") -> Union[Suggestion, None]:
result = await db.execute(
select(Suggestion)
.filter_by(id=suggestion_id)
.options(selectinload(Suggestion.record).selectinload(Record.dataset))
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)

return result.scalar_one_or_none()


async def delete_suggestion(db: "AsyncSession", suggestion: Suggestion) -> Suggestion:
return await suggestion.delete(db)
async def delete_suggestion(db: "AsyncSession", search_engine: SearchEngine, suggestion: Suggestion) -> Suggestion:
async with db.begin_nested():
suggestion = await suggestion.delete(db, autocommit=False)
# TODO: Should we touch here dataset last_activity?
jfcalvo marked this conversation as resolved.
Show resolved Hide resolved
await search_engine.delete_record_suggestion(suggestion)

await db.commit()

return suggestion


async def get_metadata_property_by_id(db: "AsyncSession", metadata_property_id: UUID) -> Optional[MetadataProperty]:
Expand Down
8 changes: 5 additions & 3 deletions src/argilla/server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Generic,
Iterable,
List,
Literal,
Optional,
Type,
TypeVar,
Expand All @@ -36,12 +35,11 @@
from argilla.server.enums import (
MetadataPropertyType,
RecordSortField,
ResponseStatus,
ResponseStatusFilter,
SimilarityOrder,
SortOrder,
)
from argilla.server.models import Dataset, MetadataProperty, Record, Response, User, Vector, VectorSettings
from argilla.server.models import Dataset, MetadataProperty, Record, Response, Suggestion, User, Vector, VectorSettings

__all__ = [
"SearchEngine",
Expand Down Expand Up @@ -311,6 +309,10 @@ async def update_record_response(self, response: Response):
async def delete_record_response(self, response: Response):
pass

@abstractmethod
async def delete_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def search(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/argilla/server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ async def delete_record_response(self, response: Response):
index_name, id=record.id, body={"script": f'ctx._source["responses"].remove("{response.user.username}")'}
)

async def delete_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)

await self._update_document_request(
index_name,
id=suggestion.record_id,
body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'},
)

async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]):
index_name = await self._get_index_or_raise(dataset)

Expand Down
7 changes: 6 additions & 1 deletion tests/unit/server/api/v1/test_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
from argilla._constants import API_KEY_HEADER_NAME
from argilla.server.models import Suggestion, UserRole
from argilla.server.search_engine import SearchEngine
from sqlalchemy import func, select

from tests.factories import SuggestionFactory, UserFactory
Expand All @@ -30,7 +31,9 @@
@pytest.mark.asyncio
class TestSuiteSuggestions:
@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner])
async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSession", role: UserRole) -> None:
async def test_delete_suggestion(
self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", role: UserRole
) -> None:
suggestion = await SuggestionFactory.create()
user = await UserFactory.create(role=role, workspaces=[suggestion.record.dataset.workspace])

Expand All @@ -50,6 +53,8 @@ async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSe
}
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0

mock_search_engine.delete_record_suggestion.assert_called_once_with(suggestion)

async def test_delete_suggestion_non_existent(self, async_client: "AsyncClient", owner_auth_header: dict) -> None:
response = await async_client.delete(f"/api/v1/suggestions/{uuid4()}", headers=owner_auth_header)

Expand Down
Loading