From 938ce29c52208c1e725927e4fc01fb29f57c571c Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 28 Jan 2026 12:31:48 -0600 Subject: [PATCH 1/5] test: run lint/format/type checks on entire repo Signed-off-by: Alex Bozarth --- .pre-commit-config.yaml | 7 +-- cli/decompose/pipeline.py | 2 +- cli/eval/commands.py | 3 +- cli/eval/runner.py | 37 +++++++------ cli/m.py | 2 +- docs/__init__.py | 0 docs/examples/__init__.py | 0 docs/examples/aLora/101_example.py | 7 +-- docs/examples/agents/react.py | 7 +-- docs/examples/agents/react_instruct.py | 7 +-- docs/examples/conftest.py | 5 +- .../context/contexts_with_sampling.py | 2 +- .../generative_slots_with_requirements.py | 5 +- .../vision_litellm_backend.py | 8 +-- .../image_text_models/vision_ollama_chat.py | 1 + .../vision_openai_examples.py | 2 +- .../101_with_gen_slots.py | 4 +- .../advanced_with_m_instruct.py | 2 +- .../101_email_with_validate.py | 2 +- .../advanced_email_with_validate_function.py | 2 +- docs/examples/intrinsics/answer_relevance.py | 8 ++- docs/examples/intrinsics/answerability.py | 7 ++- docs/examples/intrinsics/citations.py | 11 ++-- docs/examples/intrinsics/context_relevance.py | 5 +- .../intrinsics/hallucination_detection.py | 11 ++-- docs/examples/intrinsics/intrinsics.py | 7 ++- docs/examples/intrinsics/query_rewrite.py | 5 +- .../library_interop/langchain_messages.py | 6 +-- .../m_serve/m_serve_example_simple.py | 2 +- docs/examples/m_serve/pii_serve.py | 2 +- docs/examples/mcp/README.md | 2 +- docs/examples/mcp/mcp_example.py | 2 +- docs/examples/melp/lazy.py | 6 +-- docs/examples/melp/lazy_fib.py | 7 +-- docs/examples/melp/lazy_fib_sample.py | 9 ++-- docs/examples/melp/simple_example.py | 9 ++-- docs/examples/melp/states.py | 4 +- .../examples/mify/rich_table_execute_basic.py | 3 +- docs/examples/mini_researcher/researcher.py | 2 +- docs/examples/mobject/table.py | 2 +- docs/examples/notebooks/georgia_tech.ipynb | 2 +- docs/examples/notebooks/m_serve_example.ipynb | 2 +- .../notebooks/model_options_example.ipynb | 5 +- docs/examples/rag/simple_rag_with_filter.py | 2 +- docs/examples/safety/guardian.py | 2 +- docs/examples/safety/guardian_huggingface.py | 6 +-- docs/examples/safety/repair_with_guardian.py | 9 ++-- .../creating_a_new_type_of_session.py | 13 +++-- docs/examples/tools/interpreter_example.py | 6 +-- .../compositionality_with_generative_slots.py | 4 +- docs/examples/tutorial/context_example.py | 3 +- docs/examples/tutorial/document_mobject.py | 6 +-- .../tutorial/instruct_validate_repair.py | 4 +- docs/examples/tutorial/mcp_example.py | 54 ------------------- .../tutorial/model_options_example.py | 3 +- docs/examples/tutorial/simple_email.py | 2 +- docs/kv_smash/hf_example.py | 7 +-- docs/kv_smash/kv_with_chat.py | 4 +- docs/kv_smash/kvcache.py | 3 +- .../{0.py => step0_session_api.py} | 3 +- .../{1.py => step1_functional_api.py} | 2 +- .../{2.py => step2_act_cblocks.py} | 6 +-- .../session_deepdive/{3.py => step3_async.py} | 11 ++-- .../{4.py => step4_lazy_thunks.py} | 7 +-- .../{5.py => step5_composition.py} | 6 +-- .../{0.py => streaming_chat_example.py} | 6 +-- mellea/backends/huggingface.py | 2 +- pyproject.toml | 30 ++++++++++- test/__init__.py | 0 test/backends/test_adapters/test_adapter.py | 1 + test/backends/test_huggingface.py | 53 +++++++++--------- test/backends/test_litellm_ollama.py | 1 - test/backends/test_litellm_watsonx.py | 8 +-- test/backends/test_model_options.py | 1 + test/backends/test_ollama.py | 24 ++++----- test/backends/test_openai_ollama.py | 26 ++++----- .../test_openai_vllm/test_openai_vllm.py | 48 ++++++++--------- test/backends/test_tool_calls.py | 10 ++-- test/backends/test_tool_helpers.py | 3 +- test/backends/test_vision_openai.py | 13 ++++- test/backends/test_vllm.py | 20 +++---- test/core/test_base.py | 2 + test/core/test_component_typing.py | 6 +-- test/formatters/test_template_formatter.py | 20 +++---- .../components/docs/test_richdocument.py | 24 +++++---- test/stdlib/components/test_chat.py | 4 +- test/stdlib/components/test_genslot.py | 1 - test/stdlib/components/test_mify.py | 8 +-- test/stdlib/components/test_transform.py | 4 +- .../requirements/test_reqlib_markdown.py | 16 +++--- .../stdlib/requirements/test_reqlib_python.py | 3 +- test/stdlib/requirements/test_reqlib_tools.py | 1 + test/stdlib/requirements/test_requirement.py | 5 +- .../sampling/test_sofai_graph_coloring.py | 2 +- test/stdlib/test_base_context.py | 4 +- test/stdlib/test_chat_view.py | 6 +-- test/stdlib/test_session.py | 2 +- test/stdlib/test_spans.py | 6 +-- 98 files changed, 365 insertions(+), 382 deletions(-) create mode 100644 docs/__init__.py create mode 100644 docs/examples/__init__.py delete mode 100644 docs/examples/tutorial/mcp_example.py rename docs/rewrite/session_deepdive/{0.py => step0_session_api.py} (99%) rename docs/rewrite/session_deepdive/{1.py => step1_functional_api.py} (100%) rename docs/rewrite/session_deepdive/{2.py => step2_act_cblocks.py} (90%) rename docs/rewrite/session_deepdive/{3.py => step3_async.py} (65%) rename docs/rewrite/session_deepdive/{4.py => step4_lazy_thunks.py} (85%) rename docs/rewrite/session_deepdive/{5.py => step5_composition.py} (95%) rename docs/rewrite/streaming/{0.py => streaming_chat_example.py} (100%) create mode 100644 test/__init__.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43a5e0e7..d618240e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,25 +2,22 @@ exclude: ^(scratchpad/) repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.5 + rev: v0.14.14 hooks: - id: ruff-format name: "Ruff formatter" args: [--config=pyproject.toml] - files: '^(mellea|test|cli|docs).*\.(py|ipynb)$' - id: ruff name: "Ruff linter" args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml] - files: '^(mellea).*\.(py|ipynb)$' - repo: local hooks: - id: mypy name: MyPy - entry: uv run --no-sync mypy mellea + entry: uv run --no-sync mypy . pass_filenames: false language: system - files: '^(mellea|test|cli|docs).*\.(py|ipynb)$' - repo: https://github.com/astral-sh/uv-pre-commit rev: 0.7.8 diff --git a/cli/decompose/pipeline.py b/cli/decompose/pipeline.py index 1c6dda1a..a574d35d 100644 --- a/cli/decompose/pipeline.py +++ b/cli/decompose/pipeline.py @@ -5,9 +5,9 @@ from typing_extensions import NotRequired from mellea import MelleaSession +from mellea.backends import ModelOption from mellea.backends.ollama import OllamaModelBackend from mellea.backends.openai import OpenAIBackend -from mellea.backends import ModelOption from .prompt_modules import ( constraint_extractor, diff --git a/cli/eval/commands.py b/cli/eval/commands.py index ebc85dd6..17ff56be 100644 --- a/cli/eval/commands.py +++ b/cli/eval/commands.py @@ -1,5 +1,6 @@ """Use the eval command for LLM-as-a-judge evaluation, given a (set of) test file(s) consisting of prompts, instructions, and optionally, targets. -Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed.""" +Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed. +""" import typer diff --git a/cli/eval/runner.py b/cli/eval/runner.py index 38b4c1bb..3aface94 100644 --- a/cli/eval/runner.py +++ b/cli/eval/runner.py @@ -1,15 +1,16 @@ import json import re from pathlib import Path -from typing import List + +from rich.console import Console +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn import mellea +from mellea.backends import ModelOption +from mellea.backends.backend import Backend from mellea.core import ModelOutputThunk +from mellea.stdlib.components import SimpleComponent from mellea.stdlib.components.unit_test_eval import TestBasedEval -from mellea.backends import ModelOption - -from rich.console import Console -from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn console = Console() @@ -78,7 +79,6 @@ def create_session( backend: str, model: str | None, max_tokens: int | None ) -> mellea.MelleaSession: """Create a mellea session with the specified backend and model.""" - model_id = None if model: if model.isupper() or "_" in model: @@ -93,6 +93,7 @@ def create_session( try: backend_lower = backend.lower() + backend_instance: Backend if backend_lower == "ollama": from mellea.backends.ollama import OllamaModelBackend @@ -130,7 +131,7 @@ def create_session( from mellea.backends.litellm import LiteLLMBackend backend_instance = LiteLLMBackend( - model_id=model_id, + model_id=str(model_id), model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}, ) @@ -153,7 +154,7 @@ def create_session( def run_evaluations( - test_files: List[str], + test_files: list[str], backend: str, model: str | None, max_gen_tokens: int | None, @@ -173,14 +174,14 @@ def run_evaluations( "instructions": a set (in string form) of requirements which the generation should follow; the judge will evaluate if these are satisfied "examples": a list of entries containing an input_id, an input(prompt), and a list of targets. Each input may have multiple (or no) targets; inputs and targets are in messages format. """ - all_test_evals: List[TestBasedEval] = [] + all_test_evals: list[TestBasedEval] = [] for test_file in test_files: try: test_evals = TestBasedEval.from_json_file(test_file) all_test_evals.extend(test_evals) console.print(f"Loaded {len(test_evals)} test evaluations from {test_file}") - except Exception as e: + except Exception: console.print(f"Error loading {test_file}") if not all_test_evals: @@ -195,8 +196,11 @@ def run_evaluations( console.print(f"Judge model: {judge_model}") m = create_session(backend=backend, model=model, max_tokens=max_gen_tokens) + # Use same backend as generator if judge_backend not specified judge_session = create_session( - backend=judge_backend, model=judge_model, max_tokens=max_judge_tokens + backend=judge_backend if judge_backend else backend, + model=judge_model, + max_tokens=max_judge_tokens, ) all_results = [] @@ -240,12 +244,13 @@ def execute_test_eval( For each input in the test, generate a response using generation_session Then, after all inputs are processed, validate using judge_session. """ - input_results = [] # for all inputs, generate responses with generator for idx, input_text in enumerate(test_eval.inputs): - result: ModelOutputThunk = generation_session.act(input_text) + result: ModelOutputThunk = generation_session.act( + SimpleComponent(instruction=input_text) + ) model_output = str(result) targets_for_input = ( @@ -267,7 +272,7 @@ def execute_test_eval( input_text=input_text, model_output=model_output, validation_passed=passed, - score=score, + score=score if score is not None else 0, validation_reason=justification, ) input_results.append(input_result) @@ -301,7 +306,7 @@ def parse_judge_output(judge_output: str): return None, judge_output -def save_results(results: List[TestEvalResult], output_path: str, output_format: str): +def save_results(results: list[TestEvalResult], output_path: str, output_format: str): output_path_obj = Path(output_path) if output_path_obj.suffix != f".{output_format}": output_path_obj = Path(f"{output_path}.{output_format}") @@ -333,7 +338,7 @@ def save_results(results: List[TestEvalResult], output_path: str, output_format: console.print(f"Results saved to {output_path}") -def summary_stats(results: List[TestEvalResult]): +def summary_stats(results: list[TestEvalResult]): total_inputs = sum(r.total_count for r in results) passed_inputs = sum(r.passed_count for r in results) overall_pass_rate = passed_inputs / total_inputs if total_inputs > 0 else 0.0 diff --git a/cli/m.py b/cli/m.py index 07fc14b9..ab39440e 100644 --- a/cli/m.py +++ b/cli/m.py @@ -4,8 +4,8 @@ from cli.alora.commands import alora_app from cli.decompose import app as decompose_app -from cli.serve.app import serve from cli.eval.commands import eval_app +from cli.serve.app import serve cli = typer.Typer(name="m", no_args_is_help=True) diff --git a/docs/__init__.py b/docs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docs/examples/__init__.py b/docs/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docs/examples/aLora/101_example.py b/docs/examples/aLora/101_example.py index 1b65509f..8c3b66e2 100644 --- a/docs/examples/aLora/101_example.py +++ b/docs/examples/aLora/101_example.py @@ -1,12 +1,13 @@ import time -from mellea import MelleaSession from mellea.backends.aloras.huggingface.granite_aloras import HFConstraintAlora -from mellea.backends.cache import SimpleLRUCache -from mellea.backends.huggingface import LocalHFBackend from mellea.stdlib.base import ChatContext, GenerateLog from mellea.stdlib.requirement import ALoraRequirement, Requirement +from mellea import MelleaSession +from mellea.backends.cache import SimpleLRUCache +from mellea.backends.huggingface import LocalHFBackend + # Define a backend and add the constraint aLora backend = LocalHFBackend( model_id="ibm-granite/granite-3.2-8b-instruct", cache=SimpleLRUCache(5) diff --git a/docs/examples/agents/react.py b/docs/examples/agents/react.py index 117f1440..b5d1eab7 100644 --- a/docs/examples/agents/react.py +++ b/docs/examples/agents/react.py @@ -2,6 +2,7 @@ import inspect import json from collections.abc import Callable +from enum import Enum from typing import Literal import pydantic @@ -82,9 +83,9 @@ def call_tool(self, tool: ReactTool, kwargs_json: str): def tool_name_schema(self): names = self.tool_names() - fields = dict() - fields["tool"] = Literal[*names] - return pydantic.create_model("ToolSelectionSchema", **fields) + # Python 3.10 compatible: use Enum instead of Literal[*names] (requires 3.11+) + ToolEnum = Enum("ToolEnum", {name: name for name in names}) + return pydantic.create_model("ToolSelectionSchema", tool=(ToolEnum, ...)) def get_tool_from_schema(self, content: str): schema = self.tool_name_schema() diff --git a/docs/examples/agents/react_instruct.py b/docs/examples/agents/react_instruct.py index b72adbc6..86f22ad6 100644 --- a/docs/examples/agents/react_instruct.py +++ b/docs/examples/agents/react_instruct.py @@ -2,6 +2,7 @@ import inspect import json from collections.abc import Callable +from enum import Enum from typing import Literal import pydantic @@ -79,9 +80,9 @@ def call_tool(self, tool: ReactTool, kwargs_json: str): def tool_name_schema(self): names = self.tool_names() - fields = dict() - fields["tool"] = Literal[*names] - return pydantic.create_model("ToolSelectionSchema", **fields) + # Python 3.10 compatible: use Enum instead of Literal[*names] (requires 3.11+) + ToolEnum = Enum("ToolEnum", {name: name for name in names}) + return pydantic.create_model("ToolSelectionSchema", tool=(ToolEnum, ...)) def get_tool_from_schema(self, content: str): schema = self.tool_name_schema() diff --git a/docs/examples/conftest.py b/docs/examples/conftest.py index bef7dce6..f1379261 100644 --- a/docs/examples/conftest.py +++ b/docs/examples/conftest.py @@ -26,8 +26,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): terminalreporter.ensure_newline() terminalreporter.section("Skipped Examples", sep="=", blue=True, bold=True) + newline = "\n" terminalreporter.line( - f"Examples with the following names were skipped because they cannot be easily run in the pytest framework; please run them manually:\n{'\n'.join(examples_to_skip)}" + f"Examples with the following names were skipped because they cannot be easily run in the pytest framework; please run them manually:\n{newline.join(examples_to_skip)}" ) @@ -83,7 +84,7 @@ def runtest(self): if retcode != 0: raise ExampleTestException( - (f"Example failed with exit code {retcode}.\nStderr: {stderr}\n") + f"Example failed with exit code {retcode}.\nStderr: {stderr}\n" ) def repr_failure(self, excinfo, style=None): diff --git a/docs/examples/context/contexts_with_sampling.py b/docs/examples/context/contexts_with_sampling.py index 1f71397b..9661353a 100644 --- a/docs/examples/context/contexts_with_sampling.py +++ b/docs/examples/context/contexts_with_sampling.py @@ -26,7 +26,7 @@ print(f"Total Generation Attempts: {len(res.sample_generations)}") print() -print(f"Getting index of another result.") +print("Getting index of another result.") index = 0 # Just choose the first one. print( diff --git a/docs/examples/generative_slots/generative_slots_with_requirements.py b/docs/examples/generative_slots/generative_slots_with_requirements.py index 6f5a610a..63aadd19 100644 --- a/docs/examples/generative_slots/generative_slots_with_requirements.py +++ b/docs/examples/generative_slots/generative_slots_with_requirements.py @@ -1,16 +1,15 @@ from typing import Literal from mellea import generative, start_session +from mellea.core import Requirement from mellea.stdlib.components.genslot import PreconditionException from mellea.stdlib.requirements import simple_validate -from mellea.core import Requirement from mellea.stdlib.sampling.base import RejectionSamplingStrategy @generative def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: """Classify the sentiment of the text.""" - ... if __name__ == "__main__": @@ -63,7 +62,7 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: ], ) except PreconditionException as e: - print(f"exception: {str(e)}") + print(f"exception: {e!s}") # Look at why the precondition validation failed. print("Failure reasons:") diff --git a/docs/examples/image_text_models/vision_litellm_backend.py b/docs/examples/image_text_models/vision_litellm_backend.py index 03f1ea1b..9b6c4d82 100644 --- a/docs/examples/image_text_models/vision_litellm_backend.py +++ b/docs/examples/image_text_models/vision_litellm_backend.py @@ -1,6 +1,7 @@ """Examples of using vision models with LiteLLM backend.""" import os +import pathlib import litellm from PIL import Image @@ -9,7 +10,6 @@ from mellea.backends.litellm import LiteLLMBackend from mellea.backends.openai import OpenAIBackend from mellea.core import ImageBlock -import pathlib # use LiteLLM to talk to Ollama or anthropic or..... m = MelleaSession(LiteLLMBackend("ollama/granite3.2-vision")) @@ -28,15 +28,15 @@ "Is there a person on the image? Is the subject in the image smiling?", images=[test_pil], ) -print(f"Test with PIL and instruct: \n{str(res_instruct)}\n-----") +print(f"Test with PIL and instruct: \n{res_instruct!s}\n-----") # print(m.last_prompt()) # with PIL image and using m.chat res_chat = m.chat( "How many eyes can you identify in the image? Explain.", images=[test_pil] ) -print(f"Test with PIL and chat: \n{str(res_chat.content)}\n-----") +print(f"Test with PIL and chat: \n{res_chat.content!s}\n-----") # and now without images again... res_empty = m.instruct("How many eyes can you identify in the image?", images=[]) -print(f"Test without image: \n{str(res_empty)}\n-----") +print(f"Test without image: \n{res_empty!s}\n-----") diff --git a/docs/examples/image_text_models/vision_ollama_chat.py b/docs/examples/image_text_models/vision_ollama_chat.py index 49fb1198..a7aca7a7 100644 --- a/docs/examples/image_text_models/vision_ollama_chat.py +++ b/docs/examples/image_text_models/vision_ollama_chat.py @@ -1,6 +1,7 @@ """Example of using Ollama with vision models with linear context.""" import pathlib + from PIL import Image from mellea import start_session diff --git a/docs/examples/image_text_models/vision_openai_examples.py b/docs/examples/image_text_models/vision_openai_examples.py index 1ca58658..5f9ce06d 100644 --- a/docs/examples/image_text_models/vision_openai_examples.py +++ b/docs/examples/image_text_models/vision_openai_examples.py @@ -6,8 +6,8 @@ from mellea import MelleaSession from mellea.backends.openai import OpenAIBackend -from mellea.stdlib.context import ChatContext from mellea.core import ImageBlock +from mellea.stdlib.context import ChatContext # # using anthropic AI model ... # anth_key = os.environ.get("ANTHROPIC_API_KEY") diff --git a/docs/examples/information_extraction/101_with_gen_slots.py b/docs/examples/information_extraction/101_with_gen_slots.py index 961a5122..887854b5 100644 --- a/docs/examples/information_extraction/101_with_gen_slots.py +++ b/docs/examples/information_extraction/101_with_gen_slots.py @@ -8,9 +8,7 @@ @generative def extract_all_person_names(doc: str) -> list[str]: - """ - Given a document, extract names of ALL mentioned persons. Return these names as list of strings. - """ + """Given a document, extract names of ALL mentioned persons. Return these names as list of strings.""" # ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html diff --git a/docs/examples/information_extraction/advanced_with_m_instruct.py b/docs/examples/information_extraction/advanced_with_m_instruct.py index d2678952..2fac3990 100644 --- a/docs/examples/information_extraction/advanced_with_m_instruct.py +++ b/docs/examples/information_extraction/advanced_with_m_instruct.py @@ -6,9 +6,9 @@ from mellea import start_session from mellea.backends import model_ids +from mellea.core import SamplingResult from mellea.stdlib.requirements import check, simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy -from mellea.core import SamplingResult # ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html NYTimes_text = "CAMP DAVID, Md. — Leaders of the world's richest countries banded together on Saturday to press Germany to back more pro-growth policies to halt the deepening debt crisis in Europe, as President Obama for the first time gained widespread support for his argument that Europe, and the United States by extension, cannot afford Chancellor Angela Merkel's one-size-fits-all approach emphasizing austerity." diff --git a/docs/examples/instruct_validate_repair/101_email_with_validate.py b/docs/examples/instruct_validate_repair/101_email_with_validate.py index a7a0e500..ab7bbb04 100644 --- a/docs/examples/instruct_validate_repair/101_email_with_validate.py +++ b/docs/examples/instruct_validate_repair/101_email_with_validate.py @@ -1,7 +1,7 @@ from docs.examples.helper import req_print, w from mellea import start_session -from mellea.backends.model_ids import IBM_GRANITE_3_3_8B from mellea.backends import ModelOption +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B from mellea.stdlib.sampling import RejectionSamplingStrategy # create a session using Granite 4 Micro (3B) on Ollama and a simple context [see below] diff --git a/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py b/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py index 0ffe0d13..6e2a0140 100644 --- a/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py +++ b/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py @@ -1,8 +1,8 @@ from docs.examples.helper import w from mellea import start_session from mellea.backends import ModelOption -from mellea.stdlib.requirements import simple_validate from mellea.core import Requirement +from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy # create a session using Granite 4 Micro 3B on Ollama and a simple context [see below] diff --git a/docs/examples/intrinsics/answer_relevance.py b/docs/examples/intrinsics/answer_relevance.py index f945c6dd..4f9ade60 100644 --- a/docs/examples/intrinsics/answer_relevance.py +++ b/docs/examples/intrinsics/answer_relevance.py @@ -1,5 +1,4 @@ -""" -Example usage of the answer relevance intrinsic for RAG applications. +"""Example usage of the answer relevance intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -8,10 +7,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document +from mellea.stdlib.components import Document, Message from mellea.stdlib.components.intrinsic import rag - +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext().add(Message("user", "Who attended the meeting?")) diff --git a/docs/examples/intrinsics/answerability.py b/docs/examples/intrinsics/answerability.py index 6804c5d7..16954332 100644 --- a/docs/examples/intrinsics/answerability.py +++ b/docs/examples/intrinsics/answerability.py @@ -1,5 +1,4 @@ -""" -Example usage of the answerability intrinsic for RAG applications. +"""Example usage of the answerability intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -8,9 +7,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document +from mellea.stdlib.components import Document, Message from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext().add(Message("assistant", "Hello there, how can I help you?")) diff --git a/docs/examples/intrinsics/citations.py b/docs/examples/intrinsics/citations.py index 74377091..cf48e128 100644 --- a/docs/examples/intrinsics/citations.py +++ b/docs/examples/intrinsics/citations.py @@ -1,5 +1,4 @@ -""" -Example usage of the citations intrinsic for RAG applications. +"""Example usage of the citations intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -7,12 +6,12 @@ ``` """ -from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document -from mellea.stdlib.components.intrinsic import rag import json +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext().add( diff --git a/docs/examples/intrinsics/context_relevance.py b/docs/examples/intrinsics/context_relevance.py index 470973e3..9dce932d 100644 --- a/docs/examples/intrinsics/context_relevance.py +++ b/docs/examples/intrinsics/context_relevance.py @@ -1,5 +1,4 @@ -""" -Example usage of the context relevance intrinsic for RAG applications. +"""Example usage of the context relevance intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -8,9 +7,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Document from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext() diff --git a/docs/examples/intrinsics/hallucination_detection.py b/docs/examples/intrinsics/hallucination_detection.py index 271e76a3..14d24e20 100644 --- a/docs/examples/intrinsics/hallucination_detection.py +++ b/docs/examples/intrinsics/hallucination_detection.py @@ -1,5 +1,4 @@ -""" -Example usage of the hallucination detection intrinsic for RAG applications. +"""Example usage of the hallucination detection intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -7,12 +6,12 @@ ``` """ -from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document -from mellea.stdlib.components.intrinsic import rag import json +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ( diff --git a/docs/examples/intrinsics/intrinsics.py b/docs/examples/intrinsics/intrinsics.py index 10ba4e97..83ca39da 100644 --- a/docs/examples/intrinsics/intrinsics.py +++ b/docs/examples/intrinsics/intrinsics.py @@ -1,10 +1,9 @@ +import mellea.stdlib.functional as mfuncs +from mellea.backends.adapters.adapter import AdapterType, GraniteCommonAdapter from mellea.backends.huggingface import LocalHFBackend from mellea.backends.openai import OpenAIBackend, _ServerType -from mellea.backends.adapters.adapter import AdapterType, GraniteCommonAdapter +from mellea.stdlib.components import Intrinsic, Message from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message -import mellea.stdlib.functional as mfuncs -from mellea.stdlib.components import Intrinsic # This is an example for how you would directly use intrinsics. See `mellea/stdlib/intrinsics/rag.py` # for helper functions. diff --git a/docs/examples/intrinsics/query_rewrite.py b/docs/examples/intrinsics/query_rewrite.py index a95cadc7..60fb812a 100644 --- a/docs/examples/intrinsics/query_rewrite.py +++ b/docs/examples/intrinsics/query_rewrite.py @@ -1,5 +1,4 @@ -""" -Example usage of the query rewrite intrinsic for RAG applications. +"""Example usage of the query rewrite intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -8,9 +7,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ( diff --git a/docs/examples/library_interop/langchain_messages.py b/docs/examples/library_interop/langchain_messages.py index 8d99720d..eebbaa35 100644 --- a/docs/examples/library_interop/langchain_messages.py +++ b/docs/examples/library_interop/langchain_messages.py @@ -1,6 +1,6 @@ # Installing langchain is necessary for this example, but it works for any library # you may want to use Mellea with. -from langchain_core.messages import HumanMessage, AIMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage # Messages from a different library. messages = [ @@ -15,10 +15,10 @@ messages = convert_to_openai_messages(messages=messages) # Import Mellea. +from mellea import start_session +from mellea.backends import ModelOption from mellea.stdlib.components import Message from mellea.stdlib.context import ChatContext -from mellea.backends import ModelOption -from mellea import start_session # Mellea uses explicit contexts. Cast the OpenAI formatted messages into # Mellea messages and add them to the context. diff --git a/docs/examples/m_serve/m_serve_example_simple.py b/docs/examples/m_serve/m_serve_example_simple.py index f1dff480..5d433f29 100644 --- a/docs/examples/m_serve/m_serve_example_simple.py +++ b/docs/examples/m_serve/m_serve_example_simple.py @@ -4,8 +4,8 @@ import mellea from cli.serve.models import ChatMessage +from mellea.core import ModelOutputThunk, Requirement, SamplingResult from mellea.stdlib.context import ChatContext -from mellea.core import ModelOutputThunk, SamplingResult, Requirement from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy diff --git a/docs/examples/m_serve/pii_serve.py b/docs/examples/m_serve/pii_serve.py index 356867e8..0c166326 100644 --- a/docs/examples/m_serve/pii_serve.py +++ b/docs/examples/m_serve/pii_serve.py @@ -2,8 +2,8 @@ import spacy -from cli.serve.models import ChatMessage import mellea +from cli.serve.models import ChatMessage from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B from mellea.core import ModelOutputThunk, SamplingResult from mellea.stdlib.requirements import req, simple_validate diff --git a/docs/examples/mcp/README.md b/docs/examples/mcp/README.md index 8202f789..48161c33 100644 --- a/docs/examples/mcp/README.md +++ b/docs/examples/mcp/README.md @@ -14,7 +14,7 @@ uv pip install "mcp[cli]" and run the example in MCP debug UI: ```bash -uv run mcp dev docs/examples/tutorial/mcp_example.py +uv run mcp dev docs/examples/mcp/mcp_example.py ``` diff --git a/docs/examples/mcp/mcp_example.py b/docs/examples/mcp/mcp_example.py index 43d965dc..fac1a6cc 100644 --- a/docs/examples/mcp/mcp_example.py +++ b/docs/examples/mcp/mcp_example.py @@ -4,7 +4,7 @@ uv pip install "mcp[cli]" and run the example in MCP debug UI: -uv run mcp dev docs/examples/tutorial/mcp_example.py +uv run mcp dev docs/examples/mcp/mcp_example.py """ from mcp.server.fastmcp import FastMCP diff --git a/docs/examples/melp/lazy.py b/docs/examples/melp/lazy.py index 3a3b9e5f..816e70aa 100644 --- a/docs/examples/melp/lazy.py +++ b/docs/examples/melp/lazy.py @@ -1,12 +1,10 @@ import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk +from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -from mellea.core import Backend -from mellea.backends.ollama import OllamaModelBackend - backend = OllamaModelBackend("granite4:latest") diff --git a/docs/examples/melp/lazy_fib.py b/docs/examples/melp/lazy_fib.py index e91a4a2b..5e739c20 100644 --- a/docs/examples/melp/lazy_fib.py +++ b/docs/examples/melp/lazy_fib.py @@ -1,13 +1,10 @@ import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk +from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -from mellea.core import Backend -from mellea.backends.ollama import OllamaModelBackend -from typing import Tuple - backend = OllamaModelBackend("granite4:latest") diff --git a/docs/examples/melp/lazy_fib_sample.py b/docs/examples/melp/lazy_fib_sample.py index 0224f4a3..833a2803 100644 --- a/docs/examples/melp/lazy_fib_sample.py +++ b/docs/examples/melp/lazy_fib_sample.py @@ -1,13 +1,10 @@ import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk +from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -from mellea.core import Backend -from mellea.backends.ollama import OllamaModelBackend -from typing import Tuple - backend = OllamaModelBackend("granite4:latest") @@ -26,7 +23,7 @@ async def _fib_sample( try: int(value) return answer_mot - except: + except Exception: return None diff --git a/docs/examples/melp/simple_example.py b/docs/examples/melp/simple_example.py index 7862027e..ee90f51d 100644 --- a/docs/examples/melp/simple_example.py +++ b/docs/examples/melp/simple_example.py @@ -1,13 +1,12 @@ import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk, Backend + from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): - """ - In this example, we show how executing multiple MOTs in parallel should work. - """ + """In this example, we show how executing multiple MOTs in parallel should work.""" m_states = "Missouri", "Minnesota", "Montana", "Massachusetts" poem_thunks = [] @@ -19,7 +18,7 @@ async def main(backend: Backend, ctx: Context): # Notice that what we have now is a list of ModelOutputThunks, none of which are computed. for poem_thunk in poem_thunks: - assert type(poem_thunk) == ModelOutputThunk + assert isinstance(poem_thunk, ModelOutputThunk) print(f"Computed: {poem_thunk.is_computed()}") # Let's run all of these in parallel. diff --git a/docs/examples/melp/states.py b/docs/examples/melp/states.py index d8770c3a..cd390ce4 100644 --- a/docs/examples/melp/states.py +++ b/docs/examples/melp/states.py @@ -1,9 +1,9 @@ import asyncio -from mellea.core import Context, CBlock, Backend from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.context import SimpleContext +from mellea.core import Backend, CBlock, Context from mellea.stdlib.components import SimpleComponent +from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): diff --git a/docs/examples/mify/rich_table_execute_basic.py b/docs/examples/mify/rich_table_execute_basic.py index a9a5c112..d32bda58 100644 --- a/docs/examples/mify/rich_table_execute_basic.py +++ b/docs/examples/mify/rich_table_execute_basic.py @@ -2,8 +2,7 @@ import os from mellea import start_session -from mellea.backends import model_ids -from mellea.backends import ModelOption +from mellea.backends import ModelOption, model_ids from mellea.core import FancyLogger from mellea.stdlib.components.docs.richdocument import RichDocument, Table diff --git a/docs/examples/mini_researcher/researcher.py b/docs/examples/mini_researcher/researcher.py index 3fe3182a..89746404 100644 --- a/docs/examples/mini_researcher/researcher.py +++ b/docs/examples/mini_researcher/researcher.py @@ -9,9 +9,9 @@ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Requirement, SamplingResult from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy -from mellea.core import SamplingResult, Requirement # ############################# # Helper functions diff --git a/docs/examples/mobject/table.py b/docs/examples/mobject/table.py index 03788882..2d2c781e 100644 --- a/docs/examples/mobject/table.py +++ b/docs/examples/mobject/table.py @@ -28,7 +28,7 @@ def update_sales(self, store: str, amount: str): index_col=False, ) # Remove unnamed columns and columns that don't exist. - table_df.drop(table_df.filter(regex="Unname").columns, axis=1, inplace=True) + table_df = table_df.drop(table_df.filter(regex="Unname").columns, axis=1) # Sometimes extra whitespace gets added to the column names and row values. Remove it. table_df.columns = table_df.columns.str.strip() diff --git a/docs/examples/notebooks/georgia_tech.ipynb b/docs/examples/notebooks/georgia_tech.ipynb index 08422fb4..a83e3169 100644 --- a/docs/examples/notebooks/georgia_tech.ipynb +++ b/docs/examples/notebooks/georgia_tech.ipynb @@ -439,9 +439,9 @@ }, "outputs": [], "source": [ + "from mellea.backends import ModelOption\n", "from mellea.backends.model_ids import META_LLAMA_3_2_3B\n", "from mellea.backends.ollama import OllamaModelBackend\n", - "from mellea.backends import ModelOption\n", "\n", "# You can use multiple different models at the same time!\n", "m_llama = mellea.MelleaSession(backend=OllamaModelBackend(model_id=META_LLAMA_3_2_3B))\n", diff --git a/docs/examples/notebooks/m_serve_example.ipynb b/docs/examples/notebooks/m_serve_example.ipynb index 729b75bf..fb3e7429 100644 --- a/docs/examples/notebooks/m_serve_example.ipynb +++ b/docs/examples/notebooks/m_serve_example.ipynb @@ -89,8 +89,8 @@ "\n", "import mellea\n", "from cli.serve.models import ChatMessage\n", - "from mellea.stdlib.context import ChatContext\n", "from mellea.core import ModelOutputThunk, Requirement, SamplingResult\n", + "from mellea.stdlib.context import ChatContext\n", "from mellea.stdlib.requirements import simple_validate\n", "from mellea.stdlib.sampling import RejectionSamplingStrategy\n", "\n", diff --git a/docs/examples/notebooks/model_options_example.ipynb b/docs/examples/notebooks/model_options_example.ipynb index 0216010c..e518679d 100644 --- a/docs/examples/notebooks/model_options_example.ipynb +++ b/docs/examples/notebooks/model_options_example.ipynb @@ -88,9 +88,8 @@ "outputs": [], "source": [ "import mellea\n", - "from mellea.backends import model_ids\n", - "from mellea.backends.ollama import OllamaModelBackend\n", - "from mellea.backends import ModelOption" + "from mellea.backends import ModelOption, model_ids\n", + "from mellea.backends.ollama import OllamaModelBackend" ] }, { diff --git a/docs/examples/rag/simple_rag_with_filter.py b/docs/examples/rag/simple_rag_with_filter.py index adf17ecd..9ff1e6d0 100644 --- a/docs/examples/rag/simple_rag_with_filter.py +++ b/docs/examples/rag/simple_rag_with_filter.py @@ -44,7 +44,7 @@ def create_index(model, ds: list[str]) -> IndexFlatIP: def query_index(model, idx: IndexFlatIP, query: str, ds: list[str], k: int = 5) -> list: query_embedding = model.encode([query]) - distances, indices = idx.search(query_embedding, k=k) + _distances, indices = idx.search(query_embedding, k=k) return [ds[i] for i in indices[0]] diff --git a/docs/examples/safety/guardian.py b/docs/examples/safety/guardian.py index c5b8b123..aa371cfa 100644 --- a/docs/examples/safety/guardian.py +++ b/docs/examples/safety/guardian.py @@ -4,8 +4,8 @@ from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend from mellea.core import ContextTurn, ModelOutputThunk -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk # Enhanced GuardianCheck with Granite Guardian 3.3 8B support diff --git a/docs/examples/safety/guardian_huggingface.py b/docs/examples/safety/guardian_huggingface.py index bbb84698..5b139b41 100644 --- a/docs/examples/safety/guardian_huggingface.py +++ b/docs/examples/safety/guardian_huggingface.py @@ -6,11 +6,11 @@ from mellea import MelleaSession from mellea.backends import model_ids -from mellea.backends.ollama import OllamaModelBackend from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.ollama import OllamaModelBackend from mellea.core import ModelOutputThunk, ModelToolCall -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk print("=== GuardianCheck HuggingFace Backend Example ===") @@ -46,7 +46,7 @@ print(f"Guardian detected harm: {not validation_result[0]._result}") if validation_result[0]._reason: - print(f"\nGuardian feedback:") + print("\nGuardian feedback:") print(validation_result[0]._reason[:200] + "...") # Test 2: Groundedness detection diff --git a/docs/examples/safety/repair_with_guardian.py b/docs/examples/safety/repair_with_guardian.py index 2355eff5..d16e9820 100644 --- a/docs/examples/safety/repair_with_guardian.py +++ b/docs/examples/safety/repair_with_guardian.py @@ -1,5 +1,4 @@ -""" -RepairTemplateStrategy Example with Actual Function Call Validation +"""RepairTemplateStrategy Example with Actual Function Call Validation Demonstrates how RepairTemplateStrategy repairs responses using actual function calls. """ @@ -78,10 +77,10 @@ def get_stock_price(symbol: str) -> str: if hasattr(m.backend, "formatter"): try: rendered = m.backend.formatter.print(action) - print(f" Instruction sent to model:") - print(f" ---") + print(" Instruction sent to model:") + print(" ---") print(f" {rendered}") - print(f" ---") + print(" ---") except Exception: pass diff --git a/docs/examples/sessions/creating_a_new_type_of_session.py b/docs/examples/sessions/creating_a_new_type_of_session.py index 59624caf..b7659034 100644 --- a/docs/examples/sessions/creating_a_new_type_of_session.py +++ b/docs/examples/sessions/creating_a_new_type_of_session.py @@ -1,12 +1,19 @@ from typing import Literal + from PIL import Image as PILImage from mellea import MelleaSession -from mellea.core import Backend, BaseModelSubclass from mellea.backends.ollama import OllamaModelBackend -from mellea.core import CBlock, Context, ImageBlock, Requirement -from mellea.stdlib.context import ChatContext +from mellea.core import ( + Backend, + BaseModelSubclass, + CBlock, + Context, + ImageBlock, + Requirement, +) from mellea.stdlib.components import Message +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import reqify from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py index b2a9315b..1b73f5bf 100644 --- a/docs/examples/tools/interpreter_example.py +++ b/docs/examples/tools/interpreter_example.py @@ -1,7 +1,7 @@ -from mellea.stdlib.tools import code_interpreter, local_code_interpreter -from mellea import start_session, MelleaSession +from mellea import MelleaSession, start_session from mellea.backends import ModelOption -from mellea.stdlib.requirements import uses_tool, tool_arg_validator +from mellea.stdlib.requirements import tool_arg_validator, uses_tool +from mellea.stdlib.tools import code_interpreter, local_code_interpreter def example_1(m: MelleaSession): diff --git a/docs/examples/tutorial/compositionality_with_generative_slots.py b/docs/examples/tutorial/compositionality_with_generative_slots.py index 8ba09f1f..6c793485 100644 --- a/docs/examples/tutorial/compositionality_with_generative_slots.py +++ b/docs/examples/tutorial/compositionality_with_generative_slots.py @@ -34,7 +34,7 @@ def generate_novel_recommendations(summary: str) -> str: # Compose the libraries. -from typing import Literal # noqa: E402 +from typing import Literal @generative @@ -52,7 +52,7 @@ def has_theme_and_plot(summary: str) -> Literal["yes", "no"]: """Check whether the summary contains both a plot and thematic elements.""" -from mellea import start_session # noqa: E402 +from mellea import start_session m = start_session() transcript = """Meeting Transcript: Market Risk Review -- Self-Sealing Stembolts Division diff --git a/docs/examples/tutorial/context_example.py b/docs/examples/tutorial/context_example.py index e98e1182..465e9747 100644 --- a/docs/examples/tutorial/context_example.py +++ b/docs/examples/tutorial/context_example.py @@ -1,6 +1,7 @@ -from mellea import start_session from mellea.stdlib.base import ChatContext +from mellea import start_session + m = start_session(ctx=ChatContext()) m.chat("Make up a math problem.") m.chat("Solve your math problem.") diff --git a/docs/examples/tutorial/document_mobject.py b/docs/examples/tutorial/document_mobject.py index 42c18cb1..63012ddb 100644 --- a/docs/examples/tutorial/document_mobject.py +++ b/docs/examples/tutorial/document_mobject.py @@ -4,13 +4,13 @@ rd = RichDocument.from_document_file("https://arxiv.org/pdf/1906.04043") -from mellea.stdlib.components.docs.richdocument import Table # noqa: E402 +from mellea.stdlib.components.docs.richdocument import Table table1: Table = rd.get_tables()[0] print(table1.to_markdown()) -from mellea import start_session # noqa: E402 -from mellea.backends import ModelOption # noqa: E402 +from mellea import start_session +from mellea.backends import ModelOption m = start_session(model_id=model_ids.META_LLAMA_3_2_3B) for seed in [x * 12 for x in range(5)]: diff --git a/docs/examples/tutorial/instruct_validate_repair.py b/docs/examples/tutorial/instruct_validate_repair.py index 76113d0f..75985d69 100644 --- a/docs/examples/tutorial/instruct_validate_repair.py +++ b/docs/examples/tutorial/instruct_validate_repair.py @@ -9,8 +9,8 @@ check("Do not mention purple elephants."), # == r3 ] -import mellea # noqa: E402 -from mellea.stdlib.sampling import RejectionSamplingStrategy # noqa: E402 +import mellea +from mellea.stdlib.sampling import RejectionSamplingStrategy def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str: diff --git a/docs/examples/tutorial/mcp_example.py b/docs/examples/tutorial/mcp_example.py deleted file mode 100644 index 25bf5199..00000000 --- a/docs/examples/tutorial/mcp_example.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Example of an MCP server. - -You need to install the mcp package: -uv pip install "mcp[cli]" - -and run the example in MCP debug UI: -uv run mcp dev docs/examples/tutorial/mcp_example.py -""" - -from mcp.server.fastmcp import FastMCP - -from mellea import MelleaSession -from mellea.backends import ModelOption, model_ids -from mellea.backends.ollama import OllamaModelBackend -from mellea.core import ModelOutputThunk -from mellea.stdlib.requirements import requirement, simple_validate -from mellea.stdlib.sampling import RejectionSamplingStrategy - -# ################# -# run MCP debug UI with: uv run mcp dev docs/examples/tutorial/mcp_example.py -# ################## - - -# Create an MCP server -mcp = FastMCP("Demo") - - -@mcp.tool() -def write_a_poem(word_limit: int) -> str: - """Write a poem with a word limit.""" - m = MelleaSession( - OllamaModelBackend( - model_ids.HF_SMOLLM2_2B, - model_options={ModelOption.MAX_NEW_TOKENS: word_limit + 10}, - ) - ) - wl_req = Requirement( - f"Use only {word_limit} words.", - validation_fn=simple_validate(lambda x: len(x.split(" ")) < word_limit), - ) - - res = m.instruct( - "Write a poem", - requirements=[wl_req], - strategy=RejectionSamplingStrategy(loop_budget=2), - ) - assert isinstance(res, ModelOutputThunk) - return str(res.value) - - -@mcp.resource("greeting://{name}") -def get_greeting(name: str) -> str: - """Get a personalized greeting.""" - return f"Hello, {name}!" diff --git a/docs/examples/tutorial/model_options_example.py b/docs/examples/tutorial/model_options_example.py index 7eb88b9a..861cbe90 100644 --- a/docs/examples/tutorial/model_options_example.py +++ b/docs/examples/tutorial/model_options_example.py @@ -1,7 +1,6 @@ import mellea -from mellea.backends import model_ids +from mellea.backends import ModelOption, model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.backends import ModelOption m = mellea.MelleaSession( backend=OllamaModelBackend(model_options={ModelOption.SEED: 42}) diff --git a/docs/examples/tutorial/simple_email.py b/docs/examples/tutorial/simple_email.py index 0de7c772..cbaafecb 100644 --- a/docs/examples/tutorial/simple_email.py +++ b/docs/examples/tutorial/simple_email.py @@ -52,7 +52,7 @@ def write_email_with_requirements( ) print("Email with rejection sampling:") -from mellea.stdlib.sampling import RejectionSamplingStrategy # noqa: E402 +from mellea.stdlib.sampling import RejectionSamplingStrategy def write_email_with_strategy(m: mellea.MelleaSession, name: str, notes: str) -> str: diff --git a/docs/kv_smash/hf_example.py b/docs/kv_smash/hf_example.py index dc81a64a..c5db6967 100644 --- a/docs/kv_smash/hf_example.py +++ b/docs/kv_smash/hf_example.py @@ -1,10 +1,11 @@ +import asyncio + +from mellea.backends import ModelOption from mellea.backends.huggingface import LocalHFBackend from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.backends import ModelOption from mellea.core import CBlock -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message -import asyncio +from mellea.stdlib.context import ChatContext async def example(): diff --git a/docs/kv_smash/kv_with_chat.py b/docs/kv_smash/kv_with_chat.py index e0f43bc4..e6ab5a05 100644 --- a/docs/kv_smash/kv_with_chat.py +++ b/docs/kv_smash/kv_with_chat.py @@ -49,7 +49,9 @@ def merge(toks, dcs): {"role": "user", "content": c_blocks[1]}, {"role": "user", "content": "Also no cash"}, ] -templatized_input = tokenizer.apply_chat_template(conversation=messages, tokenize=False) +templatized_input: str = tokenizer.apply_chat_template( + conversation=messages, tokenize=False +) str_parts = [] tok_parts = [] diff --git a/docs/kv_smash/kvcache.py b/docs/kv_smash/kvcache.py index 94bfcc58..51e1b5cc 100644 --- a/docs/kv_smash/kvcache.py +++ b/docs/kv_smash/kvcache.py @@ -5,10 +5,9 @@ # ] # /// import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches -from transformers import AutoModelForCausalLM, PreTrainedTokenizer, AutoTokenizer -import torch model_id = "ibm-granite/granite-3.3-8b-instruct" device = torch.device("mps") diff --git a/docs/rewrite/session_deepdive/0.py b/docs/rewrite/session_deepdive/step0_session_api.py similarity index 99% rename from docs/rewrite/session_deepdive/0.py rename to docs/rewrite/session_deepdive/step0_session_api.py index 82d15776..271586b9 100644 --- a/docs/rewrite/session_deepdive/0.py +++ b/docs/rewrite/session_deepdive/step0_session_api.py @@ -1,7 +1,6 @@ from mellea import MelleaSession -from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend - +from mellea.stdlib.context import SimpleContext m = MelleaSession(backend=OllamaModelBackend("granite4:latest"), ctx=SimpleContext()) response = m.chat("What is 1+1?") diff --git a/docs/rewrite/session_deepdive/1.py b/docs/rewrite/session_deepdive/step1_functional_api.py similarity index 100% rename from docs/rewrite/session_deepdive/1.py rename to docs/rewrite/session_deepdive/step1_functional_api.py index c12e7df6..e629e840 100644 --- a/docs/rewrite/session_deepdive/1.py +++ b/docs/rewrite/session_deepdive/step1_functional_api.py @@ -1,6 +1,6 @@ import mellea.stdlib.functional as mfuncs -from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.context import SimpleContext response, next_context = mfuncs.chat( "What is 1+1?", diff --git a/docs/rewrite/session_deepdive/2.py b/docs/rewrite/session_deepdive/step2_act_cblocks.py similarity index 90% rename from docs/rewrite/session_deepdive/2.py rename to docs/rewrite/session_deepdive/step2_act_cblocks.py index 8002128d..30845436 100644 --- a/docs/rewrite/session_deepdive/2.py +++ b/docs/rewrite/session_deepdive/step2_act_cblocks.py @@ -1,10 +1,10 @@ import mellea.stdlib.functional as mfuncs -from mellea.stdlib.context import SimpleContext -from mellea.core import CBlock from mellea.backends.ollama import OllamaModelBackend +from mellea.core import CBlock +from mellea.stdlib.context import SimpleContext response, next_context = mfuncs.act( - CBlock("What is 1+1?"), + action=CBlock("What is 1+1?"), context=SimpleContext(), backend=OllamaModelBackend("granite4:latest"), ) diff --git a/docs/rewrite/session_deepdive/3.py b/docs/rewrite/session_deepdive/step3_async.py similarity index 65% rename from docs/rewrite/session_deepdive/3.py rename to docs/rewrite/session_deepdive/step3_async.py index 1d522f77..02f41887 100644 --- a/docs/rewrite/session_deepdive/3.py +++ b/docs/rewrite/session_deepdive/step3_async.py @@ -1,13 +1,14 @@ +import asyncio + import mellea.stdlib.functional as mfuncs -from mellea.core import CBlock, Context, Backend -from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend -import asyncio +from mellea.core import Backend, CBlock, Context +from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): - response, next_context = await mfuncs.aact( - CBlock("What is 1+1?"), context=ctx, backend=backend + response, _next_context = await mfuncs.aact( + action=CBlock("What is 1+1?"), context=ctx, backend=backend ) print(response.value) diff --git a/docs/rewrite/session_deepdive/4.py b/docs/rewrite/session_deepdive/step4_lazy_thunks.py similarity index 85% rename from docs/rewrite/session_deepdive/4.py rename to docs/rewrite/session_deepdive/step4_lazy_thunks.py index 9cd71dd8..1dc3871f 100644 --- a/docs/rewrite/session_deepdive/4.py +++ b/docs/rewrite/session_deepdive/step4_lazy_thunks.py @@ -1,13 +1,14 @@ -from mellea.core import CBlock, Context, Backend +import asyncio + from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context from mellea.stdlib.context import SimpleContext -import asyncio async def main(backend: Backend, ctx: Context): # This is not actually an async function; the computation ends immediately. It must be awaited because we create the thunk. # TODO clean up the above comment. - response, next_context = await backend.generate_from_context( + response, _next_context = await backend.generate_from_context( CBlock("What is 1+1?"), ctx=ctx, # TODO we should rationalize ctx and context acress mfuncs and base/backend. ) diff --git a/docs/rewrite/session_deepdive/5.py b/docs/rewrite/session_deepdive/step5_composition.py similarity index 95% rename from docs/rewrite/session_deepdive/5.py rename to docs/rewrite/session_deepdive/step5_composition.py index f03369c4..2494f5c8 100644 --- a/docs/rewrite/session_deepdive/5.py +++ b/docs/rewrite/session_deepdive/step5_composition.py @@ -1,10 +1,10 @@ -from mellea.core import CBlock, Context, Backend +import asyncio + from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -import asyncio - async def main(backend: Backend, ctx: Context): x, _ = await backend.generate_from_context(CBlock("What is 1+1?"), ctx=ctx) diff --git a/docs/rewrite/streaming/0.py b/docs/rewrite/streaming/streaming_chat_example.py similarity index 100% rename from docs/rewrite/streaming/0.py rename to docs/rewrite/streaming/streaming_chat_example.py index 2098ea3d..3235ce28 100644 --- a/docs/rewrite/streaming/0.py +++ b/docs/rewrite/streaming/streaming_chat_example.py @@ -1,8 +1,8 @@ +import asyncio + from mellea import start_session -from mellea.core.base import CBlock from mellea.backends.model_options import ModelOption - -import asyncio +from mellea.core.base import CBlock async def stream_chat(prompt: str) -> str: diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 48e74543..c72b0b5a 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -634,7 +634,7 @@ async def _generate_from_context_with_kv_cache( linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None - input_text, input_ids, merged_cache, attention_mask = ( + _input_text, input_ids, merged_cache, attention_mask = ( self._make_merged_kv_cache( linearized_ctx=linearized_ctx, ctx_as_conversation=ctx_as_chat, diff --git a/pyproject.toml b/pyproject.toml index 3e71a837..cb0c14b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,7 @@ ignore = [ # "UP006", # List vs list, etc # "UP007", # Option and Union # "UP035", # `typing.Set` is deprecated, use `set` instead" - "PD901", # Avoid using the generic variable name `df` for DataFrames + "PD901", # Generic variable name 'df' for DataFrames (deprecated rule, but needed while PD is enabled) "C901", # Complexity warnings ] @@ -188,6 +188,34 @@ max-complexity = 20 combine-as-imports = true split-on-trailing-comma = false +[tool.ruff.lint.per-file-ignores] +# E402: Module level import not at top of file +# Intentional in examples (pedagogical structure) and tests (pytestmark before imports) +# D: Docstring errors +# TODO: Add docstrings to cli/ (94 errors across 16 files) +# Not required in examples, tests, and notebooks (core mellea/ has complete docstrings) +"docs/**/*.py" = ["E402", "D"] +"docs/**/*.ipynb" = ["D"] +"test/**/*.py" = ["E402", "D"] +"cli/**/*.py" = ["D"] + +[[tool.mypy.overrides]] +# TODO: Fix type errors in docs/ examples (25 errors across 15 files) +# Temporarily relaxed type checking for docs/ to allow examples to focus on +# demonstrating functionality rather than perfect type correctness. +# Main issues: Optional/None handling, CBlock vs Component confusion, type mismatches +module = "docs.*" +disable_error_code = [ + "import-not-found", + "arg-type", + "attr-defined", + "assignment", + "operator", + "call-overload", + "index", + "return-value", +] + [tool.codespell] ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge,Strat' check-filenames = true diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/backends/test_adapters/test_adapter.py b/test/backends/test_adapters/test_adapter.py index 632cf7cf..e9022164 100644 --- a/test/backends/test_adapters/test_adapter.py +++ b/test/backends/test_adapters/test_adapter.py @@ -1,4 +1,5 @@ import pathlib + import pytest from mellea.backends.adapters import GraniteCommonAdapter diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 8add07aa..6b5bfb85 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -72,7 +72,7 @@ def session(backend): @pytest.mark.qualitative -def test_adapters(backend): +def test_adapters(backend) -> None: assert len(backend._added_adapters.items()) > 0 expected_qualified_name = "requirement_check_alora" @@ -90,7 +90,7 @@ def test_adapters(backend): @pytest.mark.qualitative -def test_system_prompt(session): +def test_system_prompt(session) -> None: result = session.chat( "Where are we going?", model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."}, @@ -99,8 +99,8 @@ def test_system_prompt(session): @pytest.mark.qualitative -def test_constraint_lora_with_requirement(session, backend): - answer = session.instruct( +def test_constraint_lora_with_requirement(session, backend) -> None: + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) assert session.backend._cache is not None # type: ignore @@ -116,9 +116,9 @@ def test_constraint_lora_with_requirement(session, backend): @pytest.mark.qualitative -def test_constraint_lora_override(session, backend): +def test_constraint_lora_override(session, backend) -> None: backend.default_to_constraint_checking_alora = False # type: ignore - answer = session.instruct( + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = session.validate( @@ -132,9 +132,9 @@ def test_constraint_lora_override(session, backend): @pytest.mark.qualitative -def test_constraint_lora_override_does_not_override_alora(session, backend): +def test_constraint_lora_override_does_not_override_alora(session, backend) -> None: backend.default_to_constraint_checking_alora = False # type: ignore - answer = session.instruct( + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = session.validate( @@ -158,9 +158,9 @@ def test_constraint_lora_override_does_not_override_alora(session, backend): @pytest.mark.qualitative -def test_llmaj_req_does_not_use_alora(session, backend): +def test_llmaj_req_does_not_use_alora(session, backend) -> None: backend.default_to_constraint_checking_alora = True # type: ignore - answer = session.instruct( + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = session.validate( @@ -176,15 +176,15 @@ def test_llmaj_req_does_not_use_alora(session, backend): @pytest.mark.qualitative -def test_instruct(session): +def test_instruct(session) -> None: result = session.instruct("Compute 1+1.") print(result) @pytest.mark.qualitative -def test_multiturn(session): +def test_multiturn(session) -> None: session.instruct("Compute 1+1") - beta = session.instruct( + session.instruct( "Take the result of the previous sum and find the corresponding letter in the greek alphabet.", model_options={ModelOption.MAX_NEW_TOKENS: 300}, ) @@ -193,7 +193,7 @@ def test_multiturn(session): @pytest.mark.qualitative -def test_chat(session): +def test_chat(session) -> None: output_message = session.chat("What is 1+1?") assert "2" in output_message.content, ( f"Expected a message with content containing 2 but found {output_message}" @@ -201,7 +201,7 @@ def test_chat(session): @pytest.mark.qualitative -def test_format(session): +def test_format(session) -> None: class Person(pydantic.BaseModel): name: str email_address: Annotated[ @@ -235,7 +235,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = [ "what is 1+1?", "what is 2+2?", @@ -253,7 +253,7 @@ async def test_generate_from_raw(session): @pytest.mark.qualitative -async def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -270,7 +270,7 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) + Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" @@ -278,7 +278,7 @@ class Answer(pydantic.BaseModel): @pytest.mark.qualitative -async def test_async_parallel_requests(session): +async def test_async_parallel_requests(session) -> None: model_opts = {ModelOption.STREAM: True} mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts @@ -314,7 +314,7 @@ async def test_async_parallel_requests(session): @pytest.mark.qualitative -async def test_async_avalue(session): +async def test_async_avalue(session) -> None: mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) @@ -324,7 +324,7 @@ async def test_async_avalue(session): @pytest.mark.qualitative -async def test_generate_with_lock(backend): +async def test_generate_with_lock(backend) -> None: # Enable the faulthandler for this test. faulthandler.enable(all_threads=True) @@ -343,7 +343,7 @@ async def test_generate_with_lock(backend): GraniteCommonAdapter("answerability", base_model_name=b.base_model_name) ) - memoized = dict() + memoized: dict[torch.Tensor, str] = dict() gen_func = model.generate def mock_func(input_ids, *args, **kwargs): @@ -414,7 +414,7 @@ def call_backend_generate(): @pytest.mark.skipif( sys.version_info < (3, 11), reason="asyncio.timeout requires python3.11 or higher" ) -async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): +async def test_generate_with_lock_does_not_block_when_awaiting_value(backend) -> None: """This is a tricky test to setup. It's purpose is to ensure that a long-running generation doesn't get blocked @@ -470,8 +470,7 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): # most likely due to a deadlock caused by awaiting a generation that cannot complete until # the streaming is done. try: - async with asyncio.timeout(timeout_in_seconds): - await req_mot.avalue() + await asyncio.wait_for(req_mot.avalue(), timeout=timeout_in_seconds) except Exception as e: # The timeout could also be caused by the generation taking too long... be careful! # We assume that if the streaming model output thunk is computed after getting its astream here, @@ -488,7 +487,7 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): @pytest.mark.qualitative -async def test_error_during_generate_with_lock(backend): +async def test_error_during_generate_with_lock(backend) -> None: # Create local versions of these objects so that mocking # doesn't impact other functions. Don't do this in regular code, # the copying is complex. @@ -529,7 +528,7 @@ def generate_and_raise_exc(*args, **kwargs): await req_mot.avalue() -def test_assert_correct_adapters(): +def test_assert_correct_adapters() -> None: model = Mock() # Test scenarios with no active adapters. diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index 2fb70e19..7999528a 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -132,7 +132,6 @@ def test_gen_slot(session): @generative def is_happy(text: str) -> bool: """Determine if text is of happy mood.""" - ... h = is_happy(session, text="I'm enjoying life.") diff --git a/test/backends/test_litellm_watsonx.py b/test/backends/test_litellm_watsonx.py index 014eeb2b..80f65b09 100644 --- a/test/backends/test_litellm_watsonx.py +++ b/test/backends/test_litellm_watsonx.py @@ -23,7 +23,7 @@ def session(): session.reset() -def test_has_potential_event_loop_errors(session): +def test_has_potential_event_loop_errors(session) -> None: """This test is specific to litellm backends that use watsonx/. It can be removed once that bug is fixed.""" backend: LiteLLMBackend = session.backend potential_err = backend._has_potential_event_loop_errors() @@ -42,13 +42,13 @@ async def new_event_loop() -> bool: @pytest.mark.qualitative -def test_multiple_sync_funcs(session): +def test_multiple_sync_funcs(session) -> None: session.chat("first") session.chat("second") @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+2+2?"] results = await session.backend.generate_from_raw( @@ -65,7 +65,7 @@ async def test_generate_from_raw(session): @pytest.mark.xfail( reason="litellm has a bug with watsonx; once that is fixed, this should pass." ) -async def test_multiple_async_funcs(session): +async def test_multiple_async_funcs(session) -> None: """If this test passes, remove the _has_potential_event_loop_errors func from litellm.""" session.chat( "first sync" diff --git a/test/backends/test_model_options.py b/test/backends/test_model_options.py index 5a75d7f3..36c1c228 100644 --- a/test/backends/test_model_options.py +++ b/test/backends/test_model_options.py @@ -1,4 +1,5 @@ import pytest + from mellea.backends import ModelOption diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 922777fb..fcca7fcd 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -25,7 +25,7 @@ def session(): @pytest.mark.qualitative -def test_simple_instruct(session): +def test_simple_instruct(session) -> None: result = session.instruct( "Write an email to Hendrik trying to sell him self-sealing stembolts." ) @@ -37,8 +37,8 @@ def test_simple_instruct(session): @pytest.mark.qualitative -def test_instruct_with_requirement(session): - response = session.instruct( +def test_instruct_with_requirement(session) -> None: + session.instruct( "Write an email to Hendrik convincing him to buy some self-sealing stembolts." ) @@ -61,7 +61,7 @@ def test_instruct_with_requirement(session): @pytest.mark.qualitative -def test_chat(session): +def test_chat(session) -> None: output_message = session.chat("What is 1+1?") assert "2" in output_message.content, ( f"Expected a message with content containing 2 but found {output_message}" @@ -69,7 +69,7 @@ def test_chat(session): @pytest.mark.qualitative -def test_format(session): +def test_format(session) -> None: class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -102,7 +102,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] results = await session.backend.generate_from_raw( @@ -114,7 +114,7 @@ async def test_generate_from_raw(session): @pytest.mark.xfail(reason="ollama sometimes fails generated structured outputs") -async def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -131,14 +131,14 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) + Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" ) -async def test_async_parallel_requests(session): +async def test_async_parallel_requests(session) -> None: model_opts = {ModelOption.STREAM: True} mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts @@ -173,7 +173,7 @@ async def test_async_parallel_requests(session): assert m2_final_val == mot2.value -async def test_async_avalue(session): +async def test_async_avalue(session) -> None: mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) @@ -182,7 +182,7 @@ async def test_async_avalue(session): assert m1_final_val == mot1.value -def test_multiple_asyncio_runs(session): +def test_multiple_asyncio_runs(session) -> None: async def test(): result = await session.achat("hello") assert result is not None @@ -191,7 +191,7 @@ async def test(): asyncio.run(test()) -def test_client_cache(session): +def test_client_cache(session) -> None: backend: OllamaModelBackend = session.backend first_client = backend._async_client diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 86f35174..d95648b0 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -39,14 +39,14 @@ def m_session(backend): @pytest.mark.qualitative -def test_instruct(m_session): +def test_instruct(m_session) -> None: result = m_session.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore @pytest.mark.qualitative -def test_multiturn(m_session): +def test_multiturn(m_session) -> None: m_session.instruct("What is the capital of France?") answer = m_session.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore @@ -68,7 +68,7 @@ def test_multiturn(m_session): @pytest.mark.qualitative -def test_chat(m_session): +def test_chat(m_session) -> None: output_message = m_session.chat("What is 1+1?") assert "2" in output_message.content, ( f"Expected a message with content containing 2 but found {output_message}" @@ -76,7 +76,7 @@ def test_chat(m_session): @pytest.mark.qualitative -def test_format(m_session): +def test_format(m_session) -> None: class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -109,11 +109,11 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(m_session): +async def test_generate_from_raw(m_session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] with pytest.raises(openai.BadRequestError): - results = await m_session.backend.generate_from_raw( + await m_session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx ) @@ -141,7 +141,7 @@ async def test_generate_from_raw(m_session): # assert False, f"formatting directive failed for {random_result.value}: {e.json()}" -async def test_async_parallel_requests(m_session): +async def test_async_parallel_requests(m_session) -> None: model_opts = {ModelOption.STREAM: True} mot1, _ = await m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts @@ -176,7 +176,7 @@ async def test_async_parallel_requests(m_session): assert m2_final_val == mot2.value -async def test_async_avalue(m_session): +async def test_async_avalue(m_session) -> None: mot1, _ = await m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) @@ -185,7 +185,7 @@ async def test_async_avalue(m_session): assert m1_final_val == mot1.value -def test_client_cache(backend): +def test_client_cache(backend) -> None: first_client = backend._async_client async def get_client_async(): @@ -211,7 +211,7 @@ async def get_client_async(): assert len(backend._client_cache.cache.values()) == 2 -async def test_reasoning_effort_conditional_passing(backend): +async def test_reasoning_effort_conditional_passing(backend) -> None: """Test that reasoning_effort is only passed to API when not None.""" from unittest.mock import AsyncMock, MagicMock, patch @@ -251,7 +251,7 @@ async def test_reasoning_effort_conditional_passing(backend): ) -def test_api_key_and_base_url_from_parameters(): +def test_api_key_and_base_url_from_parameters() -> None: """Test that API key and base URL can be set via parameters.""" backend = OpenAIBackend( model_id="gpt-4", api_key="test-api-key", base_url="https://api.test.com/v1" @@ -260,7 +260,7 @@ def test_api_key_and_base_url_from_parameters(): assert backend._base_url == "https://api.test.com/v1" -def test_parameter_overrides_env_variable(): +def test_parameter_overrides_env_variable() -> None: """Test that explicit parameters override environment variables.""" with patch.dict( os.environ, @@ -275,7 +275,7 @@ def test_parameter_overrides_env_variable(): assert backend._base_url == "https://api.param.com/v1" -def test_missing_api_key_raises_error(): +def test_missing_api_key_raises_error() -> None: """Test that missing API key raises ValueError with helpful message.""" with patch.dict(os.environ, {}, clear=True): with pytest.raises(ValueError) as exc_info: diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index d30c7f7f..2029dfe5 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -1,16 +1,16 @@ # test/rits_backend_tests/test_openai_integration.py import os +from typing import Annotated import pydantic import pytest -from typing_extensions import Annotated from mellea import MelleaSession +from mellea.backends import ModelOption from mellea.backends.adapters import GraniteCommonAdapter -from mellea.formatters import TemplateFormatter from mellea.backends.openai import OpenAIBackend -from mellea.backends import ModelOption from mellea.core import CBlock, Context, ModelOutputThunk +from mellea.formatters import TemplateFormatter from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import ALoraRequirement, LLMaJRequirement @@ -37,14 +37,14 @@ class TestOpenAIBackend: ) m = MelleaSession(backend, ctx=ChatContext()) - def test_instruct(self): + def test_instruct(self) -> None: self.m.reset() result = self.m.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore self.m.reset() - def test_multiturn(self): + def test_multiturn(self) -> None: self.m.instruct("What is the capital of France?") answer = self.m.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore @@ -65,7 +65,7 @@ def test_multiturn(self): # assert "granite3.3:8b" in result.value # self.m.reset() - def test_format(self): + def test_format(self) -> None: class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -95,9 +95,8 @@ class Email(pydantic.BaseModel): # this is not guaranteed, due to the lack of regexp pattern # assert "@" in email.to.email_address # assert email.to.email_address.endswith("example.com") - pass - async def test_generate_from_raw(self): + async def test_generate_from_raw(self) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] results = await self.m.backend.generate_from_raw( @@ -107,7 +106,7 @@ async def test_generate_from_raw(self): assert len(results) == len(prompts) assert results[0].value is not None - async def test_generate_from_raw_with_format(self): + async def test_generate_from_raw_with_format(self) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -124,7 +123,7 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) # type: ignore + Answer.model_validate_json(random_result.value) # type: ignore except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" @@ -146,7 +145,7 @@ class TestOpenAIALoraStuff: m = MelleaSession(backend, ctx=ChatContext()) - def test_adapters(self): + def test_adapters(self) -> None: assert len(self.backend._added_adapters.items()) > 0 adapter = self.backend._added_adapters["requirement_check_alora"] @@ -161,7 +160,7 @@ def test_adapters(self): self.backend.unload_adapter(adapter.qualified_name) assert adapter.qualified_name not in self.backend._loaded_adapters - def test_system_prompt(self): + def test_system_prompt(self) -> None: self.m.reset() result = self.m.chat( "Where are we going?", @@ -169,9 +168,9 @@ def test_system_prompt(self): ) print(result) - def test_constraint_lora_with_requirement(self): + def test_constraint_lora_with_requirement(self) -> None: self.m.reset() - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -184,10 +183,10 @@ def test_constraint_lora_with_requirement(self): assert "requirement_likelihood" in str(val_result.reason) self.m.reset() - def test_constraint_lora_override(self): + def test_constraint_lora_override(self) -> None: self.m.reset() self.backend.default_to_constraint_checking_alora = False # type: ignore - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -201,10 +200,10 @@ def test_constraint_lora_override(self): self.backend.default_to_constraint_checking_alora = True self.m.reset() - def test_constraint_lora_override_does_not_override_alora(self): + def test_constraint_lora_override_does_not_override_alora(self) -> None: self.m.reset() self.backend.default_to_constraint_checking_alora = False # type: ignore - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -220,6 +219,7 @@ def test_constraint_lora_override_does_not_override_alora(self): # the correct actions / results in it. assert isinstance(non_alora_output.context, Context) assert isinstance(non_alora_output.thunk, ModelOutputThunk) + assert non_alora_output.context.previous_node is not None assert isinstance( non_alora_output.context.previous_node.node_data, ALoraRequirement, # type: ignore @@ -229,10 +229,10 @@ def test_constraint_lora_override_does_not_override_alora(self): self.backend.default_to_constraint_checking_alora = True self.m.reset() - def test_llmaj_req_does_not_use_alora(self): + def test_llmaj_req_does_not_use_alora(self) -> None: self.m.reset() self.backend.default_to_constraint_checking_alora = True # type: ignore - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -245,15 +245,15 @@ def test_llmaj_req_does_not_use_alora(self): assert str(non_alora_output.reason) not in ["Y", "N"] self.m.reset() - def test_instruct(self): + def test_instruct(self) -> None: self.m.reset() result = self.m.instruct("Compute 1+1.") print(result) self.m.reset() - def test_multiturn(self): + def test_multiturn(self) -> None: self.m.instruct("Compute 1+1") - beta = self.m.instruct( + self.m.instruct( "Let n be the result of the previous sum. Find the n-th letter in the greek alphabet." ) words = self.m.instruct( @@ -262,7 +262,7 @@ def test_multiturn(self): print(words) self.m.reset() - def test_format(self): + def test_format(self) -> None: class Person(pydantic.BaseModel): name: str email_address: Annotated[ diff --git a/test/backends/test_tool_calls.py b/test/backends/test_tool_calls.py index 958dfe78..ff15cef4 100644 --- a/test/backends/test_tool_calls.py +++ b/test/backends/test_tool_calls.py @@ -1,15 +1,17 @@ +from collections.abc import Callable +from typing import Any + import pytest +from mellea.backends import ModelOption from mellea.backends.ollama import OllamaModelBackend from mellea.backends.tools import ( add_tools_from_context_actions, add_tools_from_model_options, ) -from mellea.backends import ModelOption from mellea.core import ModelOutputThunk -from mellea.stdlib.context import ChatContext - from mellea.stdlib.components.docs.richdocument import Table +from mellea.stdlib.context import ChatContext from mellea.stdlib.session import MelleaSession @@ -44,7 +46,7 @@ def test2(): ... model_opts = {ModelOption.TOOLS: [test1, test2]} - tools = {} + tools: dict[str, Callable[..., Any]] = {} add_tools_from_model_options(tools, model_opts) assert "test1" in tools diff --git a/test/backends/test_tool_helpers.py b/test/backends/test_tool_helpers.py index 4441e885..68693dd2 100644 --- a/test/backends/test_tool_helpers.py +++ b/test/backends/test_tool_helpers.py @@ -1,9 +1,10 @@ import pytest + +from mellea.backends import ModelOption from mellea.backends.tools import ( add_tools_from_context_actions, add_tools_from_model_options, ) -from mellea.backends import ModelOption from mellea.core import CBlock, Component, ModelOutputThunk, TemplateRepresentation diff --git a/test/backends/test_vision_openai.py b/test/backends/test_vision_openai.py index 6285faaf..5220e51f 100644 --- a/test/backends/test_vision_openai.py +++ b/test/backends/test_vision_openai.py @@ -121,9 +121,13 @@ def test_image_block_in_instruction( image_url = content_img.get("image_url") assert image_url is not None assert "url" in image_url + assert isinstance(image_url, dict) # check that the image is in the url content - assert image_block.value[:100] in image_url["url"] + url_value = image_url["url"] + assert isinstance(url_value, str) + assert image_block.value is not None + assert image_block.value[:100] in url_value def test_image_block_in_chat( @@ -174,9 +178,14 @@ def test_image_block_in_chat( image_url = content_img.get("image_url") assert image_url is not None assert "url" in image_url + assert isinstance(image_url, dict) # check that the image is in the url content - assert ImageBlock.from_pil_image(pil_image).value[:100] in image_url["url"] + image_value = ImageBlock.from_pil_image(pil_image).value + assert image_value is not None + url_value = image_url["url"] + assert isinstance(url_value, str) + assert image_value[:100] in url_value if __name__ == "__main__": diff --git a/test/backends/test_vllm.py b/test/backends/test_vllm.py index ed4a0354..c4f49a20 100644 --- a/test/backends/test_vllm.py +++ b/test/backends/test_vllm.py @@ -55,7 +55,7 @@ def session(backend): @pytest.mark.qualitative -def test_system_prompt(session): +def test_system_prompt(session) -> None: result = session.chat( "Where are we going?", model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."}, @@ -64,15 +64,15 @@ def test_system_prompt(session): @pytest.mark.qualitative -def test_instruct(session): +def test_instruct(session) -> None: result = session.instruct("Compute 1+1.") print(result) @pytest.mark.qualitative -def test_multiturn(session): +def test_multiturn(session) -> None: session.instruct("Compute 1+1") - beta = session.instruct( + session.instruct( "Take the result of the previous sum and find the corresponding letter in the greek alphabet." ) words = session.instruct("Now list five English words that start with that letter.") @@ -80,7 +80,7 @@ def test_multiturn(session): @pytest.mark.qualitative -def test_format(session): +def test_format(session) -> None: class Person(pydantic.BaseModel): name: str email_address: Annotated[ @@ -111,7 +111,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] results = await session.backend.generate_from_raw( @@ -123,7 +123,7 @@ async def test_generate_from_raw(session): @pytest.mark.qualitative -async def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -140,7 +140,7 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) + Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" @@ -148,7 +148,7 @@ class Answer(pydantic.BaseModel): @pytest.mark.qualitative -def test_async_parallel_requests(session): +def test_async_parallel_requests(session) -> None: async def parallel_requests(): model_opts = {ModelOption.STREAM: True} mot1, _ = await session.backend.generate_from_context( @@ -187,7 +187,7 @@ async def parallel_requests(): @pytest.mark.qualitative -def test_async_avalue(session): +def test_async_avalue(session) -> None: async def avalue(): mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() diff --git a/test/core/test_base.py b/test/core/test_base.py index 30aa6f7f..cada2f42 100644 --- a/test/core/test_base.py +++ b/test/core/test_base.py @@ -1,5 +1,7 @@ from typing import Any + import pytest + from mellea.core import CBlock, Component, ModelOutputThunk from mellea.stdlib.components import Message diff --git a/test/core/test_component_typing.py b/test/core/test_component_typing.py index 6a9c2a82..a20de283 100644 --- a/test/core/test_component_typing.py +++ b/test/core/test_component_typing.py @@ -44,7 +44,7 @@ def _parse(self, computed: ModelOutputThunk) -> int: return -1 try: return int(computed.value) - except: + except Exception: return -2 @@ -80,7 +80,7 @@ def test_mot_init_typing(): assert hasattr(mot, "__orig_class__"), ( "mots are generics and should have this field" ) - assert get_args(mot.__orig_class__)[0] == float, ( # type: ignore + assert get_args(mot.__orig_class__)[0] is float, ( # type: ignore f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" # type: ignore ) # type: ignore @@ -113,7 +113,7 @@ def test_component_parsing_fails(): def test_incorrect_type_override(): with pytest.raises(TypeError): - instruction = Instruction[int](description="this is an instruction") # type: ignore + Instruction[int](description="this is an instruction") # type: ignore # Marking as qualitative for now since there's so much generation required for this. diff --git a/test/formatters/test_template_formatter.py b/test/formatters/test_template_formatter.py index c29fcc4a..5d851053 100644 --- a/test/formatters/test_template_formatter.py +++ b/test/formatters/test_template_formatter.py @@ -2,15 +2,13 @@ import os import sys import tempfile -from typing import List import pytest -from mellea.formatters import TemplateFormatter -from mellea.backends.model_ids import ModelIdentifier, IBM_GRANITE_3_2_8B +from mellea.backends.model_ids import IBM_GRANITE_3_2_8B, ModelIdentifier from mellea.core import CBlock, Component, ModelOutputThunk, TemplateRepresentation -from mellea.stdlib.components import Message, Instruction -from mellea.stdlib.components import MObject +from mellea.formatters import TemplateFormatter +from mellea.stdlib.components import Instruction, Message, MObject @pytest.fixture(scope="module") @@ -100,7 +98,8 @@ def test_user_path(instr: Instruction): """Ensures that paths with no templates don't prevent default template lookups. Also creates a temporary dir to use as a user-specified dir and ensures template lookup - logic is correct.""" + logic is correct. + """ tf = TemplateFormatter( "granite3.3", template_path="/fake/path", use_template_cache=False ) @@ -160,7 +159,7 @@ def test_no_module(tf: TemplateFormatter): def test_no_template(tf: TemplateFormatter): class _NoTemplate(Component[str]): - def parts(self) -> List[Component | CBlock]: + def parts(self) -> list[Component | CBlock]: return [] def format_for_llm(self) -> TemplateRepresentation: @@ -214,7 +213,8 @@ def test_empty_model_id(instr: Instruction): def test_template_caching(instr: Instruction): """Caching shouldn't be interacted with this way by users. - Only toggling these internal variables to test code paths.""" + Only toggling these internal variables to test code paths. + """ tf = TemplateFormatter("default", use_template_cache=True) assert tf._template_cache is not None @@ -236,8 +236,8 @@ def test_template_caching(instr: Instruction): def test_custom_component_external_package(tf: TemplateFormatter): """Creates a fake package with a custom component and loads the package. - Ensures template loading works for custom components defined in other packages.""" - + Ensures template loading works for custom components defined in other packages. + """ new_component_content = """ from mellea.core import Component, TemplateRepresentation, ModelOutputThunk class NewComponent(Component[str]): diff --git a/test/stdlib/components/docs/test_richdocument.py b/test/stdlib/components/docs/test_richdocument.py index f96038c3..3538beac 100644 --- a/test/stdlib/components/docs/test_richdocument.py +++ b/test/stdlib/components/docs/test_richdocument.py @@ -1,10 +1,12 @@ import os -from mellea.core import TemplateRepresentation -from mellea.stdlib.components.docs.richdocument import RichDocument, Table -import mellea -from docling_core.types.doc.document import DoclingDocument import tempfile + import pytest +from docling_core.types.doc.document import DoclingDocument + +import mellea +from mellea.core import TemplateRepresentation +from mellea.stdlib.components.docs.richdocument import RichDocument, Table @pytest.fixture(scope="module") @@ -22,12 +24,12 @@ def test_richdocument_basics(rd: RichDocument): ) repr = rd.format_for_llm() - assert type(repr) == str, "rich document template args should be a dict" + assert isinstance(repr, str), "rich document template args should be a dict" def test_richdocument_markdown(rd: RichDocument): mkd = rd.to_markdown() - assert type(mkd) == str, "rich document `to_markdown` should be a string" + assert isinstance(mkd, str), "rich document `to_markdown` should be a string" assert "Bag of Words" in mkd, "expected string not in rd `to_markdown` output" @@ -48,7 +50,7 @@ def test_table(rd: RichDocument): # Getting the tables technically tests the functionality of richdocument, # but we do it here to make it easier. The provided document has one table. tables = rd.get_tables() - assert all(type(t) == Table for t in tables), ( + assert all(isinstance(t, Table) for t in tables), ( f"rich document `get_tables` returned a non-table value: {tables}" ) assert len(tables) > 0, ( @@ -57,13 +59,15 @@ def test_table(rd: RichDocument): table = tables[0] repr = table.format_for_llm() - assert type(repr) == TemplateRepresentation, "table template args should be a dict" + assert isinstance(repr, TemplateRepresentation), ( + "table template args should be a dict" + ) assert "table" in repr.args.keys() and len(repr.args.keys()) == 1, ( "table's should have a single `as_markdown` key" ) mkd_table = table.to_markdown() - assert type(mkd_table) == str, "table `to_markdown` should return a string" + assert isinstance(mkd_table, str), "table `to_markdown` should return a string" loaded_table = Table.from_markdown(mkd_table) assert loaded_table is not None, ( @@ -97,7 +101,7 @@ def test_empty_table(): def test_richdocument_generation(rd: RichDocument): m = mellea.start_session(backend_name="hf") response = m.chat(rd.to_markdown()[:500] + "\nSummarize the provided document.") - assert response.content is not "", ( + assert response.content != "", ( "response content should not be empty when summarizing a rich document" ) assert "paper" in response.content.lower() or "gltr" in response.content.lower(), ( diff --git a/test/stdlib/components/test_chat.py b/test/stdlib/components/test_chat.py index 1733e8e6..66ebb9fc 100644 --- a/test/stdlib/components/test_chat.py +++ b/test/stdlib/components/test_chat.py @@ -1,7 +1,7 @@ import pytest -from mellea.stdlib.components import Document -from mellea.stdlib.components import Message + from mellea.helpers import messages_to_docs +from mellea.stdlib.components import Document, Message def test_message_with_docs(): diff --git a/test/stdlib/components/test_genslot.py b/test/stdlib/components/test_genslot.py index ba956507..d3814ae4 100644 --- a/test/stdlib/components/test_genslot.py +++ b/test/stdlib/components/test_genslot.py @@ -78,7 +78,6 @@ def test_sentiment_output(classify_sentiment_output): def test_gen_slot_logs(classify_sentiment_output, session): - sent = classify_sentiment_output last_prompt = session.last_prompt()[-1] assert isinstance(last_prompt, dict) assert set(last_prompt.keys()) == {"role", "content", "images"} diff --git a/test/stdlib/components/test_mify.py b/test/stdlib/components/test_mify.py index 0587811a..d03ce9fd 100644 --- a/test/stdlib/components/test_mify.py +++ b/test/stdlib/components/test_mify.py @@ -1,9 +1,9 @@ import pytest -from mellea.formatters import TemplateFormatter from mellea.core import Component, TemplateRepresentation -from mellea.stdlib.components.mobject import Query, MObjectProtocol, MObject -from mellea.stdlib.components.mify import mify, MifiedProtocol +from mellea.formatters import TemplateFormatter +from mellea.stdlib.components.mify import MifiedProtocol, mify +from mellea.stdlib.components.mobject import MObject, MObjectProtocol, Query def test_protocol_adherence(): @@ -136,7 +136,7 @@ def get_details(self) -> str: return f"{self.name}, {self.age}" def extraneous_func(self): - """this function does nothing.""" + """This function does nothing.""" return diff --git a/test/stdlib/components/test_transform.py b/test/stdlib/components/test_transform.py index c99ac884..ecbd95d9 100644 --- a/test/stdlib/components/test_transform.py +++ b/test/stdlib/components/test_transform.py @@ -1,8 +1,8 @@ import pytest from mellea.core import TemplateRepresentation -from mellea.stdlib.components.docs.richdocument import TableTransform from mellea.stdlib.components import MObject, Query, Transform +from mellea.stdlib.components.docs.richdocument import TableTransform custom_mobject_description = "custom mobject description" @@ -53,7 +53,7 @@ def test_get_transform_object_custom(): assert isinstance(transform, TableTransform) with pytest.raises(AssertionError): - tr = transform.format_for_llm() + transform.format_for_llm() if __name__ == "__main__": diff --git a/test/stdlib/requirements/test_reqlib_markdown.py b/test/stdlib/requirements/test_reqlib_markdown.py index 5b901ff1..b663999f 100644 --- a/test/stdlib/requirements/test_reqlib_markdown.py +++ b/test/stdlib/requirements/test_reqlib_markdown.py @@ -60,26 +60,22 @@ async def test_markdown_table(): def test_default_output_to_bool_yes(): - assert default_output_to_bool("yeS") == True + assert default_output_to_bool("yeS") def test_default_output_to_bool_no(): - assert default_output_to_bool("nO") == False + assert not default_output_to_bool("nO") def test_default_output_to_bool_complicated_yes(): - assert ( - default_output_to_bool( - CBlock("The requirement is met by the output. Therefore, my answer is yes.") - ) - == True + assert default_output_to_bool( + CBlock("The requirement is met by the output. Therefore, my answer is yes.") ) def test_default_output_to_bool_word_with_yes_in_it(): - assert ( - default_output_to_bool("Here's a word that meets those requirements: ayes.") - == False + assert not default_output_to_bool( + "Here's a word that meets those requirements: ayes." ) diff --git a/test/stdlib/requirements/test_reqlib_python.py b/test/stdlib/requirements/test_reqlib_python.py index 403fa79e..85942716 100644 --- a/test/stdlib/requirements/test_reqlib_python.py +++ b/test/stdlib/requirements/test_reqlib_python.py @@ -17,17 +17,16 @@ _llm_sandbox_available = False from mellea.core import Context, ModelOutputThunk +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.python_reqs import ( PythonExecutionReq, _has_python_code_listing, _python_executes_without_error, ) -from mellea.stdlib.context import ChatContext def from_model(content: str) -> Context: """Helper to create context from model output.""" - ctx = ChatContext() ctx = ctx.add(ModelOutputThunk(value=content)) return ctx diff --git a/test/stdlib/requirements/test_reqlib_tools.py b/test/stdlib/requirements/test_reqlib_tools.py index 553cc0c9..c20d4a78 100644 --- a/test/stdlib/requirements/test_reqlib_tools.py +++ b/test/stdlib/requirements/test_reqlib_tools.py @@ -1,4 +1,5 @@ import pytest + from mellea.stdlib.requirements.tool_reqs import _name2str diff --git a/test/stdlib/requirements/test_requirement.py b/test/stdlib/requirements/test_requirement.py index 5db655fe..f6998265 100644 --- a/test/stdlib/requirements/test_requirement.py +++ b/test/stdlib/requirements/test_requirement.py @@ -1,6 +1,7 @@ import pytest -from mellea.stdlib.context import ChatContext + from mellea.core import ModelOutputThunk, Requirement +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import LLMaJRequirement, simple_validate from mellea.stdlib.session import start_session @@ -54,7 +55,7 @@ def test_simple_validate_invalid(): validation_func = simple_validate(lambda x: None) # type: ignore with pytest.raises(ValueError): - val_result = validation_func(ctx) + validation_func(ctx) if __name__ == "__main__": diff --git a/test/stdlib/sampling/test_sofai_graph_coloring.py b/test/stdlib/sampling/test_sofai_graph_coloring.py index 37b06ffc..c1b85a90 100644 --- a/test/stdlib/sampling/test_sofai_graph_coloring.py +++ b/test/stdlib/sampling/test_sofai_graph_coloring.py @@ -409,7 +409,7 @@ def test_best_attempt_includes_coloring_feedback( ] ] - s2_action, s2_context = strategy._prepare_s2_context( + s2_action, _s2_context = strategy._prepare_s2_context( s2_mode="best_attempt", original_action=original_action, original_context=original_context, diff --git a/test/stdlib/test_base_context.py b/test/stdlib/test_base_context.py index 698a9240..2fccb11f 100644 --- a/test/stdlib/test_base_context.py +++ b/test/stdlib/test_base_context.py @@ -1,7 +1,7 @@ import pytest -from mellea.core import Context, CBlock -from mellea.stdlib.context import SimpleContext, ChatContext +from mellea.core import CBlock, Context +from mellea.stdlib.context import ChatContext, SimpleContext def context_construction(cls: type[Context]): diff --git a/test/stdlib/test_chat_view.py b/test/stdlib/test_chat_view.py index 9b0ff93d..327258e8 100644 --- a/test/stdlib/test_chat_view.py +++ b/test/stdlib/test_chat_view.py @@ -1,7 +1,7 @@ import pytest -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message, as_chat_history +from mellea.stdlib.context import ChatContext from mellea.stdlib.session import start_session @@ -25,7 +25,7 @@ def test_chat_view_linear_ctx(linear_session): linear_session.chat("What is 1+1?") linear_session.chat("What is 2+2?") assert len(as_chat_history(linear_session.ctx)) == 4 - assert all([type(x) == Message for x in as_chat_history(linear_session.ctx)]) + assert all(isinstance(x, Message) for x in as_chat_history(linear_session.ctx)) assert len(linear_session.ctx.view_for_generation()) == 4 @@ -34,7 +34,7 @@ def test_chat_view_simple_ctx(simple_session): simple_session.chat("What is 1+1?") simple_session.chat("What is 2+2?") assert len(as_chat_history(simple_session.ctx)) == 4 - assert all([type(x) == Message for x in as_chat_history(simple_session.ctx)]) + assert all(isinstance(x, Message) for x in as_chat_history(simple_session.ctx)) assert len(simple_session.ctx.view_for_generation()) == 0 diff --git a/test/stdlib/test_session.py b/test/stdlib/test_session.py index 2bace2cd..c8514574 100644 --- a/test/stdlib/test_session.py +++ b/test/stdlib/test_session.py @@ -67,7 +67,7 @@ async def test_async_await_with_chat_context(m_session): ctx = ctx.previous_node # type: ignore # Ensure we made it back to the root. - assert ctx.is_root_node == True # type: ignore + assert ctx.is_root_node # type: ignore async def test_async_without_waiting_with_chat_context(m_session): diff --git a/test/stdlib/test_spans.py b/test/stdlib/test_spans.py index 71b03ed0..4440a1c0 100644 --- a/test/stdlib/test_spans.py +++ b/test/stdlib/test_spans.py @@ -29,7 +29,7 @@ def m_session(gh_run): @pytest.mark.qualitative -async def test_lazy_spans(m_session): +async def test_lazy_spans(m_session) -> None: m: MelleaSession = m_session backend, ctx = m.backend, m.ctx @@ -45,7 +45,7 @@ async def test_lazy_spans(m_session): @pytest.mark.qualitative -async def test_kv(m_session): +async def test_kv(m_session) -> None: m: MelleaSession = m_session backend, ctx = m.backend, m.ctx # type: ignore @@ -56,7 +56,7 @@ async def test_kv(m_session): ) ) - backend: LocalHFBackend = backend + assert isinstance(backend, LocalHFBackend) response = await backend._generate_from_context_with_kv_cache( action=CBlock("What is Nathan's work address?"), ctx=ctx, model_options=dict() ) From e70c319e488651f9bb8774ada5ce796ec0563120 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 28 Jan 2026 15:49:52 -0600 Subject: [PATCH 2/5] test: fix mypy errors in docs/ Signed-off-by: Alex Bozarth --- docs/examples/aLora/101_example.py | 2 +- .../generative_slots_with_requirements.py | 4 ++-- .../image_text_models/vision_litellm_backend.py | 5 +++-- .../image_text_models/vision_ollama_chat.py | 2 +- .../image_text_models/vision_openai_examples.py | 7 +++++-- docs/examples/m_serve/pii_serve.py | 2 +- docs/examples/melp/lazy_fib.py | 2 +- docs/examples/melp/lazy_fib_sample.py | 5 +++-- docs/examples/mini_researcher/researcher.py | 13 +++++++------ docs/examples/safety/guardian.py | 2 +- docs/examples/tools/interpreter_example.py | 6 ++++++ docs/examples/tutorial/instruct_validate_repair.py | 7 ++++--- docs/kv_smash/kv_with_chat.py | 4 ++-- docs/rewrite/session_deepdive/step2_act_cblocks.py | 4 ++-- docs/rewrite/session_deepdive/step3_async.py | 5 +++-- pyproject.toml | 12 +----------- 16 files changed, 43 insertions(+), 39 deletions(-) diff --git a/docs/examples/aLora/101_example.py b/docs/examples/aLora/101_example.py index 8c3b66e2..4e98cd0e 100644 --- a/docs/examples/aLora/101_example.py +++ b/docs/examples/aLora/101_example.py @@ -20,7 +20,7 @@ backend=backend, ) -backend.add_alora(custom_stembolt_failure_constraint) +backend.add_alora(custom_stembolt_failure_constraint) # type: ignore[attr-defined] # Create M session m = MelleaSession(backend, ctx=ChatContext()) diff --git a/docs/examples/generative_slots/generative_slots_with_requirements.py b/docs/examples/generative_slots/generative_slots_with_requirements.py index 63aadd19..03c88934 100644 --- a/docs/examples/generative_slots/generative_slots_with_requirements.py +++ b/docs/examples/generative_slots/generative_slots_with_requirements.py @@ -27,8 +27,8 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: ) print( - f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```" - ) # type: ignore + f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```" # type: ignore[index] + ) # Prompt to the model looked like: # ``` # Your task is to imitate the output of the following function for the given arguments. diff --git a/docs/examples/image_text_models/vision_litellm_backend.py b/docs/examples/image_text_models/vision_litellm_backend.py index 9b6c4d82..7bf928f0 100644 --- a/docs/examples/image_text_models/vision_litellm_backend.py +++ b/docs/examples/image_text_models/vision_litellm_backend.py @@ -26,14 +26,15 @@ # test with PIL image res_instruct = m.instruct( "Is there a person on the image? Is the subject in the image smiling?", - images=[test_pil], + images=[test_pil], # type: ignore[arg-type] ) print(f"Test with PIL and instruct: \n{res_instruct!s}\n-----") # print(m.last_prompt()) # with PIL image and using m.chat res_chat = m.chat( - "How many eyes can you identify in the image? Explain.", images=[test_pil] + "How many eyes can you identify in the image? Explain.", + images=[test_pil], # type: ignore[arg-type] ) print(f"Test with PIL and chat: \n{res_chat.content!s}\n-----") diff --git a/docs/examples/image_text_models/vision_ollama_chat.py b/docs/examples/image_text_models/vision_ollama_chat.py index a7aca7a7..79debbff 100644 --- a/docs/examples/image_text_models/vision_ollama_chat.py +++ b/docs/examples/image_text_models/vision_ollama_chat.py @@ -15,7 +15,7 @@ test_pil = Image.open(image_path) # ask a question about the image -res = m.instruct("Is the subject in the image smiling?", images=[test_pil]) +res = m.instruct("Is the subject in the image smiling?", images=[test_pil]) # type: ignore[arg-type] print(f"Result:{res!s}") # This instruction should refer to the first image. diff --git a/docs/examples/image_text_models/vision_openai_examples.py b/docs/examples/image_text_models/vision_openai_examples.py index 5f9ce06d..813fe1df 100644 --- a/docs/examples/image_text_models/vision_openai_examples.py +++ b/docs/examples/image_text_models/vision_openai_examples.py @@ -47,8 +47,11 @@ # print(m.last_prompt()) # and now with PIL image and using m.chat -res = m.chat("How many eyes can you identify in the image? Explain.", images=[test_pil]) -print(str(res.content)) +chat_res = m.chat( + "How many eyes can you identify in the image? Explain.", + images=[test_pil], # type: ignore[arg-type] +) +print(str(chat_res.content)) # and now without images again... res = m.instruct("How many eyes can you identify in the image?", images=[]) diff --git a/docs/examples/m_serve/pii_serve.py b/docs/examples/m_serve/pii_serve.py index 0c166326..75fac8cc 100644 --- a/docs/examples/m_serve/pii_serve.py +++ b/docs/examples/m_serve/pii_serve.py @@ -79,7 +79,7 @@ def serve( model_options: None | dict = None, ) -> ModelOutputThunk | SamplingResult | str: """Simple serve example to do PII stuff.""" - message = input[-1].content + message = input[-1].content or "" result = pii_remove_validate( session, message, requirements=requirements, model_options=model_options ) diff --git a/docs/examples/melp/lazy_fib.py b/docs/examples/melp/lazy_fib.py index 5e739c20..5a693324 100644 --- a/docs/examples/melp/lazy_fib.py +++ b/docs/examples/melp/lazy_fib.py @@ -25,7 +25,7 @@ async def fib_main(backend: Backend, ctx: Context): mot = await fib(backend, ctx, fibs[i - 1], fibs[i - 2]) fibs.append(mot) - print(await fibs[-1].avalue()) + print(await fibs[-1].avalue()) # type: ignore[attr-defined] # for x in fibs: # match x: # case ModelOutputThunk(): diff --git a/docs/examples/melp/lazy_fib_sample.py b/docs/examples/melp/lazy_fib_sample.py index 833a2803..b42c5901 100644 --- a/docs/examples/melp/lazy_fib_sample.py +++ b/docs/examples/melp/lazy_fib_sample.py @@ -40,13 +40,14 @@ async def fib_sampling_version( async def fib_sampling_version_main(backend: Backend, ctx: Context): - fibs = [] + fibs: list[CBlock | ModelOutputThunk] = [] for i in range(20): if i == 0 or i == 1: fibs.append(CBlock(f"{i}")) else: mot = await fib_sampling_version(backend, ctx, fibs[i - 1], fibs[i - 2]) - fibs.append(mot) + if mot is not None: + fibs.append(mot) for x_i, x in enumerate(fibs): match x: diff --git a/docs/examples/mini_researcher/researcher.py b/docs/examples/mini_researcher/researcher.py index 89746404..8e0d8e1b 100644 --- a/docs/examples/mini_researcher/researcher.py +++ b/docs/examples/mini_researcher/researcher.py @@ -1,5 +1,6 @@ from collections.abc import Callable from functools import cache +from typing import Any from openai import BaseModel from pydantic import ValidationError @@ -9,7 +10,7 @@ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.core import Requirement, SamplingResult +from mellea.core import CBlock, Component, Requirement, SamplingResult from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy @@ -126,7 +127,7 @@ def max_sub_sections(out: str) -> bool: validation_fn=simple_validate(max_sub_sections), ) - outline_context = { + outline_context: dict[str, str | CBlock | Component] = { f"Document {i + 1}": f"## Title: {d.title}, ## Source: {d.source}" for i, d in enumerate(context) } @@ -136,14 +137,14 @@ def max_sub_sections(out: str) -> bool: description="Create an outline for a report on how {{current_subtopic}} impacts {{main_topic}}. Use the Context Documents provided as guideline for the sections.", # output_prefix="# Introduction", requirements=[req_outline, req_num_sections], - user_variables=user_args, grounding_context=outline_context, + user_variables=user_args, strategy=RejectionSamplingStrategy(loop_budget=2), return_sampling_results=True, format=SectionTitles, ) - st = SectionTitles.model_validate_json(outline_result.value) + st = SectionTitles.model_validate_json(outline_result.value or "") if isinstance(outline_result, SamplingResult): if not outline_result.success: @@ -201,12 +202,12 @@ def step_write_full_report( print(f"\t{v[1]} <- {v[0].description}") print("done.") - return report_result.value + return report_result.value or "" def research_subtopic(main_topic: str, subtopic: str, context: list[RAGDocument]): """Start MiniResearcher here.""" - user_args = { + user_args: dict[str, Any] = { "context_docs": context, "current_subtopic": subtopic, "main_topic": main_topic, diff --git a/docs/examples/safety/guardian.py b/docs/examples/safety/guardian.py index aa371cfa..bd927e96 100644 --- a/docs/examples/safety/guardian.py +++ b/docs/examples/safety/guardian.py @@ -66,7 +66,7 @@ # Show Ollama backend configuration ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama") -print(f" Ollama backend: {ollama_guardian._backend.model_version}") +print(f" Ollama backend: {ollama_guardian._backend.model_version}") # type: ignore[attr-defined] print("\n=== Test 4: Groundedness Detection ===") # Test groundedness - detecting when responses lack factual grounding diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py index 1b73f5bf..0ef633b6 100644 --- a/docs/examples/tools/interpreter_example.py +++ b/docs/examples/tools/interpreter_example.py @@ -32,6 +32,9 @@ def example_3(m: MelleaSession): tool_calls=True, ) + if plot_output.tool_calls is None: + raise ValueError("Expected tool_calls but got None") + code = plot_output.tool_calls["local_code_interpreter"].args["code"] print(f"Going to execute the following code:\n```python\n{code}\n```") @@ -62,6 +65,9 @@ def example_4(m: MelleaSession): tool_calls=True, ) + if plot_output.tool_calls is None: + raise ValueError("Expected tool_calls but got None") + code = plot_output.tool_calls["local_code_interpreter"].args["code"] print(f"Going to execute the following code:\n```python\n{code}\n```") diff --git a/docs/examples/tutorial/instruct_validate_repair.py b/docs/examples/tutorial/instruct_validate_repair.py index 75985d69..ddba7c02 100644 --- a/docs/examples/tutorial/instruct_validate_repair.py +++ b/docs/examples/tutorial/instruct_validate_repair.py @@ -1,6 +1,7 @@ +from mellea.core import Requirement from mellea.stdlib.requirements import check, req, simple_validate -requirements = [ +requirements: list[Requirement | str] = [ req("The email should have a salutation"), # == r1 req( "Use only lower-case letters", @@ -17,14 +18,14 @@ def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str: email_candidate = m.instruct( "Write an email to {{name}} using the notes following: {{notes}}.", requirements=requirements, - strategy=RejectionSamplingStrategy(loop_budget=5), user_variables={"name": name, "notes": notes}, + strategy=RejectionSamplingStrategy(loop_budget=5), return_sampling_results=True, ) if email_candidate.success: return str(email_candidate.result) else: - return email_candidate.sample_generations[0].value + return email_candidate.sample_generations[0].value or "" m = mellea.start_session() diff --git a/docs/kv_smash/kv_with_chat.py b/docs/kv_smash/kv_with_chat.py index e6ab5a05..bdf6f38e 100644 --- a/docs/kv_smash/kv_with_chat.py +++ b/docs/kv_smash/kv_with_chat.py @@ -49,7 +49,7 @@ def merge(toks, dcs): {"role": "user", "content": c_blocks[1]}, {"role": "user", "content": "Also no cash"}, ] -templatized_input: str = tokenizer.apply_chat_template( +templatized_input: str = tokenizer.apply_chat_template( # type: ignore[assignment] conversation=messages, tokenize=False ) @@ -95,7 +95,7 @@ def merge(toks, dcs): merged_dcs.crop(-1) # generate and print result. -result = model.generate( +result = model.generate( # type: ignore[operator] merged_toks.to(device), attention_mask=merged_masks.to(device), past_key_values=merged_dcs, diff --git a/docs/rewrite/session_deepdive/step2_act_cblocks.py b/docs/rewrite/session_deepdive/step2_act_cblocks.py index 30845436..86e9d285 100644 --- a/docs/rewrite/session_deepdive/step2_act_cblocks.py +++ b/docs/rewrite/session_deepdive/step2_act_cblocks.py @@ -1,10 +1,10 @@ import mellea.stdlib.functional as mfuncs from mellea.backends.ollama import OllamaModelBackend -from mellea.core import CBlock +from mellea.stdlib.components import Instruction from mellea.stdlib.context import SimpleContext response, next_context = mfuncs.act( - action=CBlock("What is 1+1?"), + action=Instruction("What is 1+1?"), context=SimpleContext(), backend=OllamaModelBackend("granite4:latest"), ) diff --git a/docs/rewrite/session_deepdive/step3_async.py b/docs/rewrite/session_deepdive/step3_async.py index 02f41887..5e765add 100644 --- a/docs/rewrite/session_deepdive/step3_async.py +++ b/docs/rewrite/session_deepdive/step3_async.py @@ -2,13 +2,14 @@ import mellea.stdlib.functional as mfuncs from mellea.backends.ollama import OllamaModelBackend -from mellea.core import Backend, CBlock, Context +from mellea.core import Backend, Context +from mellea.stdlib.components import Instruction from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): response, _next_context = await mfuncs.aact( - action=CBlock("What is 1+1?"), context=ctx, backend=backend + action=Instruction("What is 1+1?"), context=ctx, backend=backend ) print(response.value) diff --git a/pyproject.toml b/pyproject.toml index cb0c14b5..701917d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,20 +200,10 @@ split-on-trailing-comma = false "cli/**/*.py" = ["D"] [[tool.mypy.overrides]] -# TODO: Fix type errors in docs/ examples (25 errors across 15 files) -# Temporarily relaxed type checking for docs/ to allow examples to focus on -# demonstrating functionality rather than perfect type correctness. -# Main issues: Optional/None handling, CBlock vs Component confusion, type mismatches +# Keep import-not-found suppressed for optional dependencies module = "docs.*" disable_error_code = [ "import-not-found", - "arg-type", - "attr-defined", - "assignment", - "operator", - "call-overload", - "index", - "return-value", ] [tool.codespell] From 251a00ef5d143cc6977e666e06aa7ef666c0a5e1 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 28 Jan 2026 15:57:36 -0600 Subject: [PATCH 3/5] test: remove TODO Signed-off-by: Alex Bozarth --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 701917d2..bcfbbdc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,6 @@ split-on-trailing-comma = false # E402: Module level import not at top of file # Intentional in examples (pedagogical structure) and tests (pytestmark before imports) # D: Docstring errors -# TODO: Add docstrings to cli/ (94 errors across 16 files) # Not required in examples, tests, and notebooks (core mellea/ has complete docstrings) "docs/**/*.py" = ["E402", "D"] "docs/**/*.ipynb" = ["D"] From e7b27ff210631acd830af34e4f5a5a6c4691fcc0 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 28 Jan 2026 16:08:45 -0600 Subject: [PATCH 4/5] test: only run pre-commit mypy on py file change Signed-off-by: Alex Bozarth --- .pre-commit-config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d618240e..37d8e282 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,9 +7,11 @@ repos: - id: ruff-format name: "Ruff formatter" args: [--config=pyproject.toml] + types_or: [python, jupyter] - id: ruff name: "Ruff linter" args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml] + types_or: [python, jupyter] - repo: local hooks: @@ -18,6 +20,7 @@ repos: entry: uv run --no-sync mypy . pass_filenames: false language: system + types_or: [python, jupyter] - repo: https://github.com/astral-sh/uv-pre-commit rev: 0.7.8 From 974f7167e8ed9f178fb33ecc5e7b4f1ec16c4d2a Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Mon, 2 Feb 2026 11:35:59 -0600 Subject: [PATCH 5/5] test: fix mypy errors Signed-off-by: Alex Bozarth --- docs/examples/library_interop/langchain_messages.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/examples/library_interop/langchain_messages.py b/docs/examples/library_interop/langchain_messages.py index ecdba8e8..ab2dd21e 100644 --- a/docs/examples/library_interop/langchain_messages.py +++ b/docs/examples/library_interop/langchain_messages.py @@ -2,6 +2,8 @@ # Installing langchain is necessary for this example, but it works for any library # you may want to use Mellea with. +from typing import Any + from langchain_core.messages import AIMessage, HumanMessage, SystemMessage # Messages from a different library. @@ -14,7 +16,7 @@ # Some libraries have conversion functions that make it easier to ingest into Mellea. from langchain_core.messages import convert_to_openai_messages -messages = convert_to_openai_messages(messages=messages) +openai_messages: list[dict[str, Any]] = convert_to_openai_messages(messages=messages) # Import Mellea. from mellea import start_session @@ -25,7 +27,7 @@ # Mellea uses explicit contexts. Cast the OpenAI formatted messages into # Mellea messages and add them to the context. ctx = ChatContext() -for msg in messages: +for msg in openai_messages: ctx = ctx.add( # NOTE: If your messages aren't in OpenAI format or have additional data like # documents / images, you need to explicitly grab those fields as well.