From eddd3d8ccd2d38b2bbdcd0ad556286f4da6f7cb6 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Sun, 1 Feb 2026 11:49:12 +0000 Subject: [PATCH] Move to ChatBricks for templates --- examples/train_scripts/run_retrieval_agent.sh | 3 +- examples/train_scripts/train_example.sh | 5 +- install.sh | 2 +- pyproject.toml | 3 +- src/agentfly/agents/agent_base.py | 28 +- .../agents/llm_backends/llm_backends.py | 43 +- src/agentfly/templates/__init__.py | 25 - src/agentfly/templates/assistant_policy.py | 29 - src/agentfly/templates/constants.py | 20 - src/agentfly/templates/global_policy.py | 6 - src/agentfly/templates/preprocess.py | 115 - src/agentfly/templates/system_policy.py | 49 - src/agentfly/templates/templates.py | 2003 ----------------- src/agentfly/templates/tool_policy.py | 282 --- src/agentfly/templates/utils.py | 450 ---- src/agentfly/templates/vision_processor.py | 777 ------- verl | 2 +- 17 files changed, 45 insertions(+), 3797 deletions(-) delete mode 100644 src/agentfly/templates/__init__.py delete mode 100644 src/agentfly/templates/assistant_policy.py delete mode 100644 src/agentfly/templates/constants.py delete mode 100644 src/agentfly/templates/global_policy.py delete mode 100644 src/agentfly/templates/preprocess.py delete mode 100644 src/agentfly/templates/system_policy.py delete mode 100644 src/agentfly/templates/templates.py delete mode 100644 src/agentfly/templates/tool_policy.py delete mode 100644 src/agentfly/templates/utils.py delete mode 100644 src/agentfly/templates/vision_processor.py diff --git a/examples/train_scripts/run_retrieval_agent.sh b/examples/train_scripts/run_retrieval_agent.sh index edd7213..a10841a 100644 --- a/examples/train_scripts/run_retrieval_agent.sh +++ b/examples/train_scripts/run_retrieval_agent.sh @@ -46,7 +46,7 @@ entropy_coeff=0.001 kl_loss_type=mse agent_type=hf max_turns=4 -template="qwen2.5" +# template="qwen2.5" tool_parser_name="hermes" total_training_steps=200 project_name="Open" @@ -63,7 +63,6 @@ python3 -m agentfly.cli train \ agent.init_config.agent_type=$agent_type \ agent.init_config.model_name_or_path=$model \ agent.init_config.max_model_len=$max_model_len \ - agent.init_config.template=$template \ agent.init_config.tool_parser_name=$tool_parser_name \ agent.init_config.tools=${tools} \ agent.init_config.reward_name=${reward_name} \ diff --git a/examples/train_scripts/train_example.sh b/examples/train_scripts/train_example.sh index 9bdc1d2..8a2e62e 100644 --- a/examples/train_scripts/train_example.sh +++ b/examples/train_scripts/train_example.sh @@ -27,11 +27,9 @@ max_model_len=8192 mini_batch_size=64 max_new_tokens_per_turn=512 num_chains=8 -num_gpus=1 - # Fully on-policy training +num_gpus=1 ppo_mini_batch_size=${mini_batch_size}*${num_chains} - ppo_micro_batch_size_per_gpu=8 kl_coef=0.001 @@ -63,7 +61,6 @@ python3 -m agentfly.cli train \ data.train_batch_size=${mini_batch_size} \ agent.init_config.agent_type=$agent_type \ agent.init_config.tools=$tools \ - agent.init_config.template=$template \ agent.init_config.model_name_or_path=$model \ agent.init_config.backend=${agent_backend} \ agent.init_config.reward_name=$reward_name \ diff --git a/install.sh b/install.sh index 484db6c..2787365 100644 --- a/install.sh +++ b/install.sh @@ -339,7 +339,7 @@ main() { INSTALLATION_STATUS+=("Python 3.12.x verification: FAILED") fi - if [ -d "AgentFly.egg-info" ] || [ -d "agents" ]; then + if [ -d "src/AgentFly.egg-info" ]; then print_success "✓ AgentFly package" INSTALLATION_STATUS+=("AgentFly package verification: SUCCESS") else diff --git a/pyproject.toml b/pyproject.toml index 7c93d58..635be62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "onnxruntime", "mpmath", "wandb", + "chat-bricks", "diffusers", "google-genai", "chess", @@ -71,7 +72,7 @@ verl = [ "ray[default]", "tensordict", "torchdata", - "transformers", + "transformers<5.0.0", "packaging>=20.0", "uvicorn", "fastapi" diff --git a/src/agentfly/agents/agent_base.py b/src/agentfly/agents/agent_base.py index 06b42d9..39f5207 100644 --- a/src/agentfly/agents/agent_base.py +++ b/src/agentfly/agents/agent_base.py @@ -1,3 +1,4 @@ +import copy import inspect import json import logging @@ -11,8 +12,7 @@ import torch from termcolor import colored -from ..templates import tokenize_conversations -from ..templates.templates import get_template +from chat_bricks import tokenize_conversations, get_template from ..tools.tool_base import BaseTool from ..utils.monitor import JsonlSink, Monitor, WandbSink from .chain.chain_base import ChainRollout @@ -120,7 +120,6 @@ def __init__( self.debug = debug self.backend = backend - self.template = template self.tools = tools self.max_model_len = max_model_len @@ -141,19 +140,26 @@ def __init__( else: self.backend_config = backend_config - self.llm_engine = self._init_llm_engine(model_name_or_path, backend) - # Create appropriate tokenizer for trajectory processing self.tokenizer = create_tokenizer(model_name_or_path) self.processor = create_processor(model_name_or_path) self._reward_fn = reward_fn + # We use model name as template if no template is provided + # For a model name, chat-bricks will use HF's template by default + if template: + self.template = template + else: + self.template = self.model_name_or_path + if self.template is None: self.jinja_template = None else: self.jinja_template = get_template(self.template).jinja_template() + self.llm_engine = self._init_llm_engine(model_name_or_path, backend) + self.wandb_project_name = wandb_project_name self.wandb_run_name = wandb_run_name self.local_cache_dir = local_cache_dir @@ -222,9 +228,6 @@ def _bind_method_tools(self): tool_method.instance = self def _init_llm_engine(self, model_name_or_path: str, backend: str): - assert not (self.template and backend == "client"), ( - "For client backend, we do not support template. Set the template when deploying the model." - ) if isinstance(model_name_or_path, str): # Extract backend-specific configuration config_kwargs = {} @@ -381,8 +384,8 @@ def timing_data(self): @property def trajectories(self): + """Get the trajectories of the agent.""" trajectories = self.get_messages() - return trajectories def tokenize_trajectories( @@ -401,16 +404,10 @@ def tokenize_trajectories( for trajectory in trajectories: messages = trajectory["messages"] messages_list.append(messages) - have_called_tool = False - for message in messages: - if message["role"] == "tool": - have_called_tool = True - break info = {} for key, value in trajectory.items(): if key != "messages": info[key] = value - info["have_called_tool"] = have_called_tool last_response = None @@ -433,6 +430,7 @@ def tokenize_trajectories( return_reward_mask=return_reward_mask, add_generation_prompt=True, concatenate_mm_inputs=concatenate_mm_inputs, + ignore_tool_calls=True, ) position_ids = torch.clip( torch.cumsum(inputs["attention_mask"], dim=-1) - 1, min=0, max=None diff --git a/src/agentfly/agents/llm_backends/llm_backends.py b/src/agentfly/agents/llm_backends/llm_backends.py index a507236..e6fc8b5 100644 --- a/src/agentfly/agents/llm_backends/llm_backends.py +++ b/src/agentfly/agents/llm_backends/llm_backends.py @@ -19,7 +19,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams -from ...templates import Chat +from chat_bricks import Chat from ...utils.vision import image_to_data_uri logger = logging.getLogger(__name__) @@ -272,14 +272,15 @@ async def _generate_single( async def generate_async(self, messages_list: str, **kwargs) -> str: """Generate text from prompt using vLLM""" - max_tokens = kwargs.get("max_tokens", self.max_tokens) - temperature = kwargs.get("temperature", self.temperature) + sampling_params = {} + if "temperature" in kwargs: + sampling_params["temperature"] = kwargs["temperature"] + if "n" in kwargs: + sampling_params["n"] = kwargs["n"] + if "max_tokens" in kwargs: + sampling_params["max_tokens"] = kwargs.get("max_tokens") + sampling_params = SamplingParams(**sampling_params) n = kwargs.get("n", 1) - sampling_params = SamplingParams( - n=1, - max_tokens=max_tokens, - temperature=temperature, - ) tools = kwargs.get("tools", None) prompts, vision_inputs = self.apply_chat_template( @@ -392,15 +393,24 @@ def _convert_to_openai_chat_without_tool_call_processing( We use the pure generated content as the history. So we don't want any tool call to be part of the history. This is used when models are not openai's official models like GPT-4o. """ - messages = copy.deepcopy(messages) + # messages = copy.deepcopy(messages) + # for message in messages: + # if "tool_calls" in message: + # del message["tool_calls"] + # if "tool_call_id" in message: + # del message["tool_call_id"] + # if "tool_choice" in message: + # del message["tool_choice"] + # return messages + + processed_messages = [] for message in messages: - if "tool_calls" in message: - del message["tool_calls"] - if "tool_call_id" in message: - del message["tool_call_id"] - if "tool_choice" in message: - del message["tool_choice"] - return messages + processed_message = {} + for k, v in message.items(): + if k not in ["tool_calls", "tool_call_id", "tool_choice"]: + processed_message[k] = v + processed_messages.append(processed_message) + return processed_messages def _process_messages(self, messages: List[Dict]): new_messages = [] @@ -423,7 +433,6 @@ async def generate_async(self, messages_list: str, **kwargs) -> str: generation_config = {} tensors = torch.ones(len(messages_list), dtype=torch.int64) - # messages_list = [self._convert_to_openai_chat_without_tool_call_processing(messages) for messages in messages_list] messages_list = [self._process_messages(messages) for messages in messages_list] messages_list = [ self._convert_to_openai_chat_without_tool_call_processing(messages) diff --git a/src/agentfly/templates/__init__.py b/src/agentfly/templates/__init__.py deleted file mode 100644 index aa7a4a3..0000000 --- a/src/agentfly/templates/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .global_policy import GlobalPolicy -from .system_policy import SystemPolicy -from .templates import Chat, Template, get_template, register_template -from .tool_policy import JsonFormatter, ToolPolicy -from .utils import ( - compare_hf_template, - tokenize_conversation, - tokenize_conversations, - validate_messages_for_template, -) - -__all__ = [ - "Template", - "Chat", - "get_template", - "register_template", - "tokenize_conversation", - "tokenize_conversations", - "compare_hf_template", - "validate_messages_for_template", - "ToolPolicy", - "JsonFormatter", - "SystemPolicy", - "GlobalPolicy", -] diff --git a/src/agentfly/templates/assistant_policy.py b/src/agentfly/templates/assistant_policy.py deleted file mode 100644 index 5621b9b..0000000 --- a/src/agentfly/templates/assistant_policy.py +++ /dev/null @@ -1,29 +0,0 @@ -import dataclasses -from abc import ABC, abstractmethod -from typing import Callable - - -@dataclasses.dataclass -class AssistantPolicy: - content_processor: Callable[[str], str] = None - - -class AssistantContentProcessor(ABC): - @abstractmethod - def __call__(self, assistant_message: str) -> str: - raise NotImplementedError - - @abstractmethod - def jinja(self) -> str: - raise NotImplementedError - - -class Qwen25AssistantContentProcessor(AssistantContentProcessor): - def __call__(self, content: str) -> str: - if content is None or content == "": - return "" - else: - return "\n" + content - - def jinja(self) -> str: - return """{% if content is none or content == "" %}{% else %}\n\n{{ content }}{% endif %}""" diff --git a/src/agentfly/templates/constants.py b/src/agentfly/templates/constants.py deleted file mode 100644 index 21161f7..0000000 --- a/src/agentfly/templates/constants.py +++ /dev/null @@ -1,20 +0,0 @@ -from enum import Enum, auto - - -class ToolPlacement(Enum): - """ - Where to inject the tool catalogue in the rendered prompt. - """ - - SYSTEM = auto() # inside the system message - FIRST_USER = auto() # as an extra first-user turn - LAST_USER = auto() # appended to the last user turn - SEPARATE = auto() # its own dedicated turn / role - - -class Role(Enum): - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" - ASSISTANT_PREFIX = "assistant_prefix" diff --git a/src/agentfly/templates/global_policy.py b/src/agentfly/templates/global_policy.py deleted file mode 100644 index f7907a3..0000000 --- a/src/agentfly/templates/global_policy.py +++ /dev/null @@ -1,6 +0,0 @@ -import dataclasses - - -@dataclasses.dataclass -class GlobalPolicy: - prefix: str = None diff --git a/src/agentfly/templates/preprocess.py b/src/agentfly/templates/preprocess.py deleted file mode 100644 index fdb5937..0000000 --- a/src/agentfly/templates/preprocess.py +++ /dev/null @@ -1,115 +0,0 @@ -import base64 -import io -from pathlib import Path -from typing import Union -from urllib.parse import urlparse - -import requests -from PIL import Image - - -def open_image_from_any(src: str, *, timeout: int = 10) -> Image.Image: - """ - Open an image from a file path, URL, or base-64 string with Pillow. - - Parameters - ---------- - src : str - The image source. It can be: - • path to an image on disk - • http(s) URL - • plain base-64 or data-URI base-64 - timeout : int, optional - HTTP timeout (s) when downloading from a URL. - - Returns - ------- - PIL.Image.Image - """ - parsed = urlparse(src) - - # 1) Detect a URL ---------------------------------------------------------------- - if parsed.scheme in {"http", "https"}: - # --- requests version - resp = requests.get(src, timeout=timeout) - resp.raise_for_status() - return Image.open(io.BytesIO(resp.content)) - - # --- urllib version (uncomment if you can’t pip-install requests) - # with urllib_request.urlopen(src, timeout=timeout) as fp: - # return Image.open(fp) - - # 2) Detect a base-64 string ------------------------------------------------------ - # • data-URI style: "data:image/png;base64,……" - # • bare base-64 : "iVBORw0KGgoAAAANSUhEUgAABVYA…" - try: - # Strip header if present - if src.startswith("data:"): - header, b64 = src.split(",", 1) - else: - b64 = src - - # “validate=True” quickly rejects non-b64 text without decoding everything - img_bytes = base64.b64decode(b64, validate=True) - return Image.open(io.BytesIO(img_bytes)) - - except (base64.binascii.Error, ValueError): - # Not base-64 → fall through to path handling - pass - - # 3) Treat it as a local file path ---------------------------------------------- - path = Path(src).expanduser().resolve() - if not path.is_file(): - raise FileNotFoundError(f"Image file not found: {path}") - return Image.open(path) - - -def image_to_data_uri(img: Union[Image.Image, str, dict], fmt=None) -> str: - if isinstance(img, dict): - if "bytes" in img: - img = img["bytes"] - - if isinstance(img, Image.Image): - # Try to detect format from PIL Image first - detected_fmt = img.format or fmt or "PNG" - buf = io.BytesIO() - img.save(buf, format=detected_fmt) - b64 = base64.b64encode(buf.getvalue()).decode() - return f"data:image/{detected_fmt.lower()};base64,{b64}" - elif isinstance(img, str): - return img - elif isinstance(img, bytes): - # Try to detect format from magic bytes - detected_fmt = fmt or detect_image_format_from_bytes(img) - return f"data:image/{detected_fmt.lower()};base64,{base64.b64encode(img).decode('utf-8')}" - else: - raise ValueError(f"Invalid image type: {type(img)}") - - -def detect_image_format_from_bytes(img_bytes: bytes) -> str: - """Detect image format from bytes using magic numbers""" - if len(img_bytes) < 4: - return "PNG" # Default fallback - - # Check magic bytes for common formats - if img_bytes.startswith(b"\xff\xd8\xff"): - return "JPEG" - elif img_bytes.startswith(b"\x89PNG\r\n\x1a\n"): - return "PNG" - elif img_bytes.startswith(b"GIF87a") or img_bytes.startswith(b"GIF89a"): - return "GIF" - elif img_bytes.startswith(b"RIFF") and img_bytes[8:12] == b"WEBP": - return "WEBP" - elif img_bytes.startswith(b"BM"): - return "BMP" - else: - return "PNG" # Default fallback - - -def image_to_pil(img: Union[Image.Image, str, dict]) -> Image.Image: - if isinstance(img, str): - return open_image_from_any(img) - elif isinstance(img, dict): - return open_image_from_any(img["bytes"]) - else: - return img diff --git a/src/agentfly/templates/system_policy.py b/src/agentfly/templates/system_policy.py deleted file mode 100644 index db8fcfb..0000000 --- a/src/agentfly/templates/system_policy.py +++ /dev/null @@ -1,49 +0,0 @@ -import dataclasses -import datetime -from abc import ABC, abstractmethod -from typing import Callable - - -@dataclasses.dataclass -class SystemPolicy: - use_system: bool = True # Global control - use_system_without_system_message: bool = True # When no system message is provided, use the system message even it is empty (will use the default one if provided) - use_system_with_tools_provided: bool = True # When tools are provided, use the system message with tools even no system message is provided - content_processor: Callable[[str], str] = None - - -class SystemContentProcessor(ABC): - @abstractmethod - def __call__(self, system_message: str) -> str: - raise NotImplementedError - - @abstractmethod - def jinja(self) -> str: - raise NotImplementedError - - -class Llama32DateProcessor(SystemContentProcessor): - """ - A system content processor that adds date information to system messages. - - In Python mode, it dynamically computes the current date. - In Jinja mode, it provides a template with placeholders that can be processed. - - Usage in Jinja templates: - - The template includes '__CURRENT_DATE__' placeholder - - Replace '__CURRENT_DATE__' with the actual formatted date during processing - - Format should be 'dd MMM yyyy' (e.g., '15 Dec 2024') - - No external context variables required - """ - - def __call__(self, system_message: str, tools: str) -> str: - return f"Cutting Knowledge Date: December 2023\nToday Date: {datetime.datetime.now().strftime('%d %b %Y')}\n\n{system_message}" - - def jinja(self) -> str: - # For Jinja templates used by external systems (like vLLM), we need a self-contained approach - # Since external systems can't provide context variables, we use a placeholder approach - # The external system should replace __CURRENT_DATE__ with the actual date - return """Cutting Knowledge Date: December 2023 -Today Date: __CURRENT_DATE__ - -{{ system_message }}""" diff --git a/src/agentfly/templates/templates.py b/src/agentfly/templates/templates.py deleted file mode 100644 index 4025dae..0000000 --- a/src/agentfly/templates/templates.py +++ /dev/null @@ -1,2003 +0,0 @@ -import dataclasses -import json -import logging -from collections import defaultdict -from copy import deepcopy -from typing import Any, Dict, List, Tuple, Union - -import torch -from transformers import PreTrainedTokenizer - -from ..utils.vision import open_image_from_any -from .assistant_policy import AssistantPolicy, Qwen25AssistantContentProcessor -from .constants import Role, ToolPlacement -from .global_policy import GlobalPolicy -from .system_policy import Llama32DateProcessor, SystemPolicy -from .tool_policy import ( - JsonCompactFormatter, - JsonIndentedFormatter, - ToolMainContentProcessor, - ToolPolicy, -) -from .vision_processor import is_vision_template - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class Template: - """Class that holds all the components of a chat template. Convert messages to string prompts, tokenize messages to token ids, and generate jinja-based chat templates. - - Args: - name: The name of this template - system_template: The system template component - system_template_with_tools: The system template with tool usage component - system_message: The default system message - stop_words: The stop words where the model stops generating (usually EOS token) - tool_template: The tool response template component - user_template: The user template component - user_template_with_tools: The user template with tool usage component - assistant_template: The assistant template component - global_policy: The global policy, controls the behavior of the template - system_policy: The system message policy, controls the behavior of forming the system message - tool_policy: The tool policy for the template, controls the behavior of forming tools. - """ - - # The name of this template - name: str - # The template of the system prompt - system_template: str = "{system_message}" - # The template of the system prompt with tool usage - system_template_with_tools: str = None - # The system message - system_message: str = "" - # Behaviors - # The tool template - tool_template: str = None - # The single tool observation template - tool_observation_template: str = "{observation}" - # The user template - user_template: str = None - user_template_with_tools: str = None - # The assistant template - assistant_template: str = None - # The tool call template - tool_call_template: str = "{tool_call}" - - # Stop criteria (the default one is EOS token) - stop_words: Union[str, List[str]] = None - # Generation prompt - generation_prompt: str = None - # Global policy - global_policy: "GlobalPolicy" = None - # System message policy - system_policy: "SystemPolicy" = None - # Assistant policy - assistant_policy: "AssistantPolicy" = None - # Tool policy for this template - tool_policy: "ToolPolicy" = None - - ## vision part - vision_start: str = None - vision_end: str = None - image_token: str = None - video_token: str = None - - chat_template: str = None - - def __post_init__(self): - """Post-initialization to automatically register vision processor if vision tokens are defined""" - if self.image_token or self.video_token: - self._register_vision_processor() - # Initialise default tool policy if none was provided - if self.tool_policy is None: - self.tool_policy = ToolPolicy() - if self.system_policy is None: - self.system_policy = SystemPolicy() - if self.assistant_policy is None: - self.assistant_policy = AssistantPolicy() - - def _register_vision_processor(self): - """Automatically register a vision processor for this template""" - from .vision_processor import VisionProcessorConfig, register_processor - - # Determine model type based on template name - model_type = self._infer_model_type() - - # Create vision config - config = VisionProcessorConfig( - model_type=model_type, - image_token=self.image_token or "", - video_token=self.video_token or "", - vision_start=self.vision_start or "", - vision_end=self.vision_end or "", - processor_class="AutoProcessor", - expansion_strategy="patch_based", - ) - - # Register the processor - register_processor(self.name, config) - - def _infer_model_type(self) -> str: - """Infer model type from template name""" - name_lower = self.name.lower() - - if "qwen" in name_lower: - return "qwen_vl" - elif "llava" in name_lower: - return "llava" - elif "gemma" in name_lower: - return "gemma3" - elif "paligemma" in name_lower: - return "paligemma" - elif "internvl" in name_lower: - return "internvl" - elif "minicpm" in name_lower: - return "minicpm" - elif "mllama" in name_lower: - return "mllama" - elif "pixtral" in name_lower: - return "pixtral" - elif "video" in name_lower: - return "video_llava" - else: - # Default to patch-based for unknown models - return "patch_based" - - def _supports_tool_call(self) -> bool: - if ( - self.system_template_with_tools or self.user_template_with_tools - ) and self.tool_template: - return True - else: - return False - - def _render_tool_calls(self, tool_calls: List[Dict]) -> str: - tool_calls_str = [] - for tool_call in tool_calls: - if "type" in tool_call and tool_call["type"] == "function": - tool_call = tool_call["function"] - tool_calls_str.append( - self.tool_call_template.format(tool_call=json.dumps(tool_call)) - ) - return "".join(tool_calls_str) - - def _render_tool_observation( - self, tool_observation_content: Union[str, List[Dict]] - ) -> str: - """ - Render parallel tool call observations into a single string. For example, for qwen3 with - following tool observations: obs1, obs2, obs3, the rendered string should be: - obs1\nobs2\nobs3 - """ - - # If there is no single tool response template, probably model does not support - # parallel tool calls and don't need to differentiate between single and multiple - # tool calls, so we just return the content as is. - if self.tool_observation_template is None: - if isinstance(tool_observation_content, str): - return tool_observation_content - elif isinstance(tool_observation_content, list): - return "\n".join([item["text"] for item in tool_observation_content]) - else: - raise ValueError( - f"Invalid tool observation content type: {type(tool_observation_content)}" - ) - - if isinstance(tool_observation_content, str): - text = tool_observation_content - elif isinstance(tool_observation_content, list): - # assert len(content) == 1, "Tool message must be a single message" - # text = content[0]["text"] - text = "" - for item in tool_observation_content: - if item["type"] == "text": - text += item["text"] - elif item["type"] == "image": - text += self.vision_start + self.image_token + self.vision_end - # This is for openai format, since chat completion API only supports image_url - elif item["type"] == "image_url": - text += self.vision_start + self.image_token + self.vision_end - else: - raise ValueError(f"Invalid message type: {item['type']}") - else: - raise ValueError( - f"Invalid tool observation content type: {type(tool_observation_content)}" - ) - - return self.tool_observation_template.format(observation=text) - - def _preprocess_messages(self, messages: List[Dict]) -> List[Dict]: - """Preprocess the messages to remove nested structures in messages - e.g. multi (parallel) tool calls, multiple tool observations. We need to insert them into message, however, if there are multiple - tool calls or observations, we will have to know previous or later messages to insert them correctly. This is not we want for our - Template class. - """ - preprocessed_messages = [] - tool_observations_str = [] - for i, message in enumerate(messages): - if message["role"] == "assistant" and "tool_calls" in message: - tool_calls_str = self._render_tool_calls(message["tool_calls"]) - message["tool_calls_str"] = tool_calls_str - if "content" in message and message["content"] is None: - message["content"] = "" - preprocessed_messages.append(message) - elif message["role"] == "tool" and "content" in message: - tool_observations_str.append( - self._render_tool_observation(message["content"]) - ) - if i == len(messages) - 1 or messages[i + 1]["role"] != "tool": - preprocessed_messages.append( - {"role": "tool", "content": "".join(tool_observations_str)} - ) - tool_observations_str = [] - else: - preprocessed_messages.append(message) - - return preprocessed_messages - - def render( - self, messages: List[Dict], tools=None, add_generation_prompt: bool = False - ) -> str: - """Render the template. - - The heavy lifting is delegated to small, single-purpose helpers so the - high-level flow is immediately apparent: - - 1. _insert_tools – decide where the tool catalogue lives - 2. _encode_turns – encode every conversation turn - 3. _maybe_add_generation_prompt – append the generation prefix if requested - - Args: - messages: The list of messages - tools: The list of tools - add_generation_prompt: Whether to add the generation prefix - - Returns: - prompt: The final prompt string - elements: The list of string *elements* that compose the prompt - roles: The corresponding list of *roles* (used by downstream post-processing) - """ - - # Step 1 – decide tool placement & clone messages - work_messages = self._preprocess_messages(messages) - logger.debug(f"[Template] work_messages: {work_messages}") - work_messages, tools_str, insert_tools_idx = self._insert_tools( - work_messages, tools - ) - - # Step 2 – encode each conversation turn to text tokens - elements, roles = self._encode_turns(work_messages, tools_str, insert_tools_idx) - - # Step 3 – append generation prefix if needed - if add_generation_prompt: - self._maybe_add_generation_prompt(elements, roles) - - # Concatenate the prompt - prompt = "".join(elements) - return prompt, elements, roles - - def _insert_tools(self, messages: List[Dict], tools): - """Clone *messages* and compute where (and how) the tool catalogue - should be injected. - - Returns: - work_messages : List[Dict] - A deepcopy of the original *messages* so we never mutate caller data. - tools_str : Optional[str] - The formatted tool catalogue or *None* if `tools` is falsy. - insert_tools_idx : int - Index of the *user* message that receives the catalogue, or -1 when - no injection is required. - """ - - if tools: - tools_str = self.tool_policy.format_tools(tools) - placement = self.tool_policy.placement - insert_tools_idx = self._find_insert_tools_index(messages, placement) - else: - tools_str = None - insert_tools_idx = -1 - return messages, tools_str, insert_tools_idx - - def _encode_turns( - self, - work_messages: List[Dict], - tools_str: str, - insert_tools_idx: int, - ) -> Tuple[List[str], List[Role]]: - """Convert every message dict into its textual representation while - tracking roles for later masking logic.""" - - elements: List[str] = [] - roles: List[Role] = [] - - # Global prefix comes first (rarely used but must respect ordering) - if self.global_policy and self.global_policy.prefix: - elements.append(self.global_policy.prefix) - roles.append(Role.SYSTEM) - - for i, message in enumerate(work_messages): - current_role = self._detect_role(message["role"]) - - # -------------------------------------------------------------- - # Handle system message insertion on the very first turn - # -------------------------------------------------------------- - if i == 0 and current_role == Role.SYSTEM: - if self.system_policy.use_system: - system_message = self._encode_system_message( - message["content"], tools=tools_str - ) - elements.append(system_message) - roles.append(Role.SYSTEM) - # Whether inserted or not, we skip further handling of this - # message because it's the (optional) system turn itself. - continue - elif i == 0 and current_role != Role.SYSTEM: - if self.system_policy.use_system: - system_message = self._encode_system_message_default( - tools=tools_str - ) - elements.append(system_message) - roles.append(Role.SYSTEM) - # Do *not* `continue` – we still need to encode this first message. - - # -------------------------------------------------------------- - # Encode regular conversation turns - # -------------------------------------------------------------- - if current_role == Role.USER: - if i == insert_tools_idx: - user_message = self._encode_user_message_with_tools( - message["content"], tools=tools_str - ) - else: - user_message = self._encode_user_message(message["content"]) - elements.append(user_message) - roles.append(Role.USER) - - elif current_role == Role.ASSISTANT: - assistant_message = self._encode_assistant_message( - content=message["content"], - tool_calls_str=message["tool_calls_str"] - if "tool_calls_str" in message - else None, - ) - elements.append(assistant_message) - roles.append(Role.ASSISTANT) - - elif current_role == Role.TOOL: - tool_message = self._encode_tool_message(message["content"]) - elements.append(tool_message) - roles.append(Role.TOOL) - else: - raise ValueError(f"Invalid role: {message['role']}") - - return elements, roles - - def _maybe_add_generation_prompt(self, elements: List[str], roles: List[Role]): - """Append the generation prefix so the model knows to continue - generating an assistant response.""" - - generation_prefix, prefix = self._encode_generation_prompt() - elements.append(generation_prefix) - roles.append(Role.ASSISTANT_PREFIX) - - def _detect_role(self, role: str) -> Role: - if role == "system": - return Role.SYSTEM - elif role == "user": - return Role.USER - elif role == "assistant": - return Role.ASSISTANT - elif role == "tool": - return Role.TOOL - else: - raise ValueError(f"Invalid role: {role}") - - def _find_insert_tools_index( - self, work_messages: List[Dict], placement: ToolPlacement - ) -> int: - insert_tools_idx = 0 # Default to insert tools at system message - for i, message in enumerate(work_messages): - if placement == ToolPlacement.SYSTEM: - insert_tools_idx = 0 - elif placement == ToolPlacement.FIRST_USER: - if message.get("role") == "user": - insert_tools_idx = i - break - elif placement == ToolPlacement.LAST_USER: - if message.get("role") == "user": - insert_tools_idx = i - else: - raise ValueError(f"Unhandled ToolPlacement: {placement}") - return insert_tools_idx - - def _encode_system_tools(self, tools: List[Dict]) -> str: - return "\n".join([json.dumps(tool) for tool in tools]) - - def _encode_system_message_default(self, tools=None) -> str: - logger.debug( - f"[Template] Encoding system message default for template: {self.name}" - ) - if not self.system_policy.use_system_without_system_message: - if tools is None: - return "" - else: - # If tools are provided, use the system message with tools - pass - - if self.system_policy.content_processor is not None: - system_message = self.system_policy.content_processor( - self.system_message, tools=tools - ) - else: - system_message = self.system_message - - if tools is None: - return self.system_template.format(system_message=system_message) - else: - if self.system_template_with_tools: - return self.system_template_with_tools.format( - system_message=system_message, tools=tools - ) - else: - return self.system_template.format(system_message=system_message) - - def _encode_system_message(self, content, tools=None) -> str: - # Handle both string content and list content formats - logger.debug(f"[Template] Encoding system message for template: {self.name}") - if isinstance(content, str): - system_message = content - else: - system_message = content[0]["text"] - - if self.system_policy.content_processor is not None: - system_message = self.system_policy.content_processor( - system_message, tools=tools - ) - - if tools is None: - return self.system_template.format(system_message=system_message) - else: - if self.system_template_with_tools is None: - return self.system_template.format(system_message=system_message) - else: - return self.system_template_with_tools.format( - system_message=system_message, tools=tools - ) - - def _encode_user_message_with_tools(self, content, tools: str) -> str: - # Handle both string content and list content formats - if isinstance(content, str): - text = content - else: - text = "" - for item in content: - if item["type"] == "text": - text += item["text"] - elif item["type"] in ["image", "image_url"]: - text += self.vision_start + self.image_token + self.vision_end - elif item["type"] == "video": - text += self.vision_start + self.video_token + self.vision_end - else: - raise ValueError(f"Invalid message type: {item['type']}") - - if self.user_template_with_tools: - user_message = self.user_template_with_tools.format( - content=text, tools=tools - ) - else: - user_message = self.user_template.format(content=text) - return user_message - - def _encode_user_message(self, content) -> str: - # Handle both string content and list content formats - if isinstance(content, str): - text = content - else: - text = "" - for item in content: - if item["type"] == "text": - text += item["text"] - elif item["type"] in ["image", "image_url"]: - text += self.vision_start + self.image_token + self.vision_end - elif item["type"] == "video": - text += self.vision_start + self.video_token + self.vision_end - else: - raise ValueError(f"Invalid message type: {item['type']}") - user_message = self.user_template.format(content=text) - return user_message - - def _encode_assistant_message(self, content, tool_calls_str=None) -> str: - if isinstance(content, str): - text = content - else: - assert len(content) == 1, "Assistant message must be a single message" - text = content[0]["text"] - - if self.assistant_policy.content_processor is not None: - text = self.assistant_policy.content_processor(text) - - if "{tool_calls}" in self.assistant_template: - assistant_message = self.assistant_template.format( - content=text, tool_calls=tool_calls_str if tool_calls_str else "" - ) - else: - assistant_message = self.assistant_template.format(content=text) - - logger.debug( - f"[Template] tool_calls_str: {tool_calls_str}, assistant_message: {assistant_message}" - ) - - return assistant_message - - def _encode_tool_message(self, content) -> str: - """ - Encode the tool message. By default, we use "observations" as placeholder for parallel tool calls and "observation" for single tool call. - """ - # We have already preprocessed the messages, so content should be a string - assert isinstance(content, str), ( - f"Content should be a string, but got {type(content)}" - ) - - if "{observations}" in self.tool_template: - tool_message = self.tool_template.format(observations=content) - else: - tool_message = self.tool_template.format(observation=content) - return tool_message - - def _encode_generation_prompt(self) -> str: - # Use generation prompt if it is set - if "{content}" in self.assistant_template: - prefix = self.assistant_template.split("{content}")[0] - if self.generation_prompt: - generation_prompt = self.generation_prompt - else: - generation_prompt = prefix - else: - raise ValueError( - f"Assistant template {self.assistant_template} does not contain {{content}}" - ) - - return generation_prompt, prefix - - def _split_assistant_message(self, assistant_message: str) -> List[str]: - # Split the assistant message into generation prefix, content, and generation suffix - generation_prefix, prefix = self._encode_generation_prompt() - assert assistant_message.startswith(prefix), ( - f"Assistant message {assistant_message} does not start with {prefix}" - ) - content_suffix = assistant_message[len(prefix) :] - content = content_suffix - suffix = "" - for stop_word in self.stop_words: - if stop_word in content_suffix: - stop_word_index = content_suffix.index(stop_word) - content = content_suffix[: stop_word_index + len(stop_word)] - suffix = content_suffix[stop_word_index + len(stop_word) :] - break - return prefix, content, suffix - - def encode( - self, - messages: List[Dict], - tokenizer: PreTrainedTokenizer, - return_tensors: str = None, - tools=None, - add_generation_prompt=False, - processor=None, - **kwargs, - ) -> str: - """Encode the messages to token ids. - - Args: - messages: The list of messages - tokenizer: The tokenizer - return_tensors: The return tensors - tools: The list of tools - add_generation_prompt: Whether to add the generation prefix - processor: The processor for vision templates - - Returns: - inputs: The dictionary of input ids, attention mask, labels, and action mask - """ - if processor is None and self.supports_vision(): - raise ValueError(f"Processor is required for vision templates: {self.name}") - - if self.supports_vision(): - # Use vision-aware encoding with proper alignment - return self._encode_with_vision_processor( - messages, - tokenizer, - return_tensors, - tools, - add_generation_prompt=add_generation_prompt, - processor=processor, - **kwargs, - ) - else: - # Use standard encoding - return self._encode_standard( - messages, - tokenizer, - return_tensors, - tools, - add_generation_prompt=add_generation_prompt, - **kwargs, - ) - - def _encode_standard( - self, - messages: List[Dict], - tokenizer: PreTrainedTokenizer, - return_tensors: str = None, - tools=None, - add_generation_prompt=False, - **kwargs, - ) -> str: - logger.debug(f"[Template] Encoding standard for template: {self.name}") - """Standard encoding without vision support""" - prompt, elements, roles = self.render( - messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs - ) - elements, mask_flags = self._postprocess_elements(elements, roles) - input_ids = [] - attention_mask = [] - labels = [] - action_mask = [] - - if tokenizer.bos_token: - # If add_bos_token is not set, we assume to add bos token - # There is potential issue if the tokenizer has bos_token but do not add it by default - if getattr(tokenizer, "add_bos_token", True): - input_ids.append(tokenizer.bos_token_id) - attention_mask.append(1) - labels.append(-100) - action_mask.append(0) - - for element, mask_flag in zip(elements, mask_flags): - cur_input_ids = tokenizer.encode(element, add_special_tokens=False) - input_ids.extend(cur_input_ids) - attention_mask.extend([1] * len(cur_input_ids)) - if mask_flag: - labels.extend([-100] * len(cur_input_ids)) - action_mask.extend([0] * len(cur_input_ids)) - else: - labels.extend(cur_input_ids) - action_mask.extend([1] * len(cur_input_ids)) - inputs = dict( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - action_mask=action_mask, - ) - if return_tensors == "pt": - inputs = {k: torch.tensor([v]) for k, v in inputs.items()} - return inputs - - def _encode_with_vision_processor( - self, - messages: List[Dict], - tokenizer: PreTrainedTokenizer, - return_tensors: str = None, - tools=None, - add_generation_prompt=False, - processor=None, - **kwargs, - ) -> str: - logger.debug( - f"[Template] Encoding with vision processor for template: {self.name}" - ) - """Encode with vision processor handling proper alignment""" - from .utils import extract_vision_inputs_from_messages - from .vision_processor import get_processor - - # Get vision processor - vision_processor = get_processor(self.name) - if vision_processor is None: - raise ValueError( - f"No vision processor registered for template: {self.name}" - ) - - # Get base prompt and mask information - prompt, elements, roles = self.render( - messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs - ) - elements, mask_flags = self._postprocess_elements(elements, roles) - - # Extract vision inputs - images, videos = extract_vision_inputs_from_messages(messages) - - logger.debug(f"[Template] images: {len(images)}") - logger.debug(f"[Template] videos: {len(videos)}") - logger.debug(f"[Template] messages: {messages}") - - # Use vision processor with alignment support - return vision_processor.process_for_llm( - prompt=prompt, - elements=elements, - mask_flags=mask_flags, - images=images, - videos=videos, - processor=processor, - tokenizer=tokenizer, - return_tensors=return_tensors, - ) - - def _postprocess_elements(self, elements: List[str], roles) -> List[str]: - # Flag non-assistant messages - new_elements = [] - mask_flags = [] - for i, element in enumerate(elements): - if roles[i] == Role.ASSISTANT: - new_elements.append(element) - mask_flags.append(False) - else: - new_elements.append(element) - mask_flags.append(True) - - # return new_elements, mask_flags - - # merge non-assistant messages and handle the generation prefix and suffixes - merged_elements = [] - merged_mask_flags = [] - - for i, (element, mask_flag) in enumerate(zip(new_elements, mask_flags)): - if i == 0: - prev_element = element - prev_mask_flag = mask_flag - continue - else: - if prev_mask_flag == mask_flag: - # Both previous and current elements are assistant messages - if not mask_flag: - prefix, content, suffix = self._split_assistant_message(element) - merged_elements.append(prefix) - merged_mask_flags.append(True) - merged_elements.append(content) - merged_mask_flags.append(False) - prev_element = suffix - prev_mask_flag = True # We need to mask the suffix - # Both previous and current elements are non-assistant messages - else: - prev_element += element - prev_mask_flag = True - else: - # Previous element is not assistant message, but the current one is - if not mask_flag: - prefix, content, suffix = self._split_assistant_message(element) - prev_element += prefix - prev_mask_flag = True - merged_elements.append(prev_element) - merged_mask_flags.append(prev_mask_flag) - merged_elements.append(content) - merged_mask_flags.append(False) - prev_element = suffix - prev_mask_flag = True - # Previous element is assistant message, but the current one is not - else: - prev_element += element - prev_mask_flag = True - if prev_element != "": - merged_elements.append(prev_element) - merged_mask_flags.append(prev_mask_flag) - return merged_elements, merged_mask_flags - - def supports_vision(self) -> bool: - """Check if this template supports vision processing""" - return is_vision_template(self.name) - - def get_vision_inputs(self, messages: List[Dict]): - vision_inputs = defaultdict(list) - logger.debug(f"[Template] get_vision_inputs: messages: {messages}") - for message in messages: - content = message["content"] - if isinstance(content, list): - for item in content: - if item["type"] == "text": - continue - elif item["type"] in ["image", "image_url", "image_base64"]: - vision_inputs["image"].append( - open_image_from_any(item[item["type"]]) - ) - elif item["type"] == "video": - raise NotImplementedError( - "Video is not supported for chat template." - ) - else: - raise ValueError(f"Invalid message type: {item['type']}") - else: - raise ValueError( - f"Invalid message content: {content}, the content should be a list of dicts" - ) - return vision_inputs - - def jinja_template(self) -> str: - """Interface for getting the Jinja template. - - Returns: - The Jinja template string - """ - if self.chat_template: - return self.chat_template - else: - return self.render_jinja_template() - - def render_jinja_template(self) -> str: - """Return a Hugging-Face style chat-template (Jinja-mini dialect). - - The implementation now mirrors the three-step structure of - `render()` for easier maintenance: - - 1. _jinja_header_constants – immutable `set` statements - 2. _jinja_system_block – first turn / system handling - 3. _jinja_loop_messages – remaining turns & per-role logic - 4. _jinja_generation_block – optional generation prefix - """ - - parts: List[str] = [] - - # 1. Constant header (always first) - parts.extend(self._jinja_header_constants()) - - # 2. System-message handling (depends on presence of tools etc.) - parts.extend(self._jinja_system_block()) - - # 2.5 Pre-compute insert index for user placement - parts.extend(self._jinja_compute_insert_idx()) - - # 3. Loop over remaining messages - parts.extend(self._jinja_loop_messages()) - - # 4. Generation prefix block - parts.extend(self._jinja_generation_block()) - - template_str = "".join(parts) - - # Post-process: Replace __CURRENT_DATE__ placeholder with actual date - if "__CURRENT_DATE__" in template_str: - from datetime import datetime - - current_date = datetime.now().strftime("%d %b %Y") - template_str = template_str.replace("__CURRENT_DATE__", current_date) - - return template_str - - # ------------------------------------------------------------------ - # Private helpers – keep them together for readability - # ------------------------------------------------------------------ - - def _jinja_header_constants(self) -> List[str]: - """Return Jinja `set` statements for all constant strings.""" - - # Compute default system message considering content processor - if self.system_policy.content_processor is not None: - # Apply content processor to system message - processed_system_message = self.system_policy.content_processor( - self.system_message, tools=None - ) # TODO: tools is not used here, but we need to pass it for consistency - default_system = self.system_template.format( - system_message=processed_system_message - ) - else: - default_system = self.system_template.format( - system_message=self.system_message - ) - - system_template_with_tools_raw = ( - self.system_template_with_tools if self.system_template_with_tools else None - ) - - # Split templates - try: - u_pref, u_suff = self.user_template.split("{content}") - a_pref, a_suff = self.assistant_template.split("{content}") - except ValueError as exc: - raise ValueError( - "`user_template` / `assistant_template` must contain `{content}` placeholder" - ) from exc - - if self.tool_template: - if "{observations}" in self.tool_template: - t_pref, t_suff = self.tool_template.split("{observations}") - elif "{observation}" in self.tool_template: - t_pref, t_suff = self.tool_template.split("{observation}") - else: - raise ValueError(f"Invalid tool template: {self.tool_template}") - else: - t_pref, t_suff = "", "" - - # Tokens for images / videos - img_tok = ( - (self.vision_start or "") - + (self.image_token or "") - + (self.vision_end or "") - ) - vid_tok = ( - (self.vision_start or "") - + (self.video_token or "") - + (self.vision_end or "") - ) - - # Check if assistant template supports tool calls - supports_tool_calls_in_template = "{tool_calls}" in self.assistant_template - - # Check if tool template uses observations (plural) or observation (singular) - uses_observations = ( - "{observations}" in self.tool_template if self.tool_template else False - ) - - header = [ - f"{{% set _user_pref = {u_pref!r} %}}", - f"{{% set _user_suff = {u_suff!r} %}}", - f"{{% set _assistant_pref = {a_pref!r} %}}", - f"{{% set _assistant_suff = {a_suff!r} %}}", - f"{{% set _assistant_template = {self.assistant_template!r} %}}", - f"{{% set _tool_pref = {t_pref!r} %}}", - f"{{% set _tool_suff = {t_suff!r} %}}", - f"{{% set _image_token = {img_tok!r} %}}", - f"{{% set _video_token = {vid_tok!r} %}}", - f"{{% set _default_system = {default_system!r} %}}", - f"{{% set _system_message = {self.system_message!r} %}}", - f"{{% set _system_template = {self.system_template!r} %}}", - f"{{% set _tool_placement = {self.tool_policy.placement.name!r} %}}", - f"{{% set _supports_tool_calls = {supports_tool_calls_in_template} %}}", - f"{{% set _uses_observations = {uses_observations} %}}", - ] - - if self.tool_template: - header.append(f"{{% set _tool_template = {self.tool_template!r} %}}") - else: - header.append("{% set _tool_template = '' %}") - - # Add tool_call_template if it exists - if self.tool_call_template: - header.append( - f"{{% set _tool_call_template = {self.tool_call_template!r} %}}" - ) - - # Add tool_observation_template if it exists - if self.tool_observation_template: - header.append( - f"{{% set _tool_observation_template = {self.tool_observation_template!r} %}}" - ) - - # Add generation_prompt if it exists - if self.generation_prompt: - header.append( - f"{{% set _generation_prompt = {self.generation_prompt!r} %}}" - ) - else: - header.append("{% set _generation_prompt = None %}") - - if system_template_with_tools_raw: - header.append( - f"{{% set _system_template_with_tools = {system_template_with_tools_raw!r} %}}" - ) - - # Add user template with tools if it exists - if self.user_template_with_tools: - # Convert double braces to single braces for Jinja compatibility - processed_template = self.user_template_with_tools.replace( - "{{", "{" - ).replace("}}", "}") - header.append( - f"{{% set _user_template_with_tools = {processed_template!r} %}}" - ) - - # ------------------------------------------------------------------ - # Formatter macro for tools (only if the template supports tool calls) - # ------------------------------------------------------------------ - - if self._supports_tool_call(): - # Build a Jinja macro that reproduces ToolPolicy.format_tools behaviour - formatter_snippet = self.tool_policy.formatter.jinja() - - # The snippet usually comes wrapped in "{{ ... }}". We drop the - # outer braces because macro bodies are already an output context. - formatter_body = formatter_snippet - - header.extend( - [ - "{% macro _fmt_tools(tools) %}", - f"{formatter_body}", - "{% endmacro %}", - ] - ) - - # ------------------------------------------------------------------ - # System processor macro (if system policy has a content processor) - # ------------------------------------------------------------------ - - if self.system_policy.content_processor is not None: - # Build a Jinja macro that reproduces the system content processor behaviour - processor_snippet = self.system_policy.content_processor.jinja() - - # The snippet should be a template that expects 'system_message' variable - # We create a macro that can be called with the system message - header.extend( - [ - "{% macro _process_system_message(system_message) %}", - f"{processor_snippet}", - "{% endmacro %}", - ] - ) - - # ------------------------------------------------------------------ - # Assistant processor macro (if assistant policy has a content processor) - # ------------------------------------------------------------------ - - if self.assistant_policy.content_processor is not None: - # Build a Jinja macro that reproduces the assistant content processor behaviour - processor_snippet = self.assistant_policy.content_processor.jinja() - - # The snippet should be a template that expects 'content' variable - # We create a macro that can be called with the assistant content - header.extend( - [ - "{% macro _process_assistant_content(content) %}", - f"{processor_snippet}", - "{% endmacro %}", - ] - ) - - return header - - def _jinja_compute_insert_idx(self) -> List[str]: - """Return Jinja code that pre-computes the index where tools should - be injected for FIRST_USER and LAST_USER placements.""" - - return [ - "{% set _insert_ns = namespace(idx=-1) %}", - "{% if _tool_placement in ['FIRST_USER', 'LAST_USER'] %}", - "{%- for _m in messages -%}", - "{%- if _m['role'] == 'user' -%}", - "{%- if _tool_placement == 'FIRST_USER' and _insert_ns.idx == -1 -%}", - "{% set _insert_ns.idx = loop.index0 %}", - "{%- elif _tool_placement == 'LAST_USER' -%}", - "{% set _insert_ns.idx = loop.index0 %}", - "{%- endif -%}", - "{%- endif -%}", - "{%- endfor -%}", - "{% endif %}", - ] - - def _jinja_system_block(self) -> List[str]: - """Return Jinja code that handles the system message logic.""" - - return [ - # Handle system message first (matching render logic) - "{% if messages and messages[0]['role'] == 'system' %}", - "{% if tools and _system_template_with_tools %}", - "{% if messages[0]['content'] is string %}", - "{% if _process_system_message is defined %}", - "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content']), tools=_fmt_tools(tools)) }}", - "{% else %}", - "{{ _system_template_with_tools.format(system_message=messages[0]['content'], tools=_fmt_tools(tools)) }}", - "{% endif %}", - "{% else %}", - "{% if _process_system_message is defined %}", - "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content'][0]['text']), tools=_fmt_tools(tools)) }}", - "{% else %}", - "{{ _system_template_with_tools.format(system_message=messages[0]['content'][0]['text'], tools=_fmt_tools(tools)) }}", - "{% endif %}", - "{% endif %}", - "{% else %}", - "{% if messages[0]['content'] is string %}", - "{% if _process_system_message is defined %}", - "{% set processed_message = _process_system_message(messages[0]['content']) %}", - "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", - "{% else %}", - "{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content']) %}{{ formatted_system }}", - "{% endif %}", - "{% else %}", - "{% if _process_system_message is defined %}", - "{% set processed_message = _process_system_message(messages[0]['content'][0]['text']) %}", - "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", - "{% else %}", - "{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content'][0]['text']) %}{{ formatted_system }}", - "{% endif %}", - "{% endif %}", - "{% endif %}", - "{% else %}", - "{% if tools and _system_template_with_tools %}", - "{% if _process_system_message is defined %}", - "{{ _system_template_with_tools.format(system_message=_process_system_message(_system_message), tools=_fmt_tools(tools)) }}", - "{% else %}", - "{{ _system_template_with_tools.format(system_message=_system_message, tools=_fmt_tools(tools)) }}", - "{% endif %}", - "{% else %}", - "{% if _process_system_message is defined %}", - "{% set processed_message = _process_system_message(_system_message) %}", - "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", - "{% else %}", - "{{ _default_system }}", - "{% endif %}", - "{% endif %}", - "{% endif %}", - ] - - def _jinja_loop_messages(self) -> List[str]: - """Return Jinja loop that encodes all messages except the first system.""" - - return [ - "{% set _tool_ns = namespace(inserted=False, user_count=0, observations=[]) %}", - # Process remaining messages (skip first if it was system) - "{% for m in messages %}", - "{% if not (loop.first and m['role'] == 'system') %}", - "{% if m['role'] == 'user' %}", - "{% set _tool_ns.user_count = _tool_ns.user_count + 1 %}", - "{% set ns = namespace(txt='') %}", - "{% if m['content'] is string %}", - "{% set ns.txt = m['content'] %}", - "{% else %}", - "{% for item in m['content'] %}", - "{% if item['type'] == 'text' %}", - "{% set ns.txt = ns.txt + item['text'] %}", - "{% elif item['type'] == 'image' %}", - "{% set ns.txt = ns.txt + _image_token %}", - "{% elif item['type'] == 'image_url' %}", - "{% set ns.txt = ns.txt + _image_token %}", - "{% elif item['type'] == 'video' %}", - "{% set ns.txt = ns.txt + _video_token %}", - "{% endif %}", - "{% endfor %}", - "{% endif %}", - "{% if tools and ((_tool_placement == 'FIRST_USER' and _tool_ns.user_count == 1) or (_tool_placement == 'LAST_USER' and loop.index0 == _insert_ns.idx)) and not _tool_ns.inserted %}", - "{% if _user_template_with_tools is defined %}", - "{% set formatted_tools = _fmt_tools(tools) %}", - "{{ _user_template_with_tools | replace('{content}', ns.txt) | replace('{tools}', formatted_tools) }}", - "{% else %}", - "{{ _user_pref }}{{ ns.txt }}{{ _user_suff }}\\n{{ _fmt_tools(tools) }}", - "{% endif %}", - "{% set _tool_ns.inserted = True %}", - "{% else %}", - "{{ _user_pref }}{{ ns.txt }}{{ _user_suff }}", - "{% endif %}", - "{% elif m['role'] == 'assistant' %}", - "{% set ns = namespace(txt='', tool_calls_str='') %}", - "{% if m['content'] is string %}", - "{% set ns.txt = m['content'] %}", - "{% else %}", - "{% if m['content'] %}", - "{% set ns.txt = m['content'][0]['text'] %}", - "{% else %}", - "{% set ns.txt = '' %}", - "{% endif %}", - "{% endif %}", - "{% if _process_assistant_content is defined %}", - "{% set ns.txt = _process_assistant_content(ns.txt) %}", - "{% endif %}", - "{% if m['tool_calls'] and _tool_call_template is defined %}", - "{% for tool_call in m['tool_calls'] %}", - "{% set tc = tool_call %}", - "{% if tool_call['type'] == 'function' %}", - "{% set tc = tool_call['function'] %}", - "{% endif %}", - "{% set tool_call_json = tc | tojson %}", - "{% set tool_call_formatted = _tool_call_template | replace('{tool_call}', tool_call_json) %}", - "{% set ns.tool_calls_str = ns.tool_calls_str + tool_call_formatted %}", - "{% endfor %}", - "{% set ns.tool_calls_str = ns.tool_calls_str %}", - "{% endif %}", - "{% if _supports_tool_calls %}", - "{% set assistant_msg = _assistant_template | replace('{content}', ns.txt) | replace('{tool_calls}', ns.tool_calls_str) %}", - "{{ assistant_msg }}", - "{% else %}", - "{{ _assistant_pref }}{{ ns.txt }}{{ _assistant_suff }}", - "{% endif %}", - "{% elif m['role'] == 'tool' %}", - "{% if loop.first or messages[loop.index0 - 1]['role'] != 'tool' %}", - "{% set _tool_ns.observations = [] %}", - "{% endif %}", - "{% set ns = namespace(txt='') %}", - "{% if m['content'] is string %}", - "{% set ns.txt = m['content'] %}", - "{% else %}", - "{% for item in m['content'] %}", - "{% if item['type'] == 'text' %}", - "{% set ns.txt = ns.txt + item['text'] %}", - "{% elif item['type'] == 'image' %}", - "{% set ns.txt = ns.txt + _image_token %}", - "{% elif item['type'] == 'image_url' %}", - "{% set ns.txt = ns.txt + _image_token %}", - "{% endif %}", - "{% endfor %}", - "{% endif %}", - "{% if _tool_observation_template is defined %}", - "{% set observation_formatted = _tool_observation_template | replace('{observation}', ns.txt) %}", - "{% set _tool_ns.observations = _tool_ns.observations + [observation_formatted] %}", - "{% else %}", - "{% set _tool_ns.observations = _tool_ns.observations + [ns.txt] %}", - "{% endif %}", - "{% if loop.last or (loop.index0 < messages|length - 1 and messages[loop.index0 + 1]['role'] != 'tool') %}", - "{% set observations_combined = _tool_ns.observations | join('') %}", - "{% if _tool_template and _uses_observations %}", - "{{ _tool_template | replace('{observations}', observations_combined) }}", - "{% elif _tool_template %}", - "{{ _tool_template | replace('{observation}', observations_combined) }}", - "{% else %}", - "{{ _tool_pref }}{{ observations_combined }}{{ _tool_suff }}", - "{% endif %}", - "{% endif %}", - "{% endif %}", - "{% endif %}", - "{% endfor %}", - ] - - def _jinja_generation_block(self) -> List[str]: - """Return Jinja code that appends the generation prefix when requested.""" - - return [ - "{% if add_generation_prompt %}", - "{% if _generation_prompt is not none %}", - "{{ _generation_prompt }}", - "{% else %}", - "{{ _assistant_pref }}", - "{% endif %}", - "{% endif %}", - ] - - def render_with_mask( - self, - messages: List[Dict], - add_generation_prompt: bool = False, - tools=None, - **kwargs, - ): - from termcolor import colored - - prompt, elements, roles = self.render( - messages, add_generation_prompt=add_generation_prompt, tools=tools, **kwargs - ) - elements, mask_flags = self._postprocess_elements(elements, roles) - - prompt = "" - for element, mask_flag in zip(elements, mask_flags): - if mask_flag: - prompt += colored(element, "red") - else: - prompt += colored(element, "green") - return prompt, elements, mask_flags - - def set_system_message(self, system_message: str): - """Set the system message.""" - self.system_message = system_message - - def copy(self): - return self.__class__( - name=self.name, - system_template=self.system_template, - system_template_with_tools=self.system_template_with_tools, - system_message=self.system_message, - user_template=self.user_template, - user_template_with_tools=self.user_template_with_tools, - assistant_template=self.assistant_template, - tool_call_template=self.tool_call_template, - tool_template=self.tool_template, - tool_observation_template=self.tool_observation_template, - stop_words=self.stop_words, - generation_prompt=self.generation_prompt, - vision_start=self.vision_start, - vision_end=self.vision_end, - image_token=self.image_token, - video_token=self.video_token, - global_policy=deepcopy(self.global_policy), - system_policy=deepcopy(self.system_policy), - tool_policy=deepcopy(self.tool_policy), - assistant_policy=deepcopy(self.assistant_policy), - chat_template=self.chat_template, - ) - - def dict(self): - return { - "template_name": self.name, - "system_message": self.system_message, - "system_template_with_tools": self.system_template_with_tools, - "stop_words": self.stop_words, - "vision_start": self.vision_start, - "vision_end": self.vision_end, - "image_token": self.image_token, - "video_token": self.video_token, - } - - -class Qwen3Template(Template): - def render( - self, - messages: List[Dict], - tools=None, - add_generation_prompt: bool = False, - enable_thinking: bool = False, - ) -> str: - """Render the Qwen3 template with special thinking logic. - - Args: - messages: The list of messages - tools: The list of tools - add_generation_prompt: Whether to add the generation prefix - enable_thinking: Whether to enable thinking mode - - Returns: - prompt: The final prompt string - elements: The list of string *elements* that compose the prompt - roles: The corresponding list of *roles* (used by downstream post-processing) - """ - - # Step 1 – decide tool placement & clone messages - work_messages = self._preprocess_messages(messages) - logger.debug(f"[Qwen3Template] work_messages: {work_messages}") - work_messages, tools_str, insert_tools_idx = self._insert_tools( - work_messages, tools - ) - - # Step 2 – clean think content from all assistant messages except the last one - work_messages = self._clean_think_content(work_messages) - - # Step 2.5 – reformat think content in the last assistant message if it exists - if work_messages and work_messages[-1].get("role") == "assistant": - work_messages = self._reformat_last_assistant_think_content(work_messages) - - # Step 3 – encode each conversation turn to text tokens - elements, roles = self._encode_turns(work_messages, tools_str, insert_tools_idx) - - # Step 4 – handle special generation prompt logic for Qwen3 - if add_generation_prompt: - self._maybe_add_generation_prompt_qwen3( - elements, roles, enable_thinking, work_messages - ) - elif work_messages and work_messages[-1].get("role") == "assistant": - # Add empty think tokens to the last assistant message if it doesn't already have think tags - self._add_empty_think_to_last_assistant(elements, roles, work_messages) - - # Concatenate the prompt - prompt = "".join(elements) - return prompt, elements, roles - - def _clean_think_content(self, messages: List[Dict]) -> List[Dict]: - """Remove all think content (...) from assistant messages and reformat existing think content.""" - cleaned_messages = [] - for i, message in enumerate(messages): - if message.get("role") == "assistant" and i != len(messages) - 1: - cleaned_message = message.copy() - content = message["content"] - - if isinstance(content, str): - # Remove think content from string - cleaned_content = self._remove_think_tags(content) - elif isinstance(content, list): - # Handle list content format - cleaned_content = [] - for item in content: - if item["type"] == "text": - cleaned_text = self._remove_think_tags(item["text"]) - cleaned_content.append( - {"type": "text", "text": cleaned_text} - ) - else: - cleaned_content.append(item) - elif content is None: - cleaned_content = "" - else: - raise ValueError(f"Invalid content type: {type(content)}") - - cleaned_message["content"] = cleaned_content - cleaned_messages.append(cleaned_message) - else: - cleaned_messages.append(message) - - return cleaned_messages - - def _remove_think_tags(self, text: str) -> str: - """Remove ... tags from text.""" - import re - - # Remove ... tags and their content - pattern = r".*?" - return re.sub(pattern, "", text, flags=re.DOTALL) - - def _has_think_tags(self, text: str) -> bool: - """Check if text contains and tags.""" - return "" in text and "" in text - - def _reformat_think_content(self, text: str) -> str: - """Reformat think content to ensure each think token ends with two newlines.""" - import re - - def replace_think_content(match): - think_content = match.group(1) - # Ensure the think content ends with exactly two newlines - think_content = think_content.rstrip("\n") - return f"\n{think_content}\n\n\n" - - # Find and replace think tags, ensuring proper formatting - pattern = r"(.*?)" - return re.sub(pattern, replace_think_content, text, flags=re.DOTALL) - - def _reformat_last_assistant_think_content( - self, messages: List[Dict] - ) -> List[Dict]: - """Reformat think content in the last assistant message.""" - if not messages or messages[-1].get("role") != "assistant": - return messages - - messages = messages.copy() - last_message = messages[-1].copy() - content = last_message["content"] - - if isinstance(content, str): - # Reformat think content in string - last_message["content"] = self._reformat_think_content(content) - else: - # Handle list content format - reformed_content = [] - for item in content: - if item["type"] == "text": - reformed_text = self._reformat_think_content(item["text"]) - reformed_content.append({"type": "text", "text": reformed_text}) - else: - reformed_content.append(item) - last_message["content"] = reformed_content - - messages[-1] = last_message - return messages - - def _maybe_add_generation_prompt_qwen3( - self, - elements: List[str], - roles: List[Role], - enable_thinking: bool, - work_messages: List[Dict], - ): - """Append the generation prefix with special Qwen3 thinking logic.""" - if enable_thinking: - # Use standard generation prompt - generation_prefix, prefix = self._encode_generation_prompt() - elements.append(generation_prefix) - roles.append(Role.ASSISTANT_PREFIX) - else: - # Check if the last message has think tags - has_existing_think = False - if work_messages and work_messages[-1].get("role") == "assistant": - content = work_messages[-1]["content"] - if isinstance(content, str): - has_existing_think = self._has_think_tags(content) - elif isinstance(content, list): - for item in content: - if item.get("type") == "text" and self._has_think_tags( - item["text"] - ): - has_existing_think = True - break - - generation_prefix, prefix = self._encode_generation_prompt() - if has_existing_think: - # Don't add empty think tokens if think tags already exist - elements.append(generation_prefix) - else: - # Add empty think tokens after the generation prefix - elements.append(generation_prefix + "\n\n\n\n") - roles.append(Role.ASSISTANT_PREFIX) - - def _add_empty_think_to_last_assistant( - self, elements: List[str], roles: List[Role], work_messages: List[Dict] - ): - """Add empty think tokens to the last assistant message if it doesn't already have think tags.""" - if not elements or not roles or not work_messages: - return - - # Check if the last message has think tags - has_existing_think = False - if work_messages[-1].get("role") == "assistant": - content = work_messages[-1]["content"] - if isinstance(content, str): - has_existing_think = self._has_think_tags(content) - elif isinstance(content, list): - for item in content: - if item.get("type") == "text" and self._has_think_tags( - item["text"] - ): - has_existing_think = True - break - - # Only add empty think tokens if no existing think tags - if not has_existing_think: - generation_prefix, prefix = self._encode_generation_prompt() - - # Find the last assistant element - for i in range(len(elements) - 1, -1, -1): - if roles[i] == Role.ASSISTANT: - # Add empty think tokens at the start of the assistant message - elements[i] = ( - prefix + "\n\n\n\n" + elements[i][len(prefix) :] - ) - break - - def _split_assistant_message(self, assistant_message: str) -> List[str]: - # Split the assistant message into generation prefix, content, and generation suffix - generation_prefix, prefix = self._encode_generation_prompt() - assert assistant_message.startswith(prefix), ( - f"Assistant message {assistant_message} does not start with {prefix}" - ) - - # We need to detect whether the assistant message starts with empty think tokens - # If so, we need to set empty think tokens as non-assistant message - if assistant_message.startswith(prefix + "\n\n\n\n"): - prefix = prefix + "\n\n\n\n" - - content_suffix = assistant_message[len(prefix) :] - content = content_suffix - suffix = "" - for stop_word in self.stop_words: - if stop_word in content_suffix: - stop_word_index = content_suffix.index(stop_word) - content = content_suffix[: stop_word_index + len(stop_word)] - suffix = content_suffix[stop_word_index + len(stop_word) :] - break - return prefix, content, suffix - - -class Chat: - def __init__( - self, - template: str, - messages: List[List[str]] = None, - tools=None, - tokenizer: PreTrainedTokenizer = None, - ): - """ - Args: - template: The name of the template to use. - messages: The messages to use for the chat. - tools: The tools to use for the chat. - tokenizer: The tokenizer to use for the chat. - """ - self.template = get_template(template) - self.messages = self.convert_to_hf_format_messages(messages) - self.tokenizer = tokenizer - self.tools = tools - self.flags = {} - - def _detect_labels(self, messages): - message = messages[0] - if "role" in message and "content" in message: - return "role", "content" - elif "from" in message and "value" in message: - return "from", "value" - else: - raise ValueError("Cannot find role label and content label in the data.") - - def _convert_single_message_to_hf_format(self, message: Dict) -> Dict: - if isinstance(message["content"], str): - message["content"] = [{"type": "text", "text": message["content"]}] - elif isinstance(message["content"], list): - for item in message["content"]: - if item["type"] == "text": - continue - elif item["type"] in ["image", "image_url"]: - pass - else: - raise ValueError(f"Invalid message type: {item['type']}") - - def convert_to_hf_format_messages( - self, messages: Union[List[Dict], Dict[str, List[Dict]]] - ) -> List[Dict]: - hf_messages = [] - if messages is None: - return None - role_label, content_label = self._detect_labels(messages) - for message in messages: - hf_message = { - "role": message[role_label], - "content": message[content_label], - } - if "tool_calls" in message: - hf_message["tool_calls"] = message["tool_calls"] - hf_messages.append(hf_message) - - for message in hf_messages: - self._convert_single_message_to_hf_format(message) - - return hf_messages - - def set_messages(self, messages: List[Dict]): - """Set the messages for the chat.""" - self.messages = self.convert_to_hf_format_messages(messages) - - def prompt(self, add_generation_prompt=False, tools=None, **kwargs) -> str: - """Get the prompt for the chat. - - Args: - add_generation_prompt: Whether to add the generation prompt. - tools: The tools to use for the chat. - **kwargs: Additional keyword arguments to pass to the template render method. - - Returns: - The prompt for the chat. - """ - self.flags["add_generation_prompt"] = add_generation_prompt - tools = tools or self.tools - prompt, _, _ = self.template.render( - messages=self.messages, - tools=tools, - add_generation_prompt=add_generation_prompt, - **kwargs, - ) - return prompt - - def prompt_with_mask( - self, add_generation_prompt=False, tools=None, **kwargs - ) -> str: - prompt_with_mask, _, _ = self.template.render_with_mask( - messages=self.messages, - add_generation_prompt=add_generation_prompt, - tools=tools, - **kwargs, - ) - return prompt_with_mask - - def vision_inputs(self) -> List[Any]: - return self.template.get_vision_inputs(self.messages) - - def tokenize( - self, - tokenizer: PreTrainedTokenizer = None, - add_generation_prompt=False, - tools=None, - processor=None, - **kwargs, - ) -> List[int]: - """Tokenize the messages. - - Args: - tokenizer: The tokenizer to use for the chat. - add_generation_prompt: Whether to add the generation prompt. - tools: The tools to use for the chat. - processor: The processor to use for the chat. - - Returns: - inputs (dict): Inputs for helping training. - - input_ids - - attention_mask - - labels - - action_mask - - multi_modal_inputs - """ - if tokenizer is None: - if self.tokenizer is None: - raise ValueError( - "Tokenizer is not set. Set it when initializing the chat or pass it as an argument." - ) - tokenizer = self.tokenizer - - if tools is None: - tools = self.tools - return self.template.encode( - messages=self.messages, - tokenizer=tokenizer, - return_tensors="pt", - tools=tools, - add_generation_prompt=add_generation_prompt, - processor=processor, - **kwargs, - ) - - def append(self, message: Union[Dict]): - self._convert_single_message_to_hf_format(message) - self.messages.append(message) - - -# A global registry for all conversation templates -TEMPLATES: Dict[str, Template] = {} - - -def register_template(template: Template, override: bool = False): - """Register a new conversation template.""" - if not override: - assert template.name not in TEMPLATES, f"{template.name} has been registered." - - TEMPLATES[template.name] = template - - -def get_template(name: str) -> Template: - """Get a conversation template.""" - return TEMPLATES[name].copy() - - -register_template( - Template( - name="qwen2.5-no-system-tool", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", - tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n", - stop_words=["<|im_end|>"], - ) -) - -register_template( - Template( - name="qwen2.5-vl", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_message="You are a helpful assistant.", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", - tool_template="<|im_start|>tool\n{observation}<|im_end|>\n", - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", - stop_words=["<|im_end|>"], - ) -) - -register_template( - Template( - name="qwen2.5-vl-system-tool", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_message="You are a helpful assistant.", - system_template_with_tools="""<|im_start|>system\n{system_message}\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", - tool_template="<|im_start|>tool\n{observation}<|im_end|>\n", - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", - stop_words=["<|im_end|>"], - ) -) - -register_template( - Template( - name="qwen3-vl-instruct", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_template_with_tools="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant{content}{tool_calls}<|im_end|>\n", - generation_prompt="<|im_start|>assistant\n", - tool_call_template="\n\n{tool_call}\n", - tool_template="<|im_start|>user{observations}<|im_end|>\n", - tool_observation_template="\n\n{observation}\n", - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", - stop_words=["<|im_end|>"], - system_policy=SystemPolicy( - use_system_without_system_message=False, - content_processor=lambda system, tools: f"{system}\n\n" - if (system != "" and tools) - else system, - ), - assistant_policy=AssistantPolicy( - content_processor=Qwen25AssistantContentProcessor(), - ), - chat_template="{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - ) -) - -register_template( - Template( - name="qwen2.5", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", - system_template_with_tools="""<|im_start|>system\n{system_message}\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant{content}{tool_calls}<|im_end|>\n", - generation_prompt="<|im_start|>assistant\n", - tool_call_template="\n\n{tool_call}\n", - tool_template="<|im_start|>user{observations}<|im_end|>\n", - tool_observation_template="\n\n{observation}\n", - stop_words=["<|im_end|>"], - assistant_policy=AssistantPolicy( - content_processor=Qwen25AssistantContentProcessor(), - ), - ) -) - - -register_template( - Template( - name="qwen2.5-think", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_message="You are a helpful assistant. To answer the user's question, you first think about the reasoning process and then provide the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .", - # system_template_with_tools="""<|im_start|>You are a helpful assistant. To answer the user's question, you first think about the reasoning process and then provide the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object inside and tags with function name and arguments within XML tags:\n\n\n{{"name": , "arguments": }}\n\n<|im_end|>\n""", - system_template_with_tools="""<|im_start|>You are a helpful assistant. To answer the user's question, you first think about the reasoning process and then call tools or provide the answer. The thinking process is enclosed within tags, i.e., [reasoning process here] [response here].\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n [reasoning process here] \n\n{{"name": , "arguments": }}\n\nYou must think first before calling any tool.<|im_end|>\n""", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", - tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n", - stop_words=["<|im_end|>"], - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", - ) -) - -register_template( - Qwen3Template( - name="qwen3", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_template_with_tools="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant{content}{tool_calls}<|im_end|>\n", - generation_prompt="<|im_start|>assistant\n", - tool_call_template="\n\n{tool_call}\n", - tool_template="<|im_start|>user{observations}<|im_end|>\n", - tool_observation_template="\n\n{observation}\n", - stop_words=["<|im_end|>"], - system_policy=SystemPolicy( - use_system_without_system_message=False, - content_processor=lambda system, tools: f"{system}\n\n" - if (system != "" and tools) - else system, - ), - assistant_policy=AssistantPolicy( - content_processor=Qwen25AssistantContentProcessor(), - ), - chat_template="{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}", - ) -) - -register_template( - Template( - name="qwen3-instruct", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - system_template_with_tools="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant{content}{tool_calls}<|im_end|>\n", - generation_prompt="<|im_start|>assistant\n", - tool_call_template="\n\n{tool_call}\n", - tool_template="<|im_start|>user{observations}<|im_end|>\n", - tool_observation_template="\n\n{observation}\n", - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", - stop_words=["<|im_end|>"], - system_policy=SystemPolicy( - use_system_without_system_message=False, - content_processor=lambda system, tools: f"{system}\n\n" - if (system != "" and tools) - else system, - ), - assistant_policy=AssistantPolicy( - content_processor=Qwen25AssistantContentProcessor(), - ), - chat_template="{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - ) -) - -register_template( - Template( - name="deepseek-prover", - system_template="{system_message}\n", - system_message="You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.", - user_template="### Instruction:\n{content}\n", - assistant_template="### Response:\n{content}\n<|EOT|>\n", - stop_words=["<|EOT|>"], - ) -) - - -# TODO: mistral template has many cornor cases, leave it for now -# register_template( -# Template( -# name="mistral", -# system_template="{system_message}", -# user_template="[INST] {content}[/INST] ", -# user_template_with_tools="[AVAILABLE TOOLS] {tools} [/AVAILABLE TOOLS] [INST] {content}[/INST] ", -# assistant_template="{content}", -# tool_template="{observation}", -# stop_words=[""], -# system_policy=SystemPolicy( -# use_system=False, -# ), -# tool_policy=ToolPolicy( -# placement=ToolPlacement.LAST_USER, -# formatter=JsonCompactFormatter() -# ) -# ) -# ) - -# TODO: system template includes current date -register_template( - Template( - name="llama-3.2", - system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", - system_template_with_tools="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\n{system_message}<|eot_id|>", - user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>", - user_template_with_tools="""<|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}.Do not use variables.\n\n{tools}\n\n{content}<|eot_id|>""", - assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}{tool_calls}<|eot_id|>", - tool_call_template="{tool_call}", - tool_template="""<|start_header_id|>ipython<|end_header_id|>\n\n"{observation}"<|eot_id|>""", - stop_words=["<|eot_id|>"], - system_policy=SystemPolicy( - use_system=True, - content_processor=Llama32DateProcessor(), - ), - tool_policy=ToolPolicy( - placement=ToolPlacement.FIRST_USER, formatter=JsonIndentedFormatter() - ), - ) -) - -register_template( - Template( - name="glm-4", - system_template="<|system|>\n{system_message}", - user_template="<|user|>\n{content}", - assistant_template="<|assistant|>\n{content}", - stop_words=[""], - global_policy=GlobalPolicy(prefix="[gMASK]"), - system_policy=SystemPolicy( - use_system=True, - use_system_without_system_message=False, - ), - ) -) - -register_template( - Template( - name="phi-4", - system_template="<|im_start|>system<|im_sep|>{system_message}<|im_end|>", - user_template="<|im_start|>user<|im_sep|>{content}<|im_end|>", - assistant_template="<|im_start|>assistant<|im_sep|>{content}<|im_end|>", - stop_words=["<|im_end|>"], - ) -) - -# Note: Partial align, some minor new-line problems. -register_template( - Template( - name="nemotron", - system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}<|eot_id|>", - system_template_with_tools="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}{tools}<|eot_id|>""", - user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>", - assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>", - tool_template="<|start_header_id|>user<|end_header_id|>\n\n[{observation}]<|eot_id|>", - stop_words=["<|eot_id|>"], - system_policy=SystemPolicy( - use_system=True, - content_processor=lambda system_message, tools: f"\n{system_message}", - ), - tool_policy=ToolPolicy( - placement=ToolPlacement.SYSTEM, - content_processor=ToolMainContentProcessor(), - formatter=JsonCompactFormatter(), - ), - ) -) - -register_template( - Template( - name="deepseek-r1-distill-qwen", - system_template="{system_message}", - user_template="<|User|>{content}", - assistant_template="<|Assistant|>{content}<|end▁of▁sentence|>", - stop_words=["<|end▁of▁sentence|>"], - generation_prompt="<|Assistant|>\n", - global_policy=GlobalPolicy(prefix="<|begin▁of▁sentence|>"), - system_policy=SystemPolicy( - use_system=True, - use_system_without_system_message=False, - ), - chat_template="{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\\n'}}{% endif %}", - ) -) - -register_template( - Template( - name="llemma", - system_template="{system_message}", - user_template="Input:{content}\n\n", - assistant_template="Response:{content}", - stop_words=[""], - ) -) - - -if __name__ == "__main__": - pass diff --git a/src/agentfly/templates/tool_policy.py b/src/agentfly/templates/tool_policy.py deleted file mode 100644 index 92f9a3c..0000000 --- a/src/agentfly/templates/tool_policy.py +++ /dev/null @@ -1,282 +0,0 @@ -import dataclasses -import json -from abc import ABC, abstractmethod -from typing import Dict # Added for content processor typing -from typing import Any, Callable, List, Tuple - -from .constants import ToolPlacement - - -# Convert ToolFormatter into an abstract base class -class ToolFormatter(ABC): - """ - Strategy that converts an in-memory list[dict] describing tools - into the textual representation expected by the target model. - """ - - @abstractmethod - def format(self, tools: List[Dict]) -> str: - """Format a list of tool dictionaries into a string representation.""" - raise NotImplementedError - - @abstractmethod - def jinja(self) -> str: - """Return a Jinja template that can be used to format the tools.""" - raise NotImplementedError - - -class ToolContentProcessor(ABC): - """ - Strategy that processes the content of a tool before it is serialized. - """ - - @abstractmethod - def __call__(self, tool: Dict) -> Dict: - raise NotImplementedError - - @abstractmethod - def jinja(self) -> str: - """Return a Jinja template that can be used to process the content of a tool.""" - raise NotImplementedError - - -class ToolMainContentProcessor(ToolContentProcessor): - """ - Strategy that processes the main content of a tool before it is serialized. - """ - - def __call__(self, tool: Dict) -> Dict: - assert isinstance(tool, dict), "Tool must be a dictionary" - if "function" in tool: - content = tool["function"] - assert "name" in content, "Tool function must have a name" - assert "parameters" in content, "Tool function must have parameters" - return content - elif "name" in tool and "parameters" in tool: - return tool - else: - raise ValueError( - f"Tool must have a function or name and parameters: {tool}" - ) - - # The main-content extraction cannot be replicated in pure Jinja, so we - # fall back to the identity behaviour at template-generation time. This - # means the processor is *ignored* in frozen chat-templates; users who - # require it must rely on the Python render path. - - def jinja(self) -> str: - # We deliberately document the limitation by returning a simple pass- - # through expression. - return "{{ tool }}" - - -# Make JsonFormatter inherit the ToolFormatter base class -class JsonFormatter(ToolFormatter): - """General JSON formatter with configurable indent, separators, and joiner.""" - - def __init__( - self, - *, - indent: int | None = None, - separators: Tuple[str, str] | None = None, - joiner: str = "\n", - format_as_list: bool = False, - content_processor: ToolContentProcessor = None, - ): - """Create a new JsonFormatter. - - Args: - indent: Indentation level passed to ``json.dumps``. ``None`` means no pretty-print. - separators: Custom separators passed to ``json.dumps``; useful for minification. - joiner: String used to join per-tool JSON strings when ``format_as_list`` is *False*. - format_as_list: If *True*, the entire ``tools`` list is serialised in a single - ``json.dumps`` call, ignoring ``joiner``. This is handy when the target - model expects a single JSON array instead of multiple individual objects. - content_processor: Optional callable applied to each individual tool dictionary - before serialisation. Defaults to the identity function. - """ - self.indent = indent - self.separators = separators - self.joiner = joiner - self.format_as_list = format_as_list - - def format(self, tools: List[Dict]) -> str: # noqa: D401 - """Return a single string obtained by dumping every tool to JSON then joining them. - - Args: - tools: A list of tool dictionaries to be stringified. - - Returns: - A string representation of the tools, formatted according to the - given ``indent``/``separators`` and concatenated with ``joiner``. - """ - # Apply the per-tool content processor first - - if self.format_as_list: - # Serialize the whole list in one go – joiner is irrelevant in this mode. - return json.dumps(tools, indent=self.indent, separators=self.separators) - - # Default behaviour: dump each tool individually then concatenate. - return self.joiner.join( - json.dumps(t, indent=self.indent, separators=self.separators) for t in tools - ) - - # ------------------------------------------------------------------ - # Jinja support - # ------------------------------------------------------------------ - - def _escape_joiner(self, joiner: str) -> str: # local helper - """Return *joiner* escaped so it is safe inside a single‐quoted Jinja - string literal (the HF chat-template parser understands the Python - backslash escapes).""" - - return joiner.replace("\\", "\\\\").replace("'", "\\'") - - def jinja(self) -> str: # noqa: D401 - """Return a **Jinja-mini** snippet that serialises the *tools* variable - with the same settings as :py:meth:`format`. - - The template assumes that a ``tools`` list is present in the Jinja - context. Because the Hugging-Face chat-template dialect only supports - a limited subset of Jinja, we restrict ourselves to `map`, `tojson`, - `join`, and optional indent on a *single* tojson call when - ``format_as_list`` is *True*. - - When ``format_as_list`` is *False* and ``indent`` is specified, we use - a Jinja loop to apply indentation to each individual tool. - """ - - # Serialise whole list -> one tojson call (supports indent argument) - if self.format_as_list: - if self.indent is None: - return "{{ tools | tojson }}" - else: - return f"{{{{ tools | tojson(indent={self.indent}) }}}}" - - # Individual objects: use loop if indent is needed, otherwise use map - if self.indent is not None: - # Use loop to apply indentation to each individual tool - # For joiners containing newlines, we need to avoid whitespace control to preserve them - # For other joiners, we can use whitespace control for cleaner output - - if "\n" in self.joiner: - # Joiner contains newlines - use Jinja's string replacement to convert \n to actual newlines - # We'll create a Jinja variable with the proper newlines - joiner_var = ( - '{% set joiner = "' - + self.joiner.replace("\n", "\\n") - + '" | replace("\\\\n", "\n") %}' - ) - return ( - joiner_var - + f"{{% for tool in tools %}}{{{{ tool | tojson(indent={self.indent}) }}}}{{% if not loop.last %}}{{{{ joiner }}}}{{% endif %}}{{% endfor %}}" - ) - else: - # Joiner doesn't contain newlines - safe to use whitespace control and escaping - joiner_escaped = self._escape_joiner(self.joiner) - return f"{{%- for tool in tools -%}}{{{{ tool | tojson(indent={self.indent}) }}}}{{%- if not loop.last -%}}{joiner_escaped}{{%- endif -%}}{{%- endfor -%}}" - else: - # No indentation needed, use the simpler map approach - joiner_escaped = self._escape_joiner(self.joiner) - return "{{ tools | map('tojson') | join('" + joiner_escaped + "') }}" - - -class JsonMinifiedFormatter(JsonFormatter): - """Single-line JSON objects without extra whitespace (legacy alias).""" - - def __init__( - self, - joiner: str = "\n", - *, - content_processor: Callable[[Dict], Any] | None = None, - ): - super().__init__( - indent=None, - separators=(",", ":"), - joiner=joiner, - content_processor=content_processor, - ) - - -class JsonIndentedFormatter(JsonFormatter): - """ - Pretty printed JSON with configurable indent (default 4). - Frequently required by models like Mistral-v0.3. - (legacy alias) - """ - - def __init__( - self, indent: int = 4, *, joiner: str = "\n\n", format_as_list: bool = False - ): - super().__init__( - indent=indent, separators=None, joiner=joiner, format_as_list=format_as_list - ) - - -class JsonCompactFormatter(JsonFormatter): - """Single-line JSON objects without extra whitespace.""" - - def __init__( - self, - *, - format_as_list: bool = True, - content_processor: Callable[[Dict], Any] | None = None, - ): - super().__init__( - indent=None, - separators=None, - format_as_list=format_as_list, - content_processor=content_processor, - ) - - -class JsonQwenFormatter(JsonFormatter): - """ - JSON formatter for Qwen models. - """ - - def __init__(self): - super().__init__( - indent=None, separators=None, format_as_list=False, content_processor=None - ) - - # No special behaviour – inherits .jinja from JsonFormatter - - -# --------------------------------------------------------------------------- -# Content processors – only implement jinja where feasible -# --------------------------------------------------------------------------- - - -try: - import yaml as _yaml # optional dependency - - class YamlFormatter(ToolFormatter): # type: ignore - def format(self, tools: List[Dict]) -> str: # noqa: D401 - return _yaml.safe_dump(tools, sort_keys=False) -except ModuleNotFoundError: # pragma: no cover - YamlFormatter = None # type: ignore - - -@dataclasses.dataclass -class ToolPolicy: - """ - Encapsulates every configuration decision about how *tools* - appear in the prompt for a given template. - """ - - placement: "ToolPlacement" = ToolPlacement.SYSTEM - content_processor: Callable[[Dict], Any] = None - formatter: ToolFormatter = dataclasses.field( - default_factory=lambda: JsonQwenFormatter() - ) - - def format_tools(self, tools: List[Dict]) -> str: - """ - Convert `tools` into ready-to-inject text according to the chosen formatter. - """ - if self.content_processor is not None: - processed_tools = [self.content_processor(t) for t in tools] - else: - processed_tools = tools - return self.formatter.format(processed_tools) diff --git a/src/agentfly/templates/utils.py b/src/agentfly/templates/utils.py deleted file mode 100644 index 7bf6b0b..0000000 --- a/src/agentfly/templates/utils.py +++ /dev/null @@ -1,450 +0,0 @@ -import logging -import re -from typing import Any - -import torch - -from .templates import Chat, get_template -from .vision_processor import get_processor - -LOGGER = logging.getLogger(__name__) - -ANSI_RE = re.compile(r"\x1b\[[0-9;]*m") # matches any ANSI color/style code - - -def strip_ansi(s: str) -> str: - """Remove ANSI escape sequences from a string.""" - return ANSI_RE.sub("", s) - - -def convert_messages_to_hf_format(messages: list) -> list: - """ - Convert messages to Hugging Face format. - """ - for message in messages: - content = message["content"] - if isinstance(content, list): - for item in content: - if "type" in item: - if item["type"] == "image_url": - item["type"] = "image" - item["image"] = item["image_url"]["url"] - del item["image_url"] - else: - # TODO: handle other types of content - pass - message["content"] = content - return messages - - -def transform_multi_turn_reward_mask(action_mask): - """ - Given a binary action_mask of shape (batch_size, sequence_length), - returns a tensor of the same shape with 1 only at the position where the action_mask is 1 and the next position is 0, - """ - # action_mask: shape (batch_size, sequence_length) - batch_size, seq_length = action_mask.shape - - # Create a shifted version of the attention mask by shifting left. - # For the last column, we append a column of zeros. - shifted = torch.cat( - [ - action_mask[:, 1:], - torch.zeros( - batch_size, 1, dtype=action_mask.dtype, device=action_mask.device - ), - ], - dim=1, - ) - - # Identify positions where the attention_mask is 1 and the shifted mask is 0. - # This means either the next position is 0 or we're at the last element. - last_ones_mask = (action_mask == 1) & (shifted == 0) - - # Optionally, convert boolean mask to integers (0s and 1s). - return last_ones_mask.int() - - -def transform_reward_mask(action_mask): - """ - Given a binary attention_mask of shape (batch_size, sequence_length), - returns a tensor of the same shape with 1 only at the rightmost (last) 1 per row, - and 0 everywhere else. - """ - batch_size, seq_length = action_mask.shape - - # Check for rows that contain at least one 1. - has_one = action_mask.sum(dim=1) > 0 - - # Reverse each row so that the first occurrence of 1 corresponds to the last 1 in the original. - reversed_mask = action_mask.flip(dims=[1]) - - # For each row, find the index of the first occurrence of 1 in the reversed row. - # Note: torch.argmax returns 0 if no element is 1, so we will handle rows with no ones separately. - first_one_idx_reversed = torch.argmax(reversed_mask, dim=1) - - # Convert to the original index position. - last_indices = seq_length - 1 - first_one_idx_reversed - - # Create an output tensor initialized with zeros. - output = torch.zeros_like(action_mask) - - # For rows that have at least one 1, set the found last index to 1. - # We use advanced indexing to assign 1 to the appropriate positions. - row_indices = torch.arange(batch_size) - output[row_indices[has_one], last_indices[has_one]] = 1 - - return output - - -def tokenize_conversation( - messages, - tokenizer, - template, - max_length=None, - tools=None, - processor=None, - return_tensors="pt", - add_generation_prompt=False, - **kwargs, # Additional kwargs for the chat template, e.g. enable_thinking -): - """ - We want to tokenize the whole conversation. But we can't just simply - use get_prompt to get string prompt and tokenize it. Because the loss - can only be computed on model's response. We want: - input_ids - attention_mask - labels: should be -100 for user prompt and input id for model's response - action_mask: should be 0 for user prompt and 1 for model's response - :param messages: - :param tokenizer: - :param conv_template: - :param max_length: - :return: input_ids, attention_mask, labels, action_mask - """ - chat = Chat(template=template, messages=messages, tokenizer=tokenizer) - inputs = chat.tokenize( - tokenizer, - add_generation_prompt=add_generation_prompt, - tools=tools, - processor=processor, - **kwargs, - ) - - if max_length is not None: - inputs["input_ids"] = inputs["input_ids"][:, :max_length] - inputs["attention_mask"] = inputs["attention_mask"][:, :max_length] - if "labels" in inputs: - inputs["labels"] = inputs["labels"][:, :max_length] - if "action_mask" in inputs: - inputs["action_mask"] = inputs["action_mask"][:, :max_length] - - return inputs - - -def convert_inputs_to_vision_inputs( - template: str, - inputs: dict, - processor, # AutoProcessor (not bare tokenizer) - messages: list, -): - """ - NEW PIPELINE: Template processes messages → Human-readable prompt → Vision processor → LLM-ready inputs - - The correct pipeline is: - 1. Template processes messages to get human-readable prompt with single multi-modal tokens - 2. Vision processor handles image/video processing and token expansion - 3. Final result is directly usable by LLMs with model(**inputs) - """ - # Get the vision processor for this template - vision_processor = get_processor(template) - if vision_processor is None: - raise ValueError(f"No vision processor registered for template: {template}") - - # Step 1: Template processes messages to get human-readable prompt - from .templates import Chat - - chat = Chat(template=template, messages=messages, tokenizer=processor.tokenizer) - prompt = ( - chat.prompt() - ) # This gives us human-readable prompt with single multi-modal tokens - - # Step 2: Extract vision inputs from messages - images, videos = extract_vision_inputs_from_messages(messages) - - # Step 3: Vision processor handles the complete pipeline - # This expands tokens and generates LLM-ready inputs - final_inputs = vision_processor.process_for_llm( - prompt=prompt, - images=images, - videos=videos, - processor=processor, - tokenizer=processor.tokenizer, - ) - - return final_inputs - - -def extract_vision_inputs_from_messages(messages: list) -> tuple[list, list]: - """Extract images and videos from messages""" - images, videos = [], [] - - for message in messages: - if isinstance(message.get("content"), list): - for item in message["content"]: - if item.get("type") in ["image", "image_url"]: - if "image" in item: - images.append(item["image"]) - elif "image_url" in item: - images.append(item["image_url"]["url"]) - elif item.get("type") in ["video", "video_url"]: - if "video" in item: - videos.append(item["video"]) - elif "video_url" in item: - videos.append(item["video_url"]["url"]) - - return images, videos - - -def process_prompt_with_vision( - prompt: str, - template: str, - processor: Any, - images: list = None, - videos: list = None, -) -> dict: - """Process a prompt with vision support""" - vision_processor = get_processor(template) - if vision_processor is None: - # If no vision processor, just return tokenized prompt - return processor.tokenizer( - prompt, - return_tensors="pt", - add_special_tokens=True, - padding=True, - truncation=True, - ) - - # Use vision processor to handle the complete pipeline - return vision_processor.process_for_llm( - prompt=prompt, - images=images or [], - videos=videos or [], - processor=processor, - tokenizer=processor.tokenizer, - ) - - -def tokenize_conversations( - messages_list, - tokenizer, - template, - max_length=None, - processor=None, - return_tensors="pt", - return_reward_mask=False, - add_generation_prompt=False, - padding_side="right", - concatenate_mm_inputs=False, -): - batch_input_ids = [] - batch_attention_masks = [] - batch_labels = [] - batch_action_masks = [] - batch_mm_inputs = [] - # TODO: add multiprocessing - for messages in messages_list: - inputs = tokenize_conversation( - messages, - tokenizer, - template, - max_length, - processor=processor, - add_generation_prompt=add_generation_prompt, - ) - batch_input_ids.append(inputs["input_ids"].squeeze(0)) - batch_attention_masks.append(inputs["attention_mask"].squeeze(0)) - batch_labels.append(inputs["labels"].squeeze(0)) - batch_action_masks.append(inputs["action_mask"].squeeze(0)) - mm_inputs = {} - if "pixel_values" in inputs: - mm_inputs["pixel_values"] = inputs["pixel_values"] - else: - mm_inputs["pixel_values"] = None - if "image_grid_thw" in inputs: - mm_inputs["image_grid_thw"] = inputs["image_grid_thw"] - else: - mm_inputs["image_grid_thw"] = None - - batch_mm_inputs.append(mm_inputs) - - if return_tensors == "pt": - # Use pad_token_id from the tokenizer interface - pad_token_id = getattr(tokenizer, "pad_token_id", 0) - - batch_input_ids = torch.nn.utils.rnn.pad_sequence( - batch_input_ids, - batch_first=True, - padding_value=pad_token_id, - padding_side=padding_side, - ) - batch_attention_masks = torch.nn.utils.rnn.pad_sequence( - batch_attention_masks, - batch_first=True, - padding_value=0, - padding_side=padding_side, - ) - batch_labels = torch.nn.utils.rnn.pad_sequence( - batch_labels, - batch_first=True, - padding_value=-100, - padding_side=padding_side, - ) - batch_action_masks = torch.nn.utils.rnn.pad_sequence( - batch_action_masks, - batch_first=True, - padding_value=0, - padding_side=padding_side, - ) - - # convert [{"pixel_values": tensor, "image_grid_thw": tensor}, ...] to {"key1": concat_tensor, "key2": concat_tensor, ...} - concatenated_mm_inputs = {} - if concatenate_mm_inputs: - for key in batch_mm_inputs[0].keys(): - if isinstance(mm_inputs[key], torch.Tensor): - concatenated_mm_inputs[key] = torch.cat( - [ - mm_inputs[key] - for mm_inputs in batch_mm_inputs - if mm_inputs[key] is not None - ], - dim=0, - ) - - inputs = dict( - input_ids=batch_input_ids, - attention_mask=batch_attention_masks, - labels=batch_labels, - action_mask=batch_action_masks, - ) - - if return_reward_mask: - inputs["reward_mask"] = transform_reward_mask(batch_action_masks) - - # Check if we need mm_inputs - mm_keys = list(batch_mm_inputs[0].keys()) - return_mm_inputs = False - for key in mm_keys: - if any(mm_inputs[key] is not None for mm_inputs in batch_mm_inputs): - return_mm_inputs = True - break - - if return_mm_inputs: - if concatenate_mm_inputs: - inputs.update(concatenated_mm_inputs) - else: - inputs["mm_inputs"] = batch_mm_inputs - - return inputs - - -def visualize_template(template, messages=None, tools=None, **kwargs): - if not messages: - messages = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "Guess the number."}, - ] - - chat = Chat(template=template, messages=messages) - print(chat.prompt(tools=tools)) - print(chat.prompt_with_mask(tools=tools)) - - -def visualize_jinja_template(tokenizer, messages=None, tools=None, **kwargs): - if not messages: - messages = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "Guess the number."}, - ] - - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, tools=tools, **kwargs - ) - print(prompt) - - -def compare_hf_template( - tokenizer, - template_name, - messages=None, - tools=None, - add_generation_prompt=False, - **kwargs, -): - official_prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - tools=tools, - add_generation_prompt=add_generation_prompt, - **kwargs, - ) - chat = Chat(template_name, messages=messages, tokenizer=tokenizer) - implemented_prompt = chat.prompt( - add_generation_prompt=add_generation_prompt, tools=tools, **kwargs - ) - is_equal = official_prompt == implemented_prompt - highlighted_prompt = chat.prompt_with_mask( - add_generation_prompt=add_generation_prompt, tools=tools, **kwargs - ) - plain_highlighted_prompt = strip_ansi(highlighted_prompt) - is_equal_between_implemented_prompts = ( - implemented_prompt == plain_highlighted_prompt - ) - jinja_template = chat.template.jinja_template() - - official_jinja_prompt = tokenizer.chat_template - tokenizer.chat_template = jinja_template - implemented_jinja_prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - tools=tools, - add_generation_prompt=add_generation_prompt, - **kwargs, - ) - is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt - tokenizer.chat_template = official_jinja_prompt - return ( - is_equal, - is_equal_between_implemented_prompts, - is_equal_between_jinja_prompts, - official_prompt, - implemented_prompt, - implemented_jinja_prompt, - highlighted_prompt, - ) - - -def validate_messages_for_template( - template_name, messages, tools=None, add_generation_prompt=False -): - """Validate the messages for the given template.""" - if add_generation_prompt and messages[-1]["role"] == "assistant": - return False - - template = get_template(template_name) - - if tools and not template._supports_tool_call(): - return False - - if template_name in ["llama-3.2"]: - for message in messages: - if "tool_calls" in message and len(message["tool_calls"]) > 1: - return False - - return True diff --git a/src/agentfly/templates/vision_processor.py b/src/agentfly/templates/vision_processor.py deleted file mode 100644 index 4eb73fc..0000000 --- a/src/agentfly/templates/vision_processor.py +++ /dev/null @@ -1,777 +0,0 @@ -""" -Comprehensive multi-modal vision processor that handles vision processing separately from template processing. -The pipeline is: Template → Human-readable prompt → Vision processor → LLM-ready inputs. -""" - -import base64 -import inspect -import math -import urllib.parse -import urllib.request -from abc import ABC, abstractmethod -from dataclasses import dataclass -from io import BytesIO -from typing import ( - TYPE_CHECKING, - Any, - BinaryIO, - Dict, - List, - Literal, - Optional, - TypedDict, - Union, -) - -import numpy as np -import PIL -import torch -from PIL.Image import Image as ImageObject -from transformers.image_utils import get_image_size, to_numpy_array - -if TYPE_CHECKING: - from transformers import ProcessorMixin - - class EncodedImage(TypedDict): - path: Optional[str] - bytes: Optional[bytes] - - ImageInput = Union[str, bytes, EncodedImage, BinaryIO, "ImageObject"] - VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] - - class MMProcessor(ProcessorMixin): - patch_size: int - image_seq_length: int - num_additional_image_tokens: int - vision_feature_select_strategy: Literal["default", "full"] - - def _get_number_of_features( - self, orig_height: int, orig_width: int, height: int, width: int - ) -> int: - pass - - -@dataclass -class VisionProcessorConfig: - """Configuration for vision processing""" - - model_type: str - image_token: str - video_token: str - vision_start: str = "" - vision_end: str = "" - processor_class: str = "AutoProcessor" - expansion_strategy: str = "patch_based" - image_max_pixels: int = 16384 * 28 * 28 - image_min_pixels: int = 4 * 28 * 28 - video_max_pixels: int = 16384 * 28 * 28 - video_min_pixels: int = 4 * 28 * 28 - video_fps: float = 2.0 - video_maxlen: int = 128 - - -class VisionProcessor(ABC): - """Abstract base class for vision processing strategies""" - - def __init__(self, config: VisionProcessorConfig): - self.config = config - self._validate_config() - - def _validate_config(self): - """Validate the vision configuration""" - required_fields = ["image_token", "video_token"] - for field in required_fields: - if not hasattr(self.config, field) or getattr(self.config, field) is None: - raise ValueError(f"Missing required field: {field}") - - @abstractmethod - def preprocess_images( - self, images: List["ImageInput"], processor: Any - ) -> Dict[str, Any]: - """Preprocess images for the model""" - pass - - @abstractmethod - def preprocess_videos( - self, videos: List["VideoInput"], processor: Any - ) -> Dict[str, Any]: - """Preprocess videos for the model""" - pass - - @abstractmethod - def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: - """Calculate the number of tokens needed for an image""" - pass - - @abstractmethod - def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int: - """Calculate the number of tokens needed for a video""" - pass - - @abstractmethod - def expand_vision_tokens( - self, - prompt: str, - images: List["ImageInput"], - videos: List["VideoInput"], - processor: Optional[Any], - ) -> str: - """Expand vision tokens in the prompt to their actual token representations""" - pass - - @abstractmethod - def get_mm_inputs( - self, - images: List["ImageInput"], - videos: List["VideoInput"], - processor: Optional[Any], - ) -> Dict[str, torch.Tensor]: - """Generate multi-modal inputs for the model""" - pass - - def process_vision_info(self, messages: List[Dict]) -> Dict[str, torch.Tensor]: - """Process vision information from messages""" - pass - - # def process_for_llm( - # self, - # prompt: str, - # images: List["ImageInput"], - # videos: List["VideoInput"], - # processor: Optional[Any], - # tokenizer: Any, - # ) -> Dict[str, torch.Tensor]: - # """ - # Complete pipeline: expand tokens and generate LLM-ready inputs. - # Returns inputs that can be used directly with model(**inputs). - # """ - # # Step 1: Expand vision tokens in the prompt - # expanded_prompt = self.expand_vision_tokens(prompt, images, videos, processor) - - # # Step 2: Tokenize the expanded prompt - # tokenized_inputs = tokenizer( - # expanded_prompt, - # return_tensors="pt", - # add_special_tokens=True, - # padding=True, - # truncation=True - # ) - - # # Step 3: Generate multi-modal inputs - # mm_inputs = self.get_mm_inputs(images, videos, processor) - - # # Step 4: Combine tokenized inputs with multi-modal inputs - # final_inputs = {**tokenized_inputs, **mm_inputs} - - # return final_inputs - - def process_for_llm( - self, - prompt: str, - elements: List[str], - mask_flags: List[bool], - images: List["ImageInput"], - videos: List["VideoInput"], - processor: Any, - tokenizer: Any, - return_tensors: str = None, - ) -> Dict[str, torch.Tensor]: - """ - Process with proper alignment of all tensors (input_ids, attention_mask, labels, action_mask). - This ensures that when vision tokens are expanded, all corresponding tensors are expanded - at the same positions, maintaining proper alignment for training and inference. - """ - import torch - - # Step 1: Tokenize elements to get base tensors with proper alignment - input_ids = [] - attention_mask = [] - labels = [] - action_mask = [] - - # Add BOS token if needed - if tokenizer.bos_token and tokenizer.add_bos_token: - input_ids.append(tokenizer.bos_token_id) - attention_mask.append(1) - labels.append(-100) - action_mask.append(0) - - images_to_process = [image for image in images] - videos_to_process = [video for video in videos] - # Step 2: Process each element with vision token expansion - for element, mask_flag in zip(elements, mask_flags): - # Check if element contains vision tokens - if self._contains_vision_tokens(element): - # Expand vision tokens in this element - # Number of images and videos should be equal to the total number of vision tokens in the element - # We check whether all images and videos are processed later. - expanded_element = self.expand_vision_tokens( - element, images_to_process, videos_to_process, processor - ) - cur_input_ids = tokenizer.encode( - expanded_element, add_special_tokens=False - ) - else: - cur_input_ids = tokenizer.encode(element, add_special_tokens=False) - - # Add tokens with proper alignment - input_ids.extend(cur_input_ids) - attention_mask.extend([1] * len(cur_input_ids)) - - if mask_flag: - labels.extend([-100] * len(cur_input_ids)) - action_mask.extend([0] * len(cur_input_ids)) - else: - labels.extend(cur_input_ids) - action_mask.extend([1] * len(cur_input_ids)) - - assert len(images_to_process) == len(videos_to_process) == 0, ( - f"All images and videos should be processed, but got {len(images_to_process)} images and {len(videos_to_process)} videos left for vision template {self.config.model_type}." - ) - - # Step 3: Create base inputs - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - "action_mask": action_mask, - } - - # Convert to tensors if requested - if return_tensors == "pt": - inputs = {k: torch.tensor([v]) for k, v in inputs.items()} - - # Step 4: Add vision inputs - mm_inputs = self.get_mm_inputs(images, videos, processor) - inputs.update(mm_inputs) - - return inputs - - def _contains_vision_tokens(self, text: str) -> bool: - """Check if text contains vision tokens""" - return self.config.image_token in text or self.config.video_token in text - - -class PatchBasedProcessor(VisionProcessor): - """Patch-based vision processor (used by Qwen-VL, LLaVA, etc.) - - Supports multiple image input formats: - - File paths (str): "/path/to/image.jpg" - - URLs (str): "https://example.com/image.jpg" - - Base64 strings (str): "data:image/jpeg;base64,/9j/4AAQ..." or raw base64 - - PIL Image objects - - Bytes objects - - File-like objects - - Dict format: {"path": "/path/to/image.jpg"} or {"bytes": b"image_data"} - """ - - def _load_image_from_input(self, image_input) -> "ImageObject": - """Load image from various input formats including URL and base64""" - - # Handle PIL Image objects directly - if hasattr(image_input, "width") and hasattr(image_input, "height"): - return image_input - - # Handle string inputs (file path, URL, or base64) - if isinstance(image_input, str): - # Check if it's a URL - if image_input.startswith(("http://", "https://")): - try: - with urllib.request.urlopen(image_input) as response: - image_data = response.read() - return PIL.Image.open(BytesIO(image_data)) - except Exception as e: - raise ValueError( - f"Failed to load image from URL {image_input}: {e}" - ) - - # Check if it's a base64 string - elif image_input.startswith("data:image/") or image_input.startswith( - "data:application/octet-stream" - ): - # Handle data URL format: data:image/jpeg;base64,/9j/4AAQ... - try: - # Extract the base64 part after the comma - base64_data = image_input.split(",", 1)[1] - image_data = base64.b64decode(base64_data) - return PIL.Image.open(BytesIO(image_data)) - except Exception as e: - raise ValueError(f"Failed to decode base64 image: {e}") - - elif image_input.startswith("iVBORw0KGgo") or len(image_input) > 100: - # Likely a raw base64 string (common for PNG images starting with iVBORw0KGgo) - try: - image_data = base64.b64decode(image_input) - return PIL.Image.open(BytesIO(image_data)) - except Exception as e: - raise ValueError(f"Failed to decode base64 image: {e}") - - # Assume it's a file path - else: - print(f"Loading image from file path: {image_input}") - return PIL.Image.open(image_input) - - # Handle bytes - elif isinstance(image_input, bytes): - return PIL.Image.open(BytesIO(image_input)) - - # Handle file-like objects - elif hasattr(image_input, "read"): - return PIL.Image.open(image_input) - - # Handle dict format - elif isinstance(image_input, dict): - if image_input.get("bytes") is not None: - return PIL.Image.open(BytesIO(image_input["bytes"])) - elif image_input.get("path") is not None: - return PIL.Image.open(image_input["path"]) - else: - raise ValueError("Invalid image dict format") - - else: - raise ValueError(f"Unsupported image input type: {type(image_input)}") - - def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject": - """Preprocess a single image""" - if (image.width * image.height) > self.config.image_max_pixels: - resize_factor = math.sqrt( - self.config.image_max_pixels / (image.width * image.height) - ) - width, height = ( - int(image.width * resize_factor), - int(image.height * resize_factor), - ) - image = image.resize((width, height)) - - if (image.width * image.height) < self.config.image_min_pixels: - resize_factor = math.sqrt( - self.config.image_min_pixels / (image.width * image.height) - ) - width, height = ( - int(image.width * resize_factor), - int(image.height * resize_factor), - ) - image = image.resize((width, height)) - - if image.mode != "RGB": - image = image.convert("RGB") - - return image - - def _regularize_images(self, images: List["ImageInput"]) -> List["ImageObject"]: - """Regularize images to avoid errors""" - results = [] - for image in images: - # Use the new helper method to handle all input formats - pil_image = self._load_image_from_input(image) - results.append(self._preprocess_single_image(pil_image)) - - return results - - def _regularize_videos( - self, videos: List["VideoInput"] - ) -> List[List["ImageObject"]]: - """Regularize videos to avoid errors""" - results = [] - for video in videos: - frames: List["ImageObject"] = [] - - # Check if video is nested images - if isinstance(video, list) and all( - isinstance(frame, (str, BinaryIO, dict)) for frame in video - ): - # Use the new image loading method for each frame - for frame in video: - try: - pil_image = self._load_image_from_input(frame) - frames.append(pil_image) - except Exception as e: - raise ValueError(f"Invalid image found in video frames: {e}") - else: - # Process actual video file - import av - - container = av.open(video, "r") - video_stream = next( - stream for stream in container.streams if stream.type == "video" - ) - - # Calculate sample indices - total_frames = video_stream.frames - if total_frames == 0: # infinite video - sample_indices = np.linspace( - 0, self.config.video_maxlen - 1, self.config.video_maxlen - ).astype(np.int32) - else: - sample_frames = max( - 1, - math.floor( - float(video_stream.duration * video_stream.time_base) - * self.config.video_fps - ), - ) - sample_frames = min( - total_frames, self.config.video_maxlen, sample_frames - ) - sample_indices = np.linspace( - 0, total_frames - 1, sample_frames - ).astype(np.int32) - - container.seek(0) - for frame_idx, frame in enumerate(container.decode(video_stream)): - if frame_idx in sample_indices: - frames.append(frame.to_image()) - - frames = self._regularize_images(frames) - results.append(frames) - - return results - - def preprocess_images( - self, images: List["ImageInput"], processor: Any - ) -> Dict[str, Any]: - """Preprocess images for the model""" - if not images: - return {} - - image_processor = getattr(processor, "image_processor", None) - if image_processor is None: - raise ValueError("Image processor not found") - - # images = self._regularize_images(images) - return image_processor(images, return_tensors="pt") - - def preprocess_videos( - self, videos: List["VideoInput"], processor: Any - ) -> Dict[str, Any]: - """Preprocess videos for the model""" - if not videos: - return {} - - video_processor = getattr( - processor, "video_processor", getattr(processor, "image_processor", None) - ) - if video_processor is None: - raise ValueError("Video processor not found") - - videos = self._regularize_videos(videos) - - # Handle different video processor interfaces - if "videos" in inspect.signature(video_processor.preprocess).parameters: - return video_processor(images=None, videos=videos, return_tensors="pt") - else: - return video_processor(videos, return_tensors="pt") - - def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: - """Calculate the number of tokens needed for an image - - Uses two approaches: - 1. Grid-based (HuggingFace method): Uses image_grid_thw and merge_size - - More accurate for models like Qwen-VL - - Accounts for hierarchical token merging - 2. Patch-based (fallback): Uses image dimensions and patch_size - - Standard approach for most ViT-based models - - Assumes each patch corresponds to one token - """ - if "pixel_values" in image_data: - # Try grid-based calculation first (HuggingFace method) - if "image_grid_thw" in image_data: - grid_info = image_data["image_grid_thw"] - if isinstance(grid_info, torch.Tensor): - grid_prod = grid_info.prod().item() - elif isinstance(grid_info, list): - grid_prod = math.prod(grid_info) - else: - grid_prod = grid_info - - # Get merge_size from processor - merge_size = getattr(processor, "merge_size", 1) - merge_length = merge_size**2 - - num_image_tokens = grid_prod // merge_length - return max(1, num_image_tokens) - - # Fallback to patch-based calculation - height, width = get_image_size( - to_numpy_array(image_data["pixel_values"][0]) - ) - image_seqlen = (height // processor.patch_size) * ( - width // processor.patch_size - ) - if hasattr(processor, "num_additional_image_tokens"): - image_seqlen += processor.num_additional_image_tokens - if ( - hasattr(processor, "vision_feature_select_strategy") - and processor.vision_feature_select_strategy == "default" - ): - image_seqlen -= 1 - return image_seqlen - return 1 - - def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int: - """Calculate the number of tokens needed for a video""" - if "pixel_values" in video_data: - # For videos, we need to calculate based on frames - video_tensor = video_data["pixel_values"][0] - if len(video_tensor.shape) > 3: # Has frame dimension - num_frames = video_tensor.shape[0] - height, width = get_image_size(to_numpy_array(video_tensor[0])) - frame_seqlen = (height // processor.patch_size) * ( - width // processor.patch_size - ) - if hasattr(processor, "num_additional_image_tokens"): - frame_seqlen += processor.num_additional_image_tokens - if ( - hasattr(processor, "vision_feature_select_strategy") - and processor.vision_feature_select_strategy == "default" - ): - frame_seqlen -= 1 - return frame_seqlen * num_frames - else: - # Single frame video - return self.calculate_image_tokens(video_data, processor) - return 1 - - def expand_vision_tokens( - self, - prompt: str, - images: List["ImageInput"], - videos: List["VideoInput"], - processor: Optional[Any], - ) -> str: - """Expand vision tokens in the prompt to their actual token representations""" - if processor is None: - raise ValueError("Processor is required for vision processing") - - # Validate that number of placeholders matches number of inputs - num_image_placeholders = prompt.count(self.config.image_token) - num_video_placeholders = prompt.count(self.config.video_token) - - # if len(images) != num_image_placeholders: - # raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})") - # if len(videos) != num_video_placeholders: - # raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})") - images_slice = [images.pop(0) for _ in range(num_image_placeholders)] - videos_slice = [videos.pop(0) for _ in range(num_video_placeholders)] - # Preprocess images and videos to get individual token counts - - processed_images = [ - self.preprocess_images([image], processor) for image in images_slice - ] - processed_videos = [ - self.preprocess_videos([video], processor) for video in videos_slice - ] - - expanded_prompt = prompt - if self.config.image_token in expanded_prompt and processed_images: - parts = expanded_prompt.split(self.config.image_token) - expanded_parts = [parts[0]] - for idx in range(len(parts) - 1): - if idx < len(processed_images): - processed_image = processed_images[idx] - if "pixel_values" in processed_image: - image_tokens = self.calculate_image_tokens( - processed_image, processor - ) - replacement = self.config.image_token * image_tokens - else: - replacement = self.config.image_token - else: - replacement = self.config.image_token - expanded_parts.append(replacement) - expanded_parts.append(parts[idx + 1]) - expanded_prompt = "".join(expanded_parts) - - # Expand video tokens sequentially - each token gets replaced with its corresponding video - if self.config.video_token in expanded_prompt and processed_videos: - parts = expanded_prompt.split(self.config.video_token) - expanded_parts = [parts[0]] - for idx in range(len(parts) - 1): - if idx < len(processed_videos): - processed_video = processed_videos[idx] - if "pixel_values" in processed_video: - video_tokens = self.calculate_video_tokens( - processed_video, processor - ) - replacement = self.config.video_token * video_tokens - else: - replacement = self.config.video_token - else: - replacement = self.config.video_token - expanded_parts.append(replacement) - expanded_parts.append(parts[idx + 1]) - expanded_prompt = "".join(expanded_parts) - - return expanded_prompt - - def get_mm_inputs( - self, - images: List["ImageInput"], - videos: List["VideoInput"], - processor: Optional[Any], - ) -> Dict[str, torch.Tensor]: - """Generate multi-modal inputs for the model""" - mm_inputs = {} - - # Process images - if images: - mm_inputs.update(self.preprocess_images(images, processor)) - - # Process videos - if videos: - mm_inputs.update(self.preprocess_videos(videos, processor)) - - return mm_inputs - - def process_vision_info(self, messages: List[Dict], processor: Any): - """Process vision information from messages""" - image_message_types = ["image", "image_url", "image_base64"] - images = [] - for message in messages: - for content in message["content"]: - if content["type"] in image_message_types: - content_type = content["type"] - images.append(content[content_type]) - mm_inputs = self.get_mm_inputs(images, [], processor) - return mm_inputs - - -class QwenVLProcessor(PatchBasedProcessor): - """Qwen-VL specific processor with custom image preprocessing""" - - def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject": - """Qwen-VL specific image preprocessing""" - image = super()._preprocess_single_image(image, **kwargs) - - # Qwen-VL specific adjustments - if min(image.width, image.height) < 28: - width, height = max(image.width, 28), max(image.height, 28) - image = image.resize((width, height)) - - if image.width / image.height > 200: - width, height = image.height * 180, image.height - image = image.resize((width, height)) - - if image.height / image.width > 200: - width, height = image.width, image.width * 180 - image = image.resize((width, height)) - - return image - - def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: - """Qwen-VL specific token calculation using grid-based approach""" - if "image_grid_thw" in image_data: - # Use grid information for more accurate token calculation - grid_info = image_data["image_grid_thw"] - if isinstance(grid_info, torch.Tensor): - grid_prod = grid_info.prod().item() - elif isinstance(grid_info, list): - grid_prod = math.prod(grid_info) - else: - grid_prod = grid_info - - # Get merge_size from processor (Qwen-VL typically uses merge_size=2) - merge_size = getattr(processor, "merge_size", 2) - merge_length = merge_size**2 - - num_image_tokens = grid_prod // merge_length - return max(1, num_image_tokens) - - # Fallback to standard calculation - return super().calculate_image_tokens(image_data, processor) - - def expand_vision_tokens( - self, - prompt: str, - images: List["ImageInput"], - videos: List["VideoInput"], - processor: Optional[Any], - ) -> str: - """Qwen-VL specific token expansion with vision tags""" - expanded_prompt = super().expand_vision_tokens( - prompt, images, videos, processor - ) - - return expanded_prompt - - -class LlavaProcessor(PatchBasedProcessor): - """LLaVA specific processor""" - - def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: - """LLaVA specific token calculation""" - if "pixel_values" in image_data: - height, width = get_image_size( - to_numpy_array(image_data["pixel_values"][0]) - ) - image_seqlen = (height // processor.patch_size) * ( - width // processor.patch_size - ) - if hasattr(processor, "num_additional_image_tokens"): - image_seqlen += processor.num_additional_image_tokens - if ( - hasattr(processor, "vision_feature_select_strategy") - and processor.vision_feature_select_strategy == "default" - ): - image_seqlen -= 1 - return image_seqlen - return 1 - - -VISION_PROCESSORS: Dict[str, VisionProcessor] = {} - -model_type_to_processor_class = { - "qwen_vl": QwenVLProcessor, - "llava": LlavaProcessor, - "gemma3": PatchBasedProcessor, - "paligemma": PatchBasedProcessor, - "internvl": PatchBasedProcessor, - "minicpm": PatchBasedProcessor, - "mllama": PatchBasedProcessor, - "pixtral": PatchBasedProcessor, - "video_llava": PatchBasedProcessor, - "patch_based": PatchBasedProcessor, -} - - -def register_processor(template_name: str, config: VisionProcessorConfig): - """Register a vision processor for a template""" - processor_class = model_type_to_processor_class.get(config.model_type) - if processor_class is None: - raise ValueError( - f"No processor class found for model type: {config.model_type}" - ) - VISION_PROCESSORS[template_name] = processor_class(config) - - -def register( - cls, template_name: str, config: VisionProcessorConfig, processor_class: type = None -): - """Register a vision processor for a template""" - if processor_class is not None: - # If processor_class is provided, use it directly - VISION_PROCESSORS[template_name] = processor_class(config) - else: - # Use the global register_processor function - register_processor(template_name, config) - - -def get_processor(template_name: str) -> Optional[VisionProcessor]: - """Get vision processor for a template""" - return VISION_PROCESSORS.get(template_name) - - -def get_processor_config(template_name: str) -> Optional[VisionProcessorConfig]: - """Get vision config for a template""" - processor = get_processor(template_name) - return processor.config if processor else None - - -def is_vision_template(template_name: str) -> bool: - """Check if template supports vision""" - return template_name in VISION_PROCESSORS - - -def list_vision_templates() -> List[str]: - """List all vision-enabled templates""" - return list(VISION_PROCESSORS.keys()) diff --git a/verl b/verl index 3adfe49..a6b57f2 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 3adfe499def9eabcc1a50d2b628a25a8aaa3aceb +Subproject commit a6b57f2659a44d0ccba543dcabbc6bd5425b3689