Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/train_scripts/run_retrieval_agent.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ entropy_coeff=0.001
kl_loss_type=mse
agent_type=hf
max_turns=4
template="qwen2.5"
# template="qwen2.5"
tool_parser_name="hermes"
total_training_steps=200
project_name="Open"
Expand All @@ -63,7 +63,6 @@ python3 -m agentfly.cli train \
agent.init_config.agent_type=$agent_type \
agent.init_config.model_name_or_path=$model \
agent.init_config.max_model_len=$max_model_len \
agent.init_config.template=$template \
agent.init_config.tool_parser_name=$tool_parser_name \
agent.init_config.tools=${tools} \
agent.init_config.reward_name=${reward_name} \
Expand Down
5 changes: 1 addition & 4 deletions examples/train_scripts/train_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ max_model_len=8192
mini_batch_size=64
max_new_tokens_per_turn=512
num_chains=8
num_gpus=1

# Fully on-policy training
num_gpus=1
ppo_mini_batch_size=${mini_batch_size}*${num_chains}

ppo_micro_batch_size_per_gpu=8

kl_coef=0.001
Expand Down Expand Up @@ -63,7 +61,6 @@ python3 -m agentfly.cli train \
data.train_batch_size=${mini_batch_size} \
agent.init_config.agent_type=$agent_type \
agent.init_config.tools=$tools \
agent.init_config.template=$template \
agent.init_config.model_name_or_path=$model \
agent.init_config.backend=${agent_backend} \
agent.init_config.reward_name=$reward_name \
Expand Down
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ main() {
INSTALLATION_STATUS+=("Python 3.12.x verification: FAILED")
fi

if [ -d "AgentFly.egg-info" ] || [ -d "agents" ]; then
if [ -d "src/AgentFly.egg-info" ]; then
print_success "✓ AgentFly package"
INSTALLATION_STATUS+=("AgentFly package verification: SUCCESS")
else
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"onnxruntime",
"mpmath",
"wandb",
"chat-bricks",
"diffusers",
"google-genai",
"chess",
Expand All @@ -71,7 +72,7 @@ verl = [
"ray[default]",
"tensordict",
"torchdata",
"transformers",
"transformers<5.0.0",
"packaging>=20.0",
"uvicorn",
"fastapi"
Expand Down
28 changes: 13 additions & 15 deletions src/agentfly/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import inspect
import json
import logging
Expand All @@ -11,8 +12,7 @@
import torch
from termcolor import colored

from ..templates import tokenize_conversations
from ..templates.templates import get_template
from chat_bricks import tokenize_conversations, get_template
from ..tools.tool_base import BaseTool
from ..utils.monitor import JsonlSink, Monitor, WandbSink
from .chain.chain_base import ChainRollout
Expand Down Expand Up @@ -120,7 +120,6 @@ def __init__(

self.debug = debug
self.backend = backend
self.template = template
self.tools = tools
self.max_model_len = max_model_len

Expand All @@ -141,19 +140,26 @@ def __init__(
else:
self.backend_config = backend_config

self.llm_engine = self._init_llm_engine(model_name_or_path, backend)

# Create appropriate tokenizer for trajectory processing
self.tokenizer = create_tokenizer(model_name_or_path)
self.processor = create_processor(model_name_or_path)

self._reward_fn = reward_fn

# We use model name as template if no template is provided
# For a model name, chat-bricks will use HF's template by default
if template:
self.template = template
else:
self.template = self.model_name_or_path

if self.template is None:
self.jinja_template = None
else:
self.jinja_template = get_template(self.template).jinja_template()

self.llm_engine = self._init_llm_engine(model_name_or_path, backend)

self.wandb_project_name = wandb_project_name
self.wandb_run_name = wandb_run_name
self.local_cache_dir = local_cache_dir
Expand Down Expand Up @@ -222,9 +228,6 @@ def _bind_method_tools(self):
tool_method.instance = self

def _init_llm_engine(self, model_name_or_path: str, backend: str):
assert not (self.template and backend == "client"), (
"For client backend, we do not support template. Set the template when deploying the model."
)
if isinstance(model_name_or_path, str):
# Extract backend-specific configuration
config_kwargs = {}
Expand Down Expand Up @@ -381,8 +384,8 @@ def timing_data(self):

@property
def trajectories(self):
"""Get the trajectories of the agent."""
trajectories = self.get_messages()

return trajectories

def tokenize_trajectories(
Expand All @@ -401,16 +404,10 @@ def tokenize_trajectories(
for trajectory in trajectories:
messages = trajectory["messages"]
messages_list.append(messages)
have_called_tool = False
for message in messages:
if message["role"] == "tool":
have_called_tool = True
break
info = {}
for key, value in trajectory.items():
if key != "messages":
info[key] = value
info["have_called_tool"] = have_called_tool

last_response = None

Expand All @@ -433,6 +430,7 @@ def tokenize_trajectories(
return_reward_mask=return_reward_mask,
add_generation_prompt=True,
concatenate_mm_inputs=concatenate_mm_inputs,
ignore_tool_calls=True,
)
position_ids = torch.clip(
torch.cumsum(inputs["attention_mask"], dim=-1) - 1, min=0, max=None
Expand Down
43 changes: 26 additions & 17 deletions src/agentfly/agents/llm_backends/llm_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams

from ...templates import Chat
from chat_bricks import Chat
from ...utils.vision import image_to_data_uri

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -272,14 +272,15 @@ async def _generate_single(

async def generate_async(self, messages_list: str, **kwargs) -> str:
"""Generate text from prompt using vLLM"""
max_tokens = kwargs.get("max_tokens", self.max_tokens)
temperature = kwargs.get("temperature", self.temperature)
sampling_params = {}
if "temperature" in kwargs:
sampling_params["temperature"] = kwargs["temperature"]
if "n" in kwargs:
sampling_params["n"] = kwargs["n"]
if "max_tokens" in kwargs:
sampling_params["max_tokens"] = kwargs.get("max_tokens")
sampling_params = SamplingParams(**sampling_params)
n = kwargs.get("n", 1)
sampling_params = SamplingParams(
n=1,
max_tokens=max_tokens,
temperature=temperature,
)

tools = kwargs.get("tools", None)
prompts, vision_inputs = self.apply_chat_template(
Expand Down Expand Up @@ -392,15 +393,24 @@ def _convert_to_openai_chat_without_tool_call_processing(
We use the pure generated content as the history. So we don't want any tool call to be part of the history.
This is used when models are not openai's official models like GPT-4o.
"""
messages = copy.deepcopy(messages)
# messages = copy.deepcopy(messages)
# for message in messages:
# if "tool_calls" in message:
# del message["tool_calls"]
# if "tool_call_id" in message:
# del message["tool_call_id"]
# if "tool_choice" in message:
# del message["tool_choice"]
# return messages

processed_messages = []
for message in messages:
if "tool_calls" in message:
del message["tool_calls"]
if "tool_call_id" in message:
del message["tool_call_id"]
if "tool_choice" in message:
del message["tool_choice"]
return messages
processed_message = {}
for k, v in message.items():
if k not in ["tool_calls", "tool_call_id", "tool_choice"]:
processed_message[k] = v
processed_messages.append(processed_message)
return processed_messages

def _process_messages(self, messages: List[Dict]):
new_messages = []
Expand All @@ -423,7 +433,6 @@ async def generate_async(self, messages_list: str, **kwargs) -> str:

generation_config = {}
tensors = torch.ones(len(messages_list), dtype=torch.int64)
# messages_list = [self._convert_to_openai_chat_without_tool_call_processing(messages) for messages in messages_list]
messages_list = [self._process_messages(messages) for messages in messages_list]
messages_list = [
self._convert_to_openai_chat_without_tool_call_processing(messages)
Expand Down
25 changes: 0 additions & 25 deletions src/agentfly/templates/__init__.py

This file was deleted.

29 changes: 0 additions & 29 deletions src/agentfly/templates/assistant_policy.py

This file was deleted.

20 changes: 0 additions & 20 deletions src/agentfly/templates/constants.py

This file was deleted.

6 changes: 0 additions & 6 deletions src/agentfly/templates/global_policy.py

This file was deleted.

Loading