Skip to content

Commit

Permalink
Feat/0.3.0 (dataelement#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 authored Apr 17, 2024
2 parents 38274ab + 695e321 commit bc107b4
Show file tree
Hide file tree
Showing 91 changed files with 1,847 additions and 9,820 deletions.
18 changes: 11 additions & 7 deletions src/backend/bisheng/api/services/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,23 @@ async def auto_update_stream(cls, assistant_id: UUID, prompt: str):
yield str(StreamData(event='message', data={'type': 'prompt', 'message': one_prompt.content}))
final_prompt += one_prompt.content
assistant.prompt = final_prompt
yield str(StreamData(event='message', data={'type': 'end', 'message': ""}))

# 生成开场白和开场问题
guide_info = auto_agent.generate_guide(assistant.prompt)
yield str(StreamData(event='message', data={'type': 'guide_word', 'message': guide_info['opening_lines']}))
yield str(StreamData(event='message', data={'type': 'end', 'message': ""}))
yield str(StreamData(event='message', data={'type': 'guide_question', 'message': guide_info['questions']}))
yield str(StreamData(event='message', data={'type': 'end', 'message': ""}))

# 自动选择工具和技能
tool_info = cls.get_auto_tool_info(assistant, auto_agent)
tool_info = [one.model_dump() for one in tool_info]
yield str(StreamData(event='message', data={'type': 'tool_list', 'message': tool_info}))
yield str(StreamData(event='message', data={'type': 'end', 'message': ""}))

flow_info = cls.get_auto_flow_info(assistant, auto_agent)
flow_info = [one. model_dump() for one in flow_info]
flow_info = [one.model_dump() for one in flow_info]
yield str(StreamData(event='message', data={'type': 'flow_list', 'message': flow_info}))

@classmethod
Expand All @@ -173,12 +177,12 @@ async def update_assistant(cls, req: AssistantUpdateReq, user_payload: UserPaylo
return AssistantNameRepeatError.return_resp()
assistant.name = req.name
assistant.desc = req.desc
assistant.logo = req.logo or assistant.logo
assistant.prompt = req.prompt or assistant.prompt
assistant.guide_word = req.guide_word or assistant.guide_word
assistant.guide_question = req.guide_question or assistant.guide_question
assistant.model_name = req.model_name or assistant.model_name
assistant.temperature = req.temperature or assistant.temperature
assistant.logo = req.logo
assistant.prompt = req.prompt
assistant.guide_word = req.guide_word
assistant.guide_question = req.guide_question
assistant.model_name = req.model_name
assistant.temperature = req.temperature
assistant.update_time = datetime.now()
AssistantDao.update_assistant(assistant)

Expand Down
23 changes: 12 additions & 11 deletions src/backend/bisheng/api/services/assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ async def get_knowledge_skill_data(self):
self.knowledge_skill_data = data
return data

def parse_tool_params(self, tool: GptsTools) -> Dict:
if not tool.extra:
return {}
params = json.loads(tool.extra)

# 判断是否需要从系统配置里获取, 不需要从系统配置获取则用本身配置的
if params.get('&initdb_conf_key'):
return self.get_initdb_conf_by_more_key(params.get('&initdb_conf_key'))
return params

async def init_tools(self, callbacks: Callbacks = None):
"""通过名称获取tool 列表
tools_name_param:: {name: params}
Expand All @@ -102,7 +112,7 @@ async def init_tools(self, callbacks: Callbacks = None):
if tool_ids:
tools_model: List[GptsTools] = GptsToolsDao.get_list_by_ids(tool_ids)
tool_name_param = {
tool.tool_key: json.loads(tool.extra) if tool.extra else {}
tool.tool_key: self.parse_tool_params(tool)
for tool in tools_model
}
tool_langchain = load_tools(tool_params=tool_name_param,
Expand Down Expand Up @@ -241,16 +251,7 @@ async def run(self, query: str, chat_history: List = None, callback: Callbacks =
'name': one,
}, input_str='', run_id=run_id)
await callback[0].on_tool_end(output='', name=one, run_id=run_id)

result = {}
async for one in self.agent.astream_events(inputs,
config=RunnableConfig(callbacks=callback),
version='v1'):
if one['event'] == 'on_chain_end':
result = one

# 最后一次输出的event即最终答案
result = result['data']['output']['__end__']
result = await self.agent.ainvoke(inputs, config=RunnableConfig(callbacks=callback))
# 包含了history,将history排除
res = []
for one in result:
Expand Down
15 changes: 15 additions & 0 deletions src/backend/bisheng/api/services/assistant_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

from bisheng.settings import settings


Expand Down Expand Up @@ -29,3 +31,16 @@ def get_agent_executor(cls):
@classmethod
def get_default_retrieval(cls) -> str:
return cls.get_gpts_conf('default-retrieval')

@classmethod
def get_initdb_conf_by_more_key(cls, key: str) -> Dict:
"""
根据多层级的key,获取对应的配置。
:param key: 例如:gpts.tools.code_interpreter 表示获取 gpts['tools']['code_interpreter']的内容
"""
# 因为属于系统配置级别,不做不存在的判断。不存在直接抛出异常
key_list = key.split('.')
root_conf = settings.get_from_db(key_list[0].strip())
for one in key_list[1:]:
root_conf = root_conf[one.strip()]
return root_conf
2 changes: 1 addition & 1 deletion src/backend/bisheng/api/services/knowledge_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def read_chunk_text(input_file, file_name, size, chunk_overlap, separator):
} for t in texts]
else:
# 如果文件不是pdf 需要内部转pdf
if file_name.rsplit('.', 1)[-1] != 'pdf':
if file_name.rsplit('.', 1)[-1].lower() != 'pdf':
b64_data = base64.b64encode(open(input_file, 'rb').read()).decode()
inp = dict(filename=file_name, b64_data=[b64_data], mode='topdf')
resp = requests.post(settings.get_knowledge().get('unstructured_api_url'), json=inp)
Expand Down
81 changes: 38 additions & 43 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,15 @@ async def on_text(self, text: str, **kwargs: Any) -> Any:
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
logger.debug(f'on_agent_action action={action} kwargs={kwargs}')

log = f'Thought: {action.log}'
log = f'\nThought: {action.log}'
# if there are line breaks, split them and send them
# as separate messages
if '\n' in log:
logs = log.split('\n')
for log in logs:
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(resp.dict())
else:
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(resp.dict())
log = log.replace('\n', '\n\n')
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(resp.dict())

async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
Expand Down Expand Up @@ -232,27 +224,17 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
asyncio.run_coroutine_threadsafe(coroutine, loop)

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
log = f'Thought: {action.log}'
log = f'\nThought: {action.log}'
# if there are line breaks, split them and send them
# as separate messages
if '\n' in log:
logs = log.split('\n')
for log in logs:
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)
else:
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)
log = log.replace("\n", "\n\n")
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
Expand Down Expand Up @@ -406,16 +388,18 @@ async def on_tool_start(self, serialized: Dict[str, Any], input_str: str,
logger.debug(
f'on_tool_start serialized={serialized} input_str={input_str} kwargs={kwargs}')

input_str = input_str
tool_name, tool_category = self.parse_tool_category(serialized['name'])
input_info = {'tool_key': tool_name, 'serialized': serialized, 'input_str': input_str}
self.tool_cache[kwargs.get('run_id').hex] = {
'input': input_info,
'category': tool_category
'category': tool_category,
'steps': f'Tool input: \n\n{input_str}\n\n',
}
resp = ChatResponse(type='start',
category=tool_category,
intermediate_steps=f'Tool input: {input_str}',
message=json.dumps(input_info),
intermediate_steps=self.tool_cache[kwargs.get('run_id').hex]['steps'],
message=json.dumps(input_info, ensure_ascii=False),
flow_id=self.flow_id,
chat_id=self.chat_id,
extra=json.dumps({'run_id': kwargs.get('run_id').hex}))
Expand All @@ -428,26 +412,25 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:

result = output
# Create a formatted message.
intermediate_steps = f'{observation_prefix}{result}'

intermediate_steps = f'{observation_prefix}\n\n{result}'
tool_name, tool_category = self.parse_tool_category(kwargs.get('name'))

# Create a ChatResponse instance.
output_info = {'tool_key': tool_name, 'output': output}
resp = ChatResponse(type='end',
category=tool_category,
intermediate_steps=intermediate_steps,
message=json.dumps(output_info),
message=json.dumps(output_info, ensure_ascii=False),
flow_id=self.flow_id,
chat_id=self.chat_id,
extra=json.dumps({'run_id': kwargs.get('run_id').hex}))

await self.websocket.send_json(resp.dict())

# 从tool cache中获取input信息
input_info = self.tool_cache.get(kwargs.get('run_id').hex)
if input_info:
output_info.update(input_info['input'])
intermediate_steps = f'{input_info["steps"]}\n\n{intermediate_steps}'
ChatMessageDao.insert_one(
ChatMessageModel(is_bot=1,
message=json.dumps(output_info),
Expand All @@ -470,10 +453,22 @@ async def on_tool_error(self, error: Union[Exception, KeyboardInterrupt],
output_info.update(input_info['input'])
resp = ChatResponse(type='end',
category=input_info['category'],
intermediate_steps='Tool output: Error: ' + str(error),
message=json.dumps(output_info),
intermediate_steps='\n\nTool output:\n\n Error: ' + str(error),
message=json.dumps(output_info, ensure_ascii=False),
flow_id=self.flow_id,
chat_id=self.chat_id,
extra=json.dumps({'run_id': kwargs.get('run_id').hex}))
await self.websocket.send_json(resp.dict())

# 保存工具调用记录
self.tool_cache.pop(kwargs.get('run_id').hex)
ChatMessageDao.insert_one(
ChatMessageModel(is_bot=1,
message=json.dumps(output_info),
intermediate_steps=f'{input_info["steps"]}\n\nTool output:\n\n Error: ' + str(error),
category=tool_category,
type='end',
flow_id=self.flow_id,
chat_id=self.chat_id,
user_id=self.user_id,
extra=json.dumps({'run_id': kwargs.get('run_id').hex})))
3 changes: 2 additions & 1 deletion src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def get_chatlist_list(*, Authorize: AuthJWT = Depends()):
create_time=message.create_time,
update_time=message.update_time))
else:
logger.warning(f'没有找到flow_id={message.flow_id}')
# 通过接口创建的会话记录,不关联技能或者助手
logger.debug(f'unknown message.flow_id={message.flow_id}')
return resp_200(chat_list)


Expand Down
38 changes: 16 additions & 22 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
from bisheng.processing.process import process_graph_cached, process_tweaks
from bisheng.services.deps import get_session_service, get_task_service
from bisheng.services.task.service import TaskService
from bisheng.settings import parse_key
from bisheng.utils.logger import logger
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile
from fastapi_jwt_auth import AuthJWT
from sqlalchemy import delete
from sqlmodel import select

try:
Expand Down Expand Up @@ -76,33 +74,29 @@ def get_config(Authorize: AuthJWT = Depends()):
if payload.get('role') != 'admin':
raise HTTPException(status_code=500, detail='Unauthorized')
with session_getter() as session:
configs = session.exec(select(Config)).all()
config_str = []
for config in configs:
config_str.append(config.key + ':')
config_str.append(config.value)
return resp_200('\n'.join(config_str))
config = session.exec(select(Config).where(
Config.key == 'initdb_config'
)).first()
if config:
config_str = config.value
else:
config_str = ''
return resp_200(config_str)


@router.post('/config/save')
def save_config(data: dict):
try:
config_yaml = yaml.safe_load(data.get('data'))
with session_getter() as session:
old_config = session.exec(select(Config).where(Config.id > 0)).all()
session.exec(delete(Config).where(Config.id > 0))
session.commit()
keys = list(config_yaml.keys())
values = parse_key(keys, data.get('data'))
# 校验是否符合yaml格式
_ = yaml.safe_load(data.get('data'))
with session_getter() as session:
for index, key in enumerate(keys):
config = Config(key=key, value=values[index])
session.add(config)
config = session.exec(select(Config).where(
Config.key == 'initdb_config'
)).first()
config.value = data.get('data')
session.add(config)
session.commit()
# 淘汰缓存
for old in old_config:
redis_key = 'config_' + old.key
redis_client.delete(redis_key)
redis_client.delete('config:initdb_config')
except Exception as e:
raise HTTPException(status_code=500, detail=f'格式不正确, {str(e)}')

Expand Down
Loading

0 comments on commit bc107b4

Please sign in to comment.