Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,26 @@ async def export_excel(session: SessionDep, current_user: CurrentUser, chat_reco
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
for column in chart_info.get('columns'):
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))
if chart_info.get('axis'):
for _type in ['x', 'y', 'series']:
if chart_info.get('axis').get(_type):
column = chart_info.get('axis').get(_type)
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))
# 处理 axis
if axis := chart_info.get('axis'):
# 处理 x 轴
if x_axis := axis.get('x'):
if 'name' in x_axis or 'value' in x_axis:
fields.append(AxisObj(name=x_axis.get('name'), value=x_axis.get('value')))

# 处理 y 轴 - 兼容数组和对象格式
if y_axis := axis.get('y'):
if isinstance(y_axis, list):
for column in y_axis:
if 'name' in column or 'value' in column:
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))
elif isinstance(y_axis, dict) and ('name' in y_axis or 'value' in y_axis):
fields.append(AxisObj(name=y_axis.get('name'), value=y_axis.get('value')))

# 处理 series
if series := axis.get('series'):
if 'name' in series or 'value' in series:
fields.append(AxisObj(name=series.get('name'), value=series.get('value')))

_predict_data = []
if is_predict_data:
Expand Down
55 changes: 39 additions & 16 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,41 @@ def get_chart_config(session: SessionDep, chart_record_id: int):
return {}


def format_chart_fields(chart_info: dict):
def _format_column(column: dict) -> str:
"""格式化单个column字段"""
value = column.get('value', '')
name = column.get('name', '')
if value != name and name:
return f"{value}({name})"
return value


def format_chart_fields(chart_info: dict) -> list:
fields = []
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
for column in chart_info.get('columns'):
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
if chart_info.get('axis'):
for _type in ['x', 'y', 'series']:
if chart_info.get('axis').get(_type):
column = chart_info.get('axis').get(_type)
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
return fields

# 处理 columns
for column in chart_info.get('columns') or []:
fields.append(_format_column(column))

# 处理 axis
if axis := chart_info.get('axis'):
# 处理 x 轴
if x_axis := axis.get('x'):
fields.append(_format_column(x_axis))

# 处理 y 轴
if y_axis := axis.get('y'):
if isinstance(y_axis, list):
for column in y_axis:
fields.append(_format_column(column))
else:
fields.append(_format_column(y_axis))

# 处理 series
if series := axis.get('series'):
fields.append(_format_column(series))

return [field for field in fields if field] # 过滤空字符串


def get_last_execute_sql_error(session: SessionDep, chart_id: int):
Expand Down Expand Up @@ -410,6 +428,11 @@ def format_record(record: ChatRecordResult):
_dict['sql'] = sqlparse.format(record.sql, reindent=True)
except Exception:
pass
# 去除返回前端多余的字段
_dict.pop('sql_reasoning_content', None)
_dict.pop('chart_reasoning_content', None)
_dict.pop('analysis_reasoning_content', None)
_dict.pop('predict_reasoning_content', None)

return _dict

Expand Down
17 changes: 15 additions & 2 deletions backend/templates/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ template:
<rule>
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
</rule>
<rule>
你只需要根据提供给你的信息生成的SQL,不需要你实际去数据库进行查询
</rule>
<rule>
请使用JSON格式返回你的回答:
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table","brief":"如何需要生成对话标题,在这里填写你生成的对话标题,否则不需要这个字段"}}
Expand Down Expand Up @@ -141,6 +144,17 @@ template:
<rule>
是否生成对话标题在<change-title>内,如果为True需要生成,否则不需要生成,生成的对话标题要求在20字以内
</rule>
<rule priority="critical" id="no-additional-info">
<title>禁止要求额外信息</title>
<requirements>
<requirement>禁止在回答中向用户询问或要求任何额外信息</requirement>
<requirement>只基于表结构和问题生成SQL,不考虑业务逻辑</requirement>
<requirement>即使查询条件不完整(如无时间范围),也必须生成可行的SQL</requirement>
</requirements>
</rule>
<rule>
不论上下文是否有回答相同的问题,都需要检查生成的SQL是否匹配<m-schema>内的定义
</rule>
</Rules>

{process_check}
Expand Down Expand Up @@ -466,8 +480,7 @@ template:
[]
- 若你的给出的JSON不是{lang}的,则必须翻译为{lang}

### 响应, 请直接返回JSON结果:
```json
### 响应, 请直接返回JSON结果(不要包含任何其他文本):

user: |
### 表结构:
Expand Down
1 change: 1 addition & 0 deletions frontend/src/views/chat/component/BaseChart.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export interface ChartAxis {
value: string
type?: 'x' | 'y' | 'series' | 'other-info'
'multi-quota'?: boolean
hidden?: boolean
}

export interface ChartData {
Expand Down
7 changes: 6 additions & 1 deletion frontend/src/views/chat/component/ChartComponent.vue
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ const axis = computed(() => {
_list.push({ name: column.name, value: column.value, type: 'series' })
})
if (params.multiQuotaName) {
_list.push({ name: params.multiQuotaName, value: params.multiQuotaName, type: 'other-info' })
_list.push({
name: params.multiQuotaName,
value: params.multiQuotaName,
type: 'other-info',
hidden: true,
})
}
return _list
})
Expand Down
7 changes: 5 additions & 2 deletions frontend/src/views/chat/component/charts/Table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
type S2DataConfig,
type S2MountContainer,
} from '@antv/s2'
import { debounce } from 'lodash-es'
import { debounce, filter } from 'lodash-es'
import { i18n } from '@/i18n'

const { t } = i18n.global
Expand Down Expand Up @@ -43,7 +43,10 @@ export class Table extends BaseChart {
}

init(axis: Array<ChartAxis>, data: Array<ChartData>) {
super.init(axis, data)
super.init(
filter(axis, (a) => !a.hidden), //隐藏多指标的other-info列
data
)

const s2DataConfig: S2DataConfig = {
fields: {
Expand Down