diff --git a/autogpts/autogpt/autogpt/agents/agent.py b/autogpts/autogpt/autogpt/agents/agent.py index 8460292a2fc3..245097a0f610 100644 --- a/autogpts/autogpt/autogpt/agents/agent.py +++ b/autogpts/autogpt/autogpt/agents/agent.py @@ -147,7 +147,7 @@ def execute( result: ActionResult if command_name == "human_feedback": - result = ActionInterruptedByHuman(user_input) + result = ActionInterruptedByHuman(feedback=user_input) self.message_history.add( "user", "I interrupted the execution of the command you proposed " @@ -185,9 +185,9 @@ def execute( ) self.context.add(context_item) - result = ActionSuccessResult(return_value) + result = ActionSuccessResult(outputs=return_value) except AgentException as e: - result = ActionErrorResult(e.message, e) + result = ActionErrorResult(reason=e.message, error=e) result_tlength = count_string_tokens(str(result), self.llm.name) history_tlength = count_string_tokens( diff --git a/autogpts/autogpt/autogpt/agents/base.py b/autogpts/autogpt/autogpt/agents/base.py index e530557eebf1..51ba1be11269 100644 --- a/autogpts/autogpt/autogpt/agents/base.py +++ b/autogpts/autogpt/autogpt/agents/base.py @@ -16,7 +16,7 @@ from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS, get_openai_command_specs from autogpt.llm.utils import count_message_tokens, create_chat_completion from autogpt.memory.message_history import MessageHistory -from autogpt.models.agent_actions import ActionHistory, ActionResult +from autogpt.models.agent_actions import EpisodicActionHistory, ActionResult from autogpt.prompts.generator import PromptGenerator from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT @@ -90,10 +90,10 @@ def __init__( defaults to 75% of `llm.max_tokens`. """ - self.event_history = ActionHistory() + self.event_history = EpisodicActionHistory() self.message_history = MessageHistory( - self.llm, + model=self.llm, max_summary_tlength=summary_max_tlength or self.send_token_limit // 6, ) diff --git a/autogpts/autogpt/autogpt/agents/features/watchdog.py b/autogpts/autogpt/autogpt/agents/features/watchdog.py index caddf195c5f9..7c43ef90599d 100644 --- a/autogpts/autogpt/autogpt/agents/features/watchdog.py +++ b/autogpts/autogpt/autogpt/agents/features/watchdog.py @@ -3,7 +3,7 @@ import logging from contextlib import ExitStack -from autogpt.models.agent_actions import ActionHistory +from autogpt.models.agent_actions import EpisodicActionHistory from ..base import BaseAgent @@ -16,7 +16,7 @@ class WatchdogMixin: looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think. """ - event_history: ActionHistory + event_history: EpisodicActionHistory def __init__(self, **kwargs) -> None: # Initialize other bases first, because we need the event_history from BaseAgent @@ -38,7 +38,7 @@ def think(self, *args, **kwargs) -> BaseAgent.ThoughtProcessOutput: and self.config.fast_llm != self.config.smart_llm ): # Detect repetitive commands - previous_cycle = self.event_history.cycles[self.event_history.cursor - 1] + previous_cycle = self.event_history.episodes[self.event_history.cursor - 1] if ( command_name == previous_cycle.action.name and command_args == previous_cycle.action.args diff --git a/autogpts/autogpt/autogpt/agents/planning_agent.py b/autogpts/autogpt/autogpt/agents/planning_agent.py index f4b6fa4ce7c4..46e6615b214f 100644 --- a/autogpts/autogpt/autogpt/agents/planning_agent.py +++ b/autogpts/autogpt/autogpt/agents/planning_agent.py @@ -23,7 +23,7 @@ ) from autogpt.models.agent_actions import ( ActionErrorResult, - ActionHistory, + EpisodicActionHistory, ActionInterruptedByHuman, ActionResult, ActionSuccessResult, @@ -69,7 +69,7 @@ def __init__( self.log_cycle_handler = LogCycleHandler() """LogCycleHandler for structured debug logging.""" - self.action_history = ActionHistory() + self.action_history = EpisodicActionHistory() self.plan: list[str] = [] """List of steps that the Agent plans to take""" @@ -229,7 +229,7 @@ def on_before_think(self, *args, **kwargs) -> ChatSequence: self.ai_config.ai_name, self.created_at, self.cycle_count, - self.action_history.cycles, + self.action_history.episodes, "action_history.json", ) self.log_cycle_handler.log_cycle( @@ -250,7 +250,7 @@ def execute( result: ActionResult if command_name == "human_feedback": - result = ActionInterruptedByHuman(user_input) + result = ActionInterruptedByHuman(feedback=user_input) self.log_cycle_handler.log_cycle( self.ai_config.ai_name, self.created_at, @@ -279,9 +279,9 @@ def execute( self.context.add(return_value[1]) return_value = return_value[0] - result = ActionSuccessResult(return_value) + result = ActionSuccessResult(outputs=return_value) except AgentException as e: - result = ActionErrorResult(e.message, e) + result = ActionErrorResult(reason=e.message, error=e) result_tlength = count_string_tokens(str(result), self.llm.name) memory_tlength = count_string_tokens( diff --git a/autogpts/autogpt/autogpt/commands/file_context.py b/autogpts/autogpt/autogpt/commands/file_context.py index c75e45621ffc..61bead91a9dd 100644 --- a/autogpts/autogpt/autogpt/commands/file_context.py +++ b/autogpts/autogpt/autogpt/commands/file_context.py @@ -67,7 +67,10 @@ def open_file(file_path: Path, agent: Agent) -> tuple[str, FileContextItem]: file_path = relative_file_path or file_path - file = FileContextItem(file_path, agent.workspace.root) + file = FileContextItem( + file_path_in_workspace=file_path, + workspace_path=agent.workspace.root, + ) if file in agent_context: raise DuplicateOperationError(f"The file {file_path} is already open") @@ -114,7 +117,10 @@ def open_folder(path: Path, agent: Agent) -> tuple[str, FolderContextItem]: path = relative_path or path - folder = FolderContextItem(path, agent.workspace.root) + folder = FolderContextItem( + path_in_workspace=path, + workspace_path=agent.workspace.root, + ) if folder in agent_context: raise DuplicateOperationError(f"The folder {path} is already open") diff --git a/autogpts/autogpt/autogpt/config/ai_config.py b/autogpts/autogpt/autogpt/config/ai_config.py index 1070968984ff..392a0198519f 100644 --- a/autogpts/autogpt/autogpt/config/ai_config.py +++ b/autogpts/autogpt/autogpt/config/ai_config.py @@ -1,14 +1,13 @@ """A module that contains the AIConfig class object that contains the configuration""" from __future__ import annotations -from dataclasses import dataclass, field from pathlib import Path +from pydantic import BaseModel, Field import yaml -@dataclass -class AIConfig: +class AIConfig(BaseModel): """ A class object that contains the configuration information for the AI @@ -21,7 +20,7 @@ class AIConfig: ai_name: str = "" ai_role: str = "" - ai_goals: list[str] = field(default_factory=list[str]) + ai_goals: list[str] = Field(default_factory=list[str]) api_budget: float = 0.0 @staticmethod @@ -53,7 +52,12 @@ def load(ai_settings_file: str | Path) -> "AIConfig": ] api_budget = config_params.get("api_budget", 0.0) - return AIConfig(ai_name, ai_role, ai_goals, api_budget) + return AIConfig( + ai_name=ai_name, + ai_role=ai_role, + ai_goals=ai_goals, + api_budget=api_budget + ) def save(self, ai_settings_file: str | Path) -> None: """ @@ -66,11 +70,5 @@ def save(self, ai_settings_file: str | Path) -> None: None """ - config = { - "ai_name": self.ai_name, - "ai_role": self.ai_role, - "ai_goals": self.ai_goals, - "api_budget": self.api_budget, - } with open(ai_settings_file, "w", encoding="utf-8") as file: - yaml.dump(config, file, allow_unicode=True) + yaml.dump(self.dict(), file, allow_unicode=True) diff --git a/autogpts/autogpt/autogpt/config/ai_directives.py b/autogpts/autogpt/autogpt/config/ai_directives.py index 76340c68e3f4..38f169bea1e0 100644 --- a/autogpts/autogpt/autogpt/config/ai_directives.py +++ b/autogpts/autogpt/autogpt/config/ai_directives.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from dataclasses import dataclass import yaml +from pydantic import BaseModel from autogpt.logs.helpers import request_user_double_check from autogpt.utils import validate_yaml_file @@ -11,8 +11,7 @@ logger = logging.getLogger(__name__) -@dataclass -class AIDirectives: +class AIDirectives(BaseModel): """An object that contains the basic directives for the AI prompt. Attributes: diff --git a/autogpts/autogpt/autogpt/llm/base.py b/autogpts/autogpt/autogpt/llm/base.py index 1ac00112d931..99afac8f68f1 100644 --- a/autogpts/autogpt/autogpt/llm/base.py +++ b/autogpts/autogpt/autogpt/llm/base.py @@ -1,12 +1,9 @@ from __future__ import annotations +import json -from copy import deepcopy -from dataclasses import dataclass, field from math import ceil, floor -from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload - -if TYPE_CHECKING: - from autogpt.llm.providers.openai import OpenAIFunctionCall +from pydantic import BaseModel, Field +from typing import Any, Literal, Optional, Type, TypedDict, TypeVar, overload MessageRole = Literal["system", "user", "assistant", "function"] MessageType = Literal["ai_response", "action_result"] @@ -31,20 +28,30 @@ class FunctionCallDict(TypedDict): arguments: str -@dataclass -class Message: +class Message(BaseModel): """OpenAI Message object containing a role and the message content""" role: MessageRole content: str - type: MessageType | None = None + type: Optional[MessageType] + + def __init__( + self, + role: MessageRole, + content: str, + type: Optional[MessageType] = None + ): + super().__init__( + role=role, + content=content, + type=type, + ) def raw(self) -> MessageDict: return {"role": self.role, "content": self.content} -@dataclass -class ModelInfo: +class ModelInfo(BaseModel): """Struct for model information. Would be lovely to eventually get this directly from APIs, but needs to be scraped from @@ -56,26 +63,22 @@ class ModelInfo: prompt_token_cost: float -@dataclass class CompletionModelInfo(ModelInfo): """Struct for generic completion model information.""" completion_token_cost: float -@dataclass class ChatModelInfo(CompletionModelInfo): """Struct for chat model information.""" supports_functions: bool = False -@dataclass class TextModelInfo(CompletionModelInfo): """Struct for text completion model information.""" -@dataclass class EmbeddingModelInfo(ModelInfo): """Struct for embedding model information.""" @@ -86,12 +89,11 @@ class EmbeddingModelInfo(ModelInfo): TChatSequence = TypeVar("TChatSequence", bound="ChatSequence") -@dataclass -class ChatSequence: +class ChatSequence(BaseModel): """Utility container for a chat sequence""" model: ChatModelInfo - messages: list[Message] = field(default_factory=list[Message]) + messages: list[Message] = Field(default_factory=list[Message]) @overload def __getitem__(self, key: int) -> Message: @@ -103,7 +105,7 @@ def __getitem__(self: TChatSequence, key: slice) -> TChatSequence: def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence: if isinstance(key, slice): - copy = deepcopy(self) + copy = self.copy(deep=True) copy.messages = self.messages[key] return copy return self.messages[key] @@ -141,7 +143,7 @@ def for_model( ) -> TChatSequence: from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS - if not model_name in OPEN_AI_CHAT_MODELS: + if model_name not in OPEN_AI_CHAT_MODELS: raise ValueError(f"Unknown chat model '{model_name}'") return cls( @@ -175,23 +177,43 @@ def separator(text: str): """ -@dataclass -class LLMResponse: +class LLMResponse(BaseModel): """Standard response struct for a response from an LLM model.""" model_info: ModelInfo -@dataclass class EmbeddingModelResponse(LLMResponse): """Standard response struct for a response from an embedding model.""" - embedding: list[float] = field(default_factory=list) + embedding: list[float] = Field(default_factory=list) -@dataclass class ChatModelResponse(LLMResponse): """Standard response struct for a response from a chat LLM.""" content: Optional[str] - function_call: Optional[OpenAIFunctionCall] + function_call: Optional[LLMFunctionCall] + + +class LLMFunctionCall(BaseModel): + """Represents a function call as generated by an OpenAI model + + Attributes: + name: the name of the function that the LLM wants to call + arguments: a stringified JSON object (unverified) containing `arg: value` pairs + """ + + name: str + arguments: dict[str, Any] = {} + + @staticmethod + def parse(raw: FunctionCallDict): + return LLMFunctionCall( + name=raw["name"], + arguments=json.loads(raw["arguments"]), + ) + + +# Complete model initialization; necessary because of order of definition +ChatModelResponse.update_forward_refs() diff --git a/autogpts/autogpt/autogpt/llm/providers/openai.py b/autogpts/autogpt/autogpt/llm/providers/openai.py index 433e5b7b1fc8..b018604dcc7b 100644 --- a/autogpts/autogpt/autogpt/llm/providers/openai.py +++ b/autogpts/autogpt/autogpt/llm/providers/openai.py @@ -4,7 +4,7 @@ import logging import time from dataclasses import dataclass -from typing import Callable, List, Optional +from typing import Callable, List, Optional, TypeVar from unittest.mock import patch import openai @@ -118,8 +118,10 @@ **OPEN_AI_EMBEDDING_MODELS, } +T = TypeVar("T", bound=Callable) -def meter_api(func: Callable): + +def meter_api(func: T) -> T: """Adds ApiManager metering to functions which make OpenAI API calls""" from autogpt.llm.api_manager import ApiManager @@ -145,6 +147,7 @@ def metering_wrapper(*args, **kwargs): update_usage_with_response(openai_obj) return openai_obj + @functools.wraps(func) def metered_func(*args, **kwargs): with patch.object( engine_api_resource.util, @@ -179,7 +182,7 @@ def retry_api( ) backoff_msg = "Waiting {backoff} seconds..." - def _wrapper(func: Callable): + def _wrapper(func: T) -> T: @functools.wraps(func) def _wrapped(*args, **kwargs): user_warned = not warn_user @@ -286,19 +289,6 @@ def create_embedding( ) -@dataclass -class OpenAIFunctionCall: - """Represents a function call as generated by an OpenAI model - - Attributes: - name: the name of the function that the LLM wants to call - arguments: a stringified JSON object (unverified) containing `arg: value` pairs - """ - - name: str - arguments: str - - @dataclass class OpenAIFunctionSpec: """Represents a "function" in OpenAI, which is mapped to a Command in Auto-GPT""" diff --git a/autogpts/autogpt/autogpt/llm/utils/__init__.py b/autogpts/autogpt/autogpt/llm/utils/__init__.py index 5438bdd853cd..f51666c2b18b 100644 --- a/autogpts/autogpt/autogpt/llm/utils/__init__.py +++ b/autogpts/autogpt/autogpt/llm/utils/__init__.py @@ -1,27 +1,30 @@ from __future__ import annotations -from typing import List, Literal, Optional +import logging +from typing import Optional from colorama import Fore from autogpt.config import Config -from ..api_manager import ApiManager from ..base import ( ChatModelResponse, ChatSequence, FunctionCallDict, + LLMFunctionCall, Message, ResponseMessageDict, ) from ..providers import openai as iopenai from ..providers.openai import ( OPEN_AI_CHAT_MODELS, - OpenAIFunctionCall, OpenAIFunctionSpec, count_openai_functions_tokens, ) -from .token_counter import * + +from .token_counter import count_message_tokens, count_string_tokens + +logger = logging.getLogger(__name__) def call_ai_function( @@ -96,7 +99,7 @@ def create_text_completion( def create_chat_completion( prompt: ChatSequence, config: Config, - functions: Optional[List[OpenAIFunctionSpec]] = None, + functions: Optional[list[OpenAIFunctionSpec]] = None, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, @@ -182,9 +185,5 @@ def create_chat_completion( return ChatModelResponse( model_info=OPEN_AI_CHAT_MODELS[model], content=content, - function_call=OpenAIFunctionCall( - name=function_call["name"], arguments=function_call["arguments"] - ) - if function_call - else None, + function_call=LLMFunctionCall.parse(function_call) if function_call else None, ) diff --git a/autogpts/autogpt/autogpt/memory/message_history.py b/autogpts/autogpt/autogpt/memory/message_history.py index 60c30ec4ad18..43395fb290e7 100644 --- a/autogpts/autogpt/autogpt/memory/message_history.py +++ b/autogpts/autogpt/autogpt/memory/message_history.py @@ -3,7 +3,6 @@ import copy import json import logging -from dataclasses import dataclass from typing import TYPE_CHECKING, Iterator, Optional if TYPE_CHECKING: @@ -23,7 +22,6 @@ logger = logging.getLogger(__name__) -@dataclass class MessageHistory(ChatSequence): max_summary_tlength: int = 500 agent: Optional[BaseAgent | Agent] = None diff --git a/autogpts/autogpt/autogpt/memory/vector/memory_item.py b/autogpts/autogpt/autogpt/memory/vector/memory_item.py index 315f17d61673..df5b6aef735b 100644 --- a/autogpts/autogpt/autogpt/memory/vector/memory_item.py +++ b/autogpts/autogpt/autogpt/memory/vector/memory_item.py @@ -1,12 +1,12 @@ from __future__ import annotations -import dataclasses import json import logging from typing import Literal import ftfy import numpy as np +from pydantic import BaseModel from autogpt.config import Config from autogpt.llm import Message @@ -20,8 +20,8 @@ MemoryDocType = Literal["webpage", "text_file", "code_file", "agent_history"] -@dataclasses.dataclass -class MemoryItem: +# FIXME: implement validators instead of allowing arbitrary types +class MemoryItem(BaseModel, arbitrary_types_allowed=True): """Memory object containing raw content as well as embeddings""" raw_content: str @@ -94,12 +94,12 @@ def from_text( metadata["source_type"] = source_type return MemoryItem( - text, - summary, - chunks, - chunk_summaries, - e_summary, - e_chunks, + raw_content=text, + summary=summary, + chunks=chunks, + chunk_summaries=chunk_summaries, + e_summary=e_summary, + e_chunks=e_chunks, metadata=metadata, ) @@ -198,8 +198,7 @@ def __eq__(self, other: MemoryItem): ) -@dataclasses.dataclass -class MemoryItemRelevance: +class MemoryItemRelevance(BaseModel): """ Class that encapsulates memory relevance search functionality and data. Instances contain a MemoryItem and its relevance scores for a given query. diff --git a/autogpts/autogpt/autogpt/models/agent_actions.py b/autogpts/autogpt/autogpt/models/agent_actions.py index bf88e42455ae..5fc52db00740 100644 --- a/autogpts/autogpt/autogpt/models/agent_actions.py +++ b/autogpts/autogpt/autogpt/models/agent_actions.py @@ -1,13 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any, Iterator, Literal, Optional +from pydantic import BaseModel + from autogpt.prompts.utils import format_numbered_list, indent -@dataclass -class Action: +class Action(BaseModel): name: str args: dict[str, Any] reasoning: str @@ -16,8 +16,7 @@ def format_call(self) -> str: return f"{self.name}({', '.join([f'{a}={repr(v)}' for a, v in self.args.items()])})" -@dataclass -class ActionSuccessResult: +class ActionSuccessResult(BaseModel): outputs: Any status: Literal["success"] = "success" @@ -27,8 +26,8 @@ def __str__(self) -> str: return f"```\n{self.outputs}\n```" if multiline else str(self.outputs) -@dataclass -class ActionErrorResult: +# FIXME: implement validators instead of allowing arbitrary types +class ActionErrorResult(BaseModel, arbitrary_types_allowed=True): reason: str error: Optional[Exception] = None status: Literal["error"] = "error" @@ -37,8 +36,7 @@ def __str__(self) -> str: return f"Action failed: '{self.reason}'" -@dataclass -class ActionInterruptedByHuman: +class ActionInterruptedByHuman(BaseModel): feedback: str status: Literal["interrupted_by_human"] = "interrupted_by_human" @@ -49,61 +47,63 @@ def __str__(self) -> str: ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman -class ActionHistory: - """Utility container for an action history""" +class Episode(BaseModel): + action: Action + result: ActionResult | None - @dataclass - class CycleRecord: - action: Action - result: ActionResult | None + def __str__(self) -> str: + executed_action = f"Executed `{self.action.format_call()}`" + action_result = f": {self.result}" if self.result else "." + return executed_action + action_result - def __str__(self) -> str: - executed_action = f"Executed `{self.action.format_call()}`" - action_result = f": {self.result}" if self.result else "." - return executed_action + action_result + +class EpisodicActionHistory(BaseModel): + """Utility container for an action history""" cursor: int - cycles: list[CycleRecord] + episodes: list[Episode] - def __init__(self, cycles: list[CycleRecord] = []): - self.cycles = cycles - self.cursor = len(self.cycles) + def __init__(self, episodes: list[Episode] = []): + super().__init__( + episodes=episodes, + cursor=len(episodes), + ) @property - def current_record(self) -> CycleRecord | None: + def current_episode(self) -> Episode | None: if self.cursor == len(self): return None return self[self.cursor] - def __getitem__(self, key: int) -> CycleRecord: - return self.cycles[key] + def __getitem__(self, key: int) -> Episode: + return self.episodes[key] - def __iter__(self) -> Iterator[CycleRecord]: - return iter(self.cycles) + def __iter__(self) -> Iterator[Episode]: + return iter(self.episodes) def __len__(self) -> int: - return len(self.cycles) + return len(self.episodes) def __bool__(self) -> bool: - return len(self.cycles) > 0 + return len(self.episodes) > 0 def register_action(self, action: Action) -> None: - if not self.current_record: - self.cycles.append(self.CycleRecord(action, None)) - assert self.current_record - elif self.current_record.action: + if not self.current_episode: + self.episodes.append(Episode(action=action, result=None)) + assert self.current_episode + elif self.current_episode.action: raise ValueError("Action for current cycle already set") def register_result(self, result: ActionResult) -> None: - if not self.current_record: + if not self.current_episode: raise RuntimeError("Cannot register result for cycle without action") - elif self.current_record.result: + elif self.current_episode.result: raise ValueError("Result for current cycle already set") - self.current_record.result = result - self.cursor = len(self.cycles) + self.current_episode.result = result + self.cursor = len(self.episodes) - def rewind(self, number_of_cycles: int = 0) -> None: + def rewind(self, number_of_episodes: int = 0) -> None: """Resets the history to an earlier state. Params: @@ -111,22 +111,22 @@ def rewind(self, number_of_cycles: int = 0) -> None: When set to 0, it will only reset the current cycle. """ # Remove partial record of current cycle - if self.current_record: - if self.current_record.action and not self.current_record.result: - self.cycles.pop(self.cursor) + if self.current_episode: + if self.current_episode.action and not self.current_episode.result: + self.episodes.pop(self.cursor) # Rewind the specified number of cycles - if number_of_cycles > 0: - self.cycles = self.cycles[:-number_of_cycles] - self.cursor = len(self.cycles) + if number_of_episodes > 0: + self.episodes = self.episodes[:-number_of_episodes] + self.cursor = len(self.episodes) def fmt_list(self) -> str: - return format_numbered_list(self.cycles) + return format_numbered_list(self.episodes) def fmt_paragraph(self) -> str: steps: list[str] = [] - for i, c in enumerate(self.cycles, 1): + for i, c in enumerate(self.episodes, 1): step = f"### Step {i}: Executed `{c.action.format_call()}`\n" step += f'- **Reasoning:** "{c.action.reasoning}"\n' step += ( diff --git a/autogpts/autogpt/autogpt/models/context_item.py b/autogpts/autogpt/autogpt/models/context_item.py index cbf49084b307..7c8e306a3b76 100644 --- a/autogpts/autogpt/autogpt/models/context_item.py +++ b/autogpts/autogpt/autogpt/models/context_item.py @@ -1,9 +1,10 @@ import logging from abc import ABC, abstractmethod -from dataclasses import dataclass from pathlib import Path from typing import Optional +from pydantic import BaseModel, Field + from autogpt.commands.file_operations_utils import read_textual_file logger = logging.getLogger(__name__) @@ -37,8 +38,7 @@ def __str__(self) -> str: ) -@dataclass -class FileContextItem(ContextItem): +class FileContextItem(BaseModel, ContextItem): file_path_in_workspace: Path workspace_path: Path @@ -59,8 +59,7 @@ def content(self) -> str: return read_textual_file(self.file_path, logger) -@dataclass -class FolderContextItem(ContextItem): +class FolderContextItem(BaseModel, ContextItem): path_in_workspace: Path workspace_path: Path @@ -87,8 +86,7 @@ def content(self) -> str: return "\n".join(items) -@dataclass -class StaticContextItem(ContextItem): - description: str - source: Optional[str] - content: str +class StaticContextItem(BaseModel, ContextItem): + item_description: str = Field(alias="description") + item_source: Optional[str] = Field(alias="source") + item_content: str = Field(alias="content") diff --git a/autogpts/autogpt/tests/unit/test_message_history.py b/autogpts/autogpt/tests/unit/test_message_history.py index 08a3a24bd334..a41d1e3a92e5 100644 --- a/autogpts/autogpt/tests/unit/test_message_history.py +++ b/autogpts/autogpt/tests/unit/test_message_history.py @@ -31,7 +31,7 @@ def agent(config: Config): def test_message_history_batch_summary(mocker, agent: Agent, config: Config): - history = MessageHistory(agent.llm, agent=agent) + history = MessageHistory(model=agent.llm, agent=agent) model = config.fast_llm message_tlength = 0 message_count = 0