Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/legacy/test_websocket_async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dashscope.protocol.websocket import WebsocketStreamingMode
from tests.unit.base_test import BaseTestEnvironment
from tests.unit.constants import TestTasks
from tests.legacy.websocket_task_request import WebSocketRequest
from tests.unit.websocket_task_request import WebSocketRequest

# set mock server url.
base_websocket_api_url = "ws://localhost:8080/ws/aigc/v1"
Expand Down
2 changes: 1 addition & 1 deletion tests/legacy/test_websocket_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
TEST_ENABLE_DATA_INSPECTION_REQUEST_ID,
TestTasks,
)
from tests.legacy.websocket_task_request import WebSocketRequest
from tests.unit.websocket_task_request import WebSocketRequest


def pytest_generate_tests(metafunc):
Expand Down
2 changes: 1 addition & 1 deletion tests/legacy/test_websocket_sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dashscope.protocol.websocket import WebsocketStreamingMode
from tests.unit.base_test import BaseTestEnvironment
from tests.unit.constants import TestTasks
from tests.legacy.websocket_task_request import WebSocketRequest
from tests.unit.websocket_task_request import WebSocketRequest


def pytest_generate_tests(metafunc):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
list_fine_tune_handler,
)
from tests.unit.mock_sse import sse_response
from tests.legacy.websocket_mock_server_task_handler import (
from tests.unit.websocket_mock_server_task_handler import (
WebSocketTaskProcessor,
)

Expand Down
File renamed without changes.
96 changes: 54 additions & 42 deletions tests/legacy/test_application.py → tests/unit/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,29 @@ def test_rag_call(self, mock_server: MockServer):
"action_type": "api",
"action_name": "文档检索",
"action": "searchDocument",
"action_input_stream": '{"query":"API接口说明中, TopP参数改如何传递?"}',
"action_input_stream": (
'{"query":"API接口说明中, ' 'TopP参数改如何传递?"}'
),
"action_input": {
"query": "API接口说明中, TopP参数改如何传递?",
"query": ("API接口说明中, TopP参数改如何传递?"),
},
"observation": """{"data": [
{
"docId": "1234",
"docName": "API接口说明",
"docUrl": "https://127.0.0.1/dl/API接口说明.pdf",
"indexId": "1",
"score": 0.11992252,
"text": "填(0,1.0),取值越大,生成的随机性越高;启用文档检索后,文档引用类型,取值包括:simple|indexed。",
"title": "API接口说明",
"titlePath": "API接口说明>>>接口说明>>>是否必 说明>>>填"
}
],
"status": "SUCCESS"
}""",
"response": "API接口说明中, TopP参数是一个float类型的参数,取值范围为0到1.0,默认为1.0。取值越大,生成的随机性越高。[5]",
"observation": (
'{"data": [{"docId": "1234", '
'"docName": "API接口说明", '
'"docUrl": "https://127.0.0.1/dl/'
'API接口说明.pdf", "indexId": "1", '
'"score": 0.11992252, "text": "填(0,1.0),'
"取值越大,生成的随机性越高;启用文档检索后,"
'文档引用类型,取值包括:simple|indexed。", '
'"title": "API接口说明", "titlePath": '
'"API接口说明>>>接口说明>>>是否必 说明>>>填"}], '
'"status": "SUCCESS"}'
),
"response": (
"API接口说明中, TopP参数是一个float类型的"
"参数,取值范围为0到1.0,默认为1.0。取值越大,"
"生成的随机性越高。[5]"
),
},
],
},
Expand All @@ -86,11 +90,12 @@ def test_rag_call(self, mock_server: MockServer):
top_p=0.2,
temperature=1.0,
doc_tag_codes=["t1234", "t2345"],
doc_reference_type=Application.DocReferenceType.simple,
doc_reference_type=(Application.DocReferenceType.simple),
has_thoughts=True,
)

self.check_result(resp, test_response)
# Test mock response type
self.check_result(resp, test_response) # type: ignore[arg-type]

