diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index a2bfdb19..773bbcca 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -10,9 +10,11 @@ from sqlalchemy import and_, select from starlette.responses import JSONResponse -from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, list_chats, get_chat_with_records, create_chat, rename_chat, \ +from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, \ + list_chats, get_chat_with_records, create_chat, rename_chat, \ delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \ - format_json_data, format_json_list_data, get_chart_config, list_recent_questions,get_chat as get_chat_exec, rename_chat_with_user + format_json_data, format_json_list_data, get_chart_config, list_recent_questions, get_chat as get_chat_exec, \ + rename_chat_with_user from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \ ChatInfo, Chat, ChatFinishStep from apps.chat.task.llm import LLMService @@ -69,6 +71,7 @@ def inner(): return await asyncio.to_thread(inner) """ + @router.get("/record/{chat_record_id}/data", summary=f"{PLACEHOLDER_PREFIX}get_chart_data") async def chat_record_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int): def inner(): @@ -81,7 +84,8 @@ def inner(): @router.get("/record/{chat_record_id}/predict_data", summary=f"{PLACEHOLDER_PREFIX}get_chart_predict_data") async def chat_predict_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int): def inner(): - data = get_chat_predict_data_with_user(chat_record_id=chat_record_id, session=session, current_user=current_user) + data = get_chat_predict_data_with_user(chat_record_id=chat_record_id, session=session, + current_user=current_user) return format_json_list_data(data) return await asyncio.to_thread(inner) @@ -102,6 +106,7 @@ async def rename(session: SessionDep, chat: RenameChat): detail=str(e) ) """ + @router.post("/rename", response_model=str, summary=f"{PLACEHOLDER_PREFIX}rename_chat") @system_log(LogConfig( operation_type=OperationType.UPDATE, @@ -117,6 +122,7 @@ async def rename(session: SessionDep, current_user: CurrentUser, chat: RenameCha detail=str(e) ) + """ @router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat") @system_log(LogConfig( operation_type=OperationType.DELETE, @@ -133,6 +139,7 @@ async def delete(session: SessionDep, chart_id: int, brief: str): detail=str(e) ) """ + @router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat") @system_log(LogConfig( operation_type=OperationType.DELETE, @@ -149,6 +156,7 @@ async def delete(session: SessionDep, current_user: CurrentUser, chart_id: int, detail=str(e) ) + @router.post("/start", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}start_chat") @require_permissions(permission=SqlbotPermission(type='ds', keyExpression="create_chat_obj.datasource")) @system_log(LogConfig( @@ -172,9 +180,11 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat module=OperationModules.CHAT, result_id_expr="id" )) -async def start_chat(session: SessionDep, current_user: CurrentUser, current_assistant: CurrentAssistant, create_chat_obj: CreateChat = CreateChat(origin=2)): +async def start_chat(session: SessionDep, current_user: CurrentUser, current_assistant: CurrentAssistant, + create_chat_obj: CreateChat = CreateChat(origin=2)): try: - return create_chat(session, current_user, create_chat_obj, create_chat_obj and create_chat_obj.datasource, current_assistant) + return create_chat(session, current_user, create_chat_obj, create_chat_obj and create_chat_obj.datasource, + current_assistant) except Exception as e: raise HTTPException( status_code=500, @@ -213,7 +223,7 @@ def _err(_e: Exception): @router.get("/recent_questions/{datasource_id}", response_model=List[str], summary=f"{PLACEHOLDER_PREFIX}get_recommend_questions") -#@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id")) +# @require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id")) async def recommend_questions(session: SessionDep, current_user: CurrentUser, datasource_id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id")): return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id) @@ -442,8 +452,8 @@ def _err(_e: Exception): @router.get("/record/{chat_record_id}/excel/export/{chat_id}", summary=f"{PLACEHOLDER_PREFIX}export_chart_data") -@system_log(LogConfig(operation_type=OperationType.EXPORT,module=OperationModules.CHAT,resource_id_expr="chat_id",)) -async def export_excel(session: SessionDep, current_user: CurrentUser, chat_record_id: int,chat_id: int, trans: Trans): +@system_log(LogConfig(operation_type=OperationType.EXPORT, module=OperationModules.CHAT, resource_id_expr="chat_id", )) +async def export_excel(session: SessionDep, current_user: CurrentUser, chat_record_id: int, chat_id: int, trans: Trans): chat_record = session.get(ChatRecord, chat_record_id) if not chat_record: raise HTTPException( diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index 3b38fda0..c71ed2a7 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -493,6 +493,8 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: ds.type_name = DB.get_db(ds.type) else: ds = session.get(CoreDatasource, create_chat_obj.datasource) + if ds.oid != current_user.oid: + raise Exception(f"Datasource with id {create_chat_obj.datasource} does not belong to current workspace") if not ds: raise Exception(f"Datasource with id {create_chat_obj.datasource} not found") diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index fdcd1a40..590155fd 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -282,6 +282,7 @@ def dynamic_user_question(self): class ChatQuestion(AiModelQuestion): chat_id: int + datasource_id: Optional[int] = None class ChatMcp(ChatQuestion): @@ -299,6 +300,7 @@ class McpQuestion(BaseModel): token: str = Body(description='token') stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) lang: Optional[str] = Body(description='语言:zh-CN|en|ko-KR', default='zh-CN') + datasource_id: Optional[int] = Body(description='数据源ID,仅当当前对话没有确定数据源时有效', default=None) class AxisObj(BaseModel): diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 8187f5b4..686f6c5f 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -107,6 +107,19 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C if not chat: raise SingleMessageError(f"Chat with id {chat_id} not found") ds: CoreDatasource | AssistantOutDsSchema | None = None + if not chat.datasource and chat_question.datasource_id: + _ds = session.get(CoreDatasource, chat_question.datasource_id) + if _ds: + if _ds.oid != current_user.oid: + raise SingleMessageError(f"Datasource with id {chat_question.datasource_id} does not belong to current workspace") + chat.datasource = _ds.id + chat.engine_type = _ds.type_name + # save chat + session.add(chat) + session.flush() + session.refresh(chat) + session.commit() + if chat.datasource: # Get available datasource if current_assistant and current_assistant.type in dynamic_ds_types: diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py index b935ffb1..1d6dfd6d 100644 --- a/backend/apps/mcp/mcp.py +++ b/backend/apps/mcp/mcp.py @@ -114,7 +114,7 @@ async def mcp_start(session: SessionDep, chat: ChatStart): async def mcp_question(session: SessionDep, chat: McpQuestion): session_user = get_user(session, chat.token) - mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question) + mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question, datasource_id=chat.datasource_id) return await question_answer_inner(session=session, current_user=session_user, request_question=mcp_chat, in_chat=False, stream=chat.stream)