Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from sqlalchemy import and_, select
from starlette.responses import JSONResponse

from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, list_chats, get_chat_with_records, create_chat, rename_chat, \
from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, \
list_chats, get_chat_with_records, create_chat, rename_chat, \
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
format_json_data, format_json_list_data, get_chart_config, list_recent_questions,get_chat as get_chat_exec, rename_chat_with_user
format_json_data, format_json_list_data, get_chart_config, list_recent_questions, get_chat as get_chat_exec, \
rename_chat_with_user
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \
ChatInfo, Chat, ChatFinishStep
from apps.chat.task.llm import LLMService
Expand Down Expand Up @@ -69,6 +71,7 @@ def inner():

return await asyncio.to_thread(inner) """


@router.get("/record/{chat_record_id}/data", summary=f"{PLACEHOLDER_PREFIX}get_chart_data")
async def chat_record_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int):
def inner():
Expand All @@ -81,7 +84,8 @@ def inner():
@router.get("/record/{chat_record_id}/predict_data", summary=f"{PLACEHOLDER_PREFIX}get_chart_predict_data")
async def chat_predict_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int):
def inner():
data = get_chat_predict_data_with_user(chat_record_id=chat_record_id, session=session, current_user=current_user)
data = get_chat_predict_data_with_user(chat_record_id=chat_record_id, session=session,
current_user=current_user)
return format_json_list_data(data)

return await asyncio.to_thread(inner)
Expand All @@ -102,6 +106,7 @@ async def rename(session: SessionDep, chat: RenameChat):
detail=str(e)
) """


@router.post("/rename", response_model=str, summary=f"{PLACEHOLDER_PREFIX}rename_chat")
@system_log(LogConfig(
operation_type=OperationType.UPDATE,
Expand All @@ -117,6 +122,7 @@ async def rename(session: SessionDep, current_user: CurrentUser, chat: RenameCha
detail=str(e)
)


""" @router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat")
@system_log(LogConfig(
operation_type=OperationType.DELETE,
Expand All @@ -133,6 +139,7 @@ async def delete(session: SessionDep, chart_id: int, brief: str):
detail=str(e)
) """


@router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat")
@system_log(LogConfig(
operation_type=OperationType.DELETE,
Expand All @@ -149,6 +156,7 @@ async def delete(session: SessionDep, current_user: CurrentUser, chart_id: int,
detail=str(e)
)


@router.post("/start", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}start_chat")
@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="create_chat_obj.datasource"))
@system_log(LogConfig(
Expand All @@ -172,9 +180,11 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
module=OperationModules.CHAT,
result_id_expr="id"
))
async def start_chat(session: SessionDep, current_user: CurrentUser, current_assistant: CurrentAssistant, create_chat_obj: CreateChat = CreateChat(origin=2)):
async def start_chat(session: SessionDep, current_user: CurrentUser, current_assistant: CurrentAssistant,
create_chat_obj: CreateChat = CreateChat(origin=2)):
try:
return create_chat(session, current_user, create_chat_obj, create_chat_obj and create_chat_obj.datasource, current_assistant)
return create_chat(session, current_user, create_chat_obj, create_chat_obj and create_chat_obj.datasource,
current_assistant)
except Exception as e:
raise HTTPException(
status_code=500,
Expand Down Expand Up @@ -213,7 +223,7 @@ def _err(_e: Exception):

@router.get("/recent_questions/{datasource_id}", response_model=List[str],
summary=f"{PLACEHOLDER_PREFIX}get_recommend_questions")
#@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id"))
# @require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id"))
async def recommend_questions(session: SessionDep, current_user: CurrentUser,
datasource_id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id")):
return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id)
Expand Down Expand Up @@ -442,8 +452,8 @@ def _err(_e: Exception):


@router.get("/record/{chat_record_id}/excel/export/{chat_id}", summary=f"{PLACEHOLDER_PREFIX}export_chart_data")
@system_log(LogConfig(operation_type=OperationType.EXPORT,module=OperationModules.CHAT,resource_id_expr="chat_id",))
async def export_excel(session: SessionDep, current_user: CurrentUser, chat_record_id: int,chat_id: int, trans: Trans):
@system_log(LogConfig(operation_type=OperationType.EXPORT, module=OperationModules.CHAT, resource_id_expr="chat_id", ))
async def export_excel(session: SessionDep, current_user: CurrentUser, chat_record_id: int, chat_id: int, trans: Trans):
chat_record = session.get(ChatRecord, chat_record_id)
if not chat_record:
raise HTTPException(
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
ds.type_name = DB.get_db(ds.type)
else:
ds = session.get(CoreDatasource, create_chat_obj.datasource)
if ds.oid != current_user.oid:
raise Exception(f"Datasource with id {create_chat_obj.datasource} does not belong to current workspace")

if not ds:
raise Exception(f"Datasource with id {create_chat_obj.datasource} not found")
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def dynamic_user_question(self):

class ChatQuestion(AiModelQuestion):
chat_id: int
datasource_id: Optional[int] = None


class ChatMcp(ChatQuestion):
Expand All @@ -299,6 +300,7 @@ class McpQuestion(BaseModel):
token: str = Body(description='token')
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)
lang: Optional[str] = Body(description='语言:zh-CN|en|ko-KR', default='zh-CN')
datasource_id: Optional[int] = Body(description='数据源ID,仅当当前对话没有确定数据源时有效', default=None)


class AxisObj(BaseModel):
Expand Down
13 changes: 13 additions & 0 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
if not chat:
raise SingleMessageError(f"Chat with id {chat_id} not found")
ds: CoreDatasource | AssistantOutDsSchema | None = None
if not chat.datasource and chat_question.datasource_id:
_ds = session.get(CoreDatasource, chat_question.datasource_id)
if _ds:
if _ds.oid != current_user.oid:
raise SingleMessageError(f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
chat.datasource = _ds.id
chat.engine_type = _ds.type_name
# save chat
session.add(chat)
session.flush()
session.refresh(chat)
session.commit()

if chat.datasource:
# Get available datasource
if current_assistant and current_assistant.type in dynamic_ds_types:
Expand Down
2 changes: 1 addition & 1 deletion backend/apps/mcp/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def mcp_start(session: SessionDep, chat: ChatStart):
async def mcp_question(session: SessionDep, chat: McpQuestion):
session_user = get_user(session, chat.token)

mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question)
mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question, datasource_id=chat.datasource_id)

return await question_answer_inner(session=session, current_user=session_user, request_question=mcp_chat,
in_chat=False, stream=chat.stream)
Expand Down