def test_flow_call(self, mock_server: MockServer):
test_response = {
Expand All @@ -105,13 +110,19 @@ def test_flow_call(self, mock_server: MockServer):
"action_type": "api",
"action_name": "plugin",
"action": "api",
"action_input_stream": '{"userId": "123", "date": "202402", "city": "hangzhou"}',
"action_input_stream": (
'{"userId": "123", "date": "202402", '
'"city": "hangzhou"}'
),
"action_input": {
"userId": "123",
"date": "202402",
"city": "hangzhou",
},
"observation": """{"quantity": 102, "type": "resident", "date": "202402", "unit": "千瓦"}""",
"observation": (
'{"quantity": 102, "type": "resident", '
'"date": "202402", "unit": "千瓦"}'
),
"response": "当月的居民用电量为102千瓦。",
},
],
Expand Down Expand Up @@ -140,7 +151,8 @@ def test_flow_call(self, mock_server: MockServer):
has_thoughts=True,
)

self.check_result(resp, test_response)
# Test mock response type
self.check_result(resp, test_response) # type: ignore[arg-type]

def test_call_with_error(self, mock_server: MockServer):
test_response = {
Expand Down Expand Up @@ -212,21 +224,21 @@ def check_result(resp: ApplicationResponse, test_response: Dict):
expected_doc_refs,
)

for i in range(len(doc_refs)):
assert doc_refs[i].index_id == expected_doc_refs[i].get(
for i, doc_ref in enumerate(doc_refs):
assert doc_ref.index_id == expected_doc_refs[i].get(
"index_id",
)
assert doc_refs[i].doc_id == expected_doc_refs[i].get("doc_id")
assert doc_refs[i].doc_name == expected_doc_refs[i].get(
assert doc_ref.doc_id == expected_doc_refs[i].get("doc_id")
assert doc_ref.doc_name == expected_doc_refs[i].get(
"doc_name",
)
assert doc_refs[i].doc_url == expected_doc_refs[i].get(
assert doc_ref.doc_url == expected_doc_refs[i].get(
"doc_url",
)
assert doc_refs[i].title == expected_doc_refs[i].get("title")
assert doc_refs[i].text == expected_doc_refs[i].get("text")
assert doc_refs[i].biz_id == expected_doc_refs[i].get("biz_id")
assert json.dumps(doc_refs[i].images) == json.dumps(
assert doc_ref.title == expected_doc_refs[i].get("title")
assert doc_ref.text == expected_doc_refs[i].get("text")
assert doc_ref.biz_id == expected_doc_refs[i].get("biz_id")
assert json.dumps(doc_ref.images) == json.dumps(
expected_doc_refs[i].get("images"),
)

Expand All @@ -238,26 +250,26 @@ def check_result(resp: ApplicationResponse, test_response: Dict):
expected_thoughts,
)

for i in range(len(thoughts)):
assert thoughts[i].thought == expected_thoughts[i].get(
for i, thought in enumerate(thoughts):
assert thought.thought == expected_thoughts[i].get(
"thought",
)
assert thoughts[i].action == expected_thoughts[i].get("action")
assert thoughts[i].action_name == expected_thoughts[i].get(
assert thought.action == expected_thoughts[i].get("action")
assert thought.action_name == expected_thoughts[i].get(
"action_name",
)
assert thoughts[i].action_type == expected_thoughts[i].get(
assert thought.action_type == expected_thoughts[i].get(
"action_type",
)
assert json.dumps(thoughts[i].action_input) == json.dumps(
assert json.dumps(thought.action_input) == json.dumps(
expected_thoughts[i].get("action_input"),
)
assert thoughts[i].action_input_stream == expected_thoughts[
i
].get("action_input_stream")
assert thoughts[i].observation == expected_thoughts[i].get(
assert thought.action_input_stream == (
expected_thoughts[i].get("action_input_stream")
)
assert thought.observation == expected_thoughts[i].get(
"observation",
)
assert thoughts[i].response == expected_thoughts[i].get(
assert thought.response == expected_thoughts[i].get(
"response",
)
Loading