Conversation
|
The PR description has been updated. Please fill out the template for your PR to be reviewed. |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🟢 Enforce conventional commitWonderful, this rule succeeded.Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/
|
psschwei
left a comment
There was a problem hiding this comment.
small nit, but otherwise lgtm
mellea/core/base.py
Outdated
| f"Cannot use `ModelOutputThunk.astream()` when the generate function is using `{self._generate_type.name}`" | ||
| ) | ||
| # Beginning value | ||
| beginning_value: str = self._underlying_value # type: ignore |
There was a problem hiding this comment.
nit: should we just store the length of the underlying value? would give us a slightly smaller memory footprint.
6b190a0 to
0a4f108
Compare
0a4f108 to
c873cc4
Compare
|
I generated a couple of tests to verify the behavior here: import asyncio
import pytest
from mellea.backends import ModelOption
from mellea.core import ModelOutputThunk
from mellea.stdlib.session import MelleaSession, start_session
@pytest.fixture(scope="module")
def m_session(gh_run):
"""Create a session for testing streaming behavior."""
m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 50})
yield m
del m
@pytest.mark.asyncio
async def test_astream_returns_only_new_chunks(m_session: MelleaSession):
"""Test that astream() returns only new chunks on subsequent calls, not the entire accumulated content.
This tests the fix from PR #358 where beginning_length is tracked to return
only the delta between calls.
"""
# Create a streaming output
out = m_session.instruct("Count from 1 to 10")
# First call to astream should return some initial content
first_chunk = await out.astream()
assert isinstance(first_chunk, str)
first_length = len(first_chunk)
# If not yet computed, call astream again
if not out.is_computed():
second_chunk = await out.astream()
assert isinstance(second_chunk, str)
# The second chunk should NOT include the first chunk's content
# It should only contain NEW content
# The total accumulated value should be first_chunk + second_chunk
accumulated_value = out.value
assert accumulated_value is not None
# Verify that second_chunk is only the new part
# (not the entire accumulated content)
if len(second_chunk) > 0:
# If we got new content, verify it's a substring of the accumulated value
# starting after the first chunk
assert accumulated_value.endswith(second_chunk) or second_chunk in accumulated_value
# The second chunk should be shorter than or equal to the total accumulated value
assert len(second_chunk) <= len(accumulated_value)
# The second chunk should not be identical to the full accumulated value
# (unless the first chunk was empty)
if first_length > 0:
assert second_chunk != accumulated_value
@pytest.mark.asyncio
async def test_astream_full_completion(m_session: MelleaSession):
"""Test that repeatedly calling astream() eventually returns the full completed output."""
out = m_session.instruct("Say hello")
accumulated_chunks = []
# Keep calling astream until completion
while not out.is_computed():
chunk = await out.astream()
accumulated_chunks.append(chunk)
# Get final chunk after completion
final_chunk = await out.astream()
accumulated_chunks.append(final_chunk)
# The concatenation of all chunks should equal the final value
concatenated = "".join(accumulated_chunks)
assert out.value is not None
assert concatenated == out.value
@pytest.mark.asyncio
async def test_astream_on_computed_thunk(m_session: MelleaSession):
"""Test that astream() on an already computed thunk returns the full value."""
out = m_session.instruct("Hello world")
# Wait for completion
final_value = await out.avalue()
assert out.is_computed()
# Calling astream on a computed thunk should return the full value
streamed_value = await out.astream()
assert streamed_value == final_value
@pytest.mark.asyncio
async def test_astream_empty_initial_value():
"""Test astream behavior when _underlying_value starts as None."""
# Create a thunk without initial value
thunk = ModelOutputThunk(None)
# Manually set it to computed with a value to test the edge case
thunk._underlying_value = "test content"
thunk._computed = True
# astream should return the full value when computed
result = await thunk.astream()
assert result == "test content"
@pytest.mark.asyncio
async def test_avalue_returns_full_content(m_session: MelleaSession):
"""Test that avalue() always returns the complete accumulated content."""
out = m_session.instruct("Count to 5")
# avalue should wait for completion and return full content
full_value = await out.avalue()
assert isinstance(full_value, str)
assert len(full_value) > 0
assert out.is_computed()
assert out.value == full_value
if __name__ == "__main__":
pytest.main([__file__, "-v"]) |
c873cc4 to
779786e
Compare
psschwei
left a comment
There was a problem hiding this comment.
LGTM
Should we add a test for this in another PR?
779786e to
74f6d13
Compare
|
@psschwei I just added tests for this change. Thanks! |
| Tests the changes around lines 278-281 and 352-356 in mellea/core/base.py | ||
| that ensure astream() returns only new content added since the beginning of |
There was a problem hiding this comment.
nit: line numbers aren't permanent so let's not include them
| Tests the changes around lines 278-281 and 352-356 in mellea/core/base.py | |
| that ensure astream() returns only new content added since the beginning of | |
| Tests that astream() returns only new content added since the beginning of |
|
LGTM |
ca17328 to
d2461cb
Compare
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
d2461cb to
0b717e4
Compare
Misc PR
Type of PR
Description
Testing