diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43a5e0e7..37d8e282 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,25 +2,25 @@ exclude: ^(scratchpad/) repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.5 + rev: v0.14.14 hooks: - id: ruff-format name: "Ruff formatter" args: [--config=pyproject.toml] - files: '^(mellea|test|cli|docs).*\.(py|ipynb)$' + types_or: [python, jupyter] - id: ruff name: "Ruff linter" args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml] - files: '^(mellea).*\.(py|ipynb)$' + types_or: [python, jupyter] - repo: local hooks: - id: mypy name: MyPy - entry: uv run --no-sync mypy mellea + entry: uv run --no-sync mypy . pass_filenames: false language: system - files: '^(mellea|test|cli|docs).*\.(py|ipynb)$' + types_or: [python, jupyter] - repo: https://github.com/astral-sh/uv-pre-commit rev: 0.7.8 diff --git a/cli/decompose/pipeline.py b/cli/decompose/pipeline.py index 1c6dda1a..a574d35d 100644 --- a/cli/decompose/pipeline.py +++ b/cli/decompose/pipeline.py @@ -5,9 +5,9 @@ from typing_extensions import NotRequired from mellea import MelleaSession +from mellea.backends import ModelOption from mellea.backends.ollama import OllamaModelBackend from mellea.backends.openai import OpenAIBackend -from mellea.backends import ModelOption from .prompt_modules import ( constraint_extractor, diff --git a/cli/eval/commands.py b/cli/eval/commands.py index ebc85dd6..17ff56be 100644 --- a/cli/eval/commands.py +++ b/cli/eval/commands.py @@ -1,5 +1,6 @@ """Use the eval command for LLM-as-a-judge evaluation, given a (set of) test file(s) consisting of prompts, instructions, and optionally, targets. -Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed.""" +Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed. +""" import typer diff --git a/cli/eval/runner.py b/cli/eval/runner.py index 38b4c1bb..3aface94 100644 --- a/cli/eval/runner.py +++ b/cli/eval/runner.py @@ -1,15 +1,16 @@ import json import re from pathlib import Path -from typing import List + +from rich.console import Console +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn import mellea +from mellea.backends import ModelOption +from mellea.backends.backend import Backend from mellea.core import ModelOutputThunk +from mellea.stdlib.components import SimpleComponent from mellea.stdlib.components.unit_test_eval import TestBasedEval -from mellea.backends import ModelOption - -from rich.console import Console -from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn console = Console() @@ -78,7 +79,6 @@ def create_session( backend: str, model: str | None, max_tokens: int | None ) -> mellea.MelleaSession: """Create a mellea session with the specified backend and model.""" - model_id = None if model: if model.isupper() or "_" in model: @@ -93,6 +93,7 @@ def create_session( try: backend_lower = backend.lower() + backend_instance: Backend if backend_lower == "ollama": from mellea.backends.ollama import OllamaModelBackend @@ -130,7 +131,7 @@ def create_session( from mellea.backends.litellm import LiteLLMBackend backend_instance = LiteLLMBackend( - model_id=model_id, + model_id=str(model_id), model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}, ) @@ -153,7 +154,7 @@ def create_session( def run_evaluations( - test_files: List[str], + test_files: list[str], backend: str, model: str | None, max_gen_tokens: int | None, @@ -173,14 +174,14 @@ def run_evaluations( "instructions": a set (in string form) of requirements which the generation should follow; the judge will evaluate if these are satisfied "examples": a list of entries containing an input_id, an input(prompt), and a list of targets. Each input may have multiple (or no) targets; inputs and targets are in messages format. """ - all_test_evals: List[TestBasedEval] = [] + all_test_evals: list[TestBasedEval] = [] for test_file in test_files: try: test_evals = TestBasedEval.from_json_file(test_file) all_test_evals.extend(test_evals) console.print(f"Loaded {len(test_evals)} test evaluations from {test_file}") - except Exception as e: + except Exception: console.print(f"Error loading {test_file}") if not all_test_evals: @@ -195,8 +196,11 @@ def run_evaluations( console.print(f"Judge model: {judge_model}") m = create_session(backend=backend, model=model, max_tokens=max_gen_tokens) + # Use same backend as generator if judge_backend not specified judge_session = create_session( - backend=judge_backend, model=judge_model, max_tokens=max_judge_tokens + backend=judge_backend if judge_backend else backend, + model=judge_model, + max_tokens=max_judge_tokens, ) all_results = [] @@ -240,12 +244,13 @@ def execute_test_eval( For each input in the test, generate a response using generation_session Then, after all inputs are processed, validate using judge_session. """ - input_results = [] # for all inputs, generate responses with generator for idx, input_text in enumerate(test_eval.inputs): - result: ModelOutputThunk = generation_session.act(input_text) + result: ModelOutputThunk = generation_session.act( + SimpleComponent(instruction=input_text) + ) model_output = str(result) targets_for_input = ( @@ -267,7 +272,7 @@ def execute_test_eval( input_text=input_text, model_output=model_output, validation_passed=passed, - score=score, + score=score if score is not None else 0, validation_reason=justification, ) input_results.append(input_result) @@ -301,7 +306,7 @@ def parse_judge_output(judge_output: str): return None, judge_output -def save_results(results: List[TestEvalResult], output_path: str, output_format: str): +def save_results(results: list[TestEvalResult], output_path: str, output_format: str): output_path_obj = Path(output_path) if output_path_obj.suffix != f".{output_format}": output_path_obj = Path(f"{output_path}.{output_format}") @@ -333,7 +338,7 @@ def save_results(results: List[TestEvalResult], output_path: str, output_format: console.print(f"Results saved to {output_path}") -def summary_stats(results: List[TestEvalResult]): +def summary_stats(results: list[TestEvalResult]): total_inputs = sum(r.total_count for r in results) passed_inputs = sum(r.passed_count for r in results) overall_pass_rate = passed_inputs / total_inputs if total_inputs > 0 else 0.0 diff --git a/cli/m.py b/cli/m.py index 07fc14b9..ab39440e 100644 --- a/cli/m.py +++ b/cli/m.py @@ -4,8 +4,8 @@ from cli.alora.commands import alora_app from cli.decompose import app as decompose_app -from cli.serve.app import serve from cli.eval.commands import eval_app +from cli.serve.app import serve cli = typer.Typer(name="m", no_args_is_help=True) diff --git a/docs/__init__.py b/docs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docs/examples/__init__.py b/docs/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docs/examples/aLora/101_example.py b/docs/examples/aLora/101_example.py index 9497b9b9..351e0a1d 100644 --- a/docs/examples/aLora/101_example.py +++ b/docs/examples/aLora/101_example.py @@ -2,13 +2,14 @@ import time -from mellea import MelleaSession from mellea.backends.aloras.huggingface.granite_aloras import HFConstraintAlora -from mellea.backends.cache import SimpleLRUCache -from mellea.backends.huggingface import LocalHFBackend from mellea.stdlib.base import ChatContext, GenerateLog from mellea.stdlib.requirement import ALoraRequirement, Requirement +from mellea import MelleaSession +from mellea.backends.cache import SimpleLRUCache +from mellea.backends.huggingface import LocalHFBackend + # Define a backend and add the constraint aLora backend = LocalHFBackend( model_id="ibm-granite/granite-3.2-8b-instruct", cache=SimpleLRUCache(5) @@ -21,7 +22,7 @@ backend=backend, ) -backend.add_alora(custom_stembolt_failure_constraint) +backend.add_alora(custom_stembolt_failure_constraint) # type: ignore[attr-defined] # Create M session m = MelleaSession(backend, ctx=ChatContext()) diff --git a/docs/examples/agents/react.py b/docs/examples/agents/react.py index 1debc044..35404136 100644 --- a/docs/examples/agents/react.py +++ b/docs/examples/agents/react.py @@ -4,6 +4,7 @@ import inspect import json from collections.abc import Callable +from enum import Enum from typing import Literal import pydantic @@ -84,9 +85,9 @@ def call_tool(self, tool: ReactTool, kwargs_json: str): def tool_name_schema(self): names = self.tool_names() - fields = dict() - fields["tool"] = Literal[*names] - return pydantic.create_model("ToolSelectionSchema", **fields) + # Python 3.10 compatible: use Enum instead of Literal[*names] (requires 3.11+) + ToolEnum = Enum("ToolEnum", {name: name for name in names}) + return pydantic.create_model("ToolSelectionSchema", tool=(ToolEnum, ...)) def get_tool_from_schema(self, content: str): schema = self.tool_name_schema() diff --git a/docs/examples/agents/react_instruct.py b/docs/examples/agents/react_instruct.py index 47dfa2b9..e444ec25 100644 --- a/docs/examples/agents/react_instruct.py +++ b/docs/examples/agents/react_instruct.py @@ -4,6 +4,7 @@ import inspect import json from collections.abc import Callable +from enum import Enum from typing import Literal import pydantic @@ -81,9 +82,9 @@ def call_tool(self, tool: ReactTool, kwargs_json: str): def tool_name_schema(self): names = self.tool_names() - fields = dict() - fields["tool"] = Literal[*names] - return pydantic.create_model("ToolSelectionSchema", **fields) + # Python 3.10 compatible: use Enum instead of Literal[*names] (requires 3.11+) + ToolEnum = Enum("ToolEnum", {name: name for name in names}) + return pydantic.create_model("ToolSelectionSchema", tool=(ToolEnum, ...)) def get_tool_from_schema(self, content: str): schema = self.tool_name_schema() diff --git a/docs/examples/conftest.py b/docs/examples/conftest.py index 28658770..5146fb7c 100644 --- a/docs/examples/conftest.py +++ b/docs/examples/conftest.py @@ -209,7 +209,7 @@ def pytest_ignore_collect(collection_path, path, config): # Extract markers and check if we should skip try: markers = _extract_markers_from_file(collection_path) - should_skip, reason = _should_skip_collection(markers) + should_skip, _reason = _should_skip_collection(markers) if should_skip: # Return True to ignore this file completely return True @@ -233,7 +233,7 @@ def pytest_pycollect_makemodule(module_path, path, parent): and "examples" in module_path.parts ): # Check for optional imports - should_skip, reason = _check_optional_imports(module_path) + should_skip, _reason = _check_optional_imports(module_path) if should_skip: # Add to skip list and return None to prevent module creation examples_to_skip.add(module_path.name) @@ -257,7 +257,7 @@ def pytest_collect_file(parent: pytest.Dir, file_path: pathlib.PosixPath): return # Check for optional imports before creating ExampleFile - should_skip, reason = _check_optional_imports(file_path) + should_skip, _reason = _check_optional_imports(file_path) if should_skip: return None @@ -344,7 +344,6 @@ def pytest_runtest_setup(item): gh_run = int(os.environ.get("CICD", 0)) # Get config options (all default to False for examples) - ignore_all = False ignore_gpu = False ignore_ram = False ignore_ollama = False diff --git a/docs/examples/context/contexts_with_sampling.py b/docs/examples/context/contexts_with_sampling.py index 35d69d45..d9a7c280 100644 --- a/docs/examples/context/contexts_with_sampling.py +++ b/docs/examples/context/contexts_with_sampling.py @@ -28,7 +28,7 @@ print(f"Total Generation Attempts: {len(res.sample_generations)}") print() -print(f"Getting index of another result.") +print("Getting index of another result.") index = 0 # Just choose the first one. print( diff --git a/docs/examples/generative_slots/generative_slots_with_requirements.py b/docs/examples/generative_slots/generative_slots_with_requirements.py index 4f65c796..b7eeff8e 100644 --- a/docs/examples/generative_slots/generative_slots_with_requirements.py +++ b/docs/examples/generative_slots/generative_slots_with_requirements.py @@ -3,16 +3,15 @@ from typing import Literal from mellea import generative, start_session +from mellea.core import Requirement from mellea.stdlib.components.genslot import PreconditionException from mellea.stdlib.requirements import simple_validate -from mellea.core import Requirement from mellea.stdlib.sampling.base import RejectionSamplingStrategy @generative def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: """Classify the sentiment of the text.""" - ... if __name__ == "__main__": @@ -30,8 +29,8 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: ) print( - f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```" - ) # type: ignore + f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```" # type: ignore[index] + ) # Prompt to the model looked like: # ``` # Your task is to imitate the output of the following function for the given arguments. @@ -65,7 +64,7 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: ], ) except PreconditionException as e: - print(f"exception: {str(e)}") + print(f"exception: {e!s}") # Look at why the precondition validation failed. print("Failure reasons:") diff --git a/docs/examples/image_text_models/vision_litellm_backend.py b/docs/examples/image_text_models/vision_litellm_backend.py index 5ca65d93..5a2dc6cc 100644 --- a/docs/examples/image_text_models/vision_litellm_backend.py +++ b/docs/examples/image_text_models/vision_litellm_backend.py @@ -3,6 +3,7 @@ """Examples of using vision models with LiteLLM backend.""" import os +import pathlib import litellm from PIL import Image @@ -11,7 +12,6 @@ from mellea.backends.litellm import LiteLLMBackend from mellea.backends.openai import OpenAIBackend from mellea.core import ImageBlock -import pathlib # use LiteLLM to talk to Ollama or anthropic or..... m = MelleaSession(LiteLLMBackend("ollama/granite3.2-vision")) @@ -28,17 +28,18 @@ # test with PIL image res_instruct = m.instruct( "Is there a person on the image? Is the subject in the image smiling?", - images=[test_pil], + images=[test_pil], # type: ignore[arg-type] ) -print(f"Test with PIL and instruct: \n{str(res_instruct)}\n-----") +print(f"Test with PIL and instruct: \n{res_instruct!s}\n-----") # print(m.last_prompt()) # with PIL image and using m.chat res_chat = m.chat( - "How many eyes can you identify in the image? Explain.", images=[test_pil] + "How many eyes can you identify in the image? Explain.", + images=[test_pil], # type: ignore[arg-type] ) -print(f"Test with PIL and chat: \n{str(res_chat.content)}\n-----") +print(f"Test with PIL and chat: \n{res_chat.content!s}\n-----") # and now without images again... res_empty = m.instruct("How many eyes can you identify in the image?", images=[]) -print(f"Test without image: \n{str(res_empty)}\n-----") +print(f"Test without image: \n{res_empty!s}\n-----") diff --git a/docs/examples/image_text_models/vision_ollama_chat.py b/docs/examples/image_text_models/vision_ollama_chat.py index b3b3fc04..a190e5b1 100644 --- a/docs/examples/image_text_models/vision_ollama_chat.py +++ b/docs/examples/image_text_models/vision_ollama_chat.py @@ -3,6 +3,7 @@ """Example of using Ollama with vision models with linear context.""" import pathlib + from PIL import Image from mellea import start_session @@ -16,7 +17,7 @@ test_pil = Image.open(image_path) # ask a question about the image -res = m.instruct("Is the subject in the image smiling?", images=[test_pil]) +res = m.instruct("Is the subject in the image smiling?", images=[test_pil]) # type: ignore[arg-type] print(f"Result:{res!s}") # This instruction should refer to the first image. diff --git a/docs/examples/image_text_models/vision_openai_examples.py b/docs/examples/image_text_models/vision_openai_examples.py index d79d5fee..46136afa 100644 --- a/docs/examples/image_text_models/vision_openai_examples.py +++ b/docs/examples/image_text_models/vision_openai_examples.py @@ -8,8 +8,8 @@ from mellea import MelleaSession from mellea.backends.openai import OpenAIBackend -from mellea.stdlib.context import ChatContext from mellea.core import ImageBlock +from mellea.stdlib.context import ChatContext # # using anthropic AI model ... # anth_key = os.environ.get("ANTHROPIC_API_KEY") @@ -49,8 +49,11 @@ # print(m.last_prompt()) # and now with PIL image and using m.chat -res = m.chat("How many eyes can you identify in the image? Explain.", images=[test_pil]) -print(str(res.content)) +chat_res = m.chat( + "How many eyes can you identify in the image? Explain.", + images=[test_pil], # type: ignore[arg-type] +) +print(str(chat_res.content)) # and now without images again... res = m.instruct("How many eyes can you identify in the image?", images=[]) diff --git a/docs/examples/information_extraction/101_with_gen_slots.py b/docs/examples/information_extraction/101_with_gen_slots.py index bec0557f..83a08803 100644 --- a/docs/examples/information_extraction/101_with_gen_slots.py +++ b/docs/examples/information_extraction/101_with_gen_slots.py @@ -10,9 +10,7 @@ @generative def extract_all_person_names(doc: str) -> list[str]: - """ - Given a document, extract names of ALL mentioned persons. Return these names as list of strings. - """ + """Given a document, extract names of ALL mentioned persons. Return these names as list of strings.""" # ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html diff --git a/docs/examples/information_extraction/advanced_with_m_instruct.py b/docs/examples/information_extraction/advanced_with_m_instruct.py index 60537c06..4084469c 100644 --- a/docs/examples/information_extraction/advanced_with_m_instruct.py +++ b/docs/examples/information_extraction/advanced_with_m_instruct.py @@ -8,9 +8,9 @@ from mellea import start_session from mellea.backends import model_ids +from mellea.core import SamplingResult from mellea.stdlib.requirements import check, simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy -from mellea.core import SamplingResult # ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html NYTimes_text = "CAMP DAVID, Md. — Leaders of the world's richest countries banded together on Saturday to press Germany to back more pro-growth policies to halt the deepening debt crisis in Europe, as President Obama for the first time gained widespread support for his argument that Europe, and the United States by extension, cannot afford Chancellor Angela Merkel's one-size-fits-all approach emphasizing austerity." diff --git a/docs/examples/instruct_validate_repair/101_email_with_validate.py b/docs/examples/instruct_validate_repair/101_email_with_validate.py index bf8589c2..742da99a 100644 --- a/docs/examples/instruct_validate_repair/101_email_with_validate.py +++ b/docs/examples/instruct_validate_repair/101_email_with_validate.py @@ -2,8 +2,8 @@ from docs.examples.helper import req_print, w from mellea import start_session -from mellea.backends.model_ids import IBM_GRANITE_3_3_8B from mellea.backends import ModelOption +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B from mellea.stdlib.sampling import RejectionSamplingStrategy # create a session using Granite 4 Micro (3B) on Ollama and a simple context [see below] diff --git a/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py b/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py index 4bdc9793..a6acafc3 100644 --- a/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py +++ b/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py @@ -3,8 +3,8 @@ from docs.examples.helper import w from mellea import start_session from mellea.backends import ModelOption -from mellea.stdlib.requirements import simple_validate from mellea.core import Requirement +from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy # create a session using Granite 4 Micro 3B on Ollama and a simple context [see below] diff --git a/docs/examples/intrinsics/answer_relevance.py b/docs/examples/intrinsics/answer_relevance.py index 6d10c048..a2588f1a 100644 --- a/docs/examples/intrinsics/answer_relevance.py +++ b/docs/examples/intrinsics/answer_relevance.py @@ -1,7 +1,6 @@ # pytest: huggingface, requires_heavy_ram, llm -""" -Example usage of the answer relevance intrinsic for RAG applications. +"""Example usage of the answer relevance intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -10,10 +9,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document +from mellea.stdlib.components import Document, Message from mellea.stdlib.components.intrinsic import rag - +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext().add(Message("user", "Who attended the meeting?")) diff --git a/docs/examples/intrinsics/answerability.py b/docs/examples/intrinsics/answerability.py index c7f2e03d..e88f9ea4 100644 --- a/docs/examples/intrinsics/answerability.py +++ b/docs/examples/intrinsics/answerability.py @@ -1,7 +1,6 @@ # pytest: huggingface, requires_heavy_ram, llm -""" -Example usage of the answerability intrinsic for RAG applications. +"""Example usage of the answerability intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -10,9 +9,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document +from mellea.stdlib.components import Document, Message from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext().add(Message("assistant", "Hello there, how can I help you?")) diff --git a/docs/examples/intrinsics/citations.py b/docs/examples/intrinsics/citations.py index f7a0b6b8..90d39567 100644 --- a/docs/examples/intrinsics/citations.py +++ b/docs/examples/intrinsics/citations.py @@ -1,7 +1,6 @@ # pytest: huggingface, requires_heavy_ram, llm -""" -Example usage of the citations intrinsic for RAG applications. +"""Example usage of the citations intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -9,12 +8,12 @@ ``` """ -from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document -from mellea.stdlib.components.intrinsic import rag import json +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext().add( diff --git a/docs/examples/intrinsics/context_relevance.py b/docs/examples/intrinsics/context_relevance.py index 0f5c8eb3..3eef6b31 100644 --- a/docs/examples/intrinsics/context_relevance.py +++ b/docs/examples/intrinsics/context_relevance.py @@ -1,7 +1,6 @@ # pytest: huggingface, requires_heavy_ram, llm -""" -Example usage of the context relevance intrinsic for RAG applications. +"""Example usage of the context relevance intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -10,9 +9,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Document from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext() diff --git a/docs/examples/intrinsics/hallucination_detection.py b/docs/examples/intrinsics/hallucination_detection.py index 0adfe38e..247f755d 100644 --- a/docs/examples/intrinsics/hallucination_detection.py +++ b/docs/examples/intrinsics/hallucination_detection.py @@ -1,7 +1,6 @@ # pytest: huggingface, requires_heavy_ram, llm -""" -Example usage of the hallucination detection intrinsic for RAG applications. +"""Example usage of the hallucination detection intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -9,12 +8,12 @@ ``` """ -from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message, Document -from mellea.stdlib.components.intrinsic import rag import json +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.components import Document, Message +from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ( diff --git a/docs/examples/intrinsics/intrinsics.py b/docs/examples/intrinsics/intrinsics.py index 70a71671..b7fcb267 100644 --- a/docs/examples/intrinsics/intrinsics.py +++ b/docs/examples/intrinsics/intrinsics.py @@ -1,12 +1,11 @@ # pytest: huggingface, requires_heavy_ram, llm +import mellea.stdlib.functional as mfuncs +from mellea.backends.adapters.adapter import AdapterType, GraniteCommonAdapter from mellea.backends.huggingface import LocalHFBackend from mellea.backends.openai import OpenAIBackend, _ServerType -from mellea.backends.adapters.adapter import AdapterType, GraniteCommonAdapter +from mellea.stdlib.components import Intrinsic, Message from mellea.stdlib.context import ChatContext -from mellea.stdlib.components import Message -import mellea.stdlib.functional as mfuncs -from mellea.stdlib.components import Intrinsic # This is an example for how you would directly use intrinsics. See `mellea/stdlib/intrinsics/rag.py` # for helper functions. diff --git a/docs/examples/intrinsics/query_rewrite.py b/docs/examples/intrinsics/query_rewrite.py index 7a6f3c56..d1624b88 100644 --- a/docs/examples/intrinsics/query_rewrite.py +++ b/docs/examples/intrinsics/query_rewrite.py @@ -1,7 +1,6 @@ # pytest: huggingface, requires_heavy_ram, llm -""" -Example usage of the query rewrite intrinsic for RAG applications. +"""Example usage of the query rewrite intrinsic for RAG applications. To run this script from the root of the Mellea source tree, use the command: ``` @@ -10,9 +9,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message from mellea.stdlib.components.intrinsic import rag +from mellea.stdlib.context import ChatContext backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ( diff --git a/docs/examples/library_interop/langchain_messages.py b/docs/examples/library_interop/langchain_messages.py index 4e0b02fb..ab2dd21e 100644 --- a/docs/examples/library_interop/langchain_messages.py +++ b/docs/examples/library_interop/langchain_messages.py @@ -2,7 +2,9 @@ # Installing langchain is necessary for this example, but it works for any library # you may want to use Mellea with. -from langchain_core.messages import HumanMessage, AIMessage, SystemMessage +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage # Messages from a different library. messages = [ @@ -14,18 +16,18 @@ # Some libraries have conversion functions that make it easier to ingest into Mellea. from langchain_core.messages import convert_to_openai_messages -messages = convert_to_openai_messages(messages=messages) +openai_messages: list[dict[str, Any]] = convert_to_openai_messages(messages=messages) # Import Mellea. +from mellea import start_session +from mellea.backends import ModelOption from mellea.stdlib.components import Message from mellea.stdlib.context import ChatContext -from mellea.backends import ModelOption -from mellea import start_session # Mellea uses explicit contexts. Cast the OpenAI formatted messages into # Mellea messages and add them to the context. ctx = ChatContext() -for msg in messages: +for msg in openai_messages: ctx = ctx.add( # NOTE: If your messages aren't in OpenAI format or have additional data like # documents / images, you need to explicitly grab those fields as well. diff --git a/docs/examples/m_serve/m_serve_example_simple.py b/docs/examples/m_serve/m_serve_example_simple.py index 2a717574..74fe4329 100644 --- a/docs/examples/m_serve/m_serve_example_simple.py +++ b/docs/examples/m_serve/m_serve_example_simple.py @@ -6,8 +6,8 @@ import mellea from cli.serve.models import ChatMessage +from mellea.core import ModelOutputThunk, Requirement, SamplingResult from mellea.stdlib.context import ChatContext -from mellea.core import ModelOutputThunk, SamplingResult, Requirement from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy diff --git a/docs/examples/m_serve/pii_serve.py b/docs/examples/m_serve/pii_serve.py index 356867e8..75fac8cc 100644 --- a/docs/examples/m_serve/pii_serve.py +++ b/docs/examples/m_serve/pii_serve.py @@ -2,8 +2,8 @@ import spacy -from cli.serve.models import ChatMessage import mellea +from cli.serve.models import ChatMessage from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B from mellea.core import ModelOutputThunk, SamplingResult from mellea.stdlib.requirements import req, simple_validate @@ -79,7 +79,7 @@ def serve( model_options: None | dict = None, ) -> ModelOutputThunk | SamplingResult | str: """Simple serve example to do PII stuff.""" - message = input[-1].content + message = input[-1].content or "" result = pii_remove_validate( session, message, requirements=requirements, model_options=model_options ) diff --git a/docs/examples/mcp/README.md b/docs/examples/mcp/README.md index 8202f789..48161c33 100644 --- a/docs/examples/mcp/README.md +++ b/docs/examples/mcp/README.md @@ -14,7 +14,7 @@ uv pip install "mcp[cli]" and run the example in MCP debug UI: ```bash -uv run mcp dev docs/examples/tutorial/mcp_example.py +uv run mcp dev docs/examples/mcp/mcp_example.py ``` diff --git a/docs/examples/mcp/mcp_example.py b/docs/examples/mcp/mcp_example.py index 43d965dc..fac1a6cc 100644 --- a/docs/examples/mcp/mcp_example.py +++ b/docs/examples/mcp/mcp_example.py @@ -4,7 +4,7 @@ uv pip install "mcp[cli]" and run the example in MCP debug UI: -uv run mcp dev docs/examples/tutorial/mcp_example.py +uv run mcp dev docs/examples/mcp/mcp_example.py """ from mcp.server.fastmcp import FastMCP diff --git a/docs/examples/melp/lazy.py b/docs/examples/melp/lazy.py index 4d293515..0719774d 100644 --- a/docs/examples/melp/lazy.py +++ b/docs/examples/melp/lazy.py @@ -1,14 +1,12 @@ # pytest: ollama, qualitative, llm import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk +from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -from mellea.core import Backend -from mellea.backends.ollama import OllamaModelBackend - backend = OllamaModelBackend("granite4:latest") diff --git a/docs/examples/melp/lazy_fib.py b/docs/examples/melp/lazy_fib.py index 1383b410..6c06c6cd 100644 --- a/docs/examples/melp/lazy_fib.py +++ b/docs/examples/melp/lazy_fib.py @@ -1,15 +1,12 @@ # pytest: ollama, llm import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk +from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -from mellea.core import Backend -from mellea.backends.ollama import OllamaModelBackend -from typing import Tuple - backend = OllamaModelBackend("granite4:latest") @@ -30,7 +27,7 @@ async def fib_main(backend: Backend, ctx: Context): mot = await fib(backend, ctx, fibs[i - 1], fibs[i - 2]) fibs.append(mot) - print(await fibs[-1].avalue()) + print(await fibs[-1].avalue()) # type: ignore[attr-defined] # for x in fibs: # match x: # case ModelOutputThunk(): diff --git a/docs/examples/melp/lazy_fib_sample.py b/docs/examples/melp/lazy_fib_sample.py index 1e1e5611..19eae64d 100644 --- a/docs/examples/melp/lazy_fib_sample.py +++ b/docs/examples/melp/lazy_fib_sample.py @@ -1,15 +1,12 @@ # pytest: ollama, llm import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk +from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -from mellea.core import Backend -from mellea.backends.ollama import OllamaModelBackend -from typing import Tuple - backend = OllamaModelBackend("granite4:latest") @@ -28,7 +25,7 @@ async def _fib_sample( try: int(value) return answer_mot - except: + except Exception: return None @@ -45,13 +42,14 @@ async def fib_sampling_version( async def fib_sampling_version_main(backend: Backend, ctx: Context): - fibs = [] + fibs: list[CBlock | ModelOutputThunk] = [] for i in range(20): if i == 0 or i == 1: fibs.append(CBlock(f"{i}")) else: mot = await fib_sampling_version(backend, ctx, fibs[i - 1], fibs[i - 2]) - fibs.append(mot) + if mot is not None: + fibs.append(mot) for x_i, x in enumerate(fibs): match x: diff --git a/docs/examples/melp/simple_example.py b/docs/examples/melp/simple_example.py index e1c38c67..772eb0a3 100644 --- a/docs/examples/melp/simple_example.py +++ b/docs/examples/melp/simple_example.py @@ -1,15 +1,14 @@ # pytest: ollama, llm import asyncio -from mellea.core import Context, CBlock, ModelOutputThunk, Backend + from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context, ModelOutputThunk from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): - """ - In this example, we show how executing multiple MOTs in parallel should work. - """ + """In this example, we show how executing multiple MOTs in parallel should work.""" m_states = "Missouri", "Minnesota", "Montana", "Massachusetts" poem_thunks = [] @@ -21,7 +20,7 @@ async def main(backend: Backend, ctx: Context): # Notice that what we have now is a list of ModelOutputThunks, none of which are computed. for poem_thunk in poem_thunks: - assert type(poem_thunk) == ModelOutputThunk + assert isinstance(poem_thunk, ModelOutputThunk) print(f"Computed: {poem_thunk.is_computed()}") # Let's run all of these in parallel. diff --git a/docs/examples/melp/states.py b/docs/examples/melp/states.py index efbd5e07..bab5d810 100644 --- a/docs/examples/melp/states.py +++ b/docs/examples/melp/states.py @@ -2,10 +2,10 @@ import asyncio -from mellea.core import Context, CBlock, Backend from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.context import SimpleContext +from mellea.core import Backend, CBlock, Context from mellea.stdlib.components import SimpleComponent +from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): diff --git a/docs/examples/mify/rich_table_execute_basic.py b/docs/examples/mify/rich_table_execute_basic.py index f1365a06..edca015d 100644 --- a/docs/examples/mify/rich_table_execute_basic.py +++ b/docs/examples/mify/rich_table_execute_basic.py @@ -4,8 +4,7 @@ import os from mellea import start_session -from mellea.backends import model_ids -from mellea.backends import ModelOption +from mellea.backends import ModelOption, model_ids from mellea.core import FancyLogger from mellea.stdlib.components.docs.richdocument import RichDocument, Table diff --git a/docs/examples/mini_researcher/researcher.py b/docs/examples/mini_researcher/researcher.py index 2d092723..87cdeda0 100644 --- a/docs/examples/mini_researcher/researcher.py +++ b/docs/examples/mini_researcher/researcher.py @@ -2,6 +2,7 @@ from collections.abc import Callable from functools import cache +from typing import Any from openai import BaseModel from pydantic import ValidationError @@ -11,9 +12,9 @@ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend +from mellea.core import CBlock, Component, Requirement, SamplingResult from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy -from mellea.core import SamplingResult, Requirement # ############################# # Helper functions @@ -128,7 +129,7 @@ def max_sub_sections(out: str) -> bool: validation_fn=simple_validate(max_sub_sections), ) - outline_context = { + outline_context: dict[str, str | CBlock | Component] = { f"Document {i + 1}": f"## Title: {d.title}, ## Source: {d.source}" for i, d in enumerate(context) } @@ -138,14 +139,14 @@ def max_sub_sections(out: str) -> bool: description="Create an outline for a report on how {{current_subtopic}} impacts {{main_topic}}. Use the Context Documents provided as guideline for the sections.", # output_prefix="# Introduction", requirements=[req_outline, req_num_sections], - user_variables=user_args, grounding_context=outline_context, + user_variables=user_args, strategy=RejectionSamplingStrategy(loop_budget=2), return_sampling_results=True, format=SectionTitles, ) - st = SectionTitles.model_validate_json(outline_result.value) + st = SectionTitles.model_validate_json(outline_result.value or "") if isinstance(outline_result, SamplingResult): if not outline_result.success: @@ -203,12 +204,12 @@ def step_write_full_report( print(f"\t{v[1]} <- {v[0].description}") print("done.") - return report_result.value + return report_result.value or "" def research_subtopic(main_topic: str, subtopic: str, context: list[RAGDocument]): """Start MiniResearcher here.""" - user_args = { + user_args: dict[str, Any] = { "context_docs": context, "current_subtopic": subtopic, "main_topic": main_topic, diff --git a/docs/examples/mobject/table.py b/docs/examples/mobject/table.py index 93234a06..da2b98da 100644 --- a/docs/examples/mobject/table.py +++ b/docs/examples/mobject/table.py @@ -30,7 +30,7 @@ def update_sales(self, store: str, amount: str): index_col=False, ) # Remove unnamed columns and columns that don't exist. - table_df.drop(table_df.filter(regex="Unname").columns, axis=1, inplace=True) + table_df = table_df.drop(table_df.filter(regex="Unname").columns, axis=1) # Sometimes extra whitespace gets added to the column names and row values. Remove it. table_df.columns = table_df.columns.str.strip() diff --git a/docs/examples/notebooks/georgia_tech.ipynb b/docs/examples/notebooks/georgia_tech.ipynb index 08422fb4..a83e3169 100644 --- a/docs/examples/notebooks/georgia_tech.ipynb +++ b/docs/examples/notebooks/georgia_tech.ipynb @@ -439,9 +439,9 @@ }, "outputs": [], "source": [ + "from mellea.backends import ModelOption\n", "from mellea.backends.model_ids import META_LLAMA_3_2_3B\n", "from mellea.backends.ollama import OllamaModelBackend\n", - "from mellea.backends import ModelOption\n", "\n", "# You can use multiple different models at the same time!\n", "m_llama = mellea.MelleaSession(backend=OllamaModelBackend(model_id=META_LLAMA_3_2_3B))\n", diff --git a/docs/examples/notebooks/m_serve_example.ipynb b/docs/examples/notebooks/m_serve_example.ipynb index 729b75bf..fb3e7429 100644 --- a/docs/examples/notebooks/m_serve_example.ipynb +++ b/docs/examples/notebooks/m_serve_example.ipynb @@ -89,8 +89,8 @@ "\n", "import mellea\n", "from cli.serve.models import ChatMessage\n", - "from mellea.stdlib.context import ChatContext\n", "from mellea.core import ModelOutputThunk, Requirement, SamplingResult\n", + "from mellea.stdlib.context import ChatContext\n", "from mellea.stdlib.requirements import simple_validate\n", "from mellea.stdlib.sampling import RejectionSamplingStrategy\n", "\n", diff --git a/docs/examples/notebooks/model_options_example.ipynb b/docs/examples/notebooks/model_options_example.ipynb index 0216010c..e518679d 100644 --- a/docs/examples/notebooks/model_options_example.ipynb +++ b/docs/examples/notebooks/model_options_example.ipynb @@ -88,9 +88,8 @@ "outputs": [], "source": [ "import mellea\n", - "from mellea.backends import model_ids\n", - "from mellea.backends.ollama import OllamaModelBackend\n", - "from mellea.backends import ModelOption" + "from mellea.backends import ModelOption, model_ids\n", + "from mellea.backends.ollama import OllamaModelBackend" ] }, { diff --git a/docs/examples/rag/simple_rag_with_filter.py b/docs/examples/rag/simple_rag_with_filter.py index adf17ecd..9ff1e6d0 100644 --- a/docs/examples/rag/simple_rag_with_filter.py +++ b/docs/examples/rag/simple_rag_with_filter.py @@ -44,7 +44,7 @@ def create_index(model, ds: list[str]) -> IndexFlatIP: def query_index(model, idx: IndexFlatIP, query: str, ds: list[str], k: int = 5) -> list: query_embedding = model.encode([query]) - distances, indices = idx.search(query_embedding, k=k) + _distances, indices = idx.search(query_embedding, k=k) return [ds[i] for i in indices[0]] diff --git a/docs/examples/safety/guardian.py b/docs/examples/safety/guardian.py index 33913921..cf981391 100644 --- a/docs/examples/safety/guardian.py +++ b/docs/examples/safety/guardian.py @@ -5,9 +5,10 @@ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.core import ContextTurn, ModelOutputThunk -from mellea.stdlib.context import ChatContext +from mellea.backends.tools import MelleaTool +from mellea.core import ContextTurn, ModelOutputThunk, ModelToolCall from mellea.stdlib.components import Message +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk # Enhanced GuardianCheck with Granite Guardian 3.3 8B support @@ -68,7 +69,7 @@ # Show Ollama backend configuration ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama") -print(f" Ollama backend: {ollama_guardian._backend.model_version}") +print(f" Ollama backend: {ollama_guardian._backend.model_version}") # type: ignore[attr-defined] print("\n=== Test 4: Groundedness Detection ===") # Test groundedness - detecting when responses lack factual grounding @@ -131,7 +132,9 @@ def dummy_func(**kwargs): hallucinated_tool_calls = { "comments_list": ModelToolCall( - name="comments_list", func=dummy_func, args={"video_id": 456789123, "count": 15} + name="comments_list", + func=MelleaTool.from_callable(dummy_func), + args={"video_id": 456789123, "count": 15}, ) } diff --git a/docs/examples/safety/guardian_huggingface.py b/docs/examples/safety/guardian_huggingface.py index dceba917..72db99b8 100644 --- a/docs/examples/safety/guardian_huggingface.py +++ b/docs/examples/safety/guardian_huggingface.py @@ -8,11 +8,12 @@ from mellea import MelleaSession from mellea.backends import model_ids -from mellea.backends.ollama import OllamaModelBackend from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.tools import MelleaTool from mellea.core import ModelOutputThunk, ModelToolCall -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk print("=== GuardianCheck HuggingFace Backend Example ===") @@ -48,7 +49,7 @@ print(f"Guardian detected harm: {not validation_result[0]._result}") if validation_result[0]._reason: - print(f"\nGuardian feedback:") + print("\nGuardian feedback:") print(validation_result[0]._reason[:200] + "...") # Test 2: Groundedness detection @@ -111,7 +112,9 @@ def dummy_func(**kwargs): hallucinated_tool_calls = { "get_stock_price": ModelToolCall( - name="get_stock_price", func=dummy_func, args={"symbol": "AAPL"} + name="get_stock_price", + func=MelleaTool.from_callable(dummy_func), + args={"symbol": "AAPL"}, ) } diff --git a/docs/examples/safety/repair_with_guardian.py b/docs/examples/safety/repair_with_guardian.py index 7c2a0c96..8a4a3e10 100644 --- a/docs/examples/safety/repair_with_guardian.py +++ b/docs/examples/safety/repair_with_guardian.py @@ -1,7 +1,6 @@ # pytest: huggingface, requires_heavy_ram, llm -""" -RepairTemplateStrategy Example with Actual Function Call Validation +"""RepairTemplateStrategy Example with Actual Function Call Validation Demonstrates how RepairTemplateStrategy repairs responses using actual function calls. """ @@ -80,10 +79,10 @@ def get_stock_price(symbol: str) -> str: if hasattr(m.backend, "formatter"): try: rendered = m.backend.formatter.print(action) - print(f" Instruction sent to model:") - print(f" ---") + print(" Instruction sent to model:") + print(" ---") print(f" {rendered}") - print(f" ---") + print(" ---") except Exception: pass diff --git a/docs/examples/sessions/creating_a_new_type_of_session.py b/docs/examples/sessions/creating_a_new_type_of_session.py index 85aec63c..92846ff8 100644 --- a/docs/examples/sessions/creating_a_new_type_of_session.py +++ b/docs/examples/sessions/creating_a_new_type_of_session.py @@ -1,14 +1,21 @@ # pytest: ollama, qualitative, llm from typing import Literal + from PIL import Image as PILImage from mellea import MelleaSession -from mellea.core import Backend, BaseModelSubclass from mellea.backends.ollama import OllamaModelBackend -from mellea.core import CBlock, Context, ImageBlock, Requirement -from mellea.stdlib.context import ChatContext +from mellea.core import ( + Backend, + BaseModelSubclass, + CBlock, + Context, + ImageBlock, + Requirement, +) from mellea.stdlib.components import Message +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import reqify from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py index 1e7e3ec8..ea77e801 100644 --- a/docs/examples/tools/interpreter_example.py +++ b/docs/examples/tools/interpreter_example.py @@ -1,9 +1,9 @@ # pytest: ollama, llm -from mellea.stdlib.tools import code_interpreter, local_code_interpreter -from mellea import start_session, MelleaSession +from mellea import MelleaSession, start_session from mellea.backends import ModelOption -from mellea.stdlib.requirements import uses_tool, tool_arg_validator +from mellea.stdlib.requirements import tool_arg_validator, uses_tool +from mellea.stdlib.tools import code_interpreter, local_code_interpreter def example_1(m: MelleaSession): @@ -34,6 +34,9 @@ def example_3(m: MelleaSession): tool_calls=True, ) + if plot_output.tool_calls is None: + raise ValueError("Expected tool_calls but got None") + code = plot_output.tool_calls["local_code_interpreter"].args["code"] print(f"Going to execute the following code:\n```python\n{code}\n```") @@ -64,6 +67,9 @@ def example_4(m: MelleaSession): tool_calls=True, ) + if plot_output.tool_calls is None: + raise ValueError("Expected tool_calls but got None") + code = plot_output.tool_calls["local_code_interpreter"].args["code"] print(f"Going to execute the following code:\n```python\n{code}\n```") diff --git a/docs/examples/tutorial/compositionality_with_generative_slots.py b/docs/examples/tutorial/compositionality_with_generative_slots.py index a8b7c9d5..26c11151 100644 --- a/docs/examples/tutorial/compositionality_with_generative_slots.py +++ b/docs/examples/tutorial/compositionality_with_generative_slots.py @@ -36,7 +36,7 @@ def generate_novel_recommendations(summary: str) -> str: # Compose the libraries. -from typing import Literal # noqa: E402 +from typing import Literal @generative @@ -54,7 +54,7 @@ def has_theme_and_plot(summary: str) -> Literal["yes", "no"]: """Check whether the summary contains both a plot and thematic elements.""" -from mellea import start_session # noqa: E402 +from mellea import start_session m = start_session() transcript = """Meeting Transcript: Market Risk Review -- Self-Sealing Stembolts Division diff --git a/docs/examples/tutorial/document_mobject.py b/docs/examples/tutorial/document_mobject.py index 9be644f3..d413d7d7 100644 --- a/docs/examples/tutorial/document_mobject.py +++ b/docs/examples/tutorial/document_mobject.py @@ -6,13 +6,13 @@ rd = RichDocument.from_document_file("https://arxiv.org/pdf/1906.04043") -from mellea.stdlib.components.docs.richdocument import Table # noqa: E402 +from mellea.stdlib.components.docs.richdocument import Table table1: Table = rd.get_tables()[0] print(table1.to_markdown()) -from mellea import start_session # noqa: E402 -from mellea.backends import ModelOption # noqa: E402 +from mellea import start_session +from mellea.backends import ModelOption m = start_session(model_id=model_ids.META_LLAMA_3_2_3B) for seed in [x * 12 for x in range(5)]: diff --git a/docs/examples/tutorial/instruct_validate_repair.py b/docs/examples/tutorial/instruct_validate_repair.py index 6295358b..4173f01b 100644 --- a/docs/examples/tutorial/instruct_validate_repair.py +++ b/docs/examples/tutorial/instruct_validate_repair.py @@ -1,8 +1,9 @@ # pytest: ollama, llm +from mellea.core import Requirement from mellea.stdlib.requirements import check, req, simple_validate -requirements = [ +requirements: list[Requirement | str] = [ req("The email should have a salutation"), # == r1 req( "Use only lower-case letters", @@ -11,22 +12,22 @@ check("Do not mention purple elephants."), # == r3 ] -import mellea # noqa: E402 -from mellea.stdlib.sampling import RejectionSamplingStrategy # noqa: E402 +import mellea +from mellea.stdlib.sampling import RejectionSamplingStrategy def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str: email_candidate = m.instruct( "Write an email to {{name}} using the notes following: {{notes}}.", requirements=requirements, - strategy=RejectionSamplingStrategy(loop_budget=5), user_variables={"name": name, "notes": notes}, + strategy=RejectionSamplingStrategy(loop_budget=5), return_sampling_results=True, ) if email_candidate.success: return str(email_candidate.result) else: - return email_candidate.sample_generations[0].value + return email_candidate.sample_generations[0].value or "" m = mellea.start_session() diff --git a/docs/examples/tutorial/mcp_example.py b/docs/examples/tutorial/mcp_example.py deleted file mode 100644 index 25bf5199..00000000 --- a/docs/examples/tutorial/mcp_example.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Example of an MCP server. - -You need to install the mcp package: -uv pip install "mcp[cli]" - -and run the example in MCP debug UI: -uv run mcp dev docs/examples/tutorial/mcp_example.py -""" - -from mcp.server.fastmcp import FastMCP - -from mellea import MelleaSession -from mellea.backends import ModelOption, model_ids -from mellea.backends.ollama import OllamaModelBackend -from mellea.core import ModelOutputThunk -from mellea.stdlib.requirements import requirement, simple_validate -from mellea.stdlib.sampling import RejectionSamplingStrategy - -# ################# -# run MCP debug UI with: uv run mcp dev docs/examples/tutorial/mcp_example.py -# ################## - - -# Create an MCP server -mcp = FastMCP("Demo") - - -@mcp.tool() -def write_a_poem(word_limit: int) -> str: - """Write a poem with a word limit.""" - m = MelleaSession( - OllamaModelBackend( - model_ids.HF_SMOLLM2_2B, - model_options={ModelOption.MAX_NEW_TOKENS: word_limit + 10}, - ) - ) - wl_req = Requirement( - f"Use only {word_limit} words.", - validation_fn=simple_validate(lambda x: len(x.split(" ")) < word_limit), - ) - - res = m.instruct( - "Write a poem", - requirements=[wl_req], - strategy=RejectionSamplingStrategy(loop_budget=2), - ) - assert isinstance(res, ModelOutputThunk) - return str(res.value) - - -@mcp.resource("greeting://{name}") -def get_greeting(name: str) -> str: - """Get a personalized greeting.""" - return f"Hello, {name}!" diff --git a/docs/examples/tutorial/model_options_example.py b/docs/examples/tutorial/model_options_example.py index 2ee1ed2b..89971120 100644 --- a/docs/examples/tutorial/model_options_example.py +++ b/docs/examples/tutorial/model_options_example.py @@ -1,9 +1,8 @@ # pytest: ollama, llm import mellea -from mellea.backends import model_ids +from mellea.backends import ModelOption, model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.backends import ModelOption m = mellea.MelleaSession( backend=OllamaModelBackend(model_options={ModelOption.SEED: 42}) diff --git a/docs/examples/tutorial/simple_email.py b/docs/examples/tutorial/simple_email.py index 9f71462c..fb86dae8 100644 --- a/docs/examples/tutorial/simple_email.py +++ b/docs/examples/tutorial/simple_email.py @@ -54,7 +54,7 @@ def write_email_with_requirements( ) print("Email with rejection sampling:") -from mellea.stdlib.sampling import RejectionSamplingStrategy # noqa: E402 +from mellea.stdlib.sampling import RejectionSamplingStrategy def write_email_with_strategy(m: mellea.MelleaSession, name: str, notes: str) -> str: diff --git a/docs/kv_smash/hf_example.py b/docs/kv_smash/hf_example.py index dc81a64a..c5db6967 100644 --- a/docs/kv_smash/hf_example.py +++ b/docs/kv_smash/hf_example.py @@ -1,10 +1,11 @@ +import asyncio + +from mellea.backends import ModelOption from mellea.backends.huggingface import LocalHFBackend from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.backends import ModelOption from mellea.core import CBlock -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message -import asyncio +from mellea.stdlib.context import ChatContext async def example(): diff --git a/docs/kv_smash/kv_with_chat.py b/docs/kv_smash/kv_with_chat.py index e0f43bc4..bdf6f38e 100644 --- a/docs/kv_smash/kv_with_chat.py +++ b/docs/kv_smash/kv_with_chat.py @@ -49,7 +49,9 @@ def merge(toks, dcs): {"role": "user", "content": c_blocks[1]}, {"role": "user", "content": "Also no cash"}, ] -templatized_input = tokenizer.apply_chat_template(conversation=messages, tokenize=False) +templatized_input: str = tokenizer.apply_chat_template( # type: ignore[assignment] + conversation=messages, tokenize=False +) str_parts = [] tok_parts = [] @@ -93,7 +95,7 @@ def merge(toks, dcs): merged_dcs.crop(-1) # generate and print result. -result = model.generate( +result = model.generate( # type: ignore[operator] merged_toks.to(device), attention_mask=merged_masks.to(device), past_key_values=merged_dcs, diff --git a/docs/kv_smash/kvcache.py b/docs/kv_smash/kvcache.py index 94bfcc58..51e1b5cc 100644 --- a/docs/kv_smash/kvcache.py +++ b/docs/kv_smash/kvcache.py @@ -5,10 +5,9 @@ # ] # /// import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches -from transformers import AutoModelForCausalLM, PreTrainedTokenizer, AutoTokenizer -import torch model_id = "ibm-granite/granite-3.3-8b-instruct" device = torch.device("mps") diff --git a/docs/rewrite/session_deepdive/0.py b/docs/rewrite/session_deepdive/step0_session_api.py similarity index 99% rename from docs/rewrite/session_deepdive/0.py rename to docs/rewrite/session_deepdive/step0_session_api.py index 82d15776..271586b9 100644 --- a/docs/rewrite/session_deepdive/0.py +++ b/docs/rewrite/session_deepdive/step0_session_api.py @@ -1,7 +1,6 @@ from mellea import MelleaSession -from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend - +from mellea.stdlib.context import SimpleContext m = MelleaSession(backend=OllamaModelBackend("granite4:latest"), ctx=SimpleContext()) response = m.chat("What is 1+1?") diff --git a/docs/rewrite/session_deepdive/1.py b/docs/rewrite/session_deepdive/step1_functional_api.py similarity index 100% rename from docs/rewrite/session_deepdive/1.py rename to docs/rewrite/session_deepdive/step1_functional_api.py index c12e7df6..e629e840 100644 --- a/docs/rewrite/session_deepdive/1.py +++ b/docs/rewrite/session_deepdive/step1_functional_api.py @@ -1,6 +1,6 @@ import mellea.stdlib.functional as mfuncs -from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.context import SimpleContext response, next_context = mfuncs.chat( "What is 1+1?", diff --git a/docs/rewrite/session_deepdive/2.py b/docs/rewrite/session_deepdive/step2_act_cblocks.py similarity index 76% rename from docs/rewrite/session_deepdive/2.py rename to docs/rewrite/session_deepdive/step2_act_cblocks.py index 8002128d..86e9d285 100644 --- a/docs/rewrite/session_deepdive/2.py +++ b/docs/rewrite/session_deepdive/step2_act_cblocks.py @@ -1,10 +1,10 @@ import mellea.stdlib.functional as mfuncs -from mellea.stdlib.context import SimpleContext -from mellea.core import CBlock from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.components import Instruction +from mellea.stdlib.context import SimpleContext response, next_context = mfuncs.act( - CBlock("What is 1+1?"), + action=Instruction("What is 1+1?"), context=SimpleContext(), backend=OllamaModelBackend("granite4:latest"), ) diff --git a/docs/rewrite/session_deepdive/3.py b/docs/rewrite/session_deepdive/step3_async.py similarity index 59% rename from docs/rewrite/session_deepdive/3.py rename to docs/rewrite/session_deepdive/step3_async.py index 1d522f77..5e765add 100644 --- a/docs/rewrite/session_deepdive/3.py +++ b/docs/rewrite/session_deepdive/step3_async.py @@ -1,13 +1,15 @@ +import asyncio + import mellea.stdlib.functional as mfuncs -from mellea.core import CBlock, Context, Backend -from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend -import asyncio +from mellea.core import Backend, Context +from mellea.stdlib.components import Instruction +from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): - response, next_context = await mfuncs.aact( - CBlock("What is 1+1?"), context=ctx, backend=backend + response, _next_context = await mfuncs.aact( + action=Instruction("What is 1+1?"), context=ctx, backend=backend ) print(response.value) diff --git a/docs/rewrite/session_deepdive/4.py b/docs/rewrite/session_deepdive/step4_lazy_thunks.py similarity index 85% rename from docs/rewrite/session_deepdive/4.py rename to docs/rewrite/session_deepdive/step4_lazy_thunks.py index 9cd71dd8..1dc3871f 100644 --- a/docs/rewrite/session_deepdive/4.py +++ b/docs/rewrite/session_deepdive/step4_lazy_thunks.py @@ -1,13 +1,14 @@ -from mellea.core import CBlock, Context, Backend +import asyncio + from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context from mellea.stdlib.context import SimpleContext -import asyncio async def main(backend: Backend, ctx: Context): # This is not actually an async function; the computation ends immediately. It must be awaited because we create the thunk. # TODO clean up the above comment. - response, next_context = await backend.generate_from_context( + response, _next_context = await backend.generate_from_context( CBlock("What is 1+1?"), ctx=ctx, # TODO we should rationalize ctx and context acress mfuncs and base/backend. ) diff --git a/docs/rewrite/session_deepdive/5.py b/docs/rewrite/session_deepdive/step5_composition.py similarity index 95% rename from docs/rewrite/session_deepdive/5.py rename to docs/rewrite/session_deepdive/step5_composition.py index f03369c4..2494f5c8 100644 --- a/docs/rewrite/session_deepdive/5.py +++ b/docs/rewrite/session_deepdive/step5_composition.py @@ -1,10 +1,10 @@ -from mellea.core import CBlock, Context, Backend +import asyncio + from mellea.backends.ollama import OllamaModelBackend +from mellea.core import Backend, CBlock, Context from mellea.stdlib.components import SimpleComponent from mellea.stdlib.context import SimpleContext -import asyncio - async def main(backend: Backend, ctx: Context): x, _ = await backend.generate_from_context(CBlock("What is 1+1?"), ctx=ctx) diff --git a/docs/rewrite/streaming/0.py b/docs/rewrite/streaming/streaming_chat_example.py similarity index 100% rename from docs/rewrite/streaming/0.py rename to docs/rewrite/streaming/streaming_chat_example.py index 2098ea3d..3235ce28 100644 --- a/docs/rewrite/streaming/0.py +++ b/docs/rewrite/streaming/streaming_chat_example.py @@ -1,8 +1,8 @@ +import asyncio + from mellea import start_session -from mellea.core.base import CBlock from mellea.backends.model_options import ModelOption - -import asyncio +from mellea.core.base import CBlock async def stream_chat(prompt: str) -> str: diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 26ffb43c..b35c5fd2 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -635,7 +635,7 @@ async def _generate_from_context_with_kv_cache( linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None - input_text, input_ids, merged_cache, attention_mask = ( + _input_text, input_ids, merged_cache, attention_mask = ( self._make_merged_kv_cache( linearized_ctx=linearized_ctx, ctx_as_conversation=ctx_as_chat, diff --git a/pyproject.toml b/pyproject.toml index f53faffc..0d6290b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,7 @@ ignore = [ # "UP006", # List vs list, etc # "UP007", # Option and Union # "UP035", # `typing.Set` is deprecated, use `set` instead" - "PD901", # Avoid using the generic variable name `df` for DataFrames + "PD901", # Generic variable name 'df' for DataFrames (deprecated rule, but needed while PD is enabled) "C901", # Complexity warnings ] @@ -190,6 +190,23 @@ max-complexity = 20 combine-as-imports = true split-on-trailing-comma = false +[tool.ruff.lint.per-file-ignores] +# E402: Module level import not at top of file +# Intentional in examples (pedagogical structure) and tests (pytestmark before imports) +# D: Docstring errors +# Not required in examples, tests, and notebooks (core mellea/ has complete docstrings) +"docs/**/*.py" = ["E402", "D"] +"docs/**/*.ipynb" = ["D"] +"test/**/*.py" = ["E402", "D"] +"cli/**/*.py" = ["D"] + +[[tool.mypy.overrides]] +# Keep import-not-found suppressed for optional dependencies +module = "docs.*" +disable_error_code = [ + "import-not-found", +] + [tool.codespell] ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge,Strat' check-filenames = true diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/backends/test_adapters/test_adapter.py b/test/backends/test_adapters/test_adapter.py index 632cf7cf..e9022164 100644 --- a/test/backends/test_adapters/test_adapter.py +++ b/test/backends/test_adapters/test_adapter.py @@ -1,4 +1,5 @@ import pathlib + import pytest from mellea.backends.adapters import GraniteCommonAdapter diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 8add07aa..6b5bfb85 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -72,7 +72,7 @@ def session(backend): @pytest.mark.qualitative -def test_adapters(backend): +def test_adapters(backend) -> None: assert len(backend._added_adapters.items()) > 0 expected_qualified_name = "requirement_check_alora" @@ -90,7 +90,7 @@ def test_adapters(backend): @pytest.mark.qualitative -def test_system_prompt(session): +def test_system_prompt(session) -> None: result = session.chat( "Where are we going?", model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."}, @@ -99,8 +99,8 @@ def test_system_prompt(session): @pytest.mark.qualitative -def test_constraint_lora_with_requirement(session, backend): - answer = session.instruct( +def test_constraint_lora_with_requirement(session, backend) -> None: + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) assert session.backend._cache is not None # type: ignore @@ -116,9 +116,9 @@ def test_constraint_lora_with_requirement(session, backend): @pytest.mark.qualitative -def test_constraint_lora_override(session, backend): +def test_constraint_lora_override(session, backend) -> None: backend.default_to_constraint_checking_alora = False # type: ignore - answer = session.instruct( + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = session.validate( @@ -132,9 +132,9 @@ def test_constraint_lora_override(session, backend): @pytest.mark.qualitative -def test_constraint_lora_override_does_not_override_alora(session, backend): +def test_constraint_lora_override_does_not_override_alora(session, backend) -> None: backend.default_to_constraint_checking_alora = False # type: ignore - answer = session.instruct( + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = session.validate( @@ -158,9 +158,9 @@ def test_constraint_lora_override_does_not_override_alora(session, backend): @pytest.mark.qualitative -def test_llmaj_req_does_not_use_alora(session, backend): +def test_llmaj_req_does_not_use_alora(session, backend) -> None: backend.default_to_constraint_checking_alora = True # type: ignore - answer = session.instruct( + session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = session.validate( @@ -176,15 +176,15 @@ def test_llmaj_req_does_not_use_alora(session, backend): @pytest.mark.qualitative -def test_instruct(session): +def test_instruct(session) -> None: result = session.instruct("Compute 1+1.") print(result) @pytest.mark.qualitative -def test_multiturn(session): +def test_multiturn(session) -> None: session.instruct("Compute 1+1") - beta = session.instruct( + session.instruct( "Take the result of the previous sum and find the corresponding letter in the greek alphabet.", model_options={ModelOption.MAX_NEW_TOKENS: 300}, ) @@ -193,7 +193,7 @@ def test_multiturn(session): @pytest.mark.qualitative -def test_chat(session): +def test_chat(session) -> None: output_message = session.chat("What is 1+1?") assert "2" in output_message.content, ( f"Expected a message with content containing 2 but found {output_message}" @@ -201,7 +201,7 @@ def test_chat(session): @pytest.mark.qualitative -def test_format(session): +def test_format(session) -> None: class Person(pydantic.BaseModel): name: str email_address: Annotated[ @@ -235,7 +235,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = [ "what is 1+1?", "what is 2+2?", @@ -253,7 +253,7 @@ async def test_generate_from_raw(session): @pytest.mark.qualitative -async def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -270,7 +270,7 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) + Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" @@ -278,7 +278,7 @@ class Answer(pydantic.BaseModel): @pytest.mark.qualitative -async def test_async_parallel_requests(session): +async def test_async_parallel_requests(session) -> None: model_opts = {ModelOption.STREAM: True} mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts @@ -314,7 +314,7 @@ async def test_async_parallel_requests(session): @pytest.mark.qualitative -async def test_async_avalue(session): +async def test_async_avalue(session) -> None: mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) @@ -324,7 +324,7 @@ async def test_async_avalue(session): @pytest.mark.qualitative -async def test_generate_with_lock(backend): +async def test_generate_with_lock(backend) -> None: # Enable the faulthandler for this test. faulthandler.enable(all_threads=True) @@ -343,7 +343,7 @@ async def test_generate_with_lock(backend): GraniteCommonAdapter("answerability", base_model_name=b.base_model_name) ) - memoized = dict() + memoized: dict[torch.Tensor, str] = dict() gen_func = model.generate def mock_func(input_ids, *args, **kwargs): @@ -414,7 +414,7 @@ def call_backend_generate(): @pytest.mark.skipif( sys.version_info < (3, 11), reason="asyncio.timeout requires python3.11 or higher" ) -async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): +async def test_generate_with_lock_does_not_block_when_awaiting_value(backend) -> None: """This is a tricky test to setup. It's purpose is to ensure that a long-running generation doesn't get blocked @@ -470,8 +470,7 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): # most likely due to a deadlock caused by awaiting a generation that cannot complete until # the streaming is done. try: - async with asyncio.timeout(timeout_in_seconds): - await req_mot.avalue() + await asyncio.wait_for(req_mot.avalue(), timeout=timeout_in_seconds) except Exception as e: # The timeout could also be caused by the generation taking too long... be careful! # We assume that if the streaming model output thunk is computed after getting its astream here, @@ -488,7 +487,7 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend): @pytest.mark.qualitative -async def test_error_during_generate_with_lock(backend): +async def test_error_during_generate_with_lock(backend) -> None: # Create local versions of these objects so that mocking # doesn't impact other functions. Don't do this in regular code, # the copying is complex. @@ -529,7 +528,7 @@ def generate_and_raise_exc(*args, **kwargs): await req_mot.avalue() -def test_assert_correct_adapters(): +def test_assert_correct_adapters() -> None: model = Mock() # Test scenarios with no active adapters. diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index 2fb70e19..7999528a 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -132,7 +132,6 @@ def test_gen_slot(session): @generative def is_happy(text: str) -> bool: """Determine if text is of happy mood.""" - ... h = is_happy(session, text="I'm enjoying life.") diff --git a/test/backends/test_litellm_watsonx.py b/test/backends/test_litellm_watsonx.py index 014eeb2b..80f65b09 100644 --- a/test/backends/test_litellm_watsonx.py +++ b/test/backends/test_litellm_watsonx.py @@ -23,7 +23,7 @@ def session(): session.reset() -def test_has_potential_event_loop_errors(session): +def test_has_potential_event_loop_errors(session) -> None: """This test is specific to litellm backends that use watsonx/. It can be removed once that bug is fixed.""" backend: LiteLLMBackend = session.backend potential_err = backend._has_potential_event_loop_errors() @@ -42,13 +42,13 @@ async def new_event_loop() -> bool: @pytest.mark.qualitative -def test_multiple_sync_funcs(session): +def test_multiple_sync_funcs(session) -> None: session.chat("first") session.chat("second") @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+2+2?"] results = await session.backend.generate_from_raw( @@ -65,7 +65,7 @@ async def test_generate_from_raw(session): @pytest.mark.xfail( reason="litellm has a bug with watsonx; once that is fixed, this should pass." ) -async def test_multiple_async_funcs(session): +async def test_multiple_async_funcs(session) -> None: """If this test passes, remove the _has_potential_event_loop_errors func from litellm.""" session.chat( "first sync" diff --git a/test/backends/test_mellea_tool.py b/test/backends/test_mellea_tool.py index 9449c5fb..85b6a39d 100644 --- a/test/backends/test_mellea_tool.py +++ b/test/backends/test_mellea_tool.py @@ -1,5 +1,5 @@ import pytest -from langchain_core.tools import Tool, tool +from langchain_core.tools import Tool, tool # type: ignore[import-not-found] from pydantic_core import ValidationError import mellea diff --git a/test/backends/test_model_options.py b/test/backends/test_model_options.py index 5a75d7f3..36c1c228 100644 --- a/test/backends/test_model_options.py +++ b/test/backends/test_model_options.py @@ -1,4 +1,5 @@ import pytest + from mellea.backends import ModelOption diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 922777fb..fcca7fcd 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -25,7 +25,7 @@ def session(): @pytest.mark.qualitative -def test_simple_instruct(session): +def test_simple_instruct(session) -> None: result = session.instruct( "Write an email to Hendrik trying to sell him self-sealing stembolts." ) @@ -37,8 +37,8 @@ def test_simple_instruct(session): @pytest.mark.qualitative -def test_instruct_with_requirement(session): - response = session.instruct( +def test_instruct_with_requirement(session) -> None: + session.instruct( "Write an email to Hendrik convincing him to buy some self-sealing stembolts." ) @@ -61,7 +61,7 @@ def test_instruct_with_requirement(session): @pytest.mark.qualitative -def test_chat(session): +def test_chat(session) -> None: output_message = session.chat("What is 1+1?") assert "2" in output_message.content, ( f"Expected a message with content containing 2 but found {output_message}" @@ -69,7 +69,7 @@ def test_chat(session): @pytest.mark.qualitative -def test_format(session): +def test_format(session) -> None: class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -102,7 +102,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] results = await session.backend.generate_from_raw( @@ -114,7 +114,7 @@ async def test_generate_from_raw(session): @pytest.mark.xfail(reason="ollama sometimes fails generated structured outputs") -async def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -131,14 +131,14 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) + Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" ) -async def test_async_parallel_requests(session): +async def test_async_parallel_requests(session) -> None: model_opts = {ModelOption.STREAM: True} mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts @@ -173,7 +173,7 @@ async def test_async_parallel_requests(session): assert m2_final_val == mot2.value -async def test_async_avalue(session): +async def test_async_avalue(session) -> None: mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) @@ -182,7 +182,7 @@ async def test_async_avalue(session): assert m1_final_val == mot1.value -def test_multiple_asyncio_runs(session): +def test_multiple_asyncio_runs(session) -> None: async def test(): result = await session.achat("hello") assert result is not None @@ -191,7 +191,7 @@ async def test(): asyncio.run(test()) -def test_client_cache(session): +def test_client_cache(session) -> None: backend: OllamaModelBackend = session.backend first_client = backend._async_client diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 86f35174..d95648b0 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -39,14 +39,14 @@ def m_session(backend): @pytest.mark.qualitative -def test_instruct(m_session): +def test_instruct(m_session) -> None: result = m_session.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore @pytest.mark.qualitative -def test_multiturn(m_session): +def test_multiturn(m_session) -> None: m_session.instruct("What is the capital of France?") answer = m_session.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore @@ -68,7 +68,7 @@ def test_multiturn(m_session): @pytest.mark.qualitative -def test_chat(m_session): +def test_chat(m_session) -> None: output_message = m_session.chat("What is 1+1?") assert "2" in output_message.content, ( f"Expected a message with content containing 2 but found {output_message}" @@ -76,7 +76,7 @@ def test_chat(m_session): @pytest.mark.qualitative -def test_format(m_session): +def test_format(m_session) -> None: class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -109,11 +109,11 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(m_session): +async def test_generate_from_raw(m_session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] with pytest.raises(openai.BadRequestError): - results = await m_session.backend.generate_from_raw( + await m_session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx ) @@ -141,7 +141,7 @@ async def test_generate_from_raw(m_session): # assert False, f"formatting directive failed for {random_result.value}: {e.json()}" -async def test_async_parallel_requests(m_session): +async def test_async_parallel_requests(m_session) -> None: model_opts = {ModelOption.STREAM: True} mot1, _ = await m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts @@ -176,7 +176,7 @@ async def test_async_parallel_requests(m_session): assert m2_final_val == mot2.value -async def test_async_avalue(m_session): +async def test_async_avalue(m_session) -> None: mot1, _ = await m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) @@ -185,7 +185,7 @@ async def test_async_avalue(m_session): assert m1_final_val == mot1.value -def test_client_cache(backend): +def test_client_cache(backend) -> None: first_client = backend._async_client async def get_client_async(): @@ -211,7 +211,7 @@ async def get_client_async(): assert len(backend._client_cache.cache.values()) == 2 -async def test_reasoning_effort_conditional_passing(backend): +async def test_reasoning_effort_conditional_passing(backend) -> None: """Test that reasoning_effort is only passed to API when not None.""" from unittest.mock import AsyncMock, MagicMock, patch @@ -251,7 +251,7 @@ async def test_reasoning_effort_conditional_passing(backend): ) -def test_api_key_and_base_url_from_parameters(): +def test_api_key_and_base_url_from_parameters() -> None: """Test that API key and base URL can be set via parameters.""" backend = OpenAIBackend( model_id="gpt-4", api_key="test-api-key", base_url="https://api.test.com/v1" @@ -260,7 +260,7 @@ def test_api_key_and_base_url_from_parameters(): assert backend._base_url == "https://api.test.com/v1" -def test_parameter_overrides_env_variable(): +def test_parameter_overrides_env_variable() -> None: """Test that explicit parameters override environment variables.""" with patch.dict( os.environ, @@ -275,7 +275,7 @@ def test_parameter_overrides_env_variable(): assert backend._base_url == "https://api.param.com/v1" -def test_missing_api_key_raises_error(): +def test_missing_api_key_raises_error() -> None: """Test that missing API key raises ValueError with helpful message.""" with patch.dict(os.environ, {}, clear=True): with pytest.raises(ValueError) as exc_info: diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index d30c7f7f..2029dfe5 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -1,16 +1,16 @@ # test/rits_backend_tests/test_openai_integration.py import os +from typing import Annotated import pydantic import pytest -from typing_extensions import Annotated from mellea import MelleaSession +from mellea.backends import ModelOption from mellea.backends.adapters import GraniteCommonAdapter -from mellea.formatters import TemplateFormatter from mellea.backends.openai import OpenAIBackend -from mellea.backends import ModelOption from mellea.core import CBlock, Context, ModelOutputThunk +from mellea.formatters import TemplateFormatter from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import ALoraRequirement, LLMaJRequirement @@ -37,14 +37,14 @@ class TestOpenAIBackend: ) m = MelleaSession(backend, ctx=ChatContext()) - def test_instruct(self): + def test_instruct(self) -> None: self.m.reset() result = self.m.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore self.m.reset() - def test_multiturn(self): + def test_multiturn(self) -> None: self.m.instruct("What is the capital of France?") answer = self.m.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore @@ -65,7 +65,7 @@ def test_multiturn(self): # assert "granite3.3:8b" in result.value # self.m.reset() - def test_format(self): + def test_format(self) -> None: class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -95,9 +95,8 @@ class Email(pydantic.BaseModel): # this is not guaranteed, due to the lack of regexp pattern # assert "@" in email.to.email_address # assert email.to.email_address.endswith("example.com") - pass - async def test_generate_from_raw(self): + async def test_generate_from_raw(self) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] results = await self.m.backend.generate_from_raw( @@ -107,7 +106,7 @@ async def test_generate_from_raw(self): assert len(results) == len(prompts) assert results[0].value is not None - async def test_generate_from_raw_with_format(self): + async def test_generate_from_raw_with_format(self) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -124,7 +123,7 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) # type: ignore + Answer.model_validate_json(random_result.value) # type: ignore except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" @@ -146,7 +145,7 @@ class TestOpenAIALoraStuff: m = MelleaSession(backend, ctx=ChatContext()) - def test_adapters(self): + def test_adapters(self) -> None: assert len(self.backend._added_adapters.items()) > 0 adapter = self.backend._added_adapters["requirement_check_alora"] @@ -161,7 +160,7 @@ def test_adapters(self): self.backend.unload_adapter(adapter.qualified_name) assert adapter.qualified_name not in self.backend._loaded_adapters - def test_system_prompt(self): + def test_system_prompt(self) -> None: self.m.reset() result = self.m.chat( "Where are we going?", @@ -169,9 +168,9 @@ def test_system_prompt(self): ) print(result) - def test_constraint_lora_with_requirement(self): + def test_constraint_lora_with_requirement(self) -> None: self.m.reset() - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -184,10 +183,10 @@ def test_constraint_lora_with_requirement(self): assert "requirement_likelihood" in str(val_result.reason) self.m.reset() - def test_constraint_lora_override(self): + def test_constraint_lora_override(self) -> None: self.m.reset() self.backend.default_to_constraint_checking_alora = False # type: ignore - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -201,10 +200,10 @@ def test_constraint_lora_override(self): self.backend.default_to_constraint_checking_alora = True self.m.reset() - def test_constraint_lora_override_does_not_override_alora(self): + def test_constraint_lora_override_does_not_override_alora(self) -> None: self.m.reset() self.backend.default_to_constraint_checking_alora = False # type: ignore - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -220,6 +219,7 @@ def test_constraint_lora_override_does_not_override_alora(self): # the correct actions / results in it. assert isinstance(non_alora_output.context, Context) assert isinstance(non_alora_output.thunk, ModelOutputThunk) + assert non_alora_output.context.previous_node is not None assert isinstance( non_alora_output.context.previous_node.node_data, ALoraRequirement, # type: ignore @@ -229,10 +229,10 @@ def test_constraint_lora_override_does_not_override_alora(self): self.backend.default_to_constraint_checking_alora = True self.m.reset() - def test_llmaj_req_does_not_use_alora(self): + def test_llmaj_req_does_not_use_alora(self) -> None: self.m.reset() self.backend.default_to_constraint_checking_alora = True # type: ignore - answer = self.m.instruct( + self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( @@ -245,15 +245,15 @@ def test_llmaj_req_does_not_use_alora(self): assert str(non_alora_output.reason) not in ["Y", "N"] self.m.reset() - def test_instruct(self): + def test_instruct(self) -> None: self.m.reset() result = self.m.instruct("Compute 1+1.") print(result) self.m.reset() - def test_multiturn(self): + def test_multiturn(self) -> None: self.m.instruct("Compute 1+1") - beta = self.m.instruct( + self.m.instruct( "Let n be the result of the previous sum. Find the n-th letter in the greek alphabet." ) words = self.m.instruct( @@ -262,7 +262,7 @@ def test_multiturn(self): print(words) self.m.reset() - def test_format(self): + def test_format(self) -> None: class Person(pydantic.BaseModel): name: str email_address: Annotated[ diff --git a/test/backends/test_tool_calls.py b/test/backends/test_tool_calls.py index 9ebf02dd..bd6c70f3 100644 --- a/test/backends/test_tool_calls.py +++ b/test/backends/test_tool_calls.py @@ -1,16 +1,16 @@ import pytest +from mellea.backends import ModelOption from mellea.backends.ollama import OllamaModelBackend from mellea.backends.tools import ( + AbstractMelleaTool, + MelleaTool, add_tools_from_context_actions, add_tools_from_model_options, - MelleaTool, ) -from mellea.backends import ModelOption from mellea.core import ModelOutputThunk -from mellea.stdlib.context import ChatContext - from mellea.stdlib.components.docs.richdocument import Table +from mellea.stdlib.context import ChatContext from mellea.stdlib.session import MelleaSession @@ -47,7 +47,7 @@ def test2(): ... ModelOption.TOOLS: [MelleaTool.from_callable(t) for t in [test1, test2]] } - tools = {} + tools: dict[str, AbstractMelleaTool] = {} add_tools_from_model_options(tools, model_opts) assert "test1" in tools diff --git a/test/backends/test_tool_helpers.py b/test/backends/test_tool_helpers.py index 5614ac40..95c8a521 100644 --- a/test/backends/test_tool_helpers.py +++ b/test/backends/test_tool_helpers.py @@ -1,10 +1,11 @@ import pytest + +from mellea.backends import ModelOption from mellea.backends.tools import ( + MelleaTool, add_tools_from_context_actions, add_tools_from_model_options, - MelleaTool, ) -from mellea.backends import ModelOption from mellea.core import CBlock, Component, ModelOutputThunk, TemplateRepresentation diff --git a/test/backends/test_vision_openai.py b/test/backends/test_vision_openai.py index b712f53c..bfcfd681 100644 --- a/test/backends/test_vision_openai.py +++ b/test/backends/test_vision_openai.py @@ -121,9 +121,13 @@ def test_image_block_in_instruction( image_url = content_img.get("image_url") assert image_url is not None assert "url" in image_url + assert isinstance(image_url, dict) # check that the image is in the url content - assert image_block.value[:100] in image_url["url"] + url_value = image_url["url"] + assert isinstance(url_value, str) + assert image_block.value is not None + assert image_block.value[:100] in url_value @pytest.mark.qualitative @@ -175,9 +179,14 @@ def test_image_block_in_chat( image_url = content_img.get("image_url") assert image_url is not None assert "url" in image_url + assert isinstance(image_url, dict) # check that the image is in the url content - assert ImageBlock.from_pil_image(pil_image).value[:100] in image_url["url"] + image_value = ImageBlock.from_pil_image(pil_image).value + assert image_value is not None + url_value = image_url["url"] + assert isinstance(url_value, str) + assert image_value[:100] in url_value if __name__ == "__main__": diff --git a/test/backends/test_vllm.py b/test/backends/test_vllm.py index ed4a0354..c4f49a20 100644 --- a/test/backends/test_vllm.py +++ b/test/backends/test_vllm.py @@ -55,7 +55,7 @@ def session(backend): @pytest.mark.qualitative -def test_system_prompt(session): +def test_system_prompt(session) -> None: result = session.chat( "Where are we going?", model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."}, @@ -64,15 +64,15 @@ def test_system_prompt(session): @pytest.mark.qualitative -def test_instruct(session): +def test_instruct(session) -> None: result = session.instruct("Compute 1+1.") print(result) @pytest.mark.qualitative -def test_multiturn(session): +def test_multiturn(session) -> None: session.instruct("Compute 1+1") - beta = session.instruct( + session.instruct( "Take the result of the previous sum and find the corresponding letter in the greek alphabet." ) words = session.instruct("Now list five English words that start with that letter.") @@ -80,7 +80,7 @@ def test_multiturn(session): @pytest.mark.qualitative -def test_format(session): +def test_format(session) -> None: class Person(pydantic.BaseModel): name: str email_address: Annotated[ @@ -111,7 +111,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -async def test_generate_from_raw(session): +async def test_generate_from_raw(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] results = await session.backend.generate_from_raw( @@ -123,7 +123,7 @@ async def test_generate_from_raw(session): @pytest.mark.qualitative -async def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session) -> None: prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): @@ -140,7 +140,7 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) + Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" @@ -148,7 +148,7 @@ class Answer(pydantic.BaseModel): @pytest.mark.qualitative -def test_async_parallel_requests(session): +def test_async_parallel_requests(session) -> None: async def parallel_requests(): model_opts = {ModelOption.STREAM: True} mot1, _ = await session.backend.generate_from_context( @@ -187,7 +187,7 @@ async def parallel_requests(): @pytest.mark.qualitative -def test_async_avalue(session): +def test_async_avalue(session) -> None: async def avalue(): mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() diff --git a/test/core/test_base.py b/test/core/test_base.py index 30aa6f7f..cada2f42 100644 --- a/test/core/test_base.py +++ b/test/core/test_base.py @@ -1,5 +1,7 @@ from typing import Any + import pytest + from mellea.core import CBlock, Component, ModelOutputThunk from mellea.stdlib.components import Message diff --git a/test/core/test_component_typing.py b/test/core/test_component_typing.py index 6a9c2a82..a20de283 100644 --- a/test/core/test_component_typing.py +++ b/test/core/test_component_typing.py @@ -44,7 +44,7 @@ def _parse(self, computed: ModelOutputThunk) -> int: return -1 try: return int(computed.value) - except: + except Exception: return -2 @@ -80,7 +80,7 @@ def test_mot_init_typing(): assert hasattr(mot, "__orig_class__"), ( "mots are generics and should have this field" ) - assert get_args(mot.__orig_class__)[0] == float, ( # type: ignore + assert get_args(mot.__orig_class__)[0] is float, ( # type: ignore f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" # type: ignore ) # type: ignore @@ -113,7 +113,7 @@ def test_component_parsing_fails(): def test_incorrect_type_override(): with pytest.raises(TypeError): - instruction = Instruction[int](description="this is an instruction") # type: ignore + Instruction[int](description="this is an instruction") # type: ignore # Marking as qualitative for now since there's so much generation required for this. diff --git a/test/formatters/test_template_formatter.py b/test/formatters/test_template_formatter.py index c29fcc4a..5d851053 100644 --- a/test/formatters/test_template_formatter.py +++ b/test/formatters/test_template_formatter.py @@ -2,15 +2,13 @@ import os import sys import tempfile -from typing import List import pytest -from mellea.formatters import TemplateFormatter -from mellea.backends.model_ids import ModelIdentifier, IBM_GRANITE_3_2_8B +from mellea.backends.model_ids import IBM_GRANITE_3_2_8B, ModelIdentifier from mellea.core import CBlock, Component, ModelOutputThunk, TemplateRepresentation -from mellea.stdlib.components import Message, Instruction -from mellea.stdlib.components import MObject +from mellea.formatters import TemplateFormatter +from mellea.stdlib.components import Instruction, Message, MObject @pytest.fixture(scope="module") @@ -100,7 +98,8 @@ def test_user_path(instr: Instruction): """Ensures that paths with no templates don't prevent default template lookups. Also creates a temporary dir to use as a user-specified dir and ensures template lookup - logic is correct.""" + logic is correct. + """ tf = TemplateFormatter( "granite3.3", template_path="/fake/path", use_template_cache=False ) @@ -160,7 +159,7 @@ def test_no_module(tf: TemplateFormatter): def test_no_template(tf: TemplateFormatter): class _NoTemplate(Component[str]): - def parts(self) -> List[Component | CBlock]: + def parts(self) -> list[Component | CBlock]: return [] def format_for_llm(self) -> TemplateRepresentation: @@ -214,7 +213,8 @@ def test_empty_model_id(instr: Instruction): def test_template_caching(instr: Instruction): """Caching shouldn't be interacted with this way by users. - Only toggling these internal variables to test code paths.""" + Only toggling these internal variables to test code paths. + """ tf = TemplateFormatter("default", use_template_cache=True) assert tf._template_cache is not None @@ -236,8 +236,8 @@ def test_template_caching(instr: Instruction): def test_custom_component_external_package(tf: TemplateFormatter): """Creates a fake package with a custom component and loads the package. - Ensures template loading works for custom components defined in other packages.""" - + Ensures template loading works for custom components defined in other packages. + """ new_component_content = """ from mellea.core import Component, TemplateRepresentation, ModelOutputThunk class NewComponent(Component[str]): diff --git a/test/stdlib/components/docs/test_richdocument.py b/test/stdlib/components/docs/test_richdocument.py index f96038c3..3538beac 100644 --- a/test/stdlib/components/docs/test_richdocument.py +++ b/test/stdlib/components/docs/test_richdocument.py @@ -1,10 +1,12 @@ import os -from mellea.core import TemplateRepresentation -from mellea.stdlib.components.docs.richdocument import RichDocument, Table -import mellea -from docling_core.types.doc.document import DoclingDocument import tempfile + import pytest +from docling_core.types.doc.document import DoclingDocument + +import mellea +from mellea.core import TemplateRepresentation +from mellea.stdlib.components.docs.richdocument import RichDocument, Table @pytest.fixture(scope="module") @@ -22,12 +24,12 @@ def test_richdocument_basics(rd: RichDocument): ) repr = rd.format_for_llm() - assert type(repr) == str, "rich document template args should be a dict" + assert isinstance(repr, str), "rich document template args should be a dict" def test_richdocument_markdown(rd: RichDocument): mkd = rd.to_markdown() - assert type(mkd) == str, "rich document `to_markdown` should be a string" + assert isinstance(mkd, str), "rich document `to_markdown` should be a string" assert "Bag of Words" in mkd, "expected string not in rd `to_markdown` output" @@ -48,7 +50,7 @@ def test_table(rd: RichDocument): # Getting the tables technically tests the functionality of richdocument, # but we do it here to make it easier. The provided document has one table. tables = rd.get_tables() - assert all(type(t) == Table for t in tables), ( + assert all(isinstance(t, Table) for t in tables), ( f"rich document `get_tables` returned a non-table value: {tables}" ) assert len(tables) > 0, ( @@ -57,13 +59,15 @@ def test_table(rd: RichDocument): table = tables[0] repr = table.format_for_llm() - assert type(repr) == TemplateRepresentation, "table template args should be a dict" + assert isinstance(repr, TemplateRepresentation), ( + "table template args should be a dict" + ) assert "table" in repr.args.keys() and len(repr.args.keys()) == 1, ( "table's should have a single `as_markdown` key" ) mkd_table = table.to_markdown() - assert type(mkd_table) == str, "table `to_markdown` should return a string" + assert isinstance(mkd_table, str), "table `to_markdown` should return a string" loaded_table = Table.from_markdown(mkd_table) assert loaded_table is not None, ( @@ -97,7 +101,7 @@ def test_empty_table(): def test_richdocument_generation(rd: RichDocument): m = mellea.start_session(backend_name="hf") response = m.chat(rd.to_markdown()[:500] + "\nSummarize the provided document.") - assert response.content is not "", ( + assert response.content != "", ( "response content should not be empty when summarizing a rich document" ) assert "paper" in response.content.lower() or "gltr" in response.content.lower(), ( diff --git a/test/stdlib/components/test_chat.py b/test/stdlib/components/test_chat.py index 1733e8e6..66ebb9fc 100644 --- a/test/stdlib/components/test_chat.py +++ b/test/stdlib/components/test_chat.py @@ -1,7 +1,7 @@ import pytest -from mellea.stdlib.components import Document -from mellea.stdlib.components import Message + from mellea.helpers import messages_to_docs +from mellea.stdlib.components import Document, Message def test_message_with_docs(): diff --git a/test/stdlib/components/test_genslot.py b/test/stdlib/components/test_genslot.py index ba956507..d3814ae4 100644 --- a/test/stdlib/components/test_genslot.py +++ b/test/stdlib/components/test_genslot.py @@ -78,7 +78,6 @@ def test_sentiment_output(classify_sentiment_output): def test_gen_slot_logs(classify_sentiment_output, session): - sent = classify_sentiment_output last_prompt = session.last_prompt()[-1] assert isinstance(last_prompt, dict) assert set(last_prompt.keys()) == {"role", "content", "images"} diff --git a/test/stdlib/components/test_mify.py b/test/stdlib/components/test_mify.py index 0587811a..d03ce9fd 100644 --- a/test/stdlib/components/test_mify.py +++ b/test/stdlib/components/test_mify.py @@ -1,9 +1,9 @@ import pytest -from mellea.formatters import TemplateFormatter from mellea.core import Component, TemplateRepresentation -from mellea.stdlib.components.mobject import Query, MObjectProtocol, MObject -from mellea.stdlib.components.mify import mify, MifiedProtocol +from mellea.formatters import TemplateFormatter +from mellea.stdlib.components.mify import MifiedProtocol, mify +from mellea.stdlib.components.mobject import MObject, MObjectProtocol, Query def test_protocol_adherence(): @@ -136,7 +136,7 @@ def get_details(self) -> str: return f"{self.name}, {self.age}" def extraneous_func(self): - """this function does nothing.""" + """This function does nothing.""" return diff --git a/test/stdlib/components/test_transform.py b/test/stdlib/components/test_transform.py index c99ac884..ecbd95d9 100644 --- a/test/stdlib/components/test_transform.py +++ b/test/stdlib/components/test_transform.py @@ -1,8 +1,8 @@ import pytest from mellea.core import TemplateRepresentation -from mellea.stdlib.components.docs.richdocument import TableTransform from mellea.stdlib.components import MObject, Query, Transform +from mellea.stdlib.components.docs.richdocument import TableTransform custom_mobject_description = "custom mobject description" @@ -53,7 +53,7 @@ def test_get_transform_object_custom(): assert isinstance(transform, TableTransform) with pytest.raises(AssertionError): - tr = transform.format_for_llm() + transform.format_for_llm() if __name__ == "__main__": diff --git a/test/stdlib/requirements/test_reqlib_markdown.py b/test/stdlib/requirements/test_reqlib_markdown.py index 5b901ff1..b663999f 100644 --- a/test/stdlib/requirements/test_reqlib_markdown.py +++ b/test/stdlib/requirements/test_reqlib_markdown.py @@ -60,26 +60,22 @@ async def test_markdown_table(): def test_default_output_to_bool_yes(): - assert default_output_to_bool("yeS") == True + assert default_output_to_bool("yeS") def test_default_output_to_bool_no(): - assert default_output_to_bool("nO") == False + assert not default_output_to_bool("nO") def test_default_output_to_bool_complicated_yes(): - assert ( - default_output_to_bool( - CBlock("The requirement is met by the output. Therefore, my answer is yes.") - ) - == True + assert default_output_to_bool( + CBlock("The requirement is met by the output. Therefore, my answer is yes.") ) def test_default_output_to_bool_word_with_yes_in_it(): - assert ( - default_output_to_bool("Here's a word that meets those requirements: ayes.") - == False + assert not default_output_to_bool( + "Here's a word that meets those requirements: ayes." ) diff --git a/test/stdlib/requirements/test_reqlib_python.py b/test/stdlib/requirements/test_reqlib_python.py index 403fa79e..85942716 100644 --- a/test/stdlib/requirements/test_reqlib_python.py +++ b/test/stdlib/requirements/test_reqlib_python.py @@ -17,17 +17,16 @@ _llm_sandbox_available = False from mellea.core import Context, ModelOutputThunk +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.python_reqs import ( PythonExecutionReq, _has_python_code_listing, _python_executes_without_error, ) -from mellea.stdlib.context import ChatContext def from_model(content: str) -> Context: """Helper to create context from model output.""" - ctx = ChatContext() ctx = ctx.add(ModelOutputThunk(value=content)) return ctx diff --git a/test/stdlib/requirements/test_reqlib_tools.py b/test/stdlib/requirements/test_reqlib_tools.py index 553cc0c9..c20d4a78 100644 --- a/test/stdlib/requirements/test_reqlib_tools.py +++ b/test/stdlib/requirements/test_reqlib_tools.py @@ -1,4 +1,5 @@ import pytest + from mellea.stdlib.requirements.tool_reqs import _name2str diff --git a/test/stdlib/requirements/test_requirement.py b/test/stdlib/requirements/test_requirement.py index 5db655fe..f6998265 100644 --- a/test/stdlib/requirements/test_requirement.py +++ b/test/stdlib/requirements/test_requirement.py @@ -1,6 +1,7 @@ import pytest -from mellea.stdlib.context import ChatContext + from mellea.core import ModelOutputThunk, Requirement +from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import LLMaJRequirement, simple_validate from mellea.stdlib.session import start_session @@ -54,7 +55,7 @@ def test_simple_validate_invalid(): validation_func = simple_validate(lambda x: None) # type: ignore with pytest.raises(ValueError): - val_result = validation_func(ctx) + validation_func(ctx) if __name__ == "__main__": diff --git a/test/stdlib/sampling/test_sofai_graph_coloring.py b/test/stdlib/sampling/test_sofai_graph_coloring.py index 37b06ffc..c1b85a90 100644 --- a/test/stdlib/sampling/test_sofai_graph_coloring.py +++ b/test/stdlib/sampling/test_sofai_graph_coloring.py @@ -409,7 +409,7 @@ def test_best_attempt_includes_coloring_feedback( ] ] - s2_action, s2_context = strategy._prepare_s2_context( + s2_action, _s2_context = strategy._prepare_s2_context( s2_mode="best_attempt", original_action=original_action, original_context=original_context, diff --git a/test/stdlib/test_base_context.py b/test/stdlib/test_base_context.py index 698a9240..2fccb11f 100644 --- a/test/stdlib/test_base_context.py +++ b/test/stdlib/test_base_context.py @@ -1,7 +1,7 @@ import pytest -from mellea.core import Context, CBlock -from mellea.stdlib.context import SimpleContext, ChatContext +from mellea.core import CBlock, Context +from mellea.stdlib.context import ChatContext, SimpleContext def context_construction(cls: type[Context]): diff --git a/test/stdlib/test_chat_view.py b/test/stdlib/test_chat_view.py index 9b0ff93d..327258e8 100644 --- a/test/stdlib/test_chat_view.py +++ b/test/stdlib/test_chat_view.py @@ -1,7 +1,7 @@ import pytest -from mellea.stdlib.context import ChatContext from mellea.stdlib.components import Message, as_chat_history +from mellea.stdlib.context import ChatContext from mellea.stdlib.session import start_session @@ -25,7 +25,7 @@ def test_chat_view_linear_ctx(linear_session): linear_session.chat("What is 1+1?") linear_session.chat("What is 2+2?") assert len(as_chat_history(linear_session.ctx)) == 4 - assert all([type(x) == Message for x in as_chat_history(linear_session.ctx)]) + assert all(isinstance(x, Message) for x in as_chat_history(linear_session.ctx)) assert len(linear_session.ctx.view_for_generation()) == 4 @@ -34,7 +34,7 @@ def test_chat_view_simple_ctx(simple_session): simple_session.chat("What is 1+1?") simple_session.chat("What is 2+2?") assert len(as_chat_history(simple_session.ctx)) == 4 - assert all([type(x) == Message for x in as_chat_history(simple_session.ctx)]) + assert all(isinstance(x, Message) for x in as_chat_history(simple_session.ctx)) assert len(simple_session.ctx.view_for_generation()) == 0 diff --git a/test/stdlib/test_session.py b/test/stdlib/test_session.py index 6fc07d1e..b25b3840 100644 --- a/test/stdlib/test_session.py +++ b/test/stdlib/test_session.py @@ -69,7 +69,7 @@ async def test_async_await_with_chat_context(m_session): ctx = ctx.previous_node # type: ignore # Ensure we made it back to the root. - assert ctx.is_root_node == True # type: ignore + assert ctx.is_root_node # type: ignore async def test_async_without_waiting_with_chat_context(m_session): diff --git a/test/stdlib/test_spans.py b/test/stdlib/test_spans.py index 71b03ed0..4440a1c0 100644 --- a/test/stdlib/test_spans.py +++ b/test/stdlib/test_spans.py @@ -29,7 +29,7 @@ def m_session(gh_run): @pytest.mark.qualitative -async def test_lazy_spans(m_session): +async def test_lazy_spans(m_session) -> None: m: MelleaSession = m_session backend, ctx = m.backend, m.ctx @@ -45,7 +45,7 @@ async def test_lazy_spans(m_session): @pytest.mark.qualitative -async def test_kv(m_session): +async def test_kv(m_session) -> None: m: MelleaSession = m_session backend, ctx = m.backend, m.ctx # type: ignore @@ -56,7 +56,7 @@ async def test_kv(m_session): ) ) - backend: LocalHFBackend = backend + assert isinstance(backend, LocalHFBackend) response = await backend._generate_from_context_with_kv_cache( action=CBlock("What is Nathan's work address?"), ctx=ctx, model_options=dict() )