Skip to content

Commit

Permalink
fix:工具执行结束时,保存输入信息
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Apr 17, 2024
1 parent cc2a414 commit 636ac3d
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,13 @@ async def on_tool_start(self, serialized: Dict[str, Any], input_str: str,
input_info = {'tool_key': tool_name, 'serialized': serialized, 'input_str': input_str}
self.tool_cache[kwargs.get('run_id').hex] = {
'input': input_info,
'category': tool_category
'category': tool_category,
'steps': f'Tool input: \n\n{input_str}\n\n',
}
resp = ChatResponse(type='start',
category=tool_category,
intermediate_steps=f'Tool input: {input_str}',
message=json.dumps(input_info),
intermediate_steps=self.tool_cache[kwargs.get('run_id').hex]['steps'],
message=json.dumps(input_info, ensure_ascii=False),
flow_id=self.flow_id,
chat_id=self.chat_id,
extra=json.dumps({'run_id': kwargs.get('run_id').hex}))
Expand All @@ -409,8 +410,12 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
observation_prefix = kwargs.get('observation_prefix', 'Tool output: ')

result = output
# 从tool cache中获取input信息
input_info = self.tool_cache.get(kwargs.get('run_id').hex)
# Create a formatted message.
intermediate_steps = f'{observation_prefix}{result}'
intermediate_steps = f'{observation_prefix}\n\n{result}'
if input_info:
intermediate_steps = f'{input_info["steps"]}\n\n{intermediate_steps}'

tool_name, tool_category = self.parse_tool_category(kwargs.get('name'))

Expand All @@ -419,15 +424,12 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
resp = ChatResponse(type='end',
category=tool_category,
intermediate_steps=intermediate_steps,
message=json.dumps(output_info),
message=json.dumps(output_info, ensure_ascii=False),
flow_id=self.flow_id,
chat_id=self.chat_id,
extra=json.dumps({'run_id': kwargs.get('run_id').hex}))

await self.websocket.send_json(resp.dict())

# 从tool cache中获取input信息
input_info = self.tool_cache.get(kwargs.get('run_id').hex)
if input_info:
output_info.update(input_info['input'])
ChatMessageDao.insert_one(
Expand All @@ -452,8 +454,8 @@ async def on_tool_error(self, error: Union[Exception, KeyboardInterrupt],
output_info.update(input_info['input'])
resp = ChatResponse(type='end',
category=input_info['category'],
intermediate_steps='Tool output: Error: ' + str(error),
message=json.dumps(output_info),
intermediate_steps=f'{input_info["steps"]}\n\nTool output:\n\n Error: ' + str(error),
message=json.dumps(output_info, ensure_ascii=False),
flow_id=self.flow_id,
chat_id=self.chat_id,
extra=json.dumps({'run_id': kwargs.get('run_id').hex}))
Expand Down

0 comments on commit 636ac3d

Please sign in to comment.