From 636ac3dd65e0172e11fa3409e35957066c8893c6 Mon Sep 17 00:00:00 2001 From: GuoQing Zhang Date: Wed, 17 Apr 2024 15:20:09 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E5=B7=A5=E5=85=B7=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E7=BB=93=E6=9D=9F=E6=97=B6=EF=BC=8C=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/bisheng/api/v1/callback.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/backend/bisheng/api/v1/callback.py b/src/backend/bisheng/api/v1/callback.py index fbef8c79d..a907697dc 100644 --- a/src/backend/bisheng/api/v1/callback.py +++ b/src/backend/bisheng/api/v1/callback.py @@ -392,12 +392,13 @@ async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, input_info = {'tool_key': tool_name, 'serialized': serialized, 'input_str': input_str} self.tool_cache[kwargs.get('run_id').hex] = { 'input': input_info, - 'category': tool_category + 'category': tool_category, + 'steps': f'Tool input: \n\n{input_str}\n\n', } resp = ChatResponse(type='start', category=tool_category, - intermediate_steps=f'Tool input: {input_str}', - message=json.dumps(input_info), + intermediate_steps=self.tool_cache[kwargs.get('run_id').hex]['steps'], + message=json.dumps(input_info, ensure_ascii=False), flow_id=self.flow_id, chat_id=self.chat_id, extra=json.dumps({'run_id': kwargs.get('run_id').hex})) @@ -409,8 +410,12 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any: observation_prefix = kwargs.get('observation_prefix', 'Tool output: ') result = output + # 从tool cache中获取input信息 + input_info = self.tool_cache.get(kwargs.get('run_id').hex) # Create a formatted message. - intermediate_steps = f'{observation_prefix}{result}' + intermediate_steps = f'{observation_prefix}\n\n{result}' + if input_info: + intermediate_steps = f'{input_info["steps"]}\n\n{intermediate_steps}' tool_name, tool_category = self.parse_tool_category(kwargs.get('name')) @@ -419,15 +424,12 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any: resp = ChatResponse(type='end', category=tool_category, intermediate_steps=intermediate_steps, - message=json.dumps(output_info), + message=json.dumps(output_info, ensure_ascii=False), flow_id=self.flow_id, chat_id=self.chat_id, extra=json.dumps({'run_id': kwargs.get('run_id').hex})) await self.websocket.send_json(resp.dict()) - - # 从tool cache中获取input信息 - input_info = self.tool_cache.get(kwargs.get('run_id').hex) if input_info: output_info.update(input_info['input']) ChatMessageDao.insert_one( @@ -452,8 +454,8 @@ async def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], output_info.update(input_info['input']) resp = ChatResponse(type='end', category=input_info['category'], - intermediate_steps='Tool output: Error: ' + str(error), - message=json.dumps(output_info), + intermediate_steps=f'{input_info["steps"]}\n\nTool output:\n\n Error: ' + str(error), + message=json.dumps(output_info, ensure_ascii=False), flow_id=self.flow_id, chat_id=self.chat_id, extra=json.dumps({'run_id': kwargs.get('run_id').hex}))