From 7cb2d0ee185561f5d65d0907a2c410400f932212 Mon Sep 17 00:00:00 2001 From: Shu Date: Thu, 22 Jan 2026 21:07:19 -0800 Subject: [PATCH 1/2] Add chess puzzle solving environment, tools, and rewards - ChessPuzzleEnv: Non-Docker environment using python-chess + Stockfish - Supports Lichess puzzle format (FEN + solution moves) - Move parsing in both UCI and SAN formats - Board visualization and state tracking - Chess tools: chess_move, chess_get_state, chess_get_legal_moves - Stateful tools with environment pooling (pool_size=8) - Chess rewards: - chess_puzzle_reward: Dense reward based on Stockfish centipawn evaluation - chess_puzzle_reward_simple: Binary reward (solved/not solved) - Data utilities: Load puzzles from Lichess CSV or JSONL format - Unit tests for the chess environment --- agentfly/envs/__init__.py | 1 + agentfly/envs/chess_env.py | 502 +++++++++++++++++++++ agentfly/rewards/__init__.py | 1 + agentfly/rewards/chess_reward.py | 169 +++++++ agentfly/tests/unit/envs/test_chess_env.py | 217 +++++++++ agentfly/tools/__init__.py | 6 +- agentfly/tools/src/chess/__init__.py | 4 + agentfly/tools/src/chess/tools.py | 112 +++++ agentfly/utils/__init__.py | 9 + agentfly/utils/chess_puzzles.py | 325 +++++++++++++ 10 files changed, 1345 insertions(+), 1 deletion(-) create mode 100644 agentfly/envs/chess_env.py create mode 100644 agentfly/rewards/chess_reward.py create mode 100644 agentfly/tests/unit/envs/test_chess_env.py create mode 100644 agentfly/tools/src/chess/__init__.py create mode 100644 agentfly/tools/src/chess/tools.py create mode 100644 agentfly/utils/chess_puzzles.py diff --git a/agentfly/envs/__init__.py b/agentfly/envs/__init__.py index 5e7bbeb..2316e6f 100644 --- a/agentfly/envs/__init__.py +++ b/agentfly/envs/__init__.py @@ -2,4 +2,5 @@ from .alfworld_env import ALFWorldEnv from .webshop_text_env import WebAgentTextEnv from .scienceworld_env import ScienceWorldEnv +from .chess_env import ChessPuzzleEnv from .manager.enroot import clear_enroot_containers diff --git a/agentfly/envs/chess_env.py b/agentfly/envs/chess_env.py new file mode 100644 index 0000000..0558c72 --- /dev/null +++ b/agentfly/envs/chess_env.py @@ -0,0 +1,502 @@ +# chess_env.py +""" +Chess puzzle environment using python-chess and Stockfish engine. +Unlike Docker-based environments, this runs locally with python-chess library. +""" + +import asyncio +import chess +import chess.engine +from typing import Any, Dict, List, Optional, Tuple, Union +from .env_base import BaseEnv + + +class ChessPuzzleEnv(BaseEnv): + """ + Chess puzzle environment using python-chess and Stockfish. + + This is a non-Docker environment that runs locally with: + - python-chess for board state management and move validation + - Stockfish engine for position evaluation and best move analysis + + Puzzle Format (Lichess-style): + - FEN: starting position + - Moves: solution moves in UCI format (e.g., "e2e4 e7e5 g1f3") + - The first move is the opponent's move that sets up the puzzle + - Subsequent moves are the solution the agent must find + + Attributes: + stockfish_path (str): Path to the Stockfish binary + analysis_time (float): Time in seconds for Stockfish analysis per position + analysis_depth (int): Depth for Stockfish analysis + max_moves (int): Maximum moves allowed per puzzle + + Example: + ```python + env = ChessPuzzleEnv() + await env.start() + obs = await env.reset({ + "puzzle_id": "test1", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7" + }) + result = await env.step("Qxf7") # or "h5f7" in UCI + await env.aclose() + ``` + """ + + def __init__( + self, + stockfish_path: str = "/opt/homebrew/bin/stockfish", + analysis_time: float = 0.1, + analysis_depth: int = 20, + max_moves: int = 20, + ): + """ + Initialize the chess puzzle environment. + + Args: + stockfish_path: Path to the Stockfish binary. Common paths: + - macOS (Homebrew): /opt/homebrew/bin/stockfish + - Linux: /usr/bin/stockfish or /usr/games/stockfish + - Windows: C:\\path\\to\\stockfish.exe + analysis_time: Time in seconds for each Stockfish analysis + analysis_depth: Depth for Stockfish search (higher = stronger but slower) + max_moves: Maximum number of moves allowed per puzzle + """ + super().__init__() + self.stockfish_path = stockfish_path + self.analysis_time = analysis_time + self.analysis_depth = analysis_depth + self.max_moves = max_moves + + # Engine and board state + self._engine: Optional[chess.engine.SimpleEngine] = None + self._board: Optional[chess.Board] = None + + # Puzzle state + self._puzzle_id: Optional[str] = None + self._puzzle_fen: Optional[str] = None + self._solution_moves: List[str] = [] + self._current_solution_idx: int = 0 + self._moves_made: List[str] = [] + self._is_solved: bool = False + + async def start(self) -> None: + """ + Start the Stockfish engine process. + + Unlike Docker-based environments, this simply spawns the Stockfish + subprocess using python-chess's engine API. + + Raises: + FileNotFoundError: If Stockfish binary is not found at the specified path + chess.engine.EngineTerminatedError: If the engine fails to start + """ + loop = asyncio.get_running_loop() + self._engine = await loop.run_in_executor( + None, + chess.engine.SimpleEngine.popen_uci, + self.stockfish_path + ) + self._board = chess.Board() + + async def reset(self, env_args: Optional[Dict[str, Any]] = None) -> str: + """ + Reset to a new puzzle. + + Args: + env_args: Dictionary with puzzle data: + - puzzle_id (str): Unique puzzle identifier + - fen (str): Starting FEN position + - moves (str): Space-separated UCI moves for the solution. + First move is opponent's setup move (auto-played). + + Returns: + Initial observation (board state as text) + + Example: + ```python + obs = await env.reset({ + "puzzle_id": "abc123", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7" # White mates with Qxf7# + }) + ``` + """ + if env_args is None: + # Default puzzle for testing: Scholar's mate position + env_args = { + "puzzle_id": "default", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7" + } + + self._puzzle_id = env_args.get("puzzle_id", "unknown") + self._puzzle_fen = env_args["fen"] + self._solution_moves = env_args.get("moves", "").split() + self._current_solution_idx = 0 + self._moves_made = [] + self._is_solved = False + + # Set up the board + self._board = chess.Board(self._puzzle_fen) + + # In Lichess puzzles, the first move is typically the opponent's move + # that sets up the puzzle. For single-move puzzles (like mate-in-1), + # we don't auto-play the first move since the agent needs to find it. + # We'll detect this by checking if the solution has more than 1 move. + if len(self._solution_moves) > 1: + # Play the setup move (opponent's move) + try: + setup_move = chess.Move.from_uci(self._solution_moves[0]) + if setup_move in self._board.legal_moves: + self._board.push(setup_move) + self._current_solution_idx = 1 + except ValueError: + pass # Invalid move format, skip + + return self._get_observation() + + async def step(self, action: str) -> Union[str, Dict[str, Any]]: + """ + Execute an action in the chess environment. + + Args: + action: Either: + - A chess move in UCI format (e.g., "e2e4") or SAN (e.g., "e4", "Nf3") + - "get_state" to get current board state + - "get_legal_moves" to list all legal moves + - "get_reward" to get current evaluation + + Returns: + Observation string with result of the action, or dict for get_reward + """ + action = action.strip() + action_lower = action.lower() + + if action_lower == "get_state": + return self._get_observation() + + if action_lower == "get_legal_moves": + return self._get_legal_moves_text() + + if action_lower == "get_reward": + return await self._get_evaluation() + + # Try to parse and make the move + return await self._make_move(action) + + async def _make_move(self, move_str: str) -> str: + """Process a move from the agent.""" + if self._is_solved: + return "Puzzle already solved! No more moves needed." + + if self._board.is_game_over(): + return f"Game is over. Result: {self._board.result()}" + + if len(self._moves_made) >= self.max_moves: + return f"Maximum moves ({self.max_moves}) reached." + + # Parse the move (try UCI first, then SAN) + move = self._parse_move(move_str) + if move is None: + return f"Invalid move format: '{move_str}'. Legal moves: {self._get_legal_moves_text()}" + + if move not in self._board.legal_moves: + return f"Illegal move: '{move_str}'. Legal moves: {self._get_legal_moves_text()}" + + # Make the move + san_move = self._board.san(move) # Get SAN before pushing + self._board.push(move) + self._moves_made.append(move.uci()) + + # Check if this matches the solution + expected_move = None + if self._current_solution_idx < len(self._solution_moves): + expected_move = self._solution_moves[self._current_solution_idx] + + is_correct = expected_move and move.uci() == expected_move + + if is_correct: + self._current_solution_idx += 1 + + # Check if puzzle is solved + if self._current_solution_idx >= len(self._solution_moves): + self._is_solved = True + return f"Correct! {san_move} - Puzzle solved!\n\n{self._get_observation()}" + + # Play the opponent's response (next move in solution) + if self._current_solution_idx < len(self._solution_moves): + try: + response_uci = self._solution_moves[self._current_solution_idx] + response_move = chess.Move.from_uci(response_uci) + if response_move in self._board.legal_moves: + response_san = self._board.san(response_move) + self._board.push(response_move) + self._current_solution_idx += 1 + + # Check again if solved after response + if self._current_solution_idx >= len(self._solution_moves): + self._is_solved = True + return f"Correct! {san_move}\nOpponent played: {response_san}\nPuzzle solved!\n\n{self._get_observation()}" + + return f"Correct! {san_move}\nOpponent played: {response_san}\nYour turn to continue.\n\n{self._get_observation()}" + except ValueError: + pass + + return f"Correct! {san_move}\n\n{self._get_observation()}" + else: + # Wrong move - still allow it but note it's not the solution + return f"Move played: {san_move} (not the puzzle solution)\n\n{self._get_observation()}" + + def _parse_move(self, move_str: str) -> Optional[chess.Move]: + """Parse a move string in UCI or SAN format.""" + move_str = move_str.strip() + + # Try UCI format first (e.g., "e2e4", "e7e8q" for promotion) + try: + return chess.Move.from_uci(move_str.lower()) + except ValueError: + pass + + # Try SAN format (e.g., "e4", "Nf3", "O-O", "Qxf7#") + try: + return self._board.parse_san(move_str) + except ValueError: + pass + + return None + + def _get_observation(self) -> str: + """Generate a text observation of the current board state.""" + parts = [ + f"FEN: {self._board.fen()}", + "", + str(self._board), # ASCII board representation + "", + f"Turn: {'White' if self._board.turn else 'Black'}", + f"Legal moves: {len(list(self._board.legal_moves))} available", + ] + + if self._board.is_check(): + parts.append("Status: CHECK!") + + if self._is_solved: + parts.append("Status: PUZZLE SOLVED!") + elif self._board.is_checkmate(): + winner = "Black" if self._board.turn else "White" + parts.append(f"Status: CHECKMATE! {winner} wins.") + elif self._board.is_stalemate(): + parts.append("Status: STALEMATE! Draw.") + elif self._board.is_insufficient_material(): + parts.append("Status: Draw by insufficient material.") + + parts.append(f"\nMoves played: {', '.join(self._moves_made) if self._moves_made else 'None'}") + + return "\n".join(parts) + + def _get_legal_moves_text(self) -> str: + """Get legal moves as a formatted string.""" + moves = [] + for move in self._board.legal_moves: + san = self._board.san(move) + moves.append(f"{move.uci()} ({san})") + return ", ".join(sorted(moves)) if moves else "No legal moves" + + async def _get_evaluation(self) -> Dict[str, Any]: + """Get Stockfish evaluation of current position.""" + if self._engine is None: + return {"observation": "Engine not available", "reward": 0.0} + + loop = asyncio.get_running_loop() + try: + info = await loop.run_in_executor( + None, + lambda: self._engine.analyse( + self._board, + chess.engine.Limit(time=self.analysis_time, depth=self.analysis_depth) + ) + ) + + score = info.get("score") + if score: + pov_score = score.white() if self._board.turn else score.black() + cp = pov_score.score(mate_score=10000) + + if cp is not None: + # Normalize to 0-1 range (sigmoid-like) + normalized = max(0.0, min(1.0, (cp + 500) / 1000)) + return { + "observation": f"Evaluation: {cp/100:.2f} pawns (centipawns: {cp})", + "reward": normalized, + "centipawns": cp, + "is_solved": self._is_solved, + } + else: + mate = pov_score.mate() + if mate is not None: + reward = 1.0 if mate > 0 else 0.0 + return { + "observation": f"Mate in {abs(mate)}" if mate > 0 else f"Getting mated in {abs(mate)}", + "reward": reward, + "mate_in": mate, + "is_solved": self._is_solved, + } + except Exception as e: + return {"observation": f"Evaluation error: {e}", "reward": 0.0} + + return {"observation": "Unable to evaluate position", "reward": 0.0} + + async def get_best_move(self) -> Tuple[str, int]: + """ + Get the best move according to Stockfish with evaluation. + + Returns: + Tuple of (best_move_uci, centipawns) + """ + if self._engine is None: + return ("", 0) + + loop = asyncio.get_running_loop() + try: + result = await loop.run_in_executor( + None, + lambda: self._engine.analyse( + self._board, + chess.engine.Limit(time=self.analysis_time, depth=self.analysis_depth) + ) + ) + + best_move = result.get("pv", [None])[0] + score = result.get("score") + + cp = 0 + if score: + pov_score = score.white() if self._board.turn else score.black() + cp = pov_score.score(mate_score=10000) or 0 + + return (best_move.uci() if best_move else "", cp) + except Exception: + return ("", 0) + + async def evaluate_move(self, move_uci: str) -> int: + """ + Evaluate a specific move by comparing position before and after. + + Args: + move_uci: Move in UCI format to evaluate + + Returns: + Centipawn difference (positive = good move, negative = bad move) + """ + if self._engine is None: + return 0 + + try: + move = chess.Move.from_uci(move_uci) + except ValueError: + return -10000 # Invalid move + + if move not in self._board.legal_moves: + return -10000 # Illegal move penalty + + loop = asyncio.get_running_loop() + + try: + # Get best move evaluation before this move + best_move_uci, best_cp = await self.get_best_move() + + # Make the move temporarily + self._board.push(move) + + # Evaluate position after move (from opponent's perspective, so negate) + after_info = await loop.run_in_executor( + None, + lambda: self._engine.analyse( + self._board, + chess.engine.Limit(time=self.analysis_time, depth=self.analysis_depth) + ) + ) + + # Undo the move + self._board.pop() + + after_score = after_info.get("score") + if after_score: + # Get score from the perspective of the player who just moved + pov_score = after_score.black() if self._board.turn else after_score.white() + after_cp = pov_score.score(mate_score=10000) or 0 + else: + after_cp = 0 + + # If this was the best move, return 0 (no loss) + if move_uci == best_move_uci: + return 0 + + # Return centipawn loss (negative = worse than best move) + return after_cp - best_cp + + except Exception: + return 0 + + @property + def is_solved(self) -> bool: + """Whether the puzzle has been solved correctly.""" + return self._is_solved + + @property + def puzzle_id(self) -> str: + """The current puzzle's ID.""" + return self._puzzle_id or "" + + @property + def moves_made(self) -> List[str]: + """List of moves made by the agent (UCI format).""" + return self._moves_made.copy() + + @property + def board(self) -> chess.Board: + """The current chess board state.""" + return self._board + + async def aclose(self) -> None: + """ + Close the Stockfish engine and release resources. + """ + if self._engine: + try: + self._engine.quit() + except Exception: + pass + self._engine = None + + def close(self) -> None: + """ + Synchronous close - quit the engine. + """ + if self._engine: + try: + self._engine.quit() + except Exception: + pass + self._engine = None + + @staticmethod + async def acquire(): + """ + Factory method to create and start a chess environment. + + Returns: + ChessPuzzleEnv: A fully initialized environment ready for use + + Example: + ```python + env = await ChessPuzzleEnv.acquire() + obs = await env.reset({"fen": "...", "moves": "..."}) + await env.aclose() + ``` + """ + env = ChessPuzzleEnv() + await env.start() + return env diff --git a/agentfly/rewards/__init__.py b/agentfly/rewards/__init__.py index ee2e107..b137478 100644 --- a/agentfly/rewards/__init__.py +++ b/agentfly/rewards/__init__.py @@ -19,4 +19,5 @@ from .gui_reward import gui_reward from .vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_reward from .vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_pass_reward +from .chess_reward import chess_puzzle_reward, chess_puzzle_reward_simple diff --git a/agentfly/rewards/chess_reward.py b/agentfly/rewards/chess_reward.py new file mode 100644 index 0000000..dd5e719 --- /dev/null +++ b/agentfly/rewards/chess_reward.py @@ -0,0 +1,169 @@ +# chess_reward.py +""" +Chess puzzle reward functions for AgentFly. + +Provides two reward functions: +- chess_puzzle_reward: Dense reward based on Stockfish evaluation (move quality) +- chess_puzzle_reward_simple: Binary reward (solved/not solved) +""" + +from typing import Dict, Any +from ..envs.chess_env import ChessPuzzleEnv +from .reward_base import reward + + +@reward( + name="chess_puzzle_reward", + env_cls=ChessPuzzleEnv, + pool_size=8 +) +async def chess_puzzle_reward(prediction: str, env: ChessPuzzleEnv) -> Dict[str, Any]: + """ + Calculate reward for chess puzzle solving based on Stockfish evaluation. + + This reward function provides: + 1. Dense reward based on move quality (centipawn evaluation) + 2. Bonus for solving the puzzle correctly + 3. Penalty for making suboptimal moves + + The reward is structured to encourage: + - Finding the best moves (matching Stockfish recommendations) + - Solving puzzles completely + - Making progress even with imperfect moves + + Args: + prediction (str): The agent's final response/output (not used directly). + env (ChessPuzzleEnv): The chess puzzle environment instance. + + Returns: + dict: A dictionary containing: + - reward (float): The calculated reward value (0.0 to 1.0+) + - is_solved (bool): Whether the puzzle was solved correctly + - moves_made (int): Number of moves made + - best_move_matches (int): How many moves matched Stockfish's best move + - centipawn_score (float): Average centipawn quality of moves (0-100 scale) + - output (str): Human-readable summary + """ + # Get puzzle state + is_solved = env.is_solved + moves_made = env.moves_made + num_moves = len(moves_made) + + # Calculate solve bonus + if is_solved: + solve_reward = 1.0 + else: + # Partial credit for progress through the solution + solution_len = len(env._solution_moves) + if solution_len > 1: + # Adjust for the setup move + progress = max(0, env._current_solution_idx - 1) / (solution_len - 1) + solve_reward = progress * 0.5 # Up to 0.5 for partial progress + elif solution_len == 1: + # Single move puzzle + solve_reward = 0.0 + else: + solve_reward = 0.0 + + # Calculate move quality reward using Stockfish + centipawn_total = 0.0 + best_move_matches = 0 + + if num_moves > 0 and env._engine is not None: + # Evaluate each move made + # We need to replay from the starting position + import chess + temp_board = chess.Board(env._puzzle_fen) + + # Apply setup move if it was made + if len(env._solution_moves) > 1 and env._current_solution_idx >= 1: + try: + setup_move = chess.Move.from_uci(env._solution_moves[0]) + if setup_move in temp_board.legal_moves: + temp_board.push(setup_move) + except ValueError: + pass + + for i, move_uci in enumerate(moves_made): + try: + # Get best move for this position + best_move, best_cp = await env.get_best_move() + + # Check if agent's move matches best move + if move_uci == best_move: + best_move_matches += 1 + centipawn_total += 100.0 # Perfect score for matching best + else: + # Evaluate the quality of the actual move + cp_loss = await env.evaluate_move(move_uci) + # Convert centipawn loss to 0-100 scale + # 0 cp loss = 100, -300 cp loss = 0 + normalized = max(0.0, min(100.0, 100.0 + (cp_loss / 3.0))) + centipawn_total += normalized + + # Apply the move to continue analysis + move = chess.Move.from_uci(move_uci) + if move in temp_board.legal_moves: + temp_board.push(move) + + except Exception: + # If analysis fails, give partial credit + centipawn_total += 50.0 + + # Average centipawn score + avg_cp = centipawn_total / num_moves if num_moves > 0 else 50.0 + move_quality_reward = avg_cp / 100.0 # 0.0 to 1.0 + + # Combine rewards + # 60% for solving, 40% for move quality + total_reward = 0.6 * solve_reward + 0.4 * move_quality_reward + + # Build output summary + output_parts = [ + f"Puzzle {'SOLVED!' if is_solved else 'not solved'}", + f"Moves made: {num_moves}", + f"Best move matches: {best_move_matches}/{num_moves}" if num_moves > 0 else "No moves made", + f"Average move quality: {avg_cp:.1f}/100", + f"Total reward: {total_reward:.3f}" + ] + + return { + "reward": total_reward, + "is_solved": is_solved, + "moves_made": num_moves, + "best_move_matches": best_move_matches, + "centipawn_score": avg_cp, + "output": "\n".join(output_parts), + } + + +@reward( + name="chess_puzzle_reward_simple", + env_cls=ChessPuzzleEnv, + pool_size=8 +) +async def chess_puzzle_reward_simple(prediction: str, env: ChessPuzzleEnv) -> Dict[str, Any]: + """ + Simple binary reward for chess puzzle solving. + + Returns 1.0 if puzzle is solved correctly, 0.0 otherwise. + Useful for comparison with dense reward and for simpler training setups + where you only care about correct solutions. + + Args: + prediction (str): The agent's final response/output (not used). + env (ChessPuzzleEnv): The chess puzzle environment instance. + + Returns: + dict: Contains: + - reward (float): 1.0 if solved, 0.0 otherwise + - is_solved (bool): Whether the puzzle was solved + - output (str): Human-readable status message + """ + is_solved = env.is_solved + + return { + "reward": 1.0 if is_solved else 0.0, + "is_solved": is_solved, + "output": f"Puzzle {'solved' if is_solved else 'not solved'}", + } diff --git a/agentfly/tests/unit/envs/test_chess_env.py b/agentfly/tests/unit/envs/test_chess_env.py new file mode 100644 index 0000000..30c215a --- /dev/null +++ b/agentfly/tests/unit/envs/test_chess_env.py @@ -0,0 +1,217 @@ +# test_chess_env.py +""" +Unit tests for the chess puzzle environment. + +Note: These tests require Stockfish to be installed. +Install with: brew install stockfish (macOS) or apt-get install stockfish (Linux) +""" + +import pytest +from ....envs.chess_env import ChessPuzzleEnv + + +# Skip all tests if Stockfish is not available +pytestmark = pytest.mark.skipif( + not pytest.importorskip("chess"), + reason="python-chess not installed" +) + + +@pytest.fixture +async def chess_env(): + """Create and start a chess environment for testing.""" + env = ChessPuzzleEnv() + try: + await env.start() + yield env + finally: + await env.aclose() + + +@pytest.mark.asyncio +async def test_env_start_and_close(): + """Test that environment can start and close properly.""" + env = ChessPuzzleEnv() + await env.start() + assert env._engine is not None + assert env._board is not None + await env.aclose() + assert env._engine is None + + +@pytest.mark.asyncio +async def test_env_reset_default_puzzle(chess_env): + """Test resetting to the default puzzle.""" + obs = await chess_env.reset() + + # Should have board state in observation + assert "FEN:" in obs + assert "Turn:" in obs + assert "Legal moves:" in obs + + +@pytest.mark.asyncio +async def test_env_reset_custom_puzzle(chess_env): + """Test resetting with a custom puzzle.""" + puzzle = { + "puzzle_id": "test_mate_in_1", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7" # Qxf7# is mate + } + + obs = await chess_env.reset(puzzle) + + assert chess_env._puzzle_id == "test_mate_in_1" + assert "White" in obs # It's White's turn + assert not chess_env._is_solved + + +@pytest.mark.asyncio +async def test_make_correct_move(chess_env): + """Test making the correct puzzle move.""" + puzzle = { + "puzzle_id": "test_mate_in_1", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7" + } + + await chess_env.reset(puzzle) + + # Make the correct move (Qxf7#) + result = await chess_env.step("h5f7") + + assert "Correct" in result or "solved" in result.lower() + assert chess_env._is_solved + + +@pytest.mark.asyncio +async def test_make_move_san_format(chess_env): + """Test making a move in SAN format.""" + puzzle = { + "puzzle_id": "test_mate_in_1", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7" + } + + await chess_env.reset(puzzle) + + # Make the correct move in SAN format + result = await chess_env.step("Qxf7#") + + assert "Correct" in result or "solved" in result.lower() + assert chess_env._is_solved + + +@pytest.mark.asyncio +async def test_make_illegal_move(chess_env): + """Test that illegal moves are rejected.""" + await chess_env.reset() + + # Try an illegal move + result = await chess_env.step("a1a8") # Can't move like this + + assert "Illegal" in result or "Invalid" in result + + +@pytest.mark.asyncio +async def test_get_state(chess_env): + """Test the get_state action.""" + await chess_env.reset() + + result = await chess_env.step("get_state") + + assert "FEN:" in result + assert "Turn:" in result + + +@pytest.mark.asyncio +async def test_get_legal_moves(chess_env): + """Test the get_legal_moves action.""" + await chess_env.reset() + + result = await chess_env.step("get_legal_moves") + + # Should contain some legal moves + assert len(result) > 0 + # Should have UCI format moves + assert "(" in result # Format is "uci (san)" + + +@pytest.mark.asyncio +async def test_get_evaluation(chess_env): + """Test the get_reward/evaluation action.""" + await chess_env.reset() + + result = await chess_env.step("get_reward") + + assert isinstance(result, dict) + assert "reward" in result + + +@pytest.mark.asyncio +async def test_get_best_move(chess_env): + """Test getting the best move from Stockfish.""" + puzzle = { + "puzzle_id": "test", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7" + } + + await chess_env.reset(puzzle) + + best_move, cp = await chess_env.get_best_move() + + # Stockfish should find the mate + assert best_move == "h5f7" # Qxf7# + + +@pytest.mark.asyncio +async def test_puzzle_state_tracking(chess_env): + """Test that puzzle state is tracked correctly.""" + puzzle = { + "puzzle_id": "test", + "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", + "moves": "e2e4 e7e5 g1f3" # Multi-move puzzle + } + + await chess_env.reset(puzzle) + + # Initial state + assert len(chess_env.moves_made) == 0 + assert not chess_env.is_solved + + # After setup, solution index should be 1 (first move was played) + assert chess_env._current_solution_idx == 1 + + # Make the correct move + await chess_env.step("e2e4") + + assert len(chess_env.moves_made) == 1 + assert "e2e4" in chess_env.moves_made + + +@pytest.mark.asyncio +async def test_multiple_puzzle_resets(chess_env): + """Test that environment can be reset multiple times.""" + for i in range(3): + obs = await chess_env.reset({ + "puzzle_id": f"test_{i}", + "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", + "moves": "e2e4" + }) + + assert chess_env._puzzle_id == f"test_{i}" + assert len(chess_env.moves_made) == 0 + assert not chess_env.is_solved + + +@pytest.mark.asyncio +async def test_board_property(chess_env): + """Test that board property returns the current board.""" + await chess_env.reset() + + board = chess_env.board + + assert board is not None + # Initial position should have all pieces + assert len(list(board.legal_moves)) > 0 diff --git a/agentfly/tools/__init__.py b/agentfly/tools/__init__.py index 959b472..3f5d127 100644 --- a/agentfly/tools/__init__.py +++ b/agentfly/tools/__init__.py @@ -28,6 +28,7 @@ from .src.search.async_dense_retriever import asyncdense_retrieve from .src.scienceworld.tools import scienceworld_explorer from .src.ui.tools import pyautogui_code_generator +from .src.chess.tools import chess_move, chess_get_state, chess_get_legal_moves # Add explicit tools in case they weren't auto-registered @@ -47,7 +48,10 @@ "invalid_input_tool": invalid_input_tool, "dense_retrieve": dense_retrieve, "pyautogui_code_generator": pyautogui_code_generator, - "calculator": calculator + "calculator": calculator, + "chess_move": chess_move, + "chess_get_state": chess_get_state, + "chess_get_legal_moves": chess_get_legal_moves, } # Update the registry with explicit tools diff --git a/agentfly/tools/src/chess/__init__.py b/agentfly/tools/src/chess/__init__.py new file mode 100644 index 0000000..41912d9 --- /dev/null +++ b/agentfly/tools/src/chess/__init__.py @@ -0,0 +1,4 @@ +# Chess tools module +from .tools import chess_move, chess_get_state, chess_get_legal_moves + +__all__ = ["chess_move", "chess_get_state", "chess_get_legal_moves"] diff --git a/agentfly/tools/src/chess/tools.py b/agentfly/tools/src/chess/tools.py new file mode 100644 index 0000000..74e4054 --- /dev/null +++ b/agentfly/tools/src/chess/tools.py @@ -0,0 +1,112 @@ +# chess/tools.py +""" +Chess puzzle tools for AgentFly. + +These tools allow agents to interact with chess puzzles: +- chess_move: Make a move on the board +- chess_get_state: Get the current board state +- chess_get_legal_moves: List all legal moves +""" + +import traceback + +from ...tool_base import tool +from ....envs.chess_env import ChessPuzzleEnv + + +@tool( + env_cls=ChessPuzzleEnv, + name="chess_move", + description="Make a chess move in the current puzzle. The move can be in UCI format (e.g., 'e2e4', 'g1f3', 'e7e8q' for promotion) or standard algebraic notation (e.g., 'e4', 'Nf3', 'O-O' for castling, 'Qxf7#' for checkmate). Returns whether the move was correct and the new board state.", + stateful=True, + pool_size=8 +) +async def chess_move(move: str, env: ChessPuzzleEnv): + """ + Make a chess move in the puzzle. + + Args: + move (str): The move to make. Can be in UCI format (e.g., 'e2e4', 'h5f7') + or SAN format (e.g., 'e4', 'Nf3', 'Qxf7+', 'O-O'). + env (ChessPuzzleEnv): The chess puzzle environment instance (auto-injected). + + Returns: + str: The result of the move including: + - Whether the move was correct for the puzzle + - The new board state (FEN and visual representation) + - Current game status (check, checkmate, etc.) + - Error message if the move is invalid/illegal + """ + try: + result = await env.step(move) + return result + except Exception as e: + return f"Error: {str(e)}\n{traceback.format_exc()}" + + +@tool( + env_cls=ChessPuzzleEnv, + name="chess_get_state", + description="Get the current chess board state including FEN notation, visual board representation, whose turn it is, and puzzle status. Use this to understand the current position before making a move.", + stateful=True, + pool_size=8 +) +async def chess_get_state(env: ChessPuzzleEnv): + """ + Get the current state of the chess puzzle. + + Args: + env (ChessPuzzleEnv): The chess puzzle environment instance (auto-injected). + + Returns: + str: A detailed representation of the current board state including: + - FEN notation (standard chess position encoding) + - ASCII board visualization + - Whose turn it is (White or Black) + - Number of legal moves available + - Check/checkmate/stalemate status + - Whether the puzzle is solved + - Moves played so far + """ + try: + result = await env.step("get_state") + return result + except Exception as e: + return f"Error: {str(e)}\n{traceback.format_exc()}" + + +@tool( + env_cls=ChessPuzzleEnv, + name="chess_get_legal_moves", + description="Get all legal moves in the current position. Each move is shown in both UCI format (e.g., 'e2e4') and standard algebraic notation (e.g., 'e4'). Use this when you need to know what moves are available.", + stateful=True, + pool_size=8 +) +async def chess_get_legal_moves(env: ChessPuzzleEnv): + """ + Get all legal moves in the current position. + + Args: + env (ChessPuzzleEnv): The chess puzzle environment instance (auto-injected). + + Returns: + str: A comma-separated list of legal moves in format "uci (san)", + e.g., "e2e4 (e4), g1f3 (Nf3), d2d4 (d4)" + Sorted alphabetically by UCI notation. + """ + try: + result = await env.step("get_legal_moves") + return result + except Exception as e: + return f"Error: {str(e)}\n{traceback.format_exc()}" + + +if __name__ == "__main__": + print("Chess Tools Schemas:") + print("=" * 50) + print("\nchess_move schema:") + print(chess_move.schema) + print("\nchess_get_state schema:") + print(chess_get_state.schema) + print("\nchess_get_legal_moves schema:") + print(chess_get_legal_moves.schema) diff --git a/agentfly/utils/__init__.py b/agentfly/utils/__init__.py index 41cc621..608f4c4 100644 --- a/agentfly/utils/__init__.py +++ b/agentfly/utils/__init__.py @@ -1,3 +1,12 @@ from .timing import Timer from .logging import Logger from .monitor import Monitor +from .chess_puzzles import ( + load_lichess_puzzles, + load_puzzles_jsonl, + generate_puzzle_prompt, + filter_puzzles_by_theme, + filter_puzzles_by_rating, + save_puzzles_jsonl, + LICHESS_THEMES, +) diff --git a/agentfly/utils/chess_puzzles.py b/agentfly/utils/chess_puzzles.py new file mode 100644 index 0000000..6ef5d4a --- /dev/null +++ b/agentfly/utils/chess_puzzles.py @@ -0,0 +1,325 @@ +# chess_puzzles.py +""" +Chess puzzle data loading utilities for AgentFly. + +Supports loading puzzles from: +- Lichess puzzle database CSV format +- JSONL format for training + +The Lichess puzzle database can be downloaded from: +https://database.lichess.org/#puzzles +""" + +import csv +import json +from pathlib import Path +from typing import Dict, List, Optional, Iterator, Union + + +def load_lichess_puzzles( + csv_path: Union[str, Path], + max_puzzles: Optional[int] = None, + min_rating: int = 0, + max_rating: int = 3000, + themes: Optional[List[str]] = None, + skip_first_n: int = 0, +) -> List[Dict]: + """ + Load puzzles from Lichess puzzle database CSV. + + The Lichess puzzle CSV format has columns: + PuzzleId,FEN,Moves,Rating,RatingDeviation,Popularity,NbPlays,Themes,GameUrl,OpeningTags + + Args: + csv_path: Path to the Lichess puzzles CSV file + max_puzzles: Maximum number of puzzles to load (None for all) + min_rating: Minimum puzzle rating to include + max_rating: Maximum puzzle rating to include + themes: Optional list of themes to filter by (e.g., ["mateIn1", "short"]) + If provided, only puzzles with at least one matching theme are included + skip_first_n: Number of puzzles to skip from the beginning + + Returns: + List of puzzle dictionaries in AgentFly format with keys: + - messages: List of message dicts for the conversation + - puzzle_id: Unique puzzle identifier + - fen: Starting FEN position + - moves: Space-separated UCI moves (solution) + - rating: Puzzle difficulty rating + - themes: List of puzzle themes + + Example: + ```python + puzzles = load_lichess_puzzles( + "lichess_puzzles.csv", + max_puzzles=1000, + min_rating=1200, + max_rating=1800, + themes=["mateIn2"] + ) + ``` + """ + puzzles = [] + csv_path = Path(csv_path) + + with open(csv_path, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + + skipped = 0 + for row in reader: + # Skip first N puzzles + if skipped < skip_first_n: + skipped += 1 + continue + + # Filter by rating + try: + rating = int(row['Rating']) + except (ValueError, KeyError): + continue + + if rating < min_rating or rating > max_rating: + continue + + # Filter by themes + puzzle_themes = row.get('Themes', '').split() + if themes and not any(t in puzzle_themes for t in themes): + continue + + # Create puzzle in AgentFly format + puzzle = { + "messages": [ + { + "role": "user", + "content": generate_puzzle_prompt( + rating=rating, + themes=puzzle_themes + ) + } + ], + "puzzle_id": row['PuzzleId'], + "fen": row['FEN'], + "moves": row['Moves'], + "rating": rating, + "themes": puzzle_themes, + } + puzzles.append(puzzle) + + if max_puzzles and len(puzzles) >= max_puzzles: + break + + return puzzles + + +def load_puzzles_jsonl( + jsonl_path: Union[str, Path], + max_puzzles: Optional[int] = None, +) -> List[Dict]: + """ + Load puzzles from JSONL file. + + Each line should be a JSON object with at least: + - fen: Starting FEN position + - moves: Space-separated UCI moves (solution) + + Optional fields: + - puzzle_id: Unique identifier + - rating: Difficulty rating + - themes: List of themes + - messages: Pre-formatted messages + + Args: + jsonl_path: Path to the JSONL file + max_puzzles: Maximum number of puzzles to load + + Returns: + List of puzzle dictionaries + """ + puzzles = [] + jsonl_path = Path(jsonl_path) + + with open(jsonl_path, 'r', encoding='utf-8') as f: + for i, line in enumerate(f): + if max_puzzles and len(puzzles) >= max_puzzles: + break + + line = line.strip() + if not line: + continue + + try: + puzzle = json.loads(line) + except json.JSONDecodeError: + continue + + # Ensure required fields + if 'fen' not in puzzle or 'moves' not in puzzle: + continue + + # Add default fields if missing + if 'puzzle_id' not in puzzle: + puzzle['puzzle_id'] = f"puzzle_{i}" + + if 'messages' not in puzzle: + puzzle['messages'] = [ + { + "role": "user", + "content": generate_puzzle_prompt( + rating=puzzle.get('rating'), + themes=puzzle.get('themes', []) + ) + } + ] + + puzzles.append(puzzle) + + return puzzles + + +def generate_puzzle_prompt( + rating: Optional[int] = None, + themes: Optional[List[str]] = None, + include_hints: bool = True, +) -> str: + """ + Generate the user prompt for a chess puzzle. + + Args: + rating: Puzzle difficulty rating + themes: List of puzzle themes + include_hints: Whether to include hints based on themes + + Returns: + Formatted prompt string for the agent + """ + parts = ["You are solving a chess puzzle."] + + if rating: + parts.append(f"Difficulty rating: {rating}") + + if include_hints and themes: + # Add hints based on common themes + if 'mateIn1' in themes: + parts.append("Hint: This is a mate in 1 - find the checkmate!") + elif 'mateIn2' in themes: + parts.append("Hint: This is a mate in 2 moves.") + elif 'mateIn3' in themes: + parts.append("Hint: This is a mate in 3 moves.") + elif 'mateIn4' in themes: + parts.append("Hint: This is a mate in 4 moves.") + elif 'fork' in themes: + parts.append("Hint: Look for a fork (attacking multiple pieces at once).") + elif 'pin' in themes: + parts.append("Hint: Look for a pin.") + elif 'skewer' in themes: + parts.append("Hint: Look for a skewer.") + elif 'discoveredAttack' in themes: + parts.append("Hint: Look for a discovered attack.") + + parts.extend([ + "", + "Use the chess_get_state tool to see the current board position.", + "Use the chess_get_legal_moves tool to see available moves.", + "Use the chess_move tool to make your move(s).", + "", + "Find the best move(s) to solve this puzzle." + ]) + + return "\n".join(parts) + + +def filter_puzzles_by_theme( + puzzles: List[Dict], + themes: List[str], + require_all: bool = False +) -> List[Dict]: + """ + Filter puzzles by themes. + + Args: + puzzles: List of puzzle dictionaries + themes: Themes to filter by + require_all: If True, puzzle must have ALL themes. If False, any matching theme. + + Returns: + Filtered list of puzzles + """ + result = [] + for puzzle in puzzles: + puzzle_themes = puzzle.get('themes', []) + if require_all: + if all(t in puzzle_themes for t in themes): + result.append(puzzle) + else: + if any(t in puzzle_themes for t in themes): + result.append(puzzle) + return result + + +def filter_puzzles_by_rating( + puzzles: List[Dict], + min_rating: int = 0, + max_rating: int = 3000 +) -> List[Dict]: + """ + Filter puzzles by rating range. + + Args: + puzzles: List of puzzle dictionaries + min_rating: Minimum rating (inclusive) + max_rating: Maximum rating (inclusive) + + Returns: + Filtered list of puzzles + """ + return [ + p for p in puzzles + if min_rating <= p.get('rating', 0) <= max_rating + ] + + +def save_puzzles_jsonl( + puzzles: List[Dict], + output_path: Union[str, Path], +) -> None: + """ + Save puzzles to JSONL file. + + Args: + puzzles: List of puzzle dictionaries + output_path: Path to output file + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + for puzzle in puzzles: + f.write(json.dumps(puzzle, ensure_ascii=False) + '\n') + + +# Common Lichess puzzle themes for reference +LICHESS_THEMES = [ + # Tactical motifs + "fork", "pin", "skewer", "discoveredAttack", "doubleCheck", + "sacrifice", "deflection", "interference", "xRayAttack", + "zugzwang", "quietMove", "defensiveMove", "clearance", + + # Mate patterns + "mateIn1", "mateIn2", "mateIn3", "mateIn4", "mateIn5", + "anastasiaMate", "arabianMate", "backRankMate", "bodenMate", + "doubleBishopMate", "hookMate", "smotheredMate", + + # Length + "oneMove", "short", "long", "veryLong", + + # Game phase + "opening", "middlegame", "endgame", + "rookEndgame", "bishopEndgame", "knightEndgame", "pawnEndgame", "queenEndgame", + + # Difficulty + "master", "masterVsMaster", "superGM", + + # Special + "castling", "enPassant", "promotion", "underPromotion", + "equality", "advantage", "crushing", +] From 740784085f73a8d7374bc78b2d97d5c40f147cf1 Mon Sep 17 00:00:00 2001 From: Shu Date: Thu, 22 Jan 2026 21:48:16 -0800 Subject: [PATCH 2/2] Add chess puzzle training scripts, sample data, and documentation - scripts/train_chess.sh: Training script with hyperparameters for chess puzzle RL - scripts/prepare_chess_data.py: Convert Lichess CSV to training format - data/chess/: Sample train/val puzzles (10 train, 5 val) - chess_readme.MD: Quick start guide and configuration reference --- chess_readme.MD | 267 ++++++++++++++++++++++++++++ data/chess/chess_puzzles_train.json | 82 +++++++++ data/chess/chess_puzzles_val.json | 42 +++++ scripts/prepare_chess_data.py | 150 ++++++++++++++++ scripts/train_chess.sh | 154 ++++++++++++++++ 5 files changed, 695 insertions(+) create mode 100644 chess_readme.MD create mode 100644 data/chess/chess_puzzles_train.json create mode 100644 data/chess/chess_puzzles_val.json create mode 100755 scripts/prepare_chess_data.py create mode 100755 scripts/train_chess.sh diff --git a/chess_readme.MD b/chess_readme.MD new file mode 100644 index 0000000..b55c3cf --- /dev/null +++ b/chess_readme.MD @@ -0,0 +1,267 @@ +# Chess Puzzle Solving Agent + +Train an RL agent to solve chess puzzles using AgentFly. + +## Overview + +This module provides a complete chess puzzle solving environment for training language model agents with reinforcement learning. The agent learns to analyze chess positions and find tactical solutions (checkmates, forks, pins, etc.). + +**Architecture:** +``` +Agent (Qwen/Llama/etc.) + ↓ +Tools: chess_move, chess_get_state, chess_get_legal_moves + ↓ +ChessPuzzleEnv (python-chess + Stockfish) + ↓ +Rewards: chess_puzzle_reward (dense) or chess_puzzle_reward_simple (binary) +``` + +## Prerequisites + +1. **Stockfish chess engine:** + ```bash + # macOS + brew install stockfish + + # Ubuntu/Debian + apt-get install stockfish + + # Verify installation + which stockfish + ``` + +2. **Python dependencies:** + ```bash + pip install python-chess + ``` + +## Quick Start + +### Option 1: Test Run with Sample Data + +The repo includes sample puzzles for testing: + +```bash +bash scripts/train_chess.sh +``` + +### Option 2: Train with Lichess Puzzles + +**Step 1: Download Lichess puzzle database** +```bash +# Download (~250MB compressed, ~1.5GB uncompressed) +curl -O https://database.lichess.org/lichess_db_puzzle.csv.zst + +# Decompress (install zstd if needed: brew install zstd) +zstd -d lichess_db_puzzle.csv.zst +``` + +**Step 2: Prepare training data** +```bash +python scripts/prepare_chess_data.py \ + --input lichess_db_puzzle.csv \ + --output data/chess/ \ + --train-size 10000 \ + --val-size 1000 \ + --min-rating 1000 \ + --max-rating 1600 \ + --themes mateIn1 mateIn2 fork pin +``` + +**Step 3: Run training** +```bash +# Set your WandB key for logging +export WANDB_API_KEY="your_key_here" + +# Start training +bash scripts/train_chess.sh +``` + +## Configuration + +Edit `scripts/train_chess.sh` to customize: + +### Model +```bash +model="Qwen/Qwen2.5-3B-Instruct" # Base model to fine-tune +template="qwen2.5" # Chat template + +# Alternatives: +# model="Qwen/Qwen2.5-7B-Instruct" +# model="meta-llama/Llama-3.1-8B-Instruct" +# template="llama3" +``` + +### Agent +```bash +agent_type="react" # ReAct agent for tool use +max_turns=10 # Max moves per puzzle +num_chains=8 # Parallel rollouts per sample +``` + +### Reward Function +```bash +# Dense reward (recommended) - based on Stockfish evaluation +reward_name="chess_puzzle_reward" + +# Binary reward - 1.0 if solved, 0.0 otherwise +# reward_name="chess_puzzle_reward_simple" +``` + +### Training +```bash +batch_size=64 +lr=4e-7 +total_training_steps=200 +adv_estimator="grpo" # Options: grpo, reinforce_plus_plus, rloo, gae +``` + +## Data Format + +Training data is a JSON array of puzzles: + +```json +[ + { + "question": "You are solving a chess puzzle...", + "puzzle_id": "abc123", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7", + "rating": 1200, + "themes": ["mateIn1", "short"] + } +] +``` + +| Field | Description | +|-------|-------------| +| `question` | Prompt shown to the agent | +| `puzzle_id` | Unique identifier | +| `fen` | Starting position (FEN notation) | +| `moves` | Solution moves (space-separated UCI) | +| `rating` | Difficulty rating (optional) | +| `themes` | Puzzle themes (optional) | + +### Move Sequence Convention + +For multi-move puzzles, the `moves` field follows Lichess convention: +- **First move**: Opponent's setup move (auto-played by environment) +- **Remaining moves**: Alternating solution moves + +Example: `"e2e4 e7e5 g1f3"` means: +1. Environment plays `e2e4` (opponent setup) +2. Agent must find `e7e5` +3. Environment responds with `g1f3` + +## Available Tools + +| Tool | Description | +|------|-------------| +| `chess_move` | Make a move (UCI: `e2e4` or SAN: `Nf3`) | +| `chess_get_state` | View board position, FEN, turn, status | +| `chess_get_legal_moves` | List all legal moves | + +## Reward Functions + +### `chess_puzzle_reward` (Dense) + +Combines solve bonus with move quality: + +``` +reward = 0.6 × solve_reward + 0.4 × move_quality_reward +``` + +- `solve_reward`: 1.0 if solved, 0-0.5 for partial progress +- `move_quality_reward`: Average centipawn quality (Stockfish evaluation) + +### `chess_puzzle_reward_simple` (Binary) + +- 1.0 if puzzle solved correctly +- 0.0 otherwise + +## Puzzle Themes + +Filter puzzles by tactical motif: + +| Category | Themes | +|----------|--------| +| **Mates** | `mateIn1`, `mateIn2`, `mateIn3`, `backRankMate`, `smotheredMate` | +| **Tactics** | `fork`, `pin`, `skewer`, `discoveredAttack`, `doubleCheck` | +| **Length** | `oneMove`, `short`, `long`, `veryLong` | +| **Phase** | `opening`, `middlegame`, `endgame` | + +## Monitoring + +Training metrics are logged to Weights & Biases: +- Reward per step +- Solve rate +- Average moves per puzzle +- KL divergence +- Loss curves + +View at: https://wandb.ai/your-project + +## Files + +``` +AgentFly/ +├── agentfly/ +│ ├── envs/chess_env.py # Chess puzzle environment +│ ├── tools/src/chess/tools.py # Agent tools +│ ├── rewards/chess_reward.py # Reward functions +│ └── utils/chess_puzzles.py # Data loading utilities +├── scripts/ +│ ├── train_chess.sh # Training script +│ └── prepare_chess_data.py # Lichess → training format +└── data/chess/ + ├── chess_puzzles_train.json # Training data + └── chess_puzzles_val.json # Validation data +``` + +## Troubleshooting + +### Stockfish not found +``` +FileNotFoundError: [Errno 2] No such file or directory: '/opt/homebrew/bin/stockfish' +``` +**Fix:** Update `stockfish_path` in `chess_env.py` or install Stockfish. + +### Out of GPU memory +**Fix:** Reduce `batch_size`, enable offloading: +```bash +actor_rollout_ref.actor.fsdp_config.param_offload=True +actor_rollout_ref.actor.fsdp_config.optimizer_offload=True +``` + +### Slow training +- Reduce `analysis_depth` in `ChessPuzzleEnv` (default: 20) +- Use simpler puzzles (lower rating, `mateIn1` only) +- Decrease `num_chains` + + +### For real training, use Lichess data: + +# Download 4+ million verified puzzles +curl -O https://database.lichess.org/lichess_db_puzzle.csv.zst +zstd -d lichess_db_puzzle.csv.zst + +# Convert to training format +python scripts/prepare_chess_data.py \ +--input lichess_db_puzzle.csv \ +--output data/chess/ \ +--train-size 10000 \ +--val-size 1000 + +The Lichess puzzles are: +- Extracted from real games +- Validated by millions of players +- Rated by difficulty (Elo) +- Tagged with themes (mateIn1, fork, pin, etc.) + + +## References + +- [Lichess Puzzle Database](https://database.lichess.org/#puzzles) +- [python-chess Documentation](https://python-chess.readthedocs.io/) +- [Stockfish](https://stockfishchess.org/) diff --git a/data/chess/chess_puzzles_train.json b/data/chess/chess_puzzles_train.json new file mode 100644 index 0000000..8708f2e --- /dev/null +++ b/data/chess/chess_puzzles_train.json @@ -0,0 +1,82 @@ +[ + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1000\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "00sHx", + "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "h5f7", + "rating": 1000, + "themes": ["mateIn1", "short"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1100\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "mate1_back_rank", + "fen": "6k1/5ppp/8/8/8/8/5PPP/4R1K1 w - - 0 1", + "moves": "e1e8", + "rating": 1100, + "themes": ["mateIn1", "backRankMate"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1050\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "mate1_queen", + "fen": "r1bqk2r/pppp1ppp/2n2n2/2b1p3/2B1P3/5Q2/PPPP1PPP/RNB1K1NR w KQkq - 4 4", + "moves": "f3f7", + "rating": 1050, + "themes": ["mateIn1", "short"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1200\nHint: This is a mate in 2 moves.\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best moves to solve this puzzle.", + "puzzle_id": "mate2_sacrifice", + "fen": "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5Q2/PPPP1PPP/RNB1K1NR w KQkq - 2 3", + "moves": "f3f7 e8f7 c4g8", + "rating": 1200, + "themes": ["mateIn2", "sacrifice"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1150\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "mate1_smothered", + "fen": "r4rk1/ppp2ppp/8/8/8/5N2/PPP2PPP/R4RK1 w - - 0 1", + "moves": "f3h4", + "rating": 1150, + "themes": ["mateIn1"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1300\nHint: This is a mate in 2 moves.\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best moves to solve this puzzle.", + "puzzle_id": "mate2_queen_sac", + "fen": "r1b1k2r/ppppqppp/2n2n2/2b1p3/2B1P3/3P1N2/PPP2PPP/RNBQK2R w KQkq - 0 5", + "moves": "c4f7 e7f7 d1b3", + "rating": 1300, + "themes": ["mateIn2", "fork"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1100\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "mate1_arabian", + "fen": "5rk1/5ppp/8/8/8/8/5PPP/4RNK1 w - - 0 1", + "moves": "e1e8", + "rating": 1100, + "themes": ["mateIn1", "backRankMate"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1250\nHint: This is a mate in 2 moves.\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best moves to solve this puzzle.", + "puzzle_id": "mate2_double_check", + "fen": "rnbqkb1r/pppp1ppp/5n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 2 3", + "moves": "h5f7 e8d8 f7f8", + "rating": 1250, + "themes": ["mateIn2", "doubleCheck"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1000\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "mate1_rook", + "fen": "4k3/8/8/8/8/8/8/R3K3 w Q - 0 1", + "moves": "a1a8", + "rating": 1000, + "themes": ["mateIn1", "short"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1400\nHint: This is a mate in 2 moves.\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best moves to solve this puzzle.", + "puzzle_id": "mate2_classic", + "fen": "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3", + "moves": "f1c4 g8f6 f3g5", + "rating": 1400, + "themes": ["mateIn2", "opening"] + } +] diff --git a/data/chess/chess_puzzles_val.json b/data/chess/chess_puzzles_val.json new file mode 100644 index 0000000..609747b --- /dev/null +++ b/data/chess/chess_puzzles_val.json @@ -0,0 +1,42 @@ +[ + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1050\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "val_mate1_01", + "fen": "r1bqk2r/pppp1Qpp/2n2n2/2b1p3/2B1P3/8/PPPP1PPP/RNB1K1NR b KQkq - 0 4", + "moves": "e8f7", + "rating": 1050, + "themes": ["mateIn1"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1100\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "val_mate1_02", + "fen": "5k2/8/5K2/8/8/8/8/7R w - - 0 1", + "moves": "h1h8", + "rating": 1100, + "themes": ["mateIn1", "rookEndgame"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1200\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "val_mate1_03", + "fen": "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5Q2/PPPP1PPP/RNB1K1NR w KQkq - 2 3", + "moves": "f3f7", + "rating": 1200, + "themes": ["mateIn1"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1300\nHint: This is a mate in 2 moves.\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best moves to solve this puzzle.", + "puzzle_id": "val_mate2_01", + "fen": "r1bqk2r/pppp1ppp/2n2n2/2b1p3/2B1P3/3P1N2/PPP2PPP/RNBQK2R w KQkq - 0 5", + "moves": "f3g5 d7d6 g5f7", + "rating": 1300, + "themes": ["mateIn2"] + }, + { + "question": "You are solving a chess puzzle.\nDifficulty rating: 1150\nHint: This is a mate in 1 - find the checkmate!\n\nUse the chess_get_state tool to see the current board position.\nUse the chess_get_legal_moves tool to see available moves.\nUse the chess_move tool to make your move(s).\n\nFind the best move to solve this puzzle.", + "puzzle_id": "val_mate1_04", + "fen": "6k1/5ppp/8/8/8/8/5PPP/1Q4K1 w - - 0 1", + "moves": "b1b8", + "rating": 1150, + "themes": ["mateIn1", "backRankMate"] + } +] diff --git a/scripts/prepare_chess_data.py b/scripts/prepare_chess_data.py new file mode 100755 index 0000000..7cf0038 --- /dev/null +++ b/scripts/prepare_chess_data.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" +Prepare chess puzzle training data from Lichess puzzle database. + +Usage: + python prepare_chess_data.py --input lichess_db_puzzle.csv --output data/chess/ + +Downloads: + Get the puzzle database from: https://database.lichess.org/#puzzles + + curl -O https://database.lichess.org/lichess_db_puzzle.csv.zst + zstd -d lichess_db_puzzle.csv.zst +""" + +import argparse +import json +import random +from pathlib import Path + +# Add parent directory to path for imports +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from agentfly.utils.chess_puzzles import ( + load_lichess_puzzles, + save_puzzles_jsonl, + filter_puzzles_by_rating, + filter_puzzles_by_theme, + LICHESS_THEMES, +) + + +def main(): + parser = argparse.ArgumentParser(description="Prepare chess puzzle training data") + parser.add_argument( + "--input", "-i", + type=str, + required=True, + help="Path to Lichess puzzle CSV file" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="./data/chess/", + help="Output directory for training data" + ) + parser.add_argument( + "--train-size", + type=int, + default=10000, + help="Number of training puzzles" + ) + parser.add_argument( + "--val-size", + type=int, + default=1000, + help="Number of validation puzzles" + ) + parser.add_argument( + "--min-rating", + type=int, + default=1000, + help="Minimum puzzle rating" + ) + parser.add_argument( + "--max-rating", + type=int, + default=1800, + help="Maximum puzzle rating" + ) + parser.add_argument( + "--themes", + type=str, + nargs="+", + default=["mateIn1", "mateIn2", "mateIn3", "fork", "pin"], + help="Puzzle themes to include" + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility" + ) + + args = parser.parse_args() + + print(f"Loading puzzles from {args.input}...") + print(f" Rating range: {args.min_rating} - {args.max_rating}") + print(f" Themes: {args.themes}") + + # Load puzzles with filters + total_needed = args.train_size + args.val_size + puzzles = load_lichess_puzzles( + csv_path=args.input, + max_puzzles=total_needed * 2, # Load extra for filtering + min_rating=args.min_rating, + max_rating=args.max_rating, + themes=args.themes, + ) + + print(f"Loaded {len(puzzles)} puzzles matching criteria") + + if len(puzzles) < total_needed: + print(f"Warning: Only found {len(puzzles)} puzzles, need {total_needed}") + + # Shuffle and split + random.seed(args.seed) + random.shuffle(puzzles) + + train_puzzles = puzzles[:args.train_size] + val_puzzles = puzzles[args.train_size:args.train_size + args.val_size] + + # Create output directory + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + + # Save as JSON (AgentFly format) + train_path = output_dir / "chess_puzzles_train.json" + val_path = output_dir / "chess_puzzles_val.json" + + with open(train_path, "w") as f: + json.dump(train_puzzles, f, indent=2) + + with open(val_path, "w") as f: + json.dump(val_puzzles, f, indent=2) + + print(f"\nSaved {len(train_puzzles)} training puzzles to {train_path}") + print(f"Saved {len(val_puzzles)} validation puzzles to {val_path}") + + # Print rating distribution + train_ratings = [p["rating"] for p in train_puzzles] + if train_ratings: + print(f"\nTraining set rating distribution:") + print(f" Min: {min(train_ratings)}") + print(f" Max: {max(train_ratings)}") + print(f" Avg: {sum(train_ratings) / len(train_ratings):.0f}") + + # Print theme distribution + theme_counts = {} + for p in train_puzzles: + for theme in p.get("themes", []): + theme_counts[theme] = theme_counts.get(theme, 0) + 1 + + print(f"\nTop themes in training set:") + for theme, count in sorted(theme_counts.items(), key=lambda x: -x[1])[:10]: + print(f" {theme}: {count}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_chess.sh b/scripts/train_chess.sh new file mode 100755 index 0000000..5b542db --- /dev/null +++ b/scripts/train_chess.sh @@ -0,0 +1,154 @@ +#!/bin/bash +# Chess Puzzle Solving Agent Training Script +# +# This script trains an agent to solve chess puzzles using: +# - Tools: chess_move, chess_get_state, chess_get_legal_moves +# - Reward: chess_puzzle_reward (dense) or chess_puzzle_reward_simple (binary) +# - Environment: ChessPuzzleEnv (python-chess + Stockfish) +# +# Prerequisites: +# 1. Install Stockfish: brew install stockfish (macOS) or apt install stockfish (Linux) +# 2. Prepare training data in data/chess/ directory +# 3. Set WANDB_API_KEY for logging + +set -x + +# ============================================================================ +# Environment Setup +# ============================================================================ + +export WANDB_API_KEY="${WANDB_API_KEY:-your_wandb_key}" +export VLLM_USE_V1=1 +export HYDRA_FULL_ERROR=1 + +# Ray cluster setup +head_node_ip=$(hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +# Clean up existing Ray cluster +ray stop +rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node (adjust --num-cpus and --num-gpus for your hardware) +ray start --head --node-ip-address="$head_node_ip" --port=$port --num-cpus 32 --num-gpus 1 + +# ============================================================================ +# Model Configuration +# ============================================================================ + +# Base model to fine-tune +model="Qwen/Qwen2.5-3B-Instruct" +template="qwen2.5" + +# Alternative models (uncomment to use): +# model="Qwen/Qwen2.5-7B-Instruct" +# model="meta-llama/Llama-3.1-8B-Instruct" +# template="llama3" + +# ============================================================================ +# Agent Configuration +# ============================================================================ + +agent_type="react" # ReAct agent for tool-using tasks +agent_backend="async_verl" + +# Chess-specific tools +tools="[chess_move,chess_get_state,chess_get_legal_moves]" + +# Reward function: +# - chess_puzzle_reward: Dense reward based on Stockfish evaluation (recommended) +# - chess_puzzle_reward_simple: Binary reward (solved/not solved) +reward_name="chess_puzzle_reward" + +# Maximum turns per puzzle (moves + state checks) +max_turns=10 + +# Parallel rollouts per puzzle sample +num_chains=8 + +# ============================================================================ +# Training Data +# ============================================================================ + +train_dataset="./data/chess/chess_puzzles_train.json" +val_dataset="./data/chess/chess_puzzles_val.json" + +# ============================================================================ +# Training Hyperparameters +# ============================================================================ + +batch_size=64 +mini_batch_size=$batch_size +lr=4e-7 +kl_coef=0.001 +entropy_coeff=0.001 +kl_loss_type="mse" +response_length=256 + +# Advantage estimator options: grpo, reinforce_plus_plus, rloo, remax, gae +adv_estimator="grpo" + +# Training duration +total_training_steps=200 +save_freq=50 +test_freq=10 + +# ============================================================================ +# Logging +# ============================================================================ + +project_name="AgentRL" +experiment_name="chess_puzzle_solver_$(date +%Y%m%d_%H%M%S)" + +# ============================================================================ +# Launch Training +# ============================================================================ + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + data.train_files=$train_dataset \ + data.val_files=$val_dataset \ + data.train_batch_size=$batch_size \ + agent.agent_type=$agent_type \ + agent.tools=$tools \ + agent.template=$template \ + agent.model_name_or_path=$model \ + agent.max_turns=${max_turns} \ + agent.backend=${agent_backend} \ + agent.reward_name=$reward_name \ + agent.num_chains=$num_chains \ + agent.use_agent=True \ + actor_rollout_ref.actor.optim.lr=$lr \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.model.path=${model} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=$kl_coef \ + actor_rollout_ref.actor.kl_loss_type=$kl_loss_type \ + actor_rollout_ref.actor.entropy_coeff=$entropy_coeff \ + actor_rollout_ref.model.enable_gradient_checkpointing=False \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.response_length=$response_length \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + critic.model.path=$model \ + critic.ppo_mini_batch_size=${mini_batch_size} \ + critic.ppo_micro_batch_size_per_gpu=2 \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=$save_freq \ + trainer.test_freq=$test_freq \ + trainer.total_training_steps=$total_training_steps \ + trainer.val_before_train=False