From 17c312e8e0d2664727e045561a2909ec3fc1acf6 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 26 Jan 2026 09:55:09 -0500 Subject: [PATCH 1/6] fix astream output Signed-off-by: Akihiko Kuroda --- mellea/core/base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mellea/core/base.py b/mellea/core/base.py index 94894713..62e32243 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -275,6 +275,8 @@ async def astream(self) -> str: raise RuntimeError( f"Cannot use `ModelOutputThunk.astream()` when the generate function is using `{self._generate_type.name}`" ) + # Beginning value + beginning_value: str = self._underlying_value # type: ignore exception_to_raise = None try: @@ -350,12 +352,17 @@ async def astream(self) -> str: assert self.parsed_repr is not None, ( "enforce constraint that a computed ModelOutputThunk has a non-None parsed_repr" ) + return self._underlying_value # type: ignore # Re-raise exception after cleanup if one occurred if exception_to_raise is not None: raise exception_to_raise - return self._underlying_value # type: ignore + return ( + self._underlying_value + if beginning_value is None + else self._underlying_value[len(str(beginning_value)) :] # type: ignore + ) def __repr__(self): """Provides a python-parsable representation (usually). From 8af7c14eb98cc7048007fcde42393d2ac34bb762 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Mon, 26 Jan 2026 15:07:48 -0500 Subject: [PATCH 2/6] review comments Signed-off-by: Akihiko Kuroda --- mellea/core/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mellea/core/base.py b/mellea/core/base.py index 62e32243..0aa0afc6 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -276,7 +276,9 @@ async def astream(self) -> str: f"Cannot use `ModelOutputThunk.astream()` when the generate function is using `{self._generate_type.name}`" ) # Beginning value - beginning_value: str = self._underlying_value # type: ignore + beginning_length = ( + 0 if self._underlying_value is None else len(str(self._underlying_value)) + ) # type: ignore exception_to_raise = None try: @@ -360,8 +362,8 @@ async def astream(self) -> str: return ( self._underlying_value - if beginning_value is None - else self._underlying_value[len(str(beginning_value)) :] # type: ignore + if beginning_length is None + else self._underlying_value[beginning_length:] # type: ignore ) def __repr__(self): From 7fe50e1cb3f58229557e759b05f82f99082d504d Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Fri, 30 Jan 2026 09:40:30 -0500 Subject: [PATCH 3/6] review comment Signed-off-by: Akihiko Kuroda --- mellea/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mellea/core/base.py b/mellea/core/base.py index 0aa0afc6..7de10f7d 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -362,7 +362,7 @@ async def astream(self) -> str: return ( self._underlying_value - if beginning_length is None + if beginning_length == 0 else self._underlying_value[beginning_length:] # type: ignore ) From 981e7fdc34cd4ebc2ad2572fc0fbe1a0892ce4bb Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Fri, 30 Jan 2026 10:28:57 -0500 Subject: [PATCH 4/6] adding tests Signed-off-by: Akihiko Kuroda --- test/core/test_astream_incremental.py | 239 ++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 test/core/test_astream_incremental.py diff --git a/test/core/test_astream_incremental.py b/test/core/test_astream_incremental.py new file mode 100644 index 00000000..41fed489 --- /dev/null +++ b/test/core/test_astream_incremental.py @@ -0,0 +1,239 @@ +"""Tests for ModelOutputThunk.astream() incremental return behavior. + +Tests the changes around lines 278-281 and 352-356 in mellea/core/base.py +that ensure astream() returns only new content added since the beginning of +each astream() call, not the entire accumulated value. +""" + +import pytest + +from mellea.backends import ModelOption +from mellea.core import CBlock, ModelOutputThunk +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.session import start_session + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_astream_returns_incremental_chunks(): + """Test that astream() returns only new content, not accumulated content. + + This tests the fix where beginning_length is captured at the start of + astream() and the return value is sliced to only include new content. + """ + session = start_session() + model_opts = {ModelOption.STREAM: True} + + mot, _ = await session.backend.generate_from_context( + CBlock("Count from 1 to 5 slowly."), SimpleContext(), model_options=model_opts + ) + + # First astream call - should return content from beginning + chunk1 = await mot.astream() + assert chunk1 is not None, "First chunk should not be None" + assert len(chunk1) > 0, "First chunk should have content" + + # Second astream call - should return only NEW content since first call + chunk2 = await mot.astream() + + if not mot.is_computed(): + # If not computed, chunk2 should be new content only + assert chunk2 is not None, "Second chunk should not be None if not computed" + + # The key test: chunk2 should NOT start with chunk1 + # (it should be incremental, not accumulated) + if len(chunk2) > 0: + # chunk2 should be different from chunk1 (new content) + assert chunk2 != chunk1, ( + "Second chunk should be different from first (incremental)" + ) + + # Get final value + final_val = await mot.avalue() + + # Final value should contain both chunks in order + assert final_val.startswith(chunk1), ( + "Final value should start with first chunk" + ) + # The concatenation of chunks should be a prefix of or equal to final value + accumulated = chunk1 + chunk2 + assert final_val.startswith(accumulated) or accumulated.startswith( + final_val + ), "Accumulated chunks should match final value progression" + else: + # If computed after first astream, chunk2 should be empty or the remainder + final_val = await mot.avalue() + # chunk1 should be a prefix of final value + assert final_val.startswith(chunk1), "Final value should start with first chunk" + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_astream_multiple_calls_accumulate_correctly(): + """Test that multiple astream() calls accumulate to the final value. + + Note: The final astream() call that marks the thunk as computed returns + the FULL value (line 350 in base.py), not just the incremental part. + """ + session = start_session() + model_opts = {ModelOption.STREAM: True} + + mot, _ = await session.backend.generate_from_context( + CBlock("Write a short sentence."), SimpleContext(), model_options=model_opts + ) + + accumulated = "" + chunks = [] + + # Stream until computed + while not mot.is_computed(): + chunk = await mot.astream() + if chunk: + chunks.append(chunk) + # Only accumulate if this wasn't the final (completing) chunk + if not mot.is_computed(): + accumulated += chunk + + # Safety: don't loop forever + if len(chunks) > 100: + break + + # Get final value + final_val = await mot.avalue() + + # The last chunk should be the full value when computed + if len(chunks) > 0: + assert chunks[-1] == final_val, ( + f"Last chunk (when computed) should be full value.\n" + f"Last chunk: {chunks[-1]!r}\n" + f"Final: {final_val!r}" + ) + + # All chunks except the last should be incremental + if len(chunks) > 1: + incremental_accumulated = "".join(chunks[:-1]) + assert final_val.startswith(incremental_accumulated), ( + f"Incremental chunks should be prefix of final value.\n" + f"Accumulated: {incremental_accumulated!r}\n" + f"Final: {final_val!r}" + ) + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_astream_beginning_length_tracking(): + """Test that beginning_length is correctly tracked across astream calls. + + This specifically tests the logic at lines 278-281 where beginning_length + is captured at the start of each astream() call. + """ + session = start_session() + model_opts = {ModelOption.STREAM: True} + + mot, _ = await session.backend.generate_from_context( + CBlock("Say hello."), SimpleContext(), model_options=model_opts + ) + + # First call: beginning_length should be 0 (or length of any pre-existing value) + chunk1 = await mot.astream() + + # Second call: beginning_length should be captured at start of this call + chunk2 = await mot.astream() + + if chunk2 and len(chunk2) > 0: + # chunk2 should not include chunk1's content + # This verifies the slicing logic at lines 352-356 + if chunk1: + assert not chunk2.startswith(chunk1), ( + "Second chunk should not start with first chunk (should be incremental)" + ) + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_astream_empty_beginning(): + """Test astream when _underlying_value starts as None.""" + session = start_session() + model_opts = {ModelOption.STREAM: True} + + mot, _ = await session.backend.generate_from_context( + CBlock("Hi"), SimpleContext(), model_options=model_opts + ) + + # At the start, _underlying_value might be None + # beginning_length should be 0 in this case (line 280) + chunk = await mot.astream() + + assert chunk is not None, "Should get a chunk even when starting from None" + + # When beginning_length is 0, should return full _underlying_value (line 354) + if mot._underlying_value: + assert chunk == mot._underlying_value or mot._underlying_value.startswith( + chunk + ), "When beginning_length is 0, should return the full underlying value" + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_astream_computed_returns_full_value(): + """Test that astream returns full value when already computed.""" + # Create a pre-computed thunk + mot = ModelOutputThunk(value="Hello, world!") + mot._computed = True + + # astream should return the full value immediately (line 272) + result = await mot.astream() + + assert result == "Hello, world!", "Computed thunk should return full value" + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_astream_final_call_returns_full_value(): + """Test that the final astream call returns the full value when computed. + + This tests the behavior at line 350 in base.py where the final call + (when _computed becomes True) returns the full _underlying_value. + """ + session = start_session() + model_opts = {ModelOption.STREAM: True} + + mot, _ = await session.backend.generate_from_context( + CBlock("Count: 1, 2, 3"), SimpleContext(), model_options=model_opts + ) + + chunks = [] + + # Collect all chunks + while not mot.is_computed(): + chunk = await mot.astream() + if chunk: + chunks.append(chunk) + + if len(chunks) > 100: # Safety + break + + # Get final value + final_val = await mot.avalue() + + # The last chunk should be the full value (not incremental) + if len(chunks) > 0: + assert chunks[-1] == final_val, ( + f"Final chunk should be the complete value.\n" + f"Last chunk: {chunks[-1]!r}\n" + f"Final value: {final_val!r}" + ) + + # All chunks before the last should be incremental (non-overlapping) + for i in range(len(chunks) - 2): # Exclude the last chunk + for j in range(i + 1, len(chunks) - 1): # Exclude the last chunk + # Earlier incremental chunks shouldn't be prefixes of later ones + if chunks[j] and chunks[i]: + assert not chunks[j].startswith(chunks[i]), ( + f"Incremental chunk {j} should not start with chunk {i}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + From 141c5010265e28060c77642b735dd5e8566b67ff Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Fri, 30 Jan 2026 10:33:38 -0500 Subject: [PATCH 5/6] adding tests Signed-off-by: Akihiko Kuroda --- test/core/test_astream_incremental.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/core/test_astream_incremental.py b/test/core/test_astream_incremental.py index 41fed489..326ae627 100644 --- a/test/core/test_astream_incremental.py +++ b/test/core/test_astream_incremental.py @@ -236,4 +236,3 @@ async def test_astream_final_call_returns_full_value(): if __name__ == "__main__": pytest.main([__file__, "-v"]) - From 5f9d43632c17d55ae0add4c6dd71051a916ed609 Mon Sep 17 00:00:00 2001 From: Akihiko Kuroda Date: Fri, 30 Jan 2026 14:12:22 -0500 Subject: [PATCH 6/6] review comment Signed-off-by: Akihiko Kuroda --- test/core/test_astream_incremental.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/core/test_astream_incremental.py b/test/core/test_astream_incremental.py index 326ae627..78b0c2a8 100644 --- a/test/core/test_astream_incremental.py +++ b/test/core/test_astream_incremental.py @@ -1,7 +1,6 @@ """Tests for ModelOutputThunk.astream() incremental return behavior. -Tests the changes around lines 278-281 and 352-356 in mellea/core/base.py -that ensure astream() returns only new content added since the beginning of +Tests that astream() returns only new content added since the beginning of each astream() call, not the entire accumulated value. """