From 538813d60e32b99ec28d45b0dd35acc829e50262 Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Thu, 22 Jan 2026 09:32:21 +0800 Subject: [PATCH 1/2] feat(ci): add unit tests --- tests/{legacy => unit}/test_add_resources.py | 0 tests/{legacy => unit}/test_application.py | 96 +++++----- tests/{legacy => unit}/test_asr_phrases.py | 124 +++++++------ .../{legacy => unit}/test_assistant_files.py | 0 tests/{legacy => unit}/test_assistants.py | 30 ++- .../{legacy => unit}/test_code_generation.py | 172 ++++++++++++++---- tests/{legacy => unit}/test_conversation.py | 31 ++-- 7 files changed, 305 insertions(+), 148 deletions(-) rename tests/{legacy => unit}/test_add_resources.py (100%) rename tests/{legacy => unit}/test_application.py (69%) rename tests/{legacy => unit}/test_asr_phrases.py (68%) rename tests/{legacy => unit}/test_assistant_files.py (100%) rename tests/{legacy => unit}/test_assistants.py (90%) rename tests/{legacy => unit}/test_code_generation.py (78%) rename tests/{legacy => unit}/test_conversation.py (88%) diff --git a/tests/legacy/test_add_resources.py b/tests/unit/test_add_resources.py similarity index 100% rename from tests/legacy/test_add_resources.py rename to tests/unit/test_add_resources.py diff --git a/tests/legacy/test_application.py b/tests/unit/test_application.py similarity index 69% rename from tests/legacy/test_application.py rename to tests/unit/test_application.py index 5a8341a..480c071 100644 --- a/tests/legacy/test_application.py +++ b/tests/unit/test_application.py @@ -45,25 +45,29 @@ def test_rag_call(self, mock_server: MockServer): "action_type": "api", "action_name": "文档检索", "action": "searchDocument", - "action_input_stream": '{"query":"API接口说明中, TopP参数改如何传递?"}', + "action_input_stream": ( + '{"query":"API接口说明中, ' 'TopP参数改如何传递?"}' + ), "action_input": { - "query": "API接口说明中, TopP参数改如何传递?", + "query": ("API接口说明中, TopP参数改如何传递?"), }, - "observation": """{"data": [ - { - "docId": "1234", - "docName": "API接口说明", - "docUrl": "https://127.0.0.1/dl/API接口说明.pdf", - "indexId": "1", - "score": 0.11992252, - "text": "填(0,1.0),取值越大,生成的随机性越高;启用文档检索后,文档引用类型,取值包括:simple|indexed。", - "title": "API接口说明", - "titlePath": "API接口说明>>>接口说明>>>是否必 说明>>>填" - } - ], - "status": "SUCCESS" - }""", - "response": "API接口说明中, TopP参数是一个float类型的参数,取值范围为0到1.0,默认为1.0。取值越大,生成的随机性越高。[5]", + "observation": ( + '{"data": [{"docId": "1234", ' + '"docName": "API接口说明", ' + '"docUrl": "https://127.0.0.1/dl/' + 'API接口说明.pdf", "indexId": "1", ' + '"score": 0.11992252, "text": "填(0,1.0),' + "取值越大,生成的随机性越高;启用文档检索后," + '文档引用类型,取值包括:simple|indexed。", ' + '"title": "API接口说明", "titlePath": ' + '"API接口说明>>>接口说明>>>是否必 说明>>>填"}], ' + '"status": "SUCCESS"}' + ), + "response": ( + "API接口说明中, TopP参数是一个float类型的" + "参数,取值范围为0到1.0,默认为1.0。取值越大," + "生成的随机性越高。[5]" + ), }, ], }, @@ -86,11 +90,12 @@ def test_rag_call(self, mock_server: MockServer): top_p=0.2, temperature=1.0, doc_tag_codes=["t1234", "t2345"], - doc_reference_type=Application.DocReferenceType.simple, + doc_reference_type=(Application.DocReferenceType.simple), has_thoughts=True, ) - self.check_result(resp, test_response) + # Test mock response type + self.check_result(resp, test_response) # type: ignore[arg-type] def test_flow_call(self, mock_server: MockServer): test_response = { @@ -105,13 +110,19 @@ def test_flow_call(self, mock_server: MockServer): "action_type": "api", "action_name": "plugin", "action": "api", - "action_input_stream": '{"userId": "123", "date": "202402", "city": "hangzhou"}', + "action_input_stream": ( + '{"userId": "123", "date": "202402", ' + '"city": "hangzhou"}' + ), "action_input": { "userId": "123", "date": "202402", "city": "hangzhou", }, - "observation": """{"quantity": 102, "type": "resident", "date": "202402", "unit": "千瓦"}""", + "observation": ( + '{"quantity": 102, "type": "resident", ' + '"date": "202402", "unit": "千瓦"}' + ), "response": "当月的居民用电量为102千瓦。", }, ], @@ -140,7 +151,8 @@ def test_flow_call(self, mock_server: MockServer): has_thoughts=True, ) - self.check_result(resp, test_response) + # Test mock response type + self.check_result(resp, test_response) # type: ignore[arg-type] def test_call_with_error(self, mock_server: MockServer): test_response = { @@ -212,21 +224,21 @@ def check_result(resp: ApplicationResponse, test_response: Dict): expected_doc_refs, ) - for i in range(len(doc_refs)): - assert doc_refs[i].index_id == expected_doc_refs[i].get( + for i, doc_ref in enumerate(doc_refs): + assert doc_ref.index_id == expected_doc_refs[i].get( "index_id", ) - assert doc_refs[i].doc_id == expected_doc_refs[i].get("doc_id") - assert doc_refs[i].doc_name == expected_doc_refs[i].get( + assert doc_ref.doc_id == expected_doc_refs[i].get("doc_id") + assert doc_ref.doc_name == expected_doc_refs[i].get( "doc_name", ) - assert doc_refs[i].doc_url == expected_doc_refs[i].get( + assert doc_ref.doc_url == expected_doc_refs[i].get( "doc_url", ) - assert doc_refs[i].title == expected_doc_refs[i].get("title") - assert doc_refs[i].text == expected_doc_refs[i].get("text") - assert doc_refs[i].biz_id == expected_doc_refs[i].get("biz_id") - assert json.dumps(doc_refs[i].images) == json.dumps( + assert doc_ref.title == expected_doc_refs[i].get("title") + assert doc_ref.text == expected_doc_refs[i].get("text") + assert doc_ref.biz_id == expected_doc_refs[i].get("biz_id") + assert json.dumps(doc_ref.images) == json.dumps( expected_doc_refs[i].get("images"), ) @@ -238,26 +250,26 @@ def check_result(resp: ApplicationResponse, test_response: Dict): expected_thoughts, ) - for i in range(len(thoughts)): - assert thoughts[i].thought == expected_thoughts[i].get( + for i, thought in enumerate(thoughts): + assert thought.thought == expected_thoughts[i].get( "thought", ) - assert thoughts[i].action == expected_thoughts[i].get("action") - assert thoughts[i].action_name == expected_thoughts[i].get( + assert thought.action == expected_thoughts[i].get("action") + assert thought.action_name == expected_thoughts[i].get( "action_name", ) - assert thoughts[i].action_type == expected_thoughts[i].get( + assert thought.action_type == expected_thoughts[i].get( "action_type", ) - assert json.dumps(thoughts[i].action_input) == json.dumps( + assert json.dumps(thought.action_input) == json.dumps( expected_thoughts[i].get("action_input"), ) - assert thoughts[i].action_input_stream == expected_thoughts[ - i - ].get("action_input_stream") - assert thoughts[i].observation == expected_thoughts[i].get( + assert thought.action_input_stream == ( + expected_thoughts[i].get("action_input_stream") + ) + assert thought.observation == expected_thoughts[i].get( "observation", ) - assert thoughts[i].response == expected_thoughts[i].get( + assert thought.response == expected_thoughts[i].get( "response", ) diff --git a/tests/legacy/test_asr_phrases.py b/tests/unit/test_asr_phrases.py similarity index 68% rename from tests/legacy/test_asr_phrases.py rename to tests/unit/test_asr_phrases.py index 01cfda1..942b55b 100644 --- a/tests/legacy/test_asr_phrases.py +++ b/tests/unit/test_asr_phrases.py @@ -40,7 +40,10 @@ def setup_class(cls): cls.update_phrase = {"黄鸡": 2, "红鸡": 1} cls.phrase_id = TEST_JOB_ID - def test_create_phrases(self, http_server): + def test_create_phrases( + self, + http_server, + ): # pylint: disable=unused-argument result = AsrPhraseManager.create_phrases( model=self.model, phrases=self.phrase, @@ -54,7 +57,10 @@ def test_create_phrases(self, http_server): assert len(result.output["finetuned_output"]) > 0 self.phrase_id = result.output["finetuned_output"] - def test_update_phrases(self, http_server): + def test_update_phrases( + self, + http_server, + ): # pylint: disable=unused-argument result = AsrPhraseManager.update_phrases( model=self.model, phrase_id=self.phrase_id, @@ -66,7 +72,10 @@ def test_update_phrases(self, http_server): assert result.output["finetuned_output"] is not None assert len(result.output["finetuned_output"]) > 0 - def test_query_phrases(self, http_server): + def test_query_phrases( + self, + http_server, + ): # pylint: disable=unused-argument result = AsrPhraseManager.query_phrases(phrase_id=self.phrase_id) assert result is not None assert result.status_code == HTTPStatus.OK @@ -75,14 +84,20 @@ def test_query_phrases(self, http_server): assert result.output["model"] is not None assert len(result.output["model"]) > 0 - def test_list_phrases(self, http_server): + def test_list_phrases( + self, + http_server, + ): # pylint: disable=unused-argument result = AsrPhraseManager.list_phrases(page=1, page_size=10) assert result is not None assert result.status_code == HTTPStatus.OK assert result.output["finetuned_outputs"] is not None assert len(result.output["finetuned_outputs"]) > 0 - def test_delete_phrases(self, http_server): + def test_delete_phrases( + self, + http_server, + ): # pylint: disable=unused-argument result = AsrPhraseManager.delete_phrases(phrase_id=self.phrase_id) assert result is not None assert result.status_code == HTTPStatus.OK @@ -90,24 +105,26 @@ def test_delete_phrases(self, http_server): assert len(result.output["finetuned_output"]) > 0 -def str2bool(str): - return True if str.lower() == "true" else False +def str2bool(test): # pylint: disable=redefined-builtin + # Return True if test string is "true", False otherwise + return test.lower() == "true" -def complete_url(url: str) -> str: +def complete_url(url: str) -> None: + # Set base URLs for dashscope API parsed = urlparse(url) base_url = "".join([parsed.scheme, "://", parsed.netloc]) dashscope.base_websocket_api_url = "/".join( [base_url, "api-ws", dashscope.common.env.api_version, "inference"], ) - dashscope.base_http_api_url = url = "/".join( + dashscope.base_http_api_url = "/".join( [base_url, "api", dashscope.common.env.api_version], ) print("Set base_websocket_api_url: ", dashscope.base_websocket_api_url) print("Set base_http_api_url: ", dashscope.base_http_api_url) -def phrases( +def phrases( # pylint: disable=redefined-outer-name,too-many-branches model, phrase_id: str, phrases: dict, @@ -115,6 +132,7 @@ def phrases( page_size: int, delete: bool, ): + # Manage ASR phrases based on provided parameters print("phrase_id: ", phrase_id) print("phrase: ", phrases) print("delete flag: ", delete) @@ -126,33 +144,30 @@ def phrases( phrase_id=phrase_id, phrases=phrases, ) - else: - print("Create phrases -->") - return AsrPhraseManager.create_phrases( - model=model, - phrases=phrases, - ) - else: - if delete: - print("Delete phrases -->") - return AsrPhraseManager.delete_phrases(phrase_id=phrase_id) - else: - if phrase_id is not None: - print("Query phrases -->") - return AsrPhraseManager.query_phrases(phrase_id=phrase_id) - if page is not None and page_size is not None: - print( - "List phrases page %d page_size %d -->" - % (page, page_size), - ) - return AsrPhraseManager.list_phrases( - page=page, - page_size=page_size, - ) + print("Create phrases -->") + return AsrPhraseManager.create_phrases( + model=model, + phrases=phrases, + ) + if delete: + print("Delete phrases -->") + return AsrPhraseManager.delete_phrases(phrase_id=phrase_id) + if phrase_id is not None: + print("Query phrases -->") + return AsrPhraseManager.query_phrases(phrase_id=phrase_id) + if page is not None and page_size is not None: + print( + f"List phrases page {page} page_size {page_size} -->", + ) + return AsrPhraseManager.list_phrases( + page=page, + page_size=page_size, + ) + return None @pytest.mark.skip -def test_by_user(): +def test_by_user(): # pylint: disable=too-many-branches parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="paraformer-realtime-v1") parser.add_argument("--phrase", type=str, default="") @@ -184,21 +199,23 @@ def test_by_user(): print("Response of phrases: ", resp) if resp is not None and resp.output is not None: output = resp.output - print("\nGet output: %s\n" % (str(output))) + print(f"\nGet output: {str(output)}\n") if ( "finetuned_output" in output and output["finetuned_output"] is not None ): - print("Get phrase_id: %s" % (output["finetuned_output"])) + print( + f"Get phrase_id: {output['finetuned_output']}", + ) if "job_id" in output and output["job_id"] is not None: - print("Get job_id: %s" % (output["job_id"])) + print(f"Get job_id: {output['job_id']}") if "create_time" in output and output["create_time"] is not None: - print("Get create_time: %s" % (output["create_time"])) + print(f"Get create_time: {output['create_time']}") if "model" in output and output["model"] is not None: - print("Get model_id: %s" % (output["model"])) + print(f"Get model_id: {output['model']}") if "output_type" in output and output["output_type"] is not None: - print("Get output_type: %s" % (output["output_type"])) + print(f"Get output_type: {output['output_type']}") if ( "finetuned_outputs" in output @@ -206,24 +223,23 @@ def test_by_user(): ): outputs = output["finetuned_outputs"] print( - "Get %d info from page_no:%d page_size:%d total:%d ->" - % ( - len(outputs), - output["page_no"], - output["page_size"], - output["total"], - ), + f"Get {len(outputs)} info from " + f"page_no:{output['page_no']} " + f"page_size:{output['page_size']} " + f"total:{output['total']} ->", ) for item in outputs: - print(" get phrase_id: %s" % (item["finetuned_output"])) - print(" get job_id: %s" % (item["job_id"])) - print(" get create_time: %s" % (item["create_time"])) - print(" get model_id: %s" % (item["model"])) - print(" get output_type: %s\n" % (item["output_type"])) + print( + f" get phrase_id: {item['finetuned_output']}", + ) + print(f" get job_id: {item['job_id']}") + print(f" get create_time: {item['create_time']}") + print(f" get model_id: {item['model']}") + print(f" get output_type: {item['output_type']}\n") else: print( - "ERROR, status_code:%d, code_message:%s, error_message:%s" - % (resp.status_code, resp.code, resp.message), + f"ERROR, status_code:{resp.status_code}, " + f"code_message:{resp.code}, error_message:{resp.message}", ) diff --git a/tests/legacy/test_assistant_files.py b/tests/unit/test_assistant_files.py similarity index 100% rename from tests/legacy/test_assistant_files.py rename to tests/unit/test_assistant_files.py diff --git a/tests/legacy/test_assistants.py b/tests/unit/test_assistants.py similarity index 90% rename from tests/legacy/test_assistants.py rename to tests/unit/test_assistants.py index 569a91c..1316f37 100644 --- a/tests/legacy/test_assistants.py +++ b/tests/unit/test_assistants.py @@ -16,9 +16,8 @@ class TestAssistants(MockServerBase): @classmethod def setup_class(cls): - cls.case_data = json.load( - open("tests/data/assistant.json", "r", encoding="utf-8"), - ) + with open("tests/data/assistant.json", "r", encoding="utf-8") as f: + cls.case_data = json.load(f) super().setup_class() def test_create_assistant_only_model(self, mock_server: MockServer): @@ -96,12 +95,19 @@ def test_create_assistant(self, mock_server: MockServer): assert req["tools"] == [{"type": "search"}, {"type": "wanx"}] assert req["instructions"] == "Your a helpful assistant." assert req["name"] == "hello" - assert response.file_ids == [] + assert not response.file_ids assert response.instructions == req["instructions"] assert response.metadata == req["metadata"] def test_create_assistant_function_call(self, mock_server: MockServer): + # Accessing dict key in test mock data + assert self.case_data is not None + # type: ignore[index] + # pylint: disable=unsubscriptable-object request_body = self.case_data["test_function_call_request"] + # Accessing dict key in test mock data + # type: ignore[index] + # pylint: disable=unsubscriptable-object response_body = json.dumps( self.case_data["test_function_call_response"], ) @@ -110,7 +116,7 @@ def test_create_assistant_function_call(self, mock_server: MockServer): req = mock_server.requests.get(block=True) assert response.model == req["model"] assert response.tools[2].function.name == "big_add" - assert response.file_ids == [] + assert not response.file_ids assert response.instructions == req["instructions"] def test_retrieve_assistant(self, mock_server: MockServer): @@ -145,11 +151,15 @@ def test_retrieve_assistant(self, mock_server: MockServer): req_assistant_id = mock_server.requests.get(block=True) assert response.model == self.TEST_MODEL_NAME assert req_assistant_id == self.ASSISTANT_ID - assert response.file_ids == [] + assert not response.file_ids assert response.instructions == response_obj["instructions"] assert response.metadata == response_obj["metadata"] def test_list_assistant(self, mock_server: MockServer): + # Accessing dict key in test mock data + assert self.case_data is not None + # type: ignore[index] + # pylint: disable=unsubscriptable-object response_obj = self.case_data["test_list"] mock_server.responses.put(json.dumps(response_obj)) response = Assistants.list( @@ -161,9 +171,9 @@ def test_list_assistant(self, mock_server: MockServer): ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert ( - req - == "/api/v1/assistants?limit=10&order=inc&after=after&before=before" + assert req == ( + "/api/v1/assistants?limit=10&order=inc&" + "after=after&before=before" ) assert len(response.data) == 2 assert response.data[0].id == "asst_1" @@ -199,7 +209,7 @@ def test_update_assistant(self, mock_server: MockServer): assert req is not None assert response.model == self.TEST_MODEL_NAME assert response.id == self.ASSISTANT_ID - assert response.file_ids == [] + assert not response.file_ids assert response.instructions == response_obj["instructions"] assert response.tools == response_obj["tools"] assert response.metadata == response_obj["metadata"] diff --git a/tests/legacy/test_code_generation.py b/tests/unit/test_code_generation.py similarity index 78% rename from tests/legacy/test_code_generation.py rename to tests/unit/test_code_generation.py index ce1c339..831ce9c 100644 --- a/tests/legacy/test_code_generation.py +++ b/tests/unit/test_code_generation.py @@ -15,6 +15,7 @@ model = CodeGeneration.Models.tongyi_lingma_v1 # yapf: disable +# pylint: disable=line-too-long class TestCodeGenerationRequest(MockServerBase): @@ -43,7 +44,10 @@ def test_custom_sample(self, mock_server: MockServer): scene=CodeGeneration.Scenes.custom, message=[ UserRoleMessageParam( - content='根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。', + content=( + '根据下面的功能描述生成一个python函数。代码的功能是' + '计算给定路径下所有文件的总大小。' + ), ), ], ) @@ -52,13 +56,27 @@ def test_custom_sample(self, mock_server: MockServer): assert req['input']['scene'] == 'custom' assert json.dumps( req['input']['message'], ensure_ascii=False, - ) == '[{"role": "user", "content": "根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。"}]' + ) == ( + '[{"role": "user", "content": ' + '"根据下面的功能描述生成一个python函数。' + '代码的功能是计算给定路径下所有文件的总大小。"}]' + ) assert response.status_code == HTTPStatus.OK assert response.request_id == 'bf321b27-a3ff-9674-a70e-be5f40a435e4' - assert response.output['choices'][0][ - 'content' - ] == '以下是生成Python函数的代码:\n\n```python\ndef file_size(path):\n total_size = 0\n for root, dirs, files in os.walk(path):\n for file in files:\n full_path = os.path.join(root, file)\n total_size += os.path.getsize(full_path)\n return total_size\n```\n\n函数名为`file_size`,输入参数是给定路径`path`。函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数遍历根目录及其子目录下的文件,计算每个文件的大小并累加到总大小上。最后,返回总大小作为函数的返回值。' # noqa E501 + assert response.output['choices'][0]['content'] == ( + '以下是生成Python函数的代码:\n\n```python\n' + 'def file_size(path):\n total_size = 0\n' + ' for root, dirs, files in os.walk(path):\n' + ' for file in files:\n' + ' full_path = os.path.join(root, file)\n' + ' total_size += os.path.getsize(full_path)\n' + ' return total_size\n```\n\n' + '函数名为`file_size`,输入参数是给定路径`path`。' + '函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数' + '遍历根目录及其子目录下的文件,计算每个文件的大小并累加到' + '总大小上。最后,返回总大小作为函数的返回值。' + ) assert response.output['choices'][0]['frame_id'] == 25 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 198 @@ -87,7 +105,7 @@ def test_custom_dict_sample(self, mock_server: MockServer): response = CodeGeneration.call( model=model, scene=CodeGeneration.Scenes.custom, - message=[{ + message=[{ # type: ignore[list-item] 'role': 'user', 'content': '根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。', }], @@ -97,7 +115,11 @@ def test_custom_dict_sample(self, mock_server: MockServer): assert req['input']['scene'] == 'custom' assert json.dumps( req['input']['message'], ensure_ascii=False, - ) == '[{"role": "user", "content": "根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。"}]' + ) == ( + '[{"role": "user", "content": ' + '"根据下面的功能描述生成一个python函数。代码的功能是' + '计算给定路径下所有文件的总大小。"}]' + ) assert response.status_code == HTTPStatus.OK assert response.request_id == 'bf321b27-a3ff-9674-a70e-be5f40a435e4' @@ -142,13 +164,31 @@ def test_nl2code_sample(self, mock_server: MockServer): assert req['input']['scene'] == 'nl2code' assert json.dumps( req['input']['message'], ensure_ascii=False, - ) == '[{"role": "user", "content": "计算给定路径下所有文件的总大小"}, {"role": "attachment", "meta": {"language": "java"}}]' + ) == ( + '[{"role": "user", "content": ' + '"计算给定路径下所有文件的总大小"}, ' + '{"role": "attachment", "meta": {"language": "java"}}]' + ) assert response.status_code == HTTPStatus.OK assert response.request_id == '59bbbea3-29a7-94d6-8c39-e4d6e465f640' - assert response.output['choices'][0][ - 'content' - ] == "```java\n/**\n * 计算给定路径下所有文件的总大小\n * @param path 路径\n * @return 总大小,单位为字节\n */\npublic static long getTotalFileSize(String path) {\n long size = 0;\n try {\n File file = new File(path);\n File[] files = file.listFiles();\n for (File f : files) {\n if (f.isFile()) {\n size += f.length();\n }\n }\n } catch (Exception e) {\n e.printStackTrace();\n }\n return size;\n}\n```\n\n使用方式:\n```java\nlong size = getTotalFileSize(\"/home/user/Documents/\");\nSystem.out.println(\"总大小:\" + size + \"字节\");\n```\n\n示例输出:\n```\n总大小:37144952字节\n```" # noqa E501 + assert response.output['choices'][0]['content'] == ( + "```java\n/**\n * 计算给定路径下所有文件的总大小\n" + " * @param path 路径\n * @return 总大小,单位为字节\n */\n" + "public static long getTotalFileSize(String path) {\n" + " long size = 0;\n try {\n" + " File file = new File(path);\n" + " File[] files = file.listFiles();\n" + " for (File f : files) {\n" + " if (f.isFile()) {\n" + " size += f.length();\n }\n }\n" + " } catch (Exception e) {\n" + " e.printStackTrace();\n }\n" + " return size;\n}\n```\n\n使用方式:\n```java\n" + "long size = getTotalFileSize(\"/home/user/Documents/\");\n" + "System.out.println(\"总大小:\" + size + \"字节\");\n```\n\n" + "示例输出:\n```\n总大小:37144952字节\n```" + ) assert response.output['choices'][0]['frame_id'] == 29 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 229 @@ -196,13 +236,32 @@ def test_code2comment_sample(self, mock_server: MockServer): assert json.dumps( req['input']['message'], ensure_ascii=False, - ) == '[{"role": "user", "content": "1. 生成中文注释\n2. 仅生成代码部分,不需要额外解释函数功能\n"}, {"role": "attachment", "meta": {"code": "\t\t@Override\n\t\tpublic CancelExportTaskResponse cancelExportTask(\n\t\t\t\tCancelExportTask cancelExportTask) {\n\t\t\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\t\t\treturn ec2Service.cancelExportTask(cancelExportTask);\n\t\t}", "language": "java"}}]'.replace('\t', '\\t').replace('\n', '\\n') # noqa E501 + ) == ( + '[{"role": "user", "content": ' + '"1. 生成中文注释\n2. 仅生成代码部分,不需要额外解释函数功能\n"}, ' + '{"role": "attachment", "meta": {"code": ' + '"\t\t@Override\n\t\tpublic CancelExportTaskResponse ' + 'cancelExportTask(\n\t\t\t\tCancelExportTask ' + 'cancelExportTask) {\n\t\t\tAmazonEC2SkeletonInterface ' + 'ec2Service = ServiceProvider.getInstance().' + 'getServiceImpl(AmazonEC2SkeletonInterface.class);\n\t\t\t' + 'return ec2Service.cancelExportTask(cancelExportTask);\n\t\t}", ' + '"language": "java"}}]' + ).replace('\t', '\\t').replace('\n', '\\n') assert response.status_code == HTTPStatus.OK assert response.request_id == 'b5e55877-bfa3-9863-88d8-09a72124cf8a' - assert response.output['choices'][0][ - 'content' - ] == '```java\n/**\n * 取消导出任务的回调函数\n *\n * @param cancelExportTask 取消导出任务的请求对象\n * @return 取消导出任务的响应对象\n */\n@Override\npublic CancelExportTaskResponse cancelExportTask(CancelExportTask cancelExportTask) {\n\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\treturn ec2Service.cancelExportTask(cancelExportTask);\n}\n```' # noqa E501 + assert response.output['choices'][0]['content'] == ( + '```java\n/**\n * 取消导出任务的回调函数\n *\n' + ' * @param cancelExportTask 取消导出任务的请求对象\n' + ' * @return 取消导出任务的响应对象\n */\n@Override\n' + 'public CancelExportTaskResponse cancelExportTask' + '(CancelExportTask cancelExportTask) {\n\t' + 'AmazonEC2SkeletonInterface ec2Service = ' + 'ServiceProvider.getInstance().' + 'getServiceImpl(AmazonEC2SkeletonInterface.class);\n\t' + 'return ec2Service.cancelExportTask(cancelExportTask);\n}\n```' + ) assert response.output['choices'][0]['frame_id'] == 17 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 133 @@ -247,13 +306,29 @@ def test_code2explain_sample(self, mock_server: MockServer): assert req['input']['scene'] == 'code2explain' assert json.dumps( req['input']['message'], ensure_ascii=False, - ) == '[{"role": "user", "content": "要求不低于200字"}, {"role": "attachment", "meta": {"code": "@Override\n public int getHeaderCacheSize()\n {\n return 0;\n }\n\n", "language": "java"}}]'.replace('\t', '\\t').replace('\n', '\\n') # noqa E501 + ) == ( + '[{"role": "user", "content": "要求不低于200字"}, ' + '{"role": "attachment", "meta": {"code": ' + '"@Override\n ' + 'public int getHeaderCacheSize()\n' + ' {\n' + ' return 0;\n' + ' }\n\n", "language": "java"}}]' + ).replace('\t', '\\t').replace('\n', '\\n') assert response.status_code == HTTPStatus.OK assert response.request_id == '089e525f-d28f-9e08-baa2-01dde87c90a7' - assert response.output['choices'][0][ - 'content' - ] == '这个Java函数是一个覆盖了另一个方法的函数,名为`getHeaderCacheSize()`。这个方法是从另一个已覆盖的方法继承过来的。在`@Override`声明中,可以确定这个函数覆盖了一个其他的函数。这个函数的返回类型是`int`。\n\n函数内容是:返回0。这个值意味着在`getHeaderCacheSize()`方法中,不会进行任何处理或更新。因此,返回的`0`值应该是没有被处理或更新的值。\n\n总的来说,这个函数的作用可能是为了让另一个方法返回一个预设的值。但是由于`@Override`的提示,我们无法确定它的真正目的,需要进一步查看代码才能得到更多的信息。' # noqa E501 + assert response.output['choices'][0]['content'] == ( + '这个Java函数是一个覆盖了另一个方法的函数,' + '名为`getHeaderCacheSize()`。这个方法是从另一个已覆盖的' + '方法继承过来的。在`@Override`声明中,可以确定这个函数' + '覆盖了一个其他的函数。这个函数的返回类型是`int`。\n\n' + '函数内容是:返回0。这个值意味着在`getHeaderCacheSize()`' + '方法中,不会进行任何处理或更新。因此,返回的`0`值应该是' + '没有被处理或更新的值。\n\n总的来说,这个函数的作用可能是为了' + '让另一个方法返回一个预设的值。但是由于`@Override`的提示,' + '我们无法确定它的真正目的,需要进一步查看代码才能得到更多的信息。' + ) assert response.output['choices'][0]['frame_id'] == 30 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 235 @@ -357,9 +432,17 @@ def test_unittest_sample(self, mock_server: MockServer): assert response.status_code == HTTPStatus.OK assert response.request_id == '6ec31e35-f355-9289-a18d-103abc36dece' - assert response.output['choices'][0][ - 'content' - ] == "这个函数用于解析时间戳映射表的输入字符串并返回该映射表的实例。函数有两个必选参数:typeClass - 用于标识数据类型的泛型;input - 输入的时间戳映射表字符串。如果typeClass为null,将抛出IllegalArgumentException异常;如果input为null,则返回null。函数内部首先检查输入的字符串是否等于\"空字符串\",如果是,则直接返回null;如果不是,则创建TimestampMap的实例,并使用input字符串创建字符串Reader对象。然后使用读取器逐个字符解析时间戳字符串,并在解析完成后返回相应的TimestampMap对象。函数的行为取决于传入的时间戳字符串类型。" # noqa E501 + assert response.output['choices'][0]['content'] == ( + "这个函数用于解析时间戳映射表的输入字符串并返回该映射表的" + "实例。函数有两个必选参数:typeClass - 用于标识数据类型的" + "泛型;input - 输入的时间戳映射表字符串。如果typeClass为" + "null,将抛出IllegalArgumentException异常;如果input为null," + "则返回null。函数内部首先检查输入的字符串是否等于\"空字符串\"," + "如果是,则直接返回null;如果不是,则创建TimestampMap的实例," + "并使用input字符串创建字符串Reader对象。然后使用读取器逐个" + "字符解析时间戳字符串,并在解析完成后返回相应的TimestampMap" + "对象。函数的行为取决于传入的时间戳字符串类型。" + ) assert response.output['choices'][0]['frame_id'] == 29 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 227 @@ -404,9 +487,20 @@ def test_codeqa_sample(self, mock_server: MockServer): assert response.status_code == HTTPStatus.OK assert response.request_id == 'e09386b7-5171-96b0-9c6f-7128507e14e6' - assert response.output['choices'][0][ - 'content' - ] == "Yes, this is possible:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n [...]\n\n def doGET(self):\n # some stuff\n if \"X-Port\" in self.headers:\n # change the port in this request\n self.server_port = int(self.headers[\"X-Port\"])\n print(\"Changed port: %s\" % self.server_port)\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', self.server_port), MyRequestHandler)\nserver.serve_forever()" # noqa E501 + assert response.output['choices'][0]['content'] == ( + "Yes, this is possible:\n" + "class MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n" + " [...]\n\n def doGET(self):\n # some stuff\n" + " if \"X-Port\" in self.headers:\n" + " # change the port in this request\n" + " self.server_port = int(self.headers[\"X-Port\"])\n" + " print(\"Changed port: %s\" % self.server_port)\n" + " [...]\n\n" + "class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n" + " pass\n\n" + "server = ThreadingHTTPServer(('localhost', self.server_port), " + "MyRequestHandler)\nserver.serve_forever()" + ) assert response.output['choices'][0]['frame_id'] == 19 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 150 @@ -419,8 +513,10 @@ def test_nl2sql_sample(self, mock_server: MockServer): 'finish_reason': 'stop', 'frame_timestamp': 1694701323.4553578, 'index': 0, - 'content': - "SELECT SUM(score) as '小明的总分数' FROM student_score WHERE name = '小明';", + 'content': ( + "SELECT SUM(score) as '小明的总分数' " + "FROM student_score WHERE name = '小明';" + ), 'frame_id': 3, }], }, @@ -478,13 +574,27 @@ def test_nl2sql_sample(self, mock_server: MockServer): req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'nl2sql' - assert json.dumps(req['input']['message'], ensure_ascii=False) == """[{"role": "user", "content": "小明的总分数是多少"}, {"role": "attachment", "meta": {"synonym_infos": {"学生姓名": "姓名|名字|名称", "学生分数": "分数|得分"}, "recall_infos": [{"content": "student_score.id='小明'", "score": "0.83"}], "schema_infos": [{"table_id": "student_score", "table_desc": "学生分数表", "columns": [{"col_name": "id", "col_caption": "学生id", "col_desc": "例值为:1,2,3", "col_type": "string"}, {"col_name": "name", "col_caption": "学生姓名", "col_desc": "例值为:张三,李四,小明", "col_type": "string"}, {"col_name": "score", "col_caption": "学生分数", "col_desc": "例值为:98,100,66", "col_type": "string"}]}]}}]""" # noqa E501 + assert json.dumps(req['input']['message'], ensure_ascii=False) == ( + '[{"role": "user", "content": "小明的总分数是多少"}, ' + '{"role": "attachment", "meta": {"synonym_infos": ' + '{"学生姓名": "姓名|名字|名称", "学生分数": "分数|得分"}, ' + '"recall_infos": [{"content": "student_score.id=\'小明\'", ' + '"score": "0.83"}], "schema_infos": [{"table_id": ' + '"student_score", "table_desc": "学生分数表", "columns": ' + '[{"col_name": "id", "col_caption": "学生id", ' + '"col_desc": "例值为:1,2,3", "col_type": "string"}, ' + '{"col_name": "name", "col_caption": "学生姓名", ' + '"col_desc": "例值为:张三,李四,小明", "col_type": "string"}, ' + '{"col_name": "score", "col_caption": "学生分数", ' + '"col_desc": "例值为:98,100,66", "col_type": "string"}]}]}}]' + ) assert response.status_code == HTTPStatus.OK assert response.request_id == 'e61a35b7-db6f-90c2-8677-9620ffea63b6' - assert response.output['choices'][0][ - 'content' - ] == "SELECT SUM(score) as '小明的总分数' FROM student_score WHERE name = '小明';" + assert response.output['choices'][0]['content'] == ( + "SELECT SUM(score) as '小明的总分数' FROM student_score " + "WHERE name = '小明';" + ) assert response.output['choices'][0]['frame_id'] == 3 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 25 diff --git a/tests/legacy/test_conversation.py b/tests/unit/test_conversation.py similarity index 88% rename from tests/legacy/test_conversation.py rename to tests/unit/test_conversation.py index 7b25384..52c734d 100644 --- a/tests/legacy/test_conversation.py +++ b/tests/unit/test_conversation.py @@ -65,7 +65,7 @@ def test_message_request(self, mock_server: MockServer): messages = [{"role": "user", "content": prompt}] resp = Generation.call( model=model, - messages=messages, + messages=messages, # type: ignore[arg-type] max_tokens=1024, api_protocol="http", result_format="message", @@ -81,7 +81,7 @@ def test_message_request(self, mock_server: MockServer): assert resp.output.text is None assert resp.output.choices[0] == Choice( finish_reason="stop", - message={ + message={ # type: ignore[arg-type] "role": "assistant", "content": "hello world", }, @@ -152,9 +152,12 @@ def test_conversation_with_message_and_prompt( ) assert response.status_code == HTTPStatus.OK assert response.output.text is None - choices = TestConversationRequest.message_response_obj["output"][ - "choices" - ] + # Accessing dict key in test mock data + msg_resp_obj = TestConversationRequest.message_response_obj + # type: ignore[index] + choices = msg_resp_obj["output"]["choices"] # type: ignore[index] + # Comparing with mock data + # type: ignore[comparison-overlap] assert response.output.choices == choices req = mock_server.requests.get(block=True) assert req["model"] == model @@ -176,9 +179,12 @@ def test_conversation_with_messages(self, mock_server: MockServer): ) assert response.status_code == HTTPStatus.OK assert response.output.text is None - choices = TestConversationRequest.message_response_obj["output"][ - "choices" - ] + # Accessing dict key in test mock data + msg_resp_obj = TestConversationRequest.message_response_obj + # type: ignore[index] + choices = msg_resp_obj["output"]["choices"] # type: ignore[index] + # Comparing with mock data + # type: ignore[comparison-overlap] assert response.output.choices == choices req = mock_server.requests.get(block=True) assert req["model"] == model @@ -200,9 +206,12 @@ def test_conversation_call_with_messages(self, mock_server: MockServer): ) assert response.status_code == HTTPStatus.OK assert response.output.text is None - choices = TestConversationRequest.message_response_obj["output"][ - "choices" - ] + # Accessing dict key in test mock data + msg_resp_obj = TestConversationRequest.message_response_obj + # type: ignore[index] + choices = msg_resp_obj["output"]["choices"] # type: ignore[index] + # Comparing with mock data + # type: ignore[comparison-overlap] assert response.output.choices == choices req = mock_server.requests.get(block=True) assert req["model"] == model From 6ea59a601046889a91ba099a4413297fc1187ee6 Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Thu, 22 Jan 2026 10:02:49 +0800 Subject: [PATCH 2/2] feat(ci): add unit tests --- tests/legacy/test_websocket_async_api.py | 2 +- tests/legacy/test_websocket_parameters.py | 2 +- tests/legacy/test_websocket_sync_api.py | 2 +- tests/unit/mock_server.py | 2 +- .../test_http_deployments_api.py | 1 + tests/{legacy => unit}/test_http_files_api.py | 1 + .../test_http_fine_tunes_api.py | 42 ++++++++--- .../{legacy => unit}/test_http_models_api.py | 1 + tests/{legacy => unit}/test_messages.py | 55 ++++++++++----- .../test_multimodal_dialog.py | 47 +++++++------ tests/{legacy => unit}/test_rerank.py | 0 tests/{legacy => unit}/test_runs.py | 70 +++++++++++++------ .../test_sketch_image_synthesis.py | 16 ++++- .../test_speech_synthesis_v2.py | 4 +- tests/{legacy => unit}/test_text_embedding.py | 3 +- tests/{legacy => unit}/test_threads.py | 6 +- tests/{legacy => unit}/test_tokenization.py | 6 +- tests/{legacy => unit}/test_tokenizer.py | 0 .../test_translation_recognizer.py | 42 ++++++----- tests/{legacy => unit}/test_understanding.py | 0 .../websocket_mock_server_task_handler.py | 59 ++++++++-------- .../websocket_task_request.py | 9 +-- 22 files changed, 234 insertions(+), 136 deletions(-) rename tests/{legacy => unit}/test_http_deployments_api.py (97%) rename tests/{legacy => unit}/test_http_files_api.py (97%) rename tests/{legacy => unit}/test_http_fine_tunes_api.py (77%) rename tests/{legacy => unit}/test_http_models_api.py (94%) rename tests/{legacy => unit}/test_messages.py (73%) rename tests/{legacy => unit}/test_multimodal_dialog.py (80%) rename tests/{legacy => unit}/test_rerank.py (100%) rename tests/{legacy => unit}/test_runs.py (83%) rename tests/{legacy => unit}/test_sketch_image_synthesis.py (79%) rename tests/{legacy => unit}/test_speech_synthesis_v2.py (95%) rename tests/{legacy => unit}/test_text_embedding.py (91%) rename tests/{legacy => unit}/test_threads.py (95%) rename tests/{legacy => unit}/test_tokenization.py (84%) rename tests/{legacy => unit}/test_tokenizer.py (100%) rename tests/{legacy => unit}/test_translation_recognizer.py (71%) rename tests/{legacy => unit}/test_understanding.py (100%) rename tests/{legacy => unit}/websocket_mock_server_task_handler.py (85%) rename tests/{legacy => unit}/websocket_task_request.py (90%) diff --git a/tests/legacy/test_websocket_async_api.py b/tests/legacy/test_websocket_async_api.py index bf87b49..00a6a97 100644 --- a/tests/legacy/test_websocket_async_api.py +++ b/tests/legacy/test_websocket_async_api.py @@ -7,7 +7,7 @@ from dashscope.protocol.websocket import WebsocketStreamingMode from tests.unit.base_test import BaseTestEnvironment from tests.unit.constants import TestTasks -from tests.legacy.websocket_task_request import WebSocketRequest +from tests.unit.websocket_task_request import WebSocketRequest # set mock server url. base_websocket_api_url = "ws://localhost:8080/ws/aigc/v1" diff --git a/tests/legacy/test_websocket_parameters.py b/tests/legacy/test_websocket_parameters.py index d21403c..2211ec5 100644 --- a/tests/legacy/test_websocket_parameters.py +++ b/tests/legacy/test_websocket_parameters.py @@ -9,7 +9,7 @@ TEST_ENABLE_DATA_INSPECTION_REQUEST_ID, TestTasks, ) -from tests.legacy.websocket_task_request import WebSocketRequest +from tests.unit.websocket_task_request import WebSocketRequest def pytest_generate_tests(metafunc): diff --git a/tests/legacy/test_websocket_sync_api.py b/tests/legacy/test_websocket_sync_api.py index 67900ea..46dede1 100644 --- a/tests/legacy/test_websocket_sync_api.py +++ b/tests/legacy/test_websocket_sync_api.py @@ -5,7 +5,7 @@ from dashscope.protocol.websocket import WebsocketStreamingMode from tests.unit.base_test import BaseTestEnvironment from tests.unit.constants import TestTasks -from tests.legacy.websocket_task_request import WebSocketRequest +from tests.unit.websocket_task_request import WebSocketRequest def pytest_generate_tests(metafunc): diff --git a/tests/unit/mock_server.py b/tests/unit/mock_server.py index 559769d..a87dc37 100644 --- a/tests/unit/mock_server.py +++ b/tests/unit/mock_server.py @@ -33,7 +33,7 @@ list_fine_tune_handler, ) from tests.unit.mock_sse import sse_response -from tests.legacy.websocket_mock_server_task_handler import ( +from tests.unit.websocket_mock_server_task_handler import ( WebSocketTaskProcessor, ) diff --git a/tests/legacy/test_http_deployments_api.py b/tests/unit/test_http_deployments_api.py similarity index 97% rename from tests/legacy/test_http_deployments_api.py rename to tests/unit/test_http_deployments_api.py index 58472bd..ab50aff 100644 --- a/tests/legacy/test_http_deployments_api.py +++ b/tests/unit/test_http_deployments_api.py @@ -9,6 +9,7 @@ class TestDeploymentRequest(MockRequestBase): + # pylint: disable=unused-argument def test_create_deployment_tune_job(self, http_server): resp = Deployments.call( model="gpt", diff --git a/tests/legacy/test_http_files_api.py b/tests/unit/test_http_files_api.py similarity index 97% rename from tests/legacy/test_http_files_api.py rename to tests/unit/test_http_files_api.py index d02ed4d..950ff95 100644 --- a/tests/legacy/test_http_files_api.py +++ b/tests/unit/test_http_files_api.py @@ -8,6 +8,7 @@ class TestFileRequest(MockRequestBase): + # pylint: disable=unused-argument def test_upload_files(self, http_server): resp = Files.upload( file_path="tests/data/dogs.jpg", diff --git a/tests/legacy/test_http_fine_tunes_api.py b/tests/unit/test_http_fine_tunes_api.py similarity index 77% rename from tests/legacy/test_http_fine_tunes_api.py rename to tests/unit/test_http_fine_tunes_api.py index 843e6d7..3ea7e57 100644 --- a/tests/legacy/test_http_fine_tunes_api.py +++ b/tests/unit/test_http_fine_tunes_api.py @@ -18,13 +18,16 @@ class TestFineTuneRequest(MockServerBase): @classmethod def setup_class(cls): + # pylint: disable=consider-using-with cls.case_data = json.load( open('tests/data/fine_tune.json', 'r', encoding='utf-8'), ) super().setup_class() def test_create_fine_tune_job(self, mock_server: MockServer): - response_body = self.case_data['create_response'] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data['create_response'] # type: ignore mock_server.responses.put(json.dumps(response_body)) model = 'gpt' training_file_ids = 'training_001' @@ -47,10 +50,19 @@ def test_create_fine_tune_job(self, mock_server: MockServer): assert req['body']['hyper_parameters'] == hyper_parameters assert resp.output.job_id == response_body['output']['job_id'] assert resp.output.status == response_body['output']['status'] - assert resp.output.hyper_parameters == {'learning_rate': '2e-5', 'n_epochs': 10, 'batch_size': 32} + expected_hyper_params = { + 'learning_rate': '2e-5', + 'n_epochs': 10, + 'batch_size': 32, + } + assert resp.output.hyper_parameters == expected_hyper_params def test_create_fine_tune_job_with_files(self, mock_server: MockServer): - response_body = self.case_data['create_multi_files_response'] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data[ # type: ignore + 'create_multi_files_response' + ] mock_server.responses.put(json.dumps(response_body)) model = 'gpt' training_file_ids = ['training_001', 'training_002'] @@ -75,10 +87,13 @@ def test_create_fine_tune_job_with_files(self, mock_server: MockServer): assert resp.output.status == response_body['output']['status'] assert resp.output.training_file_ids == training_file_ids assert resp.output.validation_file_ids == validation_file_ids - assert resp.output.hyper_parameters == response_body['output']['hyper_parameters'] + expected_hyper_params = response_body['output']['hyper_parameters'] + assert resp.output.hyper_parameters == expected_hyper_params def test_list_fine_tune_job(self, mock_server: MockServer): - response_body = self.case_data['list_response'] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data['list_response'] # type: ignore mock_server.responses.put(json.dumps(response_body)) response = FineTunes.list( page_no=10, @@ -90,7 +105,9 @@ def test_list_fine_tune_job(self, mock_server: MockServer): assert response.output.jobs[0].job_id == 'ft-202403261454-d8b4' def test_get_fine_tune_job(self, mock_server: MockServer): - response_body = self.case_data['query_response'] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data['query_response'] # type: ignore mock_server.responses.put(json.dumps(response_body)) job_id = str(uuid.uuid4()) response = FineTunes.get(job_id=job_id) @@ -100,7 +117,11 @@ def test_get_fine_tune_job(self, mock_server: MockServer): def test_delete_fine_tune_job(self, mock_server: MockServer): request_id = str(uuid.uuid4()) - response_body = '{"output": {"status": "success"}, "request_id": "%s", "code": null, "message": "", "usage": null}' % request_id # noqa E501 + response_body = ( + '{"output": {"status": "success"}, ' + f'"request_id": "{request_id}", ' + '"code": null, "message": "", "usage": null}' + ) mock_server.responses.put(response_body) rsp = FineTunes.delete(TEST_JOB_ID) req = mock_server.requests.get(block=True) @@ -110,7 +131,11 @@ def test_delete_fine_tune_job(self, mock_server: MockServer): def test_cancel_fine_tune_job(self, mock_server: MockServer): request_id = str(uuid.uuid4()) - response_body = '{"output": {"status": "success"}, "request_id": "%s", "code": null, "message": "", "usage": null}' % request_id # noqa E501 + response_body = ( + '{"output": {"status": "success"}, ' + f'"request_id": "{request_id}", ' + '"code": null, "message": "", "usage": null}' + ) mock_server.responses.put(response_body) rsp = FineTunes.cancel(TEST_JOB_ID) req = mock_server.requests.get(block=True) @@ -118,6 +143,7 @@ def test_cancel_fine_tune_job(self, mock_server: MockServer): assert rsp.status_code == HTTPStatus.OK assert rsp.request_id == request_id + # pylint: disable=unused-argument def test_stream_event(self, mock_server: MockServer): responses = FineTunes.stream_events(TEST_JOB_ID) idx = 0 diff --git a/tests/legacy/test_http_models_api.py b/tests/unit/test_http_models_api.py similarity index 94% rename from tests/legacy/test_http_models_api.py rename to tests/unit/test_http_models_api.py index 2345c52..53a01ff 100644 --- a/tests/legacy/test_http_models_api.py +++ b/tests/unit/test_http_models_api.py @@ -9,6 +9,7 @@ class TestModelRequest(MockRequestBase): + # pylint: disable=unused-argument def test_list_models(self, http_server): rsp = Models.list() assert rsp.status_code == HTTPStatus.OK diff --git a/tests/legacy/test_messages.py b/tests/unit/test_messages.py similarity index 73% rename from tests/legacy/test_messages.py rename to tests/unit/test_messages.py index b1ac74f..5307f3b 100644 --- a/tests/legacy/test_messages.py +++ b/tests/unit/test_messages.py @@ -17,14 +17,20 @@ class TestMessages(MockServerBase): @classmethod def setup_class(cls): + # pylint: disable=consider-using-with cls.case_data = json.load( open("tests/data/messages.json", "r", encoding="utf-8"), ) super().setup_class() def test_create(self, mock_server: MockServer): - request_body = self.case_data["create_message_request"] - response_body = self.case_data["create_message_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + request_body = self.case_data["create_message_request"] # type: ignore + # pylint: disable=unsubscriptable-object + response_body = self.case_data[ # type: ignore + "create_message_response" + ] mock_server.responses.put(json.dumps(response_body)) response = Messages.create(**request_body) req = mock_server.requests.get(block=True) @@ -34,7 +40,11 @@ def test_create(self, mock_server: MockServer): assert len(response.content) == 1 def test_update(self, mock_server: MockServer): - response_body = self.case_data["create_message_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data[ # type: ignore + "create_message_response" + ] mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) message_id = str(uuid.uuid4()) @@ -55,7 +65,11 @@ def test_update(self, mock_server: MockServer): assert len(response.content) == 1 def test_retrieve(self, mock_server: MockServer): - response_obj = self.case_data["create_message_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data[ # type: ignore + "create_message_response" + ] response_str = json.dumps(response_obj) mock_server.responses.put(response_str) thread_id = "tid" @@ -68,7 +82,9 @@ def test_retrieve(self, mock_server: MockServer): assert len(response.content) == 1 def test_list(self, mock_server: MockServer): - response_obj = self.case_data["list_message_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data["list_message_response"] # type: ignore mock_server.responses.put(json.dumps(response_obj)) thread_id = "test_thread_id" response = Messages.list( @@ -80,16 +96,21 @@ def test_list(self, mock_server: MockServer): ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert ( - req - == f"/api/v1/threads/{thread_id}/messages?limit=10&order=inc&after=after&before=before" + expected_path = ( + f"/api/v1/threads/{thread_id}/messages?" + "limit=10&order=inc&after=after&before=before" ) + assert req == expected_path assert len(response.data) == 2 assert response.data[0].id == "msg_1" assert response.data[1].id == "msg_0" def test_list_message_files(self, mock_server: MockServer): - response_obj = self.case_data["list_message_files_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data[ # type: ignore + "list_message_files_response" + ] mock_server.responses.put(json.dumps(response_obj)) thread_id = "test_thread_id" message_id = "test_message_id" @@ -103,10 +124,11 @@ def test_list_message_files(self, mock_server: MockServer): ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert ( - req - == f"/api/v1/threads/{thread_id}/messages/{message_id}/files?limit=10&order=inc&after=after&before=before" - ) # noqa E501 + expected_path = ( + f"/api/v1/threads/{thread_id}/messages/{message_id}/files?" + "limit=10&order=inc&after=after&before=before" + ) + assert req == expected_path assert len(response.data) == 2 assert response.data[0].id == "file-1" assert response.data[1].id == "file-2" @@ -133,9 +155,10 @@ def test_retrieve_message_file(self, mock_server: MockServer): ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert ( - req - == f"/api/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}" + expected_path = ( + f"/api/v1/threads/{thread_id}/messages/{message_id}/" + f"files/{file_id}" ) + assert req == expected_path assert response.id == file_id assert response.message_id == message_id diff --git a/tests/legacy/test_multimodal_dialog.py b/tests/unit/test_multimodal_dialog.py similarity index 80% rename from tests/legacy/test_multimodal_dialog.py rename to tests/unit/test_multimodal_dialog.py index 7be5d87..ceaedb5 100644 --- a/tests/legacy/test_multimodal_dialog.py +++ b/tests/unit/test_multimodal_dialog.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -import sys -import pytest import logging +import sys import time -from dashscope.common.logging import logger + +import pytest from dashscope.multimodal.dialog_state import DialogState from dashscope.multimodal.multimodal_dialog import ( MultiModalDialog, @@ -21,9 +21,13 @@ logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) -# create console handler and set level to debug console_handler = logging.StreamHandler() +# Global variable for dialog ID +g_dialog_id = None +# Global variable for conversation instance +conver_instance = None + # 定义voice chat服务回调 class TestCallback(MultiModalCallback): @@ -36,20 +40,22 @@ def on_started(self, dialog_id): def on_stopped(self): logger.info("stopped with server.") - pass def on_state_changed(self, state: DialogState): if state == DialogState.LISTENING: # app.update_text("开始收音。。。请提问") - pass - elif state == DialogState.THINKING: - pass + return + if state == DialogState.THINKING: # app.update_text("思考中。。。请耐心等待") - elif state == DialogState.RESPONDING: - pass + return + if state == DialogState.RESPONDING: # app.update_text("正在回答。。。") + return - def on_speech_audio_data(self, data: bytes): + def on_speech_audio_data( + self, + data: bytes, + ): # pylint: disable=unused-argument # pcm_play.play(data) return @@ -60,22 +66,22 @@ def on_error(self, error): def on_responding_started(self): # 开始端侧播放 # pcm_play.start_play() - global conver_instance - conver_instance.send_local_responding_started() - return + global conver_instance # pylint: disable=global-variable-not-assigned + if conver_instance is not None: + conver_instance.send_local_responding_started() - def on_responding_ended(self, payload): + def on_responding_ended(self, payload): # pylint: disable=unused-argument logger.debug("on responding ended") - conver_instance.send_local_responding_ended() + global conver_instance # pylint: disable=global-variable-not-assigned + if conver_instance is not None: + conver_instance.send_local_responding_ended() # pcm_play.stop_play() def on_speech_content(self, payload): - pass if payload is not None: logger.debug(payload) def on_responding_content(self, payload): - pass if payload is not None: logger.debug(payload) @@ -86,8 +92,9 @@ def on_request_accepted(self): def on_close(self, close_status_code, close_msg): logger.info( - "close with status code: %d, msg: %s" - % (close_status_code, close_msg), + "close with status code: %d, msg: %s", + close_status_code, + close_msg, ) diff --git a/tests/legacy/test_rerank.py b/tests/unit/test_rerank.py similarity index 100% rename from tests/legacy/test_rerank.py rename to tests/unit/test_rerank.py diff --git a/tests/legacy/test_runs.py b/tests/unit/test_runs.py similarity index 83% rename from tests/legacy/test_runs.py rename to tests/unit/test_runs.py index e7ba9da..cfba03d 100644 --- a/tests/legacy/test_runs.py +++ b/tests/unit/test_runs.py @@ -16,13 +16,16 @@ class TestRuns(MockServerBase): @classmethod def setup_class(cls): + # pylint: disable=consider-using-with cls.case_data = json.load( open("tests/data/runs.json", "r", encoding="utf-8"), ) super().setup_class() def test_create_simple(self, mock_server: MockServer): - response_body = self.case_data["create_run_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data["create_run_response"] # type: ignore mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) @@ -33,7 +36,9 @@ def test_create_simple(self, mock_server: MockServer): assert response.metadata == {"key": "value"} def test_create_complicated(self, mock_server: MockServer): - response_body = self.case_data["create_run_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data["create_run_response"] # type: ignore mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) @@ -76,7 +81,7 @@ def test_create_complicated(self, mock_server: MockServer): model=model_name, instructions=instructions, additional_instructions=additional_instructions, - tools=tools, + tools=tools, # type: ignore[arg-type] metadata=metadata, ) req = mock_server.requests.get(block=True) @@ -91,7 +96,9 @@ def test_create_complicated(self, mock_server: MockServer): assert response.tools[0].type == "code_interpreter" def test_retrieve(self, mock_server: MockServer): - response_obj = self.case_data["create_run_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data["create_run_response"] # type: ignore response_str = json.dumps(response_obj) mock_server.responses.put(response_str) thread_id = "tid" @@ -104,7 +111,9 @@ def test_retrieve(self, mock_server: MockServer): assert response.tools[0].type == "code_interpreter" def test_list(self, mock_server: MockServer): - response_obj = self.case_data["list_run_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data["list_run_response"] # type: ignore mock_server.responses.put(json.dumps(response_obj)) thread_id = "test_thread_id" response = Runs.list( @@ -116,16 +125,19 @@ def test_list(self, mock_server: MockServer): ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert ( - req - == f"/api/v1/threads/{thread_id}/runs?limit=10&order=inc&after=after&before=before" + expected_path = ( + f"/api/v1/threads/{thread_id}/runs?" + "limit=10&order=inc&after=after&before=before" ) + assert req == expected_path assert len(response.data) == 1 assert response.data[0].id == "1" assert response.data[0].tools[2].type == "function" def test_create_thread_and_run(self, mock_server: MockServer): - response_body = self.case_data["create_run_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data["create_run_response"] # type: ignore mock_server.responses.put(json.dumps(response_body)) assistant_id = str(uuid.uuid4()) model_name = str(uuid.uuid4()) @@ -179,7 +191,7 @@ def test_create_thread_and_run(self, mock_server: MockServer): model=model_name, instructions=instructions, additional_instructions=additional_instructions, - tools=tools, + tools=tools, # type: ignore[arg-type] metadata=metadata, ) req = mock_server.requests.get(block=True) @@ -195,7 +207,11 @@ def test_create_thread_and_run(self, mock_server: MockServer): assert response.tools[0].type == "code_interpreter" def test_submit_tool_outputs(self, mock_server: MockServer): - response_body = self.case_data["submit_function_call_result"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_body = self.case_data[ # type: ignore + "submit_function_call_result" + ] mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) run_id = str(uuid.uuid4()) @@ -217,7 +233,11 @@ def test_submit_tool_outputs(self, mock_server: MockServer): assert response.tools[0].type == "code_interpreter" def test_run_required_function_call(self, mock_server: MockServer): - response_obj = self.case_data["required_action_function_call_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data[ # type: ignore + "required_action_function_call_response" + ] mock_server.responses.put(json.dumps(response_obj)) thread_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) @@ -238,7 +258,11 @@ def test_run_required_function_call(self, mock_server: MockServer): ) def test_list_run_steps(self, mock_server: MockServer): - response_obj = self.case_data["list_run_steps_response"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data[ # type: ignore + "list_run_steps_response" + ] mock_server.responses.put(json.dumps(response_obj)) thread_id = "test_thread_id" run_id = str(uuid.uuid4()) @@ -252,10 +276,11 @@ def test_list_run_steps(self, mock_server: MockServer): ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert ( - req - == f"/api/v1/threads/{thread_id}/runs/{run_id}/steps?limit=10&order=inc&after=after&before=before" + expected_path = ( + f"/api/v1/threads/{thread_id}/runs/{run_id}/steps?" + "limit=10&order=inc&after=after&before=before" ) + assert req == expected_path assert len(response.data) == 2 assert response.data[0].id == "step_1" assert response.data[0].step_details.type == "message_creation" @@ -284,7 +309,9 @@ def test_list_run_steps(self, mock_server: MockServer): ) def test_retrieve_run_steps(self, mock_server: MockServer): - response_obj = self.case_data["retrieve_run_step"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data["retrieve_run_step"] # type: ignore mock_server.responses.put(json.dumps(response_obj)) thread_id = str(uuid.uuid4()) run_id = str(uuid.uuid4()) @@ -300,9 +327,10 @@ def test_retrieve_run_steps(self, mock_server: MockServer): ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert ( - req == f"/api/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}" + expected_path = ( + f"/api/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}" ) + assert req == expected_path assert response.id == "step_1" assert response.step_details.type == "tool_calls" @@ -316,7 +344,9 @@ def test_retrieve_run_steps(self, mock_server: MockServer): assert response.usage.completion_tokens == 22 def test_cancel(self, mock_server: MockServer): - response_obj = self.case_data["retrieve_run_step"] + # type: ignore[index] + # pylint: disable=unsubscriptable-object + response_obj = self.case_data["retrieve_run_step"] # type: ignore mock_server.responses.put(json.dumps(response_obj)) thread_id = str(uuid.uuid4()) run_id = str(uuid.uuid4()) diff --git a/tests/legacy/test_sketch_image_synthesis.py b/tests/unit/test_sketch_image_synthesis.py similarity index 79% rename from tests/legacy/test_sketch_image_synthesis.py rename to tests/unit/test_sketch_image_synthesis.py index 47400df..7046207 100644 --- a/tests/legacy/test_sketch_image_synthesis.py +++ b/tests/unit/test_sketch_image_synthesis.py @@ -43,7 +43,13 @@ def test_with_all_parameters(self, mock_server: MockServer): realisticness=9, ) req = mock_server.requests.get(block=True) - expect_req_str = '{"model": "wanx-sketch-to-image-v1", "parameters": {"n": 4, "size": "1024*1024", "sketch_weight": 8, "realisticness": 9}, "input": {"prompt": "hello", "sketch_image_url": "http://sketch_url"}}' # noqa E501 + expect_req_str = ( + '{"model": "wanx-sketch-to-image-v1", ' + '"parameters": {"n": 4, "size": "1024*1024", ' + '"sketch_weight": 8, "realisticness": 9}, ' + '"input": {"prompt": "hello", ' + '"sketch_image_url": "http://sketch_url"}}' + ) expect_req = json.loads(expect_req_str) assert expect_req == req @@ -63,7 +69,13 @@ def test_with_not_all_parameters(self, mock_server: MockServer): realisticness=9, ) req = mock_server.requests.get(block=True) - expect_req_str = '{"model": "wanx-sketch-to-image-v1", "parameters": {"n": 4, "size": "1024*1024", "realisticness": 9}, "input": {"prompt": "hello", "sketch_image_url": "http://sketch_url"}}' # noqa E501 + expect_req_str = ( + '{"model": "wanx-sketch-to-image-v1", ' + '"parameters": {"n": 4, "size": "1024*1024", ' + '"realisticness": 9}, ' + '"input": {"prompt": "hello", ' + '"sketch_image_url": "http://sketch_url"}}' + ) expect_req = json.loads(expect_req_str) assert expect_req == req diff --git a/tests/legacy/test_speech_synthesis_v2.py b/tests/unit/test_speech_synthesis_v2.py similarity index 95% rename from tests/legacy/test_speech_synthesis_v2.py rename to tests/unit/test_speech_synthesis_v2.py index 31395fc..465cfd9 100644 --- a/tests/legacy/test_speech_synthesis_v2.py +++ b/tests/unit/test_speech_synthesis_v2.py @@ -26,7 +26,7 @@ def on_event(self, message): def on_data(self, data: bytes) -> None: # save audio to file - print("recv speech audio {}".format(len(data))) + print(f"recv speech audio {len(data)}") class TestSynthesis(BaseTestEnvironment): @@ -59,7 +59,7 @@ def test_sync_call_with_multi_formats(self): url=self.url, ) audio = synthesizer.call(self.text_array[0]) - print("recv audio length {}".format(len(audio))) + print(f"recv audio length {len(audio)}") @pytest.mark.skip def test_sync_streaming_call_with_multi_formats(self): diff --git a/tests/legacy/test_text_embedding.py b/tests/unit/test_text_embedding.py similarity index 91% rename from tests/legacy/test_text_embedding.py rename to tests/unit/test_text_embedding.py index 964aabf..9aaf76a 100644 --- a/tests/legacy/test_text_embedding.py +++ b/tests/unit/test_text_embedding.py @@ -8,6 +8,7 @@ class TestTextEmbeddingRequest(MockRequestBase): + # pylint: disable=unused-argument def test_call_with_string(self, http_server): resp = TextEmbedding.call( model=TextEmbedding.Models.text_embedding_v3, @@ -25,7 +26,7 @@ def test_call_with_list_str(self, http_server): assert len(resp.output["embeddings"]) == 1 def test_call_with_opened_file(self, http_server): - with open("tests/data/multi_line.txt") as f: + with open("tests/data/multi_line.txt", encoding="utf-8") as f: response = TextEmbedding.call( model=TextEmbedding.Models.text_embedding_v3, input=f, diff --git a/tests/legacy/test_threads.py b/tests/unit/test_threads.py similarity index 95% rename from tests/legacy/test_threads.py rename to tests/unit/test_threads.py index 0a8f092..d0a4f56 100644 --- a/tests/legacy/test_threads.py +++ b/tests/unit/test_threads.py @@ -25,7 +25,7 @@ def test_create_with_no_messages(self, mock_server: MockServer): req = mock_server.requests.get(block=True) assert response.id == thread_id assert response.metadata == metadata - req["metadata"] == metadata + assert req["metadata"] == metadata def test_create_with_messages(self, mock_server: MockServer): thread_id = str(uuid.uuid4()) @@ -49,12 +49,12 @@ def test_create_with_messages(self, mock_server: MockServer): "content": "画幅画", }, ] - thread = Threads.create(messages=messages) + thread = Threads.create(messages=messages) # type: ignore[arg-type] assert thread.id == thread_id assert thread.metadata == metadata req = mock_server.requests.get(block=True) - req["messages"] == messages + assert req["messages"] == messages def test_retrieve(self, mock_server: MockServer): thread_id = str(uuid.uuid4()) diff --git a/tests/legacy/test_tokenization.py b/tests/unit/test_tokenization.py similarity index 84% rename from tests/legacy/test_tokenization.py rename to tests/unit/test_tokenization.py index 8fe1739..7728fe7 100644 --- a/tests/legacy/test_tokenization.py +++ b/tests/unit/test_tokenization.py @@ -14,7 +14,11 @@ class TestTokenization(MockServerBase): "output": { "token_ids": [115798, 198], "tokens": ["<|im_start|>", "\n"], - "prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n如何做土豆炖猪脚?<|im_end|>\n<|im_start|>assistant\n", # noqa E501 + "prompt": ( + "<|im_start|>system\nYou are a helpful assistant." + "<|im_end|>\n<|im_start|>user\n如何做土豆炖猪脚?" + "<|im_end|>\n<|im_start|>assistant\n" + ), }, "usage": { "input_tokens": 28, diff --git a/tests/legacy/test_tokenizer.py b/tests/unit/test_tokenizer.py similarity index 100% rename from tests/legacy/test_tokenizer.py rename to tests/unit/test_tokenizer.py diff --git a/tests/legacy/test_translation_recognizer.py b/tests/unit/test_translation_recognizer.py similarity index 71% rename from tests/legacy/test_translation_recognizer.py rename to tests/unit/test_translation_recognizer.py index 8ed639e..5f9782f 100644 --- a/tests/legacy/test_translation_recognizer.py +++ b/tests/unit/test_translation_recognizer.py @@ -29,6 +29,7 @@ def on_open(self) -> None: def on_close(self) -> None: print(f"[{self.tag}] TranslationRecognizerCallback close.") + # pylint: disable=unused-argument def on_event( self, request_id, @@ -38,7 +39,7 @@ def on_event( ) -> None: if translation_result is not None: translation = translation_result.get_translation("en") - # print(f'[{self.tag}]RecognitionCallback text: ', sentence['text']) partial recognition result + # partial recognition result if translation.is_sentence_end: self.translate_text = self.translate_text + translation.text if transcription_result is not None: @@ -46,12 +47,12 @@ def on_event( self.text = self.text + transcription_result.text def on_error(self, message) -> None: - print("error: {}".format(message)) + print(f"error: {message}") def on_complete(self) -> None: print(f"[{self.tag}] Transcript ==> ", self.text) print(f"[{self.tag}] Translate ==> ", self.translate_text) - print(f"[{self.tag}] Translation completed") # translation complete + print(f"[{self.tag}] Translation completed") class TestSynthesis(BaseTestEnvironment): @@ -67,8 +68,10 @@ def setup_class(cls): def test_translate_from_file(self): callback = Callback(f"process {os.getpid()}", self.file) - # Call translation service by async mode, you can customize the translation parameters, like model, format, - # sample_rate For more information, please refer to https://help.aliyun.com/document_detail/2712536.html + # Call translation service by async mode, you can customize the + # translation parameters, like model, format, sample_rate + # For more information, please refer to + # https://help.aliyun.com/document_detail/2712536.html translator = TranslationRecognizerRealtime( model=self.model, format=self.format, @@ -82,30 +85,25 @@ def test_translate_from_file(self): # Start translation translator.start() - try: - audio_data: bytes = None - f = open(self.file, "rb") + with open(self.file, "rb") as f: if os.path.getsize(self.file): while True: audio_data = f.read(3200) if not audio_data: break - else: - translator.send_audio_frame(audio_data) + translator.send_audio_frame(audio_data) time.sleep(0.01) - else: - raise Exception( - "The supplied file was empty (zero bytes long)", - ) - f.close() - except Exception as e: - raise e + # pylint: disable=broad-exception-raised + raise Exception( + "The supplied file was empty (zero bytes long)", + ) translator.stop() + request_id = translator.get_last_request_id() + first_delay = translator.get_first_package_delay() + last_delay = translator.get_last_package_delay() print( - "[Metric] requestId: {}, first package delay ms: {}, last package delay ms: {}".format( - translator.get_last_request_id(), - translator.get_first_package_delay(), - translator.get_last_package_delay(), - ), + f"[Metric] requestId: {request_id}, " + f"first package delay ms: {first_delay}, " + f"last package delay ms: {last_delay}", ) diff --git a/tests/legacy/test_understanding.py b/tests/unit/test_understanding.py similarity index 100% rename from tests/legacy/test_understanding.py rename to tests/unit/test_understanding.py diff --git a/tests/legacy/websocket_mock_server_task_handler.py b/tests/unit/websocket_mock_server_task_handler.py similarity index 85% rename from tests/legacy/websocket_mock_server_task_handler.py rename to tests/unit/websocket_mock_server_task_handler.py index 742d4bd..80743d1 100644 --- a/tests/legacy/websocket_mock_server_task_handler.py +++ b/tests/unit/websocket_mock_server_task_handler.py @@ -42,15 +42,16 @@ def __init__( self.is_binary_out = is_binary_out self._duplex_task_finished = False - async def aio_call(self): - await self._send_start_event() # no matter what, send start event first. + async def aio_call(self): # pylint: disable=too-many-branches + # no matter what, send start event first. + await self._send_start_event() if self.streaming_mode == WebsocketStreamingMode.NONE: # if binary data, we need to receive data if self.is_binary_in: binary_data = ( await self._receive_batch_binary() ) # ignore timeout. - print("Receive binary data, length: %s" % len(binary_data)) + print(f"Receive binary data, length: {len(binary_data)}") # send "event":"task-finished" if self.is_binary_out: # send binary data @@ -79,7 +80,7 @@ async def aio_call(self): "usage": { "input_tokens": 100, "output_tokens": 200, - }, # noqa E501 + }, }, ) # for echo message out. @@ -104,7 +105,7 @@ async def aio_call(self): "output_tokens": 200, }, }, - ) # noqa E501 + ) else: await self._send_task_finished( payload={ @@ -154,7 +155,7 @@ async def aio_call(self): await self._send_task_finished(payload={}) async def send_streaming_binary_output(self): - for i in range(10): + for _ in range(10): data = bytes([0x01] * 100) await self.ws.send_bytes(data) @@ -163,7 +164,7 @@ async def send_streaming_text_output(self): "task_id": self.task_id, "event": "result-generated", } - for i in range(10): + for _ in range(10): payload = { "output": { "text": "world", @@ -185,13 +186,13 @@ async def _send_start_event(self): headers = {"task_id": self.task_id, EVENT_KEY: EventType.STARTED} payload = {} message = self._build_up_message(headers, payload=payload) - print("sending task started event message: %s" % message) + print(f"sending task started event message: {message}") await self.ws.send_str(message) async def _send_task_finished(self, payload): headers = {"task_id": self.task_id, EVENT_KEY: EventType.FINISHED} message = self._build_up_message(headers, payload) - print("sending task finished message: %s" % message) + print(f"sending task finished message: {message}") await self.ws.send_str(message) async def _receive_streaming_binary_data(self): @@ -200,20 +201,18 @@ async def _receive_streaming_binary_data(self): if await self.validate_message(msg): return if msg.type == aiohttp.WSMsgType.BINARY: - print( - "Receive binary data length: %s" % len(msg.data), - ) # real server need return data and process. + # real server need return data and process. + print(f"Receive binary data length: {len(msg.data)}") elif msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - print("Receive %s event" % req["header"][ACTION_KEY]) + print(f"Receive {req['header'][ACTION_KEY]} event") if req["header"][ACTION_KEY] == ActionType.FINISHED: self._duplex_task_finished = True break - else: - print("Unknown message: %s" % msg) + print(f"Unknown message: {msg}") else: raise UnexpectedMessageReceived( - "Expect binary data but receive %s!" % msg.type, + f"Expect binary data but receive {msg.type}!", ) async def _receive_streaming_text_data(self): @@ -225,21 +224,20 @@ async def _receive_streaming_text_data(self): return if msg.type == aiohttp.WSMsgType.TEXT: msg_json = msg.json() - print("Receive %s event" % msg_json["header"][ACTION_KEY]) + print(f"Receive {msg_json['header'][ACTION_KEY]} event") if msg_json["header"][ACTION_KEY] == ActionType.CONTINUE: - print("Receive text data: " % msg_json["payload"]) + print(f"Receive text data: {msg_json['payload']}") payload.append(msg_json["payload"]) elif msg_json["header"][ACTION_KEY] == ActionType.FINISHED: - print("Receive text data: " % msg_json["payload"]) + print(f"Receive text data: {msg_json['payload']}") if msg_json["payload"]: payload.append(msg_json["payload"]) self._duplex_task_finished = True return payload - else: - print("Unknown message: %s" % msg_json) + print(f"Unknown message: {msg_json}") else: raise UnexpectedMessageReceived( - "Expect binary data but receive %s!" % msg.type, + f"Expect binary data but receive {msg.type}!", ) async def _receive_batch_binary(self): @@ -256,10 +254,9 @@ async def _receive_batch_binary(self): break if msg.type == aiohttp.WSMsgType.BINARY: return msg.data - else: - raise UnexpectedMessageReceived( - "Expect binary data but receive %s!" % msg.type, - ) + raise UnexpectedMessageReceived( + f"Expect binary data but receive {msg.type}!", + ) async def _receive_batch_text(self): """If the data is not binary, data is send in start package. @@ -272,11 +269,11 @@ async def _receive_batch_text(self): final_data = self.run_task_json_message["payload"] while True: msg = await self.ws.receive() - if self.validate_message(): + if await self.validate_message(msg): break if msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - print("Receive %s event" % req["header"][ACTION_KEY]) + print(f"Receive {req['header'][ACTION_KEY]} event") if req["header"][ACTION_KEY] == ActionType.START: print("receive start task event") elif req["header"][ACTION_KEY] == ActionType.FINISHED: @@ -284,9 +281,9 @@ async def _receive_batch_text(self): await self._send_task_finished(final_data) break else: - print("Unknown message: %s" % msg) + print(f"Unknown message: {msg}") else: - raise UnexpectedMessageReceived("Expect text %s!" % msg.type) + raise UnexpectedMessageReceived(f"Expect text {msg.type}!") def _build_up_message(self, headers, payload): message = {"header": headers, "payload": payload} @@ -296,6 +293,6 @@ async def validate_message(self, msg): if msg.type == aiohttp.WSMsgType.CLOSED: print("Client close the connection") elif msg.type == aiohttp.WSMsgType.ERROR: - print("Connection error: %s" % msg.data) + print(f"Connection error: {msg.data}") return True return False diff --git a/tests/legacy/websocket_task_request.py b/tests/unit/websocket_task_request.py similarity index 90% rename from tests/legacy/websocket_task_request.py rename to tests/unit/websocket_task_request.py index 0c3d17f..00d1b1e 100644 --- a/tests/legacy/websocket_task_request.py +++ b/tests/unit/websocket_task_request.py @@ -1,10 +1,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from dashscope.api_entities.dashscope_response import ( - DashScopeAPIResponse, - GenerationResponse, -) +from dashscope.api_entities.dashscope_response import DashScopeAPIResponse from dashscope.client.base_api import BaseAioApi, BaseApi from dashscope.common.constants import ApiProtocol from dashscope.protocol.websocket import WebsocketStreamingMode @@ -39,7 +36,7 @@ async def aio_call( ) @classmethod - def call( + def call( # type: ignore[override] # pylint: disable=arguments-renamed cls, model: str, prompt: str, @@ -50,7 +47,7 @@ def call( ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, **kwargs, - ) -> GenerationResponse: + ) -> DashScopeAPIResponse: response = BaseApi.call( model=model, task_group=task_group